mirror of
https://github.com/browseros-ai/BrowserOS.git
synced 2026-05-18 11:06:19 +00:00
fix: robust compaction with Pi-style token counting + overflow middle… (#444)
* fix: robust compaction with Pi-style token counting + overflow middleware Root cause: getCurrentTokenCount() returned stale inputTokens from the previous step, ignoring new tool results added to messages since that step. A large tool output (DOM snapshot, page content) caused a token jump that bypassed the compaction threshold check, leading to context_length_exceeded errors (322K tokens sent, model max 262K). Layer 1 — Accurate token counting (proactive): - Adopt Pi coding agent's additive approach: base(inputTokens) + outputTokens + estimate(trailing tool results) - Trailing tool results are estimated by walking backwards from end of messages array until a non-tool message is found - Falls back to full estimation with safety multiplier when no real usage data is available (first step of a turn) Layer 2 — Context overflow middleware (reactive): - LanguageModelV3Middleware that wraps doGenerate/doStream - Catches context_length_exceeded errors at the model call level - Truncates prompt (keeps system messages + most recent non-system messages targeting 60% of context window) - Retries the model call once Verified end-to-end with real model (Gemini Flash Lite via OpenRouter) on 16K context window: 4 compactions triggered correctly across 8 steps, no context_length_exceeded errors. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: adopt Pi-style overflow detection patterns + fix truncation edge case - Replace 6 generic substring matches with 17 provider-specific regex patterns from Pi coding agent (Anthropic, OpenAI, Google, xAI, Groq, OpenRouter, Bedrock, Copilot, llama.cpp, LM Studio, MiniMax, Kimi, Mistral, z.ai) - Fix truncatePrompt edge case: when the last message alone exceeds the target, keepFrom was never updated → empty non-system messages. Now always keeps at least the most recent non-system message. - Add runtime guard for LanguageModelV3 cast in ai-sdk-agent.ts - Add tests for false-positive rejection and truncation edge case Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,6 +1,12 @@
|
||||
import type { LanguageModelV3 } from '@ai-sdk/provider'
|
||||
import { AGENT_LIMITS } from '@browseros/shared/constants/limits'
|
||||
import type { BrowserContext } from '@browseros/shared/schemas/browser-context'
|
||||
import { stepCountIs, ToolLoopAgent, type UIMessage } from 'ai'
|
||||
import {
|
||||
stepCountIs,
|
||||
ToolLoopAgent,
|
||||
type UIMessage,
|
||||
wrapLanguageModel,
|
||||
} from 'ai'
|
||||
import type { Browser } from '../browser/browser'
|
||||
import type { KlavisClient } from '../lib/clients/klavis/klavis-client'
|
||||
import { logger } from '../lib/logger'
|
||||
@@ -10,6 +16,7 @@ import { buildMemoryToolSet } from '../tools/memory/build-toolset'
|
||||
import type { ToolRegistry } from '../tools/tool-registry'
|
||||
import { CHAT_MODE_ALLOWED_TOOLS } from './chat-mode'
|
||||
import { createCompactionPrepareStep } from './compaction'
|
||||
import { createContextOverflowMiddleware } from './context-overflow-middleware'
|
||||
import { buildMcpServerSpecs, createMcpClients } from './mcp-builder'
|
||||
import { buildSystemPrompt } from './prompt'
|
||||
import { createLanguageModel } from './provider-factory'
|
||||
@@ -34,8 +41,19 @@ export class AiSdkAgent {
|
||||
) {}
|
||||
|
||||
static async create(config: AiSdkAgentConfig): Promise<AiSdkAgent> {
|
||||
// Build language model from provider config
|
||||
const model = createLanguageModel(config.resolvedConfig)
|
||||
const contextWindow =
|
||||
config.resolvedConfig.contextWindowSize ??
|
||||
AGENT_LIMITS.DEFAULT_CONTEXT_WINDOW
|
||||
|
||||
// Build language model with overflow protection middleware
|
||||
const rawModel = createLanguageModel(config.resolvedConfig)
|
||||
const model =
|
||||
(rawModel as any).specificationVersion === 'v3'
|
||||
? wrapLanguageModel({
|
||||
model: rawModel as LanguageModelV3,
|
||||
middleware: createContextOverflowMiddleware(contextWindow),
|
||||
})
|
||||
: rawModel
|
||||
|
||||
// Build browser tools from the unified tool registry
|
||||
const allBrowserTools = buildBrowserToolSet(config.registry, config.browser)
|
||||
@@ -95,9 +113,6 @@ export class AiSdkAgent {
|
||||
})
|
||||
|
||||
// Configure compaction for context window management
|
||||
const contextWindow =
|
||||
config.resolvedConfig.contextWindowSize ??
|
||||
AGENT_LIMITS.DEFAULT_CONTEXT_WINDOW
|
||||
const prepareStep = createCompactionPrepareStep({
|
||||
contextWindow,
|
||||
})
|
||||
|
||||
@@ -157,8 +157,11 @@ export function estimateTokens(
|
||||
return Math.ceil(chars / 4) + imageCount * imageTokenEstimate
|
||||
}
|
||||
|
||||
interface StepWithUsage {
|
||||
usage?: { inputTokens?: number | undefined }
|
||||
export interface StepWithUsage {
|
||||
usage?: {
|
||||
inputTokens?: number | undefined
|
||||
outputTokens?: number | undefined
|
||||
}
|
||||
}
|
||||
|
||||
export function getCurrentTokenCount(
|
||||
@@ -166,15 +169,31 @@ export function getCurrentTokenCount(
|
||||
messages: ModelMessage[],
|
||||
config: ComputedConfig,
|
||||
): number {
|
||||
// Use real API usage from the last step when available
|
||||
if (steps.length > 0) {
|
||||
const lastStep = steps[steps.length - 1]
|
||||
if (lastStep.usage?.inputTokens != null && lastStep.usage.inputTokens > 0) {
|
||||
return lastStep.usage.inputTokens
|
||||
// Pi-style additive: real usage as base + trailing content added since
|
||||
const base = lastStep.usage.inputTokens
|
||||
const outputTokens = lastStep.usage.outputTokens ?? 0
|
||||
|
||||
// Estimate trailing tool result messages added after the last model call
|
||||
let trailingTokens = 0
|
||||
for (let i = messages.length - 1; i >= 0; i--) {
|
||||
if (messages[i].role === 'tool') {
|
||||
trailingTokens += estimateTokens(
|
||||
[messages[i]],
|
||||
config.imageTokenEstimate,
|
||||
)
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return base + outputTokens + trailingTokens
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: estimation with safety multiplier + overhead
|
||||
// No real usage → full estimation with safety margin
|
||||
const estimated = estimateTokens(messages, config.imageTokenEstimate)
|
||||
return Math.ceil(estimated * config.safetyMultiplier) + config.fixedOverhead
|
||||
}
|
||||
|
||||
116
apps/server/src/agent/context-overflow-middleware.ts
Normal file
116
apps/server/src/agent/context-overflow-middleware.ts
Normal file
@@ -0,0 +1,116 @@
|
||||
import type {
|
||||
LanguageModelV3CallOptions,
|
||||
LanguageModelV3Message,
|
||||
LanguageModelV3Middleware,
|
||||
LanguageModelV3Prompt,
|
||||
} from '@ai-sdk/provider'
|
||||
import { logger } from '../lib/logger'
|
||||
|
||||
/**
|
||||
* Provider-specific regex patterns for context overflow errors.
|
||||
* Adapted from Pi coding agent's overflow detection.
|
||||
*
|
||||
* @see https://github.com/badlogic/pi-mono/blob/main/packages/ai/src/utils/overflow.ts
|
||||
*/
|
||||
const OVERFLOW_PATTERNS: RegExp[] = [
|
||||
/prompt is too long/i, // Anthropic
|
||||
/input is too long for requested model/i, // Amazon Bedrock
|
||||
/exceeds the context window/i, // OpenAI (Completions & Responses API)
|
||||
/input token count.*exceeds the maximum/i, // Google (Gemini)
|
||||
/maximum prompt length is \d+/i, // xAI (Grok)
|
||||
/reduce the length of the messages/i, // Groq
|
||||
/maximum context length is \d+ tokens/i, // OpenRouter (all backends)
|
||||
/exceeds the limit of \d+/i, // GitHub Copilot
|
||||
/exceeds the available context size/i, // llama.cpp server
|
||||
/greater than the context length/i, // LM Studio
|
||||
/context window exceeds limit/i, // MiniMax
|
||||
/exceeded model token limit/i, // Kimi For Coding
|
||||
/too large for model with \d+ maximum context length/i, // Mistral
|
||||
/model_context_window_exceeded/i, // z.ai non-standard finish_reason
|
||||
/context[_ ]length[_ ]exceeded/i, // Generic fallback
|
||||
/too many tokens/i, // Generic fallback
|
||||
/token limit exceeded/i, // Generic fallback
|
||||
]
|
||||
|
||||
export function isContextOverflowError(error: unknown): boolean {
|
||||
if (!(error instanceof Error)) return false
|
||||
const msg = error.message
|
||||
return OVERFLOW_PATTERNS.some((p) => p.test(msg))
|
||||
}
|
||||
|
||||
function truncatePrompt(
|
||||
prompt: LanguageModelV3Prompt,
|
||||
contextWindow: number,
|
||||
): LanguageModelV3Prompt {
|
||||
const systemMessages: LanguageModelV3Message[] = []
|
||||
const nonSystem: LanguageModelV3Message[] = []
|
||||
for (const m of prompt) {
|
||||
if (m.role === 'system') systemMessages.push(m)
|
||||
else nonSystem.push(m)
|
||||
}
|
||||
|
||||
// Target 60% of context window to leave headroom
|
||||
const targetChars = contextWindow * 4 * 0.6
|
||||
let totalChars = 0
|
||||
let keepFrom = nonSystem.length
|
||||
|
||||
for (let i = nonSystem.length - 1; i >= 0; i--) {
|
||||
totalChars += JSON.stringify(nonSystem[i].content).length
|
||||
if (totalChars > targetChars) break
|
||||
keepFrom = i
|
||||
}
|
||||
|
||||
// Always keep at least the most recent non-system message
|
||||
if (keepFrom >= nonSystem.length && nonSystem.length > 0) {
|
||||
keepFrom = nonSystem.length - 1
|
||||
}
|
||||
|
||||
const kept: LanguageModelV3Prompt = [
|
||||
...systemMessages,
|
||||
...nonSystem.slice(keepFrom),
|
||||
]
|
||||
logger.warn('Emergency prompt truncation', {
|
||||
original: prompt.length,
|
||||
kept: kept.length,
|
||||
dropped: prompt.length - kept.length,
|
||||
})
|
||||
return kept
|
||||
}
|
||||
|
||||
export function createContextOverflowMiddleware(
|
||||
contextWindow: number,
|
||||
): LanguageModelV3Middleware {
|
||||
return {
|
||||
specificationVersion: 'v3',
|
||||
wrapGenerate: async ({ doGenerate, params }) => {
|
||||
try {
|
||||
return await doGenerate()
|
||||
} catch (error) {
|
||||
if (!isContextOverflowError(error)) throw error
|
||||
logger.warn(
|
||||
'Context overflow detected in doGenerate, truncating and retrying',
|
||||
)
|
||||
;(params as LanguageModelV3CallOptions).prompt = truncatePrompt(
|
||||
params.prompt,
|
||||
contextWindow,
|
||||
)
|
||||
return await doGenerate()
|
||||
}
|
||||
},
|
||||
wrapStream: async ({ doStream, params }) => {
|
||||
try {
|
||||
return await doStream()
|
||||
} catch (error) {
|
||||
if (!isContextOverflowError(error)) throw error
|
||||
logger.warn(
|
||||
'Context overflow detected in doStream, truncating and retrying',
|
||||
)
|
||||
;(params as LanguageModelV3CallOptions).prompt = truncatePrompt(
|
||||
params.prompt,
|
||||
contextWindow,
|
||||
)
|
||||
return await doStream()
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -4,6 +4,8 @@ import {
|
||||
computeConfig,
|
||||
estimateTokens,
|
||||
findSafeSplitPoint,
|
||||
getCurrentTokenCount,
|
||||
type StepWithUsage,
|
||||
slidingWindow,
|
||||
truncateToolOutputs,
|
||||
} from '../../src/agent/compaction'
|
||||
@@ -12,6 +14,10 @@ import {
|
||||
buildTurnPrefixPrompt,
|
||||
messagesToTranscript,
|
||||
} from '../../src/agent/compaction-prompt'
|
||||
import {
|
||||
createContextOverflowMiddleware,
|
||||
isContextOverflowError,
|
||||
} from '../../src/agent/context-overflow-middleware'
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
@@ -732,3 +738,328 @@ describe('end-to-end config coherence', () => {
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// getCurrentTokenCount — Pi-style additive counting
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
describe('getCurrentTokenCount — Pi-style additive', () => {
|
||||
const config = computeConfig(200_000)
|
||||
|
||||
it('returns estimated with safety margin when no steps exist', () => {
|
||||
const msgs = [userMsg('a'.repeat(400))]
|
||||
const result = getCurrentTokenCount([], msgs, config)
|
||||
const rawEstimate = estimateTokens(msgs, config.imageTokenEstimate)
|
||||
const expected =
|
||||
Math.ceil(rawEstimate * config.safetyMultiplier) + config.fixedOverhead
|
||||
expect(result).toBe(expected)
|
||||
})
|
||||
|
||||
it('returns estimated when last step has no usage', () => {
|
||||
const steps: StepWithUsage[] = [{ usage: undefined }]
|
||||
const msgs = [userMsg('hello')]
|
||||
const result = getCurrentTokenCount(steps, msgs, config)
|
||||
const rawEstimate = estimateTokens(msgs, config.imageTokenEstimate)
|
||||
const expected =
|
||||
Math.ceil(rawEstimate * config.safetyMultiplier) + config.fixedOverhead
|
||||
expect(result).toBe(expected)
|
||||
})
|
||||
|
||||
it('adds outputTokens to base when no trailing tool results', () => {
|
||||
const steps: StepWithUsage[] = [
|
||||
{ usage: { inputTokens: 50_000, outputTokens: 2_000 } },
|
||||
]
|
||||
const msgs = [userMsg('hello'), assistantMsg('response')]
|
||||
const result = getCurrentTokenCount(steps, msgs, config)
|
||||
expect(result).toBe(52_000)
|
||||
})
|
||||
|
||||
it('adds trailing tool result tokens to base + output', () => {
|
||||
const toolOutput = 'x'.repeat(40_000) // ~10K tokens
|
||||
const steps: StepWithUsage[] = [
|
||||
{ usage: { inputTokens: 100_000, outputTokens: 1_000 } },
|
||||
]
|
||||
const msgs = [
|
||||
userMsg('hello'),
|
||||
assistantToolCall('snapshot', {}),
|
||||
toolResult('snapshot', toolOutput),
|
||||
]
|
||||
|
||||
const result = getCurrentTokenCount(steps, msgs, config)
|
||||
const expectedTrailing = estimateTokens(
|
||||
[toolResult('snapshot', toolOutput)],
|
||||
config.imageTokenEstimate,
|
||||
)
|
||||
expect(result).toBe(100_000 + 1_000 + expectedTrailing)
|
||||
})
|
||||
|
||||
it('catches large DOM snapshot that would bypass threshold', () => {
|
||||
// Simulates the original bug: last step saw 150K tokens,
|
||||
// then a 100K-char tool result (~25K tokens) is added
|
||||
const largeSnapshot = 'x'.repeat(100_000)
|
||||
const steps: StepWithUsage[] = [
|
||||
{ usage: { inputTokens: 150_000, outputTokens: 500 } },
|
||||
]
|
||||
const msgs = [
|
||||
userMsg('navigate to site'),
|
||||
assistantToolCall('snapshot', {}),
|
||||
toolResult('snapshot', largeSnapshot),
|
||||
]
|
||||
|
||||
const result = getCurrentTokenCount(steps, msgs, config)
|
||||
// Must be significantly above 150K — the old code returned 150K (stale)
|
||||
expect(result).toBeGreaterThan(170_000)
|
||||
})
|
||||
|
||||
it('counts multiple trailing tool results', () => {
|
||||
const steps: StepWithUsage[] = [
|
||||
{ usage: { inputTokens: 80_000, outputTokens: 1_000 } },
|
||||
]
|
||||
const msgs = [
|
||||
userMsg('do things'),
|
||||
assistantToolCall('click', { selector: '#btn' }),
|
||||
toolResult('click', 'x'.repeat(4_000)),
|
||||
toolResult('snapshot', 'y'.repeat(8_000)),
|
||||
]
|
||||
|
||||
const result = getCurrentTokenCount(steps, msgs, config)
|
||||
const trailing1 = estimateTokens(
|
||||
[toolResult('click', 'x'.repeat(4_000))],
|
||||
config.imageTokenEstimate,
|
||||
)
|
||||
const trailing2 = estimateTokens(
|
||||
[toolResult('snapshot', 'y'.repeat(8_000))],
|
||||
config.imageTokenEstimate,
|
||||
)
|
||||
expect(result).toBe(80_000 + 1_000 + trailing1 + trailing2)
|
||||
})
|
||||
|
||||
it('stops counting trailing at first non-tool message', () => {
|
||||
const steps: StepWithUsage[] = [
|
||||
{ usage: { inputTokens: 50_000, outputTokens: 500 } },
|
||||
]
|
||||
// assistant message after tool results — trailing should be 0
|
||||
const msgs = [
|
||||
userMsg('hello'),
|
||||
assistantToolCall('click', {}),
|
||||
toolResult('click', 'x'.repeat(4_000)),
|
||||
assistantMsg('done'),
|
||||
]
|
||||
|
||||
const result = getCurrentTokenCount(steps, msgs, config)
|
||||
// No trailing tool results (last message is assistant)
|
||||
expect(result).toBe(50_500)
|
||||
})
|
||||
|
||||
it('handles zero outputTokens gracefully', () => {
|
||||
const steps: StepWithUsage[] = [{ usage: { inputTokens: 50_000 } }]
|
||||
const msgs = [userMsg('hello')]
|
||||
const result = getCurrentTokenCount(steps, msgs, config)
|
||||
expect(result).toBe(50_000)
|
||||
})
|
||||
})
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Context overflow middleware
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
describe('createContextOverflowMiddleware', () => {
|
||||
it('passes through when model succeeds', async () => {
|
||||
const middleware = createContextOverflowMiddleware(200_000)
|
||||
const mockResult = { text: 'hello' }
|
||||
const params = {
|
||||
prompt: [
|
||||
{ role: 'system', content: 'You are helpful' },
|
||||
{ role: 'user', content: 'hi' },
|
||||
],
|
||||
}
|
||||
|
||||
const result = await middleware.wrapGenerate!({
|
||||
doGenerate: async () => mockResult,
|
||||
params,
|
||||
} as any)
|
||||
|
||||
expect(result).toBe(mockResult)
|
||||
})
|
||||
|
||||
it('rethrows non-context errors', async () => {
|
||||
const middleware = createContextOverflowMiddleware(200_000)
|
||||
const params = {
|
||||
prompt: [{ role: 'user', content: 'hi' }],
|
||||
}
|
||||
|
||||
await expect(
|
||||
middleware.wrapGenerate!({
|
||||
doGenerate: async () => {
|
||||
throw new Error('network timeout')
|
||||
},
|
||||
params,
|
||||
} as any),
|
||||
).rejects.toThrow('network timeout')
|
||||
})
|
||||
|
||||
it('truncates and retries on context_length error', async () => {
|
||||
const middleware = createContextOverflowMiddleware(200_000)
|
||||
let callCount = 0
|
||||
const params = {
|
||||
prompt: [
|
||||
{ role: 'system', content: 'system prompt' },
|
||||
{ role: 'user', content: 'old message 1' },
|
||||
{ role: 'assistant', content: 'old response 1' },
|
||||
{ role: 'user', content: 'old message 2' },
|
||||
{ role: 'assistant', content: 'old response 2' },
|
||||
{ role: 'user', content: 'recent message' },
|
||||
],
|
||||
}
|
||||
|
||||
const result = await middleware.wrapGenerate!({
|
||||
doGenerate: async () => {
|
||||
callCount++
|
||||
if (callCount === 1) {
|
||||
throw new Error('context_length_exceeded')
|
||||
}
|
||||
return { text: 'success after truncation' }
|
||||
},
|
||||
params,
|
||||
} as any)
|
||||
|
||||
expect(callCount).toBe(2)
|
||||
expect(result).toEqual({ text: 'success after truncation' })
|
||||
// System message should be preserved
|
||||
expect(params.prompt.some((m: any) => m.role === 'system')).toBe(true)
|
||||
// Prompt should be shorter after truncation
|
||||
expect(params.prompt.length).toBeLessThanOrEqual(6)
|
||||
})
|
||||
|
||||
it('preserves system messages during truncation', async () => {
|
||||
const middleware = createContextOverflowMiddleware(10_000)
|
||||
let truncatedPrompt: any[] = []
|
||||
const params = {
|
||||
prompt: [
|
||||
{ role: 'system', content: 'important system prompt' },
|
||||
{ role: 'user', content: 'a'.repeat(50_000) },
|
||||
{ role: 'assistant', content: 'b'.repeat(50_000) },
|
||||
{ role: 'user', content: 'recent' },
|
||||
],
|
||||
}
|
||||
|
||||
await middleware.wrapGenerate!({
|
||||
doGenerate: async () => {
|
||||
if (truncatedPrompt.length === 0) {
|
||||
truncatedPrompt = [...params.prompt]
|
||||
throw new Error('maximum context length exceeded')
|
||||
}
|
||||
truncatedPrompt = [...params.prompt]
|
||||
return { text: 'ok' }
|
||||
},
|
||||
params,
|
||||
} as any)
|
||||
|
||||
const systemMsgs = truncatedPrompt.filter((m: any) => m.role === 'system')
|
||||
expect(systemMsgs.length).toBe(1)
|
||||
expect(systemMsgs[0].content).toBe('important system prompt')
|
||||
})
|
||||
|
||||
it('handles wrapStream the same way', async () => {
|
||||
const middleware = createContextOverflowMiddleware(200_000)
|
||||
let callCount = 0
|
||||
const params = {
|
||||
prompt: [
|
||||
{ role: 'system', content: 'system' },
|
||||
{ role: 'user', content: 'message' },
|
||||
],
|
||||
}
|
||||
|
||||
const result = await middleware.wrapStream!({
|
||||
doStream: async () => {
|
||||
callCount++
|
||||
if (callCount === 1) {
|
||||
throw new Error('token limit exceeded')
|
||||
}
|
||||
return { stream: 'mock-stream' }
|
||||
},
|
||||
params,
|
||||
} as any)
|
||||
|
||||
expect(callCount).toBe(2)
|
||||
expect(result).toEqual({ stream: 'mock-stream' })
|
||||
})
|
||||
|
||||
it('detects provider-specific context overflow errors', async () => {
|
||||
const middleware = createContextOverflowMiddleware(200_000)
|
||||
const errorMessages = [
|
||||
'context_length_exceeded', // Generic
|
||||
'prompt is too long: 213462 tokens > 200000 maximum', // Anthropic
|
||||
'Your input exceeds the context window of this model', // OpenAI
|
||||
'The input token count (1196265) exceeds the maximum number of tokens allowed', // Google
|
||||
"This model's maximum prompt length is 131072 but the request contains 537812 tokens", // xAI
|
||||
'Please reduce the length of the messages or completion', // Groq
|
||||
'maximum context length is 128000 tokens', // OpenRouter
|
||||
'token limit exceeded', // Generic
|
||||
'too many tokens', // Generic
|
||||
'exceeded model token limit', // Kimi
|
||||
'input is too long for requested model', // Amazon Bedrock
|
||||
]
|
||||
|
||||
for (const errMsg of errorMessages) {
|
||||
let callCount = 0
|
||||
const params = {
|
||||
prompt: [{ role: 'user', content: 'hi' }],
|
||||
}
|
||||
|
||||
await middleware.wrapGenerate!({
|
||||
doGenerate: async () => {
|
||||
callCount++
|
||||
if (callCount === 1) throw new Error(errMsg)
|
||||
return { text: 'ok' }
|
||||
},
|
||||
params,
|
||||
} as any)
|
||||
|
||||
expect(callCount).toBe(2)
|
||||
}
|
||||
})
|
||||
|
||||
it('does not false-positive on unrelated errors', () => {
|
||||
const unrelatedErrors = [
|
||||
'URL is too long',
|
||||
'Invalid max_tokens: must be between 1 and 4096',
|
||||
'session token is too long',
|
||||
'file name is too long',
|
||||
'network timeout',
|
||||
'rate limit exceeded',
|
||||
]
|
||||
|
||||
for (const errMsg of unrelatedErrors) {
|
||||
expect(isContextOverflowError(new Error(errMsg))).toBe(false)
|
||||
}
|
||||
})
|
||||
|
||||
it('keeps at least the last non-system message when it exceeds target', async () => {
|
||||
const middleware = createContextOverflowMiddleware(1_000)
|
||||
let truncatedPrompt: any[] = []
|
||||
const params = {
|
||||
prompt: [
|
||||
{ role: 'system', content: 'system' },
|
||||
{ role: 'user', content: 'x'.repeat(100_000) },
|
||||
],
|
||||
}
|
||||
|
||||
await middleware.wrapGenerate!({
|
||||
doGenerate: async () => {
|
||||
if (truncatedPrompt.length === 0) {
|
||||
truncatedPrompt = [...params.prompt]
|
||||
throw new Error('context_length_exceeded')
|
||||
}
|
||||
truncatedPrompt = [...params.prompt]
|
||||
return { text: 'ok' }
|
||||
},
|
||||
params,
|
||||
} as any)
|
||||
|
||||
// Must keep system + at least the last user message (not empty)
|
||||
expect(truncatedPrompt.length).toBe(2)
|
||||
expect(truncatedPrompt[0].role).toBe('system')
|
||||
expect(truncatedPrompt[1].role).toBe('user')
|
||||
})
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user