feat(session-affinity): add session-sticky routing for multi-account load balancing

When multiple auth credentials are configured, requests from the same
session are now routed to the same credential, improving upstream prompt
cache hit rates and maintaining context continuity.

Core components:
- SessionAffinitySelector: wraps RoundRobin/FillFirst selectors with
  session-to-auth binding; automatic failover when bound auth is
  unavailable, re-binding via the fallback selector for even distribution
- SessionCache: TTL-based in-memory cache with background cleanup
  goroutine, supporting per-session and per-auth invalidation
- StoppableSelector interface: lifecycle hook for selectors holding
  resources, called during Manager.StopAutoRefresh()

Session ID extraction priority (extractSessionIDs):
1. metadata.user_id with Claude Code session format (old
   user_{hash}_session_{uuid} and new JSON {session_id} format)
2. X-Session-ID header (generic client support)
3. metadata.user_id (non-Claude format, used as-is)
4. conversation_id field
5. Stable FNV hash from system prompt + first user/assistant messages
   (fallback for clients with no explicit session ID); returns both a
   full hash (primaryID) and a short hash without assistant content
   (fallbackID) to inherit bindings from the first turn

Multi-format message hash covers OpenAI messages, Claude system array,
Gemini contents/systemInstruction, and OpenAI Responses API input items
(including inline messages with role but no type field).

Configuration (config.yaml routing section):
- session-affinity: bool (default false)
- session-affinity-ttl: duration string (default "1h")
- claude-code-session-affinity: bool (deprecated, alias for above)
All three fields trigger selector rebuild on config hot reload.

Side effect: Idempotency-Key header is no longer auto-generated with a
random UUID when absent — only forwarded when explicitly provided by the
client, to avoid polluting session hash extraction.
This commit is contained in:
sususu98
2026-04-15 00:48:08 +08:00
parent a4c1e32ff6
commit 7c24d54ca8
9 changed files with 1517 additions and 4 deletions

View File

@@ -14,7 +14,6 @@ import (
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
@@ -188,7 +187,7 @@ func PassthroughHeadersEnabled(cfg *config.SDKConfig) bool {
func requestExecutionMetadata(ctx context.Context) map[string]any {
// Idempotency-Key is an optional client-supplied header used to correlate retries.
// It is forwarded as execution metadata; when absent we generate a UUID.
// Only include it if the client explicitly provides it.
key := ""
if ctx != nil {
if ginCtx, ok := ctx.Value("gin").(*gin.Context); ok && ginCtx != nil && ginCtx.Request != nil {
@@ -196,7 +195,7 @@ func requestExecutionMetadata(ctx context.Context) map[string]any {
}
}
if key == "" {
key = uuid.NewString()
return make(map[string]any)
}
meta := map[string]any{idempotencyKeyMetadataKey: key}

View File

@@ -105,6 +105,13 @@ type Selector interface {
Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error)
}
// StoppableSelector is an optional interface for selectors that hold resources.
// Selectors that implement this interface will have Stop called during shutdown.
type StoppableSelector interface {
Selector
Stop()
}
// Hook captures lifecycle callbacks for observing auth changes.
type Hook interface {
// OnAuthRegistered fires when a new auth is registered.
@@ -2928,6 +2935,7 @@ func (m *Manager) StartAutoRefresh(parent context.Context, interval time.Duratio
}
// StopAutoRefresh cancels the background refresh loop, if running.
// It also stops the selector if it implements StoppableSelector.
func (m *Manager) StopAutoRefresh() {
m.mu.Lock()
cancel := m.refreshCancel
@@ -2937,6 +2945,10 @@ func (m *Manager) StopAutoRefresh() {
if cancel != nil {
cancel()
}
// Stop selector if it implements StoppableSelector (e.g., SessionAffinitySelector)
if stoppable, ok := m.selector.(StoppableSelector); ok {
stoppable.Stop()
}
}
func (m *Manager) queueRefreshReschedule(authID string) {

View File

@@ -4,15 +4,21 @@ import (
"context"
"encoding/json"
"fmt"
"hash/fnv"
"math"
"math/rand/v2"
"net/http"
"regexp"
"sort"
"strconv"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
)
@@ -420,3 +426,448 @@ func isAuthBlockedForModel(auth *Auth, model string, now time.Time) (bool, block
}
return false, blockReasonNone, time.Time{}
}
// sessionPattern matches Claude Code user_id format:
// user_{hash}_account__session_{uuid}
var sessionPattern = regexp.MustCompile(`_session_([a-f0-9-]+)$`)
// SessionAffinitySelector wraps another selector with session-sticky behavior.
// It extracts session ID from multiple sources and maintains session-to-auth
// mappings with automatic failover when the bound auth becomes unavailable.
type SessionAffinitySelector struct {
fallback Selector
cache *SessionCache
}
// SessionAffinityConfig configures the session affinity selector.
type SessionAffinityConfig struct {
Fallback Selector
TTL time.Duration
}
// NewSessionAffinitySelector creates a new session-aware selector.
func NewSessionAffinitySelector(fallback Selector) *SessionAffinitySelector {
return NewSessionAffinitySelectorWithConfig(SessionAffinityConfig{
Fallback: fallback,
TTL: time.Hour,
})
}
// NewSessionAffinitySelectorWithConfig creates a selector with custom configuration.
func NewSessionAffinitySelectorWithConfig(cfg SessionAffinityConfig) *SessionAffinitySelector {
if cfg.Fallback == nil {
cfg.Fallback = &RoundRobinSelector{}
}
if cfg.TTL <= 0 {
cfg.TTL = time.Hour
}
return &SessionAffinitySelector{
fallback: cfg.Fallback,
cache: NewSessionCache(cfg.TTL),
}
}
// Pick selects an auth with session affinity when possible.
// Priority for session ID extraction:
// 1. metadata.user_id (Claude Code format) - highest priority
// 2. X-Session-ID header
// 3. metadata.user_id (non-Claude Code format)
// 4. conversation_id field
// 5. Hash-based fallback from messages
//
// Note: The cache key includes provider, session ID, and model to handle cases where
// a session uses multiple models (e.g., gemini-2.5-pro and gemini-3-flash-preview)
// that may be supported by different auth credentials, and to avoid cross-provider conflicts.
func (s *SessionAffinitySelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) {
entry := selectorLogEntry(ctx)
primaryID, fallbackID := extractSessionIDs(opts.Headers, opts.OriginalRequest, opts.Metadata)
if primaryID == "" {
entry.Debugf("session-affinity: no session ID extracted, falling back to default selector | provider=%s model=%s", provider, model)
return s.fallback.Pick(ctx, provider, model, opts, auths)
}
now := time.Now()
available, err := getAvailableAuths(auths, provider, model, now)
if err != nil {
return nil, err
}
cacheKey := provider + "::" + primaryID + "::" + model
if cachedAuthID, ok := s.cache.GetAndRefresh(cacheKey); ok {
for _, auth := range available {
if auth.ID == cachedAuthID {
entry.Infof("session-affinity: cache hit | session=%s auth=%s provider=%s model=%s", truncateSessionID(primaryID), auth.ID, provider, model)
return auth, nil
}
}
// Cached auth not available, reselect via fallback selector for even distribution
auth, err := s.fallback.Pick(ctx, provider, model, opts, auths)
if err != nil {
return nil, err
}
s.cache.Set(cacheKey, auth.ID)
entry.Infof("session-affinity: cache hit but auth unavailable, reselected | session=%s auth=%s provider=%s model=%s", truncateSessionID(primaryID), auth.ID, provider, model)
return auth, nil
}
if fallbackID != "" && fallbackID != primaryID {
fallbackKey := provider + "::" + fallbackID + "::" + model
if cachedAuthID, ok := s.cache.Get(fallbackKey); ok {
for _, auth := range available {
if auth.ID == cachedAuthID {
s.cache.Set(cacheKey, auth.ID)
entry.Infof("session-affinity: fallback cache hit | session=%s fallback=%s auth=%s provider=%s model=%s", truncateSessionID(primaryID), truncateSessionID(fallbackID), auth.ID, provider, model)
return auth, nil
}
}
}
}
auth, err := s.fallback.Pick(ctx, provider, model, opts, auths)
if err != nil {
return nil, err
}
s.cache.Set(cacheKey, auth.ID)
entry.Infof("session-affinity: cache miss, new binding | session=%s auth=%s provider=%s model=%s", truncateSessionID(primaryID), auth.ID, provider, model)
return auth, nil
}
func selectorLogEntry(ctx context.Context) *log.Entry {
if ctx == nil {
return log.NewEntry(log.StandardLogger())
}
if reqID := logging.GetRequestID(ctx); reqID != "" {
return log.WithField("request_id", reqID)
}
return log.NewEntry(log.StandardLogger())
}
// truncateSessionID shortens session ID for logging (first 8 chars + "...")
func truncateSessionID(id string) string {
if len(id) <= 20 {
return id
}
return id[:8] + "..."
}
// Stop releases resources held by the selector.
func (s *SessionAffinitySelector) Stop() {
if s.cache != nil {
s.cache.Stop()
}
}
// InvalidateAuth removes all session bindings for a specific auth.
// Called when an auth becomes rate-limited or unavailable.
func (s *SessionAffinitySelector) InvalidateAuth(authID string) {
if s.cache != nil {
s.cache.InvalidateAuth(authID)
}
}
// ExtractSessionID extracts session identifier from multiple sources.
// Priority order:
// 1. metadata.user_id (Claude Code format with _session_{uuid}) - highest priority for Claude Code clients
// 2. X-Session-ID header
// 3. metadata.user_id (non-Claude Code format)
// 4. conversation_id field in request body
// 5. Stable hash from first few messages content (fallback)
func ExtractSessionID(headers http.Header, payload []byte, metadata map[string]any) string {
primary, _ := extractSessionIDs(headers, payload, metadata)
return primary
}
// extractSessionIDs returns (primaryID, fallbackID) for session affinity.
// primaryID: full hash including assistant response (stable after first turn)
// fallbackID: short hash without assistant (used to inherit binding from first turn)
func extractSessionIDs(headers http.Header, payload []byte, metadata map[string]any) (string, string) {
// 1. metadata.user_id with Claude Code session format (highest priority)
if len(payload) > 0 {
userID := gjson.GetBytes(payload, "metadata.user_id").String()
if userID != "" {
// Old format: user_{hash}_account__session_{uuid}
if matches := sessionPattern.FindStringSubmatch(userID); len(matches) >= 2 {
id := "claude:" + matches[1]
return id, ""
}
// New format: JSON object with session_id field
// e.g. {"device_id":"...","account_uuid":"...","session_id":"uuid"}
if len(userID) > 0 && userID[0] == '{' {
if sid := gjson.Get(userID, "session_id").String(); sid != "" {
return "claude:" + sid, ""
}
}
}
}
// 2. X-Session-ID header
if headers != nil {
if sid := headers.Get("X-Session-ID"); sid != "" {
return "header:" + sid, ""
}
}
if len(payload) == 0 {
return "", ""
}
// 3. metadata.user_id (non-Claude Code format)
userID := gjson.GetBytes(payload, "metadata.user_id").String()
if userID != "" {
return "user:" + userID, ""
}
// 4. conversation_id field
if convID := gjson.GetBytes(payload, "conversation_id").String(); convID != "" {
return "conv:" + convID, ""
}
// 5. Hash-based fallback from message content
return extractMessageHashIDs(payload)
}
func extractMessageHashIDs(payload []byte) (primaryID, fallbackID string) {
var systemPrompt, firstUserMsg, firstAssistantMsg string
// OpenAI/Claude messages format
messages := gjson.GetBytes(payload, "messages")
if messages.Exists() && messages.IsArray() {
messages.ForEach(func(_, msg gjson.Result) bool {
role := msg.Get("role").String()
content := extractMessageContent(msg.Get("content"))
if content == "" {
return true
}
switch role {
case "system":
if systemPrompt == "" {
systemPrompt = truncateString(content, 100)
}
case "user":
if firstUserMsg == "" {
firstUserMsg = truncateString(content, 100)
}
case "assistant":
if firstAssistantMsg == "" {
firstAssistantMsg = truncateString(content, 100)
}
}
if systemPrompt != "" && firstUserMsg != "" && firstAssistantMsg != "" {
return false
}
return true
})
}
// Claude API: top-level "system" field (array or string)
if systemPrompt == "" {
topSystem := gjson.GetBytes(payload, "system")
if topSystem.Exists() {
if topSystem.IsArray() {
topSystem.ForEach(func(_, part gjson.Result) bool {
if text := part.Get("text").String(); text != "" && systemPrompt == "" {
systemPrompt = truncateString(text, 100)
return false
}
return true
})
} else if topSystem.Type == gjson.String {
systemPrompt = truncateString(topSystem.String(), 100)
}
}
}
// Gemini format
if systemPrompt == "" && firstUserMsg == "" {
sysInstr := gjson.GetBytes(payload, "systemInstruction.parts")
if sysInstr.Exists() && sysInstr.IsArray() {
sysInstr.ForEach(func(_, part gjson.Result) bool {
if text := part.Get("text").String(); text != "" && systemPrompt == "" {
systemPrompt = truncateString(text, 100)
return false
}
return true
})
}
contents := gjson.GetBytes(payload, "contents")
if contents.Exists() && contents.IsArray() {
contents.ForEach(func(_, msg gjson.Result) bool {
role := msg.Get("role").String()
msg.Get("parts").ForEach(func(_, part gjson.Result) bool {
text := part.Get("text").String()
if text == "" {
return true
}
switch role {
case "user":
if firstUserMsg == "" {
firstUserMsg = truncateString(text, 100)
}
case "model":
if firstAssistantMsg == "" {
firstAssistantMsg = truncateString(text, 100)
}
}
return false
})
if firstUserMsg != "" && firstAssistantMsg != "" {
return false
}
return true
})
}
}
// OpenAI Responses API format (v1/responses)
if systemPrompt == "" && firstUserMsg == "" {
if instr := gjson.GetBytes(payload, "instructions").String(); instr != "" {
systemPrompt = truncateString(instr, 100)
}
input := gjson.GetBytes(payload, "input")
if input.Exists() && input.IsArray() {
input.ForEach(func(_, item gjson.Result) bool {
itemType := item.Get("type").String()
if itemType == "reasoning" {
return true
}
// Skip non-message typed items (function_call, function_call_output, etc.)
// but allow items with no type that have a role (inline message format).
if itemType != "" && itemType != "message" {
return true
}
role := item.Get("role").String()
if itemType == "" && role == "" {
return true
}
// Handle both string content and array content (multimodal).
content := item.Get("content")
var text string
if content.Type == gjson.String {
text = content.String()
} else {
text = extractResponsesAPIContent(content)
}
if text == "" {
return true
}
switch role {
case "developer", "system":
if systemPrompt == "" {
systemPrompt = truncateString(text, 100)
}
case "user":
if firstUserMsg == "" {
firstUserMsg = truncateString(text, 100)
}
case "assistant":
if firstAssistantMsg == "" {
firstAssistantMsg = truncateString(text, 100)
}
}
if firstUserMsg != "" && firstAssistantMsg != "" {
return false
}
return true
})
}
}
if systemPrompt == "" && firstUserMsg == "" {
return "", ""
}
shortHash := computeSessionHash(systemPrompt, firstUserMsg, "")
if firstAssistantMsg == "" {
return shortHash, ""
}
fullHash := computeSessionHash(systemPrompt, firstUserMsg, firstAssistantMsg)
return fullHash, shortHash
}
func computeSessionHash(systemPrompt, userMsg, assistantMsg string) string {
h := fnv.New64a()
if systemPrompt != "" {
h.Write([]byte("sys:" + systemPrompt + "\n"))
}
if userMsg != "" {
h.Write([]byte("usr:" + userMsg + "\n"))
}
if assistantMsg != "" {
h.Write([]byte("ast:" + assistantMsg + "\n"))
}
return fmt.Sprintf("msg:%016x", h.Sum64())
}
func truncateString(s string, maxLen int) string {
if len(s) > maxLen {
return s[:maxLen]
}
return s
}
// extractMessageContent extracts text content from a message content field.
// Handles both string content and array content (multimodal messages).
// For array content, extracts text from all text-type elements.
func extractMessageContent(content gjson.Result) string {
// String content: "Hello world"
if content.Type == gjson.String {
return content.String()
}
// Array content: [{"type":"text","text":"Hello"},{"type":"image",...}]
if content.IsArray() {
var texts []string
content.ForEach(func(_, part gjson.Result) bool {
// Handle Claude format: {"type":"text","text":"content"}
if part.Get("type").String() == "text" {
if text := part.Get("text").String(); text != "" {
texts = append(texts, text)
}
}
// Handle OpenAI format: {"type":"text","text":"content"}
// Same structure as Claude, already handled above
return true
})
if len(texts) > 0 {
return strings.Join(texts, " ")
}
}
return ""
}
func extractResponsesAPIContent(content gjson.Result) string {
if !content.IsArray() {
return ""
}
var texts []string
content.ForEach(func(_, part gjson.Result) bool {
partType := part.Get("type").String()
if partType == "input_text" || partType == "output_text" || partType == "text" {
if text := part.Get("text").String(); text != "" {
texts = append(texts, text)
}
}
return true
})
if len(texts) > 0 {
return strings.Join(texts, " ")
}
return ""
}
// extractSessionID is kept for backward compatibility.
// Deprecated: Use ExtractSessionID instead.
func extractSessionID(payload []byte) string {
return ExtractSessionID(nil, payload, nil)
}

View File

@@ -4,7 +4,9 @@ import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
"sync"
"testing"
"time"
@@ -458,6 +460,159 @@ func TestRoundRobinSelectorPick_GeminiCLICredentialGrouping(t *testing.T) {
}
}
func TestExtractSessionID(t *testing.T) {
t.Parallel()
tests := []struct {
name string
payload string
want string
}{
{
name: "valid_claude_code_format",
payload: `{"metadata":{"user_id":"user_3f221fe75652cf9a89a31647f16274bb8036a9b85ac4dc226a4df0efec8dc04d_account__session_ac980658-63bd-4fb3-97ba-8da64cb1e344"}}`,
want: "claude:ac980658-63bd-4fb3-97ba-8da64cb1e344",
},
{
name: "json_user_id_with_session_id",
payload: `{"metadata":{"user_id":"{\"device_id\":\"be82c3aee1e0c2d74535bacc85f9f559228f02dd8a17298cf522b71e6c375714\",\"account_uuid\":\"\",\"session_id\":\"e26d4046-0f88-4b09-bb5b-f863ab5fb24e\"}"}}`,
want: "claude:e26d4046-0f88-4b09-bb5b-f863ab5fb24e",
},
{
name: "json_user_id_without_session_id",
payload: `{"metadata":{"user_id":"{\"device_id\":\"abc123\"}"}}`,
want: `user:{"device_id":"abc123"}`,
},
{
name: "no_session_but_user_id",
payload: `{"metadata":{"user_id":"user_abc123"}}`,
want: "user:user_abc123",
},
{
name: "conversation_id",
payload: `{"conversation_id":"conv-12345"}`,
want: "conv:conv-12345",
},
{
name: "no_metadata",
payload: `{"model":"claude-3"}`,
want: "",
},
{
name: "empty_payload",
payload: ``,
want: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := extractSessionID([]byte(tt.payload))
if got != tt.want {
t.Errorf("extractSessionID() = %q, want %q", got, tt.want)
}
})
}
}
func TestSessionAffinitySelector_SameSessionSameAuth(t *testing.T) {
t.Parallel()
fallback := &RoundRobinSelector{}
selector := NewSessionAffinitySelector(fallback)
auths := []*Auth{
{ID: "auth-a"},
{ID: "auth-b"},
{ID: "auth-c"},
}
// Use valid UUID format for session ID
payload := []byte(`{"metadata":{"user_id":"user_xxx_account__session_ac980658-63bd-4fb3-97ba-8da64cb1e344"}}`)
opts := cliproxyexecutor.Options{OriginalRequest: payload}
// Same session should always pick the same auth
first, err := selector.Pick(context.Background(), "claude", "claude-3", opts, auths)
if err != nil {
t.Fatalf("Pick() error = %v", err)
}
if first == nil {
t.Fatalf("Pick() returned nil")
}
// Verify consistency: same session, same auths -> same result
for i := 0; i < 10; i++ {
got, err := selector.Pick(context.Background(), "claude", "claude-3", opts, auths)
if err != nil {
t.Fatalf("Pick() #%d error = %v", i, err)
}
if got.ID != first.ID {
t.Fatalf("Pick() #%d auth.ID = %q, want %q (same session should pick same auth)", i, got.ID, first.ID)
}
}
}
func TestSessionAffinitySelector_NoSessionFallback(t *testing.T) {
t.Parallel()
fallback := &FillFirstSelector{}
selector := NewSessionAffinitySelector(fallback)
auths := []*Auth{
{ID: "auth-b"},
{ID: "auth-a"},
{ID: "auth-c"},
}
// No session in payload, should fallback to FillFirstSelector (picks "auth-a" after sorting)
payload := []byte(`{"model":"claude-3"}`)
opts := cliproxyexecutor.Options{OriginalRequest: payload}
got, err := selector.Pick(context.Background(), "claude", "claude-3", opts, auths)
if err != nil {
t.Fatalf("Pick() error = %v", err)
}
if got.ID != "auth-a" {
t.Fatalf("Pick() auth.ID = %q, want %q (should fallback to FillFirst)", got.ID, "auth-a")
}
}
func TestSessionAffinitySelector_DifferentSessionsDifferentAuths(t *testing.T) {
t.Parallel()
fallback := &RoundRobinSelector{}
selector := NewSessionAffinitySelector(fallback)
auths := []*Auth{
{ID: "auth-a"},
{ID: "auth-b"},
{ID: "auth-c"},
}
// Use valid UUID format for session IDs
session1 := []byte(`{"metadata":{"user_id":"user_xxx_account__session_11111111-1111-1111-1111-111111111111"}}`)
session2 := []byte(`{"metadata":{"user_id":"user_xxx_account__session_22222222-2222-2222-2222-222222222222"}}`)
opts1 := cliproxyexecutor.Options{OriginalRequest: session1}
opts2 := cliproxyexecutor.Options{OriginalRequest: session2}
auth1, _ := selector.Pick(context.Background(), "claude", "claude-3", opts1, auths)
auth2, _ := selector.Pick(context.Background(), "claude", "claude-3", opts2, auths)
// Different sessions may or may not pick different auths (depends on hash collision)
// But each session should be consistent
for i := 0; i < 5; i++ {
got1, _ := selector.Pick(context.Background(), "claude", "claude-3", opts1, auths)
got2, _ := selector.Pick(context.Background(), "claude", "claude-3", opts2, auths)
if got1.ID != auth1.ID {
t.Fatalf("session1 Pick() #%d inconsistent: got %q, want %q", i, got1.ID, auth1.ID)
}
if got2.ID != auth2.ID {
t.Fatalf("session2 Pick() #%d inconsistent: got %q, want %q", i, got2.ID, auth2.ID)
}
}
}
func TestRoundRobinSelectorPick_SingleParentFallsBackToFlat(t *testing.T) {
t.Parallel()
@@ -494,6 +649,57 @@ func TestRoundRobinSelectorPick_SingleParentFallsBackToFlat(t *testing.T) {
}
}
func TestSessionAffinitySelector_FailoverWhenAuthUnavailable(t *testing.T) {
t.Parallel()
fallback := &RoundRobinSelector{}
selector := NewSessionAffinitySelectorWithConfig(SessionAffinityConfig{
Fallback: fallback,
TTL: time.Minute,
})
defer selector.Stop()
auths := []*Auth{
{ID: "auth-a"},
{ID: "auth-b"},
{ID: "auth-c"},
}
payload := []byte(`{"metadata":{"user_id":"user_xxx_account__session_failover-test-uuid"}}`)
opts := cliproxyexecutor.Options{OriginalRequest: payload}
// First pick establishes binding
first, err := selector.Pick(context.Background(), "claude", "claude-3", opts, auths)
if err != nil {
t.Fatalf("Pick() error = %v", err)
}
// Remove the bound auth from available list (simulating rate limit)
availableWithoutFirst := make([]*Auth, 0, len(auths)-1)
for _, a := range auths {
if a.ID != first.ID {
availableWithoutFirst = append(availableWithoutFirst, a)
}
}
// With failover enabled, should pick a new auth
second, err := selector.Pick(context.Background(), "claude", "claude-3", opts, availableWithoutFirst)
if err != nil {
t.Fatalf("Pick() after failover error = %v", err)
}
if second.ID == first.ID {
t.Fatalf("Pick() after failover returned same auth %q, expected different", first.ID)
}
// Subsequent picks should consistently return the new binding
for i := 0; i < 5; i++ {
got, _ := selector.Pick(context.Background(), "claude", "claude-3", opts, availableWithoutFirst)
if got.ID != second.ID {
t.Fatalf("Pick() #%d after failover inconsistent: got %q, want %q", i, got.ID, second.ID)
}
}
}
func TestRoundRobinSelectorPick_MixedVirtualAndNonVirtualFallsBackToFlat(t *testing.T) {
t.Parallel()
@@ -527,3 +733,629 @@ func TestRoundRobinSelectorPick_MixedVirtualAndNonVirtualFallsBackToFlat(t *test
}
}
}
func TestExtractSessionID_ClaudeCodePriorityOverHeader(t *testing.T) {
t.Parallel()
// Claude Code metadata.user_id should have highest priority, even when X-Session-ID header is present
headers := make(http.Header)
headers.Set("X-Session-ID", "header-session-id")
payload := []byte(`{"metadata":{"user_id":"user_xxx_account__session_ac980658-63bd-4fb3-97ba-8da64cb1e344"}}`)
got := ExtractSessionID(headers, payload, nil)
want := "claude:ac980658-63bd-4fb3-97ba-8da64cb1e344"
if got != want {
t.Errorf("ExtractSessionID() = %q, want %q (Claude Code should have highest priority over header)", got, want)
}
}
func TestExtractSessionID_ClaudeCodePriorityOverIdempotencyKey(t *testing.T) {
t.Parallel()
// Claude Code metadata.user_id should have highest priority, even when idempotency_key is present
metadata := map[string]any{"idempotency_key": "idem-12345"}
payload := []byte(`{"metadata":{"user_id":"user_xxx_account__session_ac980658-63bd-4fb3-97ba-8da64cb1e344"}}`)
got := ExtractSessionID(nil, payload, metadata)
want := "claude:ac980658-63bd-4fb3-97ba-8da64cb1e344"
if got != want {
t.Errorf("ExtractSessionID() = %q, want %q (Claude Code should have highest priority over idempotency_key)", got, want)
}
}
func TestExtractSessionID_Headers(t *testing.T) {
t.Parallel()
headers := make(http.Header)
headers.Set("X-Session-ID", "my-explicit-session")
got := ExtractSessionID(headers, nil, nil)
want := "header:my-explicit-session"
if got != want {
t.Errorf("ExtractSessionID() with header = %q, want %q", got, want)
}
}
// TestExtractSessionID_IdempotencyKey verifies that idempotency_key is intentionally
// ignored for session affinity (it's auto-generated per-request, causing cache misses).
func TestExtractSessionID_IdempotencyKey(t *testing.T) {
t.Parallel()
metadata := map[string]any{"idempotency_key": "idem-12345"}
got := ExtractSessionID(nil, nil, metadata)
// idempotency_key is disabled - should return empty (no payload to hash)
if got != "" {
t.Errorf("ExtractSessionID() with idempotency_key = %q, want empty (idempotency_key is disabled)", got)
}
}
func TestExtractSessionID_MessageHashFallback(t *testing.T) {
t.Parallel()
// First request (user only) generates short hash
firstRequestPayload := []byte(`{"messages":[{"role":"user","content":"Hello world"}]}`)
shortHash := ExtractSessionID(nil, firstRequestPayload, nil)
if shortHash == "" {
t.Error("ExtractSessionID() first request should return short hash")
}
if !strings.HasPrefix(shortHash, "msg:") {
t.Errorf("ExtractSessionID() = %q, want prefix 'msg:'", shortHash)
}
// Multi-turn with assistant generates full hash (different from short hash)
multiTurnPayload := []byte(`{"messages":[
{"role":"user","content":"Hello world"},
{"role":"assistant","content":"Hi! How can I help?"},
{"role":"user","content":"Tell me a joke"}
]}`)
fullHash := ExtractSessionID(nil, multiTurnPayload, nil)
if fullHash == "" {
t.Error("ExtractSessionID() multi-turn should return full hash")
}
if fullHash == shortHash {
t.Error("Full hash should differ from short hash (includes assistant)")
}
// Same multi-turn payload should produce same hash
fullHash2 := ExtractSessionID(nil, multiTurnPayload, nil)
if fullHash != fullHash2 {
t.Errorf("ExtractSessionID() not stable: got %q then %q", fullHash, fullHash2)
}
}
func TestExtractSessionID_ClaudeAPITopLevelSystem(t *testing.T) {
t.Parallel()
// Claude API: system prompt in top-level "system" field (array format)
arraySystem := []byte(`{
"messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}],
"system": [{"type": "text", "text": "You are Claude Code"}]
}`)
got1 := ExtractSessionID(nil, arraySystem, nil)
if got1 == "" || !strings.HasPrefix(got1, "msg:") {
t.Errorf("ExtractSessionID() with array system = %q, want msg:* prefix", got1)
}
// Claude API: system prompt in top-level "system" field (string format)
stringSystem := []byte(`{
"messages": [{"role": "user", "content": "Hello"}],
"system": "You are Claude Code"
}`)
got2 := ExtractSessionID(nil, stringSystem, nil)
if got2 == "" || !strings.HasPrefix(got2, "msg:") {
t.Errorf("ExtractSessionID() with string system = %q, want msg:* prefix", got2)
}
// Multi-turn with top-level system should produce stable hash
multiTurn := []byte(`{
"messages": [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi!"},
{"role": "user", "content": "Help me"}
],
"system": "You are Claude Code"
}`)
got3 := ExtractSessionID(nil, multiTurn, nil)
if got3 == "" {
t.Error("ExtractSessionID() multi-turn with top-level system should return hash")
}
if got3 == got2 {
t.Error("Multi-turn hash should differ from first-turn hash (includes assistant)")
}
}
func TestExtractSessionID_GeminiFormat(t *testing.T) {
t.Parallel()
// Gemini format with systemInstruction and contents
payload := []byte(`{
"systemInstruction": {"parts": [{"text": "You are a helpful assistant."}]},
"contents": [
{"role": "user", "parts": [{"text": "Hello Gemini"}]},
{"role": "model", "parts": [{"text": "Hi there!"}]}
]
}`)
got := ExtractSessionID(nil, payload, nil)
if got == "" {
t.Error("ExtractSessionID() with Gemini format should return hash-based session ID")
}
if !strings.HasPrefix(got, "msg:") {
t.Errorf("ExtractSessionID() = %q, want prefix 'msg:'", got)
}
// Same payload should produce same hash
got2 := ExtractSessionID(nil, payload, nil)
if got != got2 {
t.Errorf("ExtractSessionID() not stable: got %q then %q", got, got2)
}
// Different user message should produce different hash
differentPayload := []byte(`{
"systemInstruction": {"parts": [{"text": "You are a helpful assistant."}]},
"contents": [
{"role": "user", "parts": [{"text": "Hello different"}]},
{"role": "model", "parts": [{"text": "Hi there!"}]}
]
}`)
got3 := ExtractSessionID(nil, differentPayload, nil)
if got == got3 {
t.Errorf("ExtractSessionID() should produce different hash for different user message")
}
}
func TestExtractSessionID_OpenAIResponsesAPI(t *testing.T) {
t.Parallel()
firstTurn := []byte(`{
"instructions": "You are Codex, based on GPT-5.",
"input": [
{"type": "message", "role": "developer", "content": [{"type": "input_text", "text": "system instructions"}]},
{"type": "message", "role": "user", "content": [{"type": "input_text", "text": "hi"}]}
]
}`)
got1 := ExtractSessionID(nil, firstTurn, nil)
if got1 == "" {
t.Error("ExtractSessionID() should return hash for OpenAI Responses API format")
}
if !strings.HasPrefix(got1, "msg:") {
t.Errorf("ExtractSessionID() = %q, want prefix 'msg:'", got1)
}
secondTurn := []byte(`{
"instructions": "You are Codex, based on GPT-5.",
"input": [
{"type": "message", "role": "developer", "content": [{"type": "input_text", "text": "system instructions"}]},
{"type": "message", "role": "user", "content": [{"type": "input_text", "text": "hi"}]},
{"type": "reasoning", "summary": [{"type": "summary_text", "text": "thinking..."}], "encrypted_content": "xxx"},
{"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": "Hello!"}]},
{"type": "message", "role": "user", "content": [{"type": "input_text", "text": "what can you do"}]}
]
}`)
got2 := ExtractSessionID(nil, secondTurn, nil)
if got2 == "" {
t.Error("ExtractSessionID() should return hash for second turn")
}
if got1 == got2 {
t.Log("First turn and second turn have different hashes (expected: second includes assistant)")
}
thirdTurn := []byte(`{
"instructions": "You are Codex, based on GPT-5.",
"input": [
{"type": "message", "role": "developer", "content": [{"type": "input_text", "text": "system instructions"}]},
{"type": "message", "role": "user", "content": [{"type": "input_text", "text": "hi"}]},
{"type": "reasoning", "summary": [{"type": "summary_text", "text": "thinking..."}], "encrypted_content": "xxx"},
{"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": "Hello!"}]},
{"type": "message", "role": "user", "content": [{"type": "input_text", "text": "what can you do"}]},
{"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": "I can help with..."}]},
{"type": "message", "role": "user", "content": [{"type": "input_text", "text": "thanks"}]}
]
}`)
got3 := ExtractSessionID(nil, thirdTurn, nil)
if got2 != got3 {
t.Errorf("Second and third turn should have same hash (same first assistant): got %q vs %q", got2, got3)
}
}
func TestSessionAffinitySelector_ThreeScenarios(t *testing.T) {
t.Parallel()
fallback := &RoundRobinSelector{}
selector := NewSessionAffinitySelectorWithConfig(SessionAffinityConfig{
Fallback: fallback,
TTL: time.Minute,
})
defer selector.Stop()
auths := []*Auth{{ID: "auth-a"}, {ID: "auth-b"}, {ID: "auth-c"}}
testCases := []struct {
name string
scenario string
payload []byte
}{
{
name: "OpenAI_Scenario1_NewRequest",
scenario: "new",
payload: []byte(`{"messages":[{"role":"system","content":"You are helpful"},{"role":"user","content":"Hello"}]}`),
},
{
name: "OpenAI_Scenario2_SecondTurn",
scenario: "second",
payload: []byte(`{"messages":[{"role":"system","content":"You are helpful"},{"role":"user","content":"Hello"},{"role":"assistant","content":"Hi there!"},{"role":"user","content":"Help me"}]}`),
},
{
name: "OpenAI_Scenario3_ManyTurns",
scenario: "many",
payload: []byte(`{"messages":[{"role":"system","content":"You are helpful"},{"role":"user","content":"Hello"},{"role":"assistant","content":"Hi there!"},{"role":"user","content":"Help me"},{"role":"assistant","content":"Sure!"},{"role":"user","content":"Thanks"}]}`),
},
{
name: "Gemini_Scenario1_NewRequest",
scenario: "new",
payload: []byte(`{"systemInstruction":{"parts":[{"text":"You are helpful"}]},"contents":[{"role":"user","parts":[{"text":"Hello Gemini"}]}]}`),
},
{
name: "Gemini_Scenario2_SecondTurn",
scenario: "second",
payload: []byte(`{"systemInstruction":{"parts":[{"text":"You are helpful"}]},"contents":[{"role":"user","parts":[{"text":"Hello Gemini"}]},{"role":"model","parts":[{"text":"Hi!"}]},{"role":"user","parts":[{"text":"Help"}]}]}`),
},
{
name: "Gemini_Scenario3_ManyTurns",
scenario: "many",
payload: []byte(`{"systemInstruction":{"parts":[{"text":"You are helpful"}]},"contents":[{"role":"user","parts":[{"text":"Hello Gemini"}]},{"role":"model","parts":[{"text":"Hi!"}]},{"role":"user","parts":[{"text":"Help"}]},{"role":"model","parts":[{"text":"Sure!"}]},{"role":"user","parts":[{"text":"Thanks"}]}]}`),
},
{
name: "Claude_Scenario1_NewRequest",
scenario: "new",
payload: []byte(`{"messages":[{"role":"user","content":"Hello Claude"}]}`),
},
{
name: "Claude_Scenario2_SecondTurn",
scenario: "second",
payload: []byte(`{"messages":[{"role":"user","content":"Hello Claude"},{"role":"assistant","content":"Hello!"},{"role":"user","content":"Help me"}]}`),
},
{
name: "Claude_Scenario3_ManyTurns",
scenario: "many",
payload: []byte(`{"messages":[{"role":"user","content":"Hello Claude"},{"role":"assistant","content":"Hello!"},{"role":"user","content":"Help"},{"role":"assistant","content":"Sure!"},{"role":"user","content":"Thanks"}]}`),
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
opts := cliproxyexecutor.Options{OriginalRequest: tc.payload}
picked, err := selector.Pick(context.Background(), "provider", "model", opts, auths)
if err != nil {
t.Fatalf("Pick() error = %v", err)
}
if picked == nil {
t.Fatal("Pick() returned nil")
}
t.Logf("%s: picked %s", tc.name, picked.ID)
})
}
t.Run("Scenario2And3_SameAuth", func(t *testing.T) {
openaiS2 := []byte(`{"messages":[{"role":"system","content":"Stable test"},{"role":"user","content":"First msg"},{"role":"assistant","content":"Response"},{"role":"user","content":"Second"}]}`)
openaiS3 := []byte(`{"messages":[{"role":"system","content":"Stable test"},{"role":"user","content":"First msg"},{"role":"assistant","content":"Response"},{"role":"user","content":"Second"},{"role":"assistant","content":"More"},{"role":"user","content":"Third"}]}`)
opts2 := cliproxyexecutor.Options{OriginalRequest: openaiS2}
opts3 := cliproxyexecutor.Options{OriginalRequest: openaiS3}
picked2, _ := selector.Pick(context.Background(), "test", "model", opts2, auths)
picked3, _ := selector.Pick(context.Background(), "test", "model", opts3, auths)
if picked2.ID != picked3.ID {
t.Errorf("Scenario2 and Scenario3 should pick same auth: got %s vs %s", picked2.ID, picked3.ID)
}
})
t.Run("Scenario1To2_InheritBinding", func(t *testing.T) {
s1 := []byte(`{"messages":[{"role":"system","content":"Inherit test"},{"role":"user","content":"Initial"}]}`)
s2 := []byte(`{"messages":[{"role":"system","content":"Inherit test"},{"role":"user","content":"Initial"},{"role":"assistant","content":"Reply"},{"role":"user","content":"Continue"}]}`)
opts1 := cliproxyexecutor.Options{OriginalRequest: s1}
opts2 := cliproxyexecutor.Options{OriginalRequest: s2}
picked1, _ := selector.Pick(context.Background(), "inherit", "model", opts1, auths)
picked2, _ := selector.Pick(context.Background(), "inherit", "model", opts2, auths)
if picked1.ID != picked2.ID {
t.Errorf("Scenario2 should inherit Scenario1 binding: got %s vs %s", picked1.ID, picked2.ID)
}
})
}
func TestSessionAffinitySelector_MultiModelSession(t *testing.T) {
t.Parallel()
fallback := &RoundRobinSelector{}
selector := NewSessionAffinitySelectorWithConfig(SessionAffinityConfig{
Fallback: fallback,
TTL: time.Minute,
})
defer selector.Stop()
// auth-a supports only model-a, auth-b supports only model-b
authA := &Auth{ID: "auth-a"}
authB := &Auth{ID: "auth-b"}
// Same session ID for all requests
payload := []byte(`{"metadata":{"user_id":"user_xxx_account__session_multi-model-test"}}`)
opts := cliproxyexecutor.Options{OriginalRequest: payload}
// Request model-a with only auth-a available for that model
authsForModelA := []*Auth{authA}
pickedA, err := selector.Pick(context.Background(), "provider", "model-a", opts, authsForModelA)
if err != nil {
t.Fatalf("Pick() for model-a error = %v", err)
}
if pickedA.ID != "auth-a" {
t.Fatalf("Pick() for model-a = %q, want auth-a", pickedA.ID)
}
// Request model-b with only auth-b available for that model
authsForModelB := []*Auth{authB}
pickedB, err := selector.Pick(context.Background(), "provider", "model-b", opts, authsForModelB)
if err != nil {
t.Fatalf("Pick() for model-b error = %v", err)
}
if pickedB.ID != "auth-b" {
t.Fatalf("Pick() for model-b = %q, want auth-b", pickedB.ID)
}
// Switch back to model-a - should still get auth-a (separate binding per model)
pickedA2, err := selector.Pick(context.Background(), "provider", "model-a", opts, authsForModelA)
if err != nil {
t.Fatalf("Pick() for model-a (2nd) error = %v", err)
}
if pickedA2.ID != "auth-a" {
t.Fatalf("Pick() for model-a (2nd) = %q, want auth-a", pickedA2.ID)
}
// Verify bindings are stable for multiple calls
for i := 0; i < 5; i++ {
gotA, _ := selector.Pick(context.Background(), "provider", "model-a", opts, authsForModelA)
gotB, _ := selector.Pick(context.Background(), "provider", "model-b", opts, authsForModelB)
if gotA.ID != "auth-a" {
t.Fatalf("Pick() #%d for model-a = %q, want auth-a", i, gotA.ID)
}
if gotB.ID != "auth-b" {
t.Fatalf("Pick() #%d for model-b = %q, want auth-b", i, gotB.ID)
}
}
}
func TestExtractSessionID_MultimodalContent(t *testing.T) {
t.Parallel()
// First request generates short hash
firstRequestPayload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"Hello world"},{"type":"image","source":{"data":"..."}}]}]}`)
shortHash := ExtractSessionID(nil, firstRequestPayload, nil)
if shortHash == "" {
t.Error("ExtractSessionID() first request should return short hash")
}
if !strings.HasPrefix(shortHash, "msg:") {
t.Errorf("ExtractSessionID() = %q, want prefix 'msg:'", shortHash)
}
// Multi-turn generates full hash
multiTurnPayload := []byte(`{"messages":[
{"role":"user","content":[{"type":"text","text":"Hello world"},{"type":"image","source":{"data":"..."}}]},
{"role":"assistant","content":"I see an image!"},
{"role":"user","content":"What is it?"}
]}`)
fullHash := ExtractSessionID(nil, multiTurnPayload, nil)
if fullHash == "" {
t.Error("ExtractSessionID() multimodal multi-turn should return full hash")
}
if fullHash == shortHash {
t.Error("Full hash should differ from short hash")
}
// Different user content produces different hash
differentPayload := []byte(`{"messages":[
{"role":"user","content":[{"type":"text","text":"Different content"}]},
{"role":"assistant","content":"I see something different!"}
]}`)
differentHash := ExtractSessionID(nil, differentPayload, nil)
if fullHash == differentHash {
t.Errorf("ExtractSessionID() should produce different hash for different content")
}
}
func TestSessionAffinitySelector_CrossProviderIsolation(t *testing.T) {
t.Parallel()
fallback := &RoundRobinSelector{}
selector := NewSessionAffinitySelectorWithConfig(SessionAffinityConfig{
Fallback: fallback,
TTL: time.Minute,
})
defer selector.Stop()
authClaude := &Auth{ID: "auth-claude"}
authGemini := &Auth{ID: "auth-gemini"}
// Same session ID for both providers
payload := []byte(`{"metadata":{"user_id":"user_xxx_account__session_cross-provider-test"}}`)
opts := cliproxyexecutor.Options{OriginalRequest: payload}
// Request via claude provider
pickedClaude, err := selector.Pick(context.Background(), "claude", "claude-3", opts, []*Auth{authClaude})
if err != nil {
t.Fatalf("Pick() for claude error = %v", err)
}
if pickedClaude.ID != "auth-claude" {
t.Fatalf("Pick() for claude = %q, want auth-claude", pickedClaude.ID)
}
// Same session but via gemini provider should get different auth
pickedGemini, err := selector.Pick(context.Background(), "gemini", "gemini-2.5-pro", opts, []*Auth{authGemini})
if err != nil {
t.Fatalf("Pick() for gemini error = %v", err)
}
if pickedGemini.ID != "auth-gemini" {
t.Fatalf("Pick() for gemini = %q, want auth-gemini", pickedGemini.ID)
}
// Verify both bindings remain stable
for i := 0; i < 5; i++ {
gotC, _ := selector.Pick(context.Background(), "claude", "claude-3", opts, []*Auth{authClaude})
gotG, _ := selector.Pick(context.Background(), "gemini", "gemini-2.5-pro", opts, []*Auth{authGemini})
if gotC.ID != "auth-claude" {
t.Fatalf("Pick() #%d for claude = %q, want auth-claude", i, gotC.ID)
}
if gotG.ID != "auth-gemini" {
t.Fatalf("Pick() #%d for gemini = %q, want auth-gemini", i, gotG.ID)
}
}
}
func TestSessionCache_GetAndRefresh(t *testing.T) {
t.Parallel()
cache := NewSessionCache(100 * time.Millisecond)
defer cache.Stop()
cache.Set("session1", "auth1")
// Verify initial value
got, ok := cache.GetAndRefresh("session1")
if !ok || got != "auth1" {
t.Fatalf("GetAndRefresh() = %q, %v, want auth1, true", got, ok)
}
// Wait half TTL and access again (should refresh)
time.Sleep(60 * time.Millisecond)
got, ok = cache.GetAndRefresh("session1")
if !ok || got != "auth1" {
t.Fatalf("GetAndRefresh() after 60ms = %q, %v, want auth1, true", got, ok)
}
// Wait another 60ms (total 120ms from original, but TTL refreshed at 60ms)
// Entry should still be valid because TTL was refreshed
time.Sleep(60 * time.Millisecond)
got, ok = cache.GetAndRefresh("session1")
if !ok || got != "auth1" {
t.Fatalf("GetAndRefresh() after refresh = %q, %v, want auth1, true (TTL should have been refreshed)", got, ok)
}
// Now wait full TTL without access
time.Sleep(110 * time.Millisecond)
got, ok = cache.GetAndRefresh("session1")
if ok {
t.Fatalf("GetAndRefresh() after expiry = %q, %v, want '', false", got, ok)
}
}
func TestSessionAffinitySelector_RoundRobinDistribution(t *testing.T) {
t.Parallel()
fallback := &RoundRobinSelector{}
selector := NewSessionAffinitySelectorWithConfig(SessionAffinityConfig{
Fallback: fallback,
TTL: time.Minute,
})
defer selector.Stop()
auths := []*Auth{
{ID: "auth-a"},
{ID: "auth-b"},
{ID: "auth-c"},
}
sessionCount := 12
counts := make(map[string]int)
for i := 0; i < sessionCount; i++ {
payload := []byte(fmt.Sprintf(`{"metadata":{"user_id":"user_xxx_account__session_%08d-0000-0000-0000-000000000000"}}`, i))
opts := cliproxyexecutor.Options{OriginalRequest: payload}
got, err := selector.Pick(context.Background(), "provider", "model", opts, auths)
if err != nil {
t.Fatalf("Pick() session %d error = %v", i, err)
}
counts[got.ID]++
}
expected := sessionCount / len(auths)
for _, auth := range auths {
got := counts[auth.ID]
if got != expected {
t.Errorf("auth %s got %d sessions, want %d (round-robin should distribute evenly)", auth.ID, got, expected)
}
}
}
func TestSessionAffinitySelector_Concurrent(t *testing.T) {
t.Parallel()
fallback := &RoundRobinSelector{}
selector := NewSessionAffinitySelectorWithConfig(SessionAffinityConfig{
Fallback: fallback,
TTL: time.Minute,
})
defer selector.Stop()
auths := []*Auth{
{ID: "auth-a"},
{ID: "auth-b"},
{ID: "auth-c"},
}
payload := []byte(`{"metadata":{"user_id":"user_xxx_account__session_concurrent-test"}}`)
opts := cliproxyexecutor.Options{OriginalRequest: payload}
// First pick to establish binding
first, err := selector.Pick(context.Background(), "claude", "claude-3", opts, auths)
if err != nil {
t.Fatalf("Initial Pick() error = %v", err)
}
expectedID := first.ID
start := make(chan struct{})
var wg sync.WaitGroup
errCh := make(chan error, 1)
goroutines := 32
iterations := 50
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
<-start
for j := 0; j < iterations; j++ {
got, err := selector.Pick(context.Background(), "claude", "claude-3", opts, auths)
if err != nil {
select {
case errCh <- err:
default:
}
return
}
if got.ID != expectedID {
select {
case errCh <- fmt.Errorf("concurrent Pick() returned %q, want %q", got.ID, expectedID):
default:
}
return
}
}
}()
}
close(start)
wg.Wait()
select {
case err := <-errCh:
t.Fatalf("concurrent Pick() error = %v", err)
default:
}
}

View File

@@ -0,0 +1,152 @@
package auth
import (
"sync"
"time"
)
// sessionEntry stores auth binding with expiration.
type sessionEntry struct {
authID string
expiresAt time.Time
}
// SessionCache provides TTL-based session to auth mapping with automatic cleanup.
type SessionCache struct {
mu sync.RWMutex
entries map[string]sessionEntry
ttl time.Duration
stopCh chan struct{}
}
// NewSessionCache creates a cache with the specified TTL.
// A background goroutine periodically cleans expired entries.
func NewSessionCache(ttl time.Duration) *SessionCache {
if ttl <= 0 {
ttl = 30 * time.Minute
}
c := &SessionCache{
entries: make(map[string]sessionEntry),
ttl: ttl,
stopCh: make(chan struct{}),
}
go c.cleanupLoop()
return c
}
// Get retrieves the auth ID bound to a session, if still valid.
// Does NOT refresh the TTL on access.
func (c *SessionCache) Get(sessionID string) (string, bool) {
if sessionID == "" {
return "", false
}
c.mu.RLock()
entry, ok := c.entries[sessionID]
c.mu.RUnlock()
if !ok {
return "", false
}
if time.Now().After(entry.expiresAt) {
c.mu.Lock()
delete(c.entries, sessionID)
c.mu.Unlock()
return "", false
}
return entry.authID, true
}
// GetAndRefresh retrieves the auth ID bound to a session and refreshes TTL on hit.
// This extends the binding lifetime for active sessions.
func (c *SessionCache) GetAndRefresh(sessionID string) (string, bool) {
if sessionID == "" {
return "", false
}
now := time.Now()
c.mu.Lock()
entry, ok := c.entries[sessionID]
if !ok {
c.mu.Unlock()
return "", false
}
if now.After(entry.expiresAt) {
delete(c.entries, sessionID)
c.mu.Unlock()
return "", false
}
// Refresh TTL on successful access
entry.expiresAt = now.Add(c.ttl)
c.entries[sessionID] = entry
c.mu.Unlock()
return entry.authID, true
}
// Set binds a session to an auth ID with TTL refresh.
func (c *SessionCache) Set(sessionID, authID string) {
if sessionID == "" || authID == "" {
return
}
c.mu.Lock()
c.entries[sessionID] = sessionEntry{
authID: authID,
expiresAt: time.Now().Add(c.ttl),
}
c.mu.Unlock()
}
// Invalidate removes a specific session binding.
func (c *SessionCache) Invalidate(sessionID string) {
if sessionID == "" {
return
}
c.mu.Lock()
delete(c.entries, sessionID)
c.mu.Unlock()
}
// InvalidateAuth removes all sessions bound to a specific auth ID.
// Used when an auth becomes unavailable.
func (c *SessionCache) InvalidateAuth(authID string) {
if authID == "" {
return
}
c.mu.Lock()
for sid, entry := range c.entries {
if entry.authID == authID {
delete(c.entries, sid)
}
}
c.mu.Unlock()
}
// Stop terminates the background cleanup goroutine.
func (c *SessionCache) Stop() {
select {
case <-c.stopCh:
default:
close(c.stopCh)
}
}
func (c *SessionCache) cleanupLoop() {
ticker := time.NewTicker(c.ttl / 2)
defer ticker.Stop()
for {
select {
case <-c.stopCh:
return
case <-ticker.C:
c.cleanup()
}
}
}
func (c *SessionCache) cleanup() {
now := time.Now()
c.mu.Lock()
for sid, entry := range c.entries {
if now.After(entry.expiresAt) {
delete(c.entries, sid)
}
}
c.mu.Unlock()
}

View File

@@ -6,6 +6,7 @@ package cliproxy
import (
"fmt"
"strings"
"time"
configaccess "github.com/router-for-me/CLIProxyAPI/v6/internal/access/config_access"
"github.com/router-for-me/CLIProxyAPI/v6/internal/api"
@@ -208,8 +209,17 @@ func (b *Builder) Build() (*Service, error) {
}
strategy := ""
sessionAffinity := false
sessionAffinityTTL := time.Hour
if b.cfg != nil {
strategy = strings.ToLower(strings.TrimSpace(b.cfg.Routing.Strategy))
// Support both legacy ClaudeCodeSessionAffinity and new universal SessionAffinity
sessionAffinity = b.cfg.Routing.ClaudeCodeSessionAffinity || b.cfg.Routing.SessionAffinity
if ttlStr := strings.TrimSpace(b.cfg.Routing.SessionAffinityTTL); ttlStr != "" {
if parsed, err := time.ParseDuration(ttlStr); err == nil && parsed > 0 {
sessionAffinityTTL = parsed
}
}
}
var selector coreauth.Selector
switch strategy {
@@ -219,6 +229,14 @@ func (b *Builder) Build() (*Service, error) {
selector = &coreauth.RoundRobinSelector{}
}
// Wrap with session affinity if enabled (failover is always on)
if sessionAffinity {
selector = coreauth.NewSessionAffinitySelectorWithConfig(coreauth.SessionAffinityConfig{
Fallback: selector,
TTL: sessionAffinityTTL,
})
}
coreManager = coreauth.NewManager(tokenStore, selector, nil)
}
// Attach a default RoundTripper provider so providers can opt-in per-auth transports.

View File

@@ -612,9 +612,13 @@ func (s *Service) Run(ctx context.Context) error {
var watcherWrapper *WatcherWrapper
reloadCallback := func(newCfg *config.Config) {
previousStrategy := ""
var previousSessionAffinity bool
var previousSessionAffinityTTL string
s.cfgMu.RLock()
if s.cfg != nil {
previousStrategy = strings.ToLower(strings.TrimSpace(s.cfg.Routing.Strategy))
previousSessionAffinity = s.cfg.Routing.ClaudeCodeSessionAffinity || s.cfg.Routing.SessionAffinity
previousSessionAffinityTTL = s.cfg.Routing.SessionAffinityTTL
}
s.cfgMu.RUnlock()
@@ -638,7 +642,15 @@ func (s *Service) Run(ctx context.Context) error {
}
previousStrategy = normalizeStrategy(previousStrategy)
nextStrategy = normalizeStrategy(nextStrategy)
if s.coreManager != nil && previousStrategy != nextStrategy {
nextSessionAffinity := newCfg.Routing.ClaudeCodeSessionAffinity || newCfg.Routing.SessionAffinity
nextSessionAffinityTTL := newCfg.Routing.SessionAffinityTTL
selectorChanged := previousStrategy != nextStrategy ||
previousSessionAffinity != nextSessionAffinity ||
previousSessionAffinityTTL != nextSessionAffinityTTL
if s.coreManager != nil && selectorChanged {
var selector coreauth.Selector
switch nextStrategy {
case "fill-first":
@@ -646,6 +658,20 @@ func (s *Service) Run(ctx context.Context) error {
default:
selector = &coreauth.RoundRobinSelector{}
}
if nextSessionAffinity {
ttl := time.Hour
if ttlStr := strings.TrimSpace(nextSessionAffinityTTL); ttlStr != "" {
if parsed, err := time.ParseDuration(ttlStr); err == nil && parsed > 0 {
ttl = parsed
}
}
selector = coreauth.NewSessionAffinitySelectorWithConfig(coreauth.SessionAffinityConfig{
Fallback: selector,
TTL: ttl,
})
}
s.coreManager.SetSelector(selector)
}