mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-05-13 23:41:36 +00:00
Merge branch 'main' into plus
This commit is contained in:
151
sdk/api/handlers/claude/gitlab_duo_handler_test.go
Normal file
151
sdk/api/handlers/claude/gitlab_duo_handler_test.go
Normal file
@@ -0,0 +1,151 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
runtimeexecutor "github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
)
|
||||
|
||||
func TestClaudeMessagesWithGitLabDuoAnthropicGateway(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
var gotPath, gotAuthHeader, gotRealmHeader string
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotPath = r.URL.Path
|
||||
gotAuthHeader = r.Header.Get("Authorization")
|
||||
gotRealmHeader = r.Header.Get("X-Gitlab-Realm")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"id":"msg_1","type":"message","role":"assistant","model":"claude-sonnet-4-5","content":[{"type":"tool_use","id":"toolu_1","name":"Bash","input":{"cmd":"ls"}}],"stop_reason":"tool_use","stop_sequence":null,"usage":{"input_tokens":11,"output_tokens":4}}`))
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
manager, _ := registerGitLabDuoAnthropicAuth(t, upstream.URL)
|
||||
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
|
||||
h := NewClaudeCodeAPIHandler(base)
|
||||
router := gin.New()
|
||||
router.POST("/v1/messages", h.ClaudeMessages)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(`{
|
||||
"model":"claude-sonnet-4-5",
|
||||
"max_tokens":128,
|
||||
"messages":[{"role":"user","content":"list files"}],
|
||||
"tools":[{"name":"Bash","description":"run bash","input_schema":{"type":"object","properties":{"cmd":{"type":"string"}},"required":["cmd"]}}]
|
||||
}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Anthropic-Version", "2023-06-01")
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d body=%s", resp.Code, http.StatusOK, resp.Body.String())
|
||||
}
|
||||
if gotPath != "/v1/proxy/anthropic/v1/messages" {
|
||||
t.Fatalf("path = %q, want %q", gotPath, "/v1/proxy/anthropic/v1/messages")
|
||||
}
|
||||
if gotAuthHeader != "Bearer gateway-token" {
|
||||
t.Fatalf("authorization = %q, want Bearer gateway-token", gotAuthHeader)
|
||||
}
|
||||
if gotRealmHeader != "saas" {
|
||||
t.Fatalf("x-gitlab-realm = %q, want saas", gotRealmHeader)
|
||||
}
|
||||
if !strings.Contains(resp.Body.String(), `"tool_use"`) {
|
||||
t.Fatalf("expected tool_use response, got %s", resp.Body.String())
|
||||
}
|
||||
if !strings.Contains(resp.Body.String(), `"Bash"`) {
|
||||
t.Fatalf("expected Bash tool in response, got %s", resp.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeMessagesStreamWithGitLabDuoAnthropicGateway(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
var gotPath string
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotPath = r.URL.Path
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = w.Write([]byte("event: message_start\n"))
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_1\",\"type\":\"message\",\"role\":\"assistant\",\"model\":\"claude-sonnet-4-5\",\"content\":[],\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":0,\"output_tokens\":0}}}\n\n"))
|
||||
_, _ = w.Write([]byte("event: content_block_start\n"))
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n"))
|
||||
_, _ = w.Write([]byte("event: content_block_delta\n"))
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"hello from duo\"}}\n\n"))
|
||||
_, _ = w.Write([]byte("event: message_delta\n"))
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\",\"stop_sequence\":null},\"usage\":{\"input_tokens\":10,\"output_tokens\":3}}\n\n"))
|
||||
_, _ = w.Write([]byte("event: message_stop\n"))
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"message_stop\"}\n\n"))
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
manager, _ := registerGitLabDuoAnthropicAuth(t, upstream.URL)
|
||||
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
|
||||
h := NewClaudeCodeAPIHandler(base)
|
||||
router := gin.New()
|
||||
router.POST("/v1/messages", h.ClaudeMessages)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(`{
|
||||
"model":"claude-sonnet-4-5",
|
||||
"stream":true,
|
||||
"max_tokens":64,
|
||||
"messages":[{"role":"user","content":"hello"}]
|
||||
}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Anthropic-Version", "2023-06-01")
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d body=%s", resp.Code, http.StatusOK, resp.Body.String())
|
||||
}
|
||||
if gotPath != "/v1/proxy/anthropic/v1/messages" {
|
||||
t.Fatalf("path = %q, want %q", gotPath, "/v1/proxy/anthropic/v1/messages")
|
||||
}
|
||||
if got := resp.Header().Get("Content-Type"); got != "text/event-stream" {
|
||||
t.Fatalf("content-type = %q, want text/event-stream", got)
|
||||
}
|
||||
if !strings.Contains(resp.Body.String(), "event: content_block_delta") {
|
||||
t.Fatalf("expected streamed claude event, got %s", resp.Body.String())
|
||||
}
|
||||
if !strings.Contains(resp.Body.String(), "hello from duo") {
|
||||
t.Fatalf("expected streamed text, got %s", resp.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func registerGitLabDuoAnthropicAuth(t *testing.T, upstreamURL string) (*coreauth.Manager, string) {
|
||||
t.Helper()
|
||||
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
manager.RegisterExecutor(runtimeexecutor.NewGitLabExecutor(&internalconfig.Config{}))
|
||||
|
||||
auth := &coreauth.Auth{
|
||||
ID: "gitlab-duo-claude-handler-test",
|
||||
Provider: "gitlab",
|
||||
Status: coreauth.StatusActive,
|
||||
Metadata: map[string]any{
|
||||
"duo_gateway_base_url": upstreamURL,
|
||||
"duo_gateway_token": "gateway-token",
|
||||
"duo_gateway_headers": map[string]string{"X-Gitlab-Realm": "saas"},
|
||||
"model_provider": "anthropic",
|
||||
"model_name": "claude-sonnet-4-5",
|
||||
},
|
||||
}
|
||||
registered, err := manager.Register(context.Background(), auth)
|
||||
if err != nil {
|
||||
t.Fatalf("register auth: %v", err)
|
||||
}
|
||||
|
||||
registry.GetGlobalRegistry().RegisterClient(registered.ID, registered.Provider, runtimeexecutor.GitLabModelsFromAuth(registered))
|
||||
t.Cleanup(func() {
|
||||
registry.GetGlobalRegistry().UnregisterClient(registered.ID)
|
||||
})
|
||||
return manager, registered.ID
|
||||
}
|
||||
@@ -274,10 +274,11 @@ type BaseAPIHandler struct {
|
||||
// Returns:
|
||||
// - *BaseAPIHandler: A new API handlers instance
|
||||
func NewBaseAPIHandlers(cfg *config.SDKConfig, authManager *coreauth.Manager) *BaseAPIHandler {
|
||||
return &BaseAPIHandler{
|
||||
h := &BaseAPIHandler{
|
||||
Cfg: cfg,
|
||||
AuthManager: authManager,
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
// UpdateClients updates the handlers' client list and configuration.
|
||||
|
||||
46
sdk/api/handlers/openai/endpoint_compat.go
Normal file
46
sdk/api/handlers/openai/endpoint_compat.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
)
|
||||
|
||||
const (
|
||||
openAIChatEndpoint = "/chat/completions"
|
||||
openAIResponsesEndpoint = "/responses"
|
||||
)
|
||||
|
||||
func resolveEndpointOverride(modelName, requestedEndpoint string) (string, bool) {
|
||||
if modelName == "" {
|
||||
return "", false
|
||||
}
|
||||
info := registry.GetGlobalRegistry().GetModelInfo(modelName, "")
|
||||
if info == nil {
|
||||
baseModel := thinking.ParseSuffix(modelName).ModelName
|
||||
if baseModel != "" && baseModel != modelName {
|
||||
info = registry.GetGlobalRegistry().GetModelInfo(baseModel, "")
|
||||
}
|
||||
}
|
||||
if info == nil || len(info.SupportedEndpoints) == 0 {
|
||||
return "", false
|
||||
}
|
||||
if endpointListContains(info.SupportedEndpoints, requestedEndpoint) {
|
||||
return "", false
|
||||
}
|
||||
if requestedEndpoint == openAIChatEndpoint && endpointListContains(info.SupportedEndpoints, openAIResponsesEndpoint) {
|
||||
return openAIResponsesEndpoint, true
|
||||
}
|
||||
if requestedEndpoint == openAIResponsesEndpoint && endpointListContains(info.SupportedEndpoints, openAIChatEndpoint) {
|
||||
return openAIChatEndpoint, true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func endpointListContains(items []string, value string) bool {
|
||||
for _, item := range items {
|
||||
if item == value {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
29
sdk/api/handlers/openai/endpoint_compat_test.go
Normal file
29
sdk/api/handlers/openai/endpoint_compat_test.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
)
|
||||
|
||||
func TestResolveEndpointOverride_StripsThinkingSuffix(t *testing.T) {
|
||||
const clientID = "test-endpoint-compat-suffix"
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.RegisterClient(clientID, "github-copilot", []*registry.ModelInfo{
|
||||
{
|
||||
ID: "test-gemini-chat-only",
|
||||
SupportedEndpoints: []string{openAIChatEndpoint},
|
||||
},
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
reg.UnregisterClient(clientID)
|
||||
})
|
||||
|
||||
override, ok := resolveEndpointOverride("test-gemini-chat-only(high)", openAIResponsesEndpoint)
|
||||
if !ok {
|
||||
t.Fatalf("expected endpoint override to be resolved")
|
||||
}
|
||||
if override != openAIChatEndpoint {
|
||||
t.Fatalf("override endpoint = %q, want %q", override, openAIChatEndpoint)
|
||||
}
|
||||
}
|
||||
143
sdk/api/handlers/openai/gitlab_duo_handler_test.go
Normal file
143
sdk/api/handlers/openai/gitlab_duo_handler_test.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
runtimeexecutor "github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor"
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
)
|
||||
|
||||
func TestOpenAIChatCompletionsWithGitLabDuoOpenAIGateway(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
var gotPath, gotAuthHeader, gotRealmHeader string
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotPath = r.URL.Path
|
||||
gotAuthHeader = r.Header.Get("Authorization")
|
||||
gotRealmHeader = r.Header.Get("X-Gitlab-Realm")
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp_1\",\"created_at\":1710000000,\"model\":\"gpt-5-codex\"}}\n\n"))
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"response.output_text.delta\",\"delta\":\"hello from duo openai\"}\n\n"))
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"created_at\":1710000000,\"model\":\"gpt-5-codex\",\"status\":\"completed\",\"output\":[{\"type\":\"message\",\"id\":\"msg_1\",\"role\":\"assistant\",\"content\":[{\"type\":\"output_text\",\"text\":\"hello from duo openai\"}]}],\"usage\":{\"input_tokens\":11,\"output_tokens\":4,\"total_tokens\":15}}}\n\n"))
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
manager := registerGitLabDuoOpenAIAuth(t, upstream.URL)
|
||||
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
|
||||
h := NewOpenAIAPIHandler(base)
|
||||
router := gin.New()
|
||||
router.POST("/v1/chat/completions", h.ChatCompletions)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(`{
|
||||
"model":"gpt-5-codex",
|
||||
"messages":[{"role":"user","content":"hello"}]
|
||||
}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d body=%s", resp.Code, http.StatusOK, resp.Body.String())
|
||||
}
|
||||
if gotPath != "/v1/proxy/openai/v1/responses" {
|
||||
t.Fatalf("path = %q, want %q", gotPath, "/v1/proxy/openai/v1/responses")
|
||||
}
|
||||
if gotAuthHeader != "Bearer gateway-token" {
|
||||
t.Fatalf("authorization = %q, want Bearer gateway-token", gotAuthHeader)
|
||||
}
|
||||
if gotRealmHeader != "saas" {
|
||||
t.Fatalf("x-gitlab-realm = %q, want saas", gotRealmHeader)
|
||||
}
|
||||
if !strings.Contains(resp.Body.String(), `"content":"hello from duo openai"`) {
|
||||
t.Fatalf("expected translated chat completion, got %s", resp.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIResponsesStreamWithGitLabDuoOpenAIGateway(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
var gotPath, gotAuthHeader string
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotPath = r.URL.Path
|
||||
gotAuthHeader = r.Header.Get("Authorization")
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp_1\",\"created_at\":1710000000,\"model\":\"gpt-5-codex\"}}\n\n"))
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"response.output_text.delta\",\"delta\":\"streamed duo output\"}\n\n"))
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"created_at\":1710000000,\"model\":\"gpt-5-codex\",\"status\":\"completed\",\"output\":[{\"type\":\"message\",\"id\":\"msg_1\",\"role\":\"assistant\",\"content\":[{\"type\":\"output_text\",\"text\":\"streamed duo output\"}]}],\"usage\":{\"input_tokens\":10,\"output_tokens\":3,\"total_tokens\":13}}}\n\n"))
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
manager := registerGitLabDuoOpenAIAuth(t, upstream.URL)
|
||||
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
|
||||
h := NewOpenAIResponsesAPIHandler(base)
|
||||
router := gin.New()
|
||||
router.POST("/v1/responses", h.Responses)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(`{
|
||||
"model":"gpt-5-codex",
|
||||
"stream":true,
|
||||
"input":"hello"
|
||||
}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d body=%s", resp.Code, http.StatusOK, resp.Body.String())
|
||||
}
|
||||
if gotPath != "/v1/proxy/openai/v1/responses" {
|
||||
t.Fatalf("path = %q, want %q", gotPath, "/v1/proxy/openai/v1/responses")
|
||||
}
|
||||
if gotAuthHeader != "Bearer gateway-token" {
|
||||
t.Fatalf("authorization = %q, want Bearer gateway-token", gotAuthHeader)
|
||||
}
|
||||
if got := resp.Header().Get("Content-Type"); got != "text/event-stream" {
|
||||
t.Fatalf("content-type = %q, want text/event-stream", got)
|
||||
}
|
||||
if !strings.Contains(resp.Body.String(), `"type":"response.output_text.delta"`) {
|
||||
t.Fatalf("expected streamed responses delta, got %s", resp.Body.String())
|
||||
}
|
||||
if !strings.Contains(resp.Body.String(), `"type":"response.completed"`) {
|
||||
t.Fatalf("expected streamed responses completion, got %s", resp.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func registerGitLabDuoOpenAIAuth(t *testing.T, upstreamURL string) *coreauth.Manager {
|
||||
t.Helper()
|
||||
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
manager.RegisterExecutor(runtimeexecutor.NewGitLabExecutor(&internalconfig.Config{}))
|
||||
|
||||
auth := &coreauth.Auth{
|
||||
ID: "gitlab-duo-openai-handler-test",
|
||||
Provider: "gitlab",
|
||||
Status: coreauth.StatusActive,
|
||||
Metadata: map[string]any{
|
||||
"duo_gateway_base_url": upstreamURL,
|
||||
"duo_gateway_token": "gateway-token",
|
||||
"duo_gateway_headers": map[string]string{"X-Gitlab-Realm": "saas"},
|
||||
"model_provider": "openai",
|
||||
"model_name": "gpt-5-codex",
|
||||
},
|
||||
}
|
||||
registered, err := manager.Register(context.Background(), auth)
|
||||
if err != nil {
|
||||
t.Fatalf("register auth: %v", err)
|
||||
}
|
||||
|
||||
registry.GetGlobalRegistry().RegisterClient(registered.ID, registered.Provider, runtimeexecutor.GitLabModelsFromAuth(registered))
|
||||
t.Cleanup(func() {
|
||||
registry.GetGlobalRegistry().UnregisterClient(registered.ID)
|
||||
})
|
||||
return manager
|
||||
}
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
codexconverter "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/codex/openai/chat-completions"
|
||||
responsesconverter "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/openai/openai/responses"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -112,6 +113,23 @@ func (h *OpenAIAPIHandler) ChatCompletions(c *gin.Context) {
|
||||
streamResult := gjson.GetBytes(rawJSON, "stream")
|
||||
stream := streamResult.Type == gjson.True
|
||||
|
||||
modelName := gjson.GetBytes(rawJSON, "model").String()
|
||||
if overrideEndpoint, ok := resolveEndpointOverride(modelName, openAIChatEndpoint); ok && overrideEndpoint == openAIResponsesEndpoint {
|
||||
originalChat := rawJSON
|
||||
if shouldTreatAsResponsesFormat(rawJSON) {
|
||||
// Already responses-style payload; no conversion needed.
|
||||
} else {
|
||||
rawJSON = codexconverter.ConvertOpenAIRequestToCodex(modelName, rawJSON, stream)
|
||||
}
|
||||
stream = gjson.GetBytes(rawJSON, "stream").Bool()
|
||||
if stream {
|
||||
h.handleStreamingResponseViaResponses(c, rawJSON, originalChat)
|
||||
} else {
|
||||
h.handleNonStreamingResponseViaResponses(c, rawJSON, originalChat)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Some clients send OpenAI Responses-format payloads to /v1/chat/completions.
|
||||
// Convert them to Chat Completions so downstream translators preserve tool metadata.
|
||||
if shouldTreatAsResponsesFormat(rawJSON) {
|
||||
@@ -245,6 +263,76 @@ func convertCompletionsRequestToChatCompletions(rawJSON []byte) []byte {
|
||||
return out
|
||||
}
|
||||
|
||||
func convertResponsesObjectToChatCompletion(ctx context.Context, modelName string, originalChatJSON, responsesRequestJSON, responsesPayload []byte) []byte {
|
||||
if len(responsesPayload) == 0 {
|
||||
return nil
|
||||
}
|
||||
wrapped := wrapResponsesPayloadAsCompleted(responsesPayload)
|
||||
if len(wrapped) == 0 {
|
||||
return nil
|
||||
}
|
||||
var param any
|
||||
converted := codexconverter.ConvertCodexResponseToOpenAINonStream(ctx, modelName, originalChatJSON, responsesRequestJSON, wrapped, ¶m)
|
||||
if len(converted) == 0 {
|
||||
return nil
|
||||
}
|
||||
return converted
|
||||
}
|
||||
|
||||
func wrapResponsesPayloadAsCompleted(payload []byte) []byte {
|
||||
if gjson.GetBytes(payload, "type").Exists() {
|
||||
return payload
|
||||
}
|
||||
if gjson.GetBytes(payload, "object").String() != "response" {
|
||||
return payload
|
||||
}
|
||||
wrapped := `{"type":"response.completed","response":{}}`
|
||||
wrapped, _ = sjson.SetRaw(wrapped, "response", string(payload))
|
||||
return []byte(wrapped)
|
||||
}
|
||||
|
||||
func writeConvertedResponsesChunk(c *gin.Context, ctx context.Context, modelName string, originalChatJSON, responsesRequestJSON, chunk []byte, param *any) {
|
||||
outputs := codexconverter.ConvertCodexResponseToOpenAI(ctx, modelName, originalChatJSON, responsesRequestJSON, chunk, param)
|
||||
for _, out := range outputs {
|
||||
if len(out) == 0 {
|
||||
continue
|
||||
}
|
||||
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", out)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OpenAIAPIHandler) forwardResponsesAsChatStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage, ctx context.Context, modelName string, originalChatJSON, responsesRequestJSON []byte, param *any) {
|
||||
h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{
|
||||
WriteChunk: func(chunk []byte) {
|
||||
outputs := codexconverter.ConvertCodexResponseToOpenAI(ctx, modelName, originalChatJSON, responsesRequestJSON, chunk, param)
|
||||
for _, out := range outputs {
|
||||
if len(out) == 0 {
|
||||
continue
|
||||
}
|
||||
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", out)
|
||||
}
|
||||
},
|
||||
WriteTerminalError: func(errMsg *interfaces.ErrorMessage) {
|
||||
if errMsg == nil {
|
||||
return
|
||||
}
|
||||
status := http.StatusInternalServerError
|
||||
if errMsg.StatusCode > 0 {
|
||||
status = errMsg.StatusCode
|
||||
}
|
||||
errText := http.StatusText(status)
|
||||
if errMsg.Error != nil && errMsg.Error.Error() != "" {
|
||||
errText = errMsg.Error.Error()
|
||||
}
|
||||
body := handlers.BuildErrorResponseBody(status, errText)
|
||||
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(body))
|
||||
},
|
||||
WriteDone: func() {
|
||||
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// convertChatCompletionsResponseToCompletions converts chat completions API response back to completions format.
|
||||
// This ensures the completions endpoint returns data in the expected format.
|
||||
//
|
||||
@@ -442,6 +530,31 @@ func (h *OpenAIAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSON []
|
||||
cliCancel()
|
||||
}
|
||||
|
||||
func (h *OpenAIAPIHandler) handleNonStreamingResponseViaResponses(c *gin.Context, rawJSON []byte, originalChatJSON []byte) {
|
||||
c.Header("Content-Type", "application/json")
|
||||
|
||||
modelName := gjson.GetBytes(rawJSON, "model").String()
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, OpenaiResponse, modelName, rawJSON, h.GetAlt(c))
|
||||
if errMsg != nil {
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
cliCancel(errMsg.Error)
|
||||
return
|
||||
}
|
||||
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||
converted := convertResponsesObjectToChatCompletion(cliCtx, modelName, originalChatJSON, rawJSON, resp)
|
||||
if converted == nil {
|
||||
h.WriteErrorResponse(c, &interfaces.ErrorMessage{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Error: fmt.Errorf("failed to convert response to chat completion format"),
|
||||
})
|
||||
cliCancel(fmt.Errorf("response conversion failed"))
|
||||
return
|
||||
}
|
||||
_, _ = c.Writer.Write(converted)
|
||||
cliCancel()
|
||||
}
|
||||
|
||||
// handleStreamingResponse handles streaming responses for Gemini models.
|
||||
// It establishes a streaming connection with the backend service and forwards
|
||||
// the response chunks to the client in real-time using Server-Sent Events.
|
||||
@@ -518,6 +631,69 @@ func (h *OpenAIAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byt
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OpenAIAPIHandler) handleStreamingResponseViaResponses(c *gin.Context, rawJSON []byte, originalChatJSON []byte) {
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
|
||||
Error: handlers.ErrorDetail{
|
||||
Message: "Streaming not supported",
|
||||
Type: "server_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
modelName := gjson.GetBytes(rawJSON, "model").String()
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, OpenaiResponse, modelName, rawJSON, h.GetAlt(c))
|
||||
var param any
|
||||
|
||||
setSSEHeaders := func() {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
}
|
||||
|
||||
// Peek for first usable chunk
|
||||
for {
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
cliCancel(c.Request.Context().Err())
|
||||
return
|
||||
case errMsg, ok := <-errChan:
|
||||
if !ok {
|
||||
errChan = nil
|
||||
continue
|
||||
}
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
if errMsg != nil {
|
||||
cliCancel(errMsg.Error)
|
||||
} else {
|
||||
cliCancel(nil)
|
||||
}
|
||||
return
|
||||
case chunk, ok := <-dataChan:
|
||||
if !ok {
|
||||
setSSEHeaders()
|
||||
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
|
||||
flusher.Flush()
|
||||
cliCancel(nil)
|
||||
return
|
||||
}
|
||||
|
||||
setSSEHeaders()
|
||||
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||
writeConvertedResponsesChunk(c, cliCtx, modelName, originalChatJSON, rawJSON, chunk, ¶m)
|
||||
flusher.Flush()
|
||||
|
||||
h.forwardResponsesAsChatStream(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan, cliCtx, modelName, originalChatJSON, rawJSON, ¶m)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleCompletionsNonStreamingResponse handles non-streaming completions responses.
|
||||
// It converts completions request to chat completions format, sends to backend,
|
||||
// then converts the response back to completions format before sending to client.
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
responsesconverter "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/openai/openai/responses"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
@@ -257,7 +258,21 @@ func (h *OpenAIResponsesAPIHandler) Responses(c *gin.Context) {
|
||||
|
||||
// Check if the client requested a streaming response.
|
||||
streamResult := gjson.GetBytes(rawJSON, "stream")
|
||||
if streamResult.Type == gjson.True {
|
||||
stream := streamResult.Type == gjson.True
|
||||
|
||||
modelName := gjson.GetBytes(rawJSON, "model").String()
|
||||
if overrideEndpoint, ok := resolveEndpointOverride(modelName, openAIResponsesEndpoint); ok && overrideEndpoint == openAIChatEndpoint {
|
||||
chatJSON := responsesconverter.ConvertOpenAIResponsesRequestToOpenAIChatCompletions(modelName, rawJSON, stream)
|
||||
stream = gjson.GetBytes(chatJSON, "stream").Bool()
|
||||
if stream {
|
||||
h.handleStreamingResponseViaChat(c, rawJSON, chatJSON)
|
||||
} else {
|
||||
h.handleNonStreamingResponseViaChat(c, rawJSON, chatJSON)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if stream {
|
||||
h.handleStreamingResponse(c, rawJSON)
|
||||
} else {
|
||||
h.handleNonStreamingResponse(c, rawJSON)
|
||||
@@ -335,6 +350,32 @@ func (h *OpenAIResponsesAPIHandler) handleNonStreamingResponse(c *gin.Context, r
|
||||
cliCancel()
|
||||
}
|
||||
|
||||
func (h *OpenAIResponsesAPIHandler) handleNonStreamingResponseViaChat(c *gin.Context, originalResponsesJSON, chatJSON []byte) {
|
||||
c.Header("Content-Type", "application/json")
|
||||
|
||||
modelName := gjson.GetBytes(chatJSON, "model").String()
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, OpenAI, modelName, chatJSON, "")
|
||||
if errMsg != nil {
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
cliCancel(errMsg.Error)
|
||||
return
|
||||
}
|
||||
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||
var param any
|
||||
converted := responsesconverter.ConvertOpenAIChatCompletionsResponseToOpenAIResponsesNonStream(cliCtx, modelName, originalResponsesJSON, originalResponsesJSON, resp, ¶m)
|
||||
if len(converted) == 0 {
|
||||
h.WriteErrorResponse(c, &interfaces.ErrorMessage{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Error: fmt.Errorf("failed to convert chat completion response to responses format"),
|
||||
})
|
||||
cliCancel(fmt.Errorf("response conversion failed"))
|
||||
return
|
||||
}
|
||||
_, _ = c.Writer.Write(converted)
|
||||
cliCancel()
|
||||
}
|
||||
|
||||
// handleStreamingResponse handles streaming responses for Gemini models.
|
||||
// It establishes a streaming connection with the backend service and forwards
|
||||
// the response chunks to the client in real-time using Server-Sent Events.
|
||||
@@ -414,6 +455,118 @@ func (h *OpenAIResponsesAPIHandler) handleStreamingResponse(c *gin.Context, rawJ
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OpenAIResponsesAPIHandler) handleStreamingResponseViaChat(c *gin.Context, originalResponsesJSON, chatJSON []byte) {
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
|
||||
Error: handlers.ErrorDetail{
|
||||
Message: "Streaming not supported",
|
||||
Type: "server_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
modelName := gjson.GetBytes(chatJSON, "model").String()
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, OpenAI, modelName, chatJSON, "")
|
||||
var param any
|
||||
|
||||
setSSEHeaders := func() {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
cliCancel(c.Request.Context().Err())
|
||||
return
|
||||
case errMsg, ok := <-errChan:
|
||||
if !ok {
|
||||
errChan = nil
|
||||
continue
|
||||
}
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
if errMsg != nil {
|
||||
cliCancel(errMsg.Error)
|
||||
} else {
|
||||
cliCancel(nil)
|
||||
}
|
||||
return
|
||||
case chunk, ok := <-dataChan:
|
||||
if !ok {
|
||||
setSSEHeaders()
|
||||
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||
_, _ = c.Writer.Write([]byte("\n"))
|
||||
flusher.Flush()
|
||||
cliCancel(nil)
|
||||
return
|
||||
}
|
||||
|
||||
setSSEHeaders()
|
||||
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||
writeChatAsResponsesChunk(c, cliCtx, modelName, originalResponsesJSON, chunk, ¶m)
|
||||
flusher.Flush()
|
||||
|
||||
h.forwardChatAsResponsesStream(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan, cliCtx, modelName, originalResponsesJSON, ¶m)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func writeChatAsResponsesChunk(c *gin.Context, ctx context.Context, modelName string, originalResponsesJSON, chunk []byte, param *any) {
|
||||
outputs := responsesconverter.ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx, modelName, originalResponsesJSON, originalResponsesJSON, chunk, param)
|
||||
for _, out := range outputs {
|
||||
if len(out) == 0 {
|
||||
continue
|
||||
}
|
||||
if bytes.HasPrefix(out, []byte("event:")) {
|
||||
_, _ = c.Writer.Write([]byte("\n"))
|
||||
}
|
||||
_, _ = c.Writer.Write(out)
|
||||
_, _ = c.Writer.Write([]byte("\n"))
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OpenAIResponsesAPIHandler) forwardChatAsResponsesStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage, ctx context.Context, modelName string, originalResponsesJSON []byte, param *any) {
|
||||
h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{
|
||||
WriteChunk: func(chunk []byte) {
|
||||
outputs := responsesconverter.ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx, modelName, originalResponsesJSON, originalResponsesJSON, chunk, param)
|
||||
for _, out := range outputs {
|
||||
if len(out) == 0 {
|
||||
continue
|
||||
}
|
||||
if bytes.HasPrefix(out, []byte("event:")) {
|
||||
_, _ = c.Writer.Write([]byte("\n"))
|
||||
}
|
||||
_, _ = c.Writer.Write(out)
|
||||
_, _ = c.Writer.Write([]byte("\n"))
|
||||
}
|
||||
},
|
||||
WriteTerminalError: func(errMsg *interfaces.ErrorMessage) {
|
||||
if errMsg == nil {
|
||||
return
|
||||
}
|
||||
status := http.StatusInternalServerError
|
||||
if errMsg.StatusCode > 0 {
|
||||
status = errMsg.StatusCode
|
||||
}
|
||||
errText := http.StatusText(status)
|
||||
if errMsg.Error != nil && errMsg.Error.Error() != "" {
|
||||
errText = errMsg.Error.Error()
|
||||
}
|
||||
body := handlers.BuildErrorResponseBody(status, errText)
|
||||
_, _ = fmt.Fprintf(c.Writer, "\nevent: error\ndata: %s\n\n", string(body))
|
||||
},
|
||||
WriteDone: func() {
|
||||
_, _ = c.Writer.Write([]byte("\n"))
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (h *OpenAIResponsesAPIHandler) forwardResponsesStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage, framer *responsesSSEFramer) {
|
||||
if framer == nil {
|
||||
framer = &responsesSSEFramer{}
|
||||
|
||||
95
sdk/auth/codebuddy.go
Normal file
95
sdk/auth/codebuddy.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codebuddy"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/browser"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// CodeBuddyAuthenticator implements the browser OAuth polling flow for CodeBuddy.
|
||||
type CodeBuddyAuthenticator struct{}
|
||||
|
||||
// NewCodeBuddyAuthenticator constructs a new CodeBuddy authenticator.
|
||||
func NewCodeBuddyAuthenticator() Authenticator {
|
||||
return &CodeBuddyAuthenticator{}
|
||||
}
|
||||
|
||||
// Provider returns the provider key for codebuddy.
|
||||
func (CodeBuddyAuthenticator) Provider() string {
|
||||
return "codebuddy"
|
||||
}
|
||||
|
||||
// codeBuddyRefreshLead is the duration before token expiry when a refresh should be attempted.
|
||||
var codeBuddyRefreshLead = 24 * time.Hour
|
||||
|
||||
// RefreshLead returns how soon before expiry a refresh should be attempted.
|
||||
// CodeBuddy tokens have a long validity period, so we refresh 24 hours before expiry.
|
||||
func (CodeBuddyAuthenticator) RefreshLead() *time.Duration {
|
||||
return &codeBuddyRefreshLead
|
||||
}
|
||||
|
||||
// Login initiates the browser OAuth flow for CodeBuddy.
|
||||
func (a CodeBuddyAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("codebuddy: configuration is required")
|
||||
}
|
||||
if opts == nil {
|
||||
opts = &LoginOptions{}
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
authSvc := codebuddy.NewCodeBuddyAuth(cfg)
|
||||
|
||||
authState, err := authSvc.FetchAuthState(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("codebuddy: failed to fetch auth state: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("\nPlease open the following URL in your browser to login:\n\n %s\n\n", authState.AuthURL)
|
||||
fmt.Println("Waiting for authorization...")
|
||||
|
||||
if !opts.NoBrowser {
|
||||
if browser.IsAvailable() {
|
||||
if errOpen := browser.OpenURL(authState.AuthURL); errOpen != nil {
|
||||
log.Debugf("codebuddy: failed to open browser: %v", errOpen)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
storage, err := authSvc.PollForToken(ctx, authState.State)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("codebuddy: %s: %w", codebuddy.GetUserFriendlyMessage(err), err)
|
||||
}
|
||||
|
||||
fmt.Printf("\nSuccessfully logged in! (User ID: %s)\n", storage.UserID)
|
||||
|
||||
authID := fmt.Sprintf("codebuddy-%s.json", storage.UserID)
|
||||
|
||||
label := storage.UserID
|
||||
if label == "" {
|
||||
label = "codebuddy-user"
|
||||
}
|
||||
|
||||
return &coreauth.Auth{
|
||||
ID: authID,
|
||||
Provider: a.Provider(),
|
||||
FileName: authID,
|
||||
Label: label,
|
||||
Storage: storage,
|
||||
Metadata: map[string]any{
|
||||
"access_token": storage.AccessToken,
|
||||
"refresh_token": storage.RefreshToken,
|
||||
"user_id": storage.UserID,
|
||||
"domain": storage.Domain,
|
||||
"expires_in": storage.ExpiresIn,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
98
sdk/auth/cursor.go
Normal file
98
sdk/auth/cursor.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
cursorauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/cursor"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/browser"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// CursorAuthenticator implements OAuth PKCE login for Cursor.
|
||||
type CursorAuthenticator struct{}
|
||||
|
||||
// NewCursorAuthenticator constructs a new Cursor authenticator.
|
||||
func NewCursorAuthenticator() Authenticator {
|
||||
return &CursorAuthenticator{}
|
||||
}
|
||||
|
||||
// Provider returns the provider key for cursor.
|
||||
func (CursorAuthenticator) Provider() string {
|
||||
return "cursor"
|
||||
}
|
||||
|
||||
// RefreshLead returns the time before expiry when a refresh should be attempted.
|
||||
func (CursorAuthenticator) RefreshLead() *time.Duration {
|
||||
d := 10 * time.Minute
|
||||
return &d
|
||||
}
|
||||
|
||||
// Login initiates the Cursor PKCE authentication flow.
|
||||
func (a CursorAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("cursor auth: configuration is required")
|
||||
}
|
||||
if opts == nil {
|
||||
opts = &LoginOptions{}
|
||||
}
|
||||
|
||||
// Generate PKCE auth parameters
|
||||
authParams, err := cursorauth.GenerateAuthParams()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cursor: failed to generate auth params: %w", err)
|
||||
}
|
||||
|
||||
// Display the login URL
|
||||
log.Info("Starting Cursor authentication...")
|
||||
log.Infof("Please visit this URL to log in: %s", authParams.LoginURL)
|
||||
|
||||
// Try to open the browser automatically
|
||||
if !opts.NoBrowser {
|
||||
if browser.IsAvailable() {
|
||||
if errOpen := browser.OpenURL(authParams.LoginURL); errOpen != nil {
|
||||
log.Warnf("Failed to open browser automatically: %v", errOpen)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log.Info("Waiting for Cursor authorization...")
|
||||
|
||||
// Poll for the auth result
|
||||
tokens, err := cursorauth.PollForAuth(ctx, authParams.UUID, authParams.Verifier)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cursor: authentication failed: %w", err)
|
||||
}
|
||||
|
||||
expiresAt := cursorauth.GetTokenExpiry(tokens.AccessToken)
|
||||
|
||||
// Auto-identify account from JWT sub claim
|
||||
sub := cursorauth.ParseJWTSub(tokens.AccessToken)
|
||||
subHash := cursorauth.SubToShortHash(sub)
|
||||
|
||||
log.Info("Cursor authentication successful!")
|
||||
|
||||
metadata := map[string]any{
|
||||
"type": "cursor",
|
||||
"access_token": tokens.AccessToken,
|
||||
"refresh_token": tokens.RefreshToken,
|
||||
"expires_at": expiresAt.Format(time.RFC3339),
|
||||
"timestamp": time.Now().UnixMilli(),
|
||||
}
|
||||
if sub != "" {
|
||||
metadata["sub"] = sub
|
||||
}
|
||||
|
||||
fileName := cursorauth.CredentialFileName("", subHash)
|
||||
|
||||
return &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: a.Provider(),
|
||||
FileName: fileName,
|
||||
Label: cursorauth.DisplayLabel("", subHash),
|
||||
Metadata: metadata,
|
||||
}, nil
|
||||
}
|
||||
@@ -237,6 +237,15 @@ func (s *FileTokenStore) readAuthFile(path, baseDir string) (*cliproxyauth.Auth,
|
||||
if disabled {
|
||||
status = cliproxyauth.StatusDisabled
|
||||
}
|
||||
|
||||
// Calculate NextRefreshAfter from expires_at (20 minutes before expiry)
|
||||
var nextRefreshAfter time.Time
|
||||
if expiresAtStr, ok := metadata["expires_at"].(string); ok && expiresAtStr != "" {
|
||||
if expiresAt, err := time.Parse(time.RFC3339, expiresAtStr); err == nil {
|
||||
nextRefreshAfter = expiresAt.Add(-20 * time.Minute)
|
||||
}
|
||||
}
|
||||
|
||||
auth := &cliproxyauth.Auth{
|
||||
ID: id,
|
||||
Provider: provider,
|
||||
@@ -249,7 +258,7 @@ func (s *FileTokenStore) readAuthFile(path, baseDir string) (*cliproxyauth.Auth,
|
||||
CreatedAt: info.ModTime(),
|
||||
UpdatedAt: info.ModTime(),
|
||||
LastRefreshedAt: time.Time{},
|
||||
NextRefreshAfter: time.Time{},
|
||||
NextRefreshAfter: nextRefreshAfter,
|
||||
}
|
||||
if email, ok := metadata["email"].(string); ok && email != "" {
|
||||
auth.Attributes["email"] = email
|
||||
|
||||
136
sdk/auth/github_copilot.go
Normal file
136
sdk/auth/github_copilot.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/browser"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// GitHubCopilotAuthenticator implements the OAuth device flow login for GitHub Copilot.
|
||||
type GitHubCopilotAuthenticator struct{}
|
||||
|
||||
// NewGitHubCopilotAuthenticator constructs a new GitHub Copilot authenticator.
|
||||
func NewGitHubCopilotAuthenticator() Authenticator {
|
||||
return &GitHubCopilotAuthenticator{}
|
||||
}
|
||||
|
||||
// Provider returns the provider key for github-copilot.
|
||||
func (GitHubCopilotAuthenticator) Provider() string {
|
||||
return "github-copilot"
|
||||
}
|
||||
|
||||
// RefreshLead returns nil since GitHub OAuth tokens don't expire in the traditional sense.
|
||||
// The token remains valid until the user revokes it or the Copilot subscription expires.
|
||||
func (GitHubCopilotAuthenticator) RefreshLead() *time.Duration {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Login initiates the GitHub device flow authentication for Copilot access.
|
||||
func (a GitHubCopilotAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("cliproxy auth: configuration is required")
|
||||
}
|
||||
if opts == nil {
|
||||
opts = &LoginOptions{}
|
||||
}
|
||||
|
||||
authSvc := copilot.NewCopilotAuth(cfg)
|
||||
|
||||
// Start the device flow
|
||||
fmt.Println("Starting GitHub Copilot authentication...")
|
||||
deviceCode, err := authSvc.StartDeviceFlow(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("github-copilot: failed to start device flow: %w", err)
|
||||
}
|
||||
|
||||
// Display the user code and verification URL
|
||||
fmt.Printf("\nTo authenticate, please visit: %s\n", deviceCode.VerificationURI)
|
||||
fmt.Printf("And enter the code: %s\n\n", deviceCode.UserCode)
|
||||
|
||||
// Try to open the browser automatically
|
||||
if !opts.NoBrowser {
|
||||
if browser.IsAvailable() {
|
||||
if errOpen := browser.OpenURL(deviceCode.VerificationURI); errOpen != nil {
|
||||
log.Warnf("Failed to open browser automatically: %v", errOpen)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Println("Waiting for GitHub authorization...")
|
||||
fmt.Printf("(This will timeout in %d seconds if not authorized)\n", deviceCode.ExpiresIn)
|
||||
|
||||
// Wait for user authorization
|
||||
authBundle, err := authSvc.WaitForAuthorization(ctx, deviceCode)
|
||||
if err != nil {
|
||||
errMsg := copilot.GetUserFriendlyMessage(err)
|
||||
return nil, fmt.Errorf("github-copilot: %s", errMsg)
|
||||
}
|
||||
|
||||
// Verify the token can get a Copilot API token
|
||||
fmt.Println("Verifying Copilot access...")
|
||||
apiToken, err := authSvc.GetCopilotAPIToken(ctx, authBundle.TokenData.AccessToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("github-copilot: failed to verify Copilot access - you may not have an active Copilot subscription: %w", err)
|
||||
}
|
||||
|
||||
// Create the token storage
|
||||
tokenStorage := authSvc.CreateTokenStorage(authBundle)
|
||||
|
||||
// Build metadata with token information for the executor
|
||||
metadata := map[string]any{
|
||||
"type": "github-copilot",
|
||||
"username": authBundle.Username,
|
||||
"email": authBundle.Email,
|
||||
"name": authBundle.Name,
|
||||
"access_token": authBundle.TokenData.AccessToken,
|
||||
"token_type": authBundle.TokenData.TokenType,
|
||||
"scope": authBundle.TokenData.Scope,
|
||||
"timestamp": time.Now().UnixMilli(),
|
||||
}
|
||||
|
||||
if apiToken.ExpiresAt > 0 {
|
||||
metadata["api_token_expires_at"] = apiToken.ExpiresAt
|
||||
}
|
||||
|
||||
fileName := fmt.Sprintf("github-copilot-%s.json", authBundle.Username)
|
||||
|
||||
label := authBundle.Email
|
||||
if label == "" {
|
||||
label = authBundle.Username
|
||||
}
|
||||
|
||||
fmt.Printf("\nGitHub Copilot authentication successful for user: %s\n", authBundle.Username)
|
||||
|
||||
return &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: a.Provider(),
|
||||
FileName: fileName,
|
||||
Label: label,
|
||||
Storage: tokenStorage,
|
||||
Metadata: metadata,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RefreshGitHubCopilotToken validates and returns the current token status.
|
||||
// GitHub OAuth tokens don't need traditional refresh - we just validate they still work.
|
||||
func RefreshGitHubCopilotToken(ctx context.Context, cfg *config.Config, storage *copilot.CopilotTokenStorage) error {
|
||||
if storage == nil || storage.AccessToken == "" {
|
||||
return fmt.Errorf("no token available")
|
||||
}
|
||||
|
||||
authSvc := copilot.NewCopilotAuth(cfg)
|
||||
|
||||
// Validate the token can still get a Copilot API token
|
||||
_, err := authSvc.GetCopilotAPIToken(ctx, storage.AccessToken)
|
||||
if err != nil {
|
||||
return fmt.Errorf("token validation failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
482
sdk/auth/gitlab.go
Normal file
482
sdk/auth/gitlab.go
Normal file
@@ -0,0 +1,482 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
gitlabauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gitlab"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/browser"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
gitLabLoginModeMetadataKey = "login_mode"
|
||||
gitLabLoginModeOAuth = "oauth"
|
||||
gitLabLoginModePAT = "pat"
|
||||
gitLabBaseURLMetadataKey = "base_url"
|
||||
gitLabOAuthClientIDMetadataKey = "oauth_client_id"
|
||||
gitLabOAuthClientSecretMetadataKey = "oauth_client_secret"
|
||||
gitLabPersonalAccessTokenMetadataKey = "personal_access_token"
|
||||
)
|
||||
|
||||
var gitLabRefreshLead = 5 * time.Minute
|
||||
|
||||
type GitLabAuthenticator struct {
|
||||
CallbackPort int
|
||||
}
|
||||
|
||||
func NewGitLabAuthenticator() *GitLabAuthenticator {
|
||||
return &GitLabAuthenticator{CallbackPort: gitlabauth.DefaultCallbackPort}
|
||||
}
|
||||
|
||||
func (a *GitLabAuthenticator) Provider() string {
|
||||
return "gitlab"
|
||||
}
|
||||
|
||||
func (a *GitLabAuthenticator) RefreshLead() *time.Duration {
|
||||
return &gitLabRefreshLead
|
||||
}
|
||||
|
||||
func (a *GitLabAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("cliproxy auth: configuration is required")
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
if opts == nil {
|
||||
opts = &LoginOptions{}
|
||||
}
|
||||
|
||||
switch strings.ToLower(strings.TrimSpace(opts.Metadata[gitLabLoginModeMetadataKey])) {
|
||||
case "", gitLabLoginModeOAuth:
|
||||
return a.loginOAuth(ctx, cfg, opts)
|
||||
case gitLabLoginModePAT:
|
||||
return a.loginPAT(ctx, cfg, opts)
|
||||
default:
|
||||
return nil, fmt.Errorf("gitlab auth: unsupported login mode %q", opts.Metadata[gitLabLoginModeMetadataKey])
|
||||
}
|
||||
}
|
||||
|
||||
func (a *GitLabAuthenticator) loginOAuth(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||
client := gitlabauth.NewAuthClient(cfg)
|
||||
baseURL := a.resolveString(opts, gitLabBaseURLMetadataKey, gitlabauth.DefaultBaseURL)
|
||||
clientID, err := a.requireInput(opts, gitLabOAuthClientIDMetadataKey, "Enter GitLab OAuth application client ID: ")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
clientSecret, err := a.optionalInput(opts, gitLabOAuthClientSecretMetadataKey, "Enter GitLab OAuth application client secret (press Enter for public PKCE app): ")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
callbackPort := a.CallbackPort
|
||||
if opts.CallbackPort > 0 {
|
||||
callbackPort = opts.CallbackPort
|
||||
}
|
||||
redirectURI := gitlabauth.RedirectURL(callbackPort)
|
||||
|
||||
pkceCodes, err := gitlabauth.GeneratePKCECodes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
state, err := misc.GenerateRandomState()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab state generation failed: %w", err)
|
||||
}
|
||||
|
||||
oauthServer := gitlabauth.NewOAuthServer(callbackPort)
|
||||
if err := oauthServer.Start(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
stopCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
if stopErr := oauthServer.Stop(stopCtx); stopErr != nil {
|
||||
log.Warnf("gitlab oauth server stop error: %v", stopErr)
|
||||
}
|
||||
}()
|
||||
|
||||
authURL, err := client.GenerateAuthURL(baseURL, clientID, redirectURI, state, pkceCodes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !opts.NoBrowser {
|
||||
fmt.Println("Opening browser for GitLab Duo authentication")
|
||||
if !browser.IsAvailable() {
|
||||
log.Warn("No browser available; please open the URL manually")
|
||||
util.PrintSSHTunnelInstructions(callbackPort)
|
||||
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
|
||||
} else if err = browser.OpenURL(authURL); err != nil {
|
||||
log.Warnf("Failed to open browser automatically: %v", err)
|
||||
util.PrintSSHTunnelInstructions(callbackPort)
|
||||
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
|
||||
}
|
||||
} else {
|
||||
util.PrintSSHTunnelInstructions(callbackPort)
|
||||
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
|
||||
}
|
||||
|
||||
fmt.Println("Waiting for GitLab OAuth callback...")
|
||||
|
||||
callbackCh := make(chan *gitlabauth.OAuthResult, 1)
|
||||
callbackErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
result, waitErr := oauthServer.WaitForCallback(5 * time.Minute)
|
||||
if waitErr != nil {
|
||||
callbackErrCh <- waitErr
|
||||
return
|
||||
}
|
||||
callbackCh <- result
|
||||
}()
|
||||
|
||||
var result *gitlabauth.OAuthResult
|
||||
var manualPromptTimer *time.Timer
|
||||
var manualPromptC <-chan time.Time
|
||||
if opts.Prompt != nil {
|
||||
manualPromptTimer = time.NewTimer(15 * time.Second)
|
||||
manualPromptC = manualPromptTimer.C
|
||||
defer manualPromptTimer.Stop()
|
||||
}
|
||||
|
||||
waitForCallback:
|
||||
for {
|
||||
select {
|
||||
case result = <-callbackCh:
|
||||
break waitForCallback
|
||||
case err = <-callbackErrCh:
|
||||
return nil, err
|
||||
case <-manualPromptC:
|
||||
manualPromptC = nil
|
||||
if manualPromptTimer != nil {
|
||||
manualPromptTimer.Stop()
|
||||
}
|
||||
input, promptErr := opts.Prompt("Paste the GitLab callback URL (or press Enter to keep waiting): ")
|
||||
if promptErr != nil {
|
||||
return nil, promptErr
|
||||
}
|
||||
parsed, parseErr := misc.ParseOAuthCallback(input)
|
||||
if parseErr != nil {
|
||||
return nil, parseErr
|
||||
}
|
||||
if parsed == nil {
|
||||
continue
|
||||
}
|
||||
result = &gitlabauth.OAuthResult{
|
||||
Code: parsed.Code,
|
||||
State: parsed.State,
|
||||
Error: parsed.Error,
|
||||
}
|
||||
break waitForCallback
|
||||
}
|
||||
}
|
||||
|
||||
if result.Error != "" {
|
||||
return nil, fmt.Errorf("gitlab oauth returned error: %s", result.Error)
|
||||
}
|
||||
if result.State != state {
|
||||
return nil, fmt.Errorf("gitlab auth: state mismatch")
|
||||
}
|
||||
|
||||
tokenResp, err := client.ExchangeCodeForTokens(ctx, baseURL, clientID, clientSecret, redirectURI, result.Code, pkceCodes.CodeVerifier)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
accessToken := strings.TrimSpace(tokenResp.AccessToken)
|
||||
if accessToken == "" {
|
||||
return nil, fmt.Errorf("gitlab auth: missing access token")
|
||||
}
|
||||
|
||||
user, err := client.GetCurrentUser(ctx, baseURL, accessToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
direct, err := client.FetchDirectAccess(ctx, baseURL, accessToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
identifier := gitLabAccountIdentifier(user)
|
||||
fileName := fmt.Sprintf("gitlab-%s.json", sanitizeGitLabFileName(identifier))
|
||||
metadata := buildGitLabAuthMetadata(baseURL, gitLabLoginModeOAuth, tokenResp, direct)
|
||||
metadata["auth_kind"] = "oauth"
|
||||
metadata[gitLabOAuthClientIDMetadataKey] = clientID
|
||||
metadata["username"] = strings.TrimSpace(user.Username)
|
||||
if email := strings.TrimSpace(primaryGitLabEmail(user)); email != "" {
|
||||
metadata["email"] = email
|
||||
}
|
||||
metadata["name"] = strings.TrimSpace(user.Name)
|
||||
|
||||
fmt.Println("GitLab Duo authentication successful")
|
||||
|
||||
return &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: a.Provider(),
|
||||
FileName: fileName,
|
||||
Label: identifier,
|
||||
Metadata: metadata,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a *GitLabAuthenticator) loginPAT(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||
client := gitlabauth.NewAuthClient(cfg)
|
||||
baseURL := a.resolveString(opts, gitLabBaseURLMetadataKey, gitlabauth.DefaultBaseURL)
|
||||
token, err := a.requireInput(opts, gitLabPersonalAccessTokenMetadataKey, "Enter GitLab personal access token: ")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, err := client.GetCurrentUser(ctx, baseURL, token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_, err = client.GetPersonalAccessTokenSelf(ctx, baseURL, token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
direct, err := client.FetchDirectAccess(ctx, baseURL, token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
identifier := gitLabAccountIdentifier(user)
|
||||
fileName := fmt.Sprintf("gitlab-%s-pat.json", sanitizeGitLabFileName(identifier))
|
||||
metadata := buildGitLabAuthMetadata(baseURL, gitLabLoginModePAT, nil, direct)
|
||||
metadata["auth_kind"] = "personal_access_token"
|
||||
metadata[gitLabPersonalAccessTokenMetadataKey] = strings.TrimSpace(token)
|
||||
metadata["token_preview"] = maskGitLabToken(token)
|
||||
metadata["username"] = strings.TrimSpace(user.Username)
|
||||
if email := strings.TrimSpace(primaryGitLabEmail(user)); email != "" {
|
||||
metadata["email"] = email
|
||||
}
|
||||
metadata["name"] = strings.TrimSpace(user.Name)
|
||||
|
||||
fmt.Println("GitLab Duo PAT authentication successful")
|
||||
|
||||
return &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: a.Provider(),
|
||||
FileName: fileName,
|
||||
Label: identifier + " (PAT)",
|
||||
Metadata: metadata,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func buildGitLabAuthMetadata(baseURL, mode string, tokenResp *gitlabauth.TokenResponse, direct *gitlabauth.DirectAccessResponse) map[string]any {
|
||||
metadata := map[string]any{
|
||||
"type": "gitlab",
|
||||
"auth_method": strings.TrimSpace(mode),
|
||||
gitLabBaseURLMetadataKey: gitlabauth.NormalizeBaseURL(baseURL),
|
||||
"last_refresh": time.Now().UTC().Format(time.RFC3339),
|
||||
"refresh_interval_seconds": 240,
|
||||
}
|
||||
if tokenResp != nil {
|
||||
metadata["access_token"] = strings.TrimSpace(tokenResp.AccessToken)
|
||||
if refreshToken := strings.TrimSpace(tokenResp.RefreshToken); refreshToken != "" {
|
||||
metadata["refresh_token"] = refreshToken
|
||||
}
|
||||
if tokenType := strings.TrimSpace(tokenResp.TokenType); tokenType != "" {
|
||||
metadata["token_type"] = tokenType
|
||||
}
|
||||
if scope := strings.TrimSpace(tokenResp.Scope); scope != "" {
|
||||
metadata["scope"] = scope
|
||||
}
|
||||
if expiry := gitlabauth.TokenExpiry(time.Now(), tokenResp); !expiry.IsZero() {
|
||||
metadata["oauth_expires_at"] = expiry.Format(time.RFC3339)
|
||||
}
|
||||
}
|
||||
mergeGitLabDirectAccessMetadata(metadata, direct)
|
||||
return metadata
|
||||
}
|
||||
|
||||
func mergeGitLabDirectAccessMetadata(metadata map[string]any, direct *gitlabauth.DirectAccessResponse) {
|
||||
if metadata == nil || direct == nil {
|
||||
return
|
||||
}
|
||||
if base := strings.TrimSpace(direct.BaseURL); base != "" {
|
||||
metadata["duo_gateway_base_url"] = base
|
||||
}
|
||||
if token := strings.TrimSpace(direct.Token); token != "" {
|
||||
metadata["duo_gateway_token"] = token
|
||||
}
|
||||
if direct.ExpiresAt > 0 {
|
||||
expiry := time.Unix(direct.ExpiresAt, 0).UTC()
|
||||
metadata["duo_gateway_expires_at"] = expiry.Format(time.RFC3339)
|
||||
now := time.Now().UTC()
|
||||
if ttl := expiry.Sub(now); ttl > 0 {
|
||||
interval := int(ttl.Seconds()) / 2
|
||||
switch {
|
||||
case interval < 60:
|
||||
interval = 60
|
||||
case interval > 240:
|
||||
interval = 240
|
||||
}
|
||||
metadata["refresh_interval_seconds"] = interval
|
||||
}
|
||||
}
|
||||
if len(direct.Headers) > 0 {
|
||||
headers := make(map[string]string, len(direct.Headers))
|
||||
for key, value := range direct.Headers {
|
||||
key = strings.TrimSpace(key)
|
||||
value = strings.TrimSpace(value)
|
||||
if key == "" || value == "" {
|
||||
continue
|
||||
}
|
||||
headers[key] = value
|
||||
}
|
||||
if len(headers) > 0 {
|
||||
metadata["duo_gateway_headers"] = headers
|
||||
}
|
||||
}
|
||||
if direct.ModelDetails != nil {
|
||||
modelDetails := map[string]any{}
|
||||
if provider := strings.TrimSpace(direct.ModelDetails.ModelProvider); provider != "" {
|
||||
modelDetails["model_provider"] = provider
|
||||
metadata["model_provider"] = provider
|
||||
}
|
||||
if model := strings.TrimSpace(direct.ModelDetails.ModelName); model != "" {
|
||||
modelDetails["model_name"] = model
|
||||
metadata["model_name"] = model
|
||||
}
|
||||
if len(modelDetails) > 0 {
|
||||
metadata["model_details"] = modelDetails
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *GitLabAuthenticator) resolveString(opts *LoginOptions, key, fallback string) string {
|
||||
if opts != nil && opts.Metadata != nil {
|
||||
if value := strings.TrimSpace(opts.Metadata[key]); value != "" {
|
||||
return value
|
||||
}
|
||||
}
|
||||
for _, envKey := range gitLabEnvKeys(key) {
|
||||
if raw, ok := os.LookupEnv(envKey); ok {
|
||||
if trimmed := strings.TrimSpace(raw); trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(fallback) != "" {
|
||||
return fallback
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (a *GitLabAuthenticator) requireInput(opts *LoginOptions, key, prompt string) (string, error) {
|
||||
if value := a.resolveString(opts, key, ""); value != "" {
|
||||
return value, nil
|
||||
}
|
||||
if opts != nil && opts.Prompt != nil {
|
||||
value, err := opts.Prompt(prompt)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if trimmed := strings.TrimSpace(value); trimmed != "" {
|
||||
return trimmed, nil
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("gitlab auth: missing required %s", key)
|
||||
}
|
||||
|
||||
func (a *GitLabAuthenticator) optionalInput(opts *LoginOptions, key, prompt string) (string, error) {
|
||||
if value := a.resolveString(opts, key, ""); value != "" {
|
||||
return value, nil
|
||||
}
|
||||
if opts != nil && opts.Prompt != nil {
|
||||
value, err := opts.Prompt(prompt)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return strings.TrimSpace(value), nil
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func primaryGitLabEmail(user *gitlabauth.User) string {
|
||||
if user == nil {
|
||||
return ""
|
||||
}
|
||||
if value := strings.TrimSpace(user.Email); value != "" {
|
||||
return value
|
||||
}
|
||||
return strings.TrimSpace(user.PublicEmail)
|
||||
}
|
||||
|
||||
func gitLabAccountIdentifier(user *gitlabauth.User) string {
|
||||
if user == nil {
|
||||
return "user"
|
||||
}
|
||||
for _, value := range []string{user.Username, primaryGitLabEmail(user), user.Name} {
|
||||
if trimmed := strings.TrimSpace(value); trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
}
|
||||
return "user"
|
||||
}
|
||||
|
||||
func sanitizeGitLabFileName(value string) string {
|
||||
value = strings.TrimSpace(strings.ToLower(value))
|
||||
if value == "" {
|
||||
return "user"
|
||||
}
|
||||
var builder strings.Builder
|
||||
lastDash := false
|
||||
for _, r := range value {
|
||||
switch {
|
||||
case r >= 'a' && r <= 'z':
|
||||
builder.WriteRune(r)
|
||||
lastDash = false
|
||||
case r >= '0' && r <= '9':
|
||||
builder.WriteRune(r)
|
||||
lastDash = false
|
||||
case r == '-' || r == '_' || r == '.':
|
||||
builder.WriteRune(r)
|
||||
lastDash = false
|
||||
default:
|
||||
if !lastDash {
|
||||
builder.WriteRune('-')
|
||||
lastDash = true
|
||||
}
|
||||
}
|
||||
}
|
||||
result := strings.Trim(builder.String(), "-")
|
||||
if result == "" {
|
||||
return "user"
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func maskGitLabToken(token string) string {
|
||||
trimmed := strings.TrimSpace(token)
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
if len(trimmed) <= 8 {
|
||||
return trimmed
|
||||
}
|
||||
return trimmed[:4] + "..." + trimmed[len(trimmed)-4:]
|
||||
}
|
||||
|
||||
func gitLabEnvKeys(key string) []string {
|
||||
switch strings.TrimSpace(key) {
|
||||
case gitLabBaseURLMetadataKey:
|
||||
return []string{"GITLAB_BASE_URL"}
|
||||
case gitLabOAuthClientIDMetadataKey:
|
||||
return []string{"GITLAB_OAUTH_CLIENT_ID"}
|
||||
case gitLabOAuthClientSecretMetadataKey:
|
||||
return []string{"GITLAB_OAUTH_CLIENT_SECRET"}
|
||||
case gitLabPersonalAccessTokenMetadataKey:
|
||||
return []string{"GITLAB_PERSONAL_ACCESS_TOKEN"}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
66
sdk/auth/gitlab_test.go
Normal file
66
sdk/auth/gitlab_test.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
|
||||
func TestGitLabAuthenticatorLoginPAT(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/v4/user":
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": 42,
|
||||
"username": "duo-user",
|
||||
"email": "duo@example.com",
|
||||
"name": "Duo User",
|
||||
})
|
||||
case "/api/v4/personal_access_tokens/self":
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": 5,
|
||||
"name": "CLIProxyAPI",
|
||||
"scopes": []string{"api"},
|
||||
})
|
||||
case "/api/v4/code_suggestions/direct_access":
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"base_url": "https://cloud.gitlab.example.com",
|
||||
"token": "gateway-token",
|
||||
"expires_at": 1710003600,
|
||||
"headers": map[string]string{"X-Gitlab-Realm": "saas"},
|
||||
"model_details": map[string]any{
|
||||
"model_provider": "anthropic",
|
||||
"model_name": "claude-sonnet-4-5",
|
||||
},
|
||||
})
|
||||
default:
|
||||
t.Fatalf("unexpected path %q", r.URL.Path)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
authenticator := NewGitLabAuthenticator()
|
||||
record, err := authenticator.Login(context.Background(), &config.Config{}, &LoginOptions{
|
||||
Metadata: map[string]string{
|
||||
"login_mode": "pat",
|
||||
"base_url": srv.URL,
|
||||
"personal_access_token": "glpat-test-token",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Login() error = %v", err)
|
||||
}
|
||||
if record.Provider != "gitlab" {
|
||||
t.Fatalf("expected gitlab provider, got %q", record.Provider)
|
||||
}
|
||||
if got := record.Metadata["model_name"]; got != "claude-sonnet-4-5" {
|
||||
t.Fatalf("expected discovered model, got %#v", got)
|
||||
}
|
||||
if got := record.Metadata["auth_kind"]; got != "personal_access_token" {
|
||||
t.Fatalf("expected personal_access_token auth kind, got %#v", got)
|
||||
}
|
||||
}
|
||||
121
sdk/auth/kilo.go
Normal file
121
sdk/auth/kilo.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kilo"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
// KiloAuthenticator implements the login flow for Kilo AI accounts.
|
||||
type KiloAuthenticator struct{}
|
||||
|
||||
// NewKiloAuthenticator constructs a Kilo authenticator.
|
||||
func NewKiloAuthenticator() *KiloAuthenticator {
|
||||
return &KiloAuthenticator{}
|
||||
}
|
||||
|
||||
func (a *KiloAuthenticator) Provider() string {
|
||||
return "kilo"
|
||||
}
|
||||
|
||||
func (a *KiloAuthenticator) RefreshLead() *time.Duration {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Login manages the device flow authentication for Kilo AI.
|
||||
func (a *KiloAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("cliproxy auth: configuration is required")
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
if opts == nil {
|
||||
opts = &LoginOptions{}
|
||||
}
|
||||
|
||||
kilocodeAuth := kilo.NewKiloAuth()
|
||||
|
||||
fmt.Println("Initiating Kilo device authentication...")
|
||||
resp, err := kilocodeAuth.InitiateDeviceFlow(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initiate device flow: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Please visit: %s\n", resp.VerificationURL)
|
||||
fmt.Printf("And enter code: %s\n", resp.Code)
|
||||
|
||||
fmt.Println("Waiting for authorization...")
|
||||
status, err := kilocodeAuth.PollForToken(ctx, resp.Code)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("authentication failed: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Authentication successful for %s\n", status.UserEmail)
|
||||
|
||||
profile, err := kilocodeAuth.GetProfile(ctx, status.Token)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch profile: %w", err)
|
||||
}
|
||||
|
||||
var orgID string
|
||||
if len(profile.Orgs) > 1 {
|
||||
fmt.Println("Multiple organizations found. Please select one:")
|
||||
for i, org := range profile.Orgs {
|
||||
fmt.Printf("[%d] %s (%s)\n", i+1, org.Name, org.ID)
|
||||
}
|
||||
|
||||
if opts.Prompt != nil {
|
||||
input, err := opts.Prompt("Enter the number of the organization: ")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var choice int
|
||||
_, err = fmt.Sscan(input, &choice)
|
||||
if err == nil && choice > 0 && choice <= len(profile.Orgs) {
|
||||
orgID = profile.Orgs[choice-1].ID
|
||||
} else {
|
||||
orgID = profile.Orgs[0].ID
|
||||
fmt.Printf("Invalid choice, defaulting to %s\n", profile.Orgs[0].Name)
|
||||
}
|
||||
} else {
|
||||
orgID = profile.Orgs[0].ID
|
||||
fmt.Printf("Non-interactive mode, defaulting to organization: %s\n", profile.Orgs[0].Name)
|
||||
}
|
||||
} else if len(profile.Orgs) == 1 {
|
||||
orgID = profile.Orgs[0].ID
|
||||
}
|
||||
|
||||
defaults, err := kilocodeAuth.GetDefaults(ctx, status.Token, orgID)
|
||||
if err != nil {
|
||||
fmt.Printf("Warning: failed to fetch defaults: %v\n", err)
|
||||
defaults = &kilo.Defaults{}
|
||||
}
|
||||
|
||||
ts := &kilo.KiloTokenStorage{
|
||||
Token: status.Token,
|
||||
OrganizationID: orgID,
|
||||
Model: defaults.Model,
|
||||
Email: status.UserEmail,
|
||||
Type: "kilo",
|
||||
}
|
||||
|
||||
fileName := kilo.CredentialFileName(status.UserEmail)
|
||||
metadata := map[string]any{
|
||||
"email": status.UserEmail,
|
||||
"organization_id": orgID,
|
||||
"model": defaults.Model,
|
||||
}
|
||||
|
||||
return &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: a.Provider(),
|
||||
FileName: fileName,
|
||||
Storage: ts,
|
||||
Metadata: metadata,
|
||||
}, nil
|
||||
}
|
||||
458
sdk/auth/kiro.go
Normal file
458
sdk/auth/kiro.go
Normal file
@@ -0,0 +1,458 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
// extractKiroIdentifier extracts a meaningful identifier for file naming.
|
||||
// Returns account name if provided, otherwise profile ARN ID, then client ID.
|
||||
// All extracted values are sanitized to prevent path injection attacks.
|
||||
func extractKiroIdentifier(accountName, profileArn, clientID string) string {
|
||||
// Priority 1: Use account name if provided
|
||||
if accountName != "" {
|
||||
return kiroauth.SanitizeEmailForFilename(accountName)
|
||||
}
|
||||
|
||||
// Priority 2: Use profile ARN ID part (sanitized to prevent path injection)
|
||||
if profileArn != "" {
|
||||
parts := strings.Split(profileArn, "/")
|
||||
if len(parts) >= 2 {
|
||||
// Sanitize the ARN component to prevent path traversal
|
||||
return kiroauth.SanitizeEmailForFilename(parts[len(parts)-1])
|
||||
}
|
||||
}
|
||||
|
||||
// Priority 3: Use client ID (for IDC auth without email/profileArn)
|
||||
if clientID != "" {
|
||||
return kiroauth.SanitizeEmailForFilename(clientID)
|
||||
}
|
||||
|
||||
// Fallback: timestamp
|
||||
return fmt.Sprintf("%d", time.Now().UnixNano()%100000)
|
||||
}
|
||||
|
||||
// KiroAuthenticator implements OAuth authentication for Kiro with Google login.
|
||||
type KiroAuthenticator struct{}
|
||||
|
||||
// NewKiroAuthenticator constructs a Kiro authenticator.
|
||||
func NewKiroAuthenticator() *KiroAuthenticator {
|
||||
return &KiroAuthenticator{}
|
||||
}
|
||||
|
||||
// Provider returns the provider key for the authenticator.
|
||||
func (a *KiroAuthenticator) Provider() string {
|
||||
return "kiro"
|
||||
}
|
||||
|
||||
// RefreshLead indicates how soon before expiry a refresh should be attempted.
|
||||
// Set to 20 minutes for proactive refresh before token expiry.
|
||||
func (a *KiroAuthenticator) RefreshLead() *time.Duration {
|
||||
d := 20 * time.Minute
|
||||
return &d
|
||||
}
|
||||
|
||||
// createAuthRecord creates an auth record from token data.
|
||||
func (a *KiroAuthenticator) createAuthRecord(tokenData *kiroauth.KiroTokenData, source string) (*coreauth.Auth, error) {
|
||||
// Parse expires_at
|
||||
expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt)
|
||||
if err != nil {
|
||||
expiresAt = time.Now().Add(1 * time.Hour)
|
||||
}
|
||||
|
||||
// Determine label and identifier based on auth method
|
||||
// Generate sequence number for uniqueness
|
||||
seq := time.Now().UnixNano() % 100000
|
||||
|
||||
var label, idPart string
|
||||
if tokenData.AuthMethod == "idc" {
|
||||
label = "kiro-idc"
|
||||
// Priority: email > startUrl identifier > sequence only
|
||||
// Email is unique, so no sequence needed when email is available
|
||||
if tokenData.Email != "" {
|
||||
idPart = kiroauth.SanitizeEmailForFilename(tokenData.Email)
|
||||
} else if tokenData.StartURL != "" {
|
||||
identifier := kiroauth.ExtractIDCIdentifier(tokenData.StartURL)
|
||||
if identifier != "" {
|
||||
idPart = fmt.Sprintf("%s-%05d", identifier, seq)
|
||||
} else {
|
||||
idPart = fmt.Sprintf("%05d", seq)
|
||||
}
|
||||
} else {
|
||||
idPart = fmt.Sprintf("%05d", seq)
|
||||
}
|
||||
} else {
|
||||
label = fmt.Sprintf("kiro-%s", source)
|
||||
idPart = extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn, tokenData.ClientID)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
fileName := fmt.Sprintf("%s-%s.json", label, idPart)
|
||||
|
||||
metadata := map[string]any{
|
||||
"type": "kiro",
|
||||
"access_token": tokenData.AccessToken,
|
||||
"refresh_token": tokenData.RefreshToken,
|
||||
"profile_arn": tokenData.ProfileArn,
|
||||
"expires_at": tokenData.ExpiresAt,
|
||||
"auth_method": tokenData.AuthMethod,
|
||||
"provider": tokenData.Provider,
|
||||
"client_id": tokenData.ClientID,
|
||||
"client_secret": tokenData.ClientSecret,
|
||||
"email": tokenData.Email,
|
||||
}
|
||||
|
||||
// Add IDC-specific fields if present
|
||||
if tokenData.StartURL != "" {
|
||||
metadata["start_url"] = tokenData.StartURL
|
||||
}
|
||||
if tokenData.Region != "" {
|
||||
metadata["region"] = tokenData.Region
|
||||
}
|
||||
|
||||
attributes := map[string]string{
|
||||
"profile_arn": tokenData.ProfileArn,
|
||||
"source": source,
|
||||
"email": tokenData.Email,
|
||||
}
|
||||
|
||||
// Add IDC-specific attributes if present
|
||||
if tokenData.AuthMethod == "idc" {
|
||||
attributes["source"] = "aws-idc"
|
||||
if tokenData.StartURL != "" {
|
||||
attributes["start_url"] = tokenData.StartURL
|
||||
}
|
||||
if tokenData.Region != "" {
|
||||
attributes["region"] = tokenData.Region
|
||||
}
|
||||
}
|
||||
|
||||
record := &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: "kiro",
|
||||
FileName: fileName,
|
||||
Label: label,
|
||||
Status: coreauth.StatusActive,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
Metadata: metadata,
|
||||
Attributes: attributes,
|
||||
// NextRefreshAfter: 20 minutes before expiry
|
||||
NextRefreshAfter: expiresAt.Add(-20 * time.Minute),
|
||||
}
|
||||
|
||||
if tokenData.Email != "" {
|
||||
fmt.Printf("\n✓ Kiro authentication completed successfully! (Account: %s)\n", tokenData.Email)
|
||||
} else {
|
||||
fmt.Println("\n✓ Kiro authentication completed successfully!")
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// Login performs OAuth login for Kiro with AWS (Builder ID or IDC).
|
||||
// This shows a method selection prompt and handles both flows.
|
||||
func (a *KiroAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("kiro auth: configuration is required")
|
||||
}
|
||||
|
||||
// Extract IDC options from metadata if present
|
||||
var idcOpts *kiroauth.IDCLoginOptions
|
||||
if opts != nil && opts.Metadata != nil {
|
||||
if startURL := opts.Metadata["start-url"]; startURL != "" {
|
||||
idcOpts = &kiroauth.IDCLoginOptions{
|
||||
StartURL: startURL,
|
||||
Region: opts.Metadata["region"],
|
||||
UseDeviceCode: opts.Metadata["flow"] == "device",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Use the unified method selection flow (Builder ID or IDC)
|
||||
ssoClient := kiroauth.NewSSOOIDCClient(cfg)
|
||||
tokenData, err := ssoClient.LoginWithMethodSelection(ctx, idcOpts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("login failed: %w", err)
|
||||
}
|
||||
|
||||
return a.createAuthRecord(tokenData, "aws")
|
||||
}
|
||||
|
||||
// LoginWithAuthCode performs OAuth login for Kiro with AWS Builder ID using authorization code flow.
|
||||
// This provides a better UX than device code flow as it uses automatic browser callback.
|
||||
func (a *KiroAuthenticator) LoginWithAuthCode(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("kiro auth: configuration is required")
|
||||
}
|
||||
|
||||
oauth := kiroauth.NewKiroOAuth(cfg)
|
||||
|
||||
// Use AWS Builder ID authorization code flow
|
||||
tokenData, err := oauth.LoginWithBuilderIDAuthCode(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("login failed: %w", err)
|
||||
}
|
||||
|
||||
// Parse expires_at
|
||||
expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt)
|
||||
if err != nil {
|
||||
expiresAt = time.Now().Add(1 * time.Hour)
|
||||
}
|
||||
|
||||
// Extract identifier for file naming
|
||||
idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn, tokenData.ClientID)
|
||||
|
||||
now := time.Now()
|
||||
fileName := fmt.Sprintf("kiro-aws-%s.json", idPart)
|
||||
|
||||
record := &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: "kiro",
|
||||
FileName: fileName,
|
||||
Label: "kiro-aws",
|
||||
Status: coreauth.StatusActive,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
Metadata: map[string]any{
|
||||
"type": "kiro",
|
||||
"access_token": tokenData.AccessToken,
|
||||
"refresh_token": tokenData.RefreshToken,
|
||||
"profile_arn": tokenData.ProfileArn,
|
||||
"expires_at": tokenData.ExpiresAt,
|
||||
"auth_method": tokenData.AuthMethod,
|
||||
"provider": tokenData.Provider,
|
||||
"client_id": tokenData.ClientID,
|
||||
"client_secret": tokenData.ClientSecret,
|
||||
"email": tokenData.Email,
|
||||
},
|
||||
Attributes: map[string]string{
|
||||
"profile_arn": tokenData.ProfileArn,
|
||||
"source": "aws-builder-id-authcode",
|
||||
"email": tokenData.Email,
|
||||
},
|
||||
// NextRefreshAfter: 20 minutes before expiry
|
||||
NextRefreshAfter: expiresAt.Add(-20 * time.Minute),
|
||||
}
|
||||
|
||||
if tokenData.Email != "" {
|
||||
fmt.Printf("\n✓ Kiro authentication completed successfully! (Account: %s)\n", tokenData.Email)
|
||||
} else {
|
||||
fmt.Println("\n✓ Kiro authentication completed successfully!")
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// LoginWithGoogle performs OAuth login for Kiro with Google.
|
||||
// NOTE: Google login is not available for third-party applications due to AWS Cognito restrictions.
|
||||
// Please use AWS Builder ID or import your token from Kiro IDE.
|
||||
func (a *KiroAuthenticator) LoginWithGoogle(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||
return nil, fmt.Errorf("Google login is not available for third-party applications due to AWS Cognito restrictions.\n\nAlternatives:\n 1. Use AWS Builder ID: cliproxy kiro --builder-id\n 2. Import token from Kiro IDE: cliproxy kiro --import\n\nTo get a token from Kiro IDE:\n 1. Open Kiro IDE and login with Google\n 2. Find: ~/.kiro/kiro-auth-token.json\n 3. Run: cliproxy kiro --import")
|
||||
}
|
||||
|
||||
// LoginWithGitHub performs OAuth login for Kiro with GitHub.
|
||||
// NOTE: GitHub login is not available for third-party applications due to AWS Cognito restrictions.
|
||||
// Please use AWS Builder ID or import your token from Kiro IDE.
|
||||
func (a *KiroAuthenticator) LoginWithGitHub(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||
return nil, fmt.Errorf("GitHub login is not available for third-party applications due to AWS Cognito restrictions.\n\nAlternatives:\n 1. Use AWS Builder ID: cliproxy kiro --builder-id\n 2. Import token from Kiro IDE: cliproxy kiro --import\n\nTo get a token from Kiro IDE:\n 1. Open Kiro IDE and login with GitHub\n 2. Find: ~/.kiro/kiro-auth-token.json\n 3. Run: cliproxy kiro --import")
|
||||
}
|
||||
|
||||
// ImportFromKiroIDE imports token from Kiro IDE's token file.
|
||||
func (a *KiroAuthenticator) ImportFromKiroIDE(ctx context.Context, cfg *config.Config) (*coreauth.Auth, error) {
|
||||
tokenData, err := kiroauth.LoadKiroIDEToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load Kiro IDE token: %w", err)
|
||||
}
|
||||
|
||||
// Parse expires_at
|
||||
expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt)
|
||||
if err != nil {
|
||||
expiresAt = time.Now().Add(1 * time.Hour)
|
||||
}
|
||||
|
||||
// Extract email from JWT if not already set (for imported tokens)
|
||||
if tokenData.Email == "" {
|
||||
tokenData.Email = kiroauth.ExtractEmailFromJWT(tokenData.AccessToken)
|
||||
}
|
||||
|
||||
// Extract identifier for file naming
|
||||
idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn, tokenData.ClientID)
|
||||
// Sanitize provider to prevent path traversal (defense-in-depth)
|
||||
provider := kiroauth.SanitizeEmailForFilename(strings.ToLower(strings.TrimSpace(tokenData.Provider)))
|
||||
if provider == "" {
|
||||
provider = "imported" // Fallback for legacy tokens without provider
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
fileName := fmt.Sprintf("kiro-%s-%s.json", provider, idPart)
|
||||
|
||||
record := &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: "kiro",
|
||||
FileName: fileName,
|
||||
Label: fmt.Sprintf("kiro-%s", provider),
|
||||
Status: coreauth.StatusActive,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
Metadata: map[string]any{
|
||||
"type": "kiro",
|
||||
"access_token": tokenData.AccessToken,
|
||||
"refresh_token": tokenData.RefreshToken,
|
||||
"profile_arn": tokenData.ProfileArn,
|
||||
"expires_at": tokenData.ExpiresAt,
|
||||
"auth_method": tokenData.AuthMethod,
|
||||
"provider": tokenData.Provider,
|
||||
"client_id": tokenData.ClientID,
|
||||
"client_secret": tokenData.ClientSecret,
|
||||
"client_id_hash": tokenData.ClientIDHash,
|
||||
"email": tokenData.Email,
|
||||
"region": tokenData.Region,
|
||||
"start_url": tokenData.StartURL,
|
||||
},
|
||||
Attributes: map[string]string{
|
||||
"profile_arn": tokenData.ProfileArn,
|
||||
"source": "kiro-ide-import",
|
||||
"email": tokenData.Email,
|
||||
"region": tokenData.Region,
|
||||
},
|
||||
// NextRefreshAfter: 20 minutes before expiry
|
||||
NextRefreshAfter: expiresAt.Add(-20 * time.Minute),
|
||||
}
|
||||
|
||||
// Display the email if extracted
|
||||
if tokenData.Email != "" {
|
||||
fmt.Printf("\n✓ Imported Kiro token from IDE (Provider: %s, Account: %s)\n", tokenData.Provider, tokenData.Email)
|
||||
} else {
|
||||
fmt.Printf("\n✓ Imported Kiro token from IDE (Provider: %s)\n", tokenData.Provider)
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// Refresh refreshes an expired Kiro token using AWS SSO OIDC.
|
||||
func (a *KiroAuthenticator) Refresh(ctx context.Context, cfg *config.Config, auth *coreauth.Auth) (*coreauth.Auth, error) {
|
||||
if auth == nil || auth.Metadata == nil {
|
||||
return nil, fmt.Errorf("invalid auth record")
|
||||
}
|
||||
|
||||
refreshToken, ok := auth.Metadata["refresh_token"].(string)
|
||||
if !ok || refreshToken == "" {
|
||||
return nil, fmt.Errorf("refresh token not found")
|
||||
}
|
||||
|
||||
clientID, _ := auth.Metadata["client_id"].(string)
|
||||
clientSecret, _ := auth.Metadata["client_secret"].(string)
|
||||
clientIDHash, _ := auth.Metadata["client_id_hash"].(string)
|
||||
authMethod, _ := auth.Metadata["auth_method"].(string)
|
||||
startURL, _ := auth.Metadata["start_url"].(string)
|
||||
region, _ := auth.Metadata["region"].(string)
|
||||
|
||||
// For Enterprise Kiro IDE (IDC auth), try to load clientId/clientSecret from device registration
|
||||
// if they are missing from metadata. This handles the case where token was imported without
|
||||
// clientId/clientSecret but has clientIdHash.
|
||||
if (clientID == "" || clientSecret == "") && clientIDHash != "" {
|
||||
if loadedClientID, loadedClientSecret, err := loadDeviceRegistrationCredentials(clientIDHash); err == nil {
|
||||
clientID = loadedClientID
|
||||
clientSecret = loadedClientSecret
|
||||
}
|
||||
}
|
||||
|
||||
var tokenData *kiroauth.KiroTokenData
|
||||
var err error
|
||||
|
||||
ssoClient := kiroauth.NewSSOOIDCClient(cfg)
|
||||
|
||||
// Use SSO OIDC refresh for AWS Builder ID or IDC, otherwise use Kiro's OAuth refresh endpoint
|
||||
switch {
|
||||
case clientID != "" && clientSecret != "" && authMethod == "idc" && region != "":
|
||||
// IDC refresh with region-specific endpoint
|
||||
tokenData, err = ssoClient.RefreshTokenWithRegion(ctx, clientID, clientSecret, refreshToken, region, startURL)
|
||||
case clientID != "" && clientSecret != "" && (authMethod == "builder-id" || authMethod == "idc"):
|
||||
// Builder ID or IDC refresh with default endpoint (us-east-1)
|
||||
tokenData, err = ssoClient.RefreshToken(ctx, clientID, clientSecret, refreshToken)
|
||||
default:
|
||||
// Fallback to Kiro's refresh endpoint (for social auth: Google/GitHub)
|
||||
oauth := kiroauth.NewKiroOAuth(cfg)
|
||||
tokenData, err = oauth.RefreshToken(ctx, refreshToken)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("token refresh failed: %w", err)
|
||||
}
|
||||
|
||||
// Parse expires_at
|
||||
expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt)
|
||||
if err != nil {
|
||||
expiresAt = time.Now().Add(1 * time.Hour)
|
||||
}
|
||||
|
||||
// Clone auth to avoid mutating the input parameter
|
||||
updated := auth.Clone()
|
||||
now := time.Now()
|
||||
updated.UpdatedAt = now
|
||||
updated.LastRefreshedAt = now
|
||||
updated.Metadata["access_token"] = tokenData.AccessToken
|
||||
updated.Metadata["refresh_token"] = tokenData.RefreshToken
|
||||
updated.Metadata["expires_at"] = tokenData.ExpiresAt
|
||||
updated.Metadata["last_refresh"] = now.Format(time.RFC3339) // For double-check optimization
|
||||
// Store clientId/clientSecret if they were loaded from device registration
|
||||
if clientID != "" && updated.Metadata["client_id"] == nil {
|
||||
updated.Metadata["client_id"] = clientID
|
||||
}
|
||||
if clientSecret != "" && updated.Metadata["client_secret"] == nil {
|
||||
updated.Metadata["client_secret"] = clientSecret
|
||||
}
|
||||
// NextRefreshAfter: 20 minutes before expiry
|
||||
updated.NextRefreshAfter = expiresAt.Add(-20 * time.Minute)
|
||||
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
// loadDeviceRegistrationCredentials loads clientId and clientSecret from device registration file.
|
||||
// This is used when refreshing tokens that were imported without clientId/clientSecret.
|
||||
func loadDeviceRegistrationCredentials(clientIDHash string) (clientID, clientSecret string, err error) {
|
||||
if clientIDHash == "" {
|
||||
return "", "", fmt.Errorf("clientIdHash is empty")
|
||||
}
|
||||
|
||||
// Sanitize clientIdHash to prevent path traversal
|
||||
if strings.Contains(clientIDHash, "/") || strings.Contains(clientIDHash, "\\") || strings.Contains(clientIDHash, "..") {
|
||||
return "", "", fmt.Errorf("invalid clientIdHash: contains path separator")
|
||||
}
|
||||
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to get home directory: %w", err)
|
||||
}
|
||||
|
||||
deviceRegPath := filepath.Join(homeDir, ".aws", "sso", "cache", clientIDHash+".json")
|
||||
data, err := os.ReadFile(deviceRegPath)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to read device registration file: %w", err)
|
||||
}
|
||||
|
||||
var deviceReg struct {
|
||||
ClientID string `json:"clientId"`
|
||||
ClientSecret string `json:"clientSecret"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(data, &deviceReg); err != nil {
|
||||
return "", "", fmt.Errorf("failed to parse device registration: %w", err)
|
||||
}
|
||||
|
||||
if deviceReg.ClientID == "" || deviceReg.ClientSecret == "" {
|
||||
return "", "", fmt.Errorf("device registration missing clientId or clientSecret")
|
||||
}
|
||||
|
||||
return deviceReg.ClientID, deviceReg.ClientSecret, nil
|
||||
}
|
||||
@@ -74,3 +74,16 @@ func (m *Manager) Login(ctx context.Context, provider string, cfg *config.Config
|
||||
}
|
||||
return record, savedPath, nil
|
||||
}
|
||||
|
||||
// SaveAuth persists an auth record directly without going through the login flow.
|
||||
func (m *Manager) SaveAuth(record *coreauth.Auth, cfg *config.Config) (string, error) {
|
||||
if m.store == nil {
|
||||
return "", fmt.Errorf("no store configured")
|
||||
}
|
||||
if cfg != nil {
|
||||
if dirSetter, ok := m.store.(interface{ SetBaseDir(string) }); ok {
|
||||
dirSetter.SetBaseDir(cfg.AuthDir)
|
||||
}
|
||||
}
|
||||
return m.store.Save(context.Background(), record)
|
||||
}
|
||||
|
||||
@@ -15,6 +15,11 @@ func init() {
|
||||
registerRefreshLead("gemini-cli", func() Authenticator { return NewGeminiAuthenticator() })
|
||||
registerRefreshLead("antigravity", func() Authenticator { return NewAntigravityAuthenticator() })
|
||||
registerRefreshLead("kimi", func() Authenticator { return NewKimiAuthenticator() })
|
||||
registerRefreshLead("kiro", func() Authenticator { return NewKiroAuthenticator() })
|
||||
registerRefreshLead("github-copilot", func() Authenticator { return NewGitHubCopilotAuthenticator() })
|
||||
registerRefreshLead("gitlab", func() Authenticator { return NewGitLabAuthenticator() })
|
||||
registerRefreshLead("codebuddy", func() Authenticator { return NewCodeBuddyAuthenticator() })
|
||||
registerRefreshLead("cursor", func() Authenticator { return NewCursorAuthenticator() })
|
||||
}
|
||||
|
||||
func registerRefreshLead(provider string, factory func() Authenticator) {
|
||||
|
||||
@@ -63,7 +63,7 @@ const (
|
||||
refreshCheckInterval = 5 * time.Second
|
||||
refreshMaxConcurrency = 16
|
||||
refreshPendingBackoff = time.Minute
|
||||
refreshFailureBackoff = 5 * time.Minute
|
||||
refreshFailureBackoff = 1 * time.Minute
|
||||
quotaBackoffBase = time.Second
|
||||
quotaBackoffMax = 30 * time.Minute
|
||||
)
|
||||
@@ -3168,7 +3168,9 @@ func (m *Manager) refreshAuth(ctx context.Context, id string) {
|
||||
updated.Runtime = auth.Runtime
|
||||
}
|
||||
updated.LastRefreshedAt = now
|
||||
updated.NextRefreshAfter = time.Time{}
|
||||
// Preserve NextRefreshAfter set by the Authenticator
|
||||
// If the Authenticator set a reasonable refresh time, it should not be overwritten
|
||||
// If the Authenticator did not set it (zero value), shouldRefresh will use default logic
|
||||
updated.LastError = nil
|
||||
updated.UpdatedAt = now
|
||||
_, _ = m.Update(ctx, updated)
|
||||
|
||||
@@ -265,7 +265,7 @@ func modelAliasChannel(auth *Auth) string {
|
||||
// and auth kind. Returns empty string if the provider/authKind combination doesn't support
|
||||
// OAuth model alias (e.g., API key authentication).
|
||||
//
|
||||
// Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kimi.
|
||||
// Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kiro, github-copilot, kimi.
|
||||
func OAuthModelAliasChannel(provider, authKind string) string {
|
||||
provider = strings.ToLower(strings.TrimSpace(provider))
|
||||
authKind = strings.ToLower(strings.TrimSpace(authKind))
|
||||
@@ -289,7 +289,7 @@ func OAuthModelAliasChannel(provider, authKind string) string {
|
||||
return ""
|
||||
}
|
||||
return "codex"
|
||||
case "gemini-cli", "aistudio", "antigravity", "qwen", "iflow", "kimi":
|
||||
case "gemini-cli", "aistudio", "antigravity", "qwen", "iflow", "kiro", "github-copilot", "kimi":
|
||||
return provider
|
||||
default:
|
||||
return ""
|
||||
|
||||
@@ -43,6 +43,15 @@ func TestResolveOAuthUpstreamModel_SuffixPreservation(t *testing.T) {
|
||||
input: "gemini-2.5-pro",
|
||||
want: "gemini-2.5-pro-exp-03-25",
|
||||
},
|
||||
{
|
||||
name: "kiro alias resolves",
|
||||
aliases: map[string][]internalconfig.OAuthModelAlias{
|
||||
"kiro": {{Name: "kiro-claude-sonnet-4-5", Alias: "sonnet"}},
|
||||
},
|
||||
channel: "kiro",
|
||||
input: "sonnet",
|
||||
want: "kiro-claude-sonnet-4-5",
|
||||
},
|
||||
{
|
||||
name: "config suffix takes priority",
|
||||
aliases: map[string][]internalconfig.OAuthModelAlias{
|
||||
@@ -70,6 +79,24 @@ func TestResolveOAuthUpstreamModel_SuffixPreservation(t *testing.T) {
|
||||
input: "gemini-2.5-pro(none)",
|
||||
want: "gemini-2.5-pro-exp-03-25(none)",
|
||||
},
|
||||
{
|
||||
name: "github-copilot suffix preserved",
|
||||
aliases: map[string][]internalconfig.OAuthModelAlias{
|
||||
"github-copilot": {{Name: "claude-opus-4.6", Alias: "opus"}},
|
||||
},
|
||||
channel: "github-copilot",
|
||||
input: "opus(medium)",
|
||||
want: "claude-opus-4.6(medium)",
|
||||
},
|
||||
{
|
||||
name: "github-copilot no suffix",
|
||||
aliases: map[string][]internalconfig.OAuthModelAlias{
|
||||
"github-copilot": {{Name: "claude-opus-4.6", Alias: "opus"}},
|
||||
},
|
||||
channel: "github-copilot",
|
||||
input: "opus",
|
||||
want: "claude-opus-4.6",
|
||||
},
|
||||
{
|
||||
name: "kimi suffix preserved",
|
||||
aliases: map[string][]internalconfig.OAuthModelAlias{
|
||||
@@ -163,6 +190,10 @@ func createAuthForChannel(channel string) *Auth {
|
||||
return &Auth{Provider: "iflow"}
|
||||
case "kimi":
|
||||
return &Auth{Provider: "kimi"}
|
||||
case "kiro":
|
||||
return &Auth{Provider: "kiro"}
|
||||
case "github-copilot":
|
||||
return &Auth{Provider: "github-copilot"}
|
||||
default:
|
||||
return &Auth{Provider: channel}
|
||||
}
|
||||
@@ -176,6 +207,22 @@ func TestOAuthModelAliasChannel_Kimi(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthModelAliasChannel_GitHubCopilot(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if got := OAuthModelAliasChannel("github-copilot", ""); got != "github-copilot" {
|
||||
t.Fatalf("OAuthModelAliasChannel() = %q, want %q", got, "github-copilot")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthModelAliasChannel_Kiro(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if got := OAuthModelAliasChannel("kiro", ""); got != "kiro" {
|
||||
t.Fatalf("OAuthModelAliasChannel() = %q, want %q", got, "kiro")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyOAuthModelAlias_SuffixPreservation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -592,6 +592,11 @@ func (m *scheduledAuthMeta) supportsModel(modelKey string) bool {
|
||||
if modelKey == "" {
|
||||
return true
|
||||
}
|
||||
// Cursor acts as a universal proxy supporting multiple model families.
|
||||
// Allow any model to be routed to cursor auth.
|
||||
if m.providerKey == "cursor" {
|
||||
return true
|
||||
}
|
||||
if len(m.supportedModelSet) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -418,8 +418,41 @@ func (a *Auth) AccountInfo() (string, string) {
|
||||
}
|
||||
}
|
||||
|
||||
// For GitHub provider (including github-copilot), return username
|
||||
if strings.HasPrefix(strings.ToLower(a.Provider), "github") {
|
||||
if a.Metadata != nil {
|
||||
if username, ok := a.Metadata["username"].(string); ok {
|
||||
username = strings.TrimSpace(username)
|
||||
if username != "" {
|
||||
return "oauth", username
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check metadata for email first (OAuth-style auth)
|
||||
if a.Metadata != nil {
|
||||
if method, ok := a.Metadata["auth_method"].(string); ok {
|
||||
switch strings.ToLower(strings.TrimSpace(method)) {
|
||||
case "oauth":
|
||||
for _, key := range []string{"email", "username", "name"} {
|
||||
if value, okValue := a.Metadata[key].(string); okValue {
|
||||
if trimmed := strings.TrimSpace(value); trimmed != "" {
|
||||
return "oauth", trimmed
|
||||
}
|
||||
}
|
||||
}
|
||||
case "pat", "personal_access_token":
|
||||
for _, key := range []string{"username", "email", "name", "token_preview"} {
|
||||
if value, okValue := a.Metadata[key].(string); okValue {
|
||||
if trimmed := strings.TrimSpace(value); trimmed != "" {
|
||||
return "personal_access_token", trimmed
|
||||
}
|
||||
}
|
||||
}
|
||||
return "personal_access_token", ""
|
||||
}
|
||||
}
|
||||
if v, ok := a.Metadata["email"].(string); ok {
|
||||
email := strings.TrimSpace(v)
|
||||
if email != "" {
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/api"
|
||||
kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor"
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
|
||||
@@ -100,6 +101,16 @@ func (s *Service) RegisterUsagePlugin(plugin usage.Plugin) {
|
||||
usage.RegisterPlugin(plugin)
|
||||
}
|
||||
|
||||
// GetWatcher returns the underlying WatcherWrapper instance.
|
||||
// This allows external components (e.g., RefreshManager) to interact with the watcher.
|
||||
// Returns nil if the service or watcher is not initialized.
|
||||
func (s *Service) GetWatcher() *WatcherWrapper {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
return s.watcher
|
||||
}
|
||||
|
||||
// newDefaultAuthManager creates a default authentication manager with all supported providers.
|
||||
func newDefaultAuthManager() *sdkAuth.Manager {
|
||||
return sdkAuth.NewManager(
|
||||
@@ -108,6 +119,7 @@ func newDefaultAuthManager() *sdkAuth.Manager {
|
||||
sdkAuth.NewCodexAuthenticator(),
|
||||
sdkAuth.NewClaudeAuthenticator(),
|
||||
sdkAuth.NewQwenAuthenticator(),
|
||||
sdkAuth.NewGitLabAuthenticator(),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -429,6 +441,18 @@ func (s *Service) ensureExecutorsForAuthWithMode(a *coreauth.Auth, forceReplace
|
||||
s.coreManager.RegisterExecutor(executor.NewIFlowExecutor(s.cfg))
|
||||
case "kimi":
|
||||
s.coreManager.RegisterExecutor(executor.NewKimiExecutor(s.cfg))
|
||||
case "kiro":
|
||||
s.coreManager.RegisterExecutor(executor.NewKiroExecutor(s.cfg))
|
||||
case "kilo":
|
||||
s.coreManager.RegisterExecutor(executor.NewKiloExecutor(s.cfg))
|
||||
case "cursor":
|
||||
s.coreManager.RegisterExecutor(executor.NewCursorExecutor(s.cfg))
|
||||
case "github-copilot":
|
||||
s.coreManager.RegisterExecutor(executor.NewGitHubCopilotExecutor(s.cfg))
|
||||
case "codebuddy":
|
||||
s.coreManager.RegisterExecutor(executor.NewCodeBuddyExecutor(s.cfg))
|
||||
case "gitlab":
|
||||
s.coreManager.RegisterExecutor(executor.NewGitLabExecutor(s.cfg))
|
||||
default:
|
||||
providerKey := strings.ToLower(strings.TrimSpace(a.Provider))
|
||||
if providerKey == "" {
|
||||
@@ -678,6 +702,18 @@ func (s *Service) Run(ctx context.Context) error {
|
||||
}
|
||||
watcherWrapper.SetConfig(s.cfg)
|
||||
|
||||
// 方案 A: 连接 Kiro 后台刷新器回调到 Watcher
|
||||
// 当后台刷新器成功刷新 token 后,立即通知 Watcher 更新内存中的 Auth 对象
|
||||
// 这解决了后台刷新与内存 Auth 对象之间的时间差问题
|
||||
kiroauth.GetRefreshManager().SetOnTokenRefreshed(func(tokenID string, tokenData *kiroauth.KiroTokenData) {
|
||||
if tokenData == nil || watcherWrapper == nil {
|
||||
return
|
||||
}
|
||||
log.Debugf("kiro refresh callback: notifying watcher for token %s", tokenID)
|
||||
watcherWrapper.NotifyTokenRefreshed(tokenID, tokenData.AccessToken, tokenData.RefreshToken, tokenData.ExpiresAt)
|
||||
})
|
||||
log.Debug("kiro: connected background refresh callback to watcher")
|
||||
|
||||
watcherCtx, watcherCancel := context.WithCancel(context.Background())
|
||||
s.watcherCancel = watcherCancel
|
||||
if err = watcherWrapper.Start(watcherCtx); err != nil {
|
||||
@@ -912,6 +948,28 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
||||
case "kimi":
|
||||
models = registry.GetKimiModels()
|
||||
models = applyExcludedModels(models, excluded)
|
||||
case "cursor":
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
models = executor.FetchCursorModels(ctx, a, s.cfg)
|
||||
models = applyExcludedModels(models, excluded)
|
||||
case "github-copilot":
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
models = executor.FetchGitHubCopilotModels(ctx, a, s.cfg)
|
||||
models = applyExcludedModels(models, excluded)
|
||||
case "kiro":
|
||||
models = s.fetchKiroModels(a)
|
||||
models = applyExcludedModels(models, excluded)
|
||||
case "kilo":
|
||||
models = executor.FetchKiloModels(context.Background(), a, s.cfg)
|
||||
models = applyExcludedModels(models, excluded)
|
||||
case "gitlab":
|
||||
models = executor.GitLabModelsFromAuth(a)
|
||||
models = applyExcludedModels(models, excluded)
|
||||
case "codebuddy":
|
||||
models = registry.GetCodeBuddyModels()
|
||||
models = applyExcludedModels(models, excluded)
|
||||
default:
|
||||
// Handle OpenAI-compatibility providers by name using config
|
||||
if s.cfg != nil {
|
||||
@@ -1527,3 +1585,216 @@ func applyOAuthModelAlias(cfg *config.Config, provider, authKind string, models
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// fetchKiroModels attempts to dynamically fetch Kiro models from the API.
|
||||
// If dynamic fetch fails, it falls back to static registry.GetKiroModels().
|
||||
func (s *Service) fetchKiroModels(a *coreauth.Auth) []*ModelInfo {
|
||||
if a == nil {
|
||||
log.Debug("kiro: auth is nil, using static models")
|
||||
return registry.GetKiroModels()
|
||||
}
|
||||
|
||||
// Extract token data from auth attributes
|
||||
tokenData := s.extractKiroTokenData(a)
|
||||
if tokenData == nil || tokenData.AccessToken == "" {
|
||||
log.Debug("kiro: no valid token data in auth, using static models")
|
||||
return registry.GetKiroModels()
|
||||
}
|
||||
|
||||
// Create KiroAuth instance
|
||||
kAuth := kiroauth.NewKiroAuth(s.cfg)
|
||||
if kAuth == nil {
|
||||
log.Warn("kiro: failed to create KiroAuth instance, using static models")
|
||||
return registry.GetKiroModels()
|
||||
}
|
||||
|
||||
// Use timeout context for API call
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Attempt to fetch dynamic models
|
||||
apiModels, err := kAuth.ListAvailableModels(ctx, tokenData)
|
||||
if err != nil {
|
||||
log.Warnf("kiro: failed to fetch dynamic models: %v, using static models", err)
|
||||
return registry.GetKiroModels()
|
||||
}
|
||||
|
||||
if len(apiModels) == 0 {
|
||||
log.Debug("kiro: API returned no models, using static models")
|
||||
return registry.GetKiroModels()
|
||||
}
|
||||
|
||||
// Convert API models to ModelInfo
|
||||
models := convertKiroAPIModels(apiModels)
|
||||
|
||||
// Generate agentic variants
|
||||
models = generateKiroAgenticVariants(models)
|
||||
|
||||
log.Infof("kiro: successfully fetched %d models from API (including agentic variants)", len(models))
|
||||
return models
|
||||
}
|
||||
|
||||
// extractKiroTokenData extracts KiroTokenData from auth attributes and metadata.
|
||||
// It supports both config-based tokens (stored in Attributes) and file-based tokens (stored in Metadata).
|
||||
func (s *Service) extractKiroTokenData(a *coreauth.Auth) *kiroauth.KiroTokenData {
|
||||
if a == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var accessToken, profileArn, refreshToken string
|
||||
|
||||
// Priority 1: Try to get from Attributes (config.yaml source)
|
||||
if a.Attributes != nil {
|
||||
accessToken = strings.TrimSpace(a.Attributes["access_token"])
|
||||
profileArn = strings.TrimSpace(a.Attributes["profile_arn"])
|
||||
refreshToken = strings.TrimSpace(a.Attributes["refresh_token"])
|
||||
}
|
||||
|
||||
// Priority 2: If not found in Attributes, try Metadata (JSON file source)
|
||||
if accessToken == "" && a.Metadata != nil {
|
||||
if at, ok := a.Metadata["access_token"].(string); ok {
|
||||
accessToken = strings.TrimSpace(at)
|
||||
}
|
||||
if pa, ok := a.Metadata["profile_arn"].(string); ok {
|
||||
profileArn = strings.TrimSpace(pa)
|
||||
}
|
||||
if rt, ok := a.Metadata["refresh_token"].(string); ok {
|
||||
refreshToken = strings.TrimSpace(rt)
|
||||
}
|
||||
}
|
||||
|
||||
// access_token is required
|
||||
if accessToken == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &kiroauth.KiroTokenData{
|
||||
AccessToken: accessToken,
|
||||
ProfileArn: profileArn,
|
||||
RefreshToken: refreshToken,
|
||||
}
|
||||
}
|
||||
|
||||
// convertKiroAPIModels converts Kiro API models to ModelInfo slice.
|
||||
func convertKiroAPIModels(apiModels []*kiroauth.KiroModel) []*ModelInfo {
|
||||
if len(apiModels) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
now := time.Now().Unix()
|
||||
models := make([]*ModelInfo, 0, len(apiModels))
|
||||
|
||||
for _, m := range apiModels {
|
||||
if m == nil || m.ModelID == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Create model ID with kiro- prefix
|
||||
modelID := "kiro-" + normalizeKiroModelID(m.ModelID)
|
||||
|
||||
info := &ModelInfo{
|
||||
ID: modelID,
|
||||
Object: "model",
|
||||
Created: now,
|
||||
OwnedBy: "aws",
|
||||
Type: "kiro",
|
||||
DisplayName: formatKiroDisplayName(m.ModelName, m.RateMultiplier),
|
||||
Description: m.Description,
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: ®istry.ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
}
|
||||
|
||||
if m.MaxInputTokens > 0 {
|
||||
info.ContextLength = m.MaxInputTokens
|
||||
}
|
||||
|
||||
models = append(models, info)
|
||||
}
|
||||
|
||||
return models
|
||||
}
|
||||
|
||||
// normalizeKiroModelID normalizes a Kiro model ID by converting dots to dashes
|
||||
// and removing common prefixes.
|
||||
func normalizeKiroModelID(modelID string) string {
|
||||
// Remove common prefixes
|
||||
modelID = strings.TrimPrefix(modelID, "anthropic.")
|
||||
modelID = strings.TrimPrefix(modelID, "amazon.")
|
||||
|
||||
// Replace dots with dashes for consistency
|
||||
modelID = strings.ReplaceAll(modelID, ".", "-")
|
||||
|
||||
// Replace underscores with dashes
|
||||
modelID = strings.ReplaceAll(modelID, "_", "-")
|
||||
|
||||
return strings.ToLower(modelID)
|
||||
}
|
||||
|
||||
// formatKiroDisplayName formats the display name with rate multiplier info.
|
||||
func formatKiroDisplayName(modelName string, rateMultiplier float64) string {
|
||||
if modelName == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
displayName := "Kiro " + modelName
|
||||
if rateMultiplier > 0 && rateMultiplier != 1.0 {
|
||||
displayName += fmt.Sprintf(" (%.1fx credit)", rateMultiplier)
|
||||
}
|
||||
|
||||
return displayName
|
||||
}
|
||||
|
||||
// generateKiroAgenticVariants generates agentic variants for Kiro models.
|
||||
// Agentic variants have optimized system prompts for coding agents.
|
||||
func generateKiroAgenticVariants(models []*ModelInfo) []*ModelInfo {
|
||||
if len(models) == 0 {
|
||||
return models
|
||||
}
|
||||
|
||||
result := make([]*ModelInfo, 0, len(models)*2)
|
||||
result = append(result, models...)
|
||||
|
||||
for _, m := range models {
|
||||
if m == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip if already an agentic variant
|
||||
if strings.HasSuffix(m.ID, "-agentic") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip auto models from agentic variant generation
|
||||
if strings.Contains(m.ID, "-auto") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Create agentic variant
|
||||
agentic := &ModelInfo{
|
||||
ID: m.ID + "-agentic",
|
||||
Object: m.Object,
|
||||
Created: m.Created,
|
||||
OwnedBy: m.OwnedBy,
|
||||
Type: m.Type,
|
||||
DisplayName: m.DisplayName + " (Agentic)",
|
||||
Description: m.Description + " - Optimized for coding agents (chunked writes)",
|
||||
ContextLength: m.ContextLength,
|
||||
MaxCompletionTokens: m.MaxCompletionTokens,
|
||||
}
|
||||
|
||||
// Copy thinking support if present
|
||||
if m.Thinking != nil {
|
||||
agentic.Thinking = ®istry.ThinkingSupport{
|
||||
Min: m.Thinking.Min,
|
||||
Max: m.Thinking.Max,
|
||||
ZeroAllowed: m.Thinking.ZeroAllowed,
|
||||
DynamicAllowed: m.Thinking.DynamicAllowed,
|
||||
}
|
||||
}
|
||||
|
||||
result = append(result, agentic)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
86
sdk/cliproxy/service_gitlab_models_test.go
Normal file
86
sdk/cliproxy/service_gitlab_models_test.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package cliproxy
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
)
|
||||
|
||||
func TestRegisterModelsForAuth_GitLabUsesDiscoveredModels(t *testing.T) {
|
||||
service := &Service{cfg: &config.Config{}}
|
||||
auth := &coreauth.Auth{
|
||||
ID: "gitlab-auth.json",
|
||||
Provider: "gitlab",
|
||||
Status: coreauth.StatusActive,
|
||||
Metadata: map[string]any{
|
||||
"model_details": map[string]any{
|
||||
"model_provider": "anthropic",
|
||||
"model_name": "claude-sonnet-4-5",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.UnregisterClient(auth.ID)
|
||||
t.Cleanup(func() { reg.UnregisterClient(auth.ID) })
|
||||
|
||||
service.registerModelsForAuth(auth)
|
||||
models := reg.GetModelsForClient(auth.ID)
|
||||
if len(models) < 2 {
|
||||
t.Fatalf("expected stable alias and discovered model, got %d entries", len(models))
|
||||
}
|
||||
|
||||
seenAlias := false
|
||||
seenDiscovered := false
|
||||
for _, model := range models {
|
||||
switch model.ID {
|
||||
case "gitlab-duo":
|
||||
seenAlias = true
|
||||
case "claude-sonnet-4-5":
|
||||
seenDiscovered = true
|
||||
}
|
||||
}
|
||||
if !seenAlias || !seenDiscovered {
|
||||
t.Fatalf("expected gitlab-duo and discovered model, got %+v", models)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterModelsForAuth_GitLabIncludesAgenticCatalog(t *testing.T) {
|
||||
service := &Service{cfg: &config.Config{}}
|
||||
auth := &coreauth.Auth{
|
||||
ID: "gitlab-agentic-auth.json",
|
||||
Provider: "gitlab",
|
||||
Status: coreauth.StatusActive,
|
||||
}
|
||||
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.UnregisterClient(auth.ID)
|
||||
t.Cleanup(func() { reg.UnregisterClient(auth.ID) })
|
||||
|
||||
service.registerModelsForAuth(auth)
|
||||
models := reg.GetModelsForClient(auth.ID)
|
||||
if len(models) < 5 {
|
||||
t.Fatalf("expected stable alias plus built-in agentic catalog, got %d entries", len(models))
|
||||
}
|
||||
|
||||
required := map[string]bool{
|
||||
"gitlab-duo": false,
|
||||
"duo-chat-opus-4-6": false,
|
||||
"duo-chat-haiku-4-5": false,
|
||||
"duo-chat-sonnet-4-5": false,
|
||||
"duo-chat-opus-4-5": false,
|
||||
"duo-chat-gpt-5-codex": false,
|
||||
}
|
||||
for _, model := range models {
|
||||
if _, ok := required[model.ID]; ok {
|
||||
required[model.ID] = true
|
||||
}
|
||||
}
|
||||
for id, seen := range required {
|
||||
if !seen {
|
||||
t.Fatalf("expected built-in GitLab Duo model %q, got %+v", id, models)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -90,3 +90,26 @@ func TestApplyOAuthModelAlias_ForkAddsMultipleAliases(t *testing.T) {
|
||||
t.Fatalf("expected forked model name %q, got %q", "models/g5-2", out[2].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyOAuthModelAlias_DefaultGitHubCopilotAliasViaSanitize(t *testing.T) {
|
||||
cfg := &config.Config{}
|
||||
cfg.SanitizeOAuthModelAlias()
|
||||
|
||||
models := []*ModelInfo{
|
||||
{ID: "claude-opus-4.6", Name: "models/claude-opus-4.6"},
|
||||
}
|
||||
|
||||
out := applyOAuthModelAlias(cfg, "github-copilot", "oauth", models)
|
||||
if len(out) != 2 {
|
||||
t.Fatalf("expected 2 models (original + default alias), got %d", len(out))
|
||||
}
|
||||
if out[0].ID != "claude-opus-4.6" {
|
||||
t.Fatalf("expected first model id %q, got %q", "claude-opus-4.6", out[0].ID)
|
||||
}
|
||||
if out[1].ID != "claude-opus-4-6" {
|
||||
t.Fatalf("expected second model id %q, got %q", "claude-opus-4-6", out[1].ID)
|
||||
}
|
||||
if out[1].Name != "models/claude-opus-4-6" {
|
||||
t.Fatalf("expected aliased model name %q, got %q", "models/claude-opus-4-6", out[1].Name)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -89,6 +89,7 @@ type WatcherWrapper struct {
|
||||
snapshotAuths func() []*coreauth.Auth
|
||||
setUpdateQueue func(queue chan<- watcher.AuthUpdate)
|
||||
dispatchRuntimeUpdate func(update watcher.AuthUpdate) bool
|
||||
notifyTokenRefreshed func(tokenID, accessToken, refreshToken, expiresAt string) // 方案 A: 后台刷新通知
|
||||
}
|
||||
|
||||
// Start proxies to the underlying watcher Start implementation.
|
||||
@@ -146,3 +147,16 @@ func (w *WatcherWrapper) SetAuthUpdateQueue(queue chan<- watcher.AuthUpdate) {
|
||||
}
|
||||
w.setUpdateQueue(queue)
|
||||
}
|
||||
|
||||
// NotifyTokenRefreshed 通知 Watcher 后台刷新器已更新 token
|
||||
// 这是方案 A 的核心方法,用于解决后台刷新与内存 Auth 对象的时间差问题
|
||||
// tokenID: token 文件名(如 kiro-xxx.json)
|
||||
// accessToken: 新的 access token
|
||||
// refreshToken: 新的 refresh token
|
||||
// expiresAt: 新的过期时间(RFC3339 格式)
|
||||
func (w *WatcherWrapper) NotifyTokenRefreshed(tokenID, accessToken, refreshToken, expiresAt string) {
|
||||
if w == nil || w.notifyTokenRefreshed == nil {
|
||||
return
|
||||
}
|
||||
w.notifyTokenRefreshed(tokenID, accessToken, refreshToken, expiresAt)
|
||||
}
|
||||
|
||||
@@ -31,5 +31,8 @@ func defaultWatcherFactory(configPath, authDir string, reload func(*config.Confi
|
||||
dispatchRuntimeUpdate: func(update watcher.AuthUpdate) bool {
|
||||
return w.DispatchRuntimeAuthUpdate(update)
|
||||
},
|
||||
notifyTokenRefreshed: func(tokenID, accessToken, refreshToken, expiresAt string) {
|
||||
w.NotifyTokenRefreshed(tokenID, accessToken, refreshToken, expiresAt)
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user