Preserve native LLM tool context (#27116)

This commit is contained in:
Kit Langton
2026-05-12 16:16:58 -04:00
committed by GitHub
parent b9e7cbf13c
commit dd14413a64
18 changed files with 244 additions and 75 deletions

View File

@@ -78,7 +78,7 @@ const streamText = LLM.stream(request).pipe(
Stream.tap((event) =>
Effect.sync(() => {
if (event.type === "text-delta") process.stdout.write(`\ntext: ${event.text}`)
if (event.type === "request-finish") process.stdout.write(`\nfinish: ${event.reason}\n`)
if (event.type === "finish") process.stdout.write(`\nfinish: ${event.reason}\n`)
}),
),
Stream.runDrain,
@@ -185,7 +185,7 @@ const FakeProtocol = Protocol.make<FakeBody, string, string, void>({
event: Schema.String,
initial: () => undefined,
step: (_, frame) => Effect.succeed([undefined, [{ type: "text-delta", id: "text-0", text: frame }]] as const),
onHalt: () => [{ type: "request-finish", reason: "stop" }],
onHalt: () => [{ type: "finish", reason: "stop" }],
},
})

View File

@@ -17,6 +17,7 @@ export type {
ExecutableTools,
Tool as ToolShape,
ToolExecute,
ToolExecuteContext,
Tools,
ToolSchema,
} from "./tool"

View File

@@ -380,7 +380,7 @@ type StepResult = readonly [ParserState, ReadonlyArray<LLMEvent>]
const NO_EVENTS: StepResult["1"] = []
// `response.completed` / `response.incomplete` are clean finishes that emit a
// `request-finish` event; `response.failed` is a hard failure that emits a
// `finish` event; `response.failed` is a hard failure that emits a
// `provider-error`. All three end the stream — kept in one set so `step` and
// the protocol's `terminal` predicate stay in sync.
const TERMINAL_TYPES = new Set(["response.completed", "response.incomplete", "response.failed"])

View File

@@ -80,7 +80,7 @@ export const finish = (
usage: input.usage,
providerMetadata: input.providerMetadata,
}),
LLMEvent.requestFinish(input),
LLMEvent.finish(input),
)
return { ...stepped, stepStarted: false }
}

View File

@@ -1,5 +1,5 @@
import { Schema } from "effect"
import { ContentBlockID, FinishReason, ProtocolID, ProviderMetadata, ResponseID, RouteID, ToolCallID } from "./ids"
import { ContentBlockID, FinishReason, ProtocolID, ProviderMetadata, RouteID, ToolCallID } from "./ids"
import { ModelRef } from "./options"
import { ToolResultValue } from "./messages"
@@ -66,14 +66,13 @@ export class Usage extends Schema.Class<Usage>("LLM.Usage")({
get visibleOutputTokens() {
return Math.max(0, (this.outputTokens ?? 0) - (this.reasoningTokens ?? 0))
}
static from(input: UsageInput) {
return input instanceof Usage ? input : new Usage(input)
}
}
export const RequestStart = Schema.Struct({
type: Schema.tag("request-start"),
id: ResponseID,
model: ModelRef,
}).annotate({ identifier: "LLM.Event.RequestStart" })
export type RequestStart = Schema.Schema.Type<typeof RequestStart>
export type UsageInput = Usage | ConstructorParameters<typeof Usage>[0]
export const StepStart = Schema.Struct({
type: Schema.tag("step-start"),
@@ -185,13 +184,13 @@ export const StepFinish = Schema.Struct({
}).annotate({ identifier: "LLM.Event.StepFinish" })
export type StepFinish = Schema.Schema.Type<typeof StepFinish>
export const RequestFinish = Schema.Struct({
type: Schema.tag("request-finish"),
export const Finish = Schema.Struct({
type: Schema.tag("finish"),
reason: FinishReason,
usage: Schema.optional(Usage),
providerMetadata: Schema.optional(ProviderMetadata),
}).annotate({ identifier: "LLM.Event.RequestFinish" })
export type RequestFinish = Schema.Schema.Type<typeof RequestFinish>
}).annotate({ identifier: "LLM.Event.Finish" })
export type Finish = Schema.Schema.Type<typeof Finish>
export const ProviderErrorEvent = Schema.Struct({
type: Schema.tag("provider-error"),
@@ -202,7 +201,6 @@ export const ProviderErrorEvent = Schema.Struct({
export type ProviderErrorEvent = Schema.Schema.Type<typeof ProviderErrorEvent>
const llmEventTagged = Schema.Union([
RequestStart,
StepStart,
TextStart,
TextDelta,
@@ -217,13 +215,15 @@ const llmEventTagged = Schema.Union([
ToolResult,
ToolError,
StepFinish,
RequestFinish,
Finish,
ProviderErrorEvent,
]).pipe(Schema.toTaggedUnion("type"))
type WithID<Event extends { readonly id: unknown }, ID> = Omit<Event, "type" | "id"> & { readonly id: ID | string }
type WithUsage<Event extends { readonly usage?: Usage }> = Omit<Event, "type" | "usage"> & {
readonly usage?: UsageInput
}
const responseID = (value: ResponseID | string) => ResponseID.make(value)
const contentBlockID = (value: ContentBlockID | string) => ContentBlockID.make(value)
const toolCallID = (value: ToolCallID | string) => ToolCallID.make(value)
@@ -233,7 +233,6 @@ const toolCallID = (value: ToolCallID | string) => ToolCallID.make(value)
* `events.filter(LLMEvent.guards["tool-call"])`.
*/
export const LLMEvent = Object.assign(llmEventTagged, {
requestStart: (input: WithID<RequestStart, ResponseID>) => RequestStart.make({ ...input, id: responseID(input.id) }),
stepStart: StepStart.make,
textStart: (input: WithID<TextStart, ContentBlockID>) => TextStart.make({ ...input, id: contentBlockID(input.id) }),
textDelta: (input: WithID<TextDelta, ContentBlockID>) => TextDelta.make({ ...input, id: contentBlockID(input.id) }),
@@ -252,11 +251,18 @@ export const LLMEvent = Object.assign(llmEventTagged, {
toolCall: (input: WithID<ToolCall, ToolCallID>) => ToolCall.make({ ...input, id: toolCallID(input.id) }),
toolResult: (input: WithID<ToolResult, ToolCallID>) => ToolResult.make({ ...input, id: toolCallID(input.id) }),
toolError: (input: WithID<ToolError, ToolCallID>) => ToolError.make({ ...input, id: toolCallID(input.id) }),
stepFinish: StepFinish.make,
requestFinish: RequestFinish.make,
stepFinish: (input: WithUsage<StepFinish>) =>
StepFinish.make({
...input,
usage: input.usage === undefined ? undefined : Usage.from(input.usage),
}),
finish: (input: WithUsage<Finish>) =>
Finish.make({
...input,
usage: input.usage === undefined ? undefined : Usage.from(input.usage),
}),
providerError: ProviderErrorEvent.make,
is: {
requestStart: llmEventTagged.guards["request-start"],
stepStart: llmEventTagged.guards["step-start"],
textStart: llmEventTagged.guards["text-start"],
textDelta: llmEventTagged.guards["text-delta"],
@@ -271,7 +277,7 @@ export const LLMEvent = Object.assign(llmEventTagged, {
toolResult: llmEventTagged.guards["tool-result"],
toolError: llmEventTagged.guards["tool-error"],
stepFinish: llmEventTagged.guards["step-finish"],
requestFinish: llmEventTagged.guards["request-finish"],
finish: llmEventTagged.guards.finish,
providerError: llmEventTagged.guards["provider-error"],
},
})

View File

@@ -12,6 +12,7 @@ import {
ToolFailure,
ToolResultPart,
type ToolResultValue,
Usage,
} from "./schema"
import { type AnyTool, type ExecutableTools, type Tools, toDefinitions } from "./tool"
@@ -72,19 +73,42 @@ export const stream = <T extends Tools>(options: StreamOptions<T>): Stream.Strea
tools: [...options.request.tools.filter((tool) => !runtimeToolNames.has(tool.name)), ...runtimeTools],
})
const loop = (request: LLMRequest, step: number): Stream.Stream<LLMEvent, LLMError> =>
const loop = (
request: LLMRequest,
step: number,
usage: Usage | undefined,
providerMetadata: ProviderMetadata | undefined,
): Stream.Stream<LLMEvent, LLMError> =>
Stream.unwrap(
Effect.gen(function* () {
const state: StepState = { assistantContent: [], toolCalls: [], finishReason: undefined }
const state: StepState = {
assistantContent: [],
toolCalls: [],
finishReason: undefined,
usage: undefined,
providerMetadata: undefined,
}
const modelStream = options
.stream(request)
.pipe(Stream.map((event) => indexStep(event, step)))
.pipe(Stream.tap((event) => Effect.sync(() => accumulate(state, event))))
.pipe(Stream.filter((event) => event.type !== "finish"))
const continuation = Stream.unwrap(
Effect.gen(function* () {
if (state.finishReason !== "tool-calls" || state.toolCalls.length === 0) return Stream.empty
if (options.toolExecution === "none") return Stream.empty
const totalUsage = addUsage(usage, state.usage)
const totalProviderMetadata = mergeProviderMetadata(providerMetadata, state.providerMetadata)
const finishStream = Stream.fromIterable([
LLMEvent.finish({
reason: state.finishReason ?? "unknown",
usage: totalUsage,
providerMetadata: totalProviderMetadata,
}),
])
if (state.finishReason !== "tool-calls" || state.toolCalls.length === 0) return finishStream
if (options.toolExecution === "none") return finishStream
const dispatched = yield* Effect.forEach(
state.toolCalls,
@@ -93,10 +117,14 @@ export const stream = <T extends Tools>(options: StreamOptions<T>): Stream.Strea
)
const resultStream = Stream.fromIterable(dispatched.flatMap(([call, result]) => emitEvents(call, result)))
if (!options.stopWhen) return resultStream
if (options.stopWhen({ step, request })) return resultStream
if (!options.stopWhen) return resultStream.pipe(Stream.concat(finishStream))
if (options.stopWhen({ step, request })) return resultStream.pipe(Stream.concat(finishStream))
return resultStream.pipe(Stream.concat(loop(followUpRequest(request, state, dispatched), step + 1)))
return resultStream.pipe(
Stream.concat(
loop(followUpRequest(request, state, dispatched), step + 1, totalUsage, totalProviderMetadata),
),
)
}),
)
@@ -104,13 +132,21 @@ export const stream = <T extends Tools>(options: StreamOptions<T>): Stream.Strea
}),
)
return loop(initialRequest, 0)
return loop(initialRequest, 0, undefined, undefined)
}
const indexStep = (event: LLMEvent, index: number): LLMEvent => {
if (event.type === "step-start") return LLMEvent.stepStart({ index })
if (event.type === "step-finish") return LLMEvent.stepFinish({ ...event, index })
return event
}
interface StepState {
assistantContent: ContentPart[]
toolCalls: ToolCallPart[]
finishReason: FinishReason | undefined
usage: Usage | undefined
providerMetadata: ProviderMetadata | undefined
}
const accumulate = (state: StepState, event: LLMEvent) => {
@@ -154,9 +190,43 @@ const accumulate = (state: StepState, event: LLMEvent) => {
)
return
}
if (event.type === "step-finish" || event.type === "request-finish") {
if (event.type === "step-finish") {
state.finishReason = event.reason === "stop" && state.toolCalls.length > 0 ? "tool-calls" : event.reason
state.usage = addUsage(state.usage, event.usage)
state.providerMetadata = mergeProviderMetadata(state.providerMetadata, event.providerMetadata)
return
}
if (event.type === "finish") {
state.finishReason ??= event.reason
state.usage ??= event.usage
state.providerMetadata = mergeProviderMetadata(state.providerMetadata, event.providerMetadata)
}
}
const addUsage = (left: Usage | undefined, right: Usage | undefined) => {
if (!left) return right
if (!right) return left
type UsageKey =
| "inputTokens"
| "outputTokens"
| "nonCachedInputTokens"
| "cacheReadInputTokens"
| "cacheWriteInputTokens"
| "reasoningTokens"
| "totalTokens"
const sum = (key: UsageKey) =>
left[key] === undefined && right[key] === undefined ? undefined : Number(left[key] ?? 0) + Number(right[key] ?? 0)
return new Usage({
inputTokens: sum("inputTokens"),
outputTokens: sum("outputTokens"),
nonCachedInputTokens: sum("nonCachedInputTokens"),
cacheReadInputTokens: sum("cacheReadInputTokens"),
cacheWriteInputTokens: sum("cacheWriteInputTokens"),
reasoningTokens: sum("reasoningTokens"),
totalTokens: sum("totalTokens"),
providerMetadata: mergeProviderMetadata(left.providerMetadata, right.providerMetadata),
})
}
const sameProviderMetadata = (left: ProviderMetadata | undefined, right: ProviderMetadata | undefined) =>
@@ -200,17 +270,17 @@ const dispatch = (tools: Tools, call: ToolCallPart): Effect.Effect<ToolResultVal
if (!tool.execute)
return Effect.succeed({ type: "error" as const, value: `Tool has no execute handler: ${call.name}` })
return decodeAndExecute(tool, call.input).pipe(
return decodeAndExecute(tool, call).pipe(
Effect.catchTag("LLM.ToolFailure", (failure) =>
Effect.succeed({ type: "error" as const, value: failure.message } satisfies ToolResultValue),
),
)
}
const decodeAndExecute = (tool: AnyTool, input: unknown): Effect.Effect<ToolResultValue, ToolFailure> =>
tool._decode(input).pipe(
const decodeAndExecute = (tool: AnyTool, call: ToolCallPart): Effect.Effect<ToolResultValue, ToolFailure> =>
tool._decode(call.input).pipe(
Effect.mapError((error) => new ToolFailure({ message: `Invalid tool input: ${error.message}` })),
Effect.flatMap((decoded) => tool.execute!(decoded)),
Effect.flatMap((decoded) => tool.execute!(decoded, { id: call.id, name: call.name })),
Effect.flatMap((value) =>
tool._encode(value).pipe(
Effect.mapError(

View File

@@ -1,5 +1,5 @@
import { Effect, JsonSchema, Schema } from "effect"
import type { ToolDefinition as ToolDefinitionClass } from "./schema"
import type { ToolCallPart, ToolDefinition as ToolDefinitionClass } from "./schema"
import { ToolDefinition, ToolFailure } from "./schema"
/**
@@ -8,9 +8,14 @@ import { ToolDefinition, ToolFailure } from "./schema"
* beyond pure data conversion belongs in the handler closure.
*/
export type ToolSchema<T> = Schema.Codec<T, any, never, never>
export interface ToolExecuteContext {
readonly id: ToolCallPart["id"]
readonly name: ToolCallPart["name"]
}
export type ToolExecute<Parameters extends ToolSchema<any>, Success extends ToolSchema<any>> = (
params: Schema.Schema.Type<Parameters>,
context?: ToolExecuteContext,
) => Effect.Effect<Schema.Schema.Type<Success>, ToolFailure>
/**
@@ -61,7 +66,7 @@ type TypedToolConfig = {
type DynamicToolConfig = {
readonly description: string
readonly jsonSchema: JsonSchema.JsonSchema
readonly execute?: (params: unknown) => Effect.Effect<unknown, ToolFailure>
readonly execute?: (params: unknown, context?: ToolExecuteContext) => Effect.Effect<unknown, ToolFailure>
}
/**
@@ -110,7 +115,7 @@ export function make<Parameters extends ToolSchema<any>, Success extends ToolSch
export function make(config: {
readonly description: string
readonly jsonSchema: JsonSchema.JsonSchema
readonly execute: (params: unknown) => Effect.Effect<unknown, ToolFailure>
readonly execute: (params: unknown, context?: ToolExecuteContext) => Effect.Effect<unknown, ToolFailure>
}): AnyExecutableTool
export function make(config: {
readonly description: string

View File

@@ -51,7 +51,7 @@ const request = LLM.request({
const raiseEvent = (event: FakeEvent): import("../src/schema").LLMEvent =>
event.type === "finish"
? { type: "request-finish", reason: event.reason }
? { type: "finish", reason: event.reason }
: { type: "text-delta", id: "text-0", text: event.text }
const fakeProtocol = Protocol.make<FakeBody, FakeEvent, FakeEvent, void>({
@@ -112,8 +112,8 @@ describe("llm route", () => {
const events = Array.from(yield* llm.stream(request).pipe(Stream.runCollect))
const response = yield* llm.generate(request)
expect(events.map((event) => event.type)).toEqual(["text-delta", "request-finish"])
expect(response.events.map((event) => event.type)).toEqual(["text-delta", "request-finish"])
expect(events.map((event) => event.type)).toEqual(["text-delta", "finish"])
expect(response.events.map((event) => event.type)).toEqual(["text-delta", "finish"])
}),
)

View File

@@ -127,7 +127,7 @@ describe("llm constructors", () => {
LLMResponse.text({
events: [
{ type: "text-delta", id: "text-0", text: "hi" },
{ type: "request-finish", reason: "stop" },
{ type: "finish", reason: "stop" },
],
}),
).toBe("hi")

View File

@@ -124,7 +124,7 @@ describe("Anthropic Messages route", () => {
providerMetadata: { anthropic: { signature: "sig_1" } },
})
expect(response.events.at(-1)).toMatchObject({
type: "request-finish",
type: "finish",
reason: "stop",
providerMetadata: { anthropic: { stopSequence: "\n\nHuman:" } },
})
@@ -182,7 +182,7 @@ describe("Anthropic Messages route", () => {
},
{ type: "step-finish", index: 0, reason: "tool-calls", usage, providerMetadata: undefined },
{
type: "request-finish",
type: "finish",
reason: "tool-calls",
providerMetadata: undefined,
usage,
@@ -275,7 +275,7 @@ describe("Anthropic Messages route", () => {
providerMetadata: { anthropic: { blockType: "web_search_tool_result" } },
})
expect(response.text).toBe("Found it.")
expect(response.events.at(-1)).toMatchObject({ type: "request-finish", reason: "stop" })
expect(response.events.at(-1)).toMatchObject({ type: "finish", reason: "stop" })
}),
)

View File

@@ -169,12 +169,12 @@ describe("Bedrock Converse route", () => {
const response = yield* LLMClient.generate(baseRequest).pipe(Effect.provide(fixedBytes(body)))
expect(response.text).toBe("Hello!")
const finishes = response.events.filter((event) => event.type === "request-finish")
const finishes = response.events.filter((event) => event.type === "finish")
// Bedrock splits the finish across `messageStop` (carries reason) and
// `metadata` (carries usage). We consolidate them into a single
// terminal `request-finish` event with both.
// terminal `finish` event with both.
expect(finishes).toHaveLength(1)
expect(finishes[0]).toMatchObject({ type: "request-finish", reason: "stop" })
expect(finishes[0]).toMatchObject({ type: "finish", reason: "stop" })
expect(response.usage).toMatchObject({
inputTokens: 5,
outputTokens: 2,
@@ -213,7 +213,7 @@ describe("Bedrock Converse route", () => {
{ type: "tool-input-delta", id: "tool_1", name: "lookup", text: '{"query"' },
{ type: "tool-input-delta", id: "tool_1", name: "lookup", text: ':"weather"}' },
])
expect(response.events.at(-1)).toMatchObject({ type: "request-finish", reason: "tool-calls" })
expect(response.events.at(-1)).toMatchObject({ type: "finish", reason: "tool-calls" })
}),
)

View File

@@ -232,7 +232,7 @@ describe("Gemini route", () => {
{ type: "text-end", id: "text-0" },
{ type: "step-finish", index: 0, reason: "stop", usage, providerMetadata: undefined },
{
type: "request-finish",
type: "finish",
reason: "stop",
usage,
},
@@ -291,7 +291,7 @@ describe("Gemini route", () => {
},
{ type: "step-finish", index: 0, reason: "tool-calls", usage, providerMetadata: undefined },
{
type: "request-finish",
type: "finish",
reason: "tool-calls",
usage,
},
@@ -325,7 +325,7 @@ describe("Gemini route", () => {
{ type: "tool-call", id: "tool_0", name: "lookup", input: { query: "weather" } },
{ type: "tool-call", id: "tool_1", name: "lookup", input: { query: "news" } },
])
expect(response.events.at(-1)).toMatchObject({ type: "request-finish", reason: "tool-calls" })
expect(response.events.at(-1)).toMatchObject({ type: "finish", reason: "tool-calls" })
}),
)
@@ -344,10 +344,10 @@ describe("Gemini route", () => {
),
)
expect(length.events.map((event) => event.type)).toEqual(["step-start", "step-finish", "request-finish"])
expect(length.events.at(-1)).toMatchObject({ type: "request-finish", reason: "length" })
expect(filtered.events.map((event) => event.type)).toEqual(["step-start", "step-finish", "request-finish"])
expect(filtered.events.at(-1)).toMatchObject({ type: "request-finish", reason: "content-filter" })
expect(length.events.map((event) => event.type)).toEqual(["step-start", "step-finish", "finish"])
expect(length.events.at(-1)).toMatchObject({ type: "finish", reason: "length" })
expect(filtered.events.map((event) => event.type)).toEqual(["step-start", "step-finish", "finish"])
expect(filtered.events.at(-1)).toMatchObject({ type: "finish", reason: "content-filter" })
}),
)

View File

@@ -249,7 +249,7 @@ describe("OpenAI Chat route", () => {
{ type: "text-end", id: "text-0" },
{ type: "step-finish", index: 0, reason: "stop", usage, providerMetadata: undefined },
{
type: "request-finish",
type: "finish",
reason: "stop",
usage,
},
@@ -288,7 +288,7 @@ describe("OpenAI Chat route", () => {
providerMetadata: undefined,
},
{ type: "step-finish", index: 0, reason: "tool-calls", usage: undefined, providerMetadata: undefined },
{ type: "request-finish", reason: "tool-calls", usage: undefined },
{ type: "finish", reason: "tool-calls", usage: undefined },
])
}),
)

View File

@@ -231,7 +231,7 @@ describe("OpenAI-compatible Chat route", () => {
expect(response.text).toBe("Hello!")
expect(response.usage).toMatchObject({ inputTokens: 5, outputTokens: 2, totalTokens: 7 })
expect(response.events.at(-1)).toMatchObject({ type: "request-finish", reason: "stop" })
expect(response.events.at(-1)).toMatchObject({ type: "finish", reason: "stop" })
}),
)
})

View File

@@ -366,7 +366,7 @@ describe("OpenAI Responses route", () => {
usage,
},
{
type: "request-finish",
type: "finish",
reason: "stop",
providerMetadata: { openai: { responseId: "resp_1", serviceTier: "default" } },
usage,
@@ -447,7 +447,7 @@ describe("OpenAI Responses route", () => {
},
{ type: "step-finish", index: 0, reason: "tool-calls", usage, providerMetadata: undefined },
{
type: "request-finish",
type: "finish",
reason: "tool-calls",
providerMetadata: undefined,
usage,

View File

@@ -120,8 +120,8 @@ export const runWeatherToolLoop = (request: LLMRequest) =>
export const expectFinish = (
events: ReadonlyArray<LLMEvent>,
reason: Extract<LLMEvent, { readonly type: "request-finish" }>["reason"],
) => expect(events.at(-1)).toMatchObject({ type: "request-finish", reason })
reason: Extract<LLMEvent, { readonly type: "finish" }>["reason"],
) => expect(events.at(-1)).toMatchObject({ type: "finish", reason })
export const expectWeatherToolCall = (response: LLMResponse) =>
expect(response.toolCalls).toMatchObject([
@@ -129,10 +129,12 @@ export const expectWeatherToolCall = (response: LLMResponse) =>
])
export const expectWeatherToolLoop = (events: ReadonlyArray<LLMEvent>) => {
const finishes = events.filter(LLMEvent.is.requestFinish)
expect(finishes).toHaveLength(2)
expect(finishes[0]?.reason).toBe("tool-calls")
expect(finishes.at(-1)?.reason).toBe("stop")
const finishes = events.filter(LLMEvent.is.finish)
expect(finishes).toHaveLength(1)
expect(finishes[0]?.reason).toBe("stop")
const stepFinishes = events.filter(LLMEvent.is.stepFinish)
expect(stepFinishes.map((event) => event.reason)).toEqual(["tool-calls", "stop"])
const toolCalls = events.filter(LLMEvent.is.toolCall)
expect(toolCalls).toHaveLength(1)
@@ -272,7 +274,7 @@ export const eventSummary = (events: ReadonlyArray<LLMEvent>) => {
summary.push({ type: "tool-error", name: event.name, message: event.message })
continue
}
if (event.type === "request-finish") {
if (event.type === "finish") {
summary.push({ type: "finish", reason: event.reason, usage: usageSummary(event.usage) })
}
}

View File

@@ -44,6 +44,11 @@ describe("llm schema", () => {
expect(() => Schema.decodeUnknownSync(LLMEvent)({ type: "bogus" })).toThrow()
})
test("finish constructors accept usage input", () => {
expect(LLMEvent.stepFinish({ index: 0, reason: "stop", usage: { inputTokens: 1 } }).usage).toBeInstanceOf(Usage)
expect(LLMEvent.finish({ reason: "stop", usage: { outputTokens: 2 } }).usage).toBeInstanceOf(Usage)
})
test("content part tagged union exposes guards", () => {
expect(ContentPart.guards.text({ type: "text", text: "hi" })).toBe(true)
expect(ContentPart.guards.media({ type: "text", text: "hi" })).toBe(false)

View File

@@ -4,7 +4,8 @@ import { GenerationOptions, LLM, LLMEvent, LLMRequest, LLMResponse, ToolChoice }
import { LLMClient } from "../src/route"
import * as AnthropicMessages from "../src/protocols/anthropic-messages"
import * as OpenAIChat from "../src/protocols/openai-chat"
import { tool, ToolFailure } from "../src/tool"
import { tool, ToolFailure, type ToolExecuteContext } from "../src/tool"
import { ToolRuntime } from "../src/tool-runtime"
import { it } from "./lib/effect"
import * as TestToolRuntime from "./lib/tool-runtime"
import { dynamicResponse, scriptedResponses } from "./lib/http"
@@ -129,7 +130,7 @@ describe("LLMClient tools", () => {
name: "get_weather",
result: { type: "json", value: { temperature: 22, condition: "sunny" } },
})
expect(events.at(-1)?.type).toBe("request-finish")
expect(events.at(-1)?.type).toBe("finish")
expect(LLMResponse.text({ events })).toBe("It's sunny in Paris.")
}),
)
@@ -148,11 +149,40 @@ describe("LLMClient tools", () => {
),
)
expect(events.filter(LLMEvent.is.requestFinish)).toHaveLength(1)
expect(events.filter(LLMEvent.is.finish)).toHaveLength(1)
expect(events.find(LLMEvent.is.toolResult)).toMatchObject({ type: "tool-result", id: "call_1" })
}),
)
it.effect("passes tool call context to execute", () =>
Effect.gen(function* () {
let context: ToolExecuteContext | undefined
const contextual = tool({
description: "Capture tool context.",
parameters: Schema.Struct({ value: Schema.String }),
success: Schema.Struct({ ok: Schema.Boolean }),
execute: (_params, ctx) =>
Effect.sync(() => {
context = ctx
return { ok: true }
}),
})
const events = Array.from(
yield* TestToolRuntime.runTools({ request: baseRequest, tools: { contextual } }).pipe(
Stream.runCollect,
Effect.provide(
scriptedResponses([
sseEvents(toolCallChunk("call_ctx", "contextual", '{"value":"x"}'), finishChunk("tool_calls")),
]),
),
),
)
expect(events.some(LLMEvent.is.toolResult)).toBe(true)
expect(context).toEqual({ id: "call_ctx", name: "contextual" })
}),
)
it.effect("can expose tool schemas without executing tool calls", () =>
Effect.gen(function* () {
const layer = scriptedResponses([
@@ -319,7 +349,7 @@ describe("LLMClient tools", () => {
"text-delta",
"text-end",
"step-finish",
"request-finish",
"finish",
])
expect(LLMResponse.text({ events })).toBe("Done.")
}),
@@ -343,7 +373,57 @@ describe("LLMClient tools", () => {
),
)
expect(events.filter(LLMEvent.is.requestFinish)).toHaveLength(2)
expect(events.filter(LLMEvent.is.finish)).toHaveLength(1)
expect(events.filter(LLMEvent.is.stepStart).map((event) => event.index)).toEqual([0, 1])
expect(events.filter(LLMEvent.is.stepFinish).map((event) => event.index)).toEqual([0, 1])
}),
)
it.effect("emits one final finish with aggregate usage", () =>
Effect.gen(function* () {
let calls = 0
const events = Array.from(
yield* ToolRuntime.stream({
request: baseRequest,
tools: { get_weather },
stopWhen: ToolRuntime.stepCountIs(2),
stream: () =>
Stream.fromIterable<LLMEvent>(
calls++ === 0
? [
LLMEvent.stepStart({ index: 0 }),
LLMEvent.toolCall({ id: "call_1", name: "get_weather", input: { city: "Paris" } }),
LLMEvent.stepFinish({
index: 0,
reason: "tool-calls",
usage: { inputTokens: 1, outputTokens: 2, totalTokens: 3 },
}),
LLMEvent.finish({
reason: "tool-calls",
usage: { inputTokens: 1, outputTokens: 2, totalTokens: 3 },
}),
]
: [
LLMEvent.stepStart({ index: 0 }),
LLMEvent.textDelta({ id: "text_1", text: "Done." }),
LLMEvent.stepFinish({
index: 0,
reason: "stop",
usage: { inputTokens: 4, outputTokens: 5, totalTokens: 9 },
}),
LLMEvent.finish({ reason: "stop", usage: { inputTokens: 4, outputTokens: 5, totalTokens: 9 } }),
],
),
}).pipe(Stream.runCollect),
)
expect(events.filter(LLMEvent.is.stepFinish).map((event) => event.index)).toEqual([0, 1])
expect(events.filter(LLMEvent.is.finish)).toHaveLength(1)
expect(events.find(LLMEvent.is.finish)?.usage).toMatchObject({
inputTokens: 5,
outputTokens: 7,
totalTokens: 12,
})
}),
)
@@ -362,7 +442,7 @@ describe("LLMClient tools", () => {
}).pipe(Stream.runCollect, Effect.provide(layer)),
)
expect(events.filter(LLMEvent.is.requestFinish)).toHaveLength(1)
expect(events.filter(LLMEvent.is.finish)).toHaveLength(1)
expect(events.find(LLMEvent.is.toolResult)).toMatchObject({ type: "tool-result", id: "call_1" })
}),
)