mirror of
https://github.com/anomalyco/opencode.git
synced 2026-05-13 15:44:56 +00:00
fix(provider): type auth errors (#27301)
This commit is contained in:
@@ -1,7 +1,6 @@
|
|||||||
import type { AuthOAuthResult, Hooks } from "@opencode-ai/plugin"
|
import type { AuthOAuthResult, Hooks } from "@opencode-ai/plugin"
|
||||||
import { Auth } from "@/auth"
|
import { Auth } from "@/auth"
|
||||||
import { InstanceState } from "@/effect/instance-state"
|
import { InstanceState } from "@/effect/instance-state"
|
||||||
import { NamedError } from "@opencode-ai/core/util/error"
|
|
||||||
import { optionalOmitUndefined } from "@opencode-ai/core/schema"
|
import { optionalOmitUndefined } from "@opencode-ai/core/schema"
|
||||||
import { Plugin } from "../plugin"
|
import { Plugin } from "../plugin"
|
||||||
import { ProviderID } from "./schema"
|
import { ProviderID } from "./schema"
|
||||||
@@ -64,23 +63,30 @@ export const CallbackInput = Schema.Struct({
|
|||||||
})
|
})
|
||||||
export type CallbackInput = Schema.Schema.Type<typeof CallbackInput>
|
export type CallbackInput = Schema.Schema.Type<typeof CallbackInput>
|
||||||
|
|
||||||
export const OauthMissing = NamedError.create("ProviderAuthOauthMissing", { providerID: ProviderID })
|
export class OauthMissing extends Schema.TaggedErrorClass<OauthMissing>()("ProviderAuthOauthMissing", {
|
||||||
|
providerID: ProviderID,
|
||||||
|
}) {}
|
||||||
|
|
||||||
export const OauthCodeMissing = NamedError.create("ProviderAuthOauthCodeMissing", { providerID: ProviderID })
|
export class OauthCodeMissing extends Schema.TaggedErrorClass<OauthCodeMissing>()("ProviderAuthOauthCodeMissing", {
|
||||||
|
providerID: ProviderID,
|
||||||
|
}) {}
|
||||||
|
|
||||||
export const OauthCallbackFailed = NamedError.create("ProviderAuthOauthCallbackFailed", {})
|
export class OauthCallbackFailed extends Schema.TaggedErrorClass<OauthCallbackFailed>()(
|
||||||
|
"ProviderAuthOauthCallbackFailed",
|
||||||
|
{},
|
||||||
|
) {}
|
||||||
|
|
||||||
export const ValidationFailed = NamedError.create("ProviderAuthValidationFailed", {
|
export class ValidationFailed extends Schema.TaggedErrorClass<ValidationFailed>()("ProviderAuthValidationFailed", {
|
||||||
field: Schema.String,
|
field: Schema.String,
|
||||||
message: Schema.String,
|
message: Schema.String,
|
||||||
})
|
}) {}
|
||||||
|
|
||||||
export type Error =
|
export type Error =
|
||||||
| Auth.AuthError
|
| Auth.AuthError
|
||||||
| InstanceType<typeof OauthMissing>
|
| OauthMissing
|
||||||
| InstanceType<typeof OauthCodeMissing>
|
| OauthCodeMissing
|
||||||
| InstanceType<typeof OauthCallbackFailed>
|
| OauthCallbackFailed
|
||||||
| InstanceType<typeof ValidationFailed>
|
| ValidationFailed
|
||||||
|
|
||||||
type Hook = NonNullable<Hooks["auth"]>
|
type Hook = NonNullable<Hooks["auth"]>
|
||||||
|
|
||||||
@@ -166,7 +172,7 @@ export const layer: Layer.Layer<Service, never, Auth.Service | Plugin.Service> =
|
|||||||
for (const prompt of method.prompts) {
|
for (const prompt of method.prompts) {
|
||||||
if (prompt.type === "text" && prompt.validate && input.inputs[prompt.key] !== undefined) {
|
if (prompt.type === "text" && prompt.validate && input.inputs[prompt.key] !== undefined) {
|
||||||
const error = prompt.validate(input.inputs[prompt.key])
|
const error = prompt.validate(input.inputs[prompt.key])
|
||||||
if (error) return yield* Effect.fail(new ValidationFailed({ field: prompt.key, message: error }))
|
if (error) return yield* new ValidationFailed({ field: prompt.key, message: error })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -183,15 +189,15 @@ export const layer: Layer.Layer<Service, never, Auth.Service | Plugin.Service> =
|
|||||||
const callback = Effect.fn("ProviderAuth.callback")(function* (input: { providerID: ProviderID } & CallbackInput) {
|
const callback = Effect.fn("ProviderAuth.callback")(function* (input: { providerID: ProviderID } & CallbackInput) {
|
||||||
const pending = (yield* InstanceState.get(state)).pending
|
const pending = (yield* InstanceState.get(state)).pending
|
||||||
const match = pending.get(input.providerID)
|
const match = pending.get(input.providerID)
|
||||||
if (!match) return yield* Effect.fail(new OauthMissing({ providerID: input.providerID }))
|
if (!match) return yield* new OauthMissing({ providerID: input.providerID })
|
||||||
if (match.method === "code" && !input.code) {
|
if (match.method === "code" && !input.code) {
|
||||||
return yield* Effect.fail(new OauthCodeMissing({ providerID: input.providerID }))
|
return yield* new OauthCodeMissing({ providerID: input.providerID })
|
||||||
}
|
}
|
||||||
|
|
||||||
const result = yield* Effect.promise(() =>
|
const result = yield* Effect.promise(() =>
|
||||||
match.method === "code" ? match.callback(input.code!) : match.callback(),
|
match.method === "code" ? match.callback(input.code!) : match.callback(),
|
||||||
)
|
)
|
||||||
if (!result || result.type !== "success") return yield* Effect.fail(new OauthCallbackFailed({}))
|
if (!result || result.type !== "success") return yield* new OauthCallbackFailed({})
|
||||||
|
|
||||||
if ("key" in result) {
|
if ("key" in result) {
|
||||||
yield* auth.set(input.providerID, {
|
yield* auth.set(input.providerID, {
|
||||||
|
|||||||
@@ -10,6 +10,26 @@ import { described } from "./metadata"
|
|||||||
|
|
||||||
const root = "/provider"
|
const root = "/provider"
|
||||||
|
|
||||||
|
const ProviderAuthErrorName = Schema.Union([
|
||||||
|
Schema.Literal("BadRequest"),
|
||||||
|
Schema.Literal("ProviderAuthOauthMissing"),
|
||||||
|
Schema.Literal("ProviderAuthOauthCodeMissing"),
|
||||||
|
Schema.Literal("ProviderAuthOauthCallbackFailed"),
|
||||||
|
Schema.Literal("ProviderAuthValidationFailed"),
|
||||||
|
])
|
||||||
|
export class ProviderAuthApiError extends Schema.ErrorClass<ProviderAuthApiError>("ProviderAuthError")(
|
||||||
|
{
|
||||||
|
name: ProviderAuthErrorName,
|
||||||
|
data: Schema.Struct({
|
||||||
|
providerID: Schema.optional(ProviderID),
|
||||||
|
field: Schema.optional(Schema.String),
|
||||||
|
message: Schema.optional(Schema.String),
|
||||||
|
kind: Schema.optional(Schema.String),
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
{ httpApiStatus: 400 },
|
||||||
|
) {}
|
||||||
|
|
||||||
export const ProviderApi = HttpApi.make("provider")
|
export const ProviderApi = HttpApi.make("provider")
|
||||||
.add(
|
.add(
|
||||||
HttpApiGroup.make("provider")
|
HttpApiGroup.make("provider")
|
||||||
@@ -39,7 +59,7 @@ export const ProviderApi = HttpApi.make("provider")
|
|||||||
query: WorkspaceRoutingQuery,
|
query: WorkspaceRoutingQuery,
|
||||||
payload: ProviderAuth.AuthorizeInput,
|
payload: ProviderAuth.AuthorizeInput,
|
||||||
success: described(Schema.UndefinedOr(ProviderAuth.Authorization), "Authorization URL and method"),
|
success: described(Schema.UndefinedOr(ProviderAuth.Authorization), "Authorization URL and method"),
|
||||||
error: HttpApiError.BadRequest,
|
error: ProviderAuthApiError,
|
||||||
}).annotateMerge(
|
}).annotateMerge(
|
||||||
OpenApi.annotations({
|
OpenApi.annotations({
|
||||||
identifier: "provider.oauth.authorize",
|
identifier: "provider.oauth.authorize",
|
||||||
@@ -52,7 +72,7 @@ export const ProviderApi = HttpApi.make("provider")
|
|||||||
query: WorkspaceRoutingQuery,
|
query: WorkspaceRoutingQuery,
|
||||||
payload: ProviderAuth.CallbackInput,
|
payload: ProviderAuth.CallbackInput,
|
||||||
success: described(Schema.Boolean, "OAuth callback processed successfully"),
|
success: described(Schema.Boolean, "OAuth callback processed successfully"),
|
||||||
error: HttpApiError.BadRequest,
|
error: ProviderAuthApiError,
|
||||||
}).annotateMerge(
|
}).annotateMerge(
|
||||||
OpenApi.annotations({
|
OpenApi.annotations({
|
||||||
identifier: "provider.oauth.callback",
|
identifier: "provider.oauth.callback",
|
||||||
|
|||||||
@@ -6,8 +6,29 @@ import { ProviderID } from "@/provider/schema"
|
|||||||
import { mapValues } from "remeda"
|
import { mapValues } from "remeda"
|
||||||
import { Effect, Schema } from "effect"
|
import { Effect, Schema } from "effect"
|
||||||
import { HttpServerRequest, HttpServerResponse } from "effect/unstable/http"
|
import { HttpServerRequest, HttpServerResponse } from "effect/unstable/http"
|
||||||
import { HttpApiBuilder, HttpApiError } from "effect/unstable/httpapi"
|
import { HttpApiBuilder } from "effect/unstable/httpapi"
|
||||||
import { InstanceHttpApi } from "../api"
|
import { InstanceHttpApi } from "../api"
|
||||||
|
import { ProviderAuthApiError } from "../groups/provider"
|
||||||
|
|
||||||
|
function mapProviderAuthError<A, R>(self: Effect.Effect<A, ProviderAuth.Error, R>) {
|
||||||
|
return self.pipe(
|
||||||
|
Effect.mapError((error) => {
|
||||||
|
if (error instanceof ProviderAuth.OauthMissing) {
|
||||||
|
return new ProviderAuthApiError({ name: error._tag, data: { providerID: error.providerID } })
|
||||||
|
}
|
||||||
|
if (error instanceof ProviderAuth.OauthCodeMissing) {
|
||||||
|
return new ProviderAuthApiError({ name: error._tag, data: { providerID: error.providerID } })
|
||||||
|
}
|
||||||
|
if (error instanceof ProviderAuth.OauthCallbackFailed) {
|
||||||
|
return new ProviderAuthApiError({ name: error._tag, data: {} })
|
||||||
|
}
|
||||||
|
if (error instanceof ProviderAuth.ValidationFailed) {
|
||||||
|
return new ProviderAuthApiError({ name: error._tag, data: { field: error.field, message: error.message } })
|
||||||
|
}
|
||||||
|
return new ProviderAuthApiError({ name: "BadRequest", data: {} })
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
export const providerHandlers = HttpApiBuilder.group(InstanceHttpApi, "provider", (handlers) =>
|
export const providerHandlers = HttpApiBuilder.group(InstanceHttpApi, "provider", (handlers) =>
|
||||||
Effect.gen(function* () {
|
Effect.gen(function* () {
|
||||||
@@ -44,13 +65,13 @@ export const providerHandlers = HttpApiBuilder.group(InstanceHttpApi, "provider"
|
|||||||
params: { providerID: ProviderID }
|
params: { providerID: ProviderID }
|
||||||
payload: ProviderAuth.AuthorizeInput
|
payload: ProviderAuth.AuthorizeInput
|
||||||
}) {
|
}) {
|
||||||
return yield* svc
|
return yield* mapProviderAuthError(
|
||||||
.authorize({
|
svc.authorize({
|
||||||
providerID: ctx.params.providerID,
|
providerID: ctx.params.providerID,
|
||||||
method: ctx.payload.method,
|
method: ctx.payload.method,
|
||||||
inputs: ctx.payload.inputs,
|
inputs: ctx.payload.inputs,
|
||||||
})
|
}),
|
||||||
.pipe(Effect.catch(() => Effect.fail(new HttpApiError.BadRequest({}))))
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
const authorizeRaw = Effect.fn("ProviderHttpApi.authorizeRaw")(function* (ctx: {
|
const authorizeRaw = Effect.fn("ProviderHttpApi.authorizeRaw")(function* (ctx: {
|
||||||
@@ -59,7 +80,7 @@ export const providerHandlers = HttpApiBuilder.group(InstanceHttpApi, "provider"
|
|||||||
}) {
|
}) {
|
||||||
const body = yield* Effect.orDie(ctx.request.text)
|
const body = yield* Effect.orDie(ctx.request.text)
|
||||||
const payload = yield* Schema.decodeUnknownEffect(Schema.fromJsonString(ProviderAuth.AuthorizeInput))(body).pipe(
|
const payload = yield* Schema.decodeUnknownEffect(Schema.fromJsonString(ProviderAuth.AuthorizeInput))(body).pipe(
|
||||||
Effect.mapError(() => new HttpApiError.BadRequest({})),
|
Effect.mapError(() => new ProviderAuthApiError({ name: "BadRequest", data: {} })),
|
||||||
)
|
)
|
||||||
// Match legacy route behavior: when authorize() resolves without a
|
// Match legacy route behavior: when authorize() resolves without a
|
||||||
// result (e.g. no further redirect), serialize as JSON `null` instead
|
// result (e.g. no further redirect), serialize as JSON `null` instead
|
||||||
@@ -72,13 +93,13 @@ export const providerHandlers = HttpApiBuilder.group(InstanceHttpApi, "provider"
|
|||||||
params: { providerID: ProviderID }
|
params: { providerID: ProviderID }
|
||||||
payload: ProviderAuth.CallbackInput
|
payload: ProviderAuth.CallbackInput
|
||||||
}) {
|
}) {
|
||||||
yield* svc
|
yield* mapProviderAuthError(
|
||||||
.callback({
|
svc.callback({
|
||||||
providerID: ctx.params.providerID,
|
providerID: ctx.params.providerID,
|
||||||
method: ctx.payload.method,
|
method: ctx.payload.method,
|
||||||
code: ctx.payload.code,
|
code: ctx.payload.code,
|
||||||
})
|
}),
|
||||||
.pipe(Effect.catch(() => Effect.fail(new HttpApiError.BadRequest({}))))
|
)
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
@@ -28,7 +28,6 @@ export const errorLayer = HttpRouter.middleware<{ handles: unknown }>()((effect)
|
|||||||
HttpServerResponse.jsonUnsafe(error.toObject(), {
|
HttpServerResponse.jsonUnsafe(error.toObject(), {
|
||||||
status: iife(() => {
|
status: iife(() => {
|
||||||
if (error instanceof Provider.ModelNotFoundError) return 400
|
if (error instanceof Provider.ModelNotFoundError) return 400
|
||||||
if (error.name === "ProviderAuthValidationFailed") return 400
|
|
||||||
return 500
|
return 500
|
||||||
}),
|
}),
|
||||||
}),
|
}),
|
||||||
|
|||||||
@@ -80,12 +80,33 @@ function requestAuthorize(input: {
|
|||||||
providerID: string
|
providerID: string
|
||||||
method: number
|
method: number
|
||||||
headers: HeadersInit
|
headers: HeadersInit
|
||||||
|
inputs?: Record<string, string>
|
||||||
}) {
|
}) {
|
||||||
return Effect.promise(async () => {
|
return Effect.promise(async () => {
|
||||||
const response = await input.app.request(`/provider/${input.providerID}/oauth/authorize`, {
|
const response = await input.app.request(`/provider/${input.providerID}/oauth/authorize`, {
|
||||||
method: "POST",
|
method: "POST",
|
||||||
headers: input.headers,
|
headers: input.headers,
|
||||||
body: JSON.stringify({ method: input.method }),
|
body: JSON.stringify({ method: input.method, ...(input.inputs ? { inputs: input.inputs } : {}) }),
|
||||||
|
})
|
||||||
|
return {
|
||||||
|
status: response.status,
|
||||||
|
body: await response.text(),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
function requestCallback(input: {
|
||||||
|
app: ReturnType<typeof app>
|
||||||
|
providerID: string
|
||||||
|
method: number
|
||||||
|
headers: HeadersInit
|
||||||
|
code?: string
|
||||||
|
}) {
|
||||||
|
return Effect.promise(async () => {
|
||||||
|
const response = await input.app.request(`/provider/${input.providerID}/oauth/callback`, {
|
||||||
|
method: "POST",
|
||||||
|
headers: input.headers,
|
||||||
|
body: JSON.stringify({ method: input.method, ...(input.code ? { code: input.code } : {}) }),
|
||||||
})
|
})
|
||||||
return {
|
return {
|
||||||
status: response.status,
|
status: response.status,
|
||||||
@@ -128,6 +149,47 @@ function writeProviderAuthPlugin(dir: string) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function writeProviderAuthValidationPlugin(dir: string) {
|
||||||
|
return Effect.gen(function* () {
|
||||||
|
const fs = yield* AppFileSystem.Service
|
||||||
|
|
||||||
|
yield* fs.writeWithDirs(
|
||||||
|
path.join(dir, ".opencode", "plugin", "provider-oauth-validation.ts"),
|
||||||
|
[
|
||||||
|
"export default {",
|
||||||
|
' id: "test.provider-oauth-validation",',
|
||||||
|
" server: async () => ({",
|
||||||
|
" auth: {",
|
||||||
|
' provider: "test-oauth-validation",',
|
||||||
|
" methods: [",
|
||||||
|
" {",
|
||||||
|
' type: "oauth",',
|
||||||
|
' label: "OAuth",',
|
||||||
|
" prompts: [",
|
||||||
|
" {",
|
||||||
|
' type: "text",',
|
||||||
|
' key: "token",',
|
||||||
|
' message: "Token",',
|
||||||
|
" validate: (value) => value === 'ok' ? undefined : 'Token must be ok',",
|
||||||
|
" },",
|
||||||
|
" ],",
|
||||||
|
" authorize: async () => ({",
|
||||||
|
` url: "${oauthURL}",`,
|
||||||
|
' method: "code",',
|
||||||
|
` instructions: "${oauthInstructions}",`,
|
||||||
|
" callback: async () => ({ type: 'success', key: 'token' }),",
|
||||||
|
" }),",
|
||||||
|
" },",
|
||||||
|
" ],",
|
||||||
|
" },",
|
||||||
|
" }),",
|
||||||
|
"}",
|
||||||
|
"",
|
||||||
|
].join("\n"),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
function writeFunctionOptionsPlugin(dir: string) {
|
function writeFunctionOptionsPlugin(dir: string) {
|
||||||
return Effect.gen(function* () {
|
return Effect.gen(function* () {
|
||||||
const fs = yield* AppFileSystem.Service
|
const fs = yield* AppFileSystem.Service
|
||||||
@@ -240,6 +302,51 @@ describe("provider HttpApi", () => {
|
|||||||
})
|
})
|
||||||
}),
|
}),
|
||||||
projectOptions,
|
projectOptions,
|
||||||
|
30000,
|
||||||
|
)
|
||||||
|
|
||||||
|
it.instance(
|
||||||
|
"returns declared provider auth validation errors",
|
||||||
|
Effect.gen(function* () {
|
||||||
|
const instance = yield* TestInstance
|
||||||
|
yield* writeProviderAuthValidationPlugin(instance.directory)
|
||||||
|
const response = yield* requestAuthorize({
|
||||||
|
app: app(),
|
||||||
|
providerID: "test-oauth-validation",
|
||||||
|
method: 0,
|
||||||
|
inputs: { token: "nope" },
|
||||||
|
headers: { "x-opencode-directory": instance.directory, "content-type": "application/json" },
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(response.status).toBe(400)
|
||||||
|
expect(JSON.parse(response.body)).toEqual({
|
||||||
|
name: "ProviderAuthValidationFailed",
|
||||||
|
data: { field: "token", message: "Token must be ok" },
|
||||||
|
})
|
||||||
|
}),
|
||||||
|
projectOptions,
|
||||||
|
30000,
|
||||||
|
)
|
||||||
|
|
||||||
|
it.instance(
|
||||||
|
"returns declared provider auth callback errors",
|
||||||
|
Effect.gen(function* () {
|
||||||
|
const instance = yield* TestInstance
|
||||||
|
const response = yield* requestCallback({
|
||||||
|
app: app(),
|
||||||
|
providerID,
|
||||||
|
method: 0,
|
||||||
|
headers: { "x-opencode-directory": instance.directory, "content-type": "application/json" },
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(response.status).toBe(400)
|
||||||
|
expect(JSON.parse(response.body)).toEqual({
|
||||||
|
name: "ProviderAuthOauthMissing",
|
||||||
|
data: { providerID },
|
||||||
|
})
|
||||||
|
}),
|
||||||
|
projectOptions,
|
||||||
|
30000,
|
||||||
)
|
)
|
||||||
|
|
||||||
it.instance(
|
it.instance(
|
||||||
|
|||||||
@@ -1721,6 +1721,21 @@ export type ProviderAuthAuthorization = {
|
|||||||
instructions: string
|
instructions: string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export type ProviderAuthError1 = {
|
||||||
|
name:
|
||||||
|
| "BadRequest"
|
||||||
|
| "ProviderAuthOauthMissing"
|
||||||
|
| "ProviderAuthOauthCodeMissing"
|
||||||
|
| "ProviderAuthOauthCallbackFailed"
|
||||||
|
| "ProviderAuthValidationFailed"
|
||||||
|
data: {
|
||||||
|
providerID?: string
|
||||||
|
field?: string
|
||||||
|
message?: string
|
||||||
|
kind?: string
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
export type TextPartInput = {
|
export type TextPartInput = {
|
||||||
id?: string
|
id?: string
|
||||||
type: "text"
|
type: "text"
|
||||||
@@ -5155,9 +5170,9 @@ export type ProviderOauthAuthorizeData = {
|
|||||||
|
|
||||||
export type ProviderOauthAuthorizeErrors = {
|
export type ProviderOauthAuthorizeErrors = {
|
||||||
/**
|
/**
|
||||||
* Bad request
|
* ProviderAuthError
|
||||||
*/
|
*/
|
||||||
400: BadRequestError
|
400: ProviderAuthError1
|
||||||
}
|
}
|
||||||
|
|
||||||
export type ProviderOauthAuthorizeError = ProviderOauthAuthorizeErrors[keyof ProviderOauthAuthorizeErrors]
|
export type ProviderOauthAuthorizeError = ProviderOauthAuthorizeErrors[keyof ProviderOauthAuthorizeErrors]
|
||||||
@@ -5191,9 +5206,9 @@ export type ProviderOauthCallbackData = {
|
|||||||
|
|
||||||
export type ProviderOauthCallbackErrors = {
|
export type ProviderOauthCallbackErrors = {
|
||||||
/**
|
/**
|
||||||
* Bad request
|
* ProviderAuthError
|
||||||
*/
|
*/
|
||||||
400: BadRequestError
|
400: ProviderAuthError1
|
||||||
}
|
}
|
||||||
|
|
||||||
export type ProviderOauthCallbackError = ProviderOauthCallbackErrors[keyof ProviderOauthCallbackErrors]
|
export type ProviderOauthCallbackError = ProviderOauthCallbackErrors[keyof ProviderOauthCallbackErrors]
|
||||||
|
|||||||
Reference in New Issue
Block a user