fix(provider): type auth errors (#27301)

This commit is contained in:
Shoubhit Dash
2026-05-13 15:47:13 +05:30
committed by GitHub
parent 733bd3c74e
commit 809af5c590
6 changed files with 200 additions and 32 deletions

View File

@@ -1,7 +1,6 @@
import type { AuthOAuthResult, Hooks } from "@opencode-ai/plugin"
import { Auth } from "@/auth"
import { InstanceState } from "@/effect/instance-state"
import { NamedError } from "@opencode-ai/core/util/error"
import { optionalOmitUndefined } from "@opencode-ai/core/schema"
import { Plugin } from "../plugin"
import { ProviderID } from "./schema"
@@ -64,23 +63,30 @@ export const CallbackInput = Schema.Struct({
})
export type CallbackInput = Schema.Schema.Type<typeof CallbackInput>
export const OauthMissing = NamedError.create("ProviderAuthOauthMissing", { providerID: ProviderID })
export class OauthMissing extends Schema.TaggedErrorClass<OauthMissing>()("ProviderAuthOauthMissing", {
providerID: ProviderID,
}) {}
export const OauthCodeMissing = NamedError.create("ProviderAuthOauthCodeMissing", { providerID: ProviderID })
export class OauthCodeMissing extends Schema.TaggedErrorClass<OauthCodeMissing>()("ProviderAuthOauthCodeMissing", {
providerID: ProviderID,
}) {}
export const OauthCallbackFailed = NamedError.create("ProviderAuthOauthCallbackFailed", {})
export class OauthCallbackFailed extends Schema.TaggedErrorClass<OauthCallbackFailed>()(
"ProviderAuthOauthCallbackFailed",
{},
) {}
export const ValidationFailed = NamedError.create("ProviderAuthValidationFailed", {
export class ValidationFailed extends Schema.TaggedErrorClass<ValidationFailed>()("ProviderAuthValidationFailed", {
field: Schema.String,
message: Schema.String,
})
}) {}
export type Error =
| Auth.AuthError
| InstanceType<typeof OauthMissing>
| InstanceType<typeof OauthCodeMissing>
| InstanceType<typeof OauthCallbackFailed>
| InstanceType<typeof ValidationFailed>
| OauthMissing
| OauthCodeMissing
| OauthCallbackFailed
| ValidationFailed
type Hook = NonNullable<Hooks["auth"]>
@@ -166,7 +172,7 @@ export const layer: Layer.Layer<Service, never, Auth.Service | Plugin.Service> =
for (const prompt of method.prompts) {
if (prompt.type === "text" && prompt.validate && input.inputs[prompt.key] !== undefined) {
const error = prompt.validate(input.inputs[prompt.key])
if (error) return yield* Effect.fail(new ValidationFailed({ field: prompt.key, message: error }))
if (error) return yield* new ValidationFailed({ field: prompt.key, message: error })
}
}
}
@@ -183,15 +189,15 @@ export const layer: Layer.Layer<Service, never, Auth.Service | Plugin.Service> =
const callback = Effect.fn("ProviderAuth.callback")(function* (input: { providerID: ProviderID } & CallbackInput) {
const pending = (yield* InstanceState.get(state)).pending
const match = pending.get(input.providerID)
if (!match) return yield* Effect.fail(new OauthMissing({ providerID: input.providerID }))
if (!match) return yield* new OauthMissing({ providerID: input.providerID })
if (match.method === "code" && !input.code) {
return yield* Effect.fail(new OauthCodeMissing({ providerID: input.providerID }))
return yield* new OauthCodeMissing({ providerID: input.providerID })
}
const result = yield* Effect.promise(() =>
match.method === "code" ? match.callback(input.code!) : match.callback(),
)
if (!result || result.type !== "success") return yield* Effect.fail(new OauthCallbackFailed({}))
if (!result || result.type !== "success") return yield* new OauthCallbackFailed({})
if ("key" in result) {
yield* auth.set(input.providerID, {

View File

@@ -10,6 +10,26 @@ import { described } from "./metadata"
const root = "/provider"
const ProviderAuthErrorName = Schema.Union([
Schema.Literal("BadRequest"),
Schema.Literal("ProviderAuthOauthMissing"),
Schema.Literal("ProviderAuthOauthCodeMissing"),
Schema.Literal("ProviderAuthOauthCallbackFailed"),
Schema.Literal("ProviderAuthValidationFailed"),
])
export class ProviderAuthApiError extends Schema.ErrorClass<ProviderAuthApiError>("ProviderAuthError")(
{
name: ProviderAuthErrorName,
data: Schema.Struct({
providerID: Schema.optional(ProviderID),
field: Schema.optional(Schema.String),
message: Schema.optional(Schema.String),
kind: Schema.optional(Schema.String),
}),
},
{ httpApiStatus: 400 },
) {}
export const ProviderApi = HttpApi.make("provider")
.add(
HttpApiGroup.make("provider")
@@ -39,7 +59,7 @@ export const ProviderApi = HttpApi.make("provider")
query: WorkspaceRoutingQuery,
payload: ProviderAuth.AuthorizeInput,
success: described(Schema.UndefinedOr(ProviderAuth.Authorization), "Authorization URL and method"),
error: HttpApiError.BadRequest,
error: ProviderAuthApiError,
}).annotateMerge(
OpenApi.annotations({
identifier: "provider.oauth.authorize",
@@ -52,7 +72,7 @@ export const ProviderApi = HttpApi.make("provider")
query: WorkspaceRoutingQuery,
payload: ProviderAuth.CallbackInput,
success: described(Schema.Boolean, "OAuth callback processed successfully"),
error: HttpApiError.BadRequest,
error: ProviderAuthApiError,
}).annotateMerge(
OpenApi.annotations({
identifier: "provider.oauth.callback",

View File

@@ -6,8 +6,29 @@ import { ProviderID } from "@/provider/schema"
import { mapValues } from "remeda"
import { Effect, Schema } from "effect"
import { HttpServerRequest, HttpServerResponse } from "effect/unstable/http"
import { HttpApiBuilder, HttpApiError } from "effect/unstable/httpapi"
import { HttpApiBuilder } from "effect/unstable/httpapi"
import { InstanceHttpApi } from "../api"
import { ProviderAuthApiError } from "../groups/provider"
function mapProviderAuthError<A, R>(self: Effect.Effect<A, ProviderAuth.Error, R>) {
return self.pipe(
Effect.mapError((error) => {
if (error instanceof ProviderAuth.OauthMissing) {
return new ProviderAuthApiError({ name: error._tag, data: { providerID: error.providerID } })
}
if (error instanceof ProviderAuth.OauthCodeMissing) {
return new ProviderAuthApiError({ name: error._tag, data: { providerID: error.providerID } })
}
if (error instanceof ProviderAuth.OauthCallbackFailed) {
return new ProviderAuthApiError({ name: error._tag, data: {} })
}
if (error instanceof ProviderAuth.ValidationFailed) {
return new ProviderAuthApiError({ name: error._tag, data: { field: error.field, message: error.message } })
}
return new ProviderAuthApiError({ name: "BadRequest", data: {} })
}),
)
}
export const providerHandlers = HttpApiBuilder.group(InstanceHttpApi, "provider", (handlers) =>
Effect.gen(function* () {
@@ -44,13 +65,13 @@ export const providerHandlers = HttpApiBuilder.group(InstanceHttpApi, "provider"
params: { providerID: ProviderID }
payload: ProviderAuth.AuthorizeInput
}) {
return yield* svc
.authorize({
return yield* mapProviderAuthError(
svc.authorize({
providerID: ctx.params.providerID,
method: ctx.payload.method,
inputs: ctx.payload.inputs,
})
.pipe(Effect.catch(() => Effect.fail(new HttpApiError.BadRequest({}))))
}),
)
})
const authorizeRaw = Effect.fn("ProviderHttpApi.authorizeRaw")(function* (ctx: {
@@ -59,7 +80,7 @@ export const providerHandlers = HttpApiBuilder.group(InstanceHttpApi, "provider"
}) {
const body = yield* Effect.orDie(ctx.request.text)
const payload = yield* Schema.decodeUnknownEffect(Schema.fromJsonString(ProviderAuth.AuthorizeInput))(body).pipe(
Effect.mapError(() => new HttpApiError.BadRequest({})),
Effect.mapError(() => new ProviderAuthApiError({ name: "BadRequest", data: {} })),
)
// Match legacy route behavior: when authorize() resolves without a
// result (e.g. no further redirect), serialize as JSON `null` instead
@@ -72,13 +93,13 @@ export const providerHandlers = HttpApiBuilder.group(InstanceHttpApi, "provider"
params: { providerID: ProviderID }
payload: ProviderAuth.CallbackInput
}) {
yield* svc
.callback({
yield* mapProviderAuthError(
svc.callback({
providerID: ctx.params.providerID,
method: ctx.payload.method,
code: ctx.payload.code,
})
.pipe(Effect.catch(() => Effect.fail(new HttpApiError.BadRequest({}))))
}),
)
return true
})

View File

@@ -28,7 +28,6 @@ export const errorLayer = HttpRouter.middleware<{ handles: unknown }>()((effect)
HttpServerResponse.jsonUnsafe(error.toObject(), {
status: iife(() => {
if (error instanceof Provider.ModelNotFoundError) return 400
if (error.name === "ProviderAuthValidationFailed") return 400
return 500
}),
}),

View File

@@ -80,12 +80,33 @@ function requestAuthorize(input: {
providerID: string
method: number
headers: HeadersInit
inputs?: Record<string, string>
}) {
return Effect.promise(async () => {
const response = await input.app.request(`/provider/${input.providerID}/oauth/authorize`, {
method: "POST",
headers: input.headers,
body: JSON.stringify({ method: input.method }),
body: JSON.stringify({ method: input.method, ...(input.inputs ? { inputs: input.inputs } : {}) }),
})
return {
status: response.status,
body: await response.text(),
}
})
}
function requestCallback(input: {
app: ReturnType<typeof app>
providerID: string
method: number
headers: HeadersInit
code?: string
}) {
return Effect.promise(async () => {
const response = await input.app.request(`/provider/${input.providerID}/oauth/callback`, {
method: "POST",
headers: input.headers,
body: JSON.stringify({ method: input.method, ...(input.code ? { code: input.code } : {}) }),
})
return {
status: response.status,
@@ -128,6 +149,47 @@ function writeProviderAuthPlugin(dir: string) {
})
}
function writeProviderAuthValidationPlugin(dir: string) {
return Effect.gen(function* () {
const fs = yield* AppFileSystem.Service
yield* fs.writeWithDirs(
path.join(dir, ".opencode", "plugin", "provider-oauth-validation.ts"),
[
"export default {",
' id: "test.provider-oauth-validation",',
" server: async () => ({",
" auth: {",
' provider: "test-oauth-validation",',
" methods: [",
" {",
' type: "oauth",',
' label: "OAuth",',
" prompts: [",
" {",
' type: "text",',
' key: "token",',
' message: "Token",',
" validate: (value) => value === 'ok' ? undefined : 'Token must be ok',",
" },",
" ],",
" authorize: async () => ({",
` url: "${oauthURL}",`,
' method: "code",',
` instructions: "${oauthInstructions}",`,
" callback: async () => ({ type: 'success', key: 'token' }),",
" }),",
" },",
" ],",
" },",
" }),",
"}",
"",
].join("\n"),
)
})
}
function writeFunctionOptionsPlugin(dir: string) {
return Effect.gen(function* () {
const fs = yield* AppFileSystem.Service
@@ -240,6 +302,51 @@ describe("provider HttpApi", () => {
})
}),
projectOptions,
30000,
)
it.instance(
"returns declared provider auth validation errors",
Effect.gen(function* () {
const instance = yield* TestInstance
yield* writeProviderAuthValidationPlugin(instance.directory)
const response = yield* requestAuthorize({
app: app(),
providerID: "test-oauth-validation",
method: 0,
inputs: { token: "nope" },
headers: { "x-opencode-directory": instance.directory, "content-type": "application/json" },
})
expect(response.status).toBe(400)
expect(JSON.parse(response.body)).toEqual({
name: "ProviderAuthValidationFailed",
data: { field: "token", message: "Token must be ok" },
})
}),
projectOptions,
30000,
)
it.instance(
"returns declared provider auth callback errors",
Effect.gen(function* () {
const instance = yield* TestInstance
const response = yield* requestCallback({
app: app(),
providerID,
method: 0,
headers: { "x-opencode-directory": instance.directory, "content-type": "application/json" },
})
expect(response.status).toBe(400)
expect(JSON.parse(response.body)).toEqual({
name: "ProviderAuthOauthMissing",
data: { providerID },
})
}),
projectOptions,
30000,
)
it.instance(

View File

@@ -1721,6 +1721,21 @@ export type ProviderAuthAuthorization = {
instructions: string
}
export type ProviderAuthError1 = {
name:
| "BadRequest"
| "ProviderAuthOauthMissing"
| "ProviderAuthOauthCodeMissing"
| "ProviderAuthOauthCallbackFailed"
| "ProviderAuthValidationFailed"
data: {
providerID?: string
field?: string
message?: string
kind?: string
}
}
export type TextPartInput = {
id?: string
type: "text"
@@ -5155,9 +5170,9 @@ export type ProviderOauthAuthorizeData = {
export type ProviderOauthAuthorizeErrors = {
/**
* Bad request
* ProviderAuthError
*/
400: BadRequestError
400: ProviderAuthError1
}
export type ProviderOauthAuthorizeError = ProviderOauthAuthorizeErrors[keyof ProviderOauthAuthorizeErrors]
@@ -5191,9 +5206,9 @@ export type ProviderOauthCallbackData = {
export type ProviderOauthCallbackErrors = {
/**
* Bad request
* ProviderAuthError
*/
400: BadRequestError
400: ProviderAuthError1
}
export type ProviderOauthCallbackError = ProviderOauthCallbackErrors[keyof ProviderOauthCallbackErrors]