fix: robust compaction with Pi-style token counting + overflow middle… (#444)

* fix: robust compaction with Pi-style token counting + overflow middleware

Root cause: getCurrentTokenCount() returned stale inputTokens from the
previous step, ignoring new tool results added to messages since that
step. A large tool output (DOM snapshot, page content) caused a token
jump that bypassed the compaction threshold check, leading to
context_length_exceeded errors (322K tokens sent, model max 262K).

Layer 1 — Accurate token counting (proactive):
- Adopt Pi coding agent's additive approach: base(inputTokens) +
  outputTokens + estimate(trailing tool results)
- Trailing tool results are estimated by walking backwards from end of
  messages array until a non-tool message is found
- Falls back to full estimation with safety multiplier when no real
  usage data is available (first step of a turn)

Layer 2 — Context overflow middleware (reactive):
- LanguageModelV3Middleware that wraps doGenerate/doStream
- Catches context_length_exceeded errors at the model call level
- Truncates prompt (keeps system messages + most recent non-system
  messages targeting 60% of context window)
- Retries the model call once

Verified end-to-end with real model (Gemini Flash Lite via OpenRouter)
on 16K context window: 4 compactions triggered correctly across 8
steps, no context_length_exceeded errors.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix: adopt Pi-style overflow detection patterns + fix truncation edge case

- Replace 6 generic substring matches with 17 provider-specific regex
  patterns from Pi coding agent (Anthropic, OpenAI, Google, xAI, Groq,
  OpenRouter, Bedrock, Copilot, llama.cpp, LM Studio, MiniMax, Kimi,
  Mistral, z.ai)
- Fix truncatePrompt edge case: when the last message alone exceeds the
  target, keepFrom was never updated → empty non-system messages. Now
  always keeps at least the most recent non-system message.
- Add runtime guard for LanguageModelV3 cast in ai-sdk-agent.ts
- Add tests for false-positive rejection and truncation edge case

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
shivammittal274
2026-03-09 14:22:35 +05:30
committed by GitHub
parent eb208b0515
commit 3808faf94d
4 changed files with 492 additions and 11 deletions

View File

@@ -1,6 +1,12 @@
import type { LanguageModelV3 } from '@ai-sdk/provider'
import { AGENT_LIMITS } from '@browseros/shared/constants/limits'
import type { BrowserContext } from '@browseros/shared/schemas/browser-context'
import { stepCountIs, ToolLoopAgent, type UIMessage } from 'ai'
import {
stepCountIs,
ToolLoopAgent,
type UIMessage,
wrapLanguageModel,
} from 'ai'
import type { Browser } from '../browser/browser'
import type { KlavisClient } from '../lib/clients/klavis/klavis-client'
import { logger } from '../lib/logger'
@@ -10,6 +16,7 @@ import { buildMemoryToolSet } from '../tools/memory/build-toolset'
import type { ToolRegistry } from '../tools/tool-registry'
import { CHAT_MODE_ALLOWED_TOOLS } from './chat-mode'
import { createCompactionPrepareStep } from './compaction'
import { createContextOverflowMiddleware } from './context-overflow-middleware'
import { buildMcpServerSpecs, createMcpClients } from './mcp-builder'
import { buildSystemPrompt } from './prompt'
import { createLanguageModel } from './provider-factory'
@@ -34,8 +41,19 @@ export class AiSdkAgent {
) {}
static async create(config: AiSdkAgentConfig): Promise<AiSdkAgent> {
// Build language model from provider config
const model = createLanguageModel(config.resolvedConfig)
const contextWindow =
config.resolvedConfig.contextWindowSize ??
AGENT_LIMITS.DEFAULT_CONTEXT_WINDOW
// Build language model with overflow protection middleware
const rawModel = createLanguageModel(config.resolvedConfig)
const model =
(rawModel as any).specificationVersion === 'v3'
? wrapLanguageModel({
model: rawModel as LanguageModelV3,
middleware: createContextOverflowMiddleware(contextWindow),
})
: rawModel
// Build browser tools from the unified tool registry
const allBrowserTools = buildBrowserToolSet(config.registry, config.browser)
@@ -95,9 +113,6 @@ export class AiSdkAgent {
})
// Configure compaction for context window management
const contextWindow =
config.resolvedConfig.contextWindowSize ??
AGENT_LIMITS.DEFAULT_CONTEXT_WINDOW
const prepareStep = createCompactionPrepareStep({
contextWindow,
})

View File

@@ -157,8 +157,11 @@ export function estimateTokens(
return Math.ceil(chars / 4) + imageCount * imageTokenEstimate
}
interface StepWithUsage {
usage?: { inputTokens?: number | undefined }
export interface StepWithUsage {
usage?: {
inputTokens?: number | undefined
outputTokens?: number | undefined
}
}
export function getCurrentTokenCount(
@@ -166,15 +169,31 @@ export function getCurrentTokenCount(
messages: ModelMessage[],
config: ComputedConfig,
): number {
// Use real API usage from the last step when available
if (steps.length > 0) {
const lastStep = steps[steps.length - 1]
if (lastStep.usage?.inputTokens != null && lastStep.usage.inputTokens > 0) {
return lastStep.usage.inputTokens
// Pi-style additive: real usage as base + trailing content added since
const base = lastStep.usage.inputTokens
const outputTokens = lastStep.usage.outputTokens ?? 0
// Estimate trailing tool result messages added after the last model call
let trailingTokens = 0
for (let i = messages.length - 1; i >= 0; i--) {
if (messages[i].role === 'tool') {
trailingTokens += estimateTokens(
[messages[i]],
config.imageTokenEstimate,
)
} else {
break
}
}
return base + outputTokens + trailingTokens
}
}
// Fallback: estimation with safety multiplier + overhead
// No real usage → full estimation with safety margin
const estimated = estimateTokens(messages, config.imageTokenEstimate)
return Math.ceil(estimated * config.safetyMultiplier) + config.fixedOverhead
}

View File

@@ -0,0 +1,116 @@
import type {
LanguageModelV3CallOptions,
LanguageModelV3Message,
LanguageModelV3Middleware,
LanguageModelV3Prompt,
} from '@ai-sdk/provider'
import { logger } from '../lib/logger'
/**
* Provider-specific regex patterns for context overflow errors.
* Adapted from Pi coding agent's overflow detection.
*
* @see https://github.com/badlogic/pi-mono/blob/main/packages/ai/src/utils/overflow.ts
*/
const OVERFLOW_PATTERNS: RegExp[] = [
/prompt is too long/i, // Anthropic
/input is too long for requested model/i, // Amazon Bedrock
/exceeds the context window/i, // OpenAI (Completions & Responses API)
/input token count.*exceeds the maximum/i, // Google (Gemini)
/maximum prompt length is \d+/i, // xAI (Grok)
/reduce the length of the messages/i, // Groq
/maximum context length is \d+ tokens/i, // OpenRouter (all backends)
/exceeds the limit of \d+/i, // GitHub Copilot
/exceeds the available context size/i, // llama.cpp server
/greater than the context length/i, // LM Studio
/context window exceeds limit/i, // MiniMax
/exceeded model token limit/i, // Kimi For Coding
/too large for model with \d+ maximum context length/i, // Mistral
/model_context_window_exceeded/i, // z.ai non-standard finish_reason
/context[_ ]length[_ ]exceeded/i, // Generic fallback
/too many tokens/i, // Generic fallback
/token limit exceeded/i, // Generic fallback
]
export function isContextOverflowError(error: unknown): boolean {
if (!(error instanceof Error)) return false
const msg = error.message
return OVERFLOW_PATTERNS.some((p) => p.test(msg))
}
function truncatePrompt(
prompt: LanguageModelV3Prompt,
contextWindow: number,
): LanguageModelV3Prompt {
const systemMessages: LanguageModelV3Message[] = []
const nonSystem: LanguageModelV3Message[] = []
for (const m of prompt) {
if (m.role === 'system') systemMessages.push(m)
else nonSystem.push(m)
}
// Target 60% of context window to leave headroom
const targetChars = contextWindow * 4 * 0.6
let totalChars = 0
let keepFrom = nonSystem.length
for (let i = nonSystem.length - 1; i >= 0; i--) {
totalChars += JSON.stringify(nonSystem[i].content).length
if (totalChars > targetChars) break
keepFrom = i
}
// Always keep at least the most recent non-system message
if (keepFrom >= nonSystem.length && nonSystem.length > 0) {
keepFrom = nonSystem.length - 1
}
const kept: LanguageModelV3Prompt = [
...systemMessages,
...nonSystem.slice(keepFrom),
]
logger.warn('Emergency prompt truncation', {
original: prompt.length,
kept: kept.length,
dropped: prompt.length - kept.length,
})
return kept
}
export function createContextOverflowMiddleware(
contextWindow: number,
): LanguageModelV3Middleware {
return {
specificationVersion: 'v3',
wrapGenerate: async ({ doGenerate, params }) => {
try {
return await doGenerate()
} catch (error) {
if (!isContextOverflowError(error)) throw error
logger.warn(
'Context overflow detected in doGenerate, truncating and retrying',
)
;(params as LanguageModelV3CallOptions).prompt = truncatePrompt(
params.prompt,
contextWindow,
)
return await doGenerate()
}
},
wrapStream: async ({ doStream, params }) => {
try {
return await doStream()
} catch (error) {
if (!isContextOverflowError(error)) throw error
logger.warn(
'Context overflow detected in doStream, truncating and retrying',
)
;(params as LanguageModelV3CallOptions).prompt = truncatePrompt(
params.prompt,
contextWindow,
)
return await doStream()
}
},
}
}

View File

@@ -4,6 +4,8 @@ import {
computeConfig,
estimateTokens,
findSafeSplitPoint,
getCurrentTokenCount,
type StepWithUsage,
slidingWindow,
truncateToolOutputs,
} from '../../src/agent/compaction'
@@ -12,6 +14,10 @@ import {
buildTurnPrefixPrompt,
messagesToTranscript,
} from '../../src/agent/compaction-prompt'
import {
createContextOverflowMiddleware,
isContextOverflowError,
} from '../../src/agent/context-overflow-middleware'
// ---------------------------------------------------------------------------
// Helpers
@@ -732,3 +738,328 @@ describe('end-to-end config coherence', () => {
}
})
})
// ---------------------------------------------------------------------------
// getCurrentTokenCount — Pi-style additive counting
// ---------------------------------------------------------------------------
describe('getCurrentTokenCount — Pi-style additive', () => {
const config = computeConfig(200_000)
it('returns estimated with safety margin when no steps exist', () => {
const msgs = [userMsg('a'.repeat(400))]
const result = getCurrentTokenCount([], msgs, config)
const rawEstimate = estimateTokens(msgs, config.imageTokenEstimate)
const expected =
Math.ceil(rawEstimate * config.safetyMultiplier) + config.fixedOverhead
expect(result).toBe(expected)
})
it('returns estimated when last step has no usage', () => {
const steps: StepWithUsage[] = [{ usage: undefined }]
const msgs = [userMsg('hello')]
const result = getCurrentTokenCount(steps, msgs, config)
const rawEstimate = estimateTokens(msgs, config.imageTokenEstimate)
const expected =
Math.ceil(rawEstimate * config.safetyMultiplier) + config.fixedOverhead
expect(result).toBe(expected)
})
it('adds outputTokens to base when no trailing tool results', () => {
const steps: StepWithUsage[] = [
{ usage: { inputTokens: 50_000, outputTokens: 2_000 } },
]
const msgs = [userMsg('hello'), assistantMsg('response')]
const result = getCurrentTokenCount(steps, msgs, config)
expect(result).toBe(52_000)
})
it('adds trailing tool result tokens to base + output', () => {
const toolOutput = 'x'.repeat(40_000) // ~10K tokens
const steps: StepWithUsage[] = [
{ usage: { inputTokens: 100_000, outputTokens: 1_000 } },
]
const msgs = [
userMsg('hello'),
assistantToolCall('snapshot', {}),
toolResult('snapshot', toolOutput),
]
const result = getCurrentTokenCount(steps, msgs, config)
const expectedTrailing = estimateTokens(
[toolResult('snapshot', toolOutput)],
config.imageTokenEstimate,
)
expect(result).toBe(100_000 + 1_000 + expectedTrailing)
})
it('catches large DOM snapshot that would bypass threshold', () => {
// Simulates the original bug: last step saw 150K tokens,
// then a 100K-char tool result (~25K tokens) is added
const largeSnapshot = 'x'.repeat(100_000)
const steps: StepWithUsage[] = [
{ usage: { inputTokens: 150_000, outputTokens: 500 } },
]
const msgs = [
userMsg('navigate to site'),
assistantToolCall('snapshot', {}),
toolResult('snapshot', largeSnapshot),
]
const result = getCurrentTokenCount(steps, msgs, config)
// Must be significantly above 150K — the old code returned 150K (stale)
expect(result).toBeGreaterThan(170_000)
})
it('counts multiple trailing tool results', () => {
const steps: StepWithUsage[] = [
{ usage: { inputTokens: 80_000, outputTokens: 1_000 } },
]
const msgs = [
userMsg('do things'),
assistantToolCall('click', { selector: '#btn' }),
toolResult('click', 'x'.repeat(4_000)),
toolResult('snapshot', 'y'.repeat(8_000)),
]
const result = getCurrentTokenCount(steps, msgs, config)
const trailing1 = estimateTokens(
[toolResult('click', 'x'.repeat(4_000))],
config.imageTokenEstimate,
)
const trailing2 = estimateTokens(
[toolResult('snapshot', 'y'.repeat(8_000))],
config.imageTokenEstimate,
)
expect(result).toBe(80_000 + 1_000 + trailing1 + trailing2)
})
it('stops counting trailing at first non-tool message', () => {
const steps: StepWithUsage[] = [
{ usage: { inputTokens: 50_000, outputTokens: 500 } },
]
// assistant message after tool results — trailing should be 0
const msgs = [
userMsg('hello'),
assistantToolCall('click', {}),
toolResult('click', 'x'.repeat(4_000)),
assistantMsg('done'),
]
const result = getCurrentTokenCount(steps, msgs, config)
// No trailing tool results (last message is assistant)
expect(result).toBe(50_500)
})
it('handles zero outputTokens gracefully', () => {
const steps: StepWithUsage[] = [{ usage: { inputTokens: 50_000 } }]
const msgs = [userMsg('hello')]
const result = getCurrentTokenCount(steps, msgs, config)
expect(result).toBe(50_000)
})
})
// ---------------------------------------------------------------------------
// Context overflow middleware
// ---------------------------------------------------------------------------
describe('createContextOverflowMiddleware', () => {
it('passes through when model succeeds', async () => {
const middleware = createContextOverflowMiddleware(200_000)
const mockResult = { text: 'hello' }
const params = {
prompt: [
{ role: 'system', content: 'You are helpful' },
{ role: 'user', content: 'hi' },
],
}
const result = await middleware.wrapGenerate!({
doGenerate: async () => mockResult,
params,
} as any)
expect(result).toBe(mockResult)
})
it('rethrows non-context errors', async () => {
const middleware = createContextOverflowMiddleware(200_000)
const params = {
prompt: [{ role: 'user', content: 'hi' }],
}
await expect(
middleware.wrapGenerate!({
doGenerate: async () => {
throw new Error('network timeout')
},
params,
} as any),
).rejects.toThrow('network timeout')
})
it('truncates and retries on context_length error', async () => {
const middleware = createContextOverflowMiddleware(200_000)
let callCount = 0
const params = {
prompt: [
{ role: 'system', content: 'system prompt' },
{ role: 'user', content: 'old message 1' },
{ role: 'assistant', content: 'old response 1' },
{ role: 'user', content: 'old message 2' },
{ role: 'assistant', content: 'old response 2' },
{ role: 'user', content: 'recent message' },
],
}
const result = await middleware.wrapGenerate!({
doGenerate: async () => {
callCount++
if (callCount === 1) {
throw new Error('context_length_exceeded')
}
return { text: 'success after truncation' }
},
params,
} as any)
expect(callCount).toBe(2)
expect(result).toEqual({ text: 'success after truncation' })
// System message should be preserved
expect(params.prompt.some((m: any) => m.role === 'system')).toBe(true)
// Prompt should be shorter after truncation
expect(params.prompt.length).toBeLessThanOrEqual(6)
})
it('preserves system messages during truncation', async () => {
const middleware = createContextOverflowMiddleware(10_000)
let truncatedPrompt: any[] = []
const params = {
prompt: [
{ role: 'system', content: 'important system prompt' },
{ role: 'user', content: 'a'.repeat(50_000) },
{ role: 'assistant', content: 'b'.repeat(50_000) },
{ role: 'user', content: 'recent' },
],
}
await middleware.wrapGenerate!({
doGenerate: async () => {
if (truncatedPrompt.length === 0) {
truncatedPrompt = [...params.prompt]
throw new Error('maximum context length exceeded')
}
truncatedPrompt = [...params.prompt]
return { text: 'ok' }
},
params,
} as any)
const systemMsgs = truncatedPrompt.filter((m: any) => m.role === 'system')
expect(systemMsgs.length).toBe(1)
expect(systemMsgs[0].content).toBe('important system prompt')
})
it('handles wrapStream the same way', async () => {
const middleware = createContextOverflowMiddleware(200_000)
let callCount = 0
const params = {
prompt: [
{ role: 'system', content: 'system' },
{ role: 'user', content: 'message' },
],
}
const result = await middleware.wrapStream!({
doStream: async () => {
callCount++
if (callCount === 1) {
throw new Error('token limit exceeded')
}
return { stream: 'mock-stream' }
},
params,
} as any)
expect(callCount).toBe(2)
expect(result).toEqual({ stream: 'mock-stream' })
})
it('detects provider-specific context overflow errors', async () => {
const middleware = createContextOverflowMiddleware(200_000)
const errorMessages = [
'context_length_exceeded', // Generic
'prompt is too long: 213462 tokens > 200000 maximum', // Anthropic
'Your input exceeds the context window of this model', // OpenAI
'The input token count (1196265) exceeds the maximum number of tokens allowed', // Google
"This model's maximum prompt length is 131072 but the request contains 537812 tokens", // xAI
'Please reduce the length of the messages or completion', // Groq
'maximum context length is 128000 tokens', // OpenRouter
'token limit exceeded', // Generic
'too many tokens', // Generic
'exceeded model token limit', // Kimi
'input is too long for requested model', // Amazon Bedrock
]
for (const errMsg of errorMessages) {
let callCount = 0
const params = {
prompt: [{ role: 'user', content: 'hi' }],
}
await middleware.wrapGenerate!({
doGenerate: async () => {
callCount++
if (callCount === 1) throw new Error(errMsg)
return { text: 'ok' }
},
params,
} as any)
expect(callCount).toBe(2)
}
})
it('does not false-positive on unrelated errors', () => {
const unrelatedErrors = [
'URL is too long',
'Invalid max_tokens: must be between 1 and 4096',
'session token is too long',
'file name is too long',
'network timeout',
'rate limit exceeded',
]
for (const errMsg of unrelatedErrors) {
expect(isContextOverflowError(new Error(errMsg))).toBe(false)
}
})
it('keeps at least the last non-system message when it exceeds target', async () => {
const middleware = createContextOverflowMiddleware(1_000)
let truncatedPrompt: any[] = []
const params = {
prompt: [
{ role: 'system', content: 'system' },
{ role: 'user', content: 'x'.repeat(100_000) },
],
}
await middleware.wrapGenerate!({
doGenerate: async () => {
if (truncatedPrompt.length === 0) {
truncatedPrompt = [...params.prompt]
throw new Error('context_length_exceeded')
}
truncatedPrompt = [...params.prompt]
return { text: 'ok' }
},
params,
} as any)
// Must keep system + at least the last user message (not empty)
expect(truncatedPrompt.length).toBe(2)
expect(truncatedPrompt[0].role).toBe('system')
expect(truncatedPrompt[1].role).toBe('user')
})
})