From 3808faf94de462141e64dcffc0f9494c1379ac63 Mon Sep 17 00:00:00 2001 From: shivammittal274 <56757235+shivammittal274@users.noreply.github.com> Date: Mon, 9 Mar 2026 14:22:35 +0530 Subject: [PATCH] =?UTF-8?q?fix:=20robust=20compaction=20with=20Pi-style=20?= =?UTF-8?q?token=20counting=20+=20overflow=20middle=E2=80=A6=20(#444)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: robust compaction with Pi-style token counting + overflow middleware Root cause: getCurrentTokenCount() returned stale inputTokens from the previous step, ignoring new tool results added to messages since that step. A large tool output (DOM snapshot, page content) caused a token jump that bypassed the compaction threshold check, leading to context_length_exceeded errors (322K tokens sent, model max 262K). Layer 1 — Accurate token counting (proactive): - Adopt Pi coding agent's additive approach: base(inputTokens) + outputTokens + estimate(trailing tool results) - Trailing tool results are estimated by walking backwards from end of messages array until a non-tool message is found - Falls back to full estimation with safety multiplier when no real usage data is available (first step of a turn) Layer 2 — Context overflow middleware (reactive): - LanguageModelV3Middleware that wraps doGenerate/doStream - Catches context_length_exceeded errors at the model call level - Truncates prompt (keeps system messages + most recent non-system messages targeting 60% of context window) - Retries the model call once Verified end-to-end with real model (Gemini Flash Lite via OpenRouter) on 16K context window: 4 compactions triggered correctly across 8 steps, no context_length_exceeded errors. Co-Authored-By: Claude Opus 4.6 * fix: adopt Pi-style overflow detection patterns + fix truncation edge case - Replace 6 generic substring matches with 17 provider-specific regex patterns from Pi coding agent (Anthropic, OpenAI, Google, xAI, Groq, OpenRouter, Bedrock, Copilot, llama.cpp, LM Studio, MiniMax, Kimi, Mistral, z.ai) - Fix truncatePrompt edge case: when the last message alone exceeds the target, keepFrom was never updated → empty non-system messages. Now always keeps at least the most recent non-system message. - Add runtime guard for LanguageModelV3 cast in ai-sdk-agent.ts - Add tests for false-positive rejection and truncation edge case Co-Authored-By: Claude Opus 4.6 --------- Co-authored-by: Claude Opus 4.6 --- apps/server/src/agent/ai-sdk-agent.ts | 27 +- apps/server/src/agent/compaction.ts | 29 +- .../src/agent/context-overflow-middleware.ts | 116 ++++++ apps/server/tests/agent/compaction.test.ts | 331 ++++++++++++++++++ 4 files changed, 492 insertions(+), 11 deletions(-) create mode 100644 apps/server/src/agent/context-overflow-middleware.ts diff --git a/apps/server/src/agent/ai-sdk-agent.ts b/apps/server/src/agent/ai-sdk-agent.ts index 4d83ea7d..49b4e0af 100644 --- a/apps/server/src/agent/ai-sdk-agent.ts +++ b/apps/server/src/agent/ai-sdk-agent.ts @@ -1,6 +1,12 @@ +import type { LanguageModelV3 } from '@ai-sdk/provider' import { AGENT_LIMITS } from '@browseros/shared/constants/limits' import type { BrowserContext } from '@browseros/shared/schemas/browser-context' -import { stepCountIs, ToolLoopAgent, type UIMessage } from 'ai' +import { + stepCountIs, + ToolLoopAgent, + type UIMessage, + wrapLanguageModel, +} from 'ai' import type { Browser } from '../browser/browser' import type { KlavisClient } from '../lib/clients/klavis/klavis-client' import { logger } from '../lib/logger' @@ -10,6 +16,7 @@ import { buildMemoryToolSet } from '../tools/memory/build-toolset' import type { ToolRegistry } from '../tools/tool-registry' import { CHAT_MODE_ALLOWED_TOOLS } from './chat-mode' import { createCompactionPrepareStep } from './compaction' +import { createContextOverflowMiddleware } from './context-overflow-middleware' import { buildMcpServerSpecs, createMcpClients } from './mcp-builder' import { buildSystemPrompt } from './prompt' import { createLanguageModel } from './provider-factory' @@ -34,8 +41,19 @@ export class AiSdkAgent { ) {} static async create(config: AiSdkAgentConfig): Promise { - // Build language model from provider config - const model = createLanguageModel(config.resolvedConfig) + const contextWindow = + config.resolvedConfig.contextWindowSize ?? + AGENT_LIMITS.DEFAULT_CONTEXT_WINDOW + + // Build language model with overflow protection middleware + const rawModel = createLanguageModel(config.resolvedConfig) + const model = + (rawModel as any).specificationVersion === 'v3' + ? wrapLanguageModel({ + model: rawModel as LanguageModelV3, + middleware: createContextOverflowMiddleware(contextWindow), + }) + : rawModel // Build browser tools from the unified tool registry const allBrowserTools = buildBrowserToolSet(config.registry, config.browser) @@ -95,9 +113,6 @@ export class AiSdkAgent { }) // Configure compaction for context window management - const contextWindow = - config.resolvedConfig.contextWindowSize ?? - AGENT_LIMITS.DEFAULT_CONTEXT_WINDOW const prepareStep = createCompactionPrepareStep({ contextWindow, }) diff --git a/apps/server/src/agent/compaction.ts b/apps/server/src/agent/compaction.ts index effdee49..5972c25a 100644 --- a/apps/server/src/agent/compaction.ts +++ b/apps/server/src/agent/compaction.ts @@ -157,8 +157,11 @@ export function estimateTokens( return Math.ceil(chars / 4) + imageCount * imageTokenEstimate } -interface StepWithUsage { - usage?: { inputTokens?: number | undefined } +export interface StepWithUsage { + usage?: { + inputTokens?: number | undefined + outputTokens?: number | undefined + } } export function getCurrentTokenCount( @@ -166,15 +169,31 @@ export function getCurrentTokenCount( messages: ModelMessage[], config: ComputedConfig, ): number { - // Use real API usage from the last step when available if (steps.length > 0) { const lastStep = steps[steps.length - 1] if (lastStep.usage?.inputTokens != null && lastStep.usage.inputTokens > 0) { - return lastStep.usage.inputTokens + // Pi-style additive: real usage as base + trailing content added since + const base = lastStep.usage.inputTokens + const outputTokens = lastStep.usage.outputTokens ?? 0 + + // Estimate trailing tool result messages added after the last model call + let trailingTokens = 0 + for (let i = messages.length - 1; i >= 0; i--) { + if (messages[i].role === 'tool') { + trailingTokens += estimateTokens( + [messages[i]], + config.imageTokenEstimate, + ) + } else { + break + } + } + + return base + outputTokens + trailingTokens } } - // Fallback: estimation with safety multiplier + overhead + // No real usage → full estimation with safety margin const estimated = estimateTokens(messages, config.imageTokenEstimate) return Math.ceil(estimated * config.safetyMultiplier) + config.fixedOverhead } diff --git a/apps/server/src/agent/context-overflow-middleware.ts b/apps/server/src/agent/context-overflow-middleware.ts new file mode 100644 index 00000000..a7ba9a68 --- /dev/null +++ b/apps/server/src/agent/context-overflow-middleware.ts @@ -0,0 +1,116 @@ +import type { + LanguageModelV3CallOptions, + LanguageModelV3Message, + LanguageModelV3Middleware, + LanguageModelV3Prompt, +} from '@ai-sdk/provider' +import { logger } from '../lib/logger' + +/** + * Provider-specific regex patterns for context overflow errors. + * Adapted from Pi coding agent's overflow detection. + * + * @see https://github.com/badlogic/pi-mono/blob/main/packages/ai/src/utils/overflow.ts + */ +const OVERFLOW_PATTERNS: RegExp[] = [ + /prompt is too long/i, // Anthropic + /input is too long for requested model/i, // Amazon Bedrock + /exceeds the context window/i, // OpenAI (Completions & Responses API) + /input token count.*exceeds the maximum/i, // Google (Gemini) + /maximum prompt length is \d+/i, // xAI (Grok) + /reduce the length of the messages/i, // Groq + /maximum context length is \d+ tokens/i, // OpenRouter (all backends) + /exceeds the limit of \d+/i, // GitHub Copilot + /exceeds the available context size/i, // llama.cpp server + /greater than the context length/i, // LM Studio + /context window exceeds limit/i, // MiniMax + /exceeded model token limit/i, // Kimi For Coding + /too large for model with \d+ maximum context length/i, // Mistral + /model_context_window_exceeded/i, // z.ai non-standard finish_reason + /context[_ ]length[_ ]exceeded/i, // Generic fallback + /too many tokens/i, // Generic fallback + /token limit exceeded/i, // Generic fallback +] + +export function isContextOverflowError(error: unknown): boolean { + if (!(error instanceof Error)) return false + const msg = error.message + return OVERFLOW_PATTERNS.some((p) => p.test(msg)) +} + +function truncatePrompt( + prompt: LanguageModelV3Prompt, + contextWindow: number, +): LanguageModelV3Prompt { + const systemMessages: LanguageModelV3Message[] = [] + const nonSystem: LanguageModelV3Message[] = [] + for (const m of prompt) { + if (m.role === 'system') systemMessages.push(m) + else nonSystem.push(m) + } + + // Target 60% of context window to leave headroom + const targetChars = contextWindow * 4 * 0.6 + let totalChars = 0 + let keepFrom = nonSystem.length + + for (let i = nonSystem.length - 1; i >= 0; i--) { + totalChars += JSON.stringify(nonSystem[i].content).length + if (totalChars > targetChars) break + keepFrom = i + } + + // Always keep at least the most recent non-system message + if (keepFrom >= nonSystem.length && nonSystem.length > 0) { + keepFrom = nonSystem.length - 1 + } + + const kept: LanguageModelV3Prompt = [ + ...systemMessages, + ...nonSystem.slice(keepFrom), + ] + logger.warn('Emergency prompt truncation', { + original: prompt.length, + kept: kept.length, + dropped: prompt.length - kept.length, + }) + return kept +} + +export function createContextOverflowMiddleware( + contextWindow: number, +): LanguageModelV3Middleware { + return { + specificationVersion: 'v3', + wrapGenerate: async ({ doGenerate, params }) => { + try { + return await doGenerate() + } catch (error) { + if (!isContextOverflowError(error)) throw error + logger.warn( + 'Context overflow detected in doGenerate, truncating and retrying', + ) + ;(params as LanguageModelV3CallOptions).prompt = truncatePrompt( + params.prompt, + contextWindow, + ) + return await doGenerate() + } + }, + wrapStream: async ({ doStream, params }) => { + try { + return await doStream() + } catch (error) { + if (!isContextOverflowError(error)) throw error + logger.warn( + 'Context overflow detected in doStream, truncating and retrying', + ) + ;(params as LanguageModelV3CallOptions).prompt = truncatePrompt( + params.prompt, + contextWindow, + ) + return await doStream() + } + }, + } +} diff --git a/apps/server/tests/agent/compaction.test.ts b/apps/server/tests/agent/compaction.test.ts index d5d2126d..ac9cba1f 100644 --- a/apps/server/tests/agent/compaction.test.ts +++ b/apps/server/tests/agent/compaction.test.ts @@ -4,6 +4,8 @@ import { computeConfig, estimateTokens, findSafeSplitPoint, + getCurrentTokenCount, + type StepWithUsage, slidingWindow, truncateToolOutputs, } from '../../src/agent/compaction' @@ -12,6 +14,10 @@ import { buildTurnPrefixPrompt, messagesToTranscript, } from '../../src/agent/compaction-prompt' +import { + createContextOverflowMiddleware, + isContextOverflowError, +} from '../../src/agent/context-overflow-middleware' // --------------------------------------------------------------------------- // Helpers @@ -732,3 +738,328 @@ describe('end-to-end config coherence', () => { } }) }) + +// --------------------------------------------------------------------------- +// getCurrentTokenCount — Pi-style additive counting +// --------------------------------------------------------------------------- + +describe('getCurrentTokenCount — Pi-style additive', () => { + const config = computeConfig(200_000) + + it('returns estimated with safety margin when no steps exist', () => { + const msgs = [userMsg('a'.repeat(400))] + const result = getCurrentTokenCount([], msgs, config) + const rawEstimate = estimateTokens(msgs, config.imageTokenEstimate) + const expected = + Math.ceil(rawEstimate * config.safetyMultiplier) + config.fixedOverhead + expect(result).toBe(expected) + }) + + it('returns estimated when last step has no usage', () => { + const steps: StepWithUsage[] = [{ usage: undefined }] + const msgs = [userMsg('hello')] + const result = getCurrentTokenCount(steps, msgs, config) + const rawEstimate = estimateTokens(msgs, config.imageTokenEstimate) + const expected = + Math.ceil(rawEstimate * config.safetyMultiplier) + config.fixedOverhead + expect(result).toBe(expected) + }) + + it('adds outputTokens to base when no trailing tool results', () => { + const steps: StepWithUsage[] = [ + { usage: { inputTokens: 50_000, outputTokens: 2_000 } }, + ] + const msgs = [userMsg('hello'), assistantMsg('response')] + const result = getCurrentTokenCount(steps, msgs, config) + expect(result).toBe(52_000) + }) + + it('adds trailing tool result tokens to base + output', () => { + const toolOutput = 'x'.repeat(40_000) // ~10K tokens + const steps: StepWithUsage[] = [ + { usage: { inputTokens: 100_000, outputTokens: 1_000 } }, + ] + const msgs = [ + userMsg('hello'), + assistantToolCall('snapshot', {}), + toolResult('snapshot', toolOutput), + ] + + const result = getCurrentTokenCount(steps, msgs, config) + const expectedTrailing = estimateTokens( + [toolResult('snapshot', toolOutput)], + config.imageTokenEstimate, + ) + expect(result).toBe(100_000 + 1_000 + expectedTrailing) + }) + + it('catches large DOM snapshot that would bypass threshold', () => { + // Simulates the original bug: last step saw 150K tokens, + // then a 100K-char tool result (~25K tokens) is added + const largeSnapshot = 'x'.repeat(100_000) + const steps: StepWithUsage[] = [ + { usage: { inputTokens: 150_000, outputTokens: 500 } }, + ] + const msgs = [ + userMsg('navigate to site'), + assistantToolCall('snapshot', {}), + toolResult('snapshot', largeSnapshot), + ] + + const result = getCurrentTokenCount(steps, msgs, config) + // Must be significantly above 150K — the old code returned 150K (stale) + expect(result).toBeGreaterThan(170_000) + }) + + it('counts multiple trailing tool results', () => { + const steps: StepWithUsage[] = [ + { usage: { inputTokens: 80_000, outputTokens: 1_000 } }, + ] + const msgs = [ + userMsg('do things'), + assistantToolCall('click', { selector: '#btn' }), + toolResult('click', 'x'.repeat(4_000)), + toolResult('snapshot', 'y'.repeat(8_000)), + ] + + const result = getCurrentTokenCount(steps, msgs, config) + const trailing1 = estimateTokens( + [toolResult('click', 'x'.repeat(4_000))], + config.imageTokenEstimate, + ) + const trailing2 = estimateTokens( + [toolResult('snapshot', 'y'.repeat(8_000))], + config.imageTokenEstimate, + ) + expect(result).toBe(80_000 + 1_000 + trailing1 + trailing2) + }) + + it('stops counting trailing at first non-tool message', () => { + const steps: StepWithUsage[] = [ + { usage: { inputTokens: 50_000, outputTokens: 500 } }, + ] + // assistant message after tool results — trailing should be 0 + const msgs = [ + userMsg('hello'), + assistantToolCall('click', {}), + toolResult('click', 'x'.repeat(4_000)), + assistantMsg('done'), + ] + + const result = getCurrentTokenCount(steps, msgs, config) + // No trailing tool results (last message is assistant) + expect(result).toBe(50_500) + }) + + it('handles zero outputTokens gracefully', () => { + const steps: StepWithUsage[] = [{ usage: { inputTokens: 50_000 } }] + const msgs = [userMsg('hello')] + const result = getCurrentTokenCount(steps, msgs, config) + expect(result).toBe(50_000) + }) +}) + +// --------------------------------------------------------------------------- +// Context overflow middleware +// --------------------------------------------------------------------------- + +describe('createContextOverflowMiddleware', () => { + it('passes through when model succeeds', async () => { + const middleware = createContextOverflowMiddleware(200_000) + const mockResult = { text: 'hello' } + const params = { + prompt: [ + { role: 'system', content: 'You are helpful' }, + { role: 'user', content: 'hi' }, + ], + } + + const result = await middleware.wrapGenerate!({ + doGenerate: async () => mockResult, + params, + } as any) + + expect(result).toBe(mockResult) + }) + + it('rethrows non-context errors', async () => { + const middleware = createContextOverflowMiddleware(200_000) + const params = { + prompt: [{ role: 'user', content: 'hi' }], + } + + await expect( + middleware.wrapGenerate!({ + doGenerate: async () => { + throw new Error('network timeout') + }, + params, + } as any), + ).rejects.toThrow('network timeout') + }) + + it('truncates and retries on context_length error', async () => { + const middleware = createContextOverflowMiddleware(200_000) + let callCount = 0 + const params = { + prompt: [ + { role: 'system', content: 'system prompt' }, + { role: 'user', content: 'old message 1' }, + { role: 'assistant', content: 'old response 1' }, + { role: 'user', content: 'old message 2' }, + { role: 'assistant', content: 'old response 2' }, + { role: 'user', content: 'recent message' }, + ], + } + + const result = await middleware.wrapGenerate!({ + doGenerate: async () => { + callCount++ + if (callCount === 1) { + throw new Error('context_length_exceeded') + } + return { text: 'success after truncation' } + }, + params, + } as any) + + expect(callCount).toBe(2) + expect(result).toEqual({ text: 'success after truncation' }) + // System message should be preserved + expect(params.prompt.some((m: any) => m.role === 'system')).toBe(true) + // Prompt should be shorter after truncation + expect(params.prompt.length).toBeLessThanOrEqual(6) + }) + + it('preserves system messages during truncation', async () => { + const middleware = createContextOverflowMiddleware(10_000) + let truncatedPrompt: any[] = [] + const params = { + prompt: [ + { role: 'system', content: 'important system prompt' }, + { role: 'user', content: 'a'.repeat(50_000) }, + { role: 'assistant', content: 'b'.repeat(50_000) }, + { role: 'user', content: 'recent' }, + ], + } + + await middleware.wrapGenerate!({ + doGenerate: async () => { + if (truncatedPrompt.length === 0) { + truncatedPrompt = [...params.prompt] + throw new Error('maximum context length exceeded') + } + truncatedPrompt = [...params.prompt] + return { text: 'ok' } + }, + params, + } as any) + + const systemMsgs = truncatedPrompt.filter((m: any) => m.role === 'system') + expect(systemMsgs.length).toBe(1) + expect(systemMsgs[0].content).toBe('important system prompt') + }) + + it('handles wrapStream the same way', async () => { + const middleware = createContextOverflowMiddleware(200_000) + let callCount = 0 + const params = { + prompt: [ + { role: 'system', content: 'system' }, + { role: 'user', content: 'message' }, + ], + } + + const result = await middleware.wrapStream!({ + doStream: async () => { + callCount++ + if (callCount === 1) { + throw new Error('token limit exceeded') + } + return { stream: 'mock-stream' } + }, + params, + } as any) + + expect(callCount).toBe(2) + expect(result).toEqual({ stream: 'mock-stream' }) + }) + + it('detects provider-specific context overflow errors', async () => { + const middleware = createContextOverflowMiddleware(200_000) + const errorMessages = [ + 'context_length_exceeded', // Generic + 'prompt is too long: 213462 tokens > 200000 maximum', // Anthropic + 'Your input exceeds the context window of this model', // OpenAI + 'The input token count (1196265) exceeds the maximum number of tokens allowed', // Google + "This model's maximum prompt length is 131072 but the request contains 537812 tokens", // xAI + 'Please reduce the length of the messages or completion', // Groq + 'maximum context length is 128000 tokens', // OpenRouter + 'token limit exceeded', // Generic + 'too many tokens', // Generic + 'exceeded model token limit', // Kimi + 'input is too long for requested model', // Amazon Bedrock + ] + + for (const errMsg of errorMessages) { + let callCount = 0 + const params = { + prompt: [{ role: 'user', content: 'hi' }], + } + + await middleware.wrapGenerate!({ + doGenerate: async () => { + callCount++ + if (callCount === 1) throw new Error(errMsg) + return { text: 'ok' } + }, + params, + } as any) + + expect(callCount).toBe(2) + } + }) + + it('does not false-positive on unrelated errors', () => { + const unrelatedErrors = [ + 'URL is too long', + 'Invalid max_tokens: must be between 1 and 4096', + 'session token is too long', + 'file name is too long', + 'network timeout', + 'rate limit exceeded', + ] + + for (const errMsg of unrelatedErrors) { + expect(isContextOverflowError(new Error(errMsg))).toBe(false) + } + }) + + it('keeps at least the last non-system message when it exceeds target', async () => { + const middleware = createContextOverflowMiddleware(1_000) + let truncatedPrompt: any[] = [] + const params = { + prompt: [ + { role: 'system', content: 'system' }, + { role: 'user', content: 'x'.repeat(100_000) }, + ], + } + + await middleware.wrapGenerate!({ + doGenerate: async () => { + if (truncatedPrompt.length === 0) { + truncatedPrompt = [...params.prompt] + throw new Error('context_length_exceeded') + } + truncatedPrompt = [...params.prompt] + return { text: 'ok' } + }, + params, + } as any) + + // Must keep system + at least the last user message (not empty) + expect(truncatedPrompt.length).toBe(2) + expect(truncatedPrompt[0].role).toBe('system') + expect(truncatedPrompt[1].role).toBe('user') + }) +})