This commit is contained in:
Luis Pater
2026-04-02 12:15:33 +08:00
65 changed files with 4994 additions and 1800 deletions

View File

@@ -10,6 +10,7 @@ import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"github.com/gin-gonic/gin"
@@ -22,6 +23,25 @@ import (
"github.com/tidwall/sjson"
)
func writeResponsesSSEChunk(w io.Writer, chunk []byte) {
if w == nil || len(chunk) == 0 {
return
}
if _, err := w.Write(chunk); err != nil {
return
}
if bytes.HasSuffix(chunk, []byte("\n\n")) {
return
}
suffix := []byte("\n\n")
if bytes.HasSuffix(chunk, []byte("\n")) {
suffix = []byte("\n")
}
if _, err := w.Write(suffix); err != nil {
return
}
}
// OpenAIResponsesAPIHandler contains the handlers for OpenAIResponses API endpoints.
// It holds a pool of clients to interact with the backend service.
type OpenAIResponsesAPIHandler struct {
@@ -271,11 +291,7 @@ func (h *OpenAIResponsesAPIHandler) handleStreamingResponse(c *gin.Context, rawJ
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
// Write first chunk logic (matching forwardResponsesStream)
if bytes.HasPrefix(chunk, []byte("event:")) {
_, _ = c.Writer.Write([]byte("\n"))
}
_, _ = c.Writer.Write(chunk)
_, _ = c.Writer.Write([]byte("\n"))
writeResponsesSSEChunk(c.Writer, chunk)
flusher.Flush()
// Continue
@@ -400,11 +416,7 @@ func (h *OpenAIResponsesAPIHandler) forwardChatAsResponsesStream(c *gin.Context,
func (h *OpenAIResponsesAPIHandler) forwardResponsesStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{
WriteChunk: func(chunk []byte) {
if bytes.HasPrefix(chunk, []byte("event:")) {
_, _ = c.Writer.Write([]byte("\n"))
}
_, _ = c.Writer.Write(chunk)
_, _ = c.Writer.Write([]byte("\n"))
writeResponsesSSEChunk(c.Writer, chunk)
},
WriteTerminalError: func(errMsg *interfaces.ErrorMessage) {
if errMsg == nil {

View File

@@ -0,0 +1,52 @@
package openai
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
)
func TestForwardResponsesStreamSeparatesDataOnlySSEChunks(t *testing.T) {
gin.SetMode(gin.TestMode)
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, nil)
h := NewOpenAIResponsesAPIHandler(base)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
flusher, ok := c.Writer.(http.Flusher)
if !ok {
t.Fatalf("expected gin writer to implement http.Flusher")
}
data := make(chan []byte, 2)
errs := make(chan *interfaces.ErrorMessage)
data <- []byte("data: {\"type\":\"response.output_item.done\",\"item\":{\"type\":\"function_call\",\"arguments\":\"{}\"}}")
data <- []byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp-1\",\"output\":[]}}")
close(data)
close(errs)
h.forwardResponsesStream(c, flusher, func(error) {}, data, errs)
body := recorder.Body.String()
parts := strings.Split(strings.TrimSpace(body), "\n\n")
if len(parts) != 2 {
t.Fatalf("expected 2 SSE events, got %d. Body: %q", len(parts), body)
}
expectedPart1 := "data: {\"type\":\"response.output_item.done\",\"item\":{\"type\":\"function_call\",\"arguments\":\"{}\"}}"
if parts[0] != expectedPart1 {
t.Errorf("unexpected first event.\nGot: %q\nWant: %q", parts[0], expectedPart1)
}
expectedPart2 := "data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp-1\",\"output\":[]}}"
if parts[1] != expectedPart2 {
t.Errorf("unexpected second event.\nGot: %q\nWant: %q", parts[1], expectedPart2)
}
}

View File

@@ -33,9 +33,6 @@ const (
wsDoneMarker = "[DONE]"
wsTurnStateHeader = "x-codex-turn-state"
wsRequestBodyKey = "REQUEST_BODY_OVERRIDE"
wsPayloadLogMaxSize = 2048
wsBodyLogMaxSize = 64 * 1024
wsBodyLogTruncated = "\n[websocket log truncated]\n"
)
var responsesWebsocketUpgrader = websocket.Upgrader{
@@ -55,14 +52,14 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
return
}
passthroughSessionID := uuid.NewString()
clientRemoteAddr := ""
if c != nil && c.Request != nil {
clientRemoteAddr = strings.TrimSpace(c.Request.RemoteAddr)
}
log.Infof("responses websocket: client connected id=%s remote=%s", passthroughSessionID, clientRemoteAddr)
downstreamSessionKey := websocketDownstreamSessionKey(c.Request)
retainResponsesWebsocketToolCaches(downstreamSessionKey)
clientIP := websocketClientAddress(c)
log.Infof("responses websocket: client connected id=%s remote=%s", passthroughSessionID, clientIP)
var wsTerminateErr error
var wsBodyLog strings.Builder
defer func() {
releaseResponsesWebsocketToolCaches(downstreamSessionKey)
if wsTerminateErr != nil {
// log.Infof("responses websocket: session closing id=%s reason=%v", passthroughSessionID, wsTerminateErr)
} else {
@@ -167,6 +164,9 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
}
continue
}
requestJSON = repairResponsesWebsocketToolCalls(downstreamSessionKey, requestJSON)
updatedLastRequest = bytes.Clone(requestJSON)
lastRequest = updatedLastRequest
modelName := gjson.GetBytes(requestJSON, "model").String()
@@ -203,6 +203,13 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
}
}
func websocketClientAddress(c *gin.Context) string {
if c == nil || c.Request == nil {
return ""
}
return strings.TrimSpace(c.ClientIP())
}
func websocketUpgradeHeaders(req *http.Request) http.Header {
headers := http.Header{}
if req == nil {
@@ -277,6 +284,15 @@ func normalizeResponseSubsequentRequest(rawJSON []byte, lastRequest []byte, last
}
}
// Compaction can cause clients to replace local websocket history with a new
// compact transcript on the next `response.create`. When the input already
// contains historical model output items, treating it as an incremental append
// duplicates stale turn-state and can leave late orphaned function_call items.
if shouldReplaceWebsocketTranscript(rawJSON, nextInput) {
normalized := normalizeResponseTranscriptReplacement(rawJSON, lastRequest)
return normalized, bytes.Clone(normalized), nil
}
// Websocket v2 mode uses response.create with previous_response_id + incremental input.
// Do not expand it into a full input transcript; upstream expects the incremental payload.
if allowIncrementalInputWithPreviousResponseID {
@@ -318,6 +334,10 @@ func normalizeResponseSubsequentRequest(rawJSON []byte, lastRequest []byte, last
Error: fmt.Errorf("invalid request input: %w", errMerge),
}
}
dedupedInput, errDedupeFunctionCalls := dedupeFunctionCallsByCallID(mergedInput)
if errDedupeFunctionCalls == nil {
mergedInput = dedupedInput
}
normalized, errDelete := sjson.DeleteBytes(rawJSON, "type")
if errDelete != nil {
@@ -348,6 +368,91 @@ func normalizeResponseSubsequentRequest(rawJSON []byte, lastRequest []byte, last
return normalized, bytes.Clone(normalized), nil
}
func shouldReplaceWebsocketTranscript(rawJSON []byte, nextInput gjson.Result) bool {
requestType := strings.TrimSpace(gjson.GetBytes(rawJSON, "type").String())
if requestType != wsRequestTypeCreate && requestType != wsRequestTypeAppend {
return false
}
if strings.TrimSpace(gjson.GetBytes(rawJSON, "previous_response_id").String()) != "" {
return false
}
if !nextInput.Exists() || !nextInput.IsArray() {
return false
}
for _, item := range nextInput.Array() {
switch strings.TrimSpace(item.Get("type").String()) {
case "function_call":
return true
case "message":
role := strings.TrimSpace(item.Get("role").String())
if role == "assistant" {
return true
}
}
}
return false
}
func normalizeResponseTranscriptReplacement(rawJSON []byte, lastRequest []byte) []byte {
normalized, errDelete := sjson.DeleteBytes(rawJSON, "type")
if errDelete != nil {
normalized = bytes.Clone(rawJSON)
}
normalized, _ = sjson.DeleteBytes(normalized, "previous_response_id")
if !gjson.GetBytes(normalized, "model").Exists() {
modelName := strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String())
if modelName != "" {
normalized, _ = sjson.SetBytes(normalized, "model", modelName)
}
}
if !gjson.GetBytes(normalized, "instructions").Exists() {
instructions := gjson.GetBytes(lastRequest, "instructions")
if instructions.Exists() {
normalized, _ = sjson.SetRawBytes(normalized, "instructions", []byte(instructions.Raw))
}
}
normalized, _ = sjson.SetBytes(normalized, "stream", true)
return bytes.Clone(normalized)
}
func dedupeFunctionCallsByCallID(rawArray string) (string, error) {
rawArray = strings.TrimSpace(rawArray)
if rawArray == "" {
return "[]", nil
}
var items []json.RawMessage
if errUnmarshal := json.Unmarshal([]byte(rawArray), &items); errUnmarshal != nil {
return "", errUnmarshal
}
seenCallIDs := make(map[string]struct{}, len(items))
filtered := make([]json.RawMessage, 0, len(items))
for _, item := range items {
if len(item) == 0 {
continue
}
itemType := strings.TrimSpace(gjson.GetBytes(item, "type").String())
if itemType == "function_call" {
callID := strings.TrimSpace(gjson.GetBytes(item, "call_id").String())
if callID != "" {
if _, ok := seenCallIDs[callID]; ok {
continue
}
seenCallIDs[callID] = struct{}{}
}
}
filtered = append(filtered, item)
}
out, errMarshal := json.Marshal(filtered)
if errMarshal != nil {
return "", errMarshal
}
return string(out), nil
}
func websocketUpstreamSupportsIncrementalInput(attributes map[string]string, metadata map[string]any) bool {
if len(attributes) > 0 {
if raw := strings.TrimSpace(attributes["websockets"]); raw != "" {
@@ -613,6 +718,10 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
) ([]byte, error) {
completed := false
completedOutput := []byte("[]")
downstreamSessionKey := ""
if c != nil && c.Request != nil {
downstreamSessionKey = websocketDownstreamSessionKey(c.Request)
}
for {
select {
@@ -690,6 +799,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
payloads := websocketJSONPayloadsFromChunk(chunk)
for i := range payloads {
recordResponsesWebsocketToolCallsFromPayload(downstreamSessionKey, payloads[i])
eventType := gjson.GetBytes(payloads[i], "type").String()
if eventType == wsEventTypeCompleted {
completed = true
@@ -837,71 +947,18 @@ func appendWebsocketEvent(builder *strings.Builder, eventType string, payload []
if builder == nil {
return
}
if builder.Len() >= wsBodyLogMaxSize {
return
}
trimmedPayload := bytes.TrimSpace(payload)
if len(trimmedPayload) == 0 {
return
}
if builder.Len() > 0 {
if !appendWebsocketLogString(builder, "\n") {
return
}
builder.WriteString("\n")
}
if !appendWebsocketLogString(builder, "websocket.") {
return
}
if !appendWebsocketLogString(builder, eventType) {
return
}
if !appendWebsocketLogString(builder, "\n") {
return
}
if !appendWebsocketLogBytes(builder, trimmedPayload, len(wsBodyLogTruncated)) {
appendWebsocketLogString(builder, wsBodyLogTruncated)
return
}
appendWebsocketLogString(builder, "\n")
}
func appendWebsocketLogString(builder *strings.Builder, value string) bool {
if builder == nil {
return false
}
remaining := wsBodyLogMaxSize - builder.Len()
if remaining <= 0 {
return false
}
if len(value) <= remaining {
builder.WriteString(value)
return true
}
builder.WriteString(value[:remaining])
return false
}
func appendWebsocketLogBytes(builder *strings.Builder, value []byte, reserveForSuffix int) bool {
if builder == nil {
return false
}
remaining := wsBodyLogMaxSize - builder.Len()
if remaining <= 0 {
return false
}
if len(value) <= remaining {
builder.Write(value)
return true
}
limit := remaining - reserveForSuffix
if limit < 0 {
limit = 0
}
if limit > len(value) {
limit = len(value)
}
builder.Write(value[:limit])
return false
builder.WriteString("websocket.")
builder.WriteString(eventType)
builder.WriteString("\n")
builder.Write(trimmedPayload)
builder.WriteString("\n")
}
func websocketPayloadEventType(payload []byte) string {
@@ -917,15 +974,8 @@ func websocketPayloadPreview(payload []byte) string {
if len(trimmedPayload) == 0 {
return "<empty>"
}
preview := trimmedPayload
if len(preview) > wsPayloadLogMaxSize {
preview = preview[:wsPayloadLogMaxSize]
}
previewText := strings.ReplaceAll(string(preview), "\n", "\\n")
previewText := strings.ReplaceAll(string(trimmedPayload), "\n", "\\n")
previewText = strings.ReplaceAll(previewText, "\r", "\\r")
if len(trimmedPayload) > wsPayloadLogMaxSize {
return fmt.Sprintf("%s...(truncated,total=%d)", previewText, len(trimmedPayload))
}
return previewText
}

View File

@@ -10,6 +10,7 @@ import (
"strings"
"sync"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
@@ -27,6 +28,12 @@ type websocketCaptureExecutor struct {
payloads [][]byte
}
type websocketCompactionCaptureExecutor struct {
mu sync.Mutex
streamPayloads [][]byte
compactPayload []byte
}
type orderedWebsocketSelector struct {
mu sync.Mutex
order []string
@@ -126,6 +133,52 @@ func (e *websocketCaptureExecutor) HttpRequest(context.Context, *coreauth.Auth,
return nil, errors.New("not implemented")
}
func (e *websocketCompactionCaptureExecutor) Identifier() string { return "test-provider" }
func (e *websocketCompactionCaptureExecutor) Execute(_ context.Context, _ *coreauth.Auth, req coreexecutor.Request, opts coreexecutor.Options) (coreexecutor.Response, error) {
e.mu.Lock()
e.compactPayload = bytes.Clone(req.Payload)
e.mu.Unlock()
if opts.Alt != "responses/compact" {
return coreexecutor.Response{}, fmt.Errorf("unexpected non-compact execute alt: %q", opts.Alt)
}
return coreexecutor.Response{Payload: []byte(`{"id":"cmp-1","object":"response.compaction"}`)}, nil
}
func (e *websocketCompactionCaptureExecutor) ExecuteStream(_ context.Context, _ *coreauth.Auth, req coreexecutor.Request, _ coreexecutor.Options) (*coreexecutor.StreamResult, error) {
e.mu.Lock()
callIndex := len(e.streamPayloads)
e.streamPayloads = append(e.streamPayloads, bytes.Clone(req.Payload))
e.mu.Unlock()
var payload []byte
switch callIndex {
case 0:
payload = []byte(`{"type":"response.completed","response":{"id":"resp-1","output":[{"type":"function_call","id":"fc-1","call_id":"call-1","name":"tool"}]}}`)
case 1:
payload = []byte(`{"type":"response.completed","response":{"id":"resp-2","output":[{"type":"message","id":"assistant-1"}]}}`)
default:
payload = []byte(`{"type":"response.completed","response":{"id":"resp-3","output":[{"type":"message","id":"assistant-2"}]}}`)
}
chunks := make(chan coreexecutor.StreamChunk, 1)
chunks <- coreexecutor.StreamChunk{Payload: payload}
close(chunks)
return &coreexecutor.StreamResult{Chunks: chunks}, nil
}
func (e *websocketCompactionCaptureExecutor) Refresh(_ context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
return auth, nil
}
func (e *websocketCompactionCaptureExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
return coreexecutor.Response{}, errors.New("not implemented")
}
func (e *websocketCompactionCaptureExecutor) HttpRequest(context.Context, *coreauth.Auth, *http.Request) (*http.Response, error) {
return nil, errors.New("not implemented")
}
func TestNormalizeResponsesWebsocketRequestCreate(t *testing.T) {
raw := []byte(`{"type":"response.create","model":"test-model","stream":false,"input":[{"type":"message","id":"msg-1"}]}`)
@@ -339,33 +392,6 @@ func TestAppendWebsocketEvent(t *testing.T) {
}
}
func TestAppendWebsocketEventTruncatesAtLimit(t *testing.T) {
var builder strings.Builder
payload := bytes.Repeat([]byte("x"), wsBodyLogMaxSize)
appendWebsocketEvent(&builder, "request", payload)
got := builder.String()
if len(got) > wsBodyLogMaxSize {
t.Fatalf("body log len = %d, want <= %d", len(got), wsBodyLogMaxSize)
}
if !strings.Contains(got, wsBodyLogTruncated) {
t.Fatalf("expected truncation marker in body log")
}
}
func TestAppendWebsocketEventNoGrowthAfterLimit(t *testing.T) {
var builder strings.Builder
appendWebsocketEvent(&builder, "request", bytes.Repeat([]byte("x"), wsBodyLogMaxSize))
initial := builder.String()
appendWebsocketEvent(&builder, "response", []byte(`{"type":"response.completed"}`))
if builder.String() != initial {
t.Fatalf("builder grew after reaching limit")
}
}
func TestSetWebsocketRequestBody(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
@@ -390,6 +416,108 @@ func TestSetWebsocketRequestBody(t *testing.T) {
}
}
func TestRepairResponsesWebsocketToolCallsInsertsCachedOutput(t *testing.T) {
cache := newWebsocketToolOutputCache(time.Minute, 10)
sessionKey := "session-1"
cacheWarm := []byte(`{"previous_response_id":"resp-1","input":[{"type":"function_call_output","call_id":"call-1","output":"ok"}]}`)
warmed := repairResponsesWebsocketToolCallsWithCache(cache, sessionKey, cacheWarm)
if gjson.GetBytes(warmed, "input.0.call_id").String() != "call-1" {
t.Fatalf("expected warmup output to remain")
}
raw := []byte(`{"input":[{"type":"function_call","call_id":"call-1","name":"tool"},{"type":"message","id":"msg-1"}]}`)
repaired := repairResponsesWebsocketToolCallsWithCache(cache, sessionKey, raw)
input := gjson.GetBytes(repaired, "input").Array()
if len(input) != 3 {
t.Fatalf("repaired input len = %d, want 3", len(input))
}
if input[0].Get("type").String() != "function_call" || input[0].Get("call_id").String() != "call-1" {
t.Fatalf("unexpected first item: %s", input[0].Raw)
}
if input[1].Get("type").String() != "function_call_output" || input[1].Get("call_id").String() != "call-1" {
t.Fatalf("missing inserted output: %s", input[1].Raw)
}
if input[2].Get("type").String() != "message" || input[2].Get("id").String() != "msg-1" {
t.Fatalf("unexpected trailing item: %s", input[2].Raw)
}
}
func TestRepairResponsesWebsocketToolCallsDropsOrphanFunctionCall(t *testing.T) {
cache := newWebsocketToolOutputCache(time.Minute, 10)
sessionKey := "session-1"
raw := []byte(`{"input":[{"type":"function_call","call_id":"call-1","name":"tool"},{"type":"message","id":"msg-1"}]}`)
repaired := repairResponsesWebsocketToolCallsWithCache(cache, sessionKey, raw)
input := gjson.GetBytes(repaired, "input").Array()
if len(input) != 1 {
t.Fatalf("repaired input len = %d, want 1", len(input))
}
if input[0].Get("type").String() != "message" || input[0].Get("id").String() != "msg-1" {
t.Fatalf("unexpected remaining item: %s", input[0].Raw)
}
}
func TestRepairResponsesWebsocketToolCallsInsertsCachedCallForOrphanOutput(t *testing.T) {
outputCache := newWebsocketToolOutputCache(time.Minute, 10)
callCache := newWebsocketToolOutputCache(time.Minute, 10)
sessionKey := "session-1"
callCache.record(sessionKey, "call-1", []byte(`{"type":"function_call","call_id":"call-1","name":"tool"}`))
raw := []byte(`{"input":[{"type":"function_call_output","call_id":"call-1","output":"ok"},{"type":"message","id":"msg-1"}]}`)
repaired := repairResponsesWebsocketToolCallsWithCaches(outputCache, callCache, sessionKey, raw)
input := gjson.GetBytes(repaired, "input").Array()
if len(input) != 3 {
t.Fatalf("repaired input len = %d, want 3", len(input))
}
if input[0].Get("type").String() != "function_call" || input[0].Get("call_id").String() != "call-1" {
t.Fatalf("missing inserted call: %s", input[0].Raw)
}
if input[1].Get("type").String() != "function_call_output" || input[1].Get("call_id").String() != "call-1" {
t.Fatalf("unexpected output item: %s", input[1].Raw)
}
if input[2].Get("type").String() != "message" || input[2].Get("id").String() != "msg-1" {
t.Fatalf("unexpected trailing item: %s", input[2].Raw)
}
}
func TestRepairResponsesWebsocketToolCallsDropsOrphanOutputWhenCallMissing(t *testing.T) {
outputCache := newWebsocketToolOutputCache(time.Minute, 10)
callCache := newWebsocketToolOutputCache(time.Minute, 10)
sessionKey := "session-1"
raw := []byte(`{"input":[{"type":"function_call_output","call_id":"call-1","output":"ok"},{"type":"message","id":"msg-1"}]}`)
repaired := repairResponsesWebsocketToolCallsWithCaches(outputCache, callCache, sessionKey, raw)
input := gjson.GetBytes(repaired, "input").Array()
if len(input) != 1 {
t.Fatalf("repaired input len = %d, want 1", len(input))
}
if input[0].Get("type").String() != "message" || input[0].Get("id").String() != "msg-1" {
t.Fatalf("unexpected remaining item: %s", input[0].Raw)
}
}
func TestRecordResponsesWebsocketToolCallsFromPayloadWithCache(t *testing.T) {
cache := newWebsocketToolOutputCache(time.Minute, 10)
sessionKey := "session-1"
payload := []byte(`{"type":"response.completed","response":{"id":"resp-1","output":[{"type":"function_call","id":"fc-1","call_id":"call-1","name":"tool","arguments":"{}"}]}}`)
recordResponsesWebsocketToolCallsFromPayloadWithCache(cache, sessionKey, payload)
cached, ok := cache.get(sessionKey, "call-1")
if !ok {
t.Fatalf("expected cached tool call")
}
if gjson.GetBytes(cached, "type").String() != "function_call" || gjson.GetBytes(cached, "call_id").String() != "call-1" {
t.Fatalf("unexpected cached tool call: %s", cached)
}
}
func TestForwardResponsesWebsocketPreservesCompletedEvent(t *testing.T) {
gin.SetMode(gin.TestMode)
@@ -593,6 +721,31 @@ func TestResponsesWebsocketPrewarmHandledLocallyForSSEUpstream(t *testing.T) {
}
}
func TestWebsocketClientAddressUsesGinClientIP(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, engine := gin.CreateTestContext(recorder)
if err := engine.SetTrustedProxies([]string{"0.0.0.0/0", "::/0"}); err != nil {
t.Fatalf("SetTrustedProxies: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "/v1/responses/ws", nil)
req.RemoteAddr = "172.18.0.1:34282"
req.Header.Set("X-Forwarded-For", "203.0.113.7")
c.Request = req
if got := websocketClientAddress(c); got != strings.TrimSpace(c.ClientIP()) {
t.Fatalf("websocketClientAddress = %q, ClientIP = %q", got, c.ClientIP())
}
}
func TestWebsocketClientAddressReturnsEmptyForNilContext(t *testing.T) {
if got := websocketClientAddress(nil); got != "" {
t.Fatalf("websocketClientAddress(nil) = %q, want empty", got)
}
}
func TestResponsesWebsocketPinsOnlyWebsocketCapableAuth(t *testing.T) {
gin.SetMode(gin.TestMode)
@@ -662,3 +815,183 @@ func TestResponsesWebsocketPinsOnlyWebsocketCapableAuth(t *testing.T) {
t.Fatalf("selected auth IDs = %v, want [auth-sse auth-ws]", got)
}
}
func TestNormalizeResponsesWebsocketRequestTreatsTranscriptReplacementAsReset(t *testing.T) {
lastRequest := []byte(`{"model":"test-model","stream":true,"input":[{"type":"message","id":"msg-1"},{"type":"function_call","id":"fc-1","call_id":"call-1"},{"type":"function_call_output","id":"tool-out-1","call_id":"call-1"},{"type":"message","id":"assistant-1","role":"assistant"}]}`)
lastResponseOutput := []byte(`[
{"type":"message","id":"assistant-1","role":"assistant"}
]`)
raw := []byte(`{"type":"response.create","input":[{"type":"function_call","id":"fc-compact","call_id":"call-1","name":"tool"},{"type":"message","id":"msg-2"}]}`)
normalized, next, errMsg := normalizeResponsesWebsocketRequest(raw, lastRequest, lastResponseOutput)
if errMsg != nil {
t.Fatalf("unexpected error: %v", errMsg.Error)
}
if gjson.GetBytes(normalized, "previous_response_id").Exists() {
t.Fatalf("previous_response_id must not exist in transcript replacement mode")
}
items := gjson.GetBytes(normalized, "input").Array()
if len(items) != 2 {
t.Fatalf("replacement input len = %d, want 2: %s", len(items), normalized)
}
if items[0].Get("id").String() != "fc-compact" || items[1].Get("id").String() != "msg-2" {
t.Fatalf("replacement transcript was not preserved as-is: %s", normalized)
}
if !bytes.Equal(next, normalized) {
t.Fatalf("next request snapshot should match replacement request")
}
}
func TestNormalizeResponsesWebsocketRequestDoesNotTreatDeveloperMessageAsReplacement(t *testing.T) {
lastRequest := []byte(`{"model":"test-model","stream":true,"input":[{"type":"message","id":"msg-1"}]}`)
lastResponseOutput := []byte(`[
{"type":"message","id":"assistant-1","role":"assistant"}
]`)
raw := []byte(`{"type":"response.create","input":[{"type":"message","id":"dev-1","role":"developer"},{"type":"message","id":"msg-2"}]}`)
normalized, next, errMsg := normalizeResponsesWebsocketRequest(raw, lastRequest, lastResponseOutput)
if errMsg != nil {
t.Fatalf("unexpected error: %v", errMsg.Error)
}
items := gjson.GetBytes(normalized, "input").Array()
if len(items) != 4 {
t.Fatalf("merged input len = %d, want 4: %s", len(items), normalized)
}
if items[0].Get("id").String() != "msg-1" ||
items[1].Get("id").String() != "assistant-1" ||
items[2].Get("id").String() != "dev-1" ||
items[3].Get("id").String() != "msg-2" {
t.Fatalf("developer follow-up should preserve merge behavior: %s", normalized)
}
if !bytes.Equal(next, normalized) {
t.Fatalf("next request snapshot should match merged request")
}
}
func TestNormalizeResponsesWebsocketRequestDropsDuplicateFunctionCallsByCallID(t *testing.T) {
lastRequest := []byte(`{"model":"test-model","stream":true,"input":[{"type":"function_call","id":"fc-1","call_id":"call-1"},{"type":"function_call_output","id":"tool-out-1","call_id":"call-1"}]}`)
lastResponseOutput := []byte(`[
{"type":"function_call","id":"fc-1","call_id":"call-1","name":"tool"}
]`)
raw := []byte(`{"type":"response.create","input":[{"type":"message","id":"msg-2"}]}`)
normalized, _, errMsg := normalizeResponsesWebsocketRequest(raw, lastRequest, lastResponseOutput)
if errMsg != nil {
t.Fatalf("unexpected error: %v", errMsg.Error)
}
items := gjson.GetBytes(normalized, "input").Array()
if len(items) != 3 {
t.Fatalf("merged input len = %d, want 3: %s", len(items), normalized)
}
if items[0].Get("id").String() != "fc-1" ||
items[1].Get("id").String() != "tool-out-1" ||
items[2].Get("id").String() != "msg-2" {
t.Fatalf("unexpected merged input order: %s", normalized)
}
}
func TestResponsesWebsocketCompactionResetsTurnStateOnTranscriptReplacement(t *testing.T) {
gin.SetMode(gin.TestMode)
executor := &websocketCompactionCaptureExecutor{}
manager := coreauth.NewManager(nil, nil, nil)
manager.RegisterExecutor(executor)
auth := &coreauth.Auth{ID: "auth-sse", Provider: executor.Identifier(), Status: coreauth.StatusActive}
if _, err := manager.Register(context.Background(), auth); err != nil {
t.Fatalf("Register auth: %v", err)
}
registry.GetGlobalRegistry().RegisterClient(auth.ID, auth.Provider, []*registry.ModelInfo{{ID: "test-model"}})
t.Cleanup(func() {
registry.GetGlobalRegistry().UnregisterClient(auth.ID)
})
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
h := NewOpenAIResponsesAPIHandler(base)
router := gin.New()
router.GET("/v1/responses/ws", h.ResponsesWebsocket)
router.POST("/v1/responses/compact", h.Compact)
server := httptest.NewServer(router)
defer server.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/v1/responses/ws"
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil {
t.Fatalf("dial websocket: %v", err)
}
defer func() {
if errClose := conn.Close(); errClose != nil {
t.Fatalf("close websocket: %v", errClose)
}
}()
requests := []string{
`{"type":"response.create","model":"test-model","input":[{"type":"message","id":"msg-1"}]}`,
`{"type":"response.create","input":[{"type":"function_call_output","call_id":"call-1","id":"tool-out-1"}]}`,
}
for i := range requests {
if errWrite := conn.WriteMessage(websocket.TextMessage, []byte(requests[i])); errWrite != nil {
t.Fatalf("write websocket message %d: %v", i+1, errWrite)
}
_, payload, errReadMessage := conn.ReadMessage()
if errReadMessage != nil {
t.Fatalf("read websocket message %d: %v", i+1, errReadMessage)
}
if got := gjson.GetBytes(payload, "type").String(); got != wsEventTypeCompleted {
t.Fatalf("message %d payload type = %s, want %s", i+1, got, wsEventTypeCompleted)
}
}
compactResp, errPost := server.Client().Post(
server.URL+"/v1/responses/compact",
"application/json",
strings.NewReader(`{"model":"test-model","input":[{"type":"message","id":"summary-1"}]}`),
)
if errPost != nil {
t.Fatalf("compact request failed: %v", errPost)
}
if errClose := compactResp.Body.Close(); errClose != nil {
t.Fatalf("close compact response body: %v", errClose)
}
if compactResp.StatusCode != http.StatusOK {
t.Fatalf("compact status = %d, want %d", compactResp.StatusCode, http.StatusOK)
}
// Simulate a post-compaction client turn that replaces local history with a compacted transcript.
// The websocket handler must treat this as a state reset, not append it to stale pre-compaction state.
postCompact := `{"type":"response.create","input":[{"type":"function_call","id":"fc-compact","call_id":"call-1","name":"tool"},{"type":"message","id":"msg-2"}]}`
if errWrite := conn.WriteMessage(websocket.TextMessage, []byte(postCompact)); errWrite != nil {
t.Fatalf("write post-compact websocket message: %v", errWrite)
}
_, payload, errReadMessage := conn.ReadMessage()
if errReadMessage != nil {
t.Fatalf("read post-compact websocket message: %v", errReadMessage)
}
if got := gjson.GetBytes(payload, "type").String(); got != wsEventTypeCompleted {
t.Fatalf("post-compact payload type = %s, want %s", got, wsEventTypeCompleted)
}
executor.mu.Lock()
defer executor.mu.Unlock()
if executor.compactPayload == nil {
t.Fatalf("compact payload was not captured")
}
if len(executor.streamPayloads) != 3 {
t.Fatalf("stream payload count = %d, want 3", len(executor.streamPayloads))
}
merged := executor.streamPayloads[2]
items := gjson.GetBytes(merged, "input").Array()
if len(items) != 2 {
t.Fatalf("merged input len = %d, want 2: %s", len(items), merged)
}
if items[0].Get("id").String() != "fc-compact" ||
items[1].Get("id").String() != "msg-2" {
t.Fatalf("unexpected post-compact input order: %s", merged)
}
if items[0].Get("call_id").String() != "call-1" {
t.Fatalf("post-compact function call id = %s, want call-1", items[0].Get("call_id").String())
}
}

View File

@@ -0,0 +1,402 @@
package openai
import (
"encoding/json"
"net/http"
"strings"
"sync"
"time"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
const (
websocketToolOutputCacheMaxPerSession = 256
websocketToolOutputCacheTTL = 30 * time.Minute
)
var defaultWebsocketToolOutputCache = newWebsocketToolOutputCache(0, websocketToolOutputCacheMaxPerSession)
var defaultWebsocketToolCallCache = newWebsocketToolOutputCache(0, websocketToolOutputCacheMaxPerSession)
var defaultWebsocketToolSessionRefs = newWebsocketToolSessionRefCounter()
type websocketToolOutputCache struct {
mu sync.Mutex
ttl time.Duration
maxPerSession int
sessions map[string]*websocketToolOutputSession
}
type websocketToolOutputSession struct {
lastSeen time.Time
outputs map[string]json.RawMessage
order []string
}
func newWebsocketToolOutputCache(ttl time.Duration, maxPerSession int) *websocketToolOutputCache {
if ttl < 0 {
ttl = websocketToolOutputCacheTTL
}
if maxPerSession <= 0 {
maxPerSession = websocketToolOutputCacheMaxPerSession
}
return &websocketToolOutputCache{
ttl: ttl,
maxPerSession: maxPerSession,
sessions: make(map[string]*websocketToolOutputSession),
}
}
func (c *websocketToolOutputCache) record(sessionKey string, callID string, item json.RawMessage) {
sessionKey = strings.TrimSpace(sessionKey)
callID = strings.TrimSpace(callID)
if sessionKey == "" || callID == "" || c == nil {
return
}
now := time.Now()
c.mu.Lock()
defer c.mu.Unlock()
c.cleanupLocked(now)
session, ok := c.sessions[sessionKey]
if !ok || session == nil {
session = &websocketToolOutputSession{
lastSeen: now,
outputs: make(map[string]json.RawMessage),
}
c.sessions[sessionKey] = session
}
session.lastSeen = now
if _, exists := session.outputs[callID]; !exists {
session.order = append(session.order, callID)
}
session.outputs[callID] = append(json.RawMessage(nil), item...)
for len(session.order) > c.maxPerSession {
evict := session.order[0]
session.order = session.order[1:]
delete(session.outputs, evict)
}
}
func (c *websocketToolOutputCache) get(sessionKey string, callID string) (json.RawMessage, bool) {
sessionKey = strings.TrimSpace(sessionKey)
callID = strings.TrimSpace(callID)
if sessionKey == "" || callID == "" || c == nil {
return nil, false
}
now := time.Now()
c.mu.Lock()
defer c.mu.Unlock()
c.cleanupLocked(now)
session, ok := c.sessions[sessionKey]
if !ok || session == nil {
return nil, false
}
session.lastSeen = now
item, ok := session.outputs[callID]
if !ok || len(item) == 0 {
return nil, false
}
return append(json.RawMessage(nil), item...), true
}
func (c *websocketToolOutputCache) cleanupLocked(now time.Time) {
if c == nil || c.ttl <= 0 {
return
}
for key, session := range c.sessions {
if session == nil {
delete(c.sessions, key)
continue
}
if now.Sub(session.lastSeen) > c.ttl {
delete(c.sessions, key)
}
}
}
func (c *websocketToolOutputCache) deleteSession(sessionKey string) {
sessionKey = strings.TrimSpace(sessionKey)
if sessionKey == "" || c == nil {
return
}
c.mu.Lock()
defer c.mu.Unlock()
delete(c.sessions, sessionKey)
}
func websocketDownstreamSessionKey(req *http.Request) string {
if req == nil {
return ""
}
if requestID := strings.TrimSpace(req.Header.Get("X-Client-Request-Id")); requestID != "" {
return requestID
}
if raw := strings.TrimSpace(req.Header.Get("X-Codex-Turn-Metadata")); raw != "" {
if sessionID := strings.TrimSpace(gjson.Get(raw, "session_id").String()); sessionID != "" {
return sessionID
}
}
if sessionID := strings.TrimSpace(req.Header.Get("Session_id")); sessionID != "" {
return sessionID
}
return ""
}
type websocketToolSessionRefCounter struct {
mu sync.Mutex
counts map[string]int
}
func newWebsocketToolSessionRefCounter() *websocketToolSessionRefCounter {
return &websocketToolSessionRefCounter{counts: make(map[string]int)}
}
func (c *websocketToolSessionRefCounter) acquire(sessionKey string) {
sessionKey = strings.TrimSpace(sessionKey)
if sessionKey == "" || c == nil {
return
}
c.mu.Lock()
defer c.mu.Unlock()
c.counts[sessionKey]++
}
func (c *websocketToolSessionRefCounter) release(sessionKey string) bool {
sessionKey = strings.TrimSpace(sessionKey)
if sessionKey == "" || c == nil {
return false
}
c.mu.Lock()
defer c.mu.Unlock()
count := c.counts[sessionKey]
if count <= 1 {
delete(c.counts, sessionKey)
return true
}
c.counts[sessionKey] = count - 1
return false
}
func retainResponsesWebsocketToolCaches(sessionKey string) {
if defaultWebsocketToolSessionRefs == nil {
return
}
defaultWebsocketToolSessionRefs.acquire(sessionKey)
}
func releaseResponsesWebsocketToolCaches(sessionKey string) {
if defaultWebsocketToolSessionRefs == nil {
return
}
if !defaultWebsocketToolSessionRefs.release(sessionKey) {
return
}
if defaultWebsocketToolOutputCache != nil {
defaultWebsocketToolOutputCache.deleteSession(sessionKey)
}
if defaultWebsocketToolCallCache != nil {
defaultWebsocketToolCallCache.deleteSession(sessionKey)
}
}
func repairResponsesWebsocketToolCalls(sessionKey string, payload []byte) []byte {
return repairResponsesWebsocketToolCallsWithCaches(defaultWebsocketToolOutputCache, defaultWebsocketToolCallCache, sessionKey, payload)
}
func repairResponsesWebsocketToolCallsWithCache(cache *websocketToolOutputCache, sessionKey string, payload []byte) []byte {
return repairResponsesWebsocketToolCallsWithCaches(cache, nil, sessionKey, payload)
}
func repairResponsesWebsocketToolCallsWithCaches(outputCache, callCache *websocketToolOutputCache, sessionKey string, payload []byte) []byte {
sessionKey = strings.TrimSpace(sessionKey)
if sessionKey == "" || outputCache == nil || len(payload) == 0 {
return payload
}
input := gjson.GetBytes(payload, "input")
if !input.Exists() || !input.IsArray() {
return payload
}
allowOrphanOutputs := strings.TrimSpace(gjson.GetBytes(payload, "previous_response_id").String()) != ""
updatedRaw, errRepair := repairResponsesToolCallsArray(outputCache, callCache, sessionKey, input.Raw, allowOrphanOutputs)
if errRepair != nil || updatedRaw == "" || updatedRaw == input.Raw {
return payload
}
updated, errSet := sjson.SetRawBytes(payload, "input", []byte(updatedRaw))
if errSet != nil {
return payload
}
return updated
}
func repairResponsesToolCallsArray(outputCache, callCache *websocketToolOutputCache, sessionKey string, rawArray string, allowOrphanOutputs bool) (string, error) {
rawArray = strings.TrimSpace(rawArray)
if rawArray == "" {
return "[]", nil
}
var items []json.RawMessage
if errUnmarshal := json.Unmarshal([]byte(rawArray), &items); errUnmarshal != nil {
return "", errUnmarshal
}
// First pass: record tool outputs and remember which call_ids have outputs in this payload.
outputPresent := make(map[string]struct{}, len(items))
callPresent := make(map[string]struct{}, len(items))
for _, item := range items {
if len(item) == 0 {
continue
}
itemType := strings.TrimSpace(gjson.GetBytes(item, "type").String())
switch itemType {
case "function_call_output":
callID := strings.TrimSpace(gjson.GetBytes(item, "call_id").String())
if callID == "" {
continue
}
outputPresent[callID] = struct{}{}
outputCache.record(sessionKey, callID, item)
case "function_call":
callID := strings.TrimSpace(gjson.GetBytes(item, "call_id").String())
if callID == "" {
continue
}
callPresent[callID] = struct{}{}
if callCache != nil {
callCache.record(sessionKey, callID, item)
}
}
}
filtered := make([]json.RawMessage, 0, len(items))
insertedCalls := make(map[string]struct{}, len(items))
for _, item := range items {
if len(item) == 0 {
continue
}
itemType := strings.TrimSpace(gjson.GetBytes(item, "type").String())
if itemType == "function_call_output" {
callID := strings.TrimSpace(gjson.GetBytes(item, "call_id").String())
if callID == "" {
// Upstream rejects tool outputs without a call_id; drop it.
continue
}
if allowOrphanOutputs {
filtered = append(filtered, item)
continue
}
if _, ok := callPresent[callID]; ok {
filtered = append(filtered, item)
continue
}
if callCache != nil {
if cached, ok := callCache.get(sessionKey, callID); ok {
if _, already := insertedCalls[callID]; !already {
filtered = append(filtered, cached)
insertedCalls[callID] = struct{}{}
callPresent[callID] = struct{}{}
}
filtered = append(filtered, item)
continue
}
}
// Drop orphaned function_call_output items; upstream rejects transcripts with missing calls.
continue
}
if itemType != "function_call" {
filtered = append(filtered, item)
continue
}
callID := strings.TrimSpace(gjson.GetBytes(item, "call_id").String())
if callID == "" {
// Upstream rejects tool calls without a call_id; drop it.
continue
}
if _, ok := outputPresent[callID]; ok {
filtered = append(filtered, item)
continue
}
if cached, ok := outputCache.get(sessionKey, callID); ok {
filtered = append(filtered, item)
filtered = append(filtered, cached)
outputPresent[callID] = struct{}{}
continue
}
// Drop orphaned function_call items; upstream rejects transcripts with missing outputs.
}
out, errMarshal := json.Marshal(filtered)
if errMarshal != nil {
return "", errMarshal
}
return string(out), nil
}
func recordResponsesWebsocketToolCallsFromPayload(sessionKey string, payload []byte) {
recordResponsesWebsocketToolCallsFromPayloadWithCache(defaultWebsocketToolCallCache, sessionKey, payload)
}
func recordResponsesWebsocketToolCallsFromPayloadWithCache(cache *websocketToolOutputCache, sessionKey string, payload []byte) {
sessionKey = strings.TrimSpace(sessionKey)
if sessionKey == "" || cache == nil || len(payload) == 0 {
return
}
eventType := strings.TrimSpace(gjson.GetBytes(payload, "type").String())
switch eventType {
case "response.completed":
output := gjson.GetBytes(payload, "response.output")
if !output.Exists() || !output.IsArray() {
return
}
for _, item := range output.Array() {
if strings.TrimSpace(item.Get("type").String()) != "function_call" {
continue
}
callID := strings.TrimSpace(item.Get("call_id").String())
if callID == "" {
continue
}
cache.record(sessionKey, callID, json.RawMessage(item.Raw))
}
case "response.output_item.added", "response.output_item.done":
item := gjson.GetBytes(payload, "item")
if !item.Exists() || !item.IsObject() {
return
}
if strings.TrimSpace(item.Get("type").String()) != "function_call" {
return
}
callID := strings.TrimSpace(item.Get("call_id").String())
if callID == "" {
return
}
cache.record(sessionKey, callID, json.RawMessage(item.Raw))
}
}

View File

@@ -8,6 +8,7 @@ import (
"io"
"net/http"
"path/filepath"
"sort"
"strconv"
"strings"
"sync"
@@ -437,6 +438,31 @@ func (m *Manager) executionModelCandidates(auth *Auth, routeModel string) []stri
return []string{resolved}
}
func (m *Manager) selectionModelForAuth(auth *Auth, routeModel string) string {
requestedModel := rewriteModelForAuth(routeModel, auth)
if strings.TrimSpace(requestedModel) == "" {
requestedModel = strings.TrimSpace(routeModel)
}
resolvedModel := m.applyOAuthModelAlias(auth, requestedModel)
if strings.TrimSpace(resolvedModel) == "" {
resolvedModel = requestedModel
}
return resolvedModel
}
func (m *Manager) selectionModelKeyForAuth(auth *Auth, routeModel string) string {
return canonicalModelKey(m.selectionModelForAuth(auth, routeModel))
}
func (m *Manager) stateModelForExecution(auth *Auth, routeModel, upstreamModel string, pooled bool) string {
stateModel := executionResultModel(routeModel, upstreamModel, pooled)
selectionModel := m.selectionModelForAuth(auth, routeModel)
if canonicalModelKey(selectionModel) == canonicalModelKey(upstreamModel) && strings.TrimSpace(selectionModel) != "" {
return strings.TrimSpace(upstreamModel)
}
return stateModel
}
func executionResultModel(routeModel, upstreamModel string, pooled bool) string {
if pooled {
if resolved := strings.TrimSpace(upstreamModel); resolved != "" {
@@ -449,14 +475,14 @@ func executionResultModel(routeModel, upstreamModel string, pooled bool) string
return strings.TrimSpace(upstreamModel)
}
func filterExecutionModels(auth *Auth, routeModel string, candidates []string, pooled bool) []string {
func (m *Manager) filterExecutionModels(auth *Auth, routeModel string, candidates []string, pooled bool) []string {
if len(candidates) == 0 {
return nil
}
now := time.Now()
out := make([]string, 0, len(candidates))
for _, upstreamModel := range candidates {
stateModel := executionResultModel(routeModel, upstreamModel, pooled)
stateModel := m.stateModelForExecution(auth, routeModel, upstreamModel, pooled)
blocked, _, _ := isAuthBlockedForModel(auth, stateModel, now)
if blocked {
continue
@@ -469,7 +495,7 @@ func filterExecutionModels(auth *Auth, routeModel string, candidates []string, p
func (m *Manager) preparedExecutionModels(auth *Auth, routeModel string) ([]string, bool) {
candidates := m.executionModelCandidates(auth, routeModel)
pooled := len(candidates) > 1
return filterExecutionModels(auth, routeModel, candidates, pooled), pooled
return m.filterExecutionModels(auth, routeModel, candidates, pooled), pooled
}
func (m *Manager) prepareExecutionModels(auth *Auth, routeModel string) []string {
@@ -477,6 +503,83 @@ func (m *Manager) prepareExecutionModels(auth *Auth, routeModel string) []string
return models
}
func (m *Manager) availableAuthsForRouteModel(auths []*Auth, provider, routeModel string, now time.Time) ([]*Auth, error) {
if len(auths) == 0 {
return nil, &Error{Code: "auth_not_found", Message: "no auth candidates"}
}
availableByPriority := make(map[int][]*Auth)
cooldownCount := 0
var earliest time.Time
for _, candidate := range auths {
checkModel := m.selectionModelForAuth(candidate, routeModel)
blocked, reason, next := isAuthBlockedForModel(candidate, checkModel, now)
if !blocked {
priority := authPriority(candidate)
availableByPriority[priority] = append(availableByPriority[priority], candidate)
continue
}
if reason == blockReasonCooldown {
cooldownCount++
if !next.IsZero() && (earliest.IsZero() || next.Before(earliest)) {
earliest = next
}
}
}
if len(availableByPriority) == 0 {
if cooldownCount == len(auths) && !earliest.IsZero() {
providerForError := provider
if providerForError == "mixed" {
providerForError = ""
}
resetIn := earliest.Sub(now)
if resetIn < 0 {
resetIn = 0
}
return nil, newModelCooldownError(routeModel, providerForError, resetIn)
}
return nil, &Error{Code: "auth_unavailable", Message: "no auth available"}
}
bestPriority := 0
found := false
for priority := range availableByPriority {
if !found || priority > bestPriority {
bestPriority = priority
found = true
}
}
available := availableByPriority[bestPriority]
if len(available) > 1 {
sort.Slice(available, func(i, j int) bool { return available[i].ID < available[j].ID })
}
return available, nil
}
func selectionArgForSelector(selector Selector, routeModel string) string {
if isBuiltInSelector(selector) {
return ""
}
return routeModel
}
func (m *Manager) authSupportsRouteModel(registryRef *registry.ModelRegistry, auth *Auth, routeModel string) bool {
if registryRef == nil || auth == nil {
return true
}
routeKey := canonicalModelKey(routeModel)
if routeKey == "" {
return true
}
if registryRef.ClientSupportsModel(auth.ID, routeKey) {
return true
}
selectionKey := m.selectionModelKeyForAuth(auth, routeModel)
return selectionKey != "" && selectionKey != routeKey && registryRef.ClientSupportsModel(auth.ID, selectionKey)
}
func discardStreamChunks(ch <-chan cliproxyexecutor.StreamChunk) {
if ch == nil {
return
@@ -627,7 +730,7 @@ func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor Provi
}
var lastErr error
for idx, execModel := range execModels {
resultModel := executionResultModel(routeModel, execModel, pooled)
resultModel := m.stateModelForExecution(auth, routeModel, execModel, pooled)
execReq := req
execReq.Model = execModel
streamResult, errStream := executor.ExecuteStream(ctx, auth, execReq, opts)
@@ -1107,7 +1210,7 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req
attempted[auth.ID] = struct{}{}
var authErr error
for _, upstreamModel := range models {
resultModel := executionResultModel(routeModel, upstreamModel, pooled)
resultModel := m.stateModelForExecution(auth, routeModel, upstreamModel, pooled)
execReq := req
execReq.Model = upstreamModel
resp, errExec := executor.Execute(execCtx, auth, execReq, opts)
@@ -1185,7 +1288,7 @@ func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string,
attempted[auth.ID] = struct{}{}
var authErr error
for _, upstreamModel := range models {
resultModel := executionResultModel(routeModel, upstreamModel, pooled)
resultModel := m.stateModelForExecution(auth, routeModel, upstreamModel, pooled)
execReq := req
execReq.Model = upstreamModel
resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts)
@@ -1734,77 +1837,79 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) {
}
} else {
if result.Model != "" {
state := ensureModelState(auth, result.Model)
state.Unavailable = true
state.Status = StatusError
state.UpdatedAt = now
if result.Error != nil {
state.LastError = cloneError(result.Error)
state.StatusMessage = result.Error.Message
auth.LastError = cloneError(result.Error)
auth.StatusMessage = result.Error.Message
}
if !isRequestScopedNotFoundResultError(result.Error) {
state := ensureModelState(auth, result.Model)
state.Unavailable = true
state.Status = StatusError
state.UpdatedAt = now
if result.Error != nil {
state.LastError = cloneError(result.Error)
state.StatusMessage = result.Error.Message
auth.LastError = cloneError(result.Error)
auth.StatusMessage = result.Error.Message
}
statusCode := statusCodeFromResult(result.Error)
if isModelSupportResultError(result.Error) {
next := now.Add(12 * time.Hour)
state.NextRetryAfter = next
suspendReason = "model_not_supported"
shouldSuspendModel = true
} else {
switch statusCode {
case 401:
next := now.Add(30 * time.Minute)
state.NextRetryAfter = next
suspendReason = "unauthorized"
shouldSuspendModel = true
case 402, 403:
next := now.Add(30 * time.Minute)
state.NextRetryAfter = next
suspendReason = "payment_required"
shouldSuspendModel = true
case 404:
statusCode := statusCodeFromResult(result.Error)
if isModelSupportResultError(result.Error) {
next := now.Add(12 * time.Hour)
state.NextRetryAfter = next
suspendReason = "not_found"
suspendReason = "model_not_supported"
shouldSuspendModel = true
case 429:
var next time.Time
backoffLevel := state.Quota.BackoffLevel
if result.RetryAfter != nil {
next = now.Add(*result.RetryAfter)
} else {
cooldown, nextLevel := nextQuotaCooldown(backoffLevel, quotaCooldownDisabledForAuth(auth))
if cooldown > 0 {
next = now.Add(cooldown)
}
backoffLevel = nextLevel
}
state.NextRetryAfter = next
state.Quota = QuotaState{
Exceeded: true,
Reason: "quota",
NextRecoverAt: next,
BackoffLevel: backoffLevel,
}
suspendReason = "quota"
shouldSuspendModel = true
setModelQuota = true
case 408, 500, 502, 503, 504:
if quotaCooldownDisabledForAuth(auth) {
state.NextRetryAfter = time.Time{}
} else {
next := now.Add(1 * time.Minute)
} else {
switch statusCode {
case 401:
next := now.Add(30 * time.Minute)
state.NextRetryAfter = next
suspendReason = "unauthorized"
shouldSuspendModel = true
case 402, 403:
next := now.Add(30 * time.Minute)
state.NextRetryAfter = next
suspendReason = "payment_required"
shouldSuspendModel = true
case 404:
next := now.Add(12 * time.Hour)
state.NextRetryAfter = next
suspendReason = "not_found"
shouldSuspendModel = true
case 429:
var next time.Time
backoffLevel := state.Quota.BackoffLevel
if result.RetryAfter != nil {
next = now.Add(*result.RetryAfter)
} else {
cooldown, nextLevel := nextQuotaCooldown(backoffLevel, quotaCooldownDisabledForAuth(auth))
if cooldown > 0 {
next = now.Add(cooldown)
}
backoffLevel = nextLevel
}
state.NextRetryAfter = next
state.Quota = QuotaState{
Exceeded: true,
Reason: "quota",
NextRecoverAt: next,
BackoffLevel: backoffLevel,
}
suspendReason = "quota"
shouldSuspendModel = true
setModelQuota = true
case 408, 500, 502, 503, 504:
if quotaCooldownDisabledForAuth(auth) {
state.NextRetryAfter = time.Time{}
} else {
next := now.Add(1 * time.Minute)
state.NextRetryAfter = next
}
default:
state.NextRetryAfter = time.Time{}
}
default:
state.NextRetryAfter = time.Time{}
}
}
auth.Status = StatusError
auth.UpdatedAt = now
updateAggregatedAvailability(auth, now)
auth.Status = StatusError
auth.UpdatedAt = now
updateAggregatedAvailability(auth, now)
}
} else {
applyAuthFailureState(auth, result.Error, result.RetryAfter, now)
}
@@ -2056,11 +2161,29 @@ func isModelSupportResultError(err *Error) bool {
return isModelSupportErrorMessage(err.Message)
}
func isRequestScopedNotFoundMessage(message string) bool {
if message == "" {
return false
}
lower := strings.ToLower(message)
return strings.Contains(lower, "item with id") &&
strings.Contains(lower, "not found") &&
strings.Contains(lower, "items are not persisted when `store` is set to false")
}
func isRequestScopedNotFoundResultError(err *Error) bool {
if err == nil || statusCodeFromResult(err) != http.StatusNotFound {
return false
}
return isRequestScopedNotFoundMessage(err.Message)
}
// isRequestInvalidError returns true if the error represents a client request
// error that should not be retried. Specifically, it treats 400 responses with
// "invalid_request_error" and all 422 responses as request-shape failures,
// where switching auths or pooled upstream models will not help. Model-support
// errors are excluded so routing can fall through to another auth or upstream.
// "invalid_request_error", request-scoped 404 item misses caused by `store=false`,
// and all 422 responses as request-shape failures, where switching auths or
// pooled upstream models will not help. Model-support errors are excluded so
// routing can fall through to another auth or upstream.
func isRequestInvalidError(err error) bool {
if err == nil {
return false
@@ -2072,6 +2195,8 @@ func isRequestInvalidError(err error) bool {
switch status {
case http.StatusBadRequest:
return strings.Contains(err.Error(), "invalid_request_error")
case http.StatusNotFound:
return isRequestScopedNotFoundMessage(err.Error())
case http.StatusUnprocessableEntity:
return true
default:
@@ -2083,6 +2208,9 @@ func applyAuthFailureState(auth *Auth, resultErr *Error, retryAfter *time.Durati
if auth == nil {
return
}
if isRequestScopedNotFoundResultError(resultErr) {
return
}
auth.Unavailable = true
auth.Status = StatusError
auth.UpdatedAt = now
@@ -2246,6 +2374,13 @@ func shouldRetrySchedulerPick(err error) bool {
return authErr.Code == "auth_not_found" || authErr.Code == "auth_unavailable"
}
func (m *Manager) routeAwareSelectionRequired(auth *Auth, routeModel string) bool {
if auth == nil || strings.TrimSpace(routeModel) == "" {
return false
}
return m.selectionModelKeyForAuth(auth, routeModel) != canonicalModelKey(routeModel)
}
func (m *Manager) pickNextLegacy(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, error) {
pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata)
@@ -2275,7 +2410,7 @@ func (m *Manager) pickNextLegacy(ctx context.Context, provider, model string, op
if _, used := tried[candidate.ID]; used {
continue
}
if modelKey != "" && registryRef != nil && !registryRef.ClientSupportsModel(candidate.ID, modelKey) {
if modelKey != "" && !m.authSupportsRouteModel(registryRef, candidate, model) {
continue
}
candidates = append(candidates, candidate)
@@ -2284,7 +2419,12 @@ func (m *Manager) pickNextLegacy(ctx context.Context, provider, model string, op
m.mu.RUnlock()
return nil, nil, &Error{Code: "auth_not_found", Message: "no auth available"}
}
selected, errPick := m.selector.Pick(ctx, provider, model, opts, candidates)
available, errAvailable := m.availableAuthsForRouteModel(candidates, provider, model, time.Now())
if errAvailable != nil {
m.mu.RUnlock()
return nil, nil, errAvailable
}
selected, errPick := m.selector.Pick(ctx, provider, selectionArgForSelector(m.selector, model), opts, available)
if errPick != nil {
m.mu.RUnlock()
return nil, nil, errPick
@@ -2310,6 +2450,22 @@ func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cli
if !m.useSchedulerFastPath() {
return m.pickNextLegacy(ctx, provider, model, opts, tried)
}
if strings.TrimSpace(model) != "" {
m.mu.RLock()
for _, candidate := range m.auths {
if candidate == nil || candidate.Provider != provider || candidate.Disabled {
continue
}
if _, used := tried[candidate.ID]; used {
continue
}
if m.routeAwareSelectionRequired(candidate, model) {
m.mu.RUnlock()
return m.pickNextLegacy(ctx, provider, model, opts, tried)
}
}
m.mu.RUnlock()
}
executor, okExecutor := m.Executor(provider)
if !okExecutor {
return nil, nil, &Error{Code: "executor_not_found", Message: "executor not registered"}
@@ -2383,7 +2539,7 @@ func (m *Manager) pickNextMixedLegacy(ctx context.Context, providers []string, m
if _, ok := m.executors[providerKey]; !ok {
continue
}
if modelKey != "" && registryRef != nil && !registryRef.ClientSupportsModel(candidate.ID, modelKey) {
if modelKey != "" && !m.authSupportsRouteModel(registryRef, candidate, model) {
continue
}
candidates = append(candidates, candidate)
@@ -2392,7 +2548,12 @@ func (m *Manager) pickNextMixedLegacy(ctx context.Context, providers []string, m
m.mu.RUnlock()
return nil, nil, "", &Error{Code: "auth_not_found", Message: "no auth available"}
}
selected, errPick := m.selector.Pick(ctx, "mixed", model, opts, candidates)
available, errAvailable := m.availableAuthsForRouteModel(candidates, "mixed", model, time.Now())
if errAvailable != nil {
m.mu.RUnlock()
return nil, nil, "", errAvailable
}
selected, errPick := m.selector.Pick(ctx, "mixed", selectionArgForSelector(m.selector, model), opts, available)
if errPick != nil {
m.mu.RUnlock()
return nil, nil, "", errPick
@@ -2444,6 +2605,29 @@ func (m *Manager) pickNextMixed(ctx context.Context, providers []string, model s
if len(eligibleProviders) == 0 {
return nil, nil, "", &Error{Code: "auth_not_found", Message: "no auth available"}
}
if strings.TrimSpace(model) != "" {
providerSet := make(map[string]struct{}, len(eligibleProviders))
for _, providerKey := range eligibleProviders {
providerSet[providerKey] = struct{}{}
}
m.mu.RLock()
for _, candidate := range m.auths {
if candidate == nil || candidate.Disabled {
continue
}
if _, ok := providerSet[strings.TrimSpace(strings.ToLower(candidate.Provider))]; !ok {
continue
}
if _, used := tried[candidate.ID]; used {
continue
}
if m.routeAwareSelectionRequired(candidate, model) {
m.mu.RUnlock()
return m.pickNextMixedLegacy(ctx, providers, model, opts, tried)
}
}
m.mu.RUnlock()
}
selected, providerKey, errPick := m.scheduler.pickMixed(ctx, eligibleProviders, model, opts, tried)
if errPick != nil && model != "" && shouldRetrySchedulerPick(errPick) {

View File

@@ -0,0 +1,111 @@
package auth
import (
"context"
"net/http"
"sync"
"testing"
"time"
internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
)
type aliasRoutingExecutor struct {
id string
mu sync.Mutex
executeModels []string
}
func (e *aliasRoutingExecutor) Identifier() string { return e.id }
func (e *aliasRoutingExecutor) Execute(_ context.Context, _ *Auth, req cliproxyexecutor.Request, _ cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
e.mu.Lock()
e.executeModels = append(e.executeModels, req.Model)
e.mu.Unlock()
return cliproxyexecutor.Response{Payload: []byte(req.Model)}, nil
}
func (e *aliasRoutingExecutor) ExecuteStream(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
return nil, &Error{HTTPStatus: http.StatusNotImplemented, Message: "ExecuteStream not implemented"}
}
func (e *aliasRoutingExecutor) Refresh(_ context.Context, auth *Auth) (*Auth, error) {
return auth, nil
}
func (e *aliasRoutingExecutor) CountTokens(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
return cliproxyexecutor.Response{}, &Error{HTTPStatus: http.StatusNotImplemented, Message: "CountTokens not implemented"}
}
func (e *aliasRoutingExecutor) HttpRequest(context.Context, *Auth, *http.Request) (*http.Response, error) {
return nil, &Error{HTTPStatus: http.StatusNotImplemented, Message: "HttpRequest not implemented"}
}
func (e *aliasRoutingExecutor) ExecuteModels() []string {
e.mu.Lock()
defer e.mu.Unlock()
out := make([]string, len(e.executeModels))
copy(out, e.executeModels)
return out
}
func TestManagerExecute_OAuthAliasBypassesBlockedRouteModel(t *testing.T) {
const (
provider = "antigravity"
routeModel = "claude-opus-4-6"
targetModel = "claude-opus-4-6-thinking"
)
manager := NewManager(nil, nil, nil)
executor := &aliasRoutingExecutor{id: provider}
manager.RegisterExecutor(executor)
manager.SetOAuthModelAlias(map[string][]internalconfig.OAuthModelAlias{
provider: {{
Name: targetModel,
Alias: routeModel,
Fork: true,
}},
})
auth := &Auth{
ID: "oauth-alias-auth",
Provider: provider,
Status: StatusActive,
ModelStates: map[string]*ModelState{
routeModel: {
Unavailable: true,
Status: StatusError,
NextRetryAfter: time.Now().Add(1 * time.Hour),
},
},
}
if _, errRegister := manager.Register(context.Background(), auth); errRegister != nil {
t.Fatalf("register auth: %v", errRegister)
}
reg := registry.GetGlobalRegistry()
reg.RegisterClient(auth.ID, provider, []*registry.ModelInfo{{ID: routeModel}, {ID: targetModel}})
t.Cleanup(func() {
reg.UnregisterClient(auth.ID)
})
manager.RefreshSchedulerEntry(auth.ID)
resp, errExecute := manager.Execute(context.Background(), []string{provider}, cliproxyexecutor.Request{Model: routeModel}, cliproxyexecutor.Options{})
if errExecute != nil {
t.Fatalf("execute error = %v, want success", errExecute)
}
if string(resp.Payload) != targetModel {
t.Fatalf("execute payload = %q, want %q", string(resp.Payload), targetModel)
}
gotModels := executor.ExecuteModels()
if len(gotModels) != 1 {
t.Fatalf("execute models len = %d, want 1", len(gotModels))
}
if gotModels[0] != targetModel {
t.Fatalf("execute model = %q, want %q", gotModels[0], targetModel)
}
}

View File

@@ -12,6 +12,8 @@ import (
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
)
const requestScopedNotFoundMessage = "Item with id 'rs_0b5f3eb6f51f175c0169ca74e4a85881998539920821603a74' not found. Items are not persisted when `store` is set to false. Try again with `store` set to true, or remove this item from your input."
func TestManager_ShouldRetryAfterError_RespectsAuthRequestRetryOverride(t *testing.T) {
m := NewManager(nil, nil, nil)
m.SetRetryConfig(3, 30*time.Second, 0)
@@ -447,3 +449,114 @@ func TestManager_MarkResult_RespectsAuthDisableCoolingOverride(t *testing.T) {
t.Fatalf("expected NextRetryAfter to be zero when disable_cooling=true, got %v", state.NextRetryAfter)
}
}
func TestManager_MarkResult_RequestScopedNotFoundDoesNotCooldownAuth(t *testing.T) {
m := NewManager(nil, nil, nil)
auth := &Auth{
ID: "auth-1",
Provider: "openai",
}
if _, errRegister := m.Register(context.Background(), auth); errRegister != nil {
t.Fatalf("register auth: %v", errRegister)
}
model := "gpt-4.1"
m.MarkResult(context.Background(), Result{
AuthID: auth.ID,
Provider: auth.Provider,
Model: model,
Success: false,
Error: &Error{
HTTPStatus: http.StatusNotFound,
Message: requestScopedNotFoundMessage,
},
})
updated, ok := m.GetByID(auth.ID)
if !ok || updated == nil {
t.Fatalf("expected auth to be present")
}
if updated.Unavailable {
t.Fatalf("expected request-scoped 404 to keep auth available")
}
if !updated.NextRetryAfter.IsZero() {
t.Fatalf("expected request-scoped 404 to keep auth cooldown unset, got %v", updated.NextRetryAfter)
}
if state := updated.ModelStates[model]; state != nil {
t.Fatalf("expected request-scoped 404 to avoid model cooldown state, got %#v", state)
}
}
func TestManager_RequestScopedNotFoundStopsRetryWithoutSuspendingAuth(t *testing.T) {
m := NewManager(nil, nil, nil)
executor := &authFallbackExecutor{
id: "openai",
executeErrors: map[string]error{
"aa-bad-auth": &Error{
HTTPStatus: http.StatusNotFound,
Message: requestScopedNotFoundMessage,
},
},
}
m.RegisterExecutor(executor)
model := "gpt-4.1"
badAuth := &Auth{ID: "aa-bad-auth", Provider: "openai"}
goodAuth := &Auth{ID: "bb-good-auth", Provider: "openai"}
reg := registry.GetGlobalRegistry()
reg.RegisterClient(badAuth.ID, "openai", []*registry.ModelInfo{{ID: model}})
reg.RegisterClient(goodAuth.ID, "openai", []*registry.ModelInfo{{ID: model}})
t.Cleanup(func() {
reg.UnregisterClient(badAuth.ID)
reg.UnregisterClient(goodAuth.ID)
})
if _, errRegister := m.Register(context.Background(), badAuth); errRegister != nil {
t.Fatalf("register bad auth: %v", errRegister)
}
if _, errRegister := m.Register(context.Background(), goodAuth); errRegister != nil {
t.Fatalf("register good auth: %v", errRegister)
}
_, errExecute := m.Execute(context.Background(), []string{"openai"}, cliproxyexecutor.Request{Model: model}, cliproxyexecutor.Options{})
if errExecute == nil {
t.Fatal("expected request-scoped not-found error")
}
errResult, ok := errExecute.(*Error)
if !ok {
t.Fatalf("expected *Error, got %T", errExecute)
}
if errResult.HTTPStatus != http.StatusNotFound {
t.Fatalf("status = %d, want %d", errResult.HTTPStatus, http.StatusNotFound)
}
if errResult.Message != requestScopedNotFoundMessage {
t.Fatalf("message = %q, want %q", errResult.Message, requestScopedNotFoundMessage)
}
got := executor.ExecuteCalls()
want := []string{badAuth.ID}
if len(got) != len(want) {
t.Fatalf("execute calls = %v, want %v", got, want)
}
for i := range want {
if got[i] != want[i] {
t.Fatalf("execute call %d auth = %q, want %q", i, got[i], want[i])
}
}
updatedBad, ok := m.GetByID(badAuth.ID)
if !ok || updatedBad == nil {
t.Fatalf("expected bad auth to remain registered")
}
if updatedBad.Unavailable {
t.Fatalf("expected request-scoped 404 to keep bad auth available")
}
if !updatedBad.NextRetryAfter.IsZero() {
t.Fatalf("expected request-scoped 404 to keep bad auth cooldown unset, got %v", updatedBad.NextRetryAfter)
}
if state := updatedBad.ModelStates[model]; state != nil {
t.Fatalf("expected request-scoped 404 to avoid bad auth model cooldown state, got %#v", state)
}
}

View File

@@ -219,6 +219,19 @@ func (s *authScheduler) pickMixed(ctx context.Context, providers []string, model
if len(normalized) == 0 {
return nil, "", &Error{Code: "provider_not_found", Message: "no provider supplied"}
}
if len(normalized) == 1 {
// When a single provider is eligible, reuse pickSingle so provider-specific preferences
// (for example Codex websocket transport) are applied consistently.
providerKey := normalized[0]
picked, errPick := s.pickSingle(ctx, providerKey, model, opts, tried)
if errPick != nil {
return nil, "", errPick
}
if picked == nil {
return nil, "", &Error{Code: "auth_not_found", Message: "no auth available"}
}
return picked, providerKey, nil
}
pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata)
modelKey := canonicalModelKey(model)
@@ -293,12 +306,46 @@ func (s *authScheduler) pickMixed(ctx context.Context, providers []string, model
}
cursorKey := strings.Join(normalized, ",") + ":" + modelKey
start := 0
if len(normalized) > 0 {
start = s.mixedCursors[cursorKey] % len(normalized)
weights := make([]int, len(normalized))
segmentStarts := make([]int, len(normalized))
segmentEnds := make([]int, len(normalized))
totalWeight := 0
for providerIndex, shard := range candidateShards {
segmentStarts[providerIndex] = totalWeight
if shard != nil {
weights[providerIndex] = shard.readyCountAtPriorityLocked(false, bestPriority)
}
totalWeight += weights[providerIndex]
segmentEnds[providerIndex] = totalWeight
}
if totalWeight == 0 {
return nil, "", s.mixedUnavailableErrorLocked(normalized, model, tried)
}
startSlot := s.mixedCursors[cursorKey] % totalWeight
startProviderIndex := -1
for providerIndex := range normalized {
if weights[providerIndex] == 0 {
continue
}
if startSlot < segmentEnds[providerIndex] {
startProviderIndex = providerIndex
break
}
}
if startProviderIndex < 0 {
return nil, "", s.mixedUnavailableErrorLocked(normalized, model, tried)
}
slot := startSlot
for offset := 0; offset < len(normalized); offset++ {
providerIndex := (start + offset) % len(normalized)
providerIndex := (startProviderIndex + offset) % len(normalized)
if weights[providerIndex] == 0 {
continue
}
if providerIndex != startProviderIndex {
slot = segmentStarts[providerIndex]
}
providerKey := normalized[providerIndex]
shard := candidateShards[providerIndex]
if shard == nil {
@@ -308,7 +355,7 @@ func (s *authScheduler) pickMixed(ctx context.Context, providers []string, model
if picked == nil {
continue
}
s.mixedCursors[cursorKey] = providerIndex + 1
s.mixedCursors[cursorKey] = slot + 1
return picked, providerKey, nil
}
return nil, "", s.mixedUnavailableErrorLocked(normalized, model, tried)
@@ -667,16 +714,25 @@ func (m *modelScheduler) highestReadyPriorityLocked(preferWebsocket bool, predic
if m == nil {
return 0, false
}
if preferWebsocket {
// When downstream is websocket and Codex supports websocket transport, prefer websocket-enabled
// credentials even if they are in a lower priority tier than HTTP-only credentials.
for _, priority := range m.priorityOrder {
bucket := m.readyByPriority[priority]
if bucket == nil {
continue
}
if bucket.ws.pickFirst(predicate) != nil {
return priority, true
}
}
}
for _, priority := range m.priorityOrder {
bucket := m.readyByPriority[priority]
if bucket == nil {
continue
}
view := &bucket.all
if preferWebsocket && len(bucket.ws.flat) > 0 {
view = &bucket.ws
}
if view.pickFirst(predicate) != nil {
if bucket.all.pickFirst(predicate) != nil {
return priority, true
}
}
@@ -694,7 +750,7 @@ func (m *modelScheduler) pickReadyAtPriorityLocked(preferWebsocket bool, priorit
return nil
}
view := &bucket.all
if preferWebsocket && len(bucket.ws.flat) > 0 {
if preferWebsocket && bucket.ws.pickFirst(predicate) != nil {
view = &bucket.ws
}
var picked *scheduledAuth
@@ -709,6 +765,20 @@ func (m *modelScheduler) pickReadyAtPriorityLocked(preferWebsocket bool, priorit
return picked.auth
}
func (m *modelScheduler) readyCountAtPriorityLocked(preferWebsocket bool, priority int) int {
if m == nil {
return 0
}
bucket := m.readyByPriority[priority]
if bucket == nil {
return 0
}
if preferWebsocket && len(bucket.ws.flat) > 0 {
return len(bucket.ws.flat)
}
return len(bucket.all.flat)
}
// unavailableErrorLocked returns the correct unavailable or cooldown error for the shard.
func (m *modelScheduler) unavailableErrorLocked(provider, model string, predicate func(*scheduledAuth) bool) error {
now := time.Now()

View File

@@ -208,7 +208,33 @@ func TestSchedulerPick_CodexWebsocketPrefersWebsocketEnabledSubset(t *testing.T)
}
}
func TestSchedulerPick_MixedProvidersUsesProviderRotationOverReadyCandidates(t *testing.T) {
func TestSchedulerPick_CodexWebsocketPrefersWebsocketEnabledAcrossPriorities(t *testing.T) {
t.Parallel()
scheduler := newSchedulerForTest(
&RoundRobinSelector{},
&Auth{ID: "codex-http", Provider: "codex", Attributes: map[string]string{"priority": "10"}},
&Auth{ID: "codex-ws-a", Provider: "codex", Attributes: map[string]string{"priority": "0", "websockets": "true"}},
&Auth{ID: "codex-ws-b", Provider: "codex", Attributes: map[string]string{"priority": "0", "websockets": "true"}},
)
ctx := cliproxyexecutor.WithDownstreamWebsocket(context.Background())
want := []string{"codex-ws-a", "codex-ws-b", "codex-ws-a"}
for index, wantID := range want {
got, errPick := scheduler.pickSingle(ctx, "codex", "", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("pickSingle() #%d error = %v", index, errPick)
}
if got == nil {
t.Fatalf("pickSingle() #%d auth = nil", index)
}
if got.ID != wantID {
t.Fatalf("pickSingle() #%d auth.ID = %q, want %q", index, got.ID, wantID)
}
}
}
func TestSchedulerPick_MixedProvidersUsesWeightedProviderRotationOverReadyCandidates(t *testing.T) {
t.Parallel()
scheduler := newSchedulerForTest(
@@ -218,8 +244,8 @@ func TestSchedulerPick_MixedProvidersUsesProviderRotationOverReadyCandidates(t *
&Auth{ID: "claude-a", Provider: "claude"},
)
wantProviders := []string{"gemini", "claude", "gemini", "claude"}
wantIDs := []string{"gemini-a", "claude-a", "gemini-b", "claude-a"}
wantProviders := []string{"gemini", "gemini", "claude", "gemini"}
wantIDs := []string{"gemini-a", "gemini-b", "claude-a", "gemini-a"}
for index := range wantProviders {
got, provider, errPick := scheduler.pickMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, nil)
if errPick != nil {
@@ -272,7 +298,7 @@ func TestSchedulerPick_MixedProvidersPrefersHighestPriorityTier(t *testing.T) {
}
}
func TestManager_PickNextMixed_UsesProviderRotationBeforeCredentialRotation(t *testing.T) {
func TestManager_PickNextMixed_UsesWeightedProviderRotationBeforeCredentialRotation(t *testing.T) {
t.Parallel()
manager := NewManager(nil, &RoundRobinSelector{}, nil)
@@ -288,8 +314,8 @@ func TestManager_PickNextMixed_UsesProviderRotationBeforeCredentialRotation(t *t
t.Fatalf("Register(claude-a) error = %v", errRegister)
}
wantProviders := []string{"gemini", "claude", "gemini", "claude"}
wantIDs := []string{"gemini-a", "claude-a", "gemini-b", "claude-a"}
wantProviders := []string{"gemini", "gemini", "claude", "gemini"}
wantIDs := []string{"gemini-a", "gemini-b", "claude-a", "gemini-a"}
for index := range wantProviders {
got, _, provider, errPick := manager.pickNextMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, map[string]struct{}{})
if errPick != nil {
@@ -399,8 +425,8 @@ func TestManager_PickNextMixed_UsesSchedulerRotation(t *testing.T) {
t.Fatalf("Register(claude-a) error = %v", errRegister)
}
wantProviders := []string{"gemini", "claude", "gemini", "claude"}
wantIDs := []string{"gemini-a", "claude-a", "gemini-b", "claude-a"}
wantProviders := []string{"gemini", "gemini", "claude", "gemini"}
wantIDs := []string{"gemini-a", "gemini-b", "claude-a", "gemini-a"}
for index := range wantProviders {
got, _, provider, errPick := manager.pickNextMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, nil)
if errPick != nil {

View File

@@ -347,6 +347,7 @@ func (s *Service) applyCoreAuthRemoval(ctx context.Context, id string) {
log.Errorf("failed to disable auth %s: %v", id, err)
}
if strings.EqualFold(strings.TrimSpace(existing.Provider), "codex") {
executor.CloseCodexWebsocketSessionsForAuthID(existing.ID, "auth_removed")
s.ensureExecutorsForAuth(existing)
}
}