diff --git a/internal/api/handlers/management/api_tools.go b/internal/api/handlers/management/api_tools.go index de546ea8..cb4805e9 100644 --- a/internal/api/handlers/management/api_tools.go +++ b/internal/api/handlers/management/api_tools.go @@ -11,6 +11,7 @@ import ( "time" "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli" coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" "github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil" @@ -636,6 +637,11 @@ func (h *Handler) apiCallTransport(auth *coreauth.Auth) http.RoundTripper { if proxyStr := strings.TrimSpace(auth.ProxyURL); proxyStr != "" { proxyCandidates = append(proxyCandidates, proxyStr) } + if h != nil && h.cfg != nil { + if proxyStr := strings.TrimSpace(proxyURLFromAPIKeyConfig(h.cfg, auth)); proxyStr != "" { + proxyCandidates = append(proxyCandidates, proxyStr) + } + } } if h != nil && h.cfg != nil { if proxyStr := strings.TrimSpace(h.cfg.ProxyURL); proxyStr != "" { @@ -658,6 +664,123 @@ func (h *Handler) apiCallTransport(auth *coreauth.Auth) http.RoundTripper { return clone } +type apiKeyConfigEntry interface { + GetAPIKey() string + GetBaseURL() string +} + +func resolveAPIKeyConfig[T apiKeyConfigEntry](entries []T, auth *coreauth.Auth) *T { + if auth == nil || len(entries) == 0 { + return nil + } + attrKey, attrBase := "", "" + if auth.Attributes != nil { + attrKey = strings.TrimSpace(auth.Attributes["api_key"]) + attrBase = strings.TrimSpace(auth.Attributes["base_url"]) + } + for i := range entries { + entry := &entries[i] + cfgKey := strings.TrimSpace((*entry).GetAPIKey()) + cfgBase := strings.TrimSpace((*entry).GetBaseURL()) + if attrKey != "" && attrBase != "" { + if strings.EqualFold(cfgKey, attrKey) && strings.EqualFold(cfgBase, attrBase) { + return entry + } + continue + } + if attrKey != "" && strings.EqualFold(cfgKey, attrKey) { + if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) { + return entry + } + } + if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) { + return entry + } + } + if attrKey != "" { + for i := range entries { + entry := &entries[i] + if strings.EqualFold(strings.TrimSpace((*entry).GetAPIKey()), attrKey) { + return entry + } + } + } + return nil +} + +func proxyURLFromAPIKeyConfig(cfg *config.Config, auth *coreauth.Auth) string { + if cfg == nil || auth == nil { + return "" + } + authKind, authAccount := auth.AccountInfo() + if !strings.EqualFold(strings.TrimSpace(authKind), "api_key") { + return "" + } + + attrs := auth.Attributes + compatName := "" + providerKey := "" + if len(attrs) > 0 { + compatName = strings.TrimSpace(attrs["compat_name"]) + providerKey = strings.TrimSpace(attrs["provider_key"]) + } + if compatName != "" || strings.EqualFold(strings.TrimSpace(auth.Provider), "openai-compatibility") { + return resolveOpenAICompatAPIKeyProxyURL(cfg, auth, strings.TrimSpace(authAccount), providerKey, compatName) + } + + switch strings.ToLower(strings.TrimSpace(auth.Provider)) { + case "gemini": + if entry := resolveAPIKeyConfig(cfg.GeminiKey, auth); entry != nil { + return strings.TrimSpace(entry.ProxyURL) + } + case "claude": + if entry := resolveAPIKeyConfig(cfg.ClaudeKey, auth); entry != nil { + return strings.TrimSpace(entry.ProxyURL) + } + case "codex": + if entry := resolveAPIKeyConfig(cfg.CodexKey, auth); entry != nil { + return strings.TrimSpace(entry.ProxyURL) + } + } + return "" +} + +func resolveOpenAICompatAPIKeyProxyURL(cfg *config.Config, auth *coreauth.Auth, apiKey, providerKey, compatName string) string { + if cfg == nil || auth == nil { + return "" + } + apiKey = strings.TrimSpace(apiKey) + if apiKey == "" { + return "" + } + candidates := make([]string, 0, 3) + if v := strings.TrimSpace(compatName); v != "" { + candidates = append(candidates, v) + } + if v := strings.TrimSpace(providerKey); v != "" { + candidates = append(candidates, v) + } + if v := strings.TrimSpace(auth.Provider); v != "" { + candidates = append(candidates, v) + } + + for i := range cfg.OpenAICompatibility { + compat := &cfg.OpenAICompatibility[i] + for _, candidate := range candidates { + if candidate != "" && strings.EqualFold(strings.TrimSpace(candidate), compat.Name) { + for j := range compat.APIKeyEntries { + entry := &compat.APIKeyEntries[j] + if strings.EqualFold(strings.TrimSpace(entry.APIKey), apiKey) { + return strings.TrimSpace(entry.ProxyURL) + } + } + return "" + } + } + } + return "" +} + func buildProxyTransport(proxyStr string) *http.Transport { transport, _, errBuild := proxyutil.BuildHTTPTransport(proxyStr) if errBuild != nil { diff --git a/internal/api/handlers/management/api_tools_test.go b/internal/api/handlers/management/api_tools_test.go index 6ed98c6e..b27fe639 100644 --- a/internal/api/handlers/management/api_tools_test.go +++ b/internal/api/handlers/management/api_tools_test.go @@ -58,6 +58,105 @@ func TestAPICallTransportInvalidAuthFallsBackToGlobalProxy(t *testing.T) { } } +func TestAPICallTransportAPIKeyAuthFallsBackToConfigProxyURL(t *testing.T) { + t.Parallel() + + h := &Handler{ + cfg: &config.Config{ + SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"}, + GeminiKey: []config.GeminiKey{{ + APIKey: "gemini-key", + ProxyURL: "http://gemini-proxy.example.com:8080", + }}, + ClaudeKey: []config.ClaudeKey{{ + APIKey: "claude-key", + ProxyURL: "http://claude-proxy.example.com:8080", + }}, + CodexKey: []config.CodexKey{{ + APIKey: "codex-key", + ProxyURL: "http://codex-proxy.example.com:8080", + }}, + OpenAICompatibility: []config.OpenAICompatibility{{ + Name: "bohe", + BaseURL: "https://bohe.example.com", + APIKeyEntries: []config.OpenAICompatibilityAPIKey{{ + APIKey: "compat-key", + ProxyURL: "http://compat-proxy.example.com:8080", + }}, + }}, + }, + } + + cases := []struct { + name string + auth *coreauth.Auth + wantProxy string + }{ + { + name: "gemini", + auth: &coreauth.Auth{ + Provider: "gemini", + Attributes: map[string]string{"api_key": "gemini-key"}, + }, + wantProxy: "http://gemini-proxy.example.com:8080", + }, + { + name: "claude", + auth: &coreauth.Auth{ + Provider: "claude", + Attributes: map[string]string{"api_key": "claude-key"}, + }, + wantProxy: "http://claude-proxy.example.com:8080", + }, + { + name: "codex", + auth: &coreauth.Auth{ + Provider: "codex", + Attributes: map[string]string{"api_key": "codex-key"}, + }, + wantProxy: "http://codex-proxy.example.com:8080", + }, + { + name: "openai-compatibility", + auth: &coreauth.Auth{ + Provider: "bohe", + Attributes: map[string]string{ + "api_key": "compat-key", + "compat_name": "bohe", + "provider_key": "bohe", + }, + }, + wantProxy: "http://compat-proxy.example.com:8080", + }, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + transport := h.apiCallTransport(tc.auth) + httpTransport, ok := transport.(*http.Transport) + if !ok { + t.Fatalf("transport type = %T, want *http.Transport", transport) + } + + req, errRequest := http.NewRequest(http.MethodGet, "https://example.com", nil) + if errRequest != nil { + t.Fatalf("http.NewRequest returned error: %v", errRequest) + } + + proxyURL, errProxy := httpTransport.Proxy(req) + if errProxy != nil { + t.Fatalf("httpTransport.Proxy returned error: %v", errProxy) + } + if proxyURL == nil || proxyURL.String() != tc.wantProxy { + t.Fatalf("proxy URL = %v, want %s", proxyURL, tc.wantProxy) + } + }) + } +} + func TestAuthByIndexDistinguishesSharedAPIKeysAcrossProviders(t *testing.T) { t.Parallel()