From eb80117c8d2d8e45041b36baa06afbe727a985ed Mon Sep 17 00:00:00 2001 From: Dax Raad Date: Mon, 8 Dec 2025 20:04:37 -0500 Subject: [PATCH] sync --- packages/opencode/src/provider/provider.ts | 2 +- packages/opencode/src/provider/transform.ts | 6 +- packages/opencode/src/session/compaction.ts | 67 ++-------- packages/opencode/src/session/llm.ts | 141 ++++++++++++++++++++ packages/opencode/src/session/message-v2.ts | 16 ++- packages/opencode/src/session/processor.ts | 5 +- packages/opencode/src/session/prompt.ts | 49 +------ 7 files changed, 170 insertions(+), 116 deletions(-) create mode 100644 packages/opencode/src/session/llm.ts diff --git a/packages/opencode/src/provider/provider.ts b/packages/opencode/src/provider/provider.ts index 880238a0d9..b823aceac4 100644 --- a/packages/opencode/src/provider/provider.ts +++ b/packages/opencode/src/provider/provider.ts @@ -838,7 +838,7 @@ export namespace Provider { return info } - export async function getLanguage(model: Model) { + export async function getLanguage(model: Model): Promise { const s = await state() const key = `${model.providerID}/${model.id}` if (s.models.has(key)) return s.models.get(key)! diff --git a/packages/opencode/src/provider/transform.ts b/packages/opencode/src/provider/transform.ts index 17fbf18f5f..fb432860a9 100644 --- a/packages/opencode/src/provider/transform.ts +++ b/packages/opencode/src/provider/transform.ts @@ -273,8 +273,8 @@ export namespace ProviderTransform { return options } - export function providerOptions(npm: string | undefined, providerID: string, options: { [x: string]: any }) { - switch (npm) { + export function providerOptions(model: Provider.Model, options: { [x: string]: any }) { + switch (model.api.npm) { case "@ai-sdk/openai": case "@ai-sdk/azure": return { @@ -302,7 +302,7 @@ export namespace ProviderTransform { } default: return { - [providerID]: options, + [model.providerID]: options, } } } diff --git a/packages/opencode/src/session/compaction.ts b/packages/opencode/src/session/compaction.ts index de75eda6e4..613b49b162 100644 --- a/packages/opencode/src/session/compaction.ts +++ b/packages/opencode/src/session/compaction.ts @@ -1,4 +1,3 @@ -import { wrapLanguageModel, type ModelMessage } from "ai" import { Session } from "." import { Identifier } from "../id/id" import { Instance } from "../project/instance" @@ -12,10 +11,9 @@ import { Flag } from "../flag/flag" import { Token } from "../util/token" import { Config } from "../config/config" import { Log } from "../util/log" -import { ProviderTransform } from "@/provider/transform" import { SessionProcessor } from "./processor" import { fn } from "@/util/fn" -import { mergeDeep, pipe } from "remeda" +import { Agent } from "@/agent/agent" export namespace SessionCompaction { const log = Log.create({ service: "session.compaction" }) @@ -97,9 +95,7 @@ export namespace SessionCompaction { abort: AbortSignal auto: boolean }) { - const cfg = await Config.get() const model = await Provider.getModel(input.model.providerID, input.model.modelID) - const language = await Provider.getLanguage(model) const system = [...SystemPrompt.compaction(model.providerID)] const msg = (await Session.updateMessage({ id: Identifier.ascending("message"), @@ -131,44 +127,16 @@ export namespace SessionCompaction { model: model, abort: input.abort, }) + const agent = await Agent.get(input.agent) const result = await processor.process({ - onError(error) { - log.error("stream error", { - error, - }) - }, - // set to 0, we handle loop - maxRetries: 0, - providerOptions: ProviderTransform.providerOptions( - model.api.npm, - model.providerID, - pipe({}, mergeDeep(ProviderTransform.options(model, input.sessionID)), mergeDeep(model.options)), - ), - headers: model.headers, - abortSignal: input.abort, - tools: model.capabilities.toolcall ? {} : undefined, + requestID: input.parentID, + agent, + abort: input.abort, + sessionID: input.sessionID, + tools: {}, + system, messages: [ - ...system.map( - (x): ModelMessage => ({ - role: "system", - content: x, - }), - ), - ...MessageV2.toModelMessage( - input.messages.filter((m) => { - if (m.info.role !== "assistant" || m.info.error === undefined) { - return true - } - if ( - MessageV2.AbortedError.isInstance(m.info.error) && - m.parts.some((part) => part.type !== "step-start" && part.type !== "reasoning") - ) { - return true - } - - return false - }), - ), + ...MessageV2.toModelMessage(input.messages), { role: "user", content: [ @@ -179,22 +147,9 @@ export namespace SessionCompaction { ], }, ], - model: wrapLanguageModel({ - model: language, - middleware: [ - { - async transformParams(args) { - if (args.type === "stream") { - // @ts-expect-error - args.params.prompt = ProviderTransform.message(args.params.prompt, model) - } - return args.params - }, - }, - ], - }), - experimental_telemetry: { isEnabled: cfg.experimental?.openTelemetry }, + model, }) + if (result === "continue" && input.auto) { const continueMsg = await Session.updateMessage({ id: Identifier.ascending("message"), diff --git a/packages/opencode/src/session/llm.ts b/packages/opencode/src/session/llm.ts new file mode 100644 index 0000000000..a4996a31aa --- /dev/null +++ b/packages/opencode/src/session/llm.ts @@ -0,0 +1,141 @@ +import { Provider } from "@/provider/provider" +import { Log } from "@/util/log" +import { streamText, wrapLanguageModel, type ModelMessage, type StreamTextResult, type Tool, type ToolSet } from "ai" +import { mergeDeep, pipe } from "remeda" +import { ProviderTransform } from "@/provider/transform" +import { iife } from "@/util/iife" +import { Config } from "@/config/config" +import { Instance } from "@/project/instance" +import type { Agent } from "@/agent/agent" + +export namespace LLM { + const log = Log.create({ service: "llm" }) + + export const OUTPUT_TOKEN_MAX = 32_000 + + export type StreamInput = { + requestID: string + sessionID: string + model: Provider.Model + agent: Agent.Info + system: string[] + abort: AbortSignal + messages: ModelMessage[] + tools: Record + retries?: number + } + + export type StreamOutput = StreamTextResult + + export async function stream(input: StreamInput) { + const [language, cfg] = await Promise.all([Provider.getLanguage(input.model), Config.get()]) + + const [first, ...rest] = input.system + const system = [first, rest.join("\n")] + const options = pipe( + ProviderTransform.options(input.model, input.sessionID), + mergeDeep(input.model.options), + mergeDeep(input.agent.options), + ) + const maxOutputTokens = ProviderTransform.maxOutputTokens( + input.model.api.npm, + options, + input.model.limit.output, + OUTPUT_TOKEN_MAX, + ) + const temperature = input.model.capabilities.temperature + ? (input.agent.temperature ?? ProviderTransform.temperature(input.model)) + : undefined + const topP = input.agent.topP ?? ProviderTransform.topP(input.model) + + return streamText({ + onError(error) { + log.error("stream error", { + error, + }) + }, + async experimental_repairToolCall(failed) { + const lower = failed.toolCall.toolName.toLowerCase() + if (lower !== failed.toolCall.toolName && input.tools[lower]) { + log.info("repairing tool call", { + tool: failed.toolCall.toolName, + repaired: lower, + }) + return { + ...failed.toolCall, + toolName: lower, + } + } + return { + ...failed.toolCall, + input: JSON.stringify({ + tool: failed.toolCall.toolName, + error: failed.error.message, + }), + toolName: "invalid", + } + }, + temperature, + topP, + providerOptions: { + [iife(() => { + switch (input.model.api.npm) { + case "@ai-sdk/openai": + case "@ai-sdk/azure": + return `openai` + case "@ai-sdk/amazon-bedrock": + return `bedrock` + case "@ai-sdk/anthropic": + return `anthropic` + case "@ai-sdk/google": + return `google` + case "@ai-sdk/gateway": + return `gateway` + case "@openrouter/ai-sdk-provider": + return `openrouter` + default: + return input.model.providerID + } + })]: options, + }, + activeTools: Object.keys(input.tools).filter((x) => x !== "invalid"), + maxOutputTokens, + abortSignal: input.abort, + headers: { + ...(input.model.providerID.startsWith("opencode") + ? { + "x-opencode-project": Instance.project.id, + "x-opencode-session": input.sessionID, + "x-opencode-request": input.requestID, + } + : undefined), + ...input.model.headers, + }, + maxRetries: input.retries ?? 0, + messages: [ + ...system.map( + (x): ModelMessage => ({ + role: "system", + content: x, + }), + ), + ...input.messages, + ], + model: wrapLanguageModel({ + model: language, + middleware: [ + { + async transformParams(args) { + if (args.type === "stream") { + // @ts-expect-error + args.params.prompt = ProviderTransform.message(args.params.prompt, input.model) + } + return args.params + }, + }, + ], + }), + experimental_telemetry: { isEnabled: cfg.experimental?.openTelemetry }, + }) + } +} diff --git a/packages/opencode/src/session/message-v2.ts b/packages/opencode/src/session/message-v2.ts index 50a480626e..ed918d5bdb 100644 --- a/packages/opencode/src/session/message-v2.ts +++ b/packages/opencode/src/session/message-v2.ts @@ -411,12 +411,7 @@ export namespace MessageV2 { }) export type WithParts = z.infer - export function toModelMessage( - input: { - info: Info - parts: Part[] - }[], - ): ModelMessage[] { + export function toModelMessage(input: WithParts[]): ModelMessage[] { const result: UIMessage[] = [] for (const msg of input) { @@ -460,6 +455,15 @@ export namespace MessageV2 { } if (msg.info.role === "assistant") { + if ( + msg.info.error && + !( + MessageV2.AbortedError.isInstance(msg.info.error) && + msg.parts.some((part) => part.type !== "step-start" && part.type !== "reasoning") + ) + ) { + continue + } const assistantMessage: UIMessage = { id: msg.info.id, role: "assistant", diff --git a/packages/opencode/src/session/processor.ts b/packages/opencode/src/session/processor.ts index f1f7dd0964..fe2756957f 100644 --- a/packages/opencode/src/session/processor.ts +++ b/packages/opencode/src/session/processor.ts @@ -12,6 +12,7 @@ import { SessionRetry } from "./retry" import { SessionStatus } from "./status" import { Plugin } from "@/plugin" import type { Provider } from "@/provider/provider" +import { LLM } from "./llm" export namespace SessionProcessor { const DOOM_LOOP_THRESHOLD = 3 @@ -47,13 +48,13 @@ export namespace SessionProcessor { partFromToolCall(toolCallID: string) { return toolcalls[toolCallID] }, - async process(streamInput: StreamInput) { + async process(streamInput: LLM.StreamInput) { log.info("process") while (true) { try { let currentText: MessageV2.TextPart | undefined let reasoningMap: Record = {} - const stream = streamText(streamInput) + const stream = await LLM.stream(streamInput) for await (const value of stream.fullStream) { input.abort.throwIfAborted() diff --git a/packages/opencode/src/session/prompt.ts b/packages/opencode/src/session/prompt.ts index d5010bc47d..e819345830 100644 --- a/packages/opencode/src/session/prompt.ts +++ b/packages/opencode/src/session/prompt.ts @@ -516,54 +516,6 @@ export namespace SessionPrompt { } const result = await processor.process({ - onError(error) { - log.error("stream error", { - error, - }) - }, - async experimental_repairToolCall(input) { - const lower = input.toolCall.toolName.toLowerCase() - if (lower !== input.toolCall.toolName && tools[lower]) { - log.info("repairing tool call", { - tool: input.toolCall.toolName, - repaired: lower, - }) - return { - ...input.toolCall, - toolName: lower, - } - } - return { - ...input.toolCall, - input: JSON.stringify({ - tool: input.toolCall.toolName, - error: input.error.message, - }), - toolName: "invalid", - } - }, - headers: { - ...(model.providerID.startsWith("opencode") - ? { - "x-opencode-project": Instance.project.id, - "x-opencode-session": sessionID, - "x-opencode-request": lastUser.id, - } - : undefined), - ...model.headers, - }, - // set to 0, we handle loop - maxRetries: 0, - activeTools: Object.keys(tools).filter((x) => x !== "invalid"), - maxOutputTokens: ProviderTransform.maxOutputTokens( - model.api.npm, - params.options, - model.limit.output, - OUTPUT_TOKEN_MAX, - ), - abortSignal: abort, - providerOptions: ProviderTransform.providerOptions(model.api.npm, model.providerID, params.options), - stopWhen: stepCountIs(1), temperature: params.temperature, topP: params.topP, toolChoice: isLastStep ? "none" : undefined, @@ -692,6 +644,7 @@ export namespace SessionPrompt { mergeDeep(await ToolRegistry.enabled(input.agent)), mergeDeep(input.tools ?? {}), ) + for (const item of await ToolRegistry.tools(input.model.providerID)) { if (Wildcard.all(item.id, enabledTools) === false) continue const schema = ProviderTransform.schema(input.model, z.toJSONSchema(item.parameters))