This commit is contained in:
Dax Raad
2025-12-08 20:04:37 -05:00
parent d22754dd68
commit eb80117c8d
7 changed files with 170 additions and 116 deletions

View File

@@ -838,7 +838,7 @@ export namespace Provider {
return info
}
export async function getLanguage(model: Model) {
export async function getLanguage(model: Model): Promise<LanguageModelV2> {
const s = await state()
const key = `${model.providerID}/${model.id}`
if (s.models.has(key)) return s.models.get(key)!

View File

@@ -273,8 +273,8 @@ export namespace ProviderTransform {
return options
}
export function providerOptions(npm: string | undefined, providerID: string, options: { [x: string]: any }) {
switch (npm) {
export function providerOptions(model: Provider.Model, options: { [x: string]: any }) {
switch (model.api.npm) {
case "@ai-sdk/openai":
case "@ai-sdk/azure":
return {
@@ -302,7 +302,7 @@ export namespace ProviderTransform {
}
default:
return {
[providerID]: options,
[model.providerID]: options,
}
}
}

View File

@@ -1,4 +1,3 @@
import { wrapLanguageModel, type ModelMessage } from "ai"
import { Session } from "."
import { Identifier } from "../id/id"
import { Instance } from "../project/instance"
@@ -12,10 +11,9 @@ import { Flag } from "../flag/flag"
import { Token } from "../util/token"
import { Config } from "../config/config"
import { Log } from "../util/log"
import { ProviderTransform } from "@/provider/transform"
import { SessionProcessor } from "./processor"
import { fn } from "@/util/fn"
import { mergeDeep, pipe } from "remeda"
import { Agent } from "@/agent/agent"
export namespace SessionCompaction {
const log = Log.create({ service: "session.compaction" })
@@ -97,9 +95,7 @@ export namespace SessionCompaction {
abort: AbortSignal
auto: boolean
}) {
const cfg = await Config.get()
const model = await Provider.getModel(input.model.providerID, input.model.modelID)
const language = await Provider.getLanguage(model)
const system = [...SystemPrompt.compaction(model.providerID)]
const msg = (await Session.updateMessage({
id: Identifier.ascending("message"),
@@ -131,44 +127,16 @@ export namespace SessionCompaction {
model: model,
abort: input.abort,
})
const agent = await Agent.get(input.agent)
const result = await processor.process({
onError(error) {
log.error("stream error", {
error,
})
},
// set to 0, we handle loop
maxRetries: 0,
providerOptions: ProviderTransform.providerOptions(
model.api.npm,
model.providerID,
pipe({}, mergeDeep(ProviderTransform.options(model, input.sessionID)), mergeDeep(model.options)),
),
headers: model.headers,
abortSignal: input.abort,
tools: model.capabilities.toolcall ? {} : undefined,
requestID: input.parentID,
agent,
abort: input.abort,
sessionID: input.sessionID,
tools: {},
system,
messages: [
...system.map(
(x): ModelMessage => ({
role: "system",
content: x,
}),
),
...MessageV2.toModelMessage(
input.messages.filter((m) => {
if (m.info.role !== "assistant" || m.info.error === undefined) {
return true
}
if (
MessageV2.AbortedError.isInstance(m.info.error) &&
m.parts.some((part) => part.type !== "step-start" && part.type !== "reasoning")
) {
return true
}
return false
}),
),
...MessageV2.toModelMessage(input.messages),
{
role: "user",
content: [
@@ -179,22 +147,9 @@ export namespace SessionCompaction {
],
},
],
model: wrapLanguageModel({
model: language,
middleware: [
{
async transformParams(args) {
if (args.type === "stream") {
// @ts-expect-error
args.params.prompt = ProviderTransform.message(args.params.prompt, model)
}
return args.params
},
},
],
}),
experimental_telemetry: { isEnabled: cfg.experimental?.openTelemetry },
model,
})
if (result === "continue" && input.auto) {
const continueMsg = await Session.updateMessage({
id: Identifier.ascending("message"),

View File

@@ -0,0 +1,141 @@
import { Provider } from "@/provider/provider"
import { Log } from "@/util/log"
import { streamText, wrapLanguageModel, type ModelMessage, type StreamTextResult, type Tool, type ToolSet } from "ai"
import { mergeDeep, pipe } from "remeda"
import { ProviderTransform } from "@/provider/transform"
import { iife } from "@/util/iife"
import { Config } from "@/config/config"
import { Instance } from "@/project/instance"
import type { Agent } from "@/agent/agent"
export namespace LLM {
const log = Log.create({ service: "llm" })
export const OUTPUT_TOKEN_MAX = 32_000
export type StreamInput = {
requestID: string
sessionID: string
model: Provider.Model
agent: Agent.Info
system: string[]
abort: AbortSignal
messages: ModelMessage[]
tools: Record<string, Tool>
retries?: number
}
export type StreamOutput = StreamTextResult<ToolSet, unknown>
export async function stream(input: StreamInput) {
const [language, cfg] = await Promise.all([Provider.getLanguage(input.model), Config.get()])
const [first, ...rest] = input.system
const system = [first, rest.join("\n")]
const options = pipe(
ProviderTransform.options(input.model, input.sessionID),
mergeDeep(input.model.options),
mergeDeep(input.agent.options),
)
const maxOutputTokens = ProviderTransform.maxOutputTokens(
input.model.api.npm,
options,
input.model.limit.output,
OUTPUT_TOKEN_MAX,
)
const temperature = input.model.capabilities.temperature
? (input.agent.temperature ?? ProviderTransform.temperature(input.model))
: undefined
const topP = input.agent.topP ?? ProviderTransform.topP(input.model)
return streamText({
onError(error) {
log.error("stream error", {
error,
})
},
async experimental_repairToolCall(failed) {
const lower = failed.toolCall.toolName.toLowerCase()
if (lower !== failed.toolCall.toolName && input.tools[lower]) {
log.info("repairing tool call", {
tool: failed.toolCall.toolName,
repaired: lower,
})
return {
...failed.toolCall,
toolName: lower,
}
}
return {
...failed.toolCall,
input: JSON.stringify({
tool: failed.toolCall.toolName,
error: failed.error.message,
}),
toolName: "invalid",
}
},
temperature,
topP,
providerOptions: {
[iife(() => {
switch (input.model.api.npm) {
case "@ai-sdk/openai":
case "@ai-sdk/azure":
return `openai`
case "@ai-sdk/amazon-bedrock":
return `bedrock`
case "@ai-sdk/anthropic":
return `anthropic`
case "@ai-sdk/google":
return `google`
case "@ai-sdk/gateway":
return `gateway`
case "@openrouter/ai-sdk-provider":
return `openrouter`
default:
return input.model.providerID
}
})]: options,
},
activeTools: Object.keys(input.tools).filter((x) => x !== "invalid"),
maxOutputTokens,
abortSignal: input.abort,
headers: {
...(input.model.providerID.startsWith("opencode")
? {
"x-opencode-project": Instance.project.id,
"x-opencode-session": input.sessionID,
"x-opencode-request": input.requestID,
}
: undefined),
...input.model.headers,
},
maxRetries: input.retries ?? 0,
messages: [
...system.map(
(x): ModelMessage => ({
role: "system",
content: x,
}),
),
...input.messages,
],
model: wrapLanguageModel({
model: language,
middleware: [
{
async transformParams(args) {
if (args.type === "stream") {
// @ts-expect-error
args.params.prompt = ProviderTransform.message(args.params.prompt, input.model)
}
return args.params
},
},
],
}),
experimental_telemetry: { isEnabled: cfg.experimental?.openTelemetry },
})
}
}

View File

@@ -411,12 +411,7 @@ export namespace MessageV2 {
})
export type WithParts = z.infer<typeof WithParts>
export function toModelMessage(
input: {
info: Info
parts: Part[]
}[],
): ModelMessage[] {
export function toModelMessage(input: WithParts[]): ModelMessage[] {
const result: UIMessage[] = []
for (const msg of input) {
@@ -460,6 +455,15 @@ export namespace MessageV2 {
}
if (msg.info.role === "assistant") {
if (
msg.info.error &&
!(
MessageV2.AbortedError.isInstance(msg.info.error) &&
msg.parts.some((part) => part.type !== "step-start" && part.type !== "reasoning")
)
) {
continue
}
const assistantMessage: UIMessage = {
id: msg.info.id,
role: "assistant",

View File

@@ -12,6 +12,7 @@ import { SessionRetry } from "./retry"
import { SessionStatus } from "./status"
import { Plugin } from "@/plugin"
import type { Provider } from "@/provider/provider"
import { LLM } from "./llm"
export namespace SessionProcessor {
const DOOM_LOOP_THRESHOLD = 3
@@ -47,13 +48,13 @@ export namespace SessionProcessor {
partFromToolCall(toolCallID: string) {
return toolcalls[toolCallID]
},
async process(streamInput: StreamInput) {
async process(streamInput: LLM.StreamInput) {
log.info("process")
while (true) {
try {
let currentText: MessageV2.TextPart | undefined
let reasoningMap: Record<string, MessageV2.ReasoningPart> = {}
const stream = streamText(streamInput)
const stream = await LLM.stream(streamInput)
for await (const value of stream.fullStream) {
input.abort.throwIfAborted()

View File

@@ -516,54 +516,6 @@ export namespace SessionPrompt {
}
const result = await processor.process({
onError(error) {
log.error("stream error", {
error,
})
},
async experimental_repairToolCall(input) {
const lower = input.toolCall.toolName.toLowerCase()
if (lower !== input.toolCall.toolName && tools[lower]) {
log.info("repairing tool call", {
tool: input.toolCall.toolName,
repaired: lower,
})
return {
...input.toolCall,
toolName: lower,
}
}
return {
...input.toolCall,
input: JSON.stringify({
tool: input.toolCall.toolName,
error: input.error.message,
}),
toolName: "invalid",
}
},
headers: {
...(model.providerID.startsWith("opencode")
? {
"x-opencode-project": Instance.project.id,
"x-opencode-session": sessionID,
"x-opencode-request": lastUser.id,
}
: undefined),
...model.headers,
},
// set to 0, we handle loop
maxRetries: 0,
activeTools: Object.keys(tools).filter((x) => x !== "invalid"),
maxOutputTokens: ProviderTransform.maxOutputTokens(
model.api.npm,
params.options,
model.limit.output,
OUTPUT_TOKEN_MAX,
),
abortSignal: abort,
providerOptions: ProviderTransform.providerOptions(model.api.npm, model.providerID, params.options),
stopWhen: stepCountIs(1),
temperature: params.temperature,
topP: params.topP,
toolChoice: isLastStep ? "none" : undefined,
@@ -692,6 +644,7 @@ export namespace SessionPrompt {
mergeDeep(await ToolRegistry.enabled(input.agent)),
mergeDeep(input.tools ?? {}),
)
for (const item of await ToolRegistry.tools(input.model.providerID)) {
if (Wildcard.all(item.id, enabledTools) === false) continue
const schema = ProviderTransform.schema(input.model, z.toJSONSchema(item.parameters))