mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-05-13 23:41:36 +00:00
Merge pull request #2816 from sususu98/feat/session-affinity
feat(session-affinity): add session-sticky routing for multi-account load balancing
This commit is contained in:
@@ -103,6 +103,13 @@ quota-exceeded:
|
|||||||
# Routing strategy for selecting credentials when multiple match.
|
# Routing strategy for selecting credentials when multiple match.
|
||||||
routing:
|
routing:
|
||||||
strategy: "round-robin" # round-robin (default), fill-first
|
strategy: "round-robin" # round-robin (default), fill-first
|
||||||
|
# Enable universal session-sticky routing for all clients.
|
||||||
|
# Session IDs are extracted from: X-Session-ID header, Idempotency-Key,
|
||||||
|
# metadata.user_id, conversation_id, or first few messages hash.
|
||||||
|
# Automatic failover is always enabled when bound auth becomes unavailable.
|
||||||
|
session-affinity: false # default: false
|
||||||
|
# How long session-to-auth bindings are retained. Default: 1h
|
||||||
|
session-affinity-ttl: "1h"
|
||||||
|
|
||||||
# When true, enable authentication for the WebSocket API (/v1/ws).
|
# When true, enable authentication for the WebSocket API (/v1/ws).
|
||||||
ws-auth: false
|
ws-auth: false
|
||||||
|
|||||||
@@ -216,6 +216,22 @@ type RoutingConfig struct {
|
|||||||
// Strategy selects the credential selection strategy.
|
// Strategy selects the credential selection strategy.
|
||||||
// Supported values: "round-robin" (default), "fill-first".
|
// Supported values: "round-robin" (default), "fill-first".
|
||||||
Strategy string `yaml:"strategy,omitempty" json:"strategy,omitempty"`
|
Strategy string `yaml:"strategy,omitempty" json:"strategy,omitempty"`
|
||||||
|
|
||||||
|
// ClaudeCodeSessionAffinity enables session-sticky routing for Claude Code clients.
|
||||||
|
// When enabled, requests with the same session ID (extracted from metadata.user_id)
|
||||||
|
// are routed to the same auth credential when available.
|
||||||
|
// Deprecated: Use SessionAffinity instead for universal session support.
|
||||||
|
ClaudeCodeSessionAffinity bool `yaml:"claude-code-session-affinity,omitempty" json:"claude-code-session-affinity,omitempty"`
|
||||||
|
|
||||||
|
// SessionAffinity enables universal session-sticky routing for all clients.
|
||||||
|
// Session IDs are extracted from multiple sources:
|
||||||
|
// X-Session-ID header, Idempotency-Key, metadata.user_id, conversation_id, or message hash.
|
||||||
|
// Automatic failover is always enabled when bound auth becomes unavailable.
|
||||||
|
SessionAffinity bool `yaml:"session-affinity,omitempty" json:"session-affinity,omitempty"`
|
||||||
|
|
||||||
|
// SessionAffinityTTL specifies how long session-to-auth bindings are retained.
|
||||||
|
// Default: 1h. Accepts duration strings like "30m", "1h", "2h30m".
|
||||||
|
SessionAffinityTTL string `yaml:"session-affinity-ttl,omitempty" json:"session-affinity-ttl,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// OAuthModelAlias defines a model ID alias for a specific channel.
|
// OAuthModelAlias defines a model ID alias for a specific channel.
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/google/uuid"
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
@@ -188,7 +187,7 @@ func PassthroughHeadersEnabled(cfg *config.SDKConfig) bool {
|
|||||||
|
|
||||||
func requestExecutionMetadata(ctx context.Context) map[string]any {
|
func requestExecutionMetadata(ctx context.Context) map[string]any {
|
||||||
// Idempotency-Key is an optional client-supplied header used to correlate retries.
|
// Idempotency-Key is an optional client-supplied header used to correlate retries.
|
||||||
// It is forwarded as execution metadata; when absent we generate a UUID.
|
// Only include it if the client explicitly provides it.
|
||||||
key := ""
|
key := ""
|
||||||
if ctx != nil {
|
if ctx != nil {
|
||||||
if ginCtx, ok := ctx.Value("gin").(*gin.Context); ok && ginCtx != nil && ginCtx.Request != nil {
|
if ginCtx, ok := ctx.Value("gin").(*gin.Context); ok && ginCtx != nil && ginCtx.Request != nil {
|
||||||
@@ -196,7 +195,7 @@ func requestExecutionMetadata(ctx context.Context) map[string]any {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if key == "" {
|
if key == "" {
|
||||||
key = uuid.NewString()
|
return make(map[string]any)
|
||||||
}
|
}
|
||||||
|
|
||||||
meta := map[string]any{idempotencyKeyMetadataKey: key}
|
meta := map[string]any{idempotencyKeyMetadataKey: key}
|
||||||
|
|||||||
@@ -105,6 +105,13 @@ type Selector interface {
|
|||||||
Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error)
|
Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// StoppableSelector is an optional interface for selectors that hold resources.
|
||||||
|
// Selectors that implement this interface will have Stop called during shutdown.
|
||||||
|
type StoppableSelector interface {
|
||||||
|
Selector
|
||||||
|
Stop()
|
||||||
|
}
|
||||||
|
|
||||||
// Hook captures lifecycle callbacks for observing auth changes.
|
// Hook captures lifecycle callbacks for observing auth changes.
|
||||||
type Hook interface {
|
type Hook interface {
|
||||||
// OnAuthRegistered fires when a new auth is registered.
|
// OnAuthRegistered fires when a new auth is registered.
|
||||||
@@ -2928,6 +2935,7 @@ func (m *Manager) StartAutoRefresh(parent context.Context, interval time.Duratio
|
|||||||
}
|
}
|
||||||
|
|
||||||
// StopAutoRefresh cancels the background refresh loop, if running.
|
// StopAutoRefresh cancels the background refresh loop, if running.
|
||||||
|
// It also stops the selector if it implements StoppableSelector.
|
||||||
func (m *Manager) StopAutoRefresh() {
|
func (m *Manager) StopAutoRefresh() {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
cancel := m.refreshCancel
|
cancel := m.refreshCancel
|
||||||
@@ -2937,6 +2945,10 @@ func (m *Manager) StopAutoRefresh() {
|
|||||||
if cancel != nil {
|
if cancel != nil {
|
||||||
cancel()
|
cancel()
|
||||||
}
|
}
|
||||||
|
// Stop selector if it implements StoppableSelector (e.g., SessionAffinitySelector)
|
||||||
|
if stoppable, ok := m.selector.(StoppableSelector); ok {
|
||||||
|
stoppable.Stop()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) queueRefreshReschedule(authID string) {
|
func (m *Manager) queueRefreshReschedule(authID string) {
|
||||||
|
|||||||
@@ -4,15 +4,21 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"hash/fnv"
|
||||||
"math"
|
"math"
|
||||||
"math/rand/v2"
|
"math/rand/v2"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"regexp"
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
)
|
)
|
||||||
@@ -420,3 +426,448 @@ func isAuthBlockedForModel(auth *Auth, model string, now time.Time) (bool, block
|
|||||||
}
|
}
|
||||||
return false, blockReasonNone, time.Time{}
|
return false, blockReasonNone, time.Time{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// sessionPattern matches Claude Code user_id format:
|
||||||
|
// user_{hash}_account__session_{uuid}
|
||||||
|
var sessionPattern = regexp.MustCompile(`_session_([a-f0-9-]+)$`)
|
||||||
|
|
||||||
|
// SessionAffinitySelector wraps another selector with session-sticky behavior.
|
||||||
|
// It extracts session ID from multiple sources and maintains session-to-auth
|
||||||
|
// mappings with automatic failover when the bound auth becomes unavailable.
|
||||||
|
type SessionAffinitySelector struct {
|
||||||
|
fallback Selector
|
||||||
|
cache *SessionCache
|
||||||
|
}
|
||||||
|
|
||||||
|
// SessionAffinityConfig configures the session affinity selector.
|
||||||
|
type SessionAffinityConfig struct {
|
||||||
|
Fallback Selector
|
||||||
|
TTL time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSessionAffinitySelector creates a new session-aware selector.
|
||||||
|
func NewSessionAffinitySelector(fallback Selector) *SessionAffinitySelector {
|
||||||
|
return NewSessionAffinitySelectorWithConfig(SessionAffinityConfig{
|
||||||
|
Fallback: fallback,
|
||||||
|
TTL: time.Hour,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSessionAffinitySelectorWithConfig creates a selector with custom configuration.
|
||||||
|
func NewSessionAffinitySelectorWithConfig(cfg SessionAffinityConfig) *SessionAffinitySelector {
|
||||||
|
if cfg.Fallback == nil {
|
||||||
|
cfg.Fallback = &RoundRobinSelector{}
|
||||||
|
}
|
||||||
|
if cfg.TTL <= 0 {
|
||||||
|
cfg.TTL = time.Hour
|
||||||
|
}
|
||||||
|
return &SessionAffinitySelector{
|
||||||
|
fallback: cfg.Fallback,
|
||||||
|
cache: NewSessionCache(cfg.TTL),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pick selects an auth with session affinity when possible.
|
||||||
|
// Priority for session ID extraction:
|
||||||
|
// 1. metadata.user_id (Claude Code format) - highest priority
|
||||||
|
// 2. X-Session-ID header
|
||||||
|
// 3. metadata.user_id (non-Claude Code format)
|
||||||
|
// 4. conversation_id field
|
||||||
|
// 5. Hash-based fallback from messages
|
||||||
|
//
|
||||||
|
// Note: The cache key includes provider, session ID, and model to handle cases where
|
||||||
|
// a session uses multiple models (e.g., gemini-2.5-pro and gemini-3-flash-preview)
|
||||||
|
// that may be supported by different auth credentials, and to avoid cross-provider conflicts.
|
||||||
|
func (s *SessionAffinitySelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) {
|
||||||
|
entry := selectorLogEntry(ctx)
|
||||||
|
primaryID, fallbackID := extractSessionIDs(opts.Headers, opts.OriginalRequest, opts.Metadata)
|
||||||
|
if primaryID == "" {
|
||||||
|
entry.Debugf("session-affinity: no session ID extracted, falling back to default selector | provider=%s model=%s", provider, model)
|
||||||
|
return s.fallback.Pick(ctx, provider, model, opts, auths)
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
available, err := getAvailableAuths(auths, provider, model, now)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
cacheKey := provider + "::" + primaryID + "::" + model
|
||||||
|
|
||||||
|
if cachedAuthID, ok := s.cache.GetAndRefresh(cacheKey); ok {
|
||||||
|
for _, auth := range available {
|
||||||
|
if auth.ID == cachedAuthID {
|
||||||
|
entry.Infof("session-affinity: cache hit | session=%s auth=%s provider=%s model=%s", truncateSessionID(primaryID), auth.ID, provider, model)
|
||||||
|
return auth, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Cached auth not available, reselect via fallback selector for even distribution
|
||||||
|
auth, err := s.fallback.Pick(ctx, provider, model, opts, auths)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
s.cache.Set(cacheKey, auth.ID)
|
||||||
|
entry.Infof("session-affinity: cache hit but auth unavailable, reselected | session=%s auth=%s provider=%s model=%s", truncateSessionID(primaryID), auth.ID, provider, model)
|
||||||
|
return auth, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if fallbackID != "" && fallbackID != primaryID {
|
||||||
|
fallbackKey := provider + "::" + fallbackID + "::" + model
|
||||||
|
if cachedAuthID, ok := s.cache.Get(fallbackKey); ok {
|
||||||
|
for _, auth := range available {
|
||||||
|
if auth.ID == cachedAuthID {
|
||||||
|
s.cache.Set(cacheKey, auth.ID)
|
||||||
|
entry.Infof("session-affinity: fallback cache hit | session=%s fallback=%s auth=%s provider=%s model=%s", truncateSessionID(primaryID), truncateSessionID(fallbackID), auth.ID, provider, model)
|
||||||
|
return auth, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auth, err := s.fallback.Pick(ctx, provider, model, opts, auths)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
s.cache.Set(cacheKey, auth.ID)
|
||||||
|
entry.Infof("session-affinity: cache miss, new binding | session=%s auth=%s provider=%s model=%s", truncateSessionID(primaryID), auth.ID, provider, model)
|
||||||
|
return auth, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func selectorLogEntry(ctx context.Context) *log.Entry {
|
||||||
|
if ctx == nil {
|
||||||
|
return log.NewEntry(log.StandardLogger())
|
||||||
|
}
|
||||||
|
if reqID := logging.GetRequestID(ctx); reqID != "" {
|
||||||
|
return log.WithField("request_id", reqID)
|
||||||
|
}
|
||||||
|
return log.NewEntry(log.StandardLogger())
|
||||||
|
}
|
||||||
|
|
||||||
|
// truncateSessionID shortens session ID for logging (first 8 chars + "...")
|
||||||
|
func truncateSessionID(id string) string {
|
||||||
|
if len(id) <= 20 {
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
return id[:8] + "..."
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop releases resources held by the selector.
|
||||||
|
func (s *SessionAffinitySelector) Stop() {
|
||||||
|
if s.cache != nil {
|
||||||
|
s.cache.Stop()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// InvalidateAuth removes all session bindings for a specific auth.
|
||||||
|
// Called when an auth becomes rate-limited or unavailable.
|
||||||
|
func (s *SessionAffinitySelector) InvalidateAuth(authID string) {
|
||||||
|
if s.cache != nil {
|
||||||
|
s.cache.InvalidateAuth(authID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtractSessionID extracts session identifier from multiple sources.
|
||||||
|
// Priority order:
|
||||||
|
// 1. metadata.user_id (Claude Code format with _session_{uuid}) - highest priority for Claude Code clients
|
||||||
|
// 2. X-Session-ID header
|
||||||
|
// 3. metadata.user_id (non-Claude Code format)
|
||||||
|
// 4. conversation_id field in request body
|
||||||
|
// 5. Stable hash from first few messages content (fallback)
|
||||||
|
func ExtractSessionID(headers http.Header, payload []byte, metadata map[string]any) string {
|
||||||
|
primary, _ := extractSessionIDs(headers, payload, metadata)
|
||||||
|
return primary
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractSessionIDs returns (primaryID, fallbackID) for session affinity.
|
||||||
|
// primaryID: full hash including assistant response (stable after first turn)
|
||||||
|
// fallbackID: short hash without assistant (used to inherit binding from first turn)
|
||||||
|
func extractSessionIDs(headers http.Header, payload []byte, metadata map[string]any) (string, string) {
|
||||||
|
// 1. metadata.user_id with Claude Code session format (highest priority)
|
||||||
|
if len(payload) > 0 {
|
||||||
|
userID := gjson.GetBytes(payload, "metadata.user_id").String()
|
||||||
|
if userID != "" {
|
||||||
|
// Old format: user_{hash}_account__session_{uuid}
|
||||||
|
if matches := sessionPattern.FindStringSubmatch(userID); len(matches) >= 2 {
|
||||||
|
id := "claude:" + matches[1]
|
||||||
|
return id, ""
|
||||||
|
}
|
||||||
|
// New format: JSON object with session_id field
|
||||||
|
// e.g. {"device_id":"...","account_uuid":"...","session_id":"uuid"}
|
||||||
|
if len(userID) > 0 && userID[0] == '{' {
|
||||||
|
if sid := gjson.Get(userID, "session_id").String(); sid != "" {
|
||||||
|
return "claude:" + sid, ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. X-Session-ID header
|
||||||
|
if headers != nil {
|
||||||
|
if sid := headers.Get("X-Session-ID"); sid != "" {
|
||||||
|
return "header:" + sid, ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(payload) == 0 {
|
||||||
|
return "", ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. metadata.user_id (non-Claude Code format)
|
||||||
|
userID := gjson.GetBytes(payload, "metadata.user_id").String()
|
||||||
|
if userID != "" {
|
||||||
|
return "user:" + userID, ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. conversation_id field
|
||||||
|
if convID := gjson.GetBytes(payload, "conversation_id").String(); convID != "" {
|
||||||
|
return "conv:" + convID, ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5. Hash-based fallback from message content
|
||||||
|
return extractMessageHashIDs(payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractMessageHashIDs(payload []byte) (primaryID, fallbackID string) {
|
||||||
|
var systemPrompt, firstUserMsg, firstAssistantMsg string
|
||||||
|
|
||||||
|
// OpenAI/Claude messages format
|
||||||
|
messages := gjson.GetBytes(payload, "messages")
|
||||||
|
if messages.Exists() && messages.IsArray() {
|
||||||
|
messages.ForEach(func(_, msg gjson.Result) bool {
|
||||||
|
role := msg.Get("role").String()
|
||||||
|
content := extractMessageContent(msg.Get("content"))
|
||||||
|
if content == "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
switch role {
|
||||||
|
case "system":
|
||||||
|
if systemPrompt == "" {
|
||||||
|
systemPrompt = truncateString(content, 100)
|
||||||
|
}
|
||||||
|
case "user":
|
||||||
|
if firstUserMsg == "" {
|
||||||
|
firstUserMsg = truncateString(content, 100)
|
||||||
|
}
|
||||||
|
case "assistant":
|
||||||
|
if firstAssistantMsg == "" {
|
||||||
|
firstAssistantMsg = truncateString(content, 100)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if systemPrompt != "" && firstUserMsg != "" && firstAssistantMsg != "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Claude API: top-level "system" field (array or string)
|
||||||
|
if systemPrompt == "" {
|
||||||
|
topSystem := gjson.GetBytes(payload, "system")
|
||||||
|
if topSystem.Exists() {
|
||||||
|
if topSystem.IsArray() {
|
||||||
|
topSystem.ForEach(func(_, part gjson.Result) bool {
|
||||||
|
if text := part.Get("text").String(); text != "" && systemPrompt == "" {
|
||||||
|
systemPrompt = truncateString(text, 100)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
} else if topSystem.Type == gjson.String {
|
||||||
|
systemPrompt = truncateString(topSystem.String(), 100)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Gemini format
|
||||||
|
if systemPrompt == "" && firstUserMsg == "" {
|
||||||
|
sysInstr := gjson.GetBytes(payload, "systemInstruction.parts")
|
||||||
|
if sysInstr.Exists() && sysInstr.IsArray() {
|
||||||
|
sysInstr.ForEach(func(_, part gjson.Result) bool {
|
||||||
|
if text := part.Get("text").String(); text != "" && systemPrompt == "" {
|
||||||
|
systemPrompt = truncateString(text, 100)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
contents := gjson.GetBytes(payload, "contents")
|
||||||
|
if contents.Exists() && contents.IsArray() {
|
||||||
|
contents.ForEach(func(_, msg gjson.Result) bool {
|
||||||
|
role := msg.Get("role").String()
|
||||||
|
msg.Get("parts").ForEach(func(_, part gjson.Result) bool {
|
||||||
|
text := part.Get("text").String()
|
||||||
|
if text == "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
switch role {
|
||||||
|
case "user":
|
||||||
|
if firstUserMsg == "" {
|
||||||
|
firstUserMsg = truncateString(text, 100)
|
||||||
|
}
|
||||||
|
case "model":
|
||||||
|
if firstAssistantMsg == "" {
|
||||||
|
firstAssistantMsg = truncateString(text, 100)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
})
|
||||||
|
if firstUserMsg != "" && firstAssistantMsg != "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OpenAI Responses API format (v1/responses)
|
||||||
|
if systemPrompt == "" && firstUserMsg == "" {
|
||||||
|
if instr := gjson.GetBytes(payload, "instructions").String(); instr != "" {
|
||||||
|
systemPrompt = truncateString(instr, 100)
|
||||||
|
}
|
||||||
|
|
||||||
|
input := gjson.GetBytes(payload, "input")
|
||||||
|
if input.Exists() && input.IsArray() {
|
||||||
|
input.ForEach(func(_, item gjson.Result) bool {
|
||||||
|
itemType := item.Get("type").String()
|
||||||
|
if itemType == "reasoning" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
// Skip non-message typed items (function_call, function_call_output, etc.)
|
||||||
|
// but allow items with no type that have a role (inline message format).
|
||||||
|
if itemType != "" && itemType != "message" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
role := item.Get("role").String()
|
||||||
|
if itemType == "" && role == "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle both string content and array content (multimodal).
|
||||||
|
content := item.Get("content")
|
||||||
|
var text string
|
||||||
|
if content.Type == gjson.String {
|
||||||
|
text = content.String()
|
||||||
|
} else {
|
||||||
|
text = extractResponsesAPIContent(content)
|
||||||
|
}
|
||||||
|
if text == "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
switch role {
|
||||||
|
case "developer", "system":
|
||||||
|
if systemPrompt == "" {
|
||||||
|
systemPrompt = truncateString(text, 100)
|
||||||
|
}
|
||||||
|
case "user":
|
||||||
|
if firstUserMsg == "" {
|
||||||
|
firstUserMsg = truncateString(text, 100)
|
||||||
|
}
|
||||||
|
case "assistant":
|
||||||
|
if firstAssistantMsg == "" {
|
||||||
|
firstAssistantMsg = truncateString(text, 100)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if firstUserMsg != "" && firstAssistantMsg != "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if systemPrompt == "" && firstUserMsg == "" {
|
||||||
|
return "", ""
|
||||||
|
}
|
||||||
|
|
||||||
|
shortHash := computeSessionHash(systemPrompt, firstUserMsg, "")
|
||||||
|
if firstAssistantMsg == "" {
|
||||||
|
return shortHash, ""
|
||||||
|
}
|
||||||
|
|
||||||
|
fullHash := computeSessionHash(systemPrompt, firstUserMsg, firstAssistantMsg)
|
||||||
|
return fullHash, shortHash
|
||||||
|
}
|
||||||
|
|
||||||
|
func computeSessionHash(systemPrompt, userMsg, assistantMsg string) string {
|
||||||
|
h := fnv.New64a()
|
||||||
|
if systemPrompt != "" {
|
||||||
|
h.Write([]byte("sys:" + systemPrompt + "\n"))
|
||||||
|
}
|
||||||
|
if userMsg != "" {
|
||||||
|
h.Write([]byte("usr:" + userMsg + "\n"))
|
||||||
|
}
|
||||||
|
if assistantMsg != "" {
|
||||||
|
h.Write([]byte("ast:" + assistantMsg + "\n"))
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("msg:%016x", h.Sum64())
|
||||||
|
}
|
||||||
|
|
||||||
|
func truncateString(s string, maxLen int) string {
|
||||||
|
if len(s) > maxLen {
|
||||||
|
return s[:maxLen]
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractMessageContent extracts text content from a message content field.
|
||||||
|
// Handles both string content and array content (multimodal messages).
|
||||||
|
// For array content, extracts text from all text-type elements.
|
||||||
|
func extractMessageContent(content gjson.Result) string {
|
||||||
|
// String content: "Hello world"
|
||||||
|
if content.Type == gjson.String {
|
||||||
|
return content.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Array content: [{"type":"text","text":"Hello"},{"type":"image",...}]
|
||||||
|
if content.IsArray() {
|
||||||
|
var texts []string
|
||||||
|
content.ForEach(func(_, part gjson.Result) bool {
|
||||||
|
// Handle Claude format: {"type":"text","text":"content"}
|
||||||
|
if part.Get("type").String() == "text" {
|
||||||
|
if text := part.Get("text").String(); text != "" {
|
||||||
|
texts = append(texts, text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Handle OpenAI format: {"type":"text","text":"content"}
|
||||||
|
// Same structure as Claude, already handled above
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
if len(texts) > 0 {
|
||||||
|
return strings.Join(texts, " ")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractResponsesAPIContent(content gjson.Result) string {
|
||||||
|
if !content.IsArray() {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
var texts []string
|
||||||
|
content.ForEach(func(_, part gjson.Result) bool {
|
||||||
|
partType := part.Get("type").String()
|
||||||
|
if partType == "input_text" || partType == "output_text" || partType == "text" {
|
||||||
|
if text := part.Get("text").String(); text != "" {
|
||||||
|
texts = append(texts, text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
if len(texts) > 0 {
|
||||||
|
return strings.Join(texts, " ")
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractSessionID is kept for backward compatibility.
|
||||||
|
// Deprecated: Use ExtractSessionID instead.
|
||||||
|
func extractSessionID(payload []byte) string {
|
||||||
|
return ExtractSessionID(nil, payload, nil)
|
||||||
|
}
|
||||||
|
|||||||
@@ -4,7 +4,9 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@@ -458,6 +460,159 @@ func TestRoundRobinSelectorPick_GeminiCLICredentialGrouping(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestExtractSessionID(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
payload string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid_claude_code_format",
|
||||||
|
payload: `{"metadata":{"user_id":"user_3f221fe75652cf9a89a31647f16274bb8036a9b85ac4dc226a4df0efec8dc04d_account__session_ac980658-63bd-4fb3-97ba-8da64cb1e344"}}`,
|
||||||
|
want: "claude:ac980658-63bd-4fb3-97ba-8da64cb1e344",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "json_user_id_with_session_id",
|
||||||
|
payload: `{"metadata":{"user_id":"{\"device_id\":\"be82c3aee1e0c2d74535bacc85f9f559228f02dd8a17298cf522b71e6c375714\",\"account_uuid\":\"\",\"session_id\":\"e26d4046-0f88-4b09-bb5b-f863ab5fb24e\"}"}}`,
|
||||||
|
want: "claude:e26d4046-0f88-4b09-bb5b-f863ab5fb24e",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "json_user_id_without_session_id",
|
||||||
|
payload: `{"metadata":{"user_id":"{\"device_id\":\"abc123\"}"}}`,
|
||||||
|
want: `user:{"device_id":"abc123"}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no_session_but_user_id",
|
||||||
|
payload: `{"metadata":{"user_id":"user_abc123"}}`,
|
||||||
|
want: "user:user_abc123",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "conversation_id",
|
||||||
|
payload: `{"conversation_id":"conv-12345"}`,
|
||||||
|
want: "conv:conv-12345",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no_metadata",
|
||||||
|
payload: `{"model":"claude-3"}`,
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty_payload",
|
||||||
|
payload: ``,
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := extractSessionID([]byte(tt.payload))
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("extractSessionID() = %q, want %q", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSessionAffinitySelector_SameSessionSameAuth(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
fallback := &RoundRobinSelector{}
|
||||||
|
selector := NewSessionAffinitySelector(fallback)
|
||||||
|
|
||||||
|
auths := []*Auth{
|
||||||
|
{ID: "auth-a"},
|
||||||
|
{ID: "auth-b"},
|
||||||
|
{ID: "auth-c"},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use valid UUID format for session ID
|
||||||
|
payload := []byte(`{"metadata":{"user_id":"user_xxx_account__session_ac980658-63bd-4fb3-97ba-8da64cb1e344"}}`)
|
||||||
|
opts := cliproxyexecutor.Options{OriginalRequest: payload}
|
||||||
|
|
||||||
|
// Same session should always pick the same auth
|
||||||
|
first, err := selector.Pick(context.Background(), "claude", "claude-3", opts, auths)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Pick() error = %v", err)
|
||||||
|
}
|
||||||
|
if first == nil {
|
||||||
|
t.Fatalf("Pick() returned nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify consistency: same session, same auths -> same result
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
got, err := selector.Pick(context.Background(), "claude", "claude-3", opts, auths)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Pick() #%d error = %v", i, err)
|
||||||
|
}
|
||||||
|
if got.ID != first.ID {
|
||||||
|
t.Fatalf("Pick() #%d auth.ID = %q, want %q (same session should pick same auth)", i, got.ID, first.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSessionAffinitySelector_NoSessionFallback(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
fallback := &FillFirstSelector{}
|
||||||
|
selector := NewSessionAffinitySelector(fallback)
|
||||||
|
|
||||||
|
auths := []*Auth{
|
||||||
|
{ID: "auth-b"},
|
||||||
|
{ID: "auth-a"},
|
||||||
|
{ID: "auth-c"},
|
||||||
|
}
|
||||||
|
|
||||||
|
// No session in payload, should fallback to FillFirstSelector (picks "auth-a" after sorting)
|
||||||
|
payload := []byte(`{"model":"claude-3"}`)
|
||||||
|
opts := cliproxyexecutor.Options{OriginalRequest: payload}
|
||||||
|
|
||||||
|
got, err := selector.Pick(context.Background(), "claude", "claude-3", opts, auths)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Pick() error = %v", err)
|
||||||
|
}
|
||||||
|
if got.ID != "auth-a" {
|
||||||
|
t.Fatalf("Pick() auth.ID = %q, want %q (should fallback to FillFirst)", got.ID, "auth-a")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSessionAffinitySelector_DifferentSessionsDifferentAuths(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
fallback := &RoundRobinSelector{}
|
||||||
|
selector := NewSessionAffinitySelector(fallback)
|
||||||
|
|
||||||
|
auths := []*Auth{
|
||||||
|
{ID: "auth-a"},
|
||||||
|
{ID: "auth-b"},
|
||||||
|
{ID: "auth-c"},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use valid UUID format for session IDs
|
||||||
|
session1 := []byte(`{"metadata":{"user_id":"user_xxx_account__session_11111111-1111-1111-1111-111111111111"}}`)
|
||||||
|
session2 := []byte(`{"metadata":{"user_id":"user_xxx_account__session_22222222-2222-2222-2222-222222222222"}}`)
|
||||||
|
|
||||||
|
opts1 := cliproxyexecutor.Options{OriginalRequest: session1}
|
||||||
|
opts2 := cliproxyexecutor.Options{OriginalRequest: session2}
|
||||||
|
|
||||||
|
auth1, _ := selector.Pick(context.Background(), "claude", "claude-3", opts1, auths)
|
||||||
|
auth2, _ := selector.Pick(context.Background(), "claude", "claude-3", opts2, auths)
|
||||||
|
|
||||||
|
// Different sessions may or may not pick different auths (depends on hash collision)
|
||||||
|
// But each session should be consistent
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
got1, _ := selector.Pick(context.Background(), "claude", "claude-3", opts1, auths)
|
||||||
|
got2, _ := selector.Pick(context.Background(), "claude", "claude-3", opts2, auths)
|
||||||
|
if got1.ID != auth1.ID {
|
||||||
|
t.Fatalf("session1 Pick() #%d inconsistent: got %q, want %q", i, got1.ID, auth1.ID)
|
||||||
|
}
|
||||||
|
if got2.ID != auth2.ID {
|
||||||
|
t.Fatalf("session2 Pick() #%d inconsistent: got %q, want %q", i, got2.ID, auth2.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestRoundRobinSelectorPick_SingleParentFallsBackToFlat(t *testing.T) {
|
func TestRoundRobinSelectorPick_SingleParentFallsBackToFlat(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
@@ -494,6 +649,57 @@ func TestRoundRobinSelectorPick_SingleParentFallsBackToFlat(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSessionAffinitySelector_FailoverWhenAuthUnavailable(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
fallback := &RoundRobinSelector{}
|
||||||
|
selector := NewSessionAffinitySelectorWithConfig(SessionAffinityConfig{
|
||||||
|
Fallback: fallback,
|
||||||
|
TTL: time.Minute,
|
||||||
|
})
|
||||||
|
defer selector.Stop()
|
||||||
|
|
||||||
|
auths := []*Auth{
|
||||||
|
{ID: "auth-a"},
|
||||||
|
{ID: "auth-b"},
|
||||||
|
{ID: "auth-c"},
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := []byte(`{"metadata":{"user_id":"user_xxx_account__session_failover-test-uuid"}}`)
|
||||||
|
opts := cliproxyexecutor.Options{OriginalRequest: payload}
|
||||||
|
|
||||||
|
// First pick establishes binding
|
||||||
|
first, err := selector.Pick(context.Background(), "claude", "claude-3", opts, auths)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Pick() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove the bound auth from available list (simulating rate limit)
|
||||||
|
availableWithoutFirst := make([]*Auth, 0, len(auths)-1)
|
||||||
|
for _, a := range auths {
|
||||||
|
if a.ID != first.ID {
|
||||||
|
availableWithoutFirst = append(availableWithoutFirst, a)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// With failover enabled, should pick a new auth
|
||||||
|
second, err := selector.Pick(context.Background(), "claude", "claude-3", opts, availableWithoutFirst)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Pick() after failover error = %v", err)
|
||||||
|
}
|
||||||
|
if second.ID == first.ID {
|
||||||
|
t.Fatalf("Pick() after failover returned same auth %q, expected different", first.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subsequent picks should consistently return the new binding
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
got, _ := selector.Pick(context.Background(), "claude", "claude-3", opts, availableWithoutFirst)
|
||||||
|
if got.ID != second.ID {
|
||||||
|
t.Fatalf("Pick() #%d after failover inconsistent: got %q, want %q", i, got.ID, second.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestRoundRobinSelectorPick_MixedVirtualAndNonVirtualFallsBackToFlat(t *testing.T) {
|
func TestRoundRobinSelectorPick_MixedVirtualAndNonVirtualFallsBackToFlat(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
@@ -527,3 +733,629 @@ func TestRoundRobinSelectorPick_MixedVirtualAndNonVirtualFallsBackToFlat(t *test
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
func TestExtractSessionID_ClaudeCodePriorityOverHeader(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// Claude Code metadata.user_id should have highest priority, even when X-Session-ID header is present
|
||||||
|
headers := make(http.Header)
|
||||||
|
headers.Set("X-Session-ID", "header-session-id")
|
||||||
|
|
||||||
|
payload := []byte(`{"metadata":{"user_id":"user_xxx_account__session_ac980658-63bd-4fb3-97ba-8da64cb1e344"}}`)
|
||||||
|
|
||||||
|
got := ExtractSessionID(headers, payload, nil)
|
||||||
|
want := "claude:ac980658-63bd-4fb3-97ba-8da64cb1e344"
|
||||||
|
if got != want {
|
||||||
|
t.Errorf("ExtractSessionID() = %q, want %q (Claude Code should have highest priority over header)", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractSessionID_ClaudeCodePriorityOverIdempotencyKey(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// Claude Code metadata.user_id should have highest priority, even when idempotency_key is present
|
||||||
|
metadata := map[string]any{"idempotency_key": "idem-12345"}
|
||||||
|
payload := []byte(`{"metadata":{"user_id":"user_xxx_account__session_ac980658-63bd-4fb3-97ba-8da64cb1e344"}}`)
|
||||||
|
|
||||||
|
got := ExtractSessionID(nil, payload, metadata)
|
||||||
|
want := "claude:ac980658-63bd-4fb3-97ba-8da64cb1e344"
|
||||||
|
if got != want {
|
||||||
|
t.Errorf("ExtractSessionID() = %q, want %q (Claude Code should have highest priority over idempotency_key)", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractSessionID_Headers(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
headers := make(http.Header)
|
||||||
|
headers.Set("X-Session-ID", "my-explicit-session")
|
||||||
|
|
||||||
|
got := ExtractSessionID(headers, nil, nil)
|
||||||
|
want := "header:my-explicit-session"
|
||||||
|
if got != want {
|
||||||
|
t.Errorf("ExtractSessionID() with header = %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestExtractSessionID_IdempotencyKey verifies that idempotency_key is intentionally
|
||||||
|
// ignored for session affinity (it's auto-generated per-request, causing cache misses).
|
||||||
|
func TestExtractSessionID_IdempotencyKey(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
metadata := map[string]any{"idempotency_key": "idem-12345"}
|
||||||
|
|
||||||
|
got := ExtractSessionID(nil, nil, metadata)
|
||||||
|
// idempotency_key is disabled - should return empty (no payload to hash)
|
||||||
|
if got != "" {
|
||||||
|
t.Errorf("ExtractSessionID() with idempotency_key = %q, want empty (idempotency_key is disabled)", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractSessionID_MessageHashFallback(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// First request (user only) generates short hash
|
||||||
|
firstRequestPayload := []byte(`{"messages":[{"role":"user","content":"Hello world"}]}`)
|
||||||
|
shortHash := ExtractSessionID(nil, firstRequestPayload, nil)
|
||||||
|
if shortHash == "" {
|
||||||
|
t.Error("ExtractSessionID() first request should return short hash")
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(shortHash, "msg:") {
|
||||||
|
t.Errorf("ExtractSessionID() = %q, want prefix 'msg:'", shortHash)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Multi-turn with assistant generates full hash (different from short hash)
|
||||||
|
multiTurnPayload := []byte(`{"messages":[
|
||||||
|
{"role":"user","content":"Hello world"},
|
||||||
|
{"role":"assistant","content":"Hi! How can I help?"},
|
||||||
|
{"role":"user","content":"Tell me a joke"}
|
||||||
|
]}`)
|
||||||
|
fullHash := ExtractSessionID(nil, multiTurnPayload, nil)
|
||||||
|
if fullHash == "" {
|
||||||
|
t.Error("ExtractSessionID() multi-turn should return full hash")
|
||||||
|
}
|
||||||
|
if fullHash == shortHash {
|
||||||
|
t.Error("Full hash should differ from short hash (includes assistant)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Same multi-turn payload should produce same hash
|
||||||
|
fullHash2 := ExtractSessionID(nil, multiTurnPayload, nil)
|
||||||
|
if fullHash != fullHash2 {
|
||||||
|
t.Errorf("ExtractSessionID() not stable: got %q then %q", fullHash, fullHash2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractSessionID_ClaudeAPITopLevelSystem(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// Claude API: system prompt in top-level "system" field (array format)
|
||||||
|
arraySystem := []byte(`{
|
||||||
|
"messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}],
|
||||||
|
"system": [{"type": "text", "text": "You are Claude Code"}]
|
||||||
|
}`)
|
||||||
|
got1 := ExtractSessionID(nil, arraySystem, nil)
|
||||||
|
if got1 == "" || !strings.HasPrefix(got1, "msg:") {
|
||||||
|
t.Errorf("ExtractSessionID() with array system = %q, want msg:* prefix", got1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Claude API: system prompt in top-level "system" field (string format)
|
||||||
|
stringSystem := []byte(`{
|
||||||
|
"messages": [{"role": "user", "content": "Hello"}],
|
||||||
|
"system": "You are Claude Code"
|
||||||
|
}`)
|
||||||
|
got2 := ExtractSessionID(nil, stringSystem, nil)
|
||||||
|
if got2 == "" || !strings.HasPrefix(got2, "msg:") {
|
||||||
|
t.Errorf("ExtractSessionID() with string system = %q, want msg:* prefix", got2)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Multi-turn with top-level system should produce stable hash
|
||||||
|
multiTurn := []byte(`{
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Hello"},
|
||||||
|
{"role": "assistant", "content": "Hi!"},
|
||||||
|
{"role": "user", "content": "Help me"}
|
||||||
|
],
|
||||||
|
"system": "You are Claude Code"
|
||||||
|
}`)
|
||||||
|
got3 := ExtractSessionID(nil, multiTurn, nil)
|
||||||
|
if got3 == "" {
|
||||||
|
t.Error("ExtractSessionID() multi-turn with top-level system should return hash")
|
||||||
|
}
|
||||||
|
if got3 == got2 {
|
||||||
|
t.Error("Multi-turn hash should differ from first-turn hash (includes assistant)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractSessionID_GeminiFormat(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// Gemini format with systemInstruction and contents
|
||||||
|
payload := []byte(`{
|
||||||
|
"systemInstruction": {"parts": [{"text": "You are a helpful assistant."}]},
|
||||||
|
"contents": [
|
||||||
|
{"role": "user", "parts": [{"text": "Hello Gemini"}]},
|
||||||
|
{"role": "model", "parts": [{"text": "Hi there!"}]}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
got := ExtractSessionID(nil, payload, nil)
|
||||||
|
if got == "" {
|
||||||
|
t.Error("ExtractSessionID() with Gemini format should return hash-based session ID")
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(got, "msg:") {
|
||||||
|
t.Errorf("ExtractSessionID() = %q, want prefix 'msg:'", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Same payload should produce same hash
|
||||||
|
got2 := ExtractSessionID(nil, payload, nil)
|
||||||
|
if got != got2 {
|
||||||
|
t.Errorf("ExtractSessionID() not stable: got %q then %q", got, got2)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Different user message should produce different hash
|
||||||
|
differentPayload := []byte(`{
|
||||||
|
"systemInstruction": {"parts": [{"text": "You are a helpful assistant."}]},
|
||||||
|
"contents": [
|
||||||
|
{"role": "user", "parts": [{"text": "Hello different"}]},
|
||||||
|
{"role": "model", "parts": [{"text": "Hi there!"}]}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
got3 := ExtractSessionID(nil, differentPayload, nil)
|
||||||
|
if got == got3 {
|
||||||
|
t.Errorf("ExtractSessionID() should produce different hash for different user message")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractSessionID_OpenAIResponsesAPI(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
firstTurn := []byte(`{
|
||||||
|
"instructions": "You are Codex, based on GPT-5.",
|
||||||
|
"input": [
|
||||||
|
{"type": "message", "role": "developer", "content": [{"type": "input_text", "text": "system instructions"}]},
|
||||||
|
{"type": "message", "role": "user", "content": [{"type": "input_text", "text": "hi"}]}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
got1 := ExtractSessionID(nil, firstTurn, nil)
|
||||||
|
if got1 == "" {
|
||||||
|
t.Error("ExtractSessionID() should return hash for OpenAI Responses API format")
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(got1, "msg:") {
|
||||||
|
t.Errorf("ExtractSessionID() = %q, want prefix 'msg:'", got1)
|
||||||
|
}
|
||||||
|
|
||||||
|
secondTurn := []byte(`{
|
||||||
|
"instructions": "You are Codex, based on GPT-5.",
|
||||||
|
"input": [
|
||||||
|
{"type": "message", "role": "developer", "content": [{"type": "input_text", "text": "system instructions"}]},
|
||||||
|
{"type": "message", "role": "user", "content": [{"type": "input_text", "text": "hi"}]},
|
||||||
|
{"type": "reasoning", "summary": [{"type": "summary_text", "text": "thinking..."}], "encrypted_content": "xxx"},
|
||||||
|
{"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": "Hello!"}]},
|
||||||
|
{"type": "message", "role": "user", "content": [{"type": "input_text", "text": "what can you do"}]}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
got2 := ExtractSessionID(nil, secondTurn, nil)
|
||||||
|
if got2 == "" {
|
||||||
|
t.Error("ExtractSessionID() should return hash for second turn")
|
||||||
|
}
|
||||||
|
|
||||||
|
if got1 == got2 {
|
||||||
|
t.Log("First turn and second turn have different hashes (expected: second includes assistant)")
|
||||||
|
}
|
||||||
|
|
||||||
|
thirdTurn := []byte(`{
|
||||||
|
"instructions": "You are Codex, based on GPT-5.",
|
||||||
|
"input": [
|
||||||
|
{"type": "message", "role": "developer", "content": [{"type": "input_text", "text": "system instructions"}]},
|
||||||
|
{"type": "message", "role": "user", "content": [{"type": "input_text", "text": "hi"}]},
|
||||||
|
{"type": "reasoning", "summary": [{"type": "summary_text", "text": "thinking..."}], "encrypted_content": "xxx"},
|
||||||
|
{"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": "Hello!"}]},
|
||||||
|
{"type": "message", "role": "user", "content": [{"type": "input_text", "text": "what can you do"}]},
|
||||||
|
{"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": "I can help with..."}]},
|
||||||
|
{"type": "message", "role": "user", "content": [{"type": "input_text", "text": "thanks"}]}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
got3 := ExtractSessionID(nil, thirdTurn, nil)
|
||||||
|
if got2 != got3 {
|
||||||
|
t.Errorf("Second and third turn should have same hash (same first assistant): got %q vs %q", got2, got3)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSessionAffinitySelector_ThreeScenarios(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
fallback := &RoundRobinSelector{}
|
||||||
|
selector := NewSessionAffinitySelectorWithConfig(SessionAffinityConfig{
|
||||||
|
Fallback: fallback,
|
||||||
|
TTL: time.Minute,
|
||||||
|
})
|
||||||
|
defer selector.Stop()
|
||||||
|
|
||||||
|
auths := []*Auth{{ID: "auth-a"}, {ID: "auth-b"}, {ID: "auth-c"}}
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
scenario string
|
||||||
|
payload []byte
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "OpenAI_Scenario1_NewRequest",
|
||||||
|
scenario: "new",
|
||||||
|
payload: []byte(`{"messages":[{"role":"system","content":"You are helpful"},{"role":"user","content":"Hello"}]}`),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "OpenAI_Scenario2_SecondTurn",
|
||||||
|
scenario: "second",
|
||||||
|
payload: []byte(`{"messages":[{"role":"system","content":"You are helpful"},{"role":"user","content":"Hello"},{"role":"assistant","content":"Hi there!"},{"role":"user","content":"Help me"}]}`),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "OpenAI_Scenario3_ManyTurns",
|
||||||
|
scenario: "many",
|
||||||
|
payload: []byte(`{"messages":[{"role":"system","content":"You are helpful"},{"role":"user","content":"Hello"},{"role":"assistant","content":"Hi there!"},{"role":"user","content":"Help me"},{"role":"assistant","content":"Sure!"},{"role":"user","content":"Thanks"}]}`),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Gemini_Scenario1_NewRequest",
|
||||||
|
scenario: "new",
|
||||||
|
payload: []byte(`{"systemInstruction":{"parts":[{"text":"You are helpful"}]},"contents":[{"role":"user","parts":[{"text":"Hello Gemini"}]}]}`),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Gemini_Scenario2_SecondTurn",
|
||||||
|
scenario: "second",
|
||||||
|
payload: []byte(`{"systemInstruction":{"parts":[{"text":"You are helpful"}]},"contents":[{"role":"user","parts":[{"text":"Hello Gemini"}]},{"role":"model","parts":[{"text":"Hi!"}]},{"role":"user","parts":[{"text":"Help"}]}]}`),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Gemini_Scenario3_ManyTurns",
|
||||||
|
scenario: "many",
|
||||||
|
payload: []byte(`{"systemInstruction":{"parts":[{"text":"You are helpful"}]},"contents":[{"role":"user","parts":[{"text":"Hello Gemini"}]},{"role":"model","parts":[{"text":"Hi!"}]},{"role":"user","parts":[{"text":"Help"}]},{"role":"model","parts":[{"text":"Sure!"}]},{"role":"user","parts":[{"text":"Thanks"}]}]}`),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Claude_Scenario1_NewRequest",
|
||||||
|
scenario: "new",
|
||||||
|
payload: []byte(`{"messages":[{"role":"user","content":"Hello Claude"}]}`),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Claude_Scenario2_SecondTurn",
|
||||||
|
scenario: "second",
|
||||||
|
payload: []byte(`{"messages":[{"role":"user","content":"Hello Claude"},{"role":"assistant","content":"Hello!"},{"role":"user","content":"Help me"}]}`),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Claude_Scenario3_ManyTurns",
|
||||||
|
scenario: "many",
|
||||||
|
payload: []byte(`{"messages":[{"role":"user","content":"Hello Claude"},{"role":"assistant","content":"Hello!"},{"role":"user","content":"Help"},{"role":"assistant","content":"Sure!"},{"role":"user","content":"Thanks"}]}`),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
tc := tc
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
opts := cliproxyexecutor.Options{OriginalRequest: tc.payload}
|
||||||
|
picked, err := selector.Pick(context.Background(), "provider", "model", opts, auths)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Pick() error = %v", err)
|
||||||
|
}
|
||||||
|
if picked == nil {
|
||||||
|
t.Fatal("Pick() returned nil")
|
||||||
|
}
|
||||||
|
t.Logf("%s: picked %s", tc.name, picked.ID)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("Scenario2And3_SameAuth", func(t *testing.T) {
|
||||||
|
openaiS2 := []byte(`{"messages":[{"role":"system","content":"Stable test"},{"role":"user","content":"First msg"},{"role":"assistant","content":"Response"},{"role":"user","content":"Second"}]}`)
|
||||||
|
openaiS3 := []byte(`{"messages":[{"role":"system","content":"Stable test"},{"role":"user","content":"First msg"},{"role":"assistant","content":"Response"},{"role":"user","content":"Second"},{"role":"assistant","content":"More"},{"role":"user","content":"Third"}]}`)
|
||||||
|
|
||||||
|
opts2 := cliproxyexecutor.Options{OriginalRequest: openaiS2}
|
||||||
|
opts3 := cliproxyexecutor.Options{OriginalRequest: openaiS3}
|
||||||
|
|
||||||
|
picked2, _ := selector.Pick(context.Background(), "test", "model", opts2, auths)
|
||||||
|
picked3, _ := selector.Pick(context.Background(), "test", "model", opts3, auths)
|
||||||
|
|
||||||
|
if picked2.ID != picked3.ID {
|
||||||
|
t.Errorf("Scenario2 and Scenario3 should pick same auth: got %s vs %s", picked2.ID, picked3.ID)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Scenario1To2_InheritBinding", func(t *testing.T) {
|
||||||
|
s1 := []byte(`{"messages":[{"role":"system","content":"Inherit test"},{"role":"user","content":"Initial"}]}`)
|
||||||
|
s2 := []byte(`{"messages":[{"role":"system","content":"Inherit test"},{"role":"user","content":"Initial"},{"role":"assistant","content":"Reply"},{"role":"user","content":"Continue"}]}`)
|
||||||
|
|
||||||
|
opts1 := cliproxyexecutor.Options{OriginalRequest: s1}
|
||||||
|
opts2 := cliproxyexecutor.Options{OriginalRequest: s2}
|
||||||
|
|
||||||
|
picked1, _ := selector.Pick(context.Background(), "inherit", "model", opts1, auths)
|
||||||
|
picked2, _ := selector.Pick(context.Background(), "inherit", "model", opts2, auths)
|
||||||
|
|
||||||
|
if picked1.ID != picked2.ID {
|
||||||
|
t.Errorf("Scenario2 should inherit Scenario1 binding: got %s vs %s", picked1.ID, picked2.ID)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSessionAffinitySelector_MultiModelSession(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
fallback := &RoundRobinSelector{}
|
||||||
|
selector := NewSessionAffinitySelectorWithConfig(SessionAffinityConfig{
|
||||||
|
Fallback: fallback,
|
||||||
|
TTL: time.Minute,
|
||||||
|
})
|
||||||
|
defer selector.Stop()
|
||||||
|
|
||||||
|
// auth-a supports only model-a, auth-b supports only model-b
|
||||||
|
authA := &Auth{ID: "auth-a"}
|
||||||
|
authB := &Auth{ID: "auth-b"}
|
||||||
|
|
||||||
|
// Same session ID for all requests
|
||||||
|
payload := []byte(`{"metadata":{"user_id":"user_xxx_account__session_multi-model-test"}}`)
|
||||||
|
opts := cliproxyexecutor.Options{OriginalRequest: payload}
|
||||||
|
|
||||||
|
// Request model-a with only auth-a available for that model
|
||||||
|
authsForModelA := []*Auth{authA}
|
||||||
|
pickedA, err := selector.Pick(context.Background(), "provider", "model-a", opts, authsForModelA)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Pick() for model-a error = %v", err)
|
||||||
|
}
|
||||||
|
if pickedA.ID != "auth-a" {
|
||||||
|
t.Fatalf("Pick() for model-a = %q, want auth-a", pickedA.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Request model-b with only auth-b available for that model
|
||||||
|
authsForModelB := []*Auth{authB}
|
||||||
|
pickedB, err := selector.Pick(context.Background(), "provider", "model-b", opts, authsForModelB)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Pick() for model-b error = %v", err)
|
||||||
|
}
|
||||||
|
if pickedB.ID != "auth-b" {
|
||||||
|
t.Fatalf("Pick() for model-b = %q, want auth-b", pickedB.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Switch back to model-a - should still get auth-a (separate binding per model)
|
||||||
|
pickedA2, err := selector.Pick(context.Background(), "provider", "model-a", opts, authsForModelA)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Pick() for model-a (2nd) error = %v", err)
|
||||||
|
}
|
||||||
|
if pickedA2.ID != "auth-a" {
|
||||||
|
t.Fatalf("Pick() for model-a (2nd) = %q, want auth-a", pickedA2.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify bindings are stable for multiple calls
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
gotA, _ := selector.Pick(context.Background(), "provider", "model-a", opts, authsForModelA)
|
||||||
|
gotB, _ := selector.Pick(context.Background(), "provider", "model-b", opts, authsForModelB)
|
||||||
|
if gotA.ID != "auth-a" {
|
||||||
|
t.Fatalf("Pick() #%d for model-a = %q, want auth-a", i, gotA.ID)
|
||||||
|
}
|
||||||
|
if gotB.ID != "auth-b" {
|
||||||
|
t.Fatalf("Pick() #%d for model-b = %q, want auth-b", i, gotB.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractSessionID_MultimodalContent(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// First request generates short hash
|
||||||
|
firstRequestPayload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"Hello world"},{"type":"image","source":{"data":"..."}}]}]}`)
|
||||||
|
shortHash := ExtractSessionID(nil, firstRequestPayload, nil)
|
||||||
|
if shortHash == "" {
|
||||||
|
t.Error("ExtractSessionID() first request should return short hash")
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(shortHash, "msg:") {
|
||||||
|
t.Errorf("ExtractSessionID() = %q, want prefix 'msg:'", shortHash)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Multi-turn generates full hash
|
||||||
|
multiTurnPayload := []byte(`{"messages":[
|
||||||
|
{"role":"user","content":[{"type":"text","text":"Hello world"},{"type":"image","source":{"data":"..."}}]},
|
||||||
|
{"role":"assistant","content":"I see an image!"},
|
||||||
|
{"role":"user","content":"What is it?"}
|
||||||
|
]}`)
|
||||||
|
fullHash := ExtractSessionID(nil, multiTurnPayload, nil)
|
||||||
|
if fullHash == "" {
|
||||||
|
t.Error("ExtractSessionID() multimodal multi-turn should return full hash")
|
||||||
|
}
|
||||||
|
if fullHash == shortHash {
|
||||||
|
t.Error("Full hash should differ from short hash")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Different user content produces different hash
|
||||||
|
differentPayload := []byte(`{"messages":[
|
||||||
|
{"role":"user","content":[{"type":"text","text":"Different content"}]},
|
||||||
|
{"role":"assistant","content":"I see something different!"}
|
||||||
|
]}`)
|
||||||
|
differentHash := ExtractSessionID(nil, differentPayload, nil)
|
||||||
|
if fullHash == differentHash {
|
||||||
|
t.Errorf("ExtractSessionID() should produce different hash for different content")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSessionAffinitySelector_CrossProviderIsolation(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
fallback := &RoundRobinSelector{}
|
||||||
|
selector := NewSessionAffinitySelectorWithConfig(SessionAffinityConfig{
|
||||||
|
Fallback: fallback,
|
||||||
|
TTL: time.Minute,
|
||||||
|
})
|
||||||
|
defer selector.Stop()
|
||||||
|
|
||||||
|
authClaude := &Auth{ID: "auth-claude"}
|
||||||
|
authGemini := &Auth{ID: "auth-gemini"}
|
||||||
|
|
||||||
|
// Same session ID for both providers
|
||||||
|
payload := []byte(`{"metadata":{"user_id":"user_xxx_account__session_cross-provider-test"}}`)
|
||||||
|
opts := cliproxyexecutor.Options{OriginalRequest: payload}
|
||||||
|
|
||||||
|
// Request via claude provider
|
||||||
|
pickedClaude, err := selector.Pick(context.Background(), "claude", "claude-3", opts, []*Auth{authClaude})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Pick() for claude error = %v", err)
|
||||||
|
}
|
||||||
|
if pickedClaude.ID != "auth-claude" {
|
||||||
|
t.Fatalf("Pick() for claude = %q, want auth-claude", pickedClaude.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Same session but via gemini provider should get different auth
|
||||||
|
pickedGemini, err := selector.Pick(context.Background(), "gemini", "gemini-2.5-pro", opts, []*Auth{authGemini})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Pick() for gemini error = %v", err)
|
||||||
|
}
|
||||||
|
if pickedGemini.ID != "auth-gemini" {
|
||||||
|
t.Fatalf("Pick() for gemini = %q, want auth-gemini", pickedGemini.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify both bindings remain stable
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
gotC, _ := selector.Pick(context.Background(), "claude", "claude-3", opts, []*Auth{authClaude})
|
||||||
|
gotG, _ := selector.Pick(context.Background(), "gemini", "gemini-2.5-pro", opts, []*Auth{authGemini})
|
||||||
|
if gotC.ID != "auth-claude" {
|
||||||
|
t.Fatalf("Pick() #%d for claude = %q, want auth-claude", i, gotC.ID)
|
||||||
|
}
|
||||||
|
if gotG.ID != "auth-gemini" {
|
||||||
|
t.Fatalf("Pick() #%d for gemini = %q, want auth-gemini", i, gotG.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSessionCache_GetAndRefresh(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
cache := NewSessionCache(100 * time.Millisecond)
|
||||||
|
defer cache.Stop()
|
||||||
|
|
||||||
|
cache.Set("session1", "auth1")
|
||||||
|
|
||||||
|
// Verify initial value
|
||||||
|
got, ok := cache.GetAndRefresh("session1")
|
||||||
|
if !ok || got != "auth1" {
|
||||||
|
t.Fatalf("GetAndRefresh() = %q, %v, want auth1, true", got, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait half TTL and access again (should refresh)
|
||||||
|
time.Sleep(60 * time.Millisecond)
|
||||||
|
got, ok = cache.GetAndRefresh("session1")
|
||||||
|
if !ok || got != "auth1" {
|
||||||
|
t.Fatalf("GetAndRefresh() after 60ms = %q, %v, want auth1, true", got, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait another 60ms (total 120ms from original, but TTL refreshed at 60ms)
|
||||||
|
// Entry should still be valid because TTL was refreshed
|
||||||
|
time.Sleep(60 * time.Millisecond)
|
||||||
|
got, ok = cache.GetAndRefresh("session1")
|
||||||
|
if !ok || got != "auth1" {
|
||||||
|
t.Fatalf("GetAndRefresh() after refresh = %q, %v, want auth1, true (TTL should have been refreshed)", got, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now wait full TTL without access
|
||||||
|
time.Sleep(110 * time.Millisecond)
|
||||||
|
got, ok = cache.GetAndRefresh("session1")
|
||||||
|
if ok {
|
||||||
|
t.Fatalf("GetAndRefresh() after expiry = %q, %v, want '', false", got, ok)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSessionAffinitySelector_RoundRobinDistribution(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
fallback := &RoundRobinSelector{}
|
||||||
|
selector := NewSessionAffinitySelectorWithConfig(SessionAffinityConfig{
|
||||||
|
Fallback: fallback,
|
||||||
|
TTL: time.Minute,
|
||||||
|
})
|
||||||
|
defer selector.Stop()
|
||||||
|
|
||||||
|
auths := []*Auth{
|
||||||
|
{ID: "auth-a"},
|
||||||
|
{ID: "auth-b"},
|
||||||
|
{ID: "auth-c"},
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionCount := 12
|
||||||
|
counts := make(map[string]int)
|
||||||
|
for i := 0; i < sessionCount; i++ {
|
||||||
|
payload := []byte(fmt.Sprintf(`{"metadata":{"user_id":"user_xxx_account__session_%08d-0000-0000-0000-000000000000"}}`, i))
|
||||||
|
opts := cliproxyexecutor.Options{OriginalRequest: payload}
|
||||||
|
got, err := selector.Pick(context.Background(), "provider", "model", opts, auths)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Pick() session %d error = %v", i, err)
|
||||||
|
}
|
||||||
|
counts[got.ID]++
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := sessionCount / len(auths)
|
||||||
|
for _, auth := range auths {
|
||||||
|
got := counts[auth.ID]
|
||||||
|
if got != expected {
|
||||||
|
t.Errorf("auth %s got %d sessions, want %d (round-robin should distribute evenly)", auth.ID, got, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSessionAffinitySelector_Concurrent(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
fallback := &RoundRobinSelector{}
|
||||||
|
selector := NewSessionAffinitySelectorWithConfig(SessionAffinityConfig{
|
||||||
|
Fallback: fallback,
|
||||||
|
TTL: time.Minute,
|
||||||
|
})
|
||||||
|
defer selector.Stop()
|
||||||
|
|
||||||
|
auths := []*Auth{
|
||||||
|
{ID: "auth-a"},
|
||||||
|
{ID: "auth-b"},
|
||||||
|
{ID: "auth-c"},
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := []byte(`{"metadata":{"user_id":"user_xxx_account__session_concurrent-test"}}`)
|
||||||
|
opts := cliproxyexecutor.Options{OriginalRequest: payload}
|
||||||
|
|
||||||
|
// First pick to establish binding
|
||||||
|
first, err := selector.Pick(context.Background(), "claude", "claude-3", opts, auths)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Initial Pick() error = %v", err)
|
||||||
|
}
|
||||||
|
expectedID := first.ID
|
||||||
|
|
||||||
|
start := make(chan struct{})
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
errCh := make(chan error, 1)
|
||||||
|
|
||||||
|
goroutines := 32
|
||||||
|
iterations := 50
|
||||||
|
for i := 0; i < goroutines; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
<-start
|
||||||
|
for j := 0; j < iterations; j++ {
|
||||||
|
got, err := selector.Pick(context.Background(), "claude", "claude-3", opts, auths)
|
||||||
|
if err != nil {
|
||||||
|
select {
|
||||||
|
case errCh <- err:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if got.ID != expectedID {
|
||||||
|
select {
|
||||||
|
case errCh <- fmt.Errorf("concurrent Pick() returned %q, want %q", got.ID, expectedID):
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
close(start)
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-errCh:
|
||||||
|
t.Fatalf("concurrent Pick() error = %v", err)
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
152
sdk/cliproxy/auth/session_cache.go
Normal file
152
sdk/cliproxy/auth/session_cache.go
Normal file
@@ -0,0 +1,152 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// sessionEntry stores auth binding with expiration.
|
||||||
|
type sessionEntry struct {
|
||||||
|
authID string
|
||||||
|
expiresAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// SessionCache provides TTL-based session to auth mapping with automatic cleanup.
|
||||||
|
type SessionCache struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
entries map[string]sessionEntry
|
||||||
|
ttl time.Duration
|
||||||
|
stopCh chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSessionCache creates a cache with the specified TTL.
|
||||||
|
// A background goroutine periodically cleans expired entries.
|
||||||
|
func NewSessionCache(ttl time.Duration) *SessionCache {
|
||||||
|
if ttl <= 0 {
|
||||||
|
ttl = 30 * time.Minute
|
||||||
|
}
|
||||||
|
c := &SessionCache{
|
||||||
|
entries: make(map[string]sessionEntry),
|
||||||
|
ttl: ttl,
|
||||||
|
stopCh: make(chan struct{}),
|
||||||
|
}
|
||||||
|
go c.cleanupLoop()
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get retrieves the auth ID bound to a session, if still valid.
|
||||||
|
// Does NOT refresh the TTL on access.
|
||||||
|
func (c *SessionCache) Get(sessionID string) (string, bool) {
|
||||||
|
if sessionID == "" {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
c.mu.RLock()
|
||||||
|
entry, ok := c.entries[sessionID]
|
||||||
|
c.mu.RUnlock()
|
||||||
|
if !ok {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
if time.Now().After(entry.expiresAt) {
|
||||||
|
c.mu.Lock()
|
||||||
|
delete(c.entries, sessionID)
|
||||||
|
c.mu.Unlock()
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
return entry.authID, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAndRefresh retrieves the auth ID bound to a session and refreshes TTL on hit.
|
||||||
|
// This extends the binding lifetime for active sessions.
|
||||||
|
func (c *SessionCache) GetAndRefresh(sessionID string) (string, bool) {
|
||||||
|
if sessionID == "" {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
now := time.Now()
|
||||||
|
c.mu.Lock()
|
||||||
|
entry, ok := c.entries[sessionID]
|
||||||
|
if !ok {
|
||||||
|
c.mu.Unlock()
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
if now.After(entry.expiresAt) {
|
||||||
|
delete(c.entries, sessionID)
|
||||||
|
c.mu.Unlock()
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
// Refresh TTL on successful access
|
||||||
|
entry.expiresAt = now.Add(c.ttl)
|
||||||
|
c.entries[sessionID] = entry
|
||||||
|
c.mu.Unlock()
|
||||||
|
return entry.authID, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set binds a session to an auth ID with TTL refresh.
|
||||||
|
func (c *SessionCache) Set(sessionID, authID string) {
|
||||||
|
if sessionID == "" || authID == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.mu.Lock()
|
||||||
|
c.entries[sessionID] = sessionEntry{
|
||||||
|
authID: authID,
|
||||||
|
expiresAt: time.Now().Add(c.ttl),
|
||||||
|
}
|
||||||
|
c.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Invalidate removes a specific session binding.
|
||||||
|
func (c *SessionCache) Invalidate(sessionID string) {
|
||||||
|
if sessionID == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.mu.Lock()
|
||||||
|
delete(c.entries, sessionID)
|
||||||
|
c.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// InvalidateAuth removes all sessions bound to a specific auth ID.
|
||||||
|
// Used when an auth becomes unavailable.
|
||||||
|
func (c *SessionCache) InvalidateAuth(authID string) {
|
||||||
|
if authID == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.mu.Lock()
|
||||||
|
for sid, entry := range c.entries {
|
||||||
|
if entry.authID == authID {
|
||||||
|
delete(c.entries, sid)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop terminates the background cleanup goroutine.
|
||||||
|
func (c *SessionCache) Stop() {
|
||||||
|
select {
|
||||||
|
case <-c.stopCh:
|
||||||
|
default:
|
||||||
|
close(c.stopCh)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SessionCache) cleanupLoop() {
|
||||||
|
ticker := time.NewTicker(c.ttl / 2)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-c.stopCh:
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
c.cleanup()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SessionCache) cleanup() {
|
||||||
|
now := time.Now()
|
||||||
|
c.mu.Lock()
|
||||||
|
for sid, entry := range c.entries {
|
||||||
|
if now.After(entry.expiresAt) {
|
||||||
|
delete(c.entries, sid)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.mu.Unlock()
|
||||||
|
}
|
||||||
@@ -6,6 +6,7 @@ package cliproxy
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
configaccess "github.com/router-for-me/CLIProxyAPI/v6/internal/access/config_access"
|
configaccess "github.com/router-for-me/CLIProxyAPI/v6/internal/access/config_access"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/api"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/api"
|
||||||
@@ -208,8 +209,17 @@ func (b *Builder) Build() (*Service, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
strategy := ""
|
strategy := ""
|
||||||
|
sessionAffinity := false
|
||||||
|
sessionAffinityTTL := time.Hour
|
||||||
if b.cfg != nil {
|
if b.cfg != nil {
|
||||||
strategy = strings.ToLower(strings.TrimSpace(b.cfg.Routing.Strategy))
|
strategy = strings.ToLower(strings.TrimSpace(b.cfg.Routing.Strategy))
|
||||||
|
// Support both legacy ClaudeCodeSessionAffinity and new universal SessionAffinity
|
||||||
|
sessionAffinity = b.cfg.Routing.ClaudeCodeSessionAffinity || b.cfg.Routing.SessionAffinity
|
||||||
|
if ttlStr := strings.TrimSpace(b.cfg.Routing.SessionAffinityTTL); ttlStr != "" {
|
||||||
|
if parsed, err := time.ParseDuration(ttlStr); err == nil && parsed > 0 {
|
||||||
|
sessionAffinityTTL = parsed
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
var selector coreauth.Selector
|
var selector coreauth.Selector
|
||||||
switch strategy {
|
switch strategy {
|
||||||
@@ -219,6 +229,14 @@ func (b *Builder) Build() (*Service, error) {
|
|||||||
selector = &coreauth.RoundRobinSelector{}
|
selector = &coreauth.RoundRobinSelector{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Wrap with session affinity if enabled (failover is always on)
|
||||||
|
if sessionAffinity {
|
||||||
|
selector = coreauth.NewSessionAffinitySelectorWithConfig(coreauth.SessionAffinityConfig{
|
||||||
|
Fallback: selector,
|
||||||
|
TTL: sessionAffinityTTL,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
coreManager = coreauth.NewManager(tokenStore, selector, nil)
|
coreManager = coreauth.NewManager(tokenStore, selector, nil)
|
||||||
}
|
}
|
||||||
// Attach a default RoundTripper provider so providers can opt-in per-auth transports.
|
// Attach a default RoundTripper provider so providers can opt-in per-auth transports.
|
||||||
|
|||||||
@@ -612,9 +612,13 @@ func (s *Service) Run(ctx context.Context) error {
|
|||||||
var watcherWrapper *WatcherWrapper
|
var watcherWrapper *WatcherWrapper
|
||||||
reloadCallback := func(newCfg *config.Config) {
|
reloadCallback := func(newCfg *config.Config) {
|
||||||
previousStrategy := ""
|
previousStrategy := ""
|
||||||
|
var previousSessionAffinity bool
|
||||||
|
var previousSessionAffinityTTL string
|
||||||
s.cfgMu.RLock()
|
s.cfgMu.RLock()
|
||||||
if s.cfg != nil {
|
if s.cfg != nil {
|
||||||
previousStrategy = strings.ToLower(strings.TrimSpace(s.cfg.Routing.Strategy))
|
previousStrategy = strings.ToLower(strings.TrimSpace(s.cfg.Routing.Strategy))
|
||||||
|
previousSessionAffinity = s.cfg.Routing.ClaudeCodeSessionAffinity || s.cfg.Routing.SessionAffinity
|
||||||
|
previousSessionAffinityTTL = s.cfg.Routing.SessionAffinityTTL
|
||||||
}
|
}
|
||||||
s.cfgMu.RUnlock()
|
s.cfgMu.RUnlock()
|
||||||
|
|
||||||
@@ -638,7 +642,15 @@ func (s *Service) Run(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
previousStrategy = normalizeStrategy(previousStrategy)
|
previousStrategy = normalizeStrategy(previousStrategy)
|
||||||
nextStrategy = normalizeStrategy(nextStrategy)
|
nextStrategy = normalizeStrategy(nextStrategy)
|
||||||
if s.coreManager != nil && previousStrategy != nextStrategy {
|
|
||||||
|
nextSessionAffinity := newCfg.Routing.ClaudeCodeSessionAffinity || newCfg.Routing.SessionAffinity
|
||||||
|
nextSessionAffinityTTL := newCfg.Routing.SessionAffinityTTL
|
||||||
|
|
||||||
|
selectorChanged := previousStrategy != nextStrategy ||
|
||||||
|
previousSessionAffinity != nextSessionAffinity ||
|
||||||
|
previousSessionAffinityTTL != nextSessionAffinityTTL
|
||||||
|
|
||||||
|
if s.coreManager != nil && selectorChanged {
|
||||||
var selector coreauth.Selector
|
var selector coreauth.Selector
|
||||||
switch nextStrategy {
|
switch nextStrategy {
|
||||||
case "fill-first":
|
case "fill-first":
|
||||||
@@ -646,6 +658,20 @@ func (s *Service) Run(ctx context.Context) error {
|
|||||||
default:
|
default:
|
||||||
selector = &coreauth.RoundRobinSelector{}
|
selector = &coreauth.RoundRobinSelector{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if nextSessionAffinity {
|
||||||
|
ttl := time.Hour
|
||||||
|
if ttlStr := strings.TrimSpace(nextSessionAffinityTTL); ttlStr != "" {
|
||||||
|
if parsed, err := time.ParseDuration(ttlStr); err == nil && parsed > 0 {
|
||||||
|
ttl = parsed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
selector = coreauth.NewSessionAffinitySelectorWithConfig(coreauth.SessionAffinityConfig{
|
||||||
|
Fallback: selector,
|
||||||
|
TTL: ttl,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
s.coreManager.SetSelector(selector)
|
s.coreManager.SetSelector(selector)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user