Merge PR #525 (v6.9.27)

This commit is contained in:
Luis Pater
2026-04-16 03:16:28 +08:00
68 changed files with 3075 additions and 3239 deletions

View File

@@ -14,7 +14,6 @@ import (
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
@@ -188,7 +187,7 @@ func PassthroughHeadersEnabled(cfg *config.SDKConfig) bool {
func requestExecutionMetadata(ctx context.Context) map[string]any {
// Idempotency-Key is an optional client-supplied header used to correlate retries.
// It is forwarded as execution metadata; when absent we generate a UUID.
// Only include it if the client explicitly provides it.
key := ""
if ctx != nil {
if ginCtx, ok := ctx.Value("gin").(*gin.Context); ok && ginCtx != nil && ginCtx.Request != nil {
@@ -196,7 +195,7 @@ func requestExecutionMetadata(ctx context.Context) map[string]any {
}
}
if key == "" {
key = uuid.NewString()
return make(map[string]any)
}
meta := map[string]any{idempotencyKeyMetadataKey: key}

View File

@@ -17,7 +17,6 @@ type ManagementTokenRequester interface {
RequestGeminiCLIToken(*gin.Context)
RequestCodexToken(*gin.Context)
RequestAntigravityToken(*gin.Context)
RequestQwenToken(*gin.Context)
RequestKimiToken(*gin.Context)
RequestIFlowToken(*gin.Context)
RequestIFlowCookieToken(*gin.Context)
@@ -52,10 +51,6 @@ func (m *managementTokenRequester) RequestAntigravityToken(c *gin.Context) {
m.handler.RequestAntigravityToken(c)
}
func (m *managementTokenRequester) RequestQwenToken(c *gin.Context) {
m.handler.RequestQwenToken(c)
}
func (m *managementTokenRequester) RequestKimiToken(c *gin.Context) {
m.handler.RequestKimiToken(c)
}

View File

@@ -39,7 +39,7 @@ func (a *KiloAuthenticator) Login(ctx context.Context, cfg *config.Config, opts
}
kilocodeAuth := kilo.NewKiloAuth()
fmt.Println("Initiating Kilo device authentication...")
resp, err := kilocodeAuth.InitiateDeviceFlow(ctx)
if err != nil {
@@ -48,7 +48,7 @@ func (a *KiloAuthenticator) Login(ctx context.Context, cfg *config.Config, opts
fmt.Printf("Please visit: %s\n", resp.VerificationURL)
fmt.Printf("And enter code: %s\n", resp.Code)
fmt.Println("Waiting for authorization...")
status, err := kilocodeAuth.PollForToken(ctx, resp.Code)
if err != nil {
@@ -68,7 +68,7 @@ func (a *KiloAuthenticator) Login(ctx context.Context, cfg *config.Config, opts
for i, org := range profile.Orgs {
fmt.Printf("[%d] %s (%s)\n", i+1, org.Name, org.ID)
}
if opts.Prompt != nil {
input, err := opts.Prompt("Enter the number of the organization: ")
if err != nil {
@@ -108,7 +108,7 @@ func (a *KiloAuthenticator) Login(ctx context.Context, cfg *config.Config, opts
metadata := map[string]any{
"email": status.UserEmail,
"organization_id": orgID,
"model": defaults.Model,
"model": defaults.Model,
}
return &coreauth.Auth{

View File

@@ -1,113 +0,0 @@
package auth
import (
"context"
"fmt"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
"github.com/router-for-me/CLIProxyAPI/v6/internal/browser"
// legacy client removed
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
log "github.com/sirupsen/logrus"
)
// QwenAuthenticator implements the device flow login for Qwen accounts.
type QwenAuthenticator struct{}
// NewQwenAuthenticator constructs a Qwen authenticator.
func NewQwenAuthenticator() *QwenAuthenticator {
return &QwenAuthenticator{}
}
func (a *QwenAuthenticator) Provider() string {
return "qwen"
}
func (a *QwenAuthenticator) RefreshLead() *time.Duration {
return new(20 * time.Minute)
}
func (a *QwenAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
if cfg == nil {
return nil, fmt.Errorf("cliproxy auth: configuration is required")
}
if ctx == nil {
ctx = context.Background()
}
if opts == nil {
opts = &LoginOptions{}
}
authSvc := qwen.NewQwenAuth(cfg)
deviceFlow, err := authSvc.InitiateDeviceFlow(ctx)
if err != nil {
return nil, fmt.Errorf("qwen device flow initiation failed: %w", err)
}
authURL := deviceFlow.VerificationURIComplete
if !opts.NoBrowser {
fmt.Println("Opening browser for Qwen authentication")
if !browser.IsAvailable() {
log.Warn("No browser available; please open the URL manually")
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
} else if err = browser.OpenURL(authURL); err != nil {
log.Warnf("Failed to open browser automatically: %v", err)
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
}
} else {
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
}
fmt.Println("Waiting for Qwen authentication...")
tokenData, err := authSvc.PollForToken(deviceFlow.DeviceCode, deviceFlow.CodeVerifier)
if err != nil {
return nil, fmt.Errorf("qwen authentication failed: %w", err)
}
tokenStorage := authSvc.CreateTokenStorage(tokenData)
email := ""
if opts.Metadata != nil {
email = opts.Metadata["email"]
if email == "" {
email = opts.Metadata["alias"]
}
}
if email == "" && opts.Prompt != nil {
email, err = opts.Prompt("Please input your email address or alias for Qwen:")
if err != nil {
return nil, err
}
}
email = strings.TrimSpace(email)
if email == "" {
return nil, &EmailRequiredError{Prompt: "Please provide an email address or alias for Qwen."}
}
tokenStorage.Email = email
// no legacy client construction
fileName := fmt.Sprintf("qwen-%s.json", tokenStorage.Email)
metadata := map[string]any{
"email": tokenStorage.Email,
}
fmt.Println("Qwen authentication successful")
return &coreauth.Auth{
ID: fileName,
Provider: a.Provider(),
FileName: fileName,
Storage: tokenStorage,
Metadata: metadata,
}, nil
}

View File

@@ -1,19 +0,0 @@
package auth
import (
"testing"
"time"
)
func TestQwenAuthenticator_RefreshLeadIsSane(t *testing.T) {
lead := NewQwenAuthenticator().RefreshLead()
if lead == nil {
t.Fatal("RefreshLead() = nil, want non-nil")
}
if *lead <= 0 {
t.Fatalf("RefreshLead() = %s, want > 0", *lead)
}
if *lead > 30*time.Minute {
t.Fatalf("RefreshLead() = %s, want <= %s", *lead, 30*time.Minute)
}
}

View File

@@ -9,7 +9,6 @@ import (
func init() {
registerRefreshLead("codex", func() Authenticator { return NewCodexAuthenticator() })
registerRefreshLead("claude", func() Authenticator { return NewClaudeAuthenticator() })
registerRefreshLead("qwen", func() Authenticator { return NewQwenAuthenticator() })
registerRefreshLead("iflow", func() Authenticator { return NewIFlowAuthenticator() })
registerRefreshLead("gemini", func() Authenticator { return NewGeminiAuthenticator() })
registerRefreshLead("gemini-cli", func() Authenticator { return NewGeminiAuthenticator() })

View File

@@ -0,0 +1,453 @@
package auth
import (
"container/heap"
"context"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
type authAutoRefreshLoop struct {
manager *Manager
interval time.Duration
concurrency int
mu sync.Mutex
queue refreshMinHeap
index map[string]*refreshHeapItem
dirty map[string]struct{}
wakeCh chan struct{}
jobs chan string
}
func newAuthAutoRefreshLoop(manager *Manager, interval time.Duration, concurrency int) *authAutoRefreshLoop {
if interval <= 0 {
interval = refreshCheckInterval
}
if concurrency <= 0 {
concurrency = refreshMaxConcurrency
}
jobBuffer := concurrency * 4
if jobBuffer < 64 {
jobBuffer = 64
}
return &authAutoRefreshLoop{
manager: manager,
interval: interval,
concurrency: concurrency,
index: make(map[string]*refreshHeapItem),
dirty: make(map[string]struct{}),
wakeCh: make(chan struct{}, 1),
jobs: make(chan string, jobBuffer),
}
}
func (l *authAutoRefreshLoop) queueReschedule(authID string) {
if l == nil || authID == "" {
return
}
l.mu.Lock()
l.dirty[authID] = struct{}{}
l.mu.Unlock()
select {
case l.wakeCh <- struct{}{}:
default:
}
}
func (l *authAutoRefreshLoop) run(ctx context.Context) {
if l == nil || l.manager == nil {
return
}
workers := l.concurrency
if workers <= 0 {
workers = refreshMaxConcurrency
}
for i := 0; i < workers; i++ {
go l.worker(ctx)
}
l.loop(ctx)
}
func (l *authAutoRefreshLoop) worker(ctx context.Context) {
for {
select {
case <-ctx.Done():
return
case authID := <-l.jobs:
if authID == "" {
continue
}
l.manager.refreshAuth(ctx, authID)
l.queueReschedule(authID)
}
}
}
func (l *authAutoRefreshLoop) rebuild(now time.Time) {
type entry struct {
id string
next time.Time
}
entries := make([]entry, 0)
l.manager.mu.RLock()
for id, auth := range l.manager.auths {
next, ok := nextRefreshCheckAt(now, auth, l.interval)
if !ok {
continue
}
entries = append(entries, entry{id: id, next: next})
}
l.manager.mu.RUnlock()
l.mu.Lock()
l.queue = l.queue[:0]
l.index = make(map[string]*refreshHeapItem, len(entries))
for _, e := range entries {
item := &refreshHeapItem{id: e.id, next: e.next}
heap.Push(&l.queue, item)
l.index[e.id] = item
}
l.mu.Unlock()
}
func (l *authAutoRefreshLoop) loop(ctx context.Context) {
timer := time.NewTimer(time.Hour)
if !timer.Stop() {
select {
case <-timer.C:
default:
}
}
defer timer.Stop()
var timerCh <-chan time.Time
l.resetTimer(timer, &timerCh, time.Now())
for {
select {
case <-ctx.Done():
return
case <-l.wakeCh:
now := time.Now()
l.applyDirty(now)
l.resetTimer(timer, &timerCh, now)
case <-timerCh:
now := time.Now()
l.handleDue(ctx, now)
l.applyDirty(now)
l.resetTimer(timer, &timerCh, now)
}
}
}
func (l *authAutoRefreshLoop) resetTimer(timer *time.Timer, timerCh *<-chan time.Time, now time.Time) {
next, ok := l.peek()
if !ok {
if !timer.Stop() {
select {
case <-timer.C:
default:
}
}
*timerCh = nil
return
}
wait := next.Sub(now)
if wait < 0 {
wait = 0
}
if !timer.Stop() {
select {
case <-timer.C:
default:
}
}
timer.Reset(wait)
*timerCh = timer.C
}
func (l *authAutoRefreshLoop) peek() (time.Time, bool) {
l.mu.Lock()
defer l.mu.Unlock()
if len(l.queue) == 0 {
return time.Time{}, false
}
return l.queue[0].next, true
}
func (l *authAutoRefreshLoop) handleDue(ctx context.Context, now time.Time) {
due := l.popDue(now)
if len(due) == 0 {
return
}
if log.IsLevelEnabled(log.DebugLevel) {
log.Debugf("auto-refresh scheduler due auths: %d", len(due))
}
for _, authID := range due {
l.handleDueAuth(ctx, now, authID)
}
}
func (l *authAutoRefreshLoop) popDue(now time.Time) []string {
l.mu.Lock()
defer l.mu.Unlock()
var due []string
for len(l.queue) > 0 {
item := l.queue[0]
if item == nil || item.next.After(now) {
break
}
popped := heap.Pop(&l.queue).(*refreshHeapItem)
if popped == nil {
continue
}
delete(l.index, popped.id)
due = append(due, popped.id)
}
return due
}
func (l *authAutoRefreshLoop) handleDueAuth(ctx context.Context, now time.Time, authID string) {
if authID == "" {
return
}
manager := l.manager
manager.mu.RLock()
auth := manager.auths[authID]
if auth == nil {
manager.mu.RUnlock()
return
}
next, shouldSchedule := nextRefreshCheckAt(now, auth, l.interval)
shouldRefresh := manager.shouldRefresh(auth, now)
exec := manager.executors[auth.Provider]
manager.mu.RUnlock()
if !shouldSchedule {
l.remove(authID)
return
}
if !shouldRefresh {
l.upsert(authID, next)
return
}
if exec == nil {
l.upsert(authID, now.Add(l.interval))
return
}
if !manager.markRefreshPending(authID, now) {
manager.mu.RLock()
auth = manager.auths[authID]
next, shouldSchedule = nextRefreshCheckAt(now, auth, l.interval)
manager.mu.RUnlock()
if shouldSchedule {
l.upsert(authID, next)
} else {
l.remove(authID)
}
return
}
select {
case <-ctx.Done():
return
case l.jobs <- authID:
}
}
func (l *authAutoRefreshLoop) applyDirty(now time.Time) {
dirty := l.drainDirty()
if len(dirty) == 0 {
return
}
for _, authID := range dirty {
l.manager.mu.RLock()
auth := l.manager.auths[authID]
next, ok := nextRefreshCheckAt(now, auth, l.interval)
l.manager.mu.RUnlock()
if !ok {
l.remove(authID)
continue
}
l.upsert(authID, next)
}
}
func (l *authAutoRefreshLoop) drainDirty() []string {
l.mu.Lock()
defer l.mu.Unlock()
if len(l.dirty) == 0 {
return nil
}
out := make([]string, 0, len(l.dirty))
for authID := range l.dirty {
out = append(out, authID)
delete(l.dirty, authID)
}
return out
}
func (l *authAutoRefreshLoop) upsert(authID string, next time.Time) {
if authID == "" || next.IsZero() {
return
}
l.mu.Lock()
defer l.mu.Unlock()
if item, ok := l.index[authID]; ok && item != nil {
item.next = next
heap.Fix(&l.queue, item.index)
return
}
item := &refreshHeapItem{id: authID, next: next}
heap.Push(&l.queue, item)
l.index[authID] = item
}
func (l *authAutoRefreshLoop) remove(authID string) {
if authID == "" {
return
}
l.mu.Lock()
defer l.mu.Unlock()
item, ok := l.index[authID]
if !ok || item == nil {
return
}
heap.Remove(&l.queue, item.index)
delete(l.index, authID)
}
func nextRefreshCheckAt(now time.Time, auth *Auth, interval time.Duration) (time.Time, bool) {
if auth == nil || auth.Disabled {
return time.Time{}, false
}
accountType, _ := auth.AccountInfo()
if accountType == "api_key" {
return time.Time{}, false
}
if !auth.NextRefreshAfter.IsZero() && now.Before(auth.NextRefreshAfter) {
return auth.NextRefreshAfter, true
}
if evaluator, ok := auth.Runtime.(RefreshEvaluator); ok && evaluator != nil {
if interval <= 0 {
interval = refreshCheckInterval
}
return now.Add(interval), true
}
lastRefresh := auth.LastRefreshedAt
if lastRefresh.IsZero() {
if ts, ok := authLastRefreshTimestamp(auth); ok {
lastRefresh = ts
}
}
expiry, hasExpiry := auth.ExpirationTime()
if pref := authPreferredInterval(auth); pref > 0 {
candidates := make([]time.Time, 0, 2)
if hasExpiry && !expiry.IsZero() {
if !expiry.After(now) || expiry.Sub(now) <= pref {
return now, true
}
candidates = append(candidates, expiry.Add(-pref))
}
if lastRefresh.IsZero() {
return now, true
}
candidates = append(candidates, lastRefresh.Add(pref))
next := candidates[0]
for _, candidate := range candidates[1:] {
if candidate.Before(next) {
next = candidate
}
}
if !next.After(now) {
return now, true
}
return next, true
}
provider := strings.ToLower(auth.Provider)
lead := ProviderRefreshLead(provider, auth.Runtime)
if lead == nil {
return time.Time{}, false
}
if hasExpiry && !expiry.IsZero() {
dueAt := expiry.Add(-*lead)
if !dueAt.After(now) {
return now, true
}
return dueAt, true
}
if !lastRefresh.IsZero() {
dueAt := lastRefresh.Add(*lead)
if !dueAt.After(now) {
return now, true
}
return dueAt, true
}
return now, true
}
type refreshHeapItem struct {
id string
next time.Time
index int
}
type refreshMinHeap []*refreshHeapItem
func (h refreshMinHeap) Len() int { return len(h) }
func (h refreshMinHeap) Less(i, j int) bool {
return h[i].next.Before(h[j].next)
}
func (h refreshMinHeap) Swap(i, j int) {
h[i], h[j] = h[j], h[i]
h[i].index = i
h[j].index = j
}
func (h *refreshMinHeap) Push(x any) {
item, ok := x.(*refreshHeapItem)
if !ok || item == nil {
return
}
item.index = len(*h)
*h = append(*h, item)
}
func (h *refreshMinHeap) Pop() any {
old := *h
n := len(old)
if n == 0 {
return (*refreshHeapItem)(nil)
}
item := old[n-1]
item.index = -1
*h = old[:n-1]
return item
}

View File

@@ -0,0 +1,137 @@
package auth
import (
"strings"
"testing"
"time"
)
type testRefreshEvaluator struct{}
func (testRefreshEvaluator) ShouldRefresh(time.Time, *Auth) bool { return false }
func setRefreshLeadFactory(t *testing.T, provider string, factory func() *time.Duration) {
t.Helper()
key := strings.ToLower(strings.TrimSpace(provider))
refreshLeadMu.Lock()
prev, hadPrev := refreshLeadFactories[key]
if factory == nil {
delete(refreshLeadFactories, key)
} else {
refreshLeadFactories[key] = factory
}
refreshLeadMu.Unlock()
t.Cleanup(func() {
refreshLeadMu.Lock()
if hadPrev {
refreshLeadFactories[key] = prev
} else {
delete(refreshLeadFactories, key)
}
refreshLeadMu.Unlock()
})
}
func TestNextRefreshCheckAt_DisabledUnschedule(t *testing.T) {
now := time.Date(2026, 4, 12, 0, 0, 0, 0, time.UTC)
auth := &Auth{ID: "a1", Provider: "test", Disabled: true}
if _, ok := nextRefreshCheckAt(now, auth, 15*time.Minute); ok {
t.Fatalf("nextRefreshCheckAt() ok = true, want false")
}
}
func TestNextRefreshCheckAt_APIKeyUnschedule(t *testing.T) {
now := time.Date(2026, 4, 12, 0, 0, 0, 0, time.UTC)
auth := &Auth{ID: "a1", Provider: "test", Attributes: map[string]string{"api_key": "k"}}
if _, ok := nextRefreshCheckAt(now, auth, 15*time.Minute); ok {
t.Fatalf("nextRefreshCheckAt() ok = true, want false")
}
}
func TestNextRefreshCheckAt_NextRefreshAfterGate(t *testing.T) {
now := time.Date(2026, 4, 12, 0, 0, 0, 0, time.UTC)
nextAfter := now.Add(30 * time.Minute)
auth := &Auth{
ID: "a1",
Provider: "test",
NextRefreshAfter: nextAfter,
Metadata: map[string]any{"email": "x@example.com"},
}
got, ok := nextRefreshCheckAt(now, auth, 15*time.Minute)
if !ok {
t.Fatalf("nextRefreshCheckAt() ok = false, want true")
}
if !got.Equal(nextAfter) {
t.Fatalf("nextRefreshCheckAt() = %s, want %s", got, nextAfter)
}
}
func TestNextRefreshCheckAt_PreferredInterval_PicksEarliestCandidate(t *testing.T) {
now := time.Date(2026, 4, 12, 0, 0, 0, 0, time.UTC)
expiry := now.Add(20 * time.Minute)
auth := &Auth{
ID: "a1",
Provider: "test",
LastRefreshedAt: now,
Metadata: map[string]any{
"email": "x@example.com",
"expires_at": expiry.Format(time.RFC3339),
"refresh_interval_seconds": 900, // 15m
},
}
got, ok := nextRefreshCheckAt(now, auth, 15*time.Minute)
if !ok {
t.Fatalf("nextRefreshCheckAt() ok = false, want true")
}
want := expiry.Add(-15 * time.Minute)
if !got.Equal(want) {
t.Fatalf("nextRefreshCheckAt() = %s, want %s", got, want)
}
}
func TestNextRefreshCheckAt_ProviderLead_Expiry(t *testing.T) {
now := time.Date(2026, 4, 12, 0, 0, 0, 0, time.UTC)
expiry := now.Add(time.Hour)
lead := 10 * time.Minute
setRefreshLeadFactory(t, "provider-lead-expiry", func() *time.Duration {
d := lead
return &d
})
auth := &Auth{
ID: "a1",
Provider: "provider-lead-expiry",
Metadata: map[string]any{
"email": "x@example.com",
"expires_at": expiry.Format(time.RFC3339),
},
}
got, ok := nextRefreshCheckAt(now, auth, 15*time.Minute)
if !ok {
t.Fatalf("nextRefreshCheckAt() ok = false, want true")
}
want := expiry.Add(-lead)
if !got.Equal(want) {
t.Fatalf("nextRefreshCheckAt() = %s, want %s", got, want)
}
}
func TestNextRefreshCheckAt_RefreshEvaluatorFallback(t *testing.T) {
now := time.Date(2026, 4, 12, 0, 0, 0, 0, time.UTC)
interval := 15 * time.Minute
auth := &Auth{
ID: "a1",
Provider: "test",
Metadata: map[string]any{"email": "x@example.com"},
Runtime: testRefreshEvaluator{},
}
got, ok := nextRefreshCheckAt(now, auth, interval)
if !ok {
t.Fatalf("nextRefreshCheckAt() ok = false, want true")
}
want := now.Add(interval)
if !got.Equal(want) {
t.Fatalf("nextRefreshCheckAt() = %s, want %s", got, want)
}
}

View File

@@ -105,6 +105,13 @@ type Selector interface {
Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error)
}
// StoppableSelector is an optional interface for selectors that hold resources.
// Selectors that implement this interface will have Stop called during shutdown.
type StoppableSelector interface {
Selector
Stop()
}
// Hook captures lifecycle callbacks for observing auth changes.
type Hook interface {
// OnAuthRegistered fires when a new auth is registered.
@@ -162,8 +169,8 @@ type Manager struct {
rtProvider RoundTripperProvider
// Auto refresh state
refreshCancel context.CancelFunc
refreshSemaphore chan struct{}
refreshCancel context.CancelFunc
refreshLoop *authAutoRefreshLoop
}
// NewManager constructs a manager with optional custom selector and hook.
@@ -182,7 +189,6 @@ func NewManager(store Store, selector Selector, hook Hook) *Manager {
auths: make(map[string]*Auth),
providerOffsets: make(map[string]int),
modelPoolOffsets: make(map[string]int),
refreshSemaphore: make(chan struct{}, refreshMaxConcurrency),
}
// atomic.Value requires non-nil initial value.
manager.runtimeConfig.Store(&internalconfig.Config{})
@@ -214,6 +220,16 @@ func (m *Manager) syncScheduler() {
m.syncSchedulerFromSnapshot(m.snapshotAuths())
}
func (m *Manager) snapshotAuths() []*Auth {
m.mu.RLock()
defer m.mu.RUnlock()
out := make([]*Auth, 0, len(m.auths))
for _, a := range m.auths {
out = append(out, a.Clone())
}
return out
}
// RefreshSchedulerEntry re-upserts a single auth into the scheduler so that its
// supportedModelSet is rebuilt from the current global model registry state.
// This must be called after models have been registered for a newly added auth,
@@ -1088,6 +1104,7 @@ func (m *Manager) Register(ctx context.Context, auth *Auth) (*Auth, error) {
if m.scheduler != nil {
m.scheduler.upsertAuth(authClone)
}
m.queueRefreshReschedule(auth.ID)
_ = m.persist(ctx, auth)
m.hook.OnAuthRegistered(ctx, auth.Clone())
return auth.Clone(), nil
@@ -1118,6 +1135,7 @@ func (m *Manager) Update(ctx context.Context, auth *Auth) (*Auth, error) {
if m.scheduler != nil {
m.scheduler.upsertAuth(authClone)
}
m.queueRefreshReschedule(auth.ID)
_ = m.persist(ctx, auth)
m.hook.OnAuthUpdated(ctx, auth.Clone())
return auth.Clone(), nil
@@ -2890,80 +2908,60 @@ func (m *Manager) StartAutoRefresh(parent context.Context, interval time.Duratio
if interval <= 0 {
interval = refreshCheckInterval
}
if m.refreshCancel != nil {
m.refreshCancel()
m.refreshCancel = nil
m.mu.Lock()
cancelPrev := m.refreshCancel
m.refreshCancel = nil
m.refreshLoop = nil
m.mu.Unlock()
if cancelPrev != nil {
cancelPrev()
}
ctx, cancel := context.WithCancel(parent)
m.refreshCancel = cancel
go func() {
ticker := time.NewTicker(interval)
defer ticker.Stop()
m.checkRefreshes(ctx)
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
m.checkRefreshes(ctx)
}
}
}()
ctx, cancelCtx := context.WithCancel(parent)
workers := refreshMaxConcurrency
if cfg, ok := m.runtimeConfig.Load().(*internalconfig.Config); ok && cfg != nil && cfg.AuthAutoRefreshWorkers > 0 {
workers = cfg.AuthAutoRefreshWorkers
}
loop := newAuthAutoRefreshLoop(m, interval, workers)
m.mu.Lock()
m.refreshCancel = cancelCtx
m.refreshLoop = loop
m.mu.Unlock()
loop.rebuild(time.Now())
go loop.run(ctx)
}
// StopAutoRefresh cancels the background refresh loop, if running.
// It also stops the selector if it implements StoppableSelector.
func (m *Manager) StopAutoRefresh() {
if m.refreshCancel != nil {
m.refreshCancel()
m.refreshCancel = nil
m.mu.Lock()
cancel := m.refreshCancel
m.refreshCancel = nil
m.refreshLoop = nil
m.mu.Unlock()
if cancel != nil {
cancel()
}
// Stop selector if it implements StoppableSelector (e.g., SessionAffinitySelector)
if stoppable, ok := m.selector.(StoppableSelector); ok {
stoppable.Stop()
}
}
func (m *Manager) checkRefreshes(ctx context.Context) {
// log.Debugf("checking refreshes")
now := time.Now()
snapshot := m.snapshotAuths()
for _, a := range snapshot {
typ, _ := a.AccountInfo()
if typ != "api_key" {
if !m.shouldRefresh(a, now) {
continue
}
log.Debugf("checking refresh for %s, %s, %s", a.Provider, a.ID, typ)
if exec := m.executorFor(a.Provider); exec == nil {
continue
}
if !m.markRefreshPending(a.ID, now) {
continue
}
go m.refreshAuthWithLimit(ctx, a.ID)
}
}
}
func (m *Manager) refreshAuthWithLimit(ctx context.Context, id string) {
if m.refreshSemaphore == nil {
m.refreshAuth(ctx, id)
func (m *Manager) queueRefreshReschedule(authID string) {
if m == nil || authID == "" {
return
}
select {
case m.refreshSemaphore <- struct{}{}:
defer func() { <-m.refreshSemaphore }()
case <-ctx.Done():
return
}
m.refreshAuth(ctx, id)
}
func (m *Manager) snapshotAuths() []*Auth {
m.mu.RLock()
defer m.mu.RUnlock()
out := make([]*Auth, 0, len(m.auths))
for _, a := range m.auths {
out = append(out, a.Clone())
loop := m.refreshLoop
m.mu.RUnlock()
if loop == nil {
return
}
return out
loop.queueReschedule(authID)
}
func (m *Manager) shouldRefresh(a *Auth, now time.Time) bool {
@@ -3173,16 +3171,20 @@ func lookupMetadataTime(meta map[string]any, keys ...string) (time.Time, bool) {
func (m *Manager) markRefreshPending(id string, now time.Time) bool {
m.mu.Lock()
defer m.mu.Unlock()
auth, ok := m.auths[id]
if !ok || auth == nil || auth.Disabled {
m.mu.Unlock()
return false
}
if !auth.NextRefreshAfter.IsZero() && now.Before(auth.NextRefreshAfter) {
m.mu.Unlock()
return false
}
auth.NextRefreshAfter = now.Add(refreshPendingBackoff)
m.auths[id] = auth
m.mu.Unlock()
m.queueRefreshReschedule(id)
return true
}
@@ -3209,16 +3211,21 @@ func (m *Manager) refreshAuth(ctx context.Context, id string) {
log.Debugf("refreshed %s, %s, %v", auth.Provider, auth.ID, err)
now := time.Now()
if err != nil {
shouldReschedule := false
m.mu.Lock()
if current := m.auths[id]; current != nil {
current.NextRefreshAfter = now.Add(refreshFailureBackoff)
current.LastError = &Error{Message: err.Error()}
m.auths[id] = current
shouldReschedule = true
if m.scheduler != nil {
m.scheduler.upsertAuth(current.Clone())
}
}
m.mu.Unlock()
if shouldReschedule {
m.queueRefreshReschedule(id)
}
return
}
if updated == nil {

View File

@@ -69,18 +69,18 @@ func TestManager_ShouldRetryAfterError_UsesOAuthModelAliasForCooldown(t *testing
m := NewManager(nil, nil, nil)
m.SetRetryConfig(3, 30*time.Second, 0)
m.SetOAuthModelAlias(map[string][]internalconfig.OAuthModelAlias{
"qwen": {
{Name: "qwen3.6-plus", Alias: "coder-model"},
"iflow": {
{Name: "deepseek-v3.1", Alias: "pool-model"},
},
})
routeModel := "coder-model"
upstreamModel := "qwen3.6-plus"
routeModel := "pool-model"
upstreamModel := "deepseek-v3.1"
next := time.Now().Add(5 * time.Second)
auth := &Auth{
ID: "auth-1",
Provider: "qwen",
Provider: "iflow",
ModelStates: map[string]*ModelState{
upstreamModel: {
Unavailable: true,
@@ -99,7 +99,7 @@ func TestManager_ShouldRetryAfterError_UsesOAuthModelAliasForCooldown(t *testing
}
_, _, maxWait := m.retrySettings()
wait, shouldRetry := m.shouldRetryAfterError(&Error{HTTPStatus: 429, Message: "quota"}, 0, []string{"qwen"}, routeModel, maxWait)
wait, shouldRetry := m.shouldRetryAfterError(&Error{HTTPStatus: 429, Message: "quota"}, 0, []string{"iflow"}, routeModel, maxWait)
if !shouldRetry {
t.Fatalf("expected shouldRetry=true, got false (wait=%v)", wait)
}

View File

@@ -265,7 +265,7 @@ func modelAliasChannel(auth *Auth) string {
// and auth kind. Returns empty string if the provider/authKind combination doesn't support
// OAuth model alias (e.g., API key authentication).
//
// Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kiro, github-copilot, kimi.
// Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, iflow, kiro, github-copilot, kimi.
func OAuthModelAliasChannel(provider, authKind string) string {
provider = strings.ToLower(strings.TrimSpace(provider))
authKind = strings.ToLower(strings.TrimSpace(authKind))
@@ -289,7 +289,7 @@ func OAuthModelAliasChannel(provider, authKind string) string {
return ""
}
return "codex"
case "gemini-cli", "aistudio", "antigravity", "qwen", "iflow", "kiro", "github-copilot", "kimi":
case "gemini-cli", "aistudio", "antigravity", "iflow", "kiro", "github-copilot", "kimi":
return provider
default:
return ""

View File

@@ -184,8 +184,6 @@ func createAuthForChannel(channel string) *Auth {
return &Auth{Provider: "aistudio"}
case "antigravity":
return &Auth{Provider: "antigravity"}
case "qwen":
return &Auth{Provider: "qwen"}
case "iflow":
return &Auth{Provider: "iflow"}
case "kimi":

View File

@@ -215,10 +215,10 @@ func TestManagerExecuteCount_OpenAICompatAliasPoolStopsOnInvalidRequest(t *testi
invalidErr := &Error{HTTPStatus: http.StatusUnprocessableEntity, Message: "unprocessable entity"}
executor := &openAICompatPoolExecutor{
id: "pool",
countErrors: map[string]error{"qwen3.5-plus": invalidErr},
countErrors: map[string]error{"deepseek-v3.1": invalidErr},
}
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
{Name: "qwen3.5-plus", Alias: alias},
{Name: "deepseek-v3.1", Alias: alias},
{Name: "glm-5", Alias: alias},
}, executor)
@@ -227,18 +227,18 @@ func TestManagerExecuteCount_OpenAICompatAliasPoolStopsOnInvalidRequest(t *testi
t.Fatalf("execute count error = %v, want %v", err, invalidErr)
}
got := executor.CountModels()
if len(got) != 1 || got[0] != "qwen3.5-plus" {
if len(got) != 1 || got[0] != "deepseek-v3.1" {
t.Fatalf("count calls = %v, want only first invalid model", got)
}
}
func TestResolveModelAliasPoolFromConfigModels(t *testing.T) {
models := []modelAliasEntry{
internalconfig.OpenAICompatibilityModel{Name: "qwen3.5-plus", Alias: "claude-opus-4.66"},
internalconfig.OpenAICompatibilityModel{Name: "deepseek-v3.1", Alias: "claude-opus-4.66"},
internalconfig.OpenAICompatibilityModel{Name: "glm-5", Alias: "claude-opus-4.66"},
internalconfig.OpenAICompatibilityModel{Name: "kimi-k2.5", Alias: "claude-opus-4.66"},
}
got := resolveModelAliasPoolFromConfigModels("claude-opus-4.66(8192)", models)
want := []string{"qwen3.5-plus(8192)", "glm-5(8192)", "kimi-k2.5(8192)"}
want := []string{"deepseek-v3.1(8192)", "glm-5(8192)", "kimi-k2.5(8192)"}
if len(got) != len(want) {
t.Fatalf("pool len = %d, want %d (%v)", len(got), len(want), got)
}
@@ -253,7 +253,7 @@ func TestManagerExecute_OpenAICompatAliasPoolRotatesWithinAuth(t *testing.T) {
alias := "claude-opus-4.66"
executor := &openAICompatPoolExecutor{id: "pool"}
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
{Name: "qwen3.5-plus", Alias: alias},
{Name: "deepseek-v3.1", Alias: alias},
{Name: "glm-5", Alias: alias},
}, executor)
@@ -268,7 +268,7 @@ func TestManagerExecute_OpenAICompatAliasPoolRotatesWithinAuth(t *testing.T) {
}
got := executor.ExecuteModels()
want := []string{"qwen3.5-plus", "glm-5", "qwen3.5-plus"}
want := []string{"deepseek-v3.1", "glm-5", "deepseek-v3.1"}
if len(got) != len(want) {
t.Fatalf("execute calls = %v, want %v", got, want)
}
@@ -284,10 +284,10 @@ func TestManagerExecute_OpenAICompatAliasPoolStopsOnBadRequest(t *testing.T) {
invalidErr := &Error{HTTPStatus: http.StatusBadRequest, Message: "invalid_request_error: malformed payload"}
executor := &openAICompatPoolExecutor{
id: "pool",
executeErrors: map[string]error{"qwen3.5-plus": invalidErr},
executeErrors: map[string]error{"deepseek-v3.1": invalidErr},
}
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
{Name: "qwen3.5-plus", Alias: alias},
{Name: "deepseek-v3.1", Alias: alias},
{Name: "glm-5", Alias: alias},
}, executor)
@@ -296,7 +296,7 @@ func TestManagerExecute_OpenAICompatAliasPoolStopsOnBadRequest(t *testing.T) {
t.Fatalf("execute error = %v, want %v", err, invalidErr)
}
got := executor.ExecuteModels()
if len(got) != 1 || got[0] != "qwen3.5-plus" {
if len(got) != 1 || got[0] != "deepseek-v3.1" {
t.Fatalf("execute calls = %v, want only first invalid model", got)
}
}
@@ -309,10 +309,10 @@ func TestManagerExecute_OpenAICompatAliasPoolFallsBackOnModelSupportBadRequest(t
}
executor := &openAICompatPoolExecutor{
id: "pool",
executeErrors: map[string]error{"qwen3.5-plus": modelSupportErr},
executeErrors: map[string]error{"deepseek-v3.1": modelSupportErr},
}
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
{Name: "qwen3.5-plus", Alias: alias},
{Name: "deepseek-v3.1", Alias: alias},
{Name: "glm-5", Alias: alias},
}, executor)
@@ -324,7 +324,7 @@ func TestManagerExecute_OpenAICompatAliasPoolFallsBackOnModelSupportBadRequest(t
t.Fatalf("payload = %q, want %q", string(resp.Payload), "glm-5")
}
got := executor.ExecuteModels()
want := []string{"qwen3.5-plus", "glm-5"}
want := []string{"deepseek-v3.1", "glm-5"}
if len(got) != len(want) {
t.Fatalf("execute calls = %v, want %v", got, want)
}
@@ -338,7 +338,7 @@ func TestManagerExecute_OpenAICompatAliasPoolFallsBackOnModelSupportBadRequest(t
if !ok || updated == nil {
t.Fatalf("expected auth to remain registered")
}
state := updated.ModelStates["qwen3.5-plus"]
state := updated.ModelStates["deepseek-v3.1"]
if state == nil {
t.Fatalf("expected suspended upstream model state")
}
@@ -355,10 +355,10 @@ func TestManagerExecute_OpenAICompatAliasPoolFallsBackOnModelSupportUnprocessabl
}
executor := &openAICompatPoolExecutor{
id: "pool",
executeErrors: map[string]error{"qwen3.5-plus": modelSupportErr},
executeErrors: map[string]error{"deepseek-v3.1": modelSupportErr},
}
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
{Name: "qwen3.5-plus", Alias: alias},
{Name: "deepseek-v3.1", Alias: alias},
{Name: "glm-5", Alias: alias},
}, executor)
@@ -370,7 +370,7 @@ func TestManagerExecute_OpenAICompatAliasPoolFallsBackOnModelSupportUnprocessabl
t.Fatalf("payload = %q, want %q", string(resp.Payload), "glm-5")
}
got := executor.ExecuteModels()
want := []string{"qwen3.5-plus", "glm-5"}
want := []string{"deepseek-v3.1", "glm-5"}
if len(got) != len(want) {
t.Fatalf("execute calls = %v, want %v", got, want)
}
@@ -385,10 +385,10 @@ func TestManagerExecute_OpenAICompatAliasPoolFallsBackWithinSameAuth(t *testing.
alias := "claude-opus-4.66"
executor := &openAICompatPoolExecutor{
id: "pool",
executeErrors: map[string]error{"qwen3.5-plus": &Error{HTTPStatus: http.StatusTooManyRequests, Message: "quota"}},
executeErrors: map[string]error{"deepseek-v3.1": &Error{HTTPStatus: http.StatusTooManyRequests, Message: "quota"}},
}
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
{Name: "qwen3.5-plus", Alias: alias},
{Name: "deepseek-v3.1", Alias: alias},
{Name: "glm-5", Alias: alias},
}, executor)
@@ -400,7 +400,7 @@ func TestManagerExecute_OpenAICompatAliasPoolFallsBackWithinSameAuth(t *testing.
t.Fatalf("payload = %q, want %q", string(resp.Payload), "glm-5")
}
got := executor.ExecuteModels()
want := []string{"qwen3.5-plus", "glm-5"}
want := []string{"deepseek-v3.1", "glm-5"}
for i := range want {
if got[i] != want[i] {
t.Fatalf("execute call %d model = %q, want %q", i, got[i], want[i])
@@ -413,11 +413,11 @@ func TestManagerExecuteStream_OpenAICompatAliasPoolRetriesOnEmptyBootstrap(t *te
executor := &openAICompatPoolExecutor{
id: "pool",
streamPayloads: map[string][]cliproxyexecutor.StreamChunk{
"qwen3.5-plus": {},
"deepseek-v3.1": {},
},
}
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
{Name: "qwen3.5-plus", Alias: alias},
{Name: "deepseek-v3.1", Alias: alias},
{Name: "glm-5", Alias: alias},
}, executor)
@@ -436,7 +436,7 @@ func TestManagerExecuteStream_OpenAICompatAliasPoolRetriesOnEmptyBootstrap(t *te
t.Fatalf("payload = %q, want %q", string(payload), "glm-5")
}
got := executor.StreamModels()
want := []string{"qwen3.5-plus", "glm-5"}
want := []string{"deepseek-v3.1", "glm-5"}
for i := range want {
if got[i] != want[i] {
t.Fatalf("stream call %d model = %q, want %q", i, got[i], want[i])
@@ -448,10 +448,10 @@ func TestManagerExecuteStream_OpenAICompatAliasPoolFallsBackBeforeFirstByte(t *t
alias := "claude-opus-4.66"
executor := &openAICompatPoolExecutor{
id: "pool",
streamFirstErrors: map[string]error{"qwen3.5-plus": &Error{HTTPStatus: http.StatusTooManyRequests, Message: "quota"}},
streamFirstErrors: map[string]error{"deepseek-v3.1": &Error{HTTPStatus: http.StatusTooManyRequests, Message: "quota"}},
}
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
{Name: "qwen3.5-plus", Alias: alias},
{Name: "deepseek-v3.1", Alias: alias},
{Name: "glm-5", Alias: alias},
}, executor)
@@ -470,7 +470,7 @@ func TestManagerExecuteStream_OpenAICompatAliasPoolFallsBackBeforeFirstByte(t *t
t.Fatalf("payload = %q, want %q", string(payload), "glm-5")
}
got := executor.StreamModels()
want := []string{"qwen3.5-plus", "glm-5"}
want := []string{"deepseek-v3.1", "glm-5"}
for i := range want {
if got[i] != want[i] {
t.Fatalf("stream call %d model = %q, want %q", i, got[i], want[i])
@@ -486,10 +486,10 @@ func TestManagerExecuteStream_OpenAICompatAliasPoolStopsOnInvalidRequest(t *test
invalidErr := &Error{HTTPStatus: http.StatusUnprocessableEntity, Message: "unprocessable entity"}
executor := &openAICompatPoolExecutor{
id: "pool",
streamFirstErrors: map[string]error{"qwen3.5-plus": invalidErr},
streamFirstErrors: map[string]error{"deepseek-v3.1": invalidErr},
}
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
{Name: "qwen3.5-plus", Alias: alias},
{Name: "deepseek-v3.1", Alias: alias},
{Name: "glm-5", Alias: alias},
}, executor)
@@ -498,7 +498,7 @@ func TestManagerExecuteStream_OpenAICompatAliasPoolStopsOnInvalidRequest(t *test
t.Fatalf("execute stream error = %v, want %v", err, invalidErr)
}
got := executor.StreamModels()
if len(got) != 1 || got[0] != "qwen3.5-plus" {
if len(got) != 1 || got[0] != "deepseek-v3.1" {
t.Fatalf("stream calls = %v, want only first invalid model", got)
}
}
@@ -511,10 +511,10 @@ func TestManagerExecute_OpenAICompatAliasPoolSkipsSuspendedUpstreamOnLaterReques
}
executor := &openAICompatPoolExecutor{
id: "pool",
executeErrors: map[string]error{"qwen3.5-plus": modelSupportErr},
executeErrors: map[string]error{"deepseek-v3.1": modelSupportErr},
}
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
{Name: "qwen3.5-plus", Alias: alias},
{Name: "deepseek-v3.1", Alias: alias},
{Name: "glm-5", Alias: alias},
}, executor)
@@ -529,7 +529,7 @@ func TestManagerExecute_OpenAICompatAliasPoolSkipsSuspendedUpstreamOnLaterReques
}
got := executor.ExecuteModels()
want := []string{"qwen3.5-plus", "glm-5", "glm-5", "glm-5"}
want := []string{"deepseek-v3.1", "glm-5", "glm-5", "glm-5"}
if len(got) != len(want) {
t.Fatalf("execute calls = %v, want %v", got, want)
}
@@ -548,10 +548,10 @@ func TestManagerExecuteStream_OpenAICompatAliasPoolSkipsSuspendedUpstreamOnLater
}
executor := &openAICompatPoolExecutor{
id: "pool",
streamFirstErrors: map[string]error{"qwen3.5-plus": modelSupportErr},
streamFirstErrors: map[string]error{"deepseek-v3.1": modelSupportErr},
}
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
{Name: "qwen3.5-plus", Alias: alias},
{Name: "deepseek-v3.1", Alias: alias},
{Name: "glm-5", Alias: alias},
}, executor)
@@ -569,7 +569,7 @@ func TestManagerExecuteStream_OpenAICompatAliasPoolSkipsSuspendedUpstreamOnLater
}
got := executor.StreamModels()
want := []string{"qwen3.5-plus", "glm-5", "glm-5", "glm-5"}
want := []string{"deepseek-v3.1", "glm-5", "glm-5", "glm-5"}
if len(got) != len(want) {
t.Fatalf("stream calls = %v, want %v", got, want)
}
@@ -584,7 +584,7 @@ func TestManagerExecuteCount_OpenAICompatAliasPoolRotatesWithinAuth(t *testing.T
alias := "claude-opus-4.66"
executor := &openAICompatPoolExecutor{id: "pool"}
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
{Name: "qwen3.5-plus", Alias: alias},
{Name: "deepseek-v3.1", Alias: alias},
{Name: "glm-5", Alias: alias},
}, executor)
@@ -599,7 +599,7 @@ func TestManagerExecuteCount_OpenAICompatAliasPoolRotatesWithinAuth(t *testing.T
}
got := executor.CountModels()
want := []string{"qwen3.5-plus", "glm-5"}
want := []string{"deepseek-v3.1", "glm-5"}
for i := range want {
if got[i] != want[i] {
t.Fatalf("count call %d model = %q, want %q", i, got[i], want[i])
@@ -615,10 +615,10 @@ func TestManagerExecuteCount_OpenAICompatAliasPoolSkipsSuspendedUpstreamOnLaterR
}
executor := &openAICompatPoolExecutor{
id: "pool",
countErrors: map[string]error{"qwen3.5-plus": modelSupportErr},
countErrors: map[string]error{"deepseek-v3.1": modelSupportErr},
}
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
{Name: "qwen3.5-plus", Alias: alias},
{Name: "deepseek-v3.1", Alias: alias},
{Name: "glm-5", Alias: alias},
}, executor)
@@ -633,7 +633,7 @@ func TestManagerExecuteCount_OpenAICompatAliasPoolSkipsSuspendedUpstreamOnLaterR
}
got := executor.CountModels()
want := []string{"qwen3.5-plus", "glm-5", "glm-5", "glm-5"}
want := []string{"deepseek-v3.1", "glm-5", "glm-5", "glm-5"}
if len(got) != len(want) {
t.Fatalf("count calls = %v, want %v", got, want)
}
@@ -650,7 +650,7 @@ func TestManagerExecute_OpenAICompatAliasPoolBlockedAuthDoesNotConsumeRetryBudge
OpenAICompatibility: []internalconfig.OpenAICompatibility{{
Name: "pool",
Models: []internalconfig.OpenAICompatibilityModel{
{Name: "qwen3.5-plus", Alias: alias},
{Name: "deepseek-v3.1", Alias: alias},
{Name: "glm-5", Alias: alias},
},
}},
@@ -701,7 +701,7 @@ func TestManagerExecute_OpenAICompatAliasPoolBlockedAuthDoesNotConsumeRetryBudge
HTTPStatus: http.StatusBadRequest,
Message: "invalid_request_error: The requested model is not supported.",
}
for _, upstreamModel := range []string{"qwen3.5-plus", "glm-5"} {
for _, upstreamModel := range []string{"deepseek-v3.1", "glm-5"} {
m.MarkResult(context.Background(), Result{
AuthID: badAuth.ID,
Provider: "pool",
@@ -733,10 +733,10 @@ func TestManagerExecuteStream_OpenAICompatAliasPoolStopsOnInvalidBootstrap(t *te
invalidErr := &Error{HTTPStatus: http.StatusBadRequest, Message: "invalid_request_error: malformed payload"}
executor := &openAICompatPoolExecutor{
id: "pool",
streamFirstErrors: map[string]error{"qwen3.5-plus": invalidErr},
streamFirstErrors: map[string]error{"deepseek-v3.1": invalidErr},
}
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
{Name: "qwen3.5-plus", Alias: alias},
{Name: "deepseek-v3.1", Alias: alias},
{Name: "glm-5", Alias: alias},
}, executor)
@@ -750,7 +750,7 @@ func TestManagerExecuteStream_OpenAICompatAliasPoolStopsOnInvalidBootstrap(t *te
if streamResult != nil {
t.Fatalf("streamResult = %#v, want nil on invalid bootstrap", streamResult)
}
if got := executor.StreamModels(); len(got) != 1 || got[0] != "qwen3.5-plus" {
if got := executor.StreamModels(); len(got) != 1 || got[0] != "deepseek-v3.1" {
t.Fatalf("stream calls = %v, want only first upstream model", got)
}
}

View File

@@ -4,15 +4,21 @@ import (
"context"
"encoding/json"
"fmt"
"hash/fnv"
"math"
"math/rand/v2"
"net/http"
"regexp"
"sort"
"strconv"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
)
@@ -420,3 +426,448 @@ func isAuthBlockedForModel(auth *Auth, model string, now time.Time) (bool, block
}
return false, blockReasonNone, time.Time{}
}
// sessionPattern matches Claude Code user_id format:
// user_{hash}_account__session_{uuid}
var sessionPattern = regexp.MustCompile(`_session_([a-f0-9-]+)$`)
// SessionAffinitySelector wraps another selector with session-sticky behavior.
// It extracts session ID from multiple sources and maintains session-to-auth
// mappings with automatic failover when the bound auth becomes unavailable.
type SessionAffinitySelector struct {
fallback Selector
cache *SessionCache
}
// SessionAffinityConfig configures the session affinity selector.
type SessionAffinityConfig struct {
Fallback Selector
TTL time.Duration
}
// NewSessionAffinitySelector creates a new session-aware selector.
func NewSessionAffinitySelector(fallback Selector) *SessionAffinitySelector {
return NewSessionAffinitySelectorWithConfig(SessionAffinityConfig{
Fallback: fallback,
TTL: time.Hour,
})
}
// NewSessionAffinitySelectorWithConfig creates a selector with custom configuration.
func NewSessionAffinitySelectorWithConfig(cfg SessionAffinityConfig) *SessionAffinitySelector {
if cfg.Fallback == nil {
cfg.Fallback = &RoundRobinSelector{}
}
if cfg.TTL <= 0 {
cfg.TTL = time.Hour
}
return &SessionAffinitySelector{
fallback: cfg.Fallback,
cache: NewSessionCache(cfg.TTL),
}
}
// Pick selects an auth with session affinity when possible.
// Priority for session ID extraction:
// 1. metadata.user_id (Claude Code format) - highest priority
// 2. X-Session-ID header
// 3. metadata.user_id (non-Claude Code format)
// 4. conversation_id field
// 5. Hash-based fallback from messages
//
// Note: The cache key includes provider, session ID, and model to handle cases where
// a session uses multiple models (e.g., gemini-2.5-pro and gemini-3-flash-preview)
// that may be supported by different auth credentials, and to avoid cross-provider conflicts.
func (s *SessionAffinitySelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) {
entry := selectorLogEntry(ctx)
primaryID, fallbackID := extractSessionIDs(opts.Headers, opts.OriginalRequest, opts.Metadata)
if primaryID == "" {
entry.Debugf("session-affinity: no session ID extracted, falling back to default selector | provider=%s model=%s", provider, model)
return s.fallback.Pick(ctx, provider, model, opts, auths)
}
now := time.Now()
available, err := getAvailableAuths(auths, provider, model, now)
if err != nil {
return nil, err
}
cacheKey := provider + "::" + primaryID + "::" + model
if cachedAuthID, ok := s.cache.GetAndRefresh(cacheKey); ok {
for _, auth := range available {
if auth.ID == cachedAuthID {
entry.Infof("session-affinity: cache hit | session=%s auth=%s provider=%s model=%s", truncateSessionID(primaryID), auth.ID, provider, model)
return auth, nil
}
}
// Cached auth not available, reselect via fallback selector for even distribution
auth, err := s.fallback.Pick(ctx, provider, model, opts, auths)
if err != nil {
return nil, err
}
s.cache.Set(cacheKey, auth.ID)
entry.Infof("session-affinity: cache hit but auth unavailable, reselected | session=%s auth=%s provider=%s model=%s", truncateSessionID(primaryID), auth.ID, provider, model)
return auth, nil
}
if fallbackID != "" && fallbackID != primaryID {
fallbackKey := provider + "::" + fallbackID + "::" + model
if cachedAuthID, ok := s.cache.Get(fallbackKey); ok {
for _, auth := range available {
if auth.ID == cachedAuthID {
s.cache.Set(cacheKey, auth.ID)
entry.Infof("session-affinity: fallback cache hit | session=%s fallback=%s auth=%s provider=%s model=%s", truncateSessionID(primaryID), truncateSessionID(fallbackID), auth.ID, provider, model)
return auth, nil
}
}
}
}
auth, err := s.fallback.Pick(ctx, provider, model, opts, auths)
if err != nil {
return nil, err
}
s.cache.Set(cacheKey, auth.ID)
entry.Infof("session-affinity: cache miss, new binding | session=%s auth=%s provider=%s model=%s", truncateSessionID(primaryID), auth.ID, provider, model)
return auth, nil
}
func selectorLogEntry(ctx context.Context) *log.Entry {
if ctx == nil {
return log.NewEntry(log.StandardLogger())
}
if reqID := logging.GetRequestID(ctx); reqID != "" {
return log.WithField("request_id", reqID)
}
return log.NewEntry(log.StandardLogger())
}
// truncateSessionID shortens session ID for logging (first 8 chars + "...")
func truncateSessionID(id string) string {
if len(id) <= 20 {
return id
}
return id[:8] + "..."
}
// Stop releases resources held by the selector.
func (s *SessionAffinitySelector) Stop() {
if s.cache != nil {
s.cache.Stop()
}
}
// InvalidateAuth removes all session bindings for a specific auth.
// Called when an auth becomes rate-limited or unavailable.
func (s *SessionAffinitySelector) InvalidateAuth(authID string) {
if s.cache != nil {
s.cache.InvalidateAuth(authID)
}
}
// ExtractSessionID extracts session identifier from multiple sources.
// Priority order:
// 1. metadata.user_id (Claude Code format with _session_{uuid}) - highest priority for Claude Code clients
// 2. X-Session-ID header
// 3. metadata.user_id (non-Claude Code format)
// 4. conversation_id field in request body
// 5. Stable hash from first few messages content (fallback)
func ExtractSessionID(headers http.Header, payload []byte, metadata map[string]any) string {
primary, _ := extractSessionIDs(headers, payload, metadata)
return primary
}
// extractSessionIDs returns (primaryID, fallbackID) for session affinity.
// primaryID: full hash including assistant response (stable after first turn)
// fallbackID: short hash without assistant (used to inherit binding from first turn)
func extractSessionIDs(headers http.Header, payload []byte, metadata map[string]any) (string, string) {
// 1. metadata.user_id with Claude Code session format (highest priority)
if len(payload) > 0 {
userID := gjson.GetBytes(payload, "metadata.user_id").String()
if userID != "" {
// Old format: user_{hash}_account__session_{uuid}
if matches := sessionPattern.FindStringSubmatch(userID); len(matches) >= 2 {
id := "claude:" + matches[1]
return id, ""
}
// New format: JSON object with session_id field
// e.g. {"device_id":"...","account_uuid":"...","session_id":"uuid"}
if len(userID) > 0 && userID[0] == '{' {
if sid := gjson.Get(userID, "session_id").String(); sid != "" {
return "claude:" + sid, ""
}
}
}
}
// 2. X-Session-ID header
if headers != nil {
if sid := headers.Get("X-Session-ID"); sid != "" {
return "header:" + sid, ""
}
}
if len(payload) == 0 {
return "", ""
}
// 3. metadata.user_id (non-Claude Code format)
userID := gjson.GetBytes(payload, "metadata.user_id").String()
if userID != "" {
return "user:" + userID, ""
}
// 4. conversation_id field
if convID := gjson.GetBytes(payload, "conversation_id").String(); convID != "" {
return "conv:" + convID, ""
}
// 5. Hash-based fallback from message content
return extractMessageHashIDs(payload)
}
func extractMessageHashIDs(payload []byte) (primaryID, fallbackID string) {
var systemPrompt, firstUserMsg, firstAssistantMsg string
// OpenAI/Claude messages format
messages := gjson.GetBytes(payload, "messages")
if messages.Exists() && messages.IsArray() {
messages.ForEach(func(_, msg gjson.Result) bool {
role := msg.Get("role").String()
content := extractMessageContent(msg.Get("content"))
if content == "" {
return true
}
switch role {
case "system":
if systemPrompt == "" {
systemPrompt = truncateString(content, 100)
}
case "user":
if firstUserMsg == "" {
firstUserMsg = truncateString(content, 100)
}
case "assistant":
if firstAssistantMsg == "" {
firstAssistantMsg = truncateString(content, 100)
}
}
if systemPrompt != "" && firstUserMsg != "" && firstAssistantMsg != "" {
return false
}
return true
})
}
// Claude API: top-level "system" field (array or string)
if systemPrompt == "" {
topSystem := gjson.GetBytes(payload, "system")
if topSystem.Exists() {
if topSystem.IsArray() {
topSystem.ForEach(func(_, part gjson.Result) bool {
if text := part.Get("text").String(); text != "" && systemPrompt == "" {
systemPrompt = truncateString(text, 100)
return false
}
return true
})
} else if topSystem.Type == gjson.String {
systemPrompt = truncateString(topSystem.String(), 100)
}
}
}
// Gemini format
if systemPrompt == "" && firstUserMsg == "" {
sysInstr := gjson.GetBytes(payload, "systemInstruction.parts")
if sysInstr.Exists() && sysInstr.IsArray() {
sysInstr.ForEach(func(_, part gjson.Result) bool {
if text := part.Get("text").String(); text != "" && systemPrompt == "" {
systemPrompt = truncateString(text, 100)
return false
}
return true
})
}
contents := gjson.GetBytes(payload, "contents")
if contents.Exists() && contents.IsArray() {
contents.ForEach(func(_, msg gjson.Result) bool {
role := msg.Get("role").String()
msg.Get("parts").ForEach(func(_, part gjson.Result) bool {
text := part.Get("text").String()
if text == "" {
return true
}
switch role {
case "user":
if firstUserMsg == "" {
firstUserMsg = truncateString(text, 100)
}
case "model":
if firstAssistantMsg == "" {
firstAssistantMsg = truncateString(text, 100)
}
}
return false
})
if firstUserMsg != "" && firstAssistantMsg != "" {
return false
}
return true
})
}
}
// OpenAI Responses API format (v1/responses)
if systemPrompt == "" && firstUserMsg == "" {
if instr := gjson.GetBytes(payload, "instructions").String(); instr != "" {
systemPrompt = truncateString(instr, 100)
}
input := gjson.GetBytes(payload, "input")
if input.Exists() && input.IsArray() {
input.ForEach(func(_, item gjson.Result) bool {
itemType := item.Get("type").String()
if itemType == "reasoning" {
return true
}
// Skip non-message typed items (function_call, function_call_output, etc.)
// but allow items with no type that have a role (inline message format).
if itemType != "" && itemType != "message" {
return true
}
role := item.Get("role").String()
if itemType == "" && role == "" {
return true
}
// Handle both string content and array content (multimodal).
content := item.Get("content")
var text string
if content.Type == gjson.String {
text = content.String()
} else {
text = extractResponsesAPIContent(content)
}
if text == "" {
return true
}
switch role {
case "developer", "system":
if systemPrompt == "" {
systemPrompt = truncateString(text, 100)
}
case "user":
if firstUserMsg == "" {
firstUserMsg = truncateString(text, 100)
}
case "assistant":
if firstAssistantMsg == "" {
firstAssistantMsg = truncateString(text, 100)
}
}
if firstUserMsg != "" && firstAssistantMsg != "" {
return false
}
return true
})
}
}
if systemPrompt == "" && firstUserMsg == "" {
return "", ""
}
shortHash := computeSessionHash(systemPrompt, firstUserMsg, "")
if firstAssistantMsg == "" {
return shortHash, ""
}
fullHash := computeSessionHash(systemPrompt, firstUserMsg, firstAssistantMsg)
return fullHash, shortHash
}
func computeSessionHash(systemPrompt, userMsg, assistantMsg string) string {
h := fnv.New64a()
if systemPrompt != "" {
h.Write([]byte("sys:" + systemPrompt + "\n"))
}
if userMsg != "" {
h.Write([]byte("usr:" + userMsg + "\n"))
}
if assistantMsg != "" {
h.Write([]byte("ast:" + assistantMsg + "\n"))
}
return fmt.Sprintf("msg:%016x", h.Sum64())
}
func truncateString(s string, maxLen int) string {
if len(s) > maxLen {
return s[:maxLen]
}
return s
}
// extractMessageContent extracts text content from a message content field.
// Handles both string content and array content (multimodal messages).
// For array content, extracts text from all text-type elements.
func extractMessageContent(content gjson.Result) string {
// String content: "Hello world"
if content.Type == gjson.String {
return content.String()
}
// Array content: [{"type":"text","text":"Hello"},{"type":"image",...}]
if content.IsArray() {
var texts []string
content.ForEach(func(_, part gjson.Result) bool {
// Handle Claude format: {"type":"text","text":"content"}
if part.Get("type").String() == "text" {
if text := part.Get("text").String(); text != "" {
texts = append(texts, text)
}
}
// Handle OpenAI format: {"type":"text","text":"content"}
// Same structure as Claude, already handled above
return true
})
if len(texts) > 0 {
return strings.Join(texts, " ")
}
}
return ""
}
func extractResponsesAPIContent(content gjson.Result) string {
if !content.IsArray() {
return ""
}
var texts []string
content.ForEach(func(_, part gjson.Result) bool {
partType := part.Get("type").String()
if partType == "input_text" || partType == "output_text" || partType == "text" {
if text := part.Get("text").String(); text != "" {
texts = append(texts, text)
}
}
return true
})
if len(texts) > 0 {
return strings.Join(texts, " ")
}
return ""
}
// extractSessionID is kept for backward compatibility.
// Deprecated: Use ExtractSessionID instead.
func extractSessionID(payload []byte) string {
return ExtractSessionID(nil, payload, nil)
}

View File

@@ -4,7 +4,9 @@ import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
"sync"
"testing"
"time"
@@ -458,6 +460,159 @@ func TestRoundRobinSelectorPick_GeminiCLICredentialGrouping(t *testing.T) {
}
}
func TestExtractSessionID(t *testing.T) {
t.Parallel()
tests := []struct {
name string
payload string
want string
}{
{
name: "valid_claude_code_format",
payload: `{"metadata":{"user_id":"user_3f221fe75652cf9a89a31647f16274bb8036a9b85ac4dc226a4df0efec8dc04d_account__session_ac980658-63bd-4fb3-97ba-8da64cb1e344"}}`,
want: "claude:ac980658-63bd-4fb3-97ba-8da64cb1e344",
},
{
name: "json_user_id_with_session_id",
payload: `{"metadata":{"user_id":"{\"device_id\":\"be82c3aee1e0c2d74535bacc85f9f559228f02dd8a17298cf522b71e6c375714\",\"account_uuid\":\"\",\"session_id\":\"e26d4046-0f88-4b09-bb5b-f863ab5fb24e\"}"}}`,
want: "claude:e26d4046-0f88-4b09-bb5b-f863ab5fb24e",
},
{
name: "json_user_id_without_session_id",
payload: `{"metadata":{"user_id":"{\"device_id\":\"abc123\"}"}}`,
want: `user:{"device_id":"abc123"}`,
},
{
name: "no_session_but_user_id",
payload: `{"metadata":{"user_id":"user_abc123"}}`,
want: "user:user_abc123",
},
{
name: "conversation_id",
payload: `{"conversation_id":"conv-12345"}`,
want: "conv:conv-12345",
},
{
name: "no_metadata",
payload: `{"model":"claude-3"}`,
want: "",
},
{
name: "empty_payload",
payload: ``,
want: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := extractSessionID([]byte(tt.payload))
if got != tt.want {
t.Errorf("extractSessionID() = %q, want %q", got, tt.want)
}
})
}
}
func TestSessionAffinitySelector_SameSessionSameAuth(t *testing.T) {
t.Parallel()
fallback := &RoundRobinSelector{}
selector := NewSessionAffinitySelector(fallback)
auths := []*Auth{
{ID: "auth-a"},
{ID: "auth-b"},
{ID: "auth-c"},
}
// Use valid UUID format for session ID
payload := []byte(`{"metadata":{"user_id":"user_xxx_account__session_ac980658-63bd-4fb3-97ba-8da64cb1e344"}}`)
opts := cliproxyexecutor.Options{OriginalRequest: payload}
// Same session should always pick the same auth
first, err := selector.Pick(context.Background(), "claude", "claude-3", opts, auths)
if err != nil {
t.Fatalf("Pick() error = %v", err)
}
if first == nil {
t.Fatalf("Pick() returned nil")
}
// Verify consistency: same session, same auths -> same result
for i := 0; i < 10; i++ {
got, err := selector.Pick(context.Background(), "claude", "claude-3", opts, auths)
if err != nil {
t.Fatalf("Pick() #%d error = %v", i, err)
}
if got.ID != first.ID {
t.Fatalf("Pick() #%d auth.ID = %q, want %q (same session should pick same auth)", i, got.ID, first.ID)
}
}
}
func TestSessionAffinitySelector_NoSessionFallback(t *testing.T) {
t.Parallel()
fallback := &FillFirstSelector{}
selector := NewSessionAffinitySelector(fallback)
auths := []*Auth{
{ID: "auth-b"},
{ID: "auth-a"},
{ID: "auth-c"},
}
// No session in payload, should fallback to FillFirstSelector (picks "auth-a" after sorting)
payload := []byte(`{"model":"claude-3"}`)
opts := cliproxyexecutor.Options{OriginalRequest: payload}
got, err := selector.Pick(context.Background(), "claude", "claude-3", opts, auths)
if err != nil {
t.Fatalf("Pick() error = %v", err)
}
if got.ID != "auth-a" {
t.Fatalf("Pick() auth.ID = %q, want %q (should fallback to FillFirst)", got.ID, "auth-a")
}
}
func TestSessionAffinitySelector_DifferentSessionsDifferentAuths(t *testing.T) {
t.Parallel()
fallback := &RoundRobinSelector{}
selector := NewSessionAffinitySelector(fallback)
auths := []*Auth{
{ID: "auth-a"},
{ID: "auth-b"},
{ID: "auth-c"},
}
// Use valid UUID format for session IDs
session1 := []byte(`{"metadata":{"user_id":"user_xxx_account__session_11111111-1111-1111-1111-111111111111"}}`)
session2 := []byte(`{"metadata":{"user_id":"user_xxx_account__session_22222222-2222-2222-2222-222222222222"}}`)
opts1 := cliproxyexecutor.Options{OriginalRequest: session1}
opts2 := cliproxyexecutor.Options{OriginalRequest: session2}
auth1, _ := selector.Pick(context.Background(), "claude", "claude-3", opts1, auths)
auth2, _ := selector.Pick(context.Background(), "claude", "claude-3", opts2, auths)
// Different sessions may or may not pick different auths (depends on hash collision)
// But each session should be consistent
for i := 0; i < 5; i++ {
got1, _ := selector.Pick(context.Background(), "claude", "claude-3", opts1, auths)
got2, _ := selector.Pick(context.Background(), "claude", "claude-3", opts2, auths)
if got1.ID != auth1.ID {
t.Fatalf("session1 Pick() #%d inconsistent: got %q, want %q", i, got1.ID, auth1.ID)
}
if got2.ID != auth2.ID {
t.Fatalf("session2 Pick() #%d inconsistent: got %q, want %q", i, got2.ID, auth2.ID)
}
}
}
func TestRoundRobinSelectorPick_SingleParentFallsBackToFlat(t *testing.T) {
t.Parallel()
@@ -494,6 +649,57 @@ func TestRoundRobinSelectorPick_SingleParentFallsBackToFlat(t *testing.T) {
}
}
func TestSessionAffinitySelector_FailoverWhenAuthUnavailable(t *testing.T) {
t.Parallel()
fallback := &RoundRobinSelector{}
selector := NewSessionAffinitySelectorWithConfig(SessionAffinityConfig{
Fallback: fallback,
TTL: time.Minute,
})
defer selector.Stop()
auths := []*Auth{
{ID: "auth-a"},
{ID: "auth-b"},
{ID: "auth-c"},
}
payload := []byte(`{"metadata":{"user_id":"user_xxx_account__session_failover-test-uuid"}}`)
opts := cliproxyexecutor.Options{OriginalRequest: payload}
// First pick establishes binding
first, err := selector.Pick(context.Background(), "claude", "claude-3", opts, auths)
if err != nil {
t.Fatalf("Pick() error = %v", err)
}
// Remove the bound auth from available list (simulating rate limit)
availableWithoutFirst := make([]*Auth, 0, len(auths)-1)
for _, a := range auths {
if a.ID != first.ID {
availableWithoutFirst = append(availableWithoutFirst, a)
}
}
// With failover enabled, should pick a new auth
second, err := selector.Pick(context.Background(), "claude", "claude-3", opts, availableWithoutFirst)
if err != nil {
t.Fatalf("Pick() after failover error = %v", err)
}
if second.ID == first.ID {
t.Fatalf("Pick() after failover returned same auth %q, expected different", first.ID)
}
// Subsequent picks should consistently return the new binding
for i := 0; i < 5; i++ {
got, _ := selector.Pick(context.Background(), "claude", "claude-3", opts, availableWithoutFirst)
if got.ID != second.ID {
t.Fatalf("Pick() #%d after failover inconsistent: got %q, want %q", i, got.ID, second.ID)
}
}
}
func TestRoundRobinSelectorPick_MixedVirtualAndNonVirtualFallsBackToFlat(t *testing.T) {
t.Parallel()
@@ -527,3 +733,629 @@ func TestRoundRobinSelectorPick_MixedVirtualAndNonVirtualFallsBackToFlat(t *test
}
}
}
func TestExtractSessionID_ClaudeCodePriorityOverHeader(t *testing.T) {
t.Parallel()
// Claude Code metadata.user_id should have highest priority, even when X-Session-ID header is present
headers := make(http.Header)
headers.Set("X-Session-ID", "header-session-id")
payload := []byte(`{"metadata":{"user_id":"user_xxx_account__session_ac980658-63bd-4fb3-97ba-8da64cb1e344"}}`)
got := ExtractSessionID(headers, payload, nil)
want := "claude:ac980658-63bd-4fb3-97ba-8da64cb1e344"
if got != want {
t.Errorf("ExtractSessionID() = %q, want %q (Claude Code should have highest priority over header)", got, want)
}
}
func TestExtractSessionID_ClaudeCodePriorityOverIdempotencyKey(t *testing.T) {
t.Parallel()
// Claude Code metadata.user_id should have highest priority, even when idempotency_key is present
metadata := map[string]any{"idempotency_key": "idem-12345"}
payload := []byte(`{"metadata":{"user_id":"user_xxx_account__session_ac980658-63bd-4fb3-97ba-8da64cb1e344"}}`)
got := ExtractSessionID(nil, payload, metadata)
want := "claude:ac980658-63bd-4fb3-97ba-8da64cb1e344"
if got != want {
t.Errorf("ExtractSessionID() = %q, want %q (Claude Code should have highest priority over idempotency_key)", got, want)
}
}
func TestExtractSessionID_Headers(t *testing.T) {
t.Parallel()
headers := make(http.Header)
headers.Set("X-Session-ID", "my-explicit-session")
got := ExtractSessionID(headers, nil, nil)
want := "header:my-explicit-session"
if got != want {
t.Errorf("ExtractSessionID() with header = %q, want %q", got, want)
}
}
// TestExtractSessionID_IdempotencyKey verifies that idempotency_key is intentionally
// ignored for session affinity (it's auto-generated per-request, causing cache misses).
func TestExtractSessionID_IdempotencyKey(t *testing.T) {
t.Parallel()
metadata := map[string]any{"idempotency_key": "idem-12345"}
got := ExtractSessionID(nil, nil, metadata)
// idempotency_key is disabled - should return empty (no payload to hash)
if got != "" {
t.Errorf("ExtractSessionID() with idempotency_key = %q, want empty (idempotency_key is disabled)", got)
}
}
func TestExtractSessionID_MessageHashFallback(t *testing.T) {
t.Parallel()
// First request (user only) generates short hash
firstRequestPayload := []byte(`{"messages":[{"role":"user","content":"Hello world"}]}`)
shortHash := ExtractSessionID(nil, firstRequestPayload, nil)
if shortHash == "" {
t.Error("ExtractSessionID() first request should return short hash")
}
if !strings.HasPrefix(shortHash, "msg:") {
t.Errorf("ExtractSessionID() = %q, want prefix 'msg:'", shortHash)
}
// Multi-turn with assistant generates full hash (different from short hash)
multiTurnPayload := []byte(`{"messages":[
{"role":"user","content":"Hello world"},
{"role":"assistant","content":"Hi! How can I help?"},
{"role":"user","content":"Tell me a joke"}
]}`)
fullHash := ExtractSessionID(nil, multiTurnPayload, nil)
if fullHash == "" {
t.Error("ExtractSessionID() multi-turn should return full hash")
}
if fullHash == shortHash {
t.Error("Full hash should differ from short hash (includes assistant)")
}
// Same multi-turn payload should produce same hash
fullHash2 := ExtractSessionID(nil, multiTurnPayload, nil)
if fullHash != fullHash2 {
t.Errorf("ExtractSessionID() not stable: got %q then %q", fullHash, fullHash2)
}
}
func TestExtractSessionID_ClaudeAPITopLevelSystem(t *testing.T) {
t.Parallel()
// Claude API: system prompt in top-level "system" field (array format)
arraySystem := []byte(`{
"messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}],
"system": [{"type": "text", "text": "You are Claude Code"}]
}`)
got1 := ExtractSessionID(nil, arraySystem, nil)
if got1 == "" || !strings.HasPrefix(got1, "msg:") {
t.Errorf("ExtractSessionID() with array system = %q, want msg:* prefix", got1)
}
// Claude API: system prompt in top-level "system" field (string format)
stringSystem := []byte(`{
"messages": [{"role": "user", "content": "Hello"}],
"system": "You are Claude Code"
}`)
got2 := ExtractSessionID(nil, stringSystem, nil)
if got2 == "" || !strings.HasPrefix(got2, "msg:") {
t.Errorf("ExtractSessionID() with string system = %q, want msg:* prefix", got2)
}
// Multi-turn with top-level system should produce stable hash
multiTurn := []byte(`{
"messages": [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi!"},
{"role": "user", "content": "Help me"}
],
"system": "You are Claude Code"
}`)
got3 := ExtractSessionID(nil, multiTurn, nil)
if got3 == "" {
t.Error("ExtractSessionID() multi-turn with top-level system should return hash")
}
if got3 == got2 {
t.Error("Multi-turn hash should differ from first-turn hash (includes assistant)")
}
}
func TestExtractSessionID_GeminiFormat(t *testing.T) {
t.Parallel()
// Gemini format with systemInstruction and contents
payload := []byte(`{
"systemInstruction": {"parts": [{"text": "You are a helpful assistant."}]},
"contents": [
{"role": "user", "parts": [{"text": "Hello Gemini"}]},
{"role": "model", "parts": [{"text": "Hi there!"}]}
]
}`)
got := ExtractSessionID(nil, payload, nil)
if got == "" {
t.Error("ExtractSessionID() with Gemini format should return hash-based session ID")
}
if !strings.HasPrefix(got, "msg:") {
t.Errorf("ExtractSessionID() = %q, want prefix 'msg:'", got)
}
// Same payload should produce same hash
got2 := ExtractSessionID(nil, payload, nil)
if got != got2 {
t.Errorf("ExtractSessionID() not stable: got %q then %q", got, got2)
}
// Different user message should produce different hash
differentPayload := []byte(`{
"systemInstruction": {"parts": [{"text": "You are a helpful assistant."}]},
"contents": [
{"role": "user", "parts": [{"text": "Hello different"}]},
{"role": "model", "parts": [{"text": "Hi there!"}]}
]
}`)
got3 := ExtractSessionID(nil, differentPayload, nil)
if got == got3 {
t.Errorf("ExtractSessionID() should produce different hash for different user message")
}
}
func TestExtractSessionID_OpenAIResponsesAPI(t *testing.T) {
t.Parallel()
firstTurn := []byte(`{
"instructions": "You are Codex, based on GPT-5.",
"input": [
{"type": "message", "role": "developer", "content": [{"type": "input_text", "text": "system instructions"}]},
{"type": "message", "role": "user", "content": [{"type": "input_text", "text": "hi"}]}
]
}`)
got1 := ExtractSessionID(nil, firstTurn, nil)
if got1 == "" {
t.Error("ExtractSessionID() should return hash for OpenAI Responses API format")
}
if !strings.HasPrefix(got1, "msg:") {
t.Errorf("ExtractSessionID() = %q, want prefix 'msg:'", got1)
}
secondTurn := []byte(`{
"instructions": "You are Codex, based on GPT-5.",
"input": [
{"type": "message", "role": "developer", "content": [{"type": "input_text", "text": "system instructions"}]},
{"type": "message", "role": "user", "content": [{"type": "input_text", "text": "hi"}]},
{"type": "reasoning", "summary": [{"type": "summary_text", "text": "thinking..."}], "encrypted_content": "xxx"},
{"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": "Hello!"}]},
{"type": "message", "role": "user", "content": [{"type": "input_text", "text": "what can you do"}]}
]
}`)
got2 := ExtractSessionID(nil, secondTurn, nil)
if got2 == "" {
t.Error("ExtractSessionID() should return hash for second turn")
}
if got1 == got2 {
t.Log("First turn and second turn have different hashes (expected: second includes assistant)")
}
thirdTurn := []byte(`{
"instructions": "You are Codex, based on GPT-5.",
"input": [
{"type": "message", "role": "developer", "content": [{"type": "input_text", "text": "system instructions"}]},
{"type": "message", "role": "user", "content": [{"type": "input_text", "text": "hi"}]},
{"type": "reasoning", "summary": [{"type": "summary_text", "text": "thinking..."}], "encrypted_content": "xxx"},
{"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": "Hello!"}]},
{"type": "message", "role": "user", "content": [{"type": "input_text", "text": "what can you do"}]},
{"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": "I can help with..."}]},
{"type": "message", "role": "user", "content": [{"type": "input_text", "text": "thanks"}]}
]
}`)
got3 := ExtractSessionID(nil, thirdTurn, nil)
if got2 != got3 {
t.Errorf("Second and third turn should have same hash (same first assistant): got %q vs %q", got2, got3)
}
}
func TestSessionAffinitySelector_ThreeScenarios(t *testing.T) {
t.Parallel()
fallback := &RoundRobinSelector{}
selector := NewSessionAffinitySelectorWithConfig(SessionAffinityConfig{
Fallback: fallback,
TTL: time.Minute,
})
defer selector.Stop()
auths := []*Auth{{ID: "auth-a"}, {ID: "auth-b"}, {ID: "auth-c"}}
testCases := []struct {
name string
scenario string
payload []byte
}{
{
name: "OpenAI_Scenario1_NewRequest",
scenario: "new",
payload: []byte(`{"messages":[{"role":"system","content":"You are helpful"},{"role":"user","content":"Hello"}]}`),
},
{
name: "OpenAI_Scenario2_SecondTurn",
scenario: "second",
payload: []byte(`{"messages":[{"role":"system","content":"You are helpful"},{"role":"user","content":"Hello"},{"role":"assistant","content":"Hi there!"},{"role":"user","content":"Help me"}]}`),
},
{
name: "OpenAI_Scenario3_ManyTurns",
scenario: "many",
payload: []byte(`{"messages":[{"role":"system","content":"You are helpful"},{"role":"user","content":"Hello"},{"role":"assistant","content":"Hi there!"},{"role":"user","content":"Help me"},{"role":"assistant","content":"Sure!"},{"role":"user","content":"Thanks"}]}`),
},
{
name: "Gemini_Scenario1_NewRequest",
scenario: "new",
payload: []byte(`{"systemInstruction":{"parts":[{"text":"You are helpful"}]},"contents":[{"role":"user","parts":[{"text":"Hello Gemini"}]}]}`),
},
{
name: "Gemini_Scenario2_SecondTurn",
scenario: "second",
payload: []byte(`{"systemInstruction":{"parts":[{"text":"You are helpful"}]},"contents":[{"role":"user","parts":[{"text":"Hello Gemini"}]},{"role":"model","parts":[{"text":"Hi!"}]},{"role":"user","parts":[{"text":"Help"}]}]}`),
},
{
name: "Gemini_Scenario3_ManyTurns",
scenario: "many",
payload: []byte(`{"systemInstruction":{"parts":[{"text":"You are helpful"}]},"contents":[{"role":"user","parts":[{"text":"Hello Gemini"}]},{"role":"model","parts":[{"text":"Hi!"}]},{"role":"user","parts":[{"text":"Help"}]},{"role":"model","parts":[{"text":"Sure!"}]},{"role":"user","parts":[{"text":"Thanks"}]}]}`),
},
{
name: "Claude_Scenario1_NewRequest",
scenario: "new",
payload: []byte(`{"messages":[{"role":"user","content":"Hello Claude"}]}`),
},
{
name: "Claude_Scenario2_SecondTurn",
scenario: "second",
payload: []byte(`{"messages":[{"role":"user","content":"Hello Claude"},{"role":"assistant","content":"Hello!"},{"role":"user","content":"Help me"}]}`),
},
{
name: "Claude_Scenario3_ManyTurns",
scenario: "many",
payload: []byte(`{"messages":[{"role":"user","content":"Hello Claude"},{"role":"assistant","content":"Hello!"},{"role":"user","content":"Help"},{"role":"assistant","content":"Sure!"},{"role":"user","content":"Thanks"}]}`),
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
opts := cliproxyexecutor.Options{OriginalRequest: tc.payload}
picked, err := selector.Pick(context.Background(), "provider", "model", opts, auths)
if err != nil {
t.Fatalf("Pick() error = %v", err)
}
if picked == nil {
t.Fatal("Pick() returned nil")
}
t.Logf("%s: picked %s", tc.name, picked.ID)
})
}
t.Run("Scenario2And3_SameAuth", func(t *testing.T) {
openaiS2 := []byte(`{"messages":[{"role":"system","content":"Stable test"},{"role":"user","content":"First msg"},{"role":"assistant","content":"Response"},{"role":"user","content":"Second"}]}`)
openaiS3 := []byte(`{"messages":[{"role":"system","content":"Stable test"},{"role":"user","content":"First msg"},{"role":"assistant","content":"Response"},{"role":"user","content":"Second"},{"role":"assistant","content":"More"},{"role":"user","content":"Third"}]}`)
opts2 := cliproxyexecutor.Options{OriginalRequest: openaiS2}
opts3 := cliproxyexecutor.Options{OriginalRequest: openaiS3}
picked2, _ := selector.Pick(context.Background(), "test", "model", opts2, auths)
picked3, _ := selector.Pick(context.Background(), "test", "model", opts3, auths)
if picked2.ID != picked3.ID {
t.Errorf("Scenario2 and Scenario3 should pick same auth: got %s vs %s", picked2.ID, picked3.ID)
}
})
t.Run("Scenario1To2_InheritBinding", func(t *testing.T) {
s1 := []byte(`{"messages":[{"role":"system","content":"Inherit test"},{"role":"user","content":"Initial"}]}`)
s2 := []byte(`{"messages":[{"role":"system","content":"Inherit test"},{"role":"user","content":"Initial"},{"role":"assistant","content":"Reply"},{"role":"user","content":"Continue"}]}`)
opts1 := cliproxyexecutor.Options{OriginalRequest: s1}
opts2 := cliproxyexecutor.Options{OriginalRequest: s2}
picked1, _ := selector.Pick(context.Background(), "inherit", "model", opts1, auths)
picked2, _ := selector.Pick(context.Background(), "inherit", "model", opts2, auths)
if picked1.ID != picked2.ID {
t.Errorf("Scenario2 should inherit Scenario1 binding: got %s vs %s", picked1.ID, picked2.ID)
}
})
}
func TestSessionAffinitySelector_MultiModelSession(t *testing.T) {
t.Parallel()
fallback := &RoundRobinSelector{}
selector := NewSessionAffinitySelectorWithConfig(SessionAffinityConfig{
Fallback: fallback,
TTL: time.Minute,
})
defer selector.Stop()
// auth-a supports only model-a, auth-b supports only model-b
authA := &Auth{ID: "auth-a"}
authB := &Auth{ID: "auth-b"}
// Same session ID for all requests
payload := []byte(`{"metadata":{"user_id":"user_xxx_account__session_multi-model-test"}}`)
opts := cliproxyexecutor.Options{OriginalRequest: payload}
// Request model-a with only auth-a available for that model
authsForModelA := []*Auth{authA}
pickedA, err := selector.Pick(context.Background(), "provider", "model-a", opts, authsForModelA)
if err != nil {
t.Fatalf("Pick() for model-a error = %v", err)
}
if pickedA.ID != "auth-a" {
t.Fatalf("Pick() for model-a = %q, want auth-a", pickedA.ID)
}
// Request model-b with only auth-b available for that model
authsForModelB := []*Auth{authB}
pickedB, err := selector.Pick(context.Background(), "provider", "model-b", opts, authsForModelB)
if err != nil {
t.Fatalf("Pick() for model-b error = %v", err)
}
if pickedB.ID != "auth-b" {
t.Fatalf("Pick() for model-b = %q, want auth-b", pickedB.ID)
}
// Switch back to model-a - should still get auth-a (separate binding per model)
pickedA2, err := selector.Pick(context.Background(), "provider", "model-a", opts, authsForModelA)
if err != nil {
t.Fatalf("Pick() for model-a (2nd) error = %v", err)
}
if pickedA2.ID != "auth-a" {
t.Fatalf("Pick() for model-a (2nd) = %q, want auth-a", pickedA2.ID)
}
// Verify bindings are stable for multiple calls
for i := 0; i < 5; i++ {
gotA, _ := selector.Pick(context.Background(), "provider", "model-a", opts, authsForModelA)
gotB, _ := selector.Pick(context.Background(), "provider", "model-b", opts, authsForModelB)
if gotA.ID != "auth-a" {
t.Fatalf("Pick() #%d for model-a = %q, want auth-a", i, gotA.ID)
}
if gotB.ID != "auth-b" {
t.Fatalf("Pick() #%d for model-b = %q, want auth-b", i, gotB.ID)
}
}
}
func TestExtractSessionID_MultimodalContent(t *testing.T) {
t.Parallel()
// First request generates short hash
firstRequestPayload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"Hello world"},{"type":"image","source":{"data":"..."}}]}]}`)
shortHash := ExtractSessionID(nil, firstRequestPayload, nil)
if shortHash == "" {
t.Error("ExtractSessionID() first request should return short hash")
}
if !strings.HasPrefix(shortHash, "msg:") {
t.Errorf("ExtractSessionID() = %q, want prefix 'msg:'", shortHash)
}
// Multi-turn generates full hash
multiTurnPayload := []byte(`{"messages":[
{"role":"user","content":[{"type":"text","text":"Hello world"},{"type":"image","source":{"data":"..."}}]},
{"role":"assistant","content":"I see an image!"},
{"role":"user","content":"What is it?"}
]}`)
fullHash := ExtractSessionID(nil, multiTurnPayload, nil)
if fullHash == "" {
t.Error("ExtractSessionID() multimodal multi-turn should return full hash")
}
if fullHash == shortHash {
t.Error("Full hash should differ from short hash")
}
// Different user content produces different hash
differentPayload := []byte(`{"messages":[
{"role":"user","content":[{"type":"text","text":"Different content"}]},
{"role":"assistant","content":"I see something different!"}
]}`)
differentHash := ExtractSessionID(nil, differentPayload, nil)
if fullHash == differentHash {
t.Errorf("ExtractSessionID() should produce different hash for different content")
}
}
func TestSessionAffinitySelector_CrossProviderIsolation(t *testing.T) {
t.Parallel()
fallback := &RoundRobinSelector{}
selector := NewSessionAffinitySelectorWithConfig(SessionAffinityConfig{
Fallback: fallback,
TTL: time.Minute,
})
defer selector.Stop()
authClaude := &Auth{ID: "auth-claude"}
authGemini := &Auth{ID: "auth-gemini"}
// Same session ID for both providers
payload := []byte(`{"metadata":{"user_id":"user_xxx_account__session_cross-provider-test"}}`)
opts := cliproxyexecutor.Options{OriginalRequest: payload}
// Request via claude provider
pickedClaude, err := selector.Pick(context.Background(), "claude", "claude-3", opts, []*Auth{authClaude})
if err != nil {
t.Fatalf("Pick() for claude error = %v", err)
}
if pickedClaude.ID != "auth-claude" {
t.Fatalf("Pick() for claude = %q, want auth-claude", pickedClaude.ID)
}
// Same session but via gemini provider should get different auth
pickedGemini, err := selector.Pick(context.Background(), "gemini", "gemini-2.5-pro", opts, []*Auth{authGemini})
if err != nil {
t.Fatalf("Pick() for gemini error = %v", err)
}
if pickedGemini.ID != "auth-gemini" {
t.Fatalf("Pick() for gemini = %q, want auth-gemini", pickedGemini.ID)
}
// Verify both bindings remain stable
for i := 0; i < 5; i++ {
gotC, _ := selector.Pick(context.Background(), "claude", "claude-3", opts, []*Auth{authClaude})
gotG, _ := selector.Pick(context.Background(), "gemini", "gemini-2.5-pro", opts, []*Auth{authGemini})
if gotC.ID != "auth-claude" {
t.Fatalf("Pick() #%d for claude = %q, want auth-claude", i, gotC.ID)
}
if gotG.ID != "auth-gemini" {
t.Fatalf("Pick() #%d for gemini = %q, want auth-gemini", i, gotG.ID)
}
}
}
func TestSessionCache_GetAndRefresh(t *testing.T) {
t.Parallel()
cache := NewSessionCache(100 * time.Millisecond)
defer cache.Stop()
cache.Set("session1", "auth1")
// Verify initial value
got, ok := cache.GetAndRefresh("session1")
if !ok || got != "auth1" {
t.Fatalf("GetAndRefresh() = %q, %v, want auth1, true", got, ok)
}
// Wait half TTL and access again (should refresh)
time.Sleep(60 * time.Millisecond)
got, ok = cache.GetAndRefresh("session1")
if !ok || got != "auth1" {
t.Fatalf("GetAndRefresh() after 60ms = %q, %v, want auth1, true", got, ok)
}
// Wait another 60ms (total 120ms from original, but TTL refreshed at 60ms)
// Entry should still be valid because TTL was refreshed
time.Sleep(60 * time.Millisecond)
got, ok = cache.GetAndRefresh("session1")
if !ok || got != "auth1" {
t.Fatalf("GetAndRefresh() after refresh = %q, %v, want auth1, true (TTL should have been refreshed)", got, ok)
}
// Now wait full TTL without access
time.Sleep(110 * time.Millisecond)
got, ok = cache.GetAndRefresh("session1")
if ok {
t.Fatalf("GetAndRefresh() after expiry = %q, %v, want '', false", got, ok)
}
}
func TestSessionAffinitySelector_RoundRobinDistribution(t *testing.T) {
t.Parallel()
fallback := &RoundRobinSelector{}
selector := NewSessionAffinitySelectorWithConfig(SessionAffinityConfig{
Fallback: fallback,
TTL: time.Minute,
})
defer selector.Stop()
auths := []*Auth{
{ID: "auth-a"},
{ID: "auth-b"},
{ID: "auth-c"},
}
sessionCount := 12
counts := make(map[string]int)
for i := 0; i < sessionCount; i++ {
payload := []byte(fmt.Sprintf(`{"metadata":{"user_id":"user_xxx_account__session_%08d-0000-0000-0000-000000000000"}}`, i))
opts := cliproxyexecutor.Options{OriginalRequest: payload}
got, err := selector.Pick(context.Background(), "provider", "model", opts, auths)
if err != nil {
t.Fatalf("Pick() session %d error = %v", i, err)
}
counts[got.ID]++
}
expected := sessionCount / len(auths)
for _, auth := range auths {
got := counts[auth.ID]
if got != expected {
t.Errorf("auth %s got %d sessions, want %d (round-robin should distribute evenly)", auth.ID, got, expected)
}
}
}
func TestSessionAffinitySelector_Concurrent(t *testing.T) {
t.Parallel()
fallback := &RoundRobinSelector{}
selector := NewSessionAffinitySelectorWithConfig(SessionAffinityConfig{
Fallback: fallback,
TTL: time.Minute,
})
defer selector.Stop()
auths := []*Auth{
{ID: "auth-a"},
{ID: "auth-b"},
{ID: "auth-c"},
}
payload := []byte(`{"metadata":{"user_id":"user_xxx_account__session_concurrent-test"}}`)
opts := cliproxyexecutor.Options{OriginalRequest: payload}
// First pick to establish binding
first, err := selector.Pick(context.Background(), "claude", "claude-3", opts, auths)
if err != nil {
t.Fatalf("Initial Pick() error = %v", err)
}
expectedID := first.ID
start := make(chan struct{})
var wg sync.WaitGroup
errCh := make(chan error, 1)
goroutines := 32
iterations := 50
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
<-start
for j := 0; j < iterations; j++ {
got, err := selector.Pick(context.Background(), "claude", "claude-3", opts, auths)
if err != nil {
select {
case errCh <- err:
default:
}
return
}
if got.ID != expectedID {
select {
case errCh <- fmt.Errorf("concurrent Pick() returned %q, want %q", got.ID, expectedID):
default:
}
return
}
}
}()
}
close(start)
wg.Wait()
select {
case err := <-errCh:
t.Fatalf("concurrent Pick() error = %v", err)
default:
}
}

View File

@@ -0,0 +1,152 @@
package auth
import (
"sync"
"time"
)
// sessionEntry stores auth binding with expiration.
type sessionEntry struct {
authID string
expiresAt time.Time
}
// SessionCache provides TTL-based session to auth mapping with automatic cleanup.
type SessionCache struct {
mu sync.RWMutex
entries map[string]sessionEntry
ttl time.Duration
stopCh chan struct{}
}
// NewSessionCache creates a cache with the specified TTL.
// A background goroutine periodically cleans expired entries.
func NewSessionCache(ttl time.Duration) *SessionCache {
if ttl <= 0 {
ttl = 30 * time.Minute
}
c := &SessionCache{
entries: make(map[string]sessionEntry),
ttl: ttl,
stopCh: make(chan struct{}),
}
go c.cleanupLoop()
return c
}
// Get retrieves the auth ID bound to a session, if still valid.
// Does NOT refresh the TTL on access.
func (c *SessionCache) Get(sessionID string) (string, bool) {
if sessionID == "" {
return "", false
}
c.mu.RLock()
entry, ok := c.entries[sessionID]
c.mu.RUnlock()
if !ok {
return "", false
}
if time.Now().After(entry.expiresAt) {
c.mu.Lock()
delete(c.entries, sessionID)
c.mu.Unlock()
return "", false
}
return entry.authID, true
}
// GetAndRefresh retrieves the auth ID bound to a session and refreshes TTL on hit.
// This extends the binding lifetime for active sessions.
func (c *SessionCache) GetAndRefresh(sessionID string) (string, bool) {
if sessionID == "" {
return "", false
}
now := time.Now()
c.mu.Lock()
entry, ok := c.entries[sessionID]
if !ok {
c.mu.Unlock()
return "", false
}
if now.After(entry.expiresAt) {
delete(c.entries, sessionID)
c.mu.Unlock()
return "", false
}
// Refresh TTL on successful access
entry.expiresAt = now.Add(c.ttl)
c.entries[sessionID] = entry
c.mu.Unlock()
return entry.authID, true
}
// Set binds a session to an auth ID with TTL refresh.
func (c *SessionCache) Set(sessionID, authID string) {
if sessionID == "" || authID == "" {
return
}
c.mu.Lock()
c.entries[sessionID] = sessionEntry{
authID: authID,
expiresAt: time.Now().Add(c.ttl),
}
c.mu.Unlock()
}
// Invalidate removes a specific session binding.
func (c *SessionCache) Invalidate(sessionID string) {
if sessionID == "" {
return
}
c.mu.Lock()
delete(c.entries, sessionID)
c.mu.Unlock()
}
// InvalidateAuth removes all sessions bound to a specific auth ID.
// Used when an auth becomes unavailable.
func (c *SessionCache) InvalidateAuth(authID string) {
if authID == "" {
return
}
c.mu.Lock()
for sid, entry := range c.entries {
if entry.authID == authID {
delete(c.entries, sid)
}
}
c.mu.Unlock()
}
// Stop terminates the background cleanup goroutine.
func (c *SessionCache) Stop() {
select {
case <-c.stopCh:
default:
close(c.stopCh)
}
}
func (c *SessionCache) cleanupLoop() {
ticker := time.NewTicker(c.ttl / 2)
defer ticker.Stop()
for {
select {
case <-c.stopCh:
return
case <-ticker.C:
c.cleanup()
}
}
}
func (c *SessionCache) cleanup() {
now := time.Now()
c.mu.Lock()
for sid, entry := range c.entries {
if now.After(entry.expiresAt) {
delete(c.entries, sid)
}
}
c.mu.Unlock()
}

View File

@@ -6,6 +6,7 @@ package cliproxy
import (
"fmt"
"strings"
"time"
configaccess "github.com/router-for-me/CLIProxyAPI/v6/internal/access/config_access"
"github.com/router-for-me/CLIProxyAPI/v6/internal/api"
@@ -208,8 +209,17 @@ func (b *Builder) Build() (*Service, error) {
}
strategy := ""
sessionAffinity := false
sessionAffinityTTL := time.Hour
if b.cfg != nil {
strategy = strings.ToLower(strings.TrimSpace(b.cfg.Routing.Strategy))
// Support both legacy ClaudeCodeSessionAffinity and new universal SessionAffinity
sessionAffinity = b.cfg.Routing.ClaudeCodeSessionAffinity || b.cfg.Routing.SessionAffinity
if ttlStr := strings.TrimSpace(b.cfg.Routing.SessionAffinityTTL); ttlStr != "" {
if parsed, err := time.ParseDuration(ttlStr); err == nil && parsed > 0 {
sessionAffinityTTL = parsed
}
}
}
var selector coreauth.Selector
switch strategy {
@@ -219,6 +229,14 @@ func (b *Builder) Build() (*Service, error) {
selector = &coreauth.RoundRobinSelector{}
}
// Wrap with session affinity if enabled (failover is always on)
if sessionAffinity {
selector = coreauth.NewSessionAffinitySelectorWithConfig(coreauth.SessionAffinityConfig{
Fallback: selector,
TTL: sessionAffinityTTL,
})
}
coreManager = coreauth.NewManager(tokenStore, selector, nil)
}
// Attach a default RoundTripper provider so providers can opt-in per-auth transports.

View File

@@ -118,7 +118,6 @@ func newDefaultAuthManager() *sdkAuth.Manager {
sdkAuth.NewGeminiAuthenticator(),
sdkAuth.NewCodexAuthenticator(),
sdkAuth.NewClaudeAuthenticator(),
sdkAuth.NewQwenAuthenticator(),
sdkAuth.NewGitLabAuthenticator(),
)
}
@@ -435,8 +434,6 @@ func (s *Service) ensureExecutorsForAuthWithMode(a *coreauth.Auth, forceReplace
s.coreManager.RegisterExecutor(executor.NewAntigravityExecutor(s.cfg))
case "claude":
s.coreManager.RegisterExecutor(executor.NewClaudeExecutor(s.cfg))
case "qwen":
s.coreManager.RegisterExecutor(executor.NewQwenExecutor(s.cfg))
case "iflow":
s.coreManager.RegisterExecutor(executor.NewIFlowExecutor(s.cfg))
case "kimi":
@@ -639,9 +636,13 @@ func (s *Service) Run(ctx context.Context) error {
var watcherWrapper *WatcherWrapper
reloadCallback := func(newCfg *config.Config) {
previousStrategy := ""
var previousSessionAffinity bool
var previousSessionAffinityTTL string
s.cfgMu.RLock()
if s.cfg != nil {
previousStrategy = strings.ToLower(strings.TrimSpace(s.cfg.Routing.Strategy))
previousSessionAffinity = s.cfg.Routing.ClaudeCodeSessionAffinity || s.cfg.Routing.SessionAffinity
previousSessionAffinityTTL = s.cfg.Routing.SessionAffinityTTL
}
s.cfgMu.RUnlock()
@@ -665,7 +666,15 @@ func (s *Service) Run(ctx context.Context) error {
}
previousStrategy = normalizeStrategy(previousStrategy)
nextStrategy = normalizeStrategy(nextStrategy)
if s.coreManager != nil && previousStrategy != nextStrategy {
nextSessionAffinity := newCfg.Routing.ClaudeCodeSessionAffinity || newCfg.Routing.SessionAffinity
nextSessionAffinityTTL := newCfg.Routing.SessionAffinityTTL
selectorChanged := previousStrategy != nextStrategy ||
previousSessionAffinity != nextSessionAffinity ||
previousSessionAffinityTTL != nextSessionAffinityTTL
if s.coreManager != nil && selectorChanged {
var selector coreauth.Selector
switch nextStrategy {
case "fill-first":
@@ -673,6 +682,20 @@ func (s *Service) Run(ctx context.Context) error {
default:
selector = &coreauth.RoundRobinSelector{}
}
if nextSessionAffinity {
ttl := time.Hour
if ttlStr := strings.TrimSpace(nextSessionAffinityTTL); ttlStr != "" {
if parsed, err := time.ParseDuration(ttlStr); err == nil && parsed > 0 {
ttl = parsed
}
}
selector = coreauth.NewSessionAffinitySelectorWithConfig(coreauth.SessionAffinityConfig{
Fallback: selector,
TTL: ttl,
})
}
s.coreManager.SetSelector(selector)
}
@@ -939,9 +962,6 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
}
}
models = applyExcludedModels(models, excluded)
case "qwen":
models = registry.GetQwenModels()
models = applyExcludedModels(models, excluded)
case "iflow":
models = registry.GetIFlowModels()
models = applyExcludedModels(models, excluded)