From 105a21548f15b6c07f9c76aa916486e07af6262d Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Wed, 1 Apr 2026 13:17:10 +0800 Subject: [PATCH] fix(codex): centralize session management with global store and add tests for executor session lifecycle --- .../executor/codex_websockets_executor.go | 128 +++++++++++++++--- .../codex_websockets_executor_store_test.go | 48 +++++++ sdk/cliproxy/service.go | 1 + 3 files changed, 159 insertions(+), 18 deletions(-) create mode 100644 internal/runtime/executor/codex_websockets_executor_store_test.go diff --git a/internal/runtime/executor/codex_websockets_executor.go b/internal/runtime/executor/codex_websockets_executor.go index afc255e3..dc9a8a79 100644 --- a/internal/runtime/executor/codex_websockets_executor.go +++ b/internal/runtime/executor/codex_websockets_executor.go @@ -46,10 +46,18 @@ const ( type CodexWebsocketsExecutor struct { *CodexExecutor - sessMu sync.Mutex + store *codexWebsocketSessionStore +} + +type codexWebsocketSessionStore struct { + mu sync.Mutex sessions map[string]*codexWebsocketSession } +var globalCodexWebsocketSessionStore = &codexWebsocketSessionStore{ + sessions: make(map[string]*codexWebsocketSession), +} + type codexWebsocketSession struct { sessionID string @@ -73,7 +81,7 @@ type codexWebsocketSession struct { func NewCodexWebsocketsExecutor(cfg *config.Config) *CodexWebsocketsExecutor { return &CodexWebsocketsExecutor{ CodexExecutor: NewCodexExecutor(cfg), - sessions: make(map[string]*codexWebsocketSession), + store: globalCodexWebsocketSessionStore, } } @@ -1058,16 +1066,23 @@ func (e *CodexWebsocketsExecutor) getOrCreateSession(sessionID string) *codexWeb if sessionID == "" { return nil } - e.sessMu.Lock() - defer e.sessMu.Unlock() - if e.sessions == nil { - e.sessions = make(map[string]*codexWebsocketSession) + if e == nil { + return nil } - if sess, ok := e.sessions[sessionID]; ok && sess != nil { + store := e.store + if store == nil { + store = globalCodexWebsocketSessionStore + } + store.mu.Lock() + defer store.mu.Unlock() + if store.sessions == nil { + store.sessions = make(map[string]*codexWebsocketSession) + } + if sess, ok := store.sessions[sessionID]; ok && sess != nil { return sess } sess := &codexWebsocketSession{sessionID: sessionID} - e.sessions[sessionID] = sess + store.sessions[sessionID] = sess return sess } @@ -1213,14 +1228,20 @@ func (e *CodexWebsocketsExecutor) CloseExecutionSession(sessionID string) { return } if sessionID == cliproxyauth.CloseAllExecutionSessionsID { - e.closeAllExecutionSessions("executor_replaced") + // Executor replacement can happen during hot reload (config/credential changes). + // Do not force-close upstream websocket sessions here, otherwise in-flight + // downstream websocket requests get interrupted. return } - e.sessMu.Lock() - sess := e.sessions[sessionID] - delete(e.sessions, sessionID) - e.sessMu.Unlock() + store := e.store + if store == nil { + store = globalCodexWebsocketSessionStore + } + store.mu.Lock() + sess := store.sessions[sessionID] + delete(store.sessions, sessionID) + store.mu.Unlock() e.closeExecutionSession(sess, "session_closed") } @@ -1230,15 +1251,19 @@ func (e *CodexWebsocketsExecutor) closeAllExecutionSessions(reason string) { return } - e.sessMu.Lock() - sessions := make([]*codexWebsocketSession, 0, len(e.sessions)) - for sessionID, sess := range e.sessions { - delete(e.sessions, sessionID) + store := e.store + if store == nil { + store = globalCodexWebsocketSessionStore + } + store.mu.Lock() + sessions := make([]*codexWebsocketSession, 0, len(store.sessions)) + for sessionID, sess := range store.sessions { + delete(store.sessions, sessionID) if sess != nil { sessions = append(sessions, sess) } } - e.sessMu.Unlock() + store.mu.Unlock() for i := range sessions { e.closeExecutionSession(sessions[i], reason) @@ -1246,6 +1271,10 @@ func (e *CodexWebsocketsExecutor) closeAllExecutionSessions(reason string) { } func (e *CodexWebsocketsExecutor) closeExecutionSession(sess *codexWebsocketSession, reason string) { + closeCodexWebsocketSession(sess, reason) +} + +func closeCodexWebsocketSession(sess *codexWebsocketSession, reason string) { if sess == nil { return } @@ -1286,6 +1315,69 @@ func logCodexWebsocketDisconnected(sessionID string, authID string, wsURL string log.Infof("codex websockets: upstream disconnected session=%s auth=%s url=%s reason=%s", strings.TrimSpace(sessionID), strings.TrimSpace(authID), strings.TrimSpace(wsURL), strings.TrimSpace(reason)) } +// CloseCodexWebsocketSessionsForAuthID closes all active Codex upstream websocket sessions +// associated with the supplied auth ID. +func CloseCodexWebsocketSessionsForAuthID(authID string, reason string) { + authID = strings.TrimSpace(authID) + if authID == "" { + return + } + reason = strings.TrimSpace(reason) + if reason == "" { + reason = "auth_removed" + } + + store := globalCodexWebsocketSessionStore + if store == nil { + return + } + + type sessionItem struct { + sessionID string + sess *codexWebsocketSession + } + + store.mu.Lock() + items := make([]sessionItem, 0, len(store.sessions)) + for sessionID, sess := range store.sessions { + items = append(items, sessionItem{sessionID: sessionID, sess: sess}) + } + store.mu.Unlock() + + matches := make([]sessionItem, 0) + for i := range items { + sess := items[i].sess + if sess == nil { + continue + } + sess.connMu.Lock() + sessAuthID := strings.TrimSpace(sess.authID) + sess.connMu.Unlock() + if sessAuthID == authID { + matches = append(matches, items[i]) + } + } + if len(matches) == 0 { + return + } + + toClose := make([]*codexWebsocketSession, 0, len(matches)) + store.mu.Lock() + for i := range matches { + current, ok := store.sessions[matches[i].sessionID] + if !ok || current == nil || current != matches[i].sess { + continue + } + delete(store.sessions, matches[i].sessionID) + toClose = append(toClose, current) + } + store.mu.Unlock() + + for i := range toClose { + closeCodexWebsocketSession(toClose[i], reason) + } +} + // CodexAutoExecutor routes Codex requests to the websocket transport only when: // 1. The downstream transport is websocket, and // 2. The selected auth enables websockets. diff --git a/internal/runtime/executor/codex_websockets_executor_store_test.go b/internal/runtime/executor/codex_websockets_executor_store_test.go new file mode 100644 index 00000000..1a23fa31 --- /dev/null +++ b/internal/runtime/executor/codex_websockets_executor_store_test.go @@ -0,0 +1,48 @@ +package executor + +import ( + "testing" + + cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" +) + +func TestCodexWebsocketsExecutor_SessionStoreSurvivesExecutorReplacement(t *testing.T) { + sessionID := "test-session-store-survives-replace" + + globalCodexWebsocketSessionStore.mu.Lock() + delete(globalCodexWebsocketSessionStore.sessions, sessionID) + globalCodexWebsocketSessionStore.mu.Unlock() + + exec1 := NewCodexWebsocketsExecutor(nil) + sess1 := exec1.getOrCreateSession(sessionID) + if sess1 == nil { + t.Fatalf("expected session to be created") + } + + exec2 := NewCodexWebsocketsExecutor(nil) + sess2 := exec2.getOrCreateSession(sessionID) + if sess2 == nil { + t.Fatalf("expected session to be available across executors") + } + if sess1 != sess2 { + t.Fatalf("expected the same session instance across executors") + } + + exec1.CloseExecutionSession(cliproxyauth.CloseAllExecutionSessionsID) + + globalCodexWebsocketSessionStore.mu.Lock() + _, stillPresent := globalCodexWebsocketSessionStore.sessions[sessionID] + globalCodexWebsocketSessionStore.mu.Unlock() + if !stillPresent { + t.Fatalf("expected session to remain after executor replacement close marker") + } + + exec2.CloseExecutionSession(sessionID) + + globalCodexWebsocketSessionStore.mu.Lock() + _, presentAfterClose := globalCodexWebsocketSessionStore.sessions[sessionID] + globalCodexWebsocketSessionStore.mu.Unlock() + if presentAfterClose { + t.Fatalf("expected session to be removed after explicit close") + } +} diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index ffbd7289..3103554a 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -335,6 +335,7 @@ func (s *Service) applyCoreAuthRemoval(ctx context.Context, id string) { log.Errorf("failed to disable auth %s: %v", id, err) } if strings.EqualFold(strings.TrimSpace(existing.Provider), "codex") { + executor.CloseCodexWebsocketSessionsForAuthID(existing.ID, "auth_removed") s.ensureExecutorsForAuth(existing) } }