diff --git a/cmd/server/main.go b/cmd/server/main.go index 4148cd06..f66c12ee 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -79,6 +79,8 @@ func main() { var kiloLogin bool var iflowLogin bool var iflowCookie bool + var gitlabLogin bool + var gitlabTokenLogin bool var noBrowser bool var oauthCallbackPort int var antigravityLogin bool @@ -111,6 +113,8 @@ func main() { flag.BoolVar(&kiloLogin, "kilo-login", false, "Login to Kilo AI using device flow") flag.BoolVar(&iflowLogin, "iflow-login", false, "Login to iFlow using OAuth") flag.BoolVar(&iflowCookie, "iflow-cookie", false, "Login to iFlow using Cookie") + flag.BoolVar(&gitlabLogin, "gitlab-login", false, "Login to GitLab Duo using OAuth") + flag.BoolVar(&gitlabTokenLogin, "gitlab-token-login", false, "Login to GitLab Duo using a personal access token") flag.BoolVar(&noBrowser, "no-browser", false, "Don't open browser automatically for OAuth") flag.IntVar(&oauthCallbackPort, "oauth-callback-port", 0, "Override OAuth callback port (defaults to provider-specific port)") flag.BoolVar(&useIncognito, "incognito", false, "Open browser in incognito/private mode for OAuth (useful for multiple accounts)") @@ -527,6 +531,10 @@ func main() { cmd.DoIFlowLogin(cfg, options) } else if iflowCookie { cmd.DoIFlowCookieAuth(cfg, options) + } else if gitlabLogin { + cmd.DoGitLabLogin(cfg, options) + } else if gitlabTokenLogin { + cmd.DoGitLabTokenLogin(cfg, options) } else if kimiLogin { cmd.DoKimiLogin(cfg, options) } else if kiroLogin { diff --git a/internal/auth/gitlab/gitlab.go b/internal/auth/gitlab/gitlab.go new file mode 100644 index 00000000..7be2a141 --- /dev/null +++ b/internal/auth/gitlab/gitlab.go @@ -0,0 +1,492 @@ +package gitlab + +import ( + "context" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "net/url" + "strconv" + "strings" + "sync" + "time" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" + log "github.com/sirupsen/logrus" +) + +const ( + DefaultBaseURL = "https://gitlab.com" + DefaultCallbackPort = 17171 + defaultOAuthScope = "api read_user" +) + +type PKCECodes struct { + CodeVerifier string + CodeChallenge string +} + +type OAuthResult struct { + Code string + State string + Error string +} + +type OAuthServer struct { + server *http.Server + port int + resultChan chan *OAuthResult + errorChan chan error + mu sync.Mutex + running bool +} + +type TokenResponse struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + RefreshToken string `json:"refresh_token"` + Scope string `json:"scope"` + CreatedAt int64 `json:"created_at"` + ExpiresIn int `json:"expires_in"` +} + +type User struct { + ID int64 `json:"id"` + Username string `json:"username"` + Name string `json:"name"` + Email string `json:"email"` + PublicEmail string `json:"public_email"` +} + +type PersonalAccessTokenSelf struct { + ID int64 `json:"id"` + Name string `json:"name"` + Scopes []string `json:"scopes"` + UserID int64 `json:"user_id"` +} + +type ModelDetails struct { + ModelProvider string `json:"model_provider"` + ModelName string `json:"model_name"` +} + +type DirectAccessResponse struct { + BaseURL string `json:"base_url"` + Token string `json:"token"` + ExpiresAt int64 `json:"expires_at"` + Headers map[string]string `json:"headers"` + ModelDetails *ModelDetails `json:"model_details,omitempty"` +} + +type DiscoveredModel struct { + ModelProvider string + ModelName string +} + +type AuthClient struct { + httpClient *http.Client +} + +func NewAuthClient(cfg *config.Config) *AuthClient { + client := &http.Client{} + if cfg != nil { + client = util.SetProxy(&cfg.SDKConfig, client) + } + return &AuthClient{httpClient: client} +} + +func NormalizeBaseURL(raw string) string { + value := strings.TrimSpace(raw) + if value == "" { + return DefaultBaseURL + } + if !strings.Contains(value, "://") { + value = "https://" + value + } + value = strings.TrimRight(value, "/") + return value +} + +func TokenExpiry(now time.Time, token *TokenResponse) time.Time { + if token == nil { + return time.Time{} + } + if token.CreatedAt > 0 && token.ExpiresIn > 0 { + return time.Unix(token.CreatedAt+int64(token.ExpiresIn), 0).UTC() + } + if token.ExpiresIn > 0 { + return now.UTC().Add(time.Duration(token.ExpiresIn) * time.Second) + } + return time.Time{} +} + +func GeneratePKCECodes() (*PKCECodes, error) { + verifierBytes := make([]byte, 32) + if _, err := rand.Read(verifierBytes); err != nil { + return nil, fmt.Errorf("gitlab pkce generation failed: %w", err) + } + verifier := base64.RawURLEncoding.EncodeToString(verifierBytes) + sum := sha256.Sum256([]byte(verifier)) + challenge := base64.RawURLEncoding.EncodeToString(sum[:]) + return &PKCECodes{ + CodeVerifier: verifier, + CodeChallenge: challenge, + }, nil +} + +func NewOAuthServer(port int) *OAuthServer { + return &OAuthServer{ + port: port, + resultChan: make(chan *OAuthResult, 1), + errorChan: make(chan error, 1), + } +} + +func (s *OAuthServer) Start() error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.running { + return fmt.Errorf("gitlab oauth server already running") + } + if !s.isPortAvailable() { + return fmt.Errorf("port %d is already in use", s.port) + } + + mux := http.NewServeMux() + mux.HandleFunc("/auth/callback", s.handleCallback) + + s.server = &http.Server{ + Addr: fmt.Sprintf(":%d", s.port), + Handler: mux, + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + } + s.running = true + + go func() { + if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + s.errorChan <- err + } + }() + + time.Sleep(100 * time.Millisecond) + return nil +} + +func (s *OAuthServer) Stop(ctx context.Context) error { + s.mu.Lock() + defer s.mu.Unlock() + if !s.running || s.server == nil { + return nil + } + defer func() { + s.running = false + s.server = nil + }() + return s.server.Shutdown(ctx) +} + +func (s *OAuthServer) WaitForCallback(timeout time.Duration) (*OAuthResult, error) { + select { + case result := <-s.resultChan: + return result, nil + case err := <-s.errorChan: + return nil, err + case <-time.After(timeout): + return nil, fmt.Errorf("timeout waiting for OAuth callback") + } +} + +func (s *OAuthServer) handleCallback(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + query := r.URL.Query() + if errParam := strings.TrimSpace(query.Get("error")); errParam != "" { + s.sendResult(&OAuthResult{Error: errParam}) + http.Error(w, errParam, http.StatusBadRequest) + return + } + code := strings.TrimSpace(query.Get("code")) + state := strings.TrimSpace(query.Get("state")) + if code == "" || state == "" { + s.sendResult(&OAuthResult{Error: "missing_code_or_state"}) + http.Error(w, "missing code or state", http.StatusBadRequest) + return + } + s.sendResult(&OAuthResult{Code: code, State: state}) + _, _ = w.Write([]byte("GitLab authentication received. You can close this tab.")) +} + +func (s *OAuthServer) sendResult(result *OAuthResult) { + select { + case s.resultChan <- result: + default: + log.Debug("gitlab oauth result channel full, dropping callback result") + } +} + +func (s *OAuthServer) isPortAvailable() bool { + listener, err := net.Listen("tcp", fmt.Sprintf(":%d", s.port)) + if err != nil { + return false + } + _ = listener.Close() + return true +} + +func RedirectURL(port int) string { + return fmt.Sprintf("http://localhost:%d/auth/callback", port) +} + +func (c *AuthClient) GenerateAuthURL(baseURL, clientID, redirectURI, state string, pkce *PKCECodes) (string, error) { + if pkce == nil { + return "", fmt.Errorf("gitlab auth URL generation failed: PKCE codes are required") + } + if strings.TrimSpace(clientID) == "" { + return "", fmt.Errorf("gitlab auth URL generation failed: client ID is required") + } + baseURL = NormalizeBaseURL(baseURL) + params := url.Values{ + "client_id": {strings.TrimSpace(clientID)}, + "response_type": {"code"}, + "redirect_uri": {strings.TrimSpace(redirectURI)}, + "scope": {defaultOAuthScope}, + "state": {strings.TrimSpace(state)}, + "code_challenge": {pkce.CodeChallenge}, + "code_challenge_method": {"S256"}, + } + return fmt.Sprintf("%s/oauth/authorize?%s", baseURL, params.Encode()), nil +} + +func (c *AuthClient) ExchangeCodeForTokens(ctx context.Context, baseURL, clientID, clientSecret, redirectURI, code, codeVerifier string) (*TokenResponse, error) { + form := url.Values{ + "grant_type": {"authorization_code"}, + "client_id": {strings.TrimSpace(clientID)}, + "code": {strings.TrimSpace(code)}, + "redirect_uri": {strings.TrimSpace(redirectURI)}, + "code_verifier": {strings.TrimSpace(codeVerifier)}, + } + if secret := strings.TrimSpace(clientSecret); secret != "" { + form.Set("client_secret", secret) + } + return c.postToken(ctx, NormalizeBaseURL(baseURL)+"/oauth/token", form) +} + +func (c *AuthClient) RefreshTokens(ctx context.Context, baseURL, clientID, clientSecret, refreshToken string) (*TokenResponse, error) { + form := url.Values{ + "grant_type": {"refresh_token"}, + "refresh_token": {strings.TrimSpace(refreshToken)}, + } + if clientID = strings.TrimSpace(clientID); clientID != "" { + form.Set("client_id", clientID) + } + if secret := strings.TrimSpace(clientSecret); secret != "" { + form.Set("client_secret", secret) + } + return c.postToken(ctx, NormalizeBaseURL(baseURL)+"/oauth/token", form) +} + +func (c *AuthClient) postToken(ctx context.Context, tokenURL string, form url.Values) (*TokenResponse, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, strings.NewReader(form.Encode())) + if err != nil { + return nil, fmt.Errorf("gitlab token request failed: %w", err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("gitlab token request failed: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("gitlab token response read failed: %w", err) + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, fmt.Errorf("gitlab token request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) + } + var token TokenResponse + if err := json.Unmarshal(body, &token); err != nil { + return nil, fmt.Errorf("gitlab token response decode failed: %w", err) + } + return &token, nil +} + +func (c *AuthClient) GetCurrentUser(ctx context.Context, baseURL, token string) (*User, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, NormalizeBaseURL(baseURL)+"/api/v4/user", nil) + if err != nil { + return nil, fmt.Errorf("gitlab user request failed: %w", err) + } + req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(token)) + req.Header.Set("Accept", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("gitlab user request failed: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("gitlab user response read failed: %w", err) + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, fmt.Errorf("gitlab user request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) + } + + var user User + if err := json.Unmarshal(body, &user); err != nil { + return nil, fmt.Errorf("gitlab user response decode failed: %w", err) + } + return &user, nil +} + +func (c *AuthClient) GetPersonalAccessTokenSelf(ctx context.Context, baseURL, token string) (*PersonalAccessTokenSelf, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, NormalizeBaseURL(baseURL)+"/api/v4/personal_access_tokens/self", nil) + if err != nil { + return nil, fmt.Errorf("gitlab PAT self request failed: %w", err) + } + req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(token)) + req.Header.Set("Accept", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("gitlab PAT self request failed: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("gitlab PAT self response read failed: %w", err) + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, fmt.Errorf("gitlab PAT self request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) + } + + var pat PersonalAccessTokenSelf + if err := json.Unmarshal(body, &pat); err != nil { + return nil, fmt.Errorf("gitlab PAT self response decode failed: %w", err) + } + return &pat, nil +} + +func (c *AuthClient) FetchDirectAccess(ctx context.Context, baseURL, token string) (*DirectAccessResponse, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, NormalizeBaseURL(baseURL)+"/api/v4/code_suggestions/direct_access", nil) + if err != nil { + return nil, fmt.Errorf("gitlab direct access request failed: %w", err) + } + req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(token)) + req.Header.Set("Accept", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("gitlab direct access request failed: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("gitlab direct access response read failed: %w", err) + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, fmt.Errorf("gitlab direct access request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) + } + + var direct DirectAccessResponse + if err := json.Unmarshal(body, &direct); err != nil { + return nil, fmt.Errorf("gitlab direct access response decode failed: %w", err) + } + if direct.Headers == nil { + direct.Headers = make(map[string]string) + } + return &direct, nil +} + +func ExtractDiscoveredModels(metadata map[string]any) []DiscoveredModel { + if len(metadata) == 0 { + return nil + } + + models := make([]DiscoveredModel, 0, 4) + seen := make(map[string]struct{}) + appendModel := func(provider, name string) { + provider = strings.TrimSpace(provider) + name = strings.TrimSpace(name) + if name == "" { + return + } + key := strings.ToLower(provider + "\x00" + name) + if _, ok := seen[key]; ok { + return + } + seen[key] = struct{}{} + models = append(models, DiscoveredModel{ + ModelProvider: provider, + ModelName: name, + }) + } + + if raw, ok := metadata["model_details"]; ok { + appendDiscoveredModels(raw, appendModel) + } + appendModel(stringValue(metadata["model_provider"]), stringValue(metadata["model_name"])) + + for _, key := range []string{"models", "supported_models", "discovered_models"} { + if raw, ok := metadata[key]; ok { + appendDiscoveredModels(raw, appendModel) + } + } + + return models +} + +func appendDiscoveredModels(raw any, appendModel func(provider, name string)) { + switch typed := raw.(type) { + case map[string]any: + appendModel(stringValue(typed["model_provider"]), stringValue(typed["model_name"])) + appendModel(stringValue(typed["provider"]), stringValue(typed["name"])) + if nested, ok := typed["models"]; ok { + appendDiscoveredModels(nested, appendModel) + } + case []any: + for _, item := range typed { + appendDiscoveredModels(item, appendModel) + } + case []string: + for _, item := range typed { + appendModel("", item) + } + case string: + appendModel("", typed) + } +} + +func stringValue(raw any) string { + switch typed := raw.(type) { + case string: + return strings.TrimSpace(typed) + case fmt.Stringer: + return strings.TrimSpace(typed.String()) + case json.Number: + return typed.String() + case int: + return strconv.Itoa(typed) + case int64: + return strconv.FormatInt(typed, 10) + case float64: + return strconv.FormatInt(int64(typed), 10) + default: + return "" + } +} diff --git a/internal/auth/gitlab/gitlab_test.go b/internal/auth/gitlab/gitlab_test.go new file mode 100644 index 00000000..aa4c0b2b --- /dev/null +++ b/internal/auth/gitlab/gitlab_test.go @@ -0,0 +1,72 @@ +package gitlab + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" +) + +func TestNormalizeBaseURL(t *testing.T) { + tests := []struct { + name string + in string + want string + }{ + {name: "default", in: "", want: DefaultBaseURL}, + {name: "plain host", in: "gitlab.example.com", want: "https://gitlab.example.com"}, + {name: "trim trailing slash", in: "https://gitlab.example.com/", want: "https://gitlab.example.com"}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if got := NormalizeBaseURL(tc.in); got != tc.want { + t.Fatalf("NormalizeBaseURL(%q) = %q, want %q", tc.in, got, tc.want) + } + }) + } +} + +func TestFetchDirectAccess_ParsesModelDetails(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Fatalf("expected POST, got %s", r.Method) + } + if got := r.Header.Get("Authorization"); got != "Bearer pat-123" { + t.Fatalf("expected Authorization header, got %q", got) + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{ + "base_url":"https://gateway.gitlab.example.com/v1", + "token":"duo-gateway-token", + "expires_at":2000000000, + "headers":{ + "X-Gitlab-Realm":"saas", + "X-Gitlab-Host-Name":"gitlab.example.com" + }, + "model_details":{ + "model_provider":"anthropic", + "model_name":"claude-sonnet-4-5" + } + }`)) + })) + defer server.Close() + + client := &AuthClient{httpClient: server.Client()} + direct, err := client.FetchDirectAccess(context.Background(), server.URL, "pat-123") + if err != nil { + t.Fatalf("FetchDirectAccess returned error: %v", err) + } + if direct.BaseURL != "https://gateway.gitlab.example.com/v1" { + t.Fatalf("unexpected base_url %q", direct.BaseURL) + } + if direct.Token != "duo-gateway-token" { + t.Fatalf("unexpected token %q", direct.Token) + } + if direct.ModelDetails == nil || direct.ModelDetails.ModelName != "claude-sonnet-4-5" { + t.Fatalf("unexpected model details: %+v", direct.ModelDetails) + } + if direct.Headers["X-Gitlab-Realm"] != "saas" { + t.Fatalf("expected X-Gitlab-Realm header, got %+v", direct.Headers) + } +} diff --git a/internal/cmd/auth_manager.go b/internal/cmd/auth_manager.go index 2a3407be..ea7a0532 100644 --- a/internal/cmd/auth_manager.go +++ b/internal/cmd/auth_manager.go @@ -23,6 +23,7 @@ func newAuthManager() *sdkAuth.Manager { sdkAuth.NewKiroAuthenticator(), sdkAuth.NewGitHubCopilotAuthenticator(), sdkAuth.NewKiloAuthenticator(), + sdkAuth.NewGitLabAuthenticator(), ) return manager } diff --git a/internal/cmd/gitlab_login.go b/internal/cmd/gitlab_login.go new file mode 100644 index 00000000..9384bec1 --- /dev/null +++ b/internal/cmd/gitlab_login.go @@ -0,0 +1,69 @@ +package cmd + +import ( + "context" + "fmt" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth" +) + +func DoGitLabLogin(cfg *config.Config, options *LoginOptions) { + if options == nil { + options = &LoginOptions{} + } + + promptFn := options.Prompt + if promptFn == nil { + promptFn = defaultProjectPrompt() + } + + manager := newAuthManager() + authOpts := &sdkAuth.LoginOptions{ + NoBrowser: options.NoBrowser, + CallbackPort: options.CallbackPort, + Metadata: map[string]string{ + "login_mode": "oauth", + }, + Prompt: promptFn, + } + + _, savedPath, err := manager.Login(context.Background(), "gitlab", cfg, authOpts) + if err != nil { + fmt.Printf("GitLab Duo authentication failed: %v\n", err) + return + } + if savedPath != "" { + fmt.Printf("Authentication saved to %s\n", savedPath) + } + fmt.Println("GitLab Duo authentication successful!") +} + +func DoGitLabTokenLogin(cfg *config.Config, options *LoginOptions) { + if options == nil { + options = &LoginOptions{} + } + + promptFn := options.Prompt + if promptFn == nil { + promptFn = defaultProjectPrompt() + } + + manager := newAuthManager() + authOpts := &sdkAuth.LoginOptions{ + Metadata: map[string]string{ + "login_mode": "pat", + }, + Prompt: promptFn, + } + + _, savedPath, err := manager.Login(context.Background(), "gitlab", cfg, authOpts) + if err != nil { + fmt.Printf("GitLab Duo PAT authentication failed: %v\n", err) + return + } + if savedPath != "" { + fmt.Printf("Authentication saved to %s\n", savedPath) + } + fmt.Println("GitLab Duo PAT authentication successful!") +} diff --git a/internal/runtime/executor/gitlab_executor.go b/internal/runtime/executor/gitlab_executor.go new file mode 100644 index 00000000..d05fa086 --- /dev/null +++ b/internal/runtime/executor/gitlab_executor.go @@ -0,0 +1,746 @@ +package executor + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gitlab" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" + "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking" + cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" + sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" + "github.com/tidwall/gjson" +) + +const ( + gitLabProviderKey = "gitlab" + gitLabAuthMethodOAuth = "oauth" + gitLabAuthMethodPAT = "pat" + gitLabChatEndpoint = "/api/v4/chat/completions" + gitLabCodeSuggestionsEndpoint = "/api/v4/code_suggestions/completions" +) + +type GitLabExecutor struct { + cfg *config.Config +} + +type gitLabPrompt struct { + Instruction string + FileName string + ContentAboveCursor string + ChatContext []map[string]any + CodeSuggestionContext []map[string]any +} + +func NewGitLabExecutor(cfg *config.Config) *GitLabExecutor { + return &GitLabExecutor{cfg: cfg} +} + +func (e *GitLabExecutor) Identifier() string { return gitLabProviderKey } + +func (e *GitLabExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) { + baseModel := thinking.ParseSuffix(req.Model).ModelName + + reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) + defer reporter.trackFailure(ctx, &err) + + translated, err := e.translateToOpenAI(req, opts) + if err != nil { + return resp, err + } + prompt := buildGitLabPrompt(translated) + if strings.TrimSpace(prompt.Instruction) == "" && strings.TrimSpace(prompt.ContentAboveCursor) == "" { + err = statusErr{code: http.StatusBadRequest, msg: "gitlab duo executor: request has no usable text content"} + return resp, err + } + + text, err := e.invoke(ctx, auth, prompt) + if err != nil { + return resp, err + } + + responseModel := gitLabResolvedModel(auth, req.Model) + openAIResponse := buildGitLabOpenAIResponse(responseModel, text, translated) + reporter.publish(ctx, parseOpenAIUsage(openAIResponse)) + reporter.ensurePublished(ctx) + + var param any + out := sdktranslator.TranslateNonStream( + ctx, + sdktranslator.FromString("openai"), + opts.SourceFormat, + req.Model, + opts.OriginalRequest, + translated, + openAIResponse, + ¶m, + ) + return cliproxyexecutor.Response{Payload: []byte(out), Headers: make(http.Header)}, nil +} + +func (e *GitLabExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) { + baseModel := thinking.ParseSuffix(req.Model).ModelName + + reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) + defer reporter.trackFailure(ctx, &err) + + translated, err := e.translateToOpenAI(req, opts) + if err != nil { + return nil, err + } + prompt := buildGitLabPrompt(translated) + if strings.TrimSpace(prompt.Instruction) == "" && strings.TrimSpace(prompt.ContentAboveCursor) == "" { + return nil, statusErr{code: http.StatusBadRequest, msg: "gitlab duo executor: request has no usable text content"} + } + + text, err := e.invoke(ctx, auth, prompt) + if err != nil { + return nil, err + } + + responseModel := gitLabResolvedModel(auth, req.Model) + openAIResponse := buildGitLabOpenAIResponse(responseModel, text, translated) + reporter.publish(ctx, parseOpenAIUsage(openAIResponse)) + reporter.ensurePublished(ctx) + + out := make(chan cliproxyexecutor.StreamChunk, 8) + go func() { + defer close(out) + var param any + lines := buildGitLabOpenAIStream(responseModel, text) + for _, line := range lines { + chunks := sdktranslator.TranslateStream( + ctx, + sdktranslator.FromString("openai"), + opts.SourceFormat, + req.Model, + opts.OriginalRequest, + translated, + []byte(line), + ¶m, + ) + for i := range chunks { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])} + } + } + }() + return &cliproxyexecutor.StreamResult{Headers: make(http.Header), Chunks: out}, nil +} + +func (e *GitLabExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { + if auth == nil { + return nil, fmt.Errorf("gitlab duo executor: auth is nil") + } + baseURL := gitLabBaseURL(auth) + token := gitLabPrimaryToken(auth) + if baseURL == "" || token == "" { + return nil, fmt.Errorf("gitlab duo executor: missing base URL or token") + } + + client := gitlab.NewAuthClient(e.cfg) + method := strings.ToLower(strings.TrimSpace(gitLabMetadataString(auth.Metadata, "auth_method", "auth_kind"))) + if method == "" { + method = gitLabAuthMethodOAuth + } + + if method == gitLabAuthMethodOAuth { + if refreshed, refreshErr := e.refreshOAuthToken(ctx, client, auth, baseURL); refreshErr == nil && refreshed != nil { + token = refreshed.AccessToken + applyGitLabTokenMetadata(auth.Metadata, refreshed) + } + } + + direct, err := client.FetchDirectAccess(ctx, baseURL, token) + if err != nil && method == gitLabAuthMethodOAuth { + if refreshed, refreshErr := e.refreshOAuthToken(ctx, client, auth, baseURL); refreshErr == nil && refreshed != nil { + token = refreshed.AccessToken + applyGitLabTokenMetadata(auth.Metadata, refreshed) + direct, err = client.FetchDirectAccess(ctx, baseURL, token) + } + } + if err != nil { + return nil, err + } + + if auth.Metadata == nil { + auth.Metadata = make(map[string]any) + } + auth.Metadata["type"] = gitLabProviderKey + auth.Metadata["auth_method"] = method + auth.Metadata["auth_kind"] = gitLabAuthKind(method) + auth.Metadata["base_url"] = gitlab.NormalizeBaseURL(baseURL) + auth.Metadata["last_refresh"] = time.Now().UTC().Format(time.RFC3339) + mergeGitLabDirectAccessMetadata(auth.Metadata, direct) + return auth, nil +} + +func (e *GitLabExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + baseModel := thinking.ParseSuffix(req.Model).ModelName + translated := sdktranslator.TranslateRequest(opts.SourceFormat, sdktranslator.FromString("openai"), baseModel, req.Payload, false) + enc, err := tokenizerForModel(baseModel) + if err != nil { + return cliproxyexecutor.Response{}, fmt.Errorf("gitlab duo executor: tokenizer init failed: %w", err) + } + count, err := countOpenAIChatTokens(enc, translated) + if err != nil { + return cliproxyexecutor.Response{}, err + } + return cliproxyexecutor.Response{Payload: buildOpenAIUsageJSON(count), Headers: make(http.Header)}, nil +} + +func (e *GitLabExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) { + if req == nil { + return nil, fmt.Errorf("gitlab duo executor: request is nil") + } + if ctx == nil { + ctx = req.Context() + } + httpReq := req.WithContext(ctx) + if token := gitLabPrimaryToken(auth); token != "" { + httpReq.Header.Set("Authorization", "Bearer "+token) + } + return newProxyAwareHTTPClient(ctx, e.cfg, auth, 0).Do(httpReq) +} + +func (e *GitLabExecutor) translateToOpenAI(req cliproxyexecutor.Request, opts cliproxyexecutor.Options) ([]byte, error) { + baseModel := thinking.ParseSuffix(req.Model).ModelName + return sdktranslator.TranslateRequest(opts.SourceFormat, sdktranslator.FromString("openai"), baseModel, req.Payload, opts.Stream), nil +} + +func (e *GitLabExecutor) invoke(ctx context.Context, auth *cliproxyauth.Auth, prompt gitLabPrompt) (string, error) { + if text, err := e.requestChat(ctx, auth, prompt); err == nil { + return text, nil + } else if !shouldFallbackToCodeSuggestions(err) { + return "", err + } + return e.requestCodeSuggestions(ctx, auth, prompt) +} + +func (e *GitLabExecutor) requestChat(ctx context.Context, auth *cliproxyauth.Auth, prompt gitLabPrompt) (string, error) { + body := map[string]any{ + "content": prompt.Instruction, + "with_clean_history": true, + } + if len(prompt.ChatContext) > 0 { + body["additional_context"] = prompt.ChatContext + } + return e.doJSONTextRequest(ctx, auth, gitLabChatEndpoint, body) +} + +func (e *GitLabExecutor) requestCodeSuggestions(ctx context.Context, auth *cliproxyauth.Auth, prompt gitLabPrompt) (string, error) { + contentAbove := strings.TrimSpace(prompt.ContentAboveCursor) + if contentAbove == "" { + contentAbove = prompt.Instruction + } + body := map[string]any{ + "current_file": map[string]any{ + "file_name": prompt.FileName, + "content_above_cursor": contentAbove, + "content_below_cursor": "", + }, + "intent": "generation", + "generation_type": "small_file", + "user_instruction": prompt.Instruction, + "stream": false, + } + if len(prompt.CodeSuggestionContext) > 0 { + body["context"] = prompt.CodeSuggestionContext + } + return e.doJSONTextRequest(ctx, auth, gitLabCodeSuggestionsEndpoint, body) +} + +func (e *GitLabExecutor) doJSONTextRequest(ctx context.Context, auth *cliproxyauth.Auth, endpoint string, payload map[string]any) (string, error) { + token := gitLabPrimaryToken(auth) + baseURL := gitLabBaseURL(auth) + if token == "" || baseURL == "" { + return "", statusErr{code: http.StatusUnauthorized, msg: "gitlab duo executor: missing credentials"} + } + + body, err := json.Marshal(payload) + if err != nil { + return "", fmt.Errorf("gitlab duo executor: marshal request failed: %w", err) + } + + url := strings.TrimRight(baseURL, "/") + endpoint + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return "", err + } + req.Header.Set("Authorization", "Bearer "+token) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + req.Header.Set("User-Agent", "CLIProxyAPI/GitLab-Duo") + + var authID, authLabel, authType, authValue string + if auth != nil { + authID = auth.ID + authLabel = auth.Label + authType, authValue = auth.AccountInfo() + } + recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ + URL: url, + Method: http.MethodPost, + Headers: req.Header.Clone(), + Body: body, + Provider: e.Identifier(), + AuthID: authID, + AuthLabel: authLabel, + AuthType: authType, + AuthValue: authValue, + }) + + httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) + resp, err := httpClient.Do(req) + if err != nil { + recordAPIResponseError(ctx, e.cfg, err) + return "", err + } + defer func() { _ = resp.Body.Close() }() + recordAPIResponseMetadata(ctx, e.cfg, resp.StatusCode, resp.Header.Clone()) + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + recordAPIResponseError(ctx, e.cfg, err) + return "", err + } + appendAPIResponseChunk(ctx, e.cfg, respBody) + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return "", statusErr{code: resp.StatusCode, msg: strings.TrimSpace(string(respBody))} + } + + text, err := parseGitLabTextResponse(endpoint, respBody) + if err != nil { + return "", err + } + return strings.TrimSpace(text), nil +} + +func (e *GitLabExecutor) refreshOAuthToken(ctx context.Context, client *gitlab.AuthClient, auth *cliproxyauth.Auth, baseURL string) (*gitlab.TokenResponse, error) { + if auth == nil { + return nil, fmt.Errorf("gitlab duo executor: auth is nil") + } + refreshToken := gitLabMetadataString(auth.Metadata, "refresh_token") + if refreshToken == "" { + return nil, fmt.Errorf("gitlab duo executor: refresh token missing") + } + if !gitLabOAuthTokenNeedsRefresh(auth.Metadata) && gitLabPrimaryToken(auth) != "" { + return nil, nil + } + return client.RefreshTokens( + ctx, + baseURL, + gitLabMetadataString(auth.Metadata, "oauth_client_id"), + gitLabMetadataString(auth.Metadata, "oauth_client_secret"), + refreshToken, + ) +} + +func buildGitLabPrompt(payload []byte) gitLabPrompt { + root := gjson.ParseBytes(payload) + prompt := gitLabPrompt{ + FileName: "prompt.txt", + } + + msgs := root.Get("messages") + if msgs.Exists() && msgs.IsArray() { + systemIndex := 0 + contextIndex := 0 + transcript := make([]string, 0, len(msgs.Array())) + var lastUser string + msgs.ForEach(func(_, msg gjson.Result) bool { + role := strings.TrimSpace(msg.Get("role").String()) + if role == "" { + role = "user" + } + content := openAIContentText(msg.Get("content")) + if content == "" { + return true + } + switch role { + case "system": + systemIndex++ + prompt.ChatContext = append(prompt.ChatContext, map[string]any{ + "category": "snippet", + "id": fmt.Sprintf("system-%d", systemIndex), + "content": content, + }) + case "user": + lastUser = content + contextIndex++ + prompt.CodeSuggestionContext = append(prompt.CodeSuggestionContext, map[string]any{ + "type": "snippet", + "name": fmt.Sprintf("user-%d", contextIndex), + "content": content, + }) + transcript = append(transcript, "User:\n"+content) + default: + contextIndex++ + prompt.ChatContext = append(prompt.ChatContext, map[string]any{ + "category": "snippet", + "id": fmt.Sprintf("%s-%d", role, contextIndex), + "content": content, + }) + prompt.CodeSuggestionContext = append(prompt.CodeSuggestionContext, map[string]any{ + "type": "snippet", + "name": fmt.Sprintf("%s-%d", role, contextIndex), + "content": content, + }) + transcript = append(transcript, strings.Title(role)+":\n"+content) + } + return true + }) + prompt.Instruction = strings.TrimSpace(lastUser) + prompt.ContentAboveCursor = truncateGitLabPrompt(strings.Join(transcript, "\n\n"), 12000) + } + + if prompt.Instruction == "" { + for _, key := range []string{"prompt", "input", "instructions"} { + if value := strings.TrimSpace(root.Get(key).String()); value != "" { + prompt.Instruction = value + break + } + } + } + if prompt.ContentAboveCursor == "" { + prompt.ContentAboveCursor = prompt.Instruction + } + prompt.Instruction = truncateGitLabPrompt(prompt.Instruction, 4000) + prompt.ContentAboveCursor = truncateGitLabPrompt(prompt.ContentAboveCursor, 12000) + return prompt +} + +func openAIContentText(content gjson.Result) string { + segments := make([]string, 0, 8) + collectOpenAIContent(content, &segments) + return strings.TrimSpace(strings.Join(segments, "\n")) +} + +func truncateGitLabPrompt(value string, limit int) string { + value = strings.TrimSpace(value) + if limit <= 0 || len(value) <= limit { + return value + } + return strings.TrimSpace(value[:limit]) +} + +func parseGitLabTextResponse(endpoint string, body []byte) (string, error) { + if endpoint == gitLabChatEndpoint { + var text string + if err := json.Unmarshal(body, &text); err == nil { + return text, nil + } + if value := strings.TrimSpace(gjson.GetBytes(body, "response").String()); value != "" { + return value, nil + } + } + if value := strings.TrimSpace(gjson.GetBytes(body, "choices.0.text").String()); value != "" { + return value, nil + } + if value := strings.TrimSpace(gjson.GetBytes(body, "response").String()); value != "" { + return value, nil + } + var plain string + if err := json.Unmarshal(body, &plain); err == nil && strings.TrimSpace(plain) != "" { + return plain, nil + } + return "", fmt.Errorf("gitlab duo executor: upstream returned no text payload") +} + +func shouldFallbackToCodeSuggestions(err error) bool { + if err == nil { + return false + } + status, ok := err.(interface{ StatusCode() int }) + if !ok { + return false + } + switch status.StatusCode() { + case http.StatusForbidden, http.StatusNotFound, http.StatusMethodNotAllowed, http.StatusNotImplemented: + return true + default: + return false + } +} + +func buildGitLabOpenAIResponse(model, text string, translatedReq []byte) []byte { + promptTokens, completionTokens := gitLabUsage(model, translatedReq, text) + payload := map[string]any{ + "id": fmt.Sprintf("gitlab-%d", time.Now().UnixNano()), + "object": "chat.completion", + "created": time.Now().Unix(), + "model": model, + "choices": []map[string]any{{ + "index": 0, + "message": map[string]any{ + "role": "assistant", + "content": text, + }, + "finish_reason": "stop", + }}, + "usage": map[string]any{ + "prompt_tokens": promptTokens, + "completion_tokens": completionTokens, + "total_tokens": promptTokens + completionTokens, + }, + } + raw, _ := json.Marshal(payload) + return raw +} + +func buildGitLabOpenAIStream(model, text string) []string { + now := time.Now().Unix() + id := fmt.Sprintf("gitlab-%d", time.Now().UnixNano()) + chunks := []map[string]any{ + { + "id": id, + "object": "chat.completion.chunk", + "created": now, + "model": model, + "choices": []map[string]any{{ + "index": 0, + "delta": map[string]any{"role": "assistant"}, + }}, + }, + { + "id": id, + "object": "chat.completion.chunk", + "created": now, + "model": model, + "choices": []map[string]any{{ + "index": 0, + "delta": map[string]any{"content": text}, + }}, + }, + { + "id": id, + "object": "chat.completion.chunk", + "created": now, + "model": model, + "choices": []map[string]any{{ + "index": 0, + "delta": map[string]any{}, + "finish_reason": "stop", + }}, + }, + } + lines := make([]string, 0, len(chunks)+1) + for _, chunk := range chunks { + raw, _ := json.Marshal(chunk) + lines = append(lines, "data: "+string(raw)) + } + lines = append(lines, "data: [DONE]") + return lines +} + +func gitLabUsage(model string, translatedReq []byte, text string) (int64, int64) { + enc, err := tokenizerForModel(model) + if err != nil { + return 0, 0 + } + promptTokens, err := countOpenAIChatTokens(enc, translatedReq) + if err != nil { + promptTokens = 0 + } + completionCount, err := enc.Count(strings.TrimSpace(text)) + if err != nil { + return promptTokens, 0 + } + return promptTokens, int64(completionCount) +} + +func gitLabPrimaryToken(auth *cliproxyauth.Auth) string { + if auth == nil || auth.Metadata == nil { + return "" + } + if token := gitLabMetadataString(auth.Metadata, "access_token"); token != "" { + return token + } + return gitLabMetadataString(auth.Metadata, "personal_access_token") +} + +func gitLabBaseURL(auth *cliproxyauth.Auth) string { + if auth == nil || auth.Metadata == nil { + return "" + } + return gitlab.NormalizeBaseURL(gitLabMetadataString(auth.Metadata, "base_url")) +} + +func gitLabResolvedModel(auth *cliproxyauth.Auth, requested string) string { + requested = strings.TrimSpace(thinking.ParseSuffix(requested).ModelName) + if requested != "" && !strings.EqualFold(requested, "gitlab-duo") { + return requested + } + if auth != nil && auth.Metadata != nil { + for _, model := range gitlab.ExtractDiscoveredModels(auth.Metadata) { + if name := strings.TrimSpace(model.ModelName); name != "" { + return name + } + } + } + if requested != "" { + return requested + } + return "gitlab-duo" +} + +func gitLabMetadataString(metadata map[string]any, keys ...string) string { + for _, key := range keys { + if metadata == nil { + return "" + } + if value, ok := metadata[key].(string); ok { + if trimmed := strings.TrimSpace(value); trimmed != "" { + return trimmed + } + } + } + return "" +} + +func gitLabOAuthTokenNeedsRefresh(metadata map[string]any) bool { + expiry := gitLabMetadataString(metadata, "oauth_expires_at") + if expiry == "" { + return true + } + ts, err := time.Parse(time.RFC3339, expiry) + if err != nil { + return true + } + return time.Until(ts) <= 5*time.Minute +} + +func applyGitLabTokenMetadata(metadata map[string]any, tokenResp *gitlab.TokenResponse) { + if metadata == nil || tokenResp == nil { + return + } + if accessToken := strings.TrimSpace(tokenResp.AccessToken); accessToken != "" { + metadata["access_token"] = accessToken + } + if refreshToken := strings.TrimSpace(tokenResp.RefreshToken); refreshToken != "" { + metadata["refresh_token"] = refreshToken + } + if tokenType := strings.TrimSpace(tokenResp.TokenType); tokenType != "" { + metadata["token_type"] = tokenType + } + if scope := strings.TrimSpace(tokenResp.Scope); scope != "" { + metadata["scope"] = scope + } + if expiry := gitlab.TokenExpiry(time.Now(), tokenResp); !expiry.IsZero() { + metadata["oauth_expires_at"] = expiry.Format(time.RFC3339) + } +} + +func mergeGitLabDirectAccessMetadata(metadata map[string]any, direct *gitlab.DirectAccessResponse) { + if metadata == nil || direct == nil { + return + } + if base := strings.TrimSpace(direct.BaseURL); base != "" { + metadata["duo_gateway_base_url"] = base + } + if token := strings.TrimSpace(direct.Token); token != "" { + metadata["duo_gateway_token"] = token + } + if direct.ExpiresAt > 0 { + expiry := time.Unix(direct.ExpiresAt, 0).UTC() + metadata["duo_gateway_expires_at"] = expiry.Format(time.RFC3339) + if ttl := expiry.Sub(time.Now().UTC()); ttl > 0 { + interval := int(ttl.Seconds()) / 2 + switch { + case interval < 60: + interval = 60 + case interval > 240: + interval = 240 + } + metadata["refresh_interval_seconds"] = interval + } + } + if len(direct.Headers) > 0 { + headers := make(map[string]string, len(direct.Headers)) + for key, value := range direct.Headers { + key = strings.TrimSpace(key) + value = strings.TrimSpace(value) + if key == "" || value == "" { + continue + } + headers[key] = value + } + if len(headers) > 0 { + metadata["duo_gateway_headers"] = headers + } + } + if direct.ModelDetails != nil { + modelDetails := map[string]any{} + if provider := strings.TrimSpace(direct.ModelDetails.ModelProvider); provider != "" { + modelDetails["model_provider"] = provider + metadata["model_provider"] = provider + } + if model := strings.TrimSpace(direct.ModelDetails.ModelName); model != "" { + modelDetails["model_name"] = model + metadata["model_name"] = model + } + if len(modelDetails) > 0 { + metadata["model_details"] = modelDetails + } + } +} + +func gitLabAuthKind(method string) string { + switch strings.ToLower(strings.TrimSpace(method)) { + case gitLabAuthMethodPAT: + return "personal_access_token" + default: + return "oauth" + } +} + +func GitLabModelsFromAuth(auth *cliproxyauth.Auth) []*registry.ModelInfo { + models := make([]*registry.ModelInfo, 0, 4) + seen := make(map[string]struct{}, 4) + addModel := func(id, displayName, provider string) { + id = strings.TrimSpace(id) + if id == "" { + return + } + key := strings.ToLower(id) + if _, ok := seen[key]; ok { + return + } + seen[key] = struct{}{} + models = append(models, ®istry.ModelInfo{ + ID: id, + Object: "model", + Created: time.Now().Unix(), + OwnedBy: "gitlab", + Type: "gitlab", + DisplayName: displayName, + Description: provider, + UserDefined: true, + }) + } + + addModel("gitlab-duo", "GitLab Duo", "gitlab") + if auth == nil { + return models + } + for _, model := range gitlab.ExtractDiscoveredModels(auth.Metadata) { + name := strings.TrimSpace(model.ModelName) + if name == "" { + continue + } + displayName := "GitLab Duo" + if provider := strings.TrimSpace(model.ModelProvider); provider != "" { + displayName = fmt.Sprintf("GitLab Duo (%s)", provider) + } + addModel(name, displayName, strings.TrimSpace(model.ModelProvider)) + } + return models +} diff --git a/internal/runtime/executor/gitlab_executor_test.go b/internal/runtime/executor/gitlab_executor_test.go new file mode 100644 index 00000000..8257ddb6 --- /dev/null +++ b/internal/runtime/executor/gitlab_executor_test.go @@ -0,0 +1,124 @@ +package executor + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" +) + +func TestGitLabExecutorRefresh_WithPATStoresGatewayMetadata(t *testing.T) { + var server *httptest.Server + server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/v4/code_suggestions/direct_access" { + t.Fatalf("unexpected path %s", r.URL.Path) + } + if got := r.Header.Get("Authorization"); got != "Bearer pat-123" { + t.Fatalf("unexpected Authorization header %q", got) + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{ + "base_url":"` + server.URL + `", + "token":"gateway-token", + "expires_at":2000000000, + "headers":{"X-Gitlab-Realm":"saas"}, + "model_details":{"model_provider":"mistral","model_name":"codestral-2501"} + }`)) + })) + defer server.Close() + + exec := NewGitLabExecutor(nil) + auth := &cliproxyauth.Auth{ + ID: "gitlab-pat.json", + Provider: "gitlab", + Metadata: map[string]any{ + "type": "gitlab", + "auth_method": "pat", + "base_url": server.URL, + "personal_access_token": "pat-123", + }, + } + + updated, err := exec.Refresh(context.Background(), auth) + if err != nil { + t.Fatalf("Refresh returned error: %v", err) + } + if got := metadataString(updated.Metadata, "duo_gateway_token"); got != "gateway-token" { + t.Fatalf("unexpected gateway token %q", got) + } + if got := gitLabModelName(updated); got != "codestral-2501" { + t.Fatalf("unexpected model name %q", got) + } + headers := gitLabHeaders(updated) + if headers["X-Gitlab-Realm"] != "saas" { + t.Fatalf("unexpected gateway headers %+v", headers) + } +} + +func TestGitLabExecutorExecute_UsesGatewayHeadersAndResolvedModel(t *testing.T) { + var receivedAuth string + var receivedRealm string + var receivedModel string + + gateway := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + receivedAuth = r.Header.Get("Authorization") + receivedRealm = r.Header.Get("X-Gitlab-Realm") + receivedModel = findJSONField(string(body), `"model":"`, `"`) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"id":"ok","object":"chat.completion","choices":[{"index":0,"message":{"role":"assistant","content":"ok"},"finish_reason":"stop"}]}`)) + })) + defer gateway.Close() + + exec := NewGitLabExecutor(nil) + auth := &cliproxyauth.Auth{ + ID: "gitlab-oauth.json", + Provider: "gitlab", + Metadata: map[string]any{ + "type": "gitlab", + "auth_method": "oauth", + "duo_gateway_base_url": gateway.URL, + "duo_gateway_token": "gateway-token", + "duo_gateway_headers": map[string]any{"X-Gitlab-Realm": "saas"}, + "model_details": map[string]any{"model_name": "codestral-2501", "model_provider": "mistral"}, + }, + } + + resp, err := exec.Execute(context.Background(), auth, cliproxyexecutor.Request{ + Model: "gitlab-duo", + Payload: []byte(`{"model":"gitlab-duo","messages":[{"role":"user","content":"hello"}]}`), + }, cliproxyexecutor.Options{}) + if err != nil { + t.Fatalf("Execute returned error: %v", err) + } + if len(resp.Payload) == 0 { + t.Fatal("expected non-empty payload") + } + if receivedAuth != "Bearer gateway-token" { + t.Fatalf("unexpected Authorization header %q", receivedAuth) + } + if receivedRealm != "saas" { + t.Fatalf("unexpected X-Gitlab-Realm header %q", receivedRealm) + } + if receivedModel != "codestral-2501" { + t.Fatalf("unexpected resolved model %q", receivedModel) + } +} + +func findJSONField(body, prefix, suffix string) string { + start := strings.Index(body, prefix) + if start < 0 { + return "" + } + start += len(prefix) + end := strings.Index(body[start:], suffix) + if end < 0 { + return "" + } + return body[start : start+end] +} diff --git a/sdk/auth/gitlab.go b/sdk/auth/gitlab.go new file mode 100644 index 00000000..ae2b9177 --- /dev/null +++ b/sdk/auth/gitlab.go @@ -0,0 +1,462 @@ +package auth + +import ( + "context" + "fmt" + "strings" + "time" + + gitlabauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gitlab" + "github.com/router-for-me/CLIProxyAPI/v6/internal/browser" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/misc" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" + coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + log "github.com/sirupsen/logrus" +) + +const ( + gitLabLoginModeMetadataKey = "login_mode" + gitLabLoginModeOAuth = "oauth" + gitLabLoginModePAT = "pat" + gitLabBaseURLMetadataKey = "base_url" + gitLabOAuthClientIDMetadataKey = "oauth_client_id" + gitLabOAuthClientSecretMetadataKey = "oauth_client_secret" + gitLabPersonalAccessTokenMetadataKey = "personal_access_token" +) + +var gitLabRefreshLead = 5 * time.Minute + +type GitLabAuthenticator struct { + CallbackPort int +} + +func NewGitLabAuthenticator() *GitLabAuthenticator { + return &GitLabAuthenticator{CallbackPort: gitlabauth.DefaultCallbackPort} +} + +func (a *GitLabAuthenticator) Provider() string { + return "gitlab" +} + +func (a *GitLabAuthenticator) RefreshLead() *time.Duration { + return &gitLabRefreshLead +} + +func (a *GitLabAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) { + if cfg == nil { + return nil, fmt.Errorf("cliproxy auth: configuration is required") + } + if ctx == nil { + ctx = context.Background() + } + if opts == nil { + opts = &LoginOptions{} + } + + switch strings.ToLower(strings.TrimSpace(opts.Metadata[gitLabLoginModeMetadataKey])) { + case "", gitLabLoginModeOAuth: + return a.loginOAuth(ctx, cfg, opts) + case gitLabLoginModePAT: + return a.loginPAT(ctx, cfg, opts) + default: + return nil, fmt.Errorf("gitlab auth: unsupported login mode %q", opts.Metadata[gitLabLoginModeMetadataKey]) + } +} + +func (a *GitLabAuthenticator) loginOAuth(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) { + client := gitlabauth.NewAuthClient(cfg) + baseURL := a.resolveString(opts, gitLabBaseURLMetadataKey, gitlabauth.DefaultBaseURL) + clientID, err := a.requireInput(opts, gitLabOAuthClientIDMetadataKey, "Enter GitLab OAuth application client ID: ") + if err != nil { + return nil, err + } + clientSecret, err := a.optionalInput(opts, gitLabOAuthClientSecretMetadataKey, "Enter GitLab OAuth application client secret (press Enter for public PKCE app): ") + if err != nil { + return nil, err + } + + callbackPort := a.CallbackPort + if opts.CallbackPort > 0 { + callbackPort = opts.CallbackPort + } + redirectURI := gitlabauth.RedirectURL(callbackPort) + + pkceCodes, err := gitlabauth.GeneratePKCECodes() + if err != nil { + return nil, err + } + state, err := misc.GenerateRandomState() + if err != nil { + return nil, fmt.Errorf("gitlab state generation failed: %w", err) + } + + oauthServer := gitlabauth.NewOAuthServer(callbackPort) + if err := oauthServer.Start(); err != nil { + return nil, err + } + defer func() { + stopCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + if stopErr := oauthServer.Stop(stopCtx); stopErr != nil { + log.Warnf("gitlab oauth server stop error: %v", stopErr) + } + }() + + authURL, err := client.GenerateAuthURL(baseURL, clientID, redirectURI, state, pkceCodes) + if err != nil { + return nil, err + } + + if !opts.NoBrowser { + fmt.Println("Opening browser for GitLab Duo authentication") + if !browser.IsAvailable() { + log.Warn("No browser available; please open the URL manually") + util.PrintSSHTunnelInstructions(callbackPort) + fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL) + } else if err = browser.OpenURL(authURL); err != nil { + log.Warnf("Failed to open browser automatically: %v", err) + util.PrintSSHTunnelInstructions(callbackPort) + fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL) + } + } else { + util.PrintSSHTunnelInstructions(callbackPort) + fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL) + } + + fmt.Println("Waiting for GitLab OAuth callback...") + + callbackCh := make(chan *gitlabauth.OAuthResult, 1) + callbackErrCh := make(chan error, 1) + go func() { + result, waitErr := oauthServer.WaitForCallback(5 * time.Minute) + if waitErr != nil { + callbackErrCh <- waitErr + return + } + callbackCh <- result + }() + + var result *gitlabauth.OAuthResult + var manualPromptTimer *time.Timer + var manualPromptC <-chan time.Time + if opts.Prompt != nil { + manualPromptTimer = time.NewTimer(15 * time.Second) + manualPromptC = manualPromptTimer.C + defer manualPromptTimer.Stop() + } + +waitForCallback: + for { + select { + case result = <-callbackCh: + break waitForCallback + case err = <-callbackErrCh: + return nil, err + case <-manualPromptC: + manualPromptC = nil + if manualPromptTimer != nil { + manualPromptTimer.Stop() + } + input, promptErr := opts.Prompt("Paste the GitLab callback URL (or press Enter to keep waiting): ") + if promptErr != nil { + return nil, promptErr + } + parsed, parseErr := misc.ParseOAuthCallback(input) + if parseErr != nil { + return nil, parseErr + } + if parsed == nil { + continue + } + result = &gitlabauth.OAuthResult{ + Code: parsed.Code, + State: parsed.State, + Error: parsed.Error, + } + break waitForCallback + } + } + + if result.Error != "" { + return nil, fmt.Errorf("gitlab oauth returned error: %s", result.Error) + } + if result.State != state { + return nil, fmt.Errorf("gitlab auth: state mismatch") + } + + tokenResp, err := client.ExchangeCodeForTokens(ctx, baseURL, clientID, clientSecret, redirectURI, result.Code, pkceCodes.CodeVerifier) + if err != nil { + return nil, err + } + accessToken := strings.TrimSpace(tokenResp.AccessToken) + if accessToken == "" { + return nil, fmt.Errorf("gitlab auth: missing access token") + } + + user, err := client.GetCurrentUser(ctx, baseURL, accessToken) + if err != nil { + return nil, err + } + direct, err := client.FetchDirectAccess(ctx, baseURL, accessToken) + if err != nil { + return nil, err + } + + identifier := gitLabAccountIdentifier(user) + fileName := fmt.Sprintf("gitlab-%s.json", sanitizeGitLabFileName(identifier)) + metadata := buildGitLabAuthMetadata(baseURL, gitLabLoginModeOAuth, tokenResp, direct) + metadata["auth_kind"] = "oauth" + metadata[gitLabOAuthClientIDMetadataKey] = clientID + if strings.TrimSpace(clientSecret) != "" { + metadata[gitLabOAuthClientSecretMetadataKey] = clientSecret + } + metadata["username"] = strings.TrimSpace(user.Username) + if email := strings.TrimSpace(primaryGitLabEmail(user)); email != "" { + metadata["email"] = email + } + metadata["name"] = strings.TrimSpace(user.Name) + + fmt.Println("GitLab Duo authentication successful") + + return &coreauth.Auth{ + ID: fileName, + Provider: a.Provider(), + FileName: fileName, + Label: identifier, + Metadata: metadata, + }, nil +} + +func (a *GitLabAuthenticator) loginPAT(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) { + client := gitlabauth.NewAuthClient(cfg) + baseURL := a.resolveString(opts, gitLabBaseURLMetadataKey, gitlabauth.DefaultBaseURL) + token, err := a.requireInput(opts, gitLabPersonalAccessTokenMetadataKey, "Enter GitLab personal access token: ") + if err != nil { + return nil, err + } + + user, err := client.GetCurrentUser(ctx, baseURL, token) + if err != nil { + return nil, err + } + _, err = client.GetPersonalAccessTokenSelf(ctx, baseURL, token) + if err != nil { + return nil, err + } + direct, err := client.FetchDirectAccess(ctx, baseURL, token) + if err != nil { + return nil, err + } + + identifier := gitLabAccountIdentifier(user) + fileName := fmt.Sprintf("gitlab-%s-pat.json", sanitizeGitLabFileName(identifier)) + metadata := buildGitLabAuthMetadata(baseURL, gitLabLoginModePAT, nil, direct) + metadata["auth_kind"] = "personal_access_token" + metadata[gitLabPersonalAccessTokenMetadataKey] = strings.TrimSpace(token) + metadata["token_preview"] = maskGitLabToken(token) + metadata["username"] = strings.TrimSpace(user.Username) + if email := strings.TrimSpace(primaryGitLabEmail(user)); email != "" { + metadata["email"] = email + } + metadata["name"] = strings.TrimSpace(user.Name) + + fmt.Println("GitLab Duo PAT authentication successful") + + return &coreauth.Auth{ + ID: fileName, + Provider: a.Provider(), + FileName: fileName, + Label: identifier + " (PAT)", + Metadata: metadata, + }, nil +} + +func buildGitLabAuthMetadata(baseURL, mode string, tokenResp *gitlabauth.TokenResponse, direct *gitlabauth.DirectAccessResponse) map[string]any { + metadata := map[string]any{ + "type": "gitlab", + "auth_method": strings.TrimSpace(mode), + gitLabBaseURLMetadataKey: gitlabauth.NormalizeBaseURL(baseURL), + "last_refresh": time.Now().UTC().Format(time.RFC3339), + "refresh_interval_seconds": 240, + } + if tokenResp != nil { + metadata["access_token"] = strings.TrimSpace(tokenResp.AccessToken) + if refreshToken := strings.TrimSpace(tokenResp.RefreshToken); refreshToken != "" { + metadata["refresh_token"] = refreshToken + } + if tokenType := strings.TrimSpace(tokenResp.TokenType); tokenType != "" { + metadata["token_type"] = tokenType + } + if scope := strings.TrimSpace(tokenResp.Scope); scope != "" { + metadata["scope"] = scope + } + if expiry := gitlabauth.TokenExpiry(time.Now(), tokenResp); !expiry.IsZero() { + metadata["oauth_expires_at"] = expiry.Format(time.RFC3339) + } + } + mergeGitLabDirectAccessMetadata(metadata, direct) + return metadata +} + +func mergeGitLabDirectAccessMetadata(metadata map[string]any, direct *gitlabauth.DirectAccessResponse) { + if metadata == nil || direct == nil { + return + } + if base := strings.TrimSpace(direct.BaseURL); base != "" { + metadata["duo_gateway_base_url"] = base + } + if token := strings.TrimSpace(direct.Token); token != "" { + metadata["duo_gateway_token"] = token + } + if direct.ExpiresAt > 0 { + expiry := time.Unix(direct.ExpiresAt, 0).UTC() + metadata["duo_gateway_expires_at"] = expiry.Format(time.RFC3339) + now := time.Now().UTC() + if ttl := expiry.Sub(now); ttl > 0 { + interval := int(ttl.Seconds()) / 2 + switch { + case interval < 60: + interval = 60 + case interval > 240: + interval = 240 + } + metadata["refresh_interval_seconds"] = interval + } + } + if len(direct.Headers) > 0 { + headers := make(map[string]string, len(direct.Headers)) + for key, value := range direct.Headers { + key = strings.TrimSpace(key) + value = strings.TrimSpace(value) + if key == "" || value == "" { + continue + } + headers[key] = value + } + if len(headers) > 0 { + metadata["duo_gateway_headers"] = headers + } + } + if direct.ModelDetails != nil { + modelDetails := map[string]any{} + if provider := strings.TrimSpace(direct.ModelDetails.ModelProvider); provider != "" { + modelDetails["model_provider"] = provider + metadata["model_provider"] = provider + } + if model := strings.TrimSpace(direct.ModelDetails.ModelName); model != "" { + modelDetails["model_name"] = model + metadata["model_name"] = model + } + if len(modelDetails) > 0 { + metadata["model_details"] = modelDetails + } + } +} + +func (a *GitLabAuthenticator) resolveString(opts *LoginOptions, key, fallback string) string { + if opts != nil && opts.Metadata != nil { + if value := strings.TrimSpace(opts.Metadata[key]); value != "" { + return value + } + } + if strings.TrimSpace(fallback) != "" { + return fallback + } + return "" +} + +func (a *GitLabAuthenticator) requireInput(opts *LoginOptions, key, prompt string) (string, error) { + if value := a.resolveString(opts, key, ""); value != "" { + return value, nil + } + if opts != nil && opts.Prompt != nil { + value, err := opts.Prompt(prompt) + if err != nil { + return "", err + } + if trimmed := strings.TrimSpace(value); trimmed != "" { + return trimmed, nil + } + } + return "", fmt.Errorf("gitlab auth: missing required %s", key) +} + +func (a *GitLabAuthenticator) optionalInput(opts *LoginOptions, key, prompt string) (string, error) { + if value := a.resolveString(opts, key, ""); value != "" { + return value, nil + } + if opts != nil && opts.Prompt != nil { + value, err := opts.Prompt(prompt) + if err != nil { + return "", err + } + return strings.TrimSpace(value), nil + } + return "", nil +} + +func primaryGitLabEmail(user *gitlabauth.User) string { + if user == nil { + return "" + } + if value := strings.TrimSpace(user.Email); value != "" { + return value + } + return strings.TrimSpace(user.PublicEmail) +} + +func gitLabAccountIdentifier(user *gitlabauth.User) string { + if user == nil { + return "user" + } + for _, value := range []string{user.Username, primaryGitLabEmail(user), user.Name} { + if trimmed := strings.TrimSpace(value); trimmed != "" { + return trimmed + } + } + return "user" +} + +func sanitizeGitLabFileName(value string) string { + value = strings.TrimSpace(strings.ToLower(value)) + if value == "" { + return "user" + } + var builder strings.Builder + lastDash := false + for _, r := range value { + switch { + case r >= 'a' && r <= 'z': + builder.WriteRune(r) + lastDash = false + case r >= '0' && r <= '9': + builder.WriteRune(r) + lastDash = false + case r == '-' || r == '_' || r == '.': + builder.WriteRune(r) + lastDash = false + default: + if !lastDash { + builder.WriteRune('-') + lastDash = true + } + } + } + result := strings.Trim(builder.String(), "-") + if result == "" { + return "user" + } + return result +} + +func maskGitLabToken(token string) string { + trimmed := strings.TrimSpace(token) + if trimmed == "" { + return "" + } + if len(trimmed) <= 8 { + return trimmed + } + return trimmed[:4] + "..." + trimmed[len(trimmed)-4:] +} diff --git a/sdk/auth/refresh_registry.go b/sdk/auth/refresh_registry.go index ecf8e820..411950ae 100644 --- a/sdk/auth/refresh_registry.go +++ b/sdk/auth/refresh_registry.go @@ -17,6 +17,7 @@ func init() { registerRefreshLead("kimi", func() Authenticator { return NewKimiAuthenticator() }) registerRefreshLead("kiro", func() Authenticator { return NewKiroAuthenticator() }) registerRefreshLead("github-copilot", func() Authenticator { return NewGitHubCopilotAuthenticator() }) + registerRefreshLead("gitlab", func() Authenticator { return NewGitLabAuthenticator() }) } func registerRefreshLead(provider string, factory func() Authenticator) { diff --git a/sdk/cliproxy/auth/types.go b/sdk/cliproxy/auth/types.go index e24919f5..225038e1 100644 --- a/sdk/cliproxy/auth/types.go +++ b/sdk/cliproxy/auth/types.go @@ -390,6 +390,27 @@ func (a *Auth) AccountInfo() (string, string) { // Check metadata for email first (OAuth-style auth) if a.Metadata != nil { + if method, ok := a.Metadata["auth_method"].(string); ok { + switch strings.ToLower(strings.TrimSpace(method)) { + case "oauth": + for _, key := range []string{"email", "username", "name"} { + if value, okValue := a.Metadata[key].(string); okValue { + if trimmed := strings.TrimSpace(value); trimmed != "" { + return "oauth", trimmed + } + } + } + case "pat", "personal_access_token": + for _, key := range []string{"username", "email", "name", "token_preview"} { + if value, okValue := a.Metadata[key].(string); okValue { + if trimmed := strings.TrimSpace(value); trimmed != "" { + return "personal_access_token", trimmed + } + } + } + return "personal_access_token", "" + } + } if v, ok := a.Metadata["email"].(string); ok { email := strings.TrimSpace(v) if email != "" { diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index 82f6c85d..e0955603 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -119,6 +119,7 @@ func newDefaultAuthManager() *sdkAuth.Manager { sdkAuth.NewCodexAuthenticator(), sdkAuth.NewClaudeAuthenticator(), sdkAuth.NewQwenAuthenticator(), + sdkAuth.NewGitLabAuthenticator(), ) } @@ -444,6 +445,8 @@ func (s *Service) ensureExecutorsForAuthWithMode(a *coreauth.Auth, forceReplace s.coreManager.RegisterExecutor(executor.NewKiloExecutor(s.cfg)) case "github-copilot": s.coreManager.RegisterExecutor(executor.NewGitHubCopilotExecutor(s.cfg)) + case "gitlab": + s.coreManager.RegisterExecutor(executor.NewGitLabExecutor(s.cfg)) default: providerKey := strings.ToLower(strings.TrimSpace(a.Provider)) if providerKey == "" { @@ -891,7 +894,7 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) { models = applyExcludedModels(models, excluded) case "kimi": models = registry.GetKimiModels() - models = applyExcludedModels(models, excluded) + models = applyExcludedModels(models, excluded) case "github-copilot": ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) defer cancel() @@ -903,6 +906,9 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) { case "kilo": models = executor.FetchKiloModels(context.Background(), a, s.cfg) models = applyExcludedModels(models, excluded) + case "gitlab": + models = executor.GitLabModelsFromAuth(a) + models = applyExcludedModels(models, excluded) default: // Handle OpenAI-compatibility providers by name using config if s.cfg != nil { diff --git a/sdk/cliproxy/service_gitlab_models_test.go b/sdk/cliproxy/service_gitlab_models_test.go new file mode 100644 index 00000000..7794649c --- /dev/null +++ b/sdk/cliproxy/service_gitlab_models_test.go @@ -0,0 +1,59 @@ +package cliproxy + +import ( + "strings" + "testing" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" + coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/config" +) + +func TestRegisterModelsForAuth_GitLabUsesDiscoveredModelAndAlias(t *testing.T) { + service := &Service{cfg: &config.Config{}} + auth := &coreauth.Auth{ + ID: "gitlab-auth", + Provider: "gitlab", + Status: coreauth.StatusActive, + Metadata: map[string]any{ + "model_details": map[string]any{ + "model_provider": "mistral", + "model_name": "codestral-2501", + }, + }, + } + + reg := registry.GetGlobalRegistry() + reg.UnregisterClient(auth.ID) + t.Cleanup(func() { + reg.UnregisterClient(auth.ID) + }) + + service.registerModelsForAuth(auth) + + models := reg.GetModelsForClient(auth.ID) + if len(models) == 0 { + t.Fatal("expected GitLab models to be registered") + } + + seenActual := false + seenAlias := false + for _, model := range models { + if model == nil { + continue + } + switch strings.TrimSpace(model.ID) { + case "codestral-2501": + seenActual = true + case "gitlab-duo": + seenAlias = true + } + } + + if !seenActual { + t.Fatal("expected discovered GitLab model to be registered") + } + if !seenAlias { + t.Fatal("expected stable GitLab Duo alias to be registered") + } +}