diff --git a/extensions/byteplus/video-generation-provider.ts b/extensions/byteplus/video-generation-provider.ts index 3ac517a7894..d4000469102 100644 --- a/extensions/byteplus/video-generation-provider.ts +++ b/extensions/byteplus/video-generation-provider.ts @@ -4,7 +4,8 @@ import { resolveApiKeyForProvider } from "openclaw/plugin-sdk/provider-auth-runt import { assertOkOrThrowHttpError, createProviderOperationDeadline, - fetchWithTimeout, + fetchProviderDownloadResponse, + fetchProviderOperationResponse, postJsonRequest, resolveProviderOperationTimeoutMs, resolveProviderHttpRequestConfig, @@ -82,16 +83,21 @@ async function pollBytePlusTask(params: { label: `BytePlus video generation task ${params.taskId}`, }); for (let attempt = 0; attempt < MAX_POLL_ATTEMPTS; attempt += 1) { - const response = await fetchWithTimeout( - `${params.baseUrl}/contents/generations/tasks/${params.taskId}`, - { + const response = await fetchProviderOperationResponse({ + stage: "poll", + url: `${params.baseUrl}/contents/generations/tasks/${params.taskId}`, + init: { method: "GET", headers: params.headers, }, - resolveProviderOperationTimeoutMs({ deadline, defaultTimeoutMs: DEFAULT_TIMEOUT_MS }), - params.fetchFn, - ); - await assertOkOrThrowHttpError(response, "BytePlus video status request failed"); + timeoutMs: resolveProviderOperationTimeoutMs({ + deadline, + defaultTimeoutMs: DEFAULT_TIMEOUT_MS, + }), + fetchFn: params.fetchFn, + provider: "byteplus", + requestFailedMessage: "BytePlus video status request failed", + }); const payload = (await response.json()) as BytePlusTaskResponse; switch (normalizeOptionalString(payload.status)) { case "succeeded": @@ -116,13 +122,14 @@ async function downloadBytePlusVideo(params: { timeoutMs?: number; fetchFn: typeof fetch; }): Promise { - const response = await fetchWithTimeout( - params.url, - { method: "GET" }, - params.timeoutMs ?? DEFAULT_TIMEOUT_MS, - params.fetchFn, - ); - await assertOkOrThrowHttpError(response, "BytePlus generated video download failed"); + const response = await fetchProviderDownloadResponse({ + url: params.url, + init: { method: "GET" }, + timeoutMs: params.timeoutMs ?? DEFAULT_TIMEOUT_MS, + fetchFn: params.fetchFn, + provider: "byteplus", + requestFailedMessage: "BytePlus generated video download failed", + }); const mimeType = normalizeOptionalString(response.headers.get("content-type")) ?? "video/mp4"; const arrayBuffer = await response.arrayBuffer(); return { diff --git a/extensions/google/embedding-provider.ts b/extensions/google/embedding-provider.ts index 477516691bf..5831dd6c146 100644 --- a/extensions/google/embedding-provider.ts +++ b/extensions/google/embedding-provider.ts @@ -14,7 +14,10 @@ import { requireApiKey, resolveApiKeyForProvider, } from "openclaw/plugin-sdk/provider-auth-runtime"; -import { createProviderHttpError } from "openclaw/plugin-sdk/provider-http"; +import { + createProviderHttpError, + providerOperationRetryConfig, +} from "openclaw/plugin-sdk/provider-http"; import type { SsrFPolicy } from "openclaw/plugin-sdk/ssrf-runtime"; import { normalizeOptionalString } from "openclaw/plugin-sdk/string-coerce-runtime"; @@ -199,7 +202,7 @@ async function fetchGeminiEmbeddingPayload(params: { return await executeWithApiKeyRotation({ provider: "google", apiKeys: params.client.apiKeys, - transientRetry: true, + transientRetry: providerOperationRetryConfig("read"), execute: async (apiKey) => { const authHeaders = parseGeminiAuth(apiKey); const headers = { diff --git a/extensions/google/video-generation-provider.ts b/extensions/google/video-generation-provider.ts index cce6ae6784f..19527d31e70 100644 --- a/extensions/google/video-generation-provider.ts +++ b/extensions/google/video-generation-provider.ts @@ -3,6 +3,7 @@ import path from "node:path"; import { resolveApiKeyForProvider } from "openclaw/plugin-sdk/provider-auth-runtime"; import { createProviderOperationDeadline, + executeProviderOperationWithRetry, resolveProviderOperationTimeoutMs, waitProviderOperationPollInterval, } from "openclaw/plugin-sdk/provider-http"; @@ -161,9 +162,15 @@ async function downloadGeneratedVideo(params: { rootDir: tempDir, path: fileName, write: async (downloadPath) => { - await params.client.files.download({ - file: params.file as never, - downloadPath, + await executeProviderOperationWithRetry({ + provider: "google", + stage: "download", + operation: async () => { + await params.client.files.download({ + file: params.file as never, + downloadPath, + }); + }, }); }, }); @@ -230,27 +237,33 @@ async function downloadGeneratedVideoFromUri(params: { if (!downloadUrl) { return undefined; } - const { response, release } = await fetchWithSsrFGuard({ - url: downloadUrl, + return await executeProviderOperationWithRetry({ + provider: "google", + stage: "download", + operation: async () => { + const { response, release } = await fetchWithSsrFGuard({ + url: downloadUrl, + }); + try { + if (!response.ok) { + throw new Error( + `Failed to download Google generated video: ${response.status} ${response.statusText}`, + ); + } + const buffer = Buffer.from(await response.arrayBuffer()); + return { + buffer, + mimeType: + normalizeOptionalString(response.headers.get("content-type")) || + normalizeOptionalString(params.mimeType) || + "video/mp4", + fileName: `video-${params.index + 1}.mp4`, + }; + } finally { + await release(); + } + }, }); - try { - if (!response.ok) { - throw new Error( - `Failed to download Google generated video: ${response.status} ${response.statusText}`, - ); - } - const buffer = Buffer.from(await response.arrayBuffer()); - return { - buffer, - mimeType: - normalizeOptionalString(response.headers.get("content-type")) || - normalizeOptionalString(params.mimeType) || - "video/mp4", - fileName: `video-${params.index + 1}.mp4`, - }; - } finally { - await release(); - } } function extractGoogleApiErrorCode(error: unknown): number | undefined { @@ -284,39 +297,52 @@ async function requestGoogleVideoJson(params: { method: "GET" | "POST"; headers: Record; deadline: ReturnType; + stage: "create" | "poll"; body?: unknown; }): Promise { - const controller = new AbortController(); - const timeout = setTimeout( - () => controller.abort(), - resolveProviderOperationTimeoutMs({ - deadline: params.deadline, - defaultTimeoutMs: DEFAULT_TIMEOUT_MS, - }), - ); - try { - const { response, release } = await fetchWithSsrFGuard({ - url: params.url, - init: { - method: params.method, - headers: params.headers, - ...(params.body === undefined ? {} : { body: JSON.stringify(params.body) }), - }, - signal: controller.signal, - }); - try { - const text = await response.text(); - const payload = text ? (JSON.parse(text) as unknown) : {}; - if (!response.ok) { - throw new Error(typeof payload === "string" ? payload : JSON.stringify(payload ?? null)); + return await executeProviderOperationWithRetry({ + provider: "google", + stage: params.stage, + operation: async () => { + const controller = new AbortController(); + const timeout = setTimeout( + () => { + const error = new Error("request timed out"); + error.name = "TimeoutError"; + controller.abort(error); + }, + resolveProviderOperationTimeoutMs({ + deadline: params.deadline, + defaultTimeoutMs: DEFAULT_TIMEOUT_MS, + }), + ); + try { + const { response, release } = await fetchWithSsrFGuard({ + url: params.url, + init: { + method: params.method, + headers: params.headers, + ...(params.body === undefined ? {} : { body: JSON.stringify(params.body) }), + }, + signal: controller.signal, + }); + try { + const text = await response.text(); + const payload = text ? (JSON.parse(text) as unknown) : {}; + if (!response.ok) { + throw new Error( + typeof payload === "string" ? payload : JSON.stringify(payload ?? null), + ); + } + return payload; + } finally { + await release(); + } + } finally { + clearTimeout(timeout); } - return payload; - } finally { - await release(); - } - } finally { - clearTimeout(timeout); - } + }, + }); } async function generateGoogleVideoViaRest(params: { @@ -334,6 +360,7 @@ async function generateGoogleVideoViaRest(params: { method: "POST", headers: params.headers, deadline: params.deadline, + stage: "create", body: { instances: [{ prompt: params.prompt }], parameters: { @@ -363,6 +390,7 @@ async function generateGoogleVideoViaRest(params: { method: "GET", headers: params.headers, deadline: params.deadline, + stage: "poll", }); } const error = (operation as { error?: unknown }).error; @@ -462,7 +490,11 @@ export function buildGoogleVideoGenerationProvider(): VideoGenerationProvider { } await waitProviderOperationPollInterval({ deadline, pollIntervalMs: POLL_INTERVAL_MS }); resolveProviderOperationTimeoutMs({ deadline, defaultTimeoutMs: DEFAULT_TIMEOUT_MS }); - sdkOperation = await client.operations.getVideosOperation({ operation: sdkOperation }); + sdkOperation = await executeProviderOperationWithRetry({ + provider: "google", + stage: "poll", + operation: () => client.operations.getVideosOperation({ operation: sdkOperation }), + }); } operation = sdkOperation; } diff --git a/extensions/minimax/music-generation-provider.ts b/extensions/minimax/music-generation-provider.ts index 5334d07db7c..bcafb8b186a 100644 --- a/extensions/minimax/music-generation-provider.ts +++ b/extensions/minimax/music-generation-provider.ts @@ -8,7 +8,7 @@ import { isProviderApiKeyConfigured } from "openclaw/plugin-sdk/provider-auth"; import { resolveApiKeyForProvider } from "openclaw/plugin-sdk/provider-auth-runtime"; import { assertOkOrThrowHttpError, - fetchWithTimeout, + fetchProviderDownloadResponse, postJsonRequest, resolveProviderHttpRequestConfig, } from "openclaw/plugin-sdk/provider-http"; @@ -89,13 +89,14 @@ async function downloadTrackFromUrl(params: { timeoutMs?: number; fetchFn: typeof fetch; }): Promise { - const response = await fetchWithTimeout( - params.url, - { method: "GET" }, - params.timeoutMs ?? DEFAULT_TIMEOUT_MS, - params.fetchFn, - ); - await assertOkOrThrowHttpError(response, "MiniMax generated music download failed"); + const response = await fetchProviderDownloadResponse({ + url: params.url, + init: { method: "GET" }, + timeoutMs: params.timeoutMs ?? DEFAULT_TIMEOUT_MS, + fetchFn: params.fetchFn, + provider: "minimax", + requestFailedMessage: "MiniMax generated music download failed", + }); const mimeType = normalizeOptionalString(response.headers.get("content-type")) ?? "audio/mpeg"; const ext = extensionForMime(mimeType)?.replace(/^\./u, "") || "mp3"; return { diff --git a/extensions/minimax/provider-http.test-helpers.ts b/extensions/minimax/provider-http.test-helpers.ts index a6eaa233ea9..ee0be39ed18 100644 --- a/extensions/minimax/provider-http.test-helpers.ts +++ b/extensions/minimax/provider-http.test-helpers.ts @@ -1,9 +1,15 @@ -import type { resolveProviderHttpRequestConfig } from "openclaw/plugin-sdk/provider-http"; +import type { + fetchProviderDownloadResponse, + fetchProviderOperationResponse, + resolveProviderHttpRequestConfig, +} from "openclaw/plugin-sdk/provider-http"; import { afterEach, vi, type Mock } from "vitest"; type ResolveProviderHttpRequestConfigParams = Parameters< typeof resolveProviderHttpRequestConfig >[0]; +type FetchProviderOperationResponseParams = Parameters[0]; +type FetchProviderDownloadResponseParams = Parameters[0]; type ResolveProviderHttpRequestConfigResult = { baseUrl: string; @@ -18,6 +24,8 @@ interface MinimaxProviderHttpMocks { resolveApiKeyForProviderMock: Mock<() => Promise<{ apiKey: string }>>; postJsonRequestMock: AnyMock; fetchWithTimeoutMock: AnyMock; + fetchProviderOperationResponseMock: AnyMock; + fetchProviderDownloadResponseMock: AnyMock; assertOkOrThrowHttpErrorMock: Mock<() => Promise>; resolveProviderHttpRequestConfigMock: Mock< (params: ResolveProviderHttpRequestConfigParams) => ResolveProviderHttpRequestConfigResult @@ -28,7 +36,9 @@ const minimaxProviderHttpMocks = vi.hoisted(() => ({ resolveApiKeyForProviderMock: vi.fn(async () => ({ apiKey: "provider-key" })), postJsonRequestMock: vi.fn(), fetchWithTimeoutMock: vi.fn(), - assertOkOrThrowHttpErrorMock: vi.fn(async () => {}), + fetchProviderOperationResponseMock: vi.fn(), + fetchProviderDownloadResponseMock: vi.fn(), + assertOkOrThrowHttpErrorMock: vi.fn(async (_response: Response, _label: string) => {}), resolveProviderHttpRequestConfigMock: vi.fn((params: ResolveProviderHttpRequestConfigParams) => ({ baseUrl: params.baseUrl ?? params.defaultBaseUrl, allowPrivateNetwork: false, @@ -37,6 +47,40 @@ const minimaxProviderHttpMocks = vi.hoisted(() => ({ })), })); +minimaxProviderHttpMocks.fetchProviderOperationResponseMock.mockImplementation( + async (params: FetchProviderOperationResponseParams) => { + const response = await minimaxProviderHttpMocks.fetchWithTimeoutMock( + params.url, + params.init ?? {}, + params.timeoutMs ?? 60_000, + params.fetchFn, + ); + if (params.requestFailedMessage) { + await minimaxProviderHttpMocks.assertOkOrThrowHttpErrorMock( + response, + params.requestFailedMessage, + ); + } + return response; + }, +); + +minimaxProviderHttpMocks.fetchProviderDownloadResponseMock.mockImplementation( + async (params: FetchProviderDownloadResponseParams) => { + const response = await minimaxProviderHttpMocks.fetchWithTimeoutMock( + params.url, + params.init ?? {}, + params.timeoutMs ?? 60_000, + params.fetchFn, + ); + await minimaxProviderHttpMocks.assertOkOrThrowHttpErrorMock( + response, + params.requestFailedMessage, + ); + return response; + }, +); + vi.mock("openclaw/plugin-sdk/provider-auth-runtime", () => ({ resolveApiKeyForProvider: minimaxProviderHttpMocks.resolveApiKeyForProviderMock, })); @@ -53,6 +97,8 @@ vi.mock("openclaw/plugin-sdk/provider-http", () => ({ label, timeoutMs, }), + fetchProviderDownloadResponse: minimaxProviderHttpMocks.fetchProviderDownloadResponseMock, + fetchProviderOperationResponse: minimaxProviderHttpMocks.fetchProviderOperationResponseMock, fetchWithTimeout: minimaxProviderHttpMocks.fetchWithTimeoutMock, postJsonRequest: minimaxProviderHttpMocks.postJsonRequestMock, resolveProviderOperationTimeoutMs: ({ defaultTimeoutMs }: { defaultTimeoutMs: number }) => @@ -70,6 +116,8 @@ export function installMinimaxProviderHttpMockCleanup(): void { minimaxProviderHttpMocks.resolveApiKeyForProviderMock.mockClear(); minimaxProviderHttpMocks.postJsonRequestMock.mockReset(); minimaxProviderHttpMocks.fetchWithTimeoutMock.mockReset(); + minimaxProviderHttpMocks.fetchProviderOperationResponseMock.mockClear(); + minimaxProviderHttpMocks.fetchProviderDownloadResponseMock.mockClear(); minimaxProviderHttpMocks.assertOkOrThrowHttpErrorMock.mockClear(); minimaxProviderHttpMocks.resolveProviderHttpRequestConfigMock.mockClear(); }); diff --git a/extensions/minimax/video-generation-provider.ts b/extensions/minimax/video-generation-provider.ts index d784ffa9f5e..9fde36c110a 100644 --- a/extensions/minimax/video-generation-provider.ts +++ b/extensions/minimax/video-generation-provider.ts @@ -4,7 +4,8 @@ import { resolveApiKeyForProvider } from "openclaw/plugin-sdk/provider-auth-runt import { assertOkOrThrowHttpError, createProviderOperationDeadline, - fetchWithTimeout, + fetchProviderDownloadResponse, + fetchProviderOperationResponse, postJsonRequest, resolveProviderOperationTimeoutMs, resolveProviderHttpRequestConfig, @@ -171,16 +172,21 @@ async function pollMinimaxVideo(params: { for (let attempt = 0; attempt < MAX_POLL_ATTEMPTS; attempt += 1) { const url = new URL(`${params.baseUrl}/v1/query/video_generation`); url.searchParams.set("task_id", params.taskId); - const response = await fetchWithTimeout( - url.toString(), - { + const response = await fetchProviderOperationResponse({ + stage: "poll", + url: url.toString(), + init: { method: "GET", headers: params.headers, }, - resolveProviderOperationTimeoutMs({ deadline, defaultTimeoutMs: DEFAULT_TIMEOUT_MS }), - params.fetchFn, - ); - await assertOkOrThrowHttpError(response, "MiniMax video status request failed"); + timeoutMs: resolveProviderOperationTimeoutMs({ + deadline, + defaultTimeoutMs: DEFAULT_TIMEOUT_MS, + }), + fetchFn: params.fetchFn, + provider: "minimax", + requestFailedMessage: "MiniMax video status request failed", + }); const payload = (await response.json()) as MinimaxQueryResponse; assertMinimaxBaseResp(payload.base_resp, "MiniMax video generation failed"); switch (normalizeOptionalString(payload.status)) { @@ -206,13 +212,14 @@ async function downloadVideoFromUrl(params: { timeoutMs?: number; fetchFn: typeof fetch; }): Promise { - const response = await fetchWithTimeout( - params.url, - { method: "GET" }, - params.timeoutMs ?? DEFAULT_TIMEOUT_MS, - params.fetchFn, - ); - await assertOkOrThrowHttpError(response, "MiniMax generated video download failed"); + const response = await fetchProviderDownloadResponse({ + url: params.url, + init: { method: "GET" }, + timeoutMs: params.timeoutMs ?? DEFAULT_TIMEOUT_MS, + fetchFn: params.fetchFn, + provider: "minimax", + requestFailedMessage: "MiniMax generated video download failed", + }); const mimeType = normalizeOptionalString(response.headers.get("content-type")) ?? "video/mp4"; const arrayBuffer = await response.arrayBuffer(); return { @@ -231,32 +238,32 @@ async function downloadVideoFromFileId(params: { }): Promise { const url = new URL(`${params.baseUrl}/v1/files/retrieve`); url.searchParams.set("file_id", params.fileId); - const metadataResponse = await fetchWithTimeout( - url.toString(), - { + const metadataResponse = await fetchProviderOperationResponse({ + stage: "download", + url: url.toString(), + init: { method: "GET", headers: params.headers, }, - params.timeoutMs ?? DEFAULT_TIMEOUT_MS, - params.fetchFn, - ); - await assertOkOrThrowHttpError( - metadataResponse, - "MiniMax generated video metadata request failed", - ); + timeoutMs: params.timeoutMs ?? DEFAULT_TIMEOUT_MS, + fetchFn: params.fetchFn, + provider: "minimax", + requestFailedMessage: "MiniMax generated video metadata request failed", + }); const metadata = (await metadataResponse.json()) as MinimaxFileRetrieveResponse; assertMinimaxBaseResp(metadata.base_resp, "MiniMax generated video metadata request failed"); const downloadUrl = normalizeOptionalString(metadata.file?.download_url); if (!downloadUrl) { throw new Error("MiniMax generated video metadata missing download_url"); } - const response = await fetchWithTimeout( - downloadUrl, - { method: "GET" }, - params.timeoutMs ?? DEFAULT_TIMEOUT_MS, - params.fetchFn, - ); - await assertOkOrThrowHttpError(response, "MiniMax generated video download failed"); + const response = await fetchProviderDownloadResponse({ + url: downloadUrl, + init: { method: "GET" }, + timeoutMs: params.timeoutMs ?? DEFAULT_TIMEOUT_MS, + fetchFn: params.fetchFn, + provider: "minimax", + requestFailedMessage: "MiniMax generated video download failed", + }); const mimeType = normalizeOptionalString(response.headers.get("content-type")) ?? "video/mp4"; const arrayBuffer = await response.arrayBuffer(); return { diff --git a/extensions/openai/video-generation-provider.ts b/extensions/openai/video-generation-provider.ts index 71a8eafd599..ab6b6ceb41c 100644 --- a/extensions/openai/video-generation-provider.ts +++ b/extensions/openai/video-generation-provider.ts @@ -4,6 +4,7 @@ import { resolveApiKeyForProvider } from "openclaw/plugin-sdk/provider-auth-runt import { assertOkOrThrowHttpError, createProviderOperationDeadline, + fetchProviderDownloadResponse, fetchWithTimeout, pollProviderOperationJson, postJsonRequest, @@ -151,19 +152,20 @@ async function downloadOpenAIVideo(params: { }): Promise { const url = new URL(`${params.baseUrl}/videos/${params.videoId}/content`); url.searchParams.set("variant", "video"); - const response = await fetchWithTimeout( - url.toString(), - { + const response = await fetchProviderDownloadResponse({ + url: url.toString(), + init: { method: "GET", headers: new Headers({ ...Object.fromEntries(params.headers.entries()), Accept: "application/binary", }), }, - params.timeoutMs ?? DEFAULT_TIMEOUT_MS, - params.fetchFn, - ); - await assertOkOrThrowHttpError(response, "OpenAI video download failed"); + timeoutMs: params.timeoutMs ?? DEFAULT_TIMEOUT_MS, + fetchFn: params.fetchFn, + provider: "openai", + requestFailedMessage: "OpenAI video download failed", + }); const mimeType = normalizeOptionalString(response.headers.get("content-type")) ?? "video/mp4"; const arrayBuffer = await response.arrayBuffer(); return { diff --git a/extensions/runway/video-generation-provider.ts b/extensions/runway/video-generation-provider.ts index 87879a49b47..469ebfc177b 100644 --- a/extensions/runway/video-generation-provider.ts +++ b/extensions/runway/video-generation-provider.ts @@ -4,7 +4,8 @@ import { resolveApiKeyForProvider } from "openclaw/plugin-sdk/provider-auth-runt import { assertOkOrThrowHttpError, createProviderOperationDeadline, - fetchWithTimeout, + fetchProviderDownloadResponse, + fetchProviderOperationResponse, postJsonRequest, resolveProviderOperationTimeoutMs, resolveProviderHttpRequestConfig, @@ -208,16 +209,21 @@ async function pollRunwayTask(params: { label: `Runway video generation task ${params.taskId}`, }); for (let attempt = 0; attempt < MAX_POLL_ATTEMPTS; attempt += 1) { - const response = await fetchWithTimeout( - `${params.baseUrl}/v1/tasks/${params.taskId}`, - { + const response = await fetchProviderOperationResponse({ + stage: "poll", + url: `${params.baseUrl}/v1/tasks/${params.taskId}`, + init: { method: "GET", headers: params.headers, }, - resolveProviderOperationTimeoutMs({ deadline, defaultTimeoutMs: DEFAULT_TIMEOUT_MS }), - params.fetchFn, - ); - await assertOkOrThrowHttpError(response, "Runway video status request failed"); + timeoutMs: resolveProviderOperationTimeoutMs({ + deadline, + defaultTimeoutMs: DEFAULT_TIMEOUT_MS, + }), + fetchFn: params.fetchFn, + provider: "runway", + requestFailedMessage: "Runway video status request failed", + }); const payload = (await response.json()) as RunwayTaskDetailResponse; switch (payload.status) { case "SUCCEEDED": @@ -247,13 +253,14 @@ async function downloadRunwayVideos(params: { }): Promise { const videos: GeneratedVideoAsset[] = []; for (const [index, url] of params.urls.entries()) { - const response = await fetchWithTimeout( + const response = await fetchProviderDownloadResponse({ url, - { method: "GET" }, - params.timeoutMs ?? DEFAULT_TIMEOUT_MS, - params.fetchFn, - ); - await assertOkOrThrowHttpError(response, "Runway generated video download failed"); + init: { method: "GET" }, + timeoutMs: params.timeoutMs ?? DEFAULT_TIMEOUT_MS, + fetchFn: params.fetchFn, + provider: "runway", + requestFailedMessage: "Runway generated video download failed", + }); const mimeType = normalizeOptionalString(response.headers.get("content-type")) ?? "video/mp4"; const arrayBuffer = await response.arrayBuffer(); videos.push({ diff --git a/extensions/together/video-generation-provider.ts b/extensions/together/video-generation-provider.ts index 2ee6065921b..93f2d906f16 100644 --- a/extensions/together/video-generation-provider.ts +++ b/extensions/together/video-generation-provider.ts @@ -4,7 +4,7 @@ import { resolveApiKeyForProvider } from "openclaw/plugin-sdk/provider-auth-runt import { assertOkOrThrowHttpError, createProviderOperationDeadline, - fetchWithTimeout, + fetchProviderDownloadResponse, pollProviderOperationJson, postJsonRequest, resolveProviderOperationTimeoutMs, @@ -102,13 +102,14 @@ async function downloadTogetherVideo(params: { timeoutMs?: number; fetchFn: typeof fetch; }): Promise { - const response = await fetchWithTimeout( - params.url, - { method: "GET" }, - params.timeoutMs ?? DEFAULT_TIMEOUT_MS, - params.fetchFn, - ); - await assertOkOrThrowHttpError(response, "Together generated video download failed"); + const response = await fetchProviderDownloadResponse({ + url: params.url, + init: { method: "GET" }, + timeoutMs: params.timeoutMs ?? DEFAULT_TIMEOUT_MS, + fetchFn: params.fetchFn, + provider: "together", + requestFailedMessage: "Together generated video download failed", + }); const mimeType = normalizeOptionalString(response.headers.get("content-type")) ?? "video/mp4"; const arrayBuffer = await response.arrayBuffer(); return { diff --git a/extensions/xai/video-generation-provider.ts b/extensions/xai/video-generation-provider.ts index 796bbf8c1b8..4da9e440b78 100644 --- a/extensions/xai/video-generation-provider.ts +++ b/extensions/xai/video-generation-provider.ts @@ -4,7 +4,8 @@ import { resolveApiKeyForProvider } from "openclaw/plugin-sdk/provider-auth-runt import { assertOkOrThrowHttpError, createProviderOperationDeadline, - fetchWithTimeout, + fetchProviderDownloadResponse, + fetchProviderOperationResponse, postJsonRequest, resolveProviderOperationTimeoutMs, resolveProviderHttpRequestConfig, @@ -261,16 +262,21 @@ async function pollXaiVideo(params: { label: `xAI video generation request ${params.requestId}`, }); for (let attempt = 0; attempt < MAX_POLL_ATTEMPTS; attempt += 1) { - const response = await fetchWithTimeout( - `${params.baseUrl}/videos/${params.requestId}`, - { + const response = await fetchProviderOperationResponse({ + stage: "poll", + url: `${params.baseUrl}/videos/${params.requestId}`, + init: { method: "GET", headers: params.headers, }, - resolveProviderOperationTimeoutMs({ deadline, defaultTimeoutMs: DEFAULT_TIMEOUT_MS }), - params.fetchFn, - ); - await assertOkOrThrowHttpError(response, "xAI video status request failed"); + timeoutMs: resolveProviderOperationTimeoutMs({ + deadline, + defaultTimeoutMs: DEFAULT_TIMEOUT_MS, + }), + fetchFn: params.fetchFn, + provider: "xai", + requestFailedMessage: "xAI video status request failed", + }); const payload = (await response.json()) as XaiVideoStatusResponse; switch (payload.status) { case "done": @@ -296,13 +302,14 @@ async function downloadXaiVideo(params: { timeoutMs?: number; fetchFn: typeof fetch; }): Promise { - const response = await fetchWithTimeout( - params.url, - { method: "GET" }, - params.timeoutMs ?? DEFAULT_TIMEOUT_MS, - params.fetchFn, - ); - await assertOkOrThrowHttpError(response, "xAI generated video download failed"); + const response = await fetchProviderDownloadResponse({ + url: params.url, + init: { method: "GET" }, + timeoutMs: params.timeoutMs ?? DEFAULT_TIMEOUT_MS, + fetchFn: params.fetchFn, + provider: "xai", + requestFailedMessage: "xAI generated video download failed", + }); const mimeType = normalizeOptionalString(response.headers.get("content-type")) ?? "video/mp4"; const arrayBuffer = await response.arrayBuffer(); return { diff --git a/src/agents/api-key-rotation.ts b/src/agents/api-key-rotation.ts index 7cf7a043e39..6a2feecd306 100644 --- a/src/agents/api-key-rotation.ts +++ b/src/agents/api-key-rotation.ts @@ -1,13 +1,13 @@ import { sleepWithAbort } from "../infra/backoff.js"; import { formatErrorMessage } from "../infra/errors.js"; -import { collectProviderApiKeys, isApiKeyRateLimitError } from "./live-auth-keys.js"; import { resolveTransientProviderAttempts, resolveTransientProviderDelayMs, resolveTransientProviderRetryOptions, shouldRetrySameKeyProviderOperation, type TransientProviderRetryConfig, -} from "./provider-operation-retry.js"; +} from "../provider-runtime/operation-retry.js"; +import { collectProviderApiKeys, isApiKeyRateLimitError } from "./live-auth-keys.js"; type ApiKeyRetryParams = { apiKey: string; diff --git a/src/agents/provider-operation-retry.ts b/src/agents/provider-operation-retry.ts index fc81d38bc18..c8939a1dd55 100644 --- a/src/agents/provider-operation-retry.ts +++ b/src/agents/provider-operation-retry.ts @@ -1,205 +1,15 @@ -import { formatErrorMessage } from "../infra/errors.js"; - -export type TransientProviderRetryParams = { - error: unknown; - message: string; - provider: string; - apiKeyIndex: number; - attemptNumber: number; -}; - -export type TransientProviderRetryOptions = { - /** - * Total executions per API key, including the first call. - * attempts: 2 means one initial call plus one same-key retry. - */ - attempts: number; - baseDelayMs?: number; - maxDelayMs?: number; - signal?: AbortSignal; - shouldRetry?: (params: TransientProviderRetryParams) => boolean; - sleep?: (ms: number, signal?: AbortSignal) => Promise; -}; - -export type TransientProviderRetryConfig = boolean | TransientProviderRetryOptions; - -export const DEFAULT_TRANSIENT_PROVIDER_RETRY_OPTIONS = { - attempts: 2, - baseDelayMs: 250, - maxDelayMs: 1_000, -} as const satisfies TransientProviderRetryOptions; - -export function resolveTransientProviderRetryOptions( - options?: TransientProviderRetryConfig, -): TransientProviderRetryOptions | undefined { - if (!options) { - return undefined; - } - if (options === true) { - return DEFAULT_TRANSIENT_PROVIDER_RETRY_OPTIONS; - } - return options; -} - -function readErrorName(error: unknown): string | undefined { - if (typeof error !== "object" || error === null) { - return undefined; - } - const name = (error as { name?: unknown }).name; - return typeof name === "string" ? name : undefined; -} - -function isTimeoutNamedError(error: unknown): boolean { - const name = readErrorName(error); - return name === "TimeoutError" || name === "RequestTimeoutError"; -} - -function readErrorStatus(error: unknown): number | undefined { - if (typeof error !== "object" || error === null) { - return undefined; - } - const record = error as { status?: unknown; statusCode?: unknown; code?: unknown }; - for (const value of [record.status, record.statusCode, record.code]) { - if (typeof value === "number" && Number.isInteger(value)) { - return value; - } - if (typeof value === "string" && /^\d{3}$/.test(value.trim())) { - return Number(value.trim()); - } - } - return undefined; -} - -function readErrorCode(error: unknown): string | undefined { - if (typeof error !== "object" || error === null) { - return undefined; - } - const code = (error as { code?: unknown }).code; - return typeof code === "string" ? code : undefined; -} - -function readErrorCause(error: unknown): unknown { - if (typeof error !== "object" || error === null) { - return undefined; - } - return (error as { cause?: unknown }).cause; -} - -function hasTransientNetworkSignal(error: unknown, message: string): boolean { - const transientCodes = /\b(?:ECONNRESET|ECONNREFUSED|ETIMEDOUT|EAI_AGAIN)\b/i; - if (transientCodes.test(message)) { - return true; - } - const code = readErrorCode(error); - if (code && transientCodes.test(code)) { - return true; - } - const cause = readErrorCause(error); - if (!cause || cause === error) { - return false; - } - const causeCode = readErrorCode(cause); - if (causeCode && transientCodes.test(causeCode)) { - return true; - } - const causeMessage = formatErrorMessage(cause); - return transientCodes.test(causeMessage); -} - -function hasTimeoutSignal(error: unknown, message: string): boolean { - if (isTimeoutNamedError(error)) { - return true; - } - if (/\b(?:request timeout|provider timeout|timed out|timeout)\b/i.test(message)) { - return true; - } - const cause = readErrorCause(error); - if (!cause || cause === error) { - return false; - } - if (isTimeoutNamedError(cause)) { - return true; - } - return /\b(?:request timeout|provider timeout|timed out|timeout)\b/i.test( - formatErrorMessage(cause), - ); -} - -export function isTransientProviderOperationError(error: unknown, message: string): boolean { - const status = readErrorStatus(error); - if (status !== undefined) { - return status === 500 || status === 502 || status === 503 || status === 504; - } - if ( - /\b(?:HTTP\s*)?(?:400|401|403|404)\b/i.test(message) || - /\b(?:invalid api key|permission denied|model not found|validation|unsupported model)\b/i.test( - message, - ) - ) { - return false; - } - if (/\b(?:HTTP\s*)?(?:500|502|503|504)\b/i.test(message)) { - return true; - } - if (hasTransientNetworkSignal(error, message)) { - return true; - } - if (hasTimeoutSignal(error, message)) { - return true; - } - if (/\bfetch failed\b/i.test(message)) { - return hasTransientNetworkSignal(error, message); - } - return false; -} - -export function resolveTransientProviderAttempts(options?: TransientProviderRetryOptions): number { - if (!options) { - return 1; - } - return Math.max(1, Math.round(Number.isFinite(options.attempts) ? options.attempts : 1)); -} - -export function resolveTransientProviderDelayMs( - options: TransientProviderRetryOptions, - attemptNumber: number, -): number { - const rawBaseDelayMs = options.baseDelayMs ?? 250; - const baseDelayMs = Math.max( - 0, - Math.round(Number.isFinite(rawBaseDelayMs) ? rawBaseDelayMs : 250), - ); - const rawMaxDelayMs = options.maxDelayMs ?? 1_000; - const maxDelayMs = Math.max( - baseDelayMs, - Math.round(Number.isFinite(rawMaxDelayMs) ? rawMaxDelayMs : 1_000), - ); - return Math.min(maxDelayMs, baseDelayMs * 2 ** Math.max(attemptNumber - 1, 0)); -} - -export function shouldRetrySameKeyProviderOperation(params: { - options: TransientProviderRetryOptions; - error: unknown; - message: string; - provider: string; - apiKeyIndex: number; - attemptNumber: number; - maxAttempts: number; -}): boolean { - if (params.attemptNumber >= params.maxAttempts) { - return false; - } - if (params.options.signal?.aborted) { - return false; - } - const retryParams: TransientProviderRetryParams = { - error: params.error, - message: params.message, - provider: params.provider, - apiKeyIndex: params.apiKeyIndex, - attemptNumber: params.attemptNumber, - }; - return params.options.shouldRetry - ? params.options.shouldRetry(retryParams) - : isTransientProviderOperationError(params.error, params.message); -} +export { + DEFAULT_TRANSIENT_PROVIDER_RETRY_OPTIONS, + defaultTransientProviderRetryForStage, + executeProviderOperationWithRetry, + isTransientProviderOperationError, + providerOperationRetryConfig, + resolveTransientProviderAttempts, + resolveTransientProviderDelayMs, + resolveTransientProviderRetryOptions, + shouldRetrySameKeyProviderOperation, + type ProviderOperationRetryStage, + type TransientProviderRetryConfig, + type TransientProviderRetryOptions, + type TransientProviderRetryParams, +} from "../provider-runtime/operation-retry.js"; diff --git a/src/media-understanding/runner.entries.ts b/src/media-understanding/runner.entries.ts index 28503ff50fc..9eec4b46ed8 100644 --- a/src/media-understanding/runner.entries.ts +++ b/src/media-understanding/runner.entries.ts @@ -22,6 +22,7 @@ import { resolveProxyFetchFromEnv } from "../infra/net/proxy-fetch.js"; import { resolvePreferredOpenClawTmpDir } from "../infra/tmp-openclaw-dir.js"; import { runFfmpeg } from "../media/ffmpeg-exec.js"; import { runExec } from "../process/exec.js"; +import { providerOperationRetryConfig } from "../provider-runtime/operation-retry.js"; import { normalizeLowercaseStringOrEmpty } from "../shared/string-coerce.js"; import { MediaAttachmentCache } from "./attachments.js"; import { @@ -648,7 +649,7 @@ export async function runProviderEntry(params: { const result = await executeWithApiKeyRotation({ provider: providerId, apiKeys, - transientRetry: true, + transientRetry: providerOperationRetryConfig("read"), execute: async (apiKey) => transcribeAudio({ buffer: media.buffer, @@ -706,7 +707,7 @@ export async function runProviderEntry(params: { const result = await executeWithApiKeyRotation({ provider: providerId, apiKeys, - transientRetry: true, + transientRetry: providerOperationRetryConfig("read"), execute: (apiKey) => describeVideo({ buffer: media.buffer, diff --git a/src/media-understanding/shared.test.ts b/src/media-understanding/shared.test.ts index a5109d9a65c..9da28a8d5bd 100644 --- a/src/media-understanding/shared.test.ts +++ b/src/media-understanding/shared.test.ts @@ -28,6 +28,7 @@ vi.mock("../infra/net/proxy-env.js", async () => { import { createProviderOperationDeadline, + fetchProviderDownloadResponse, fetchWithTimeoutGuarded, pollProviderOperationJson, postJsonRequest, @@ -179,6 +180,60 @@ describe("provider operation deadlines", () => { }), ).rejects.toThrow("model rejected"); }); + + it("retries transient provider status failures while polling", async () => { + vi.useFakeTimers(); + vi.setSystemTime(1_000); + const fetchFn = vi + .fn() + .mockResolvedValueOnce( + new Response("busy", { status: 503, statusText: "Service Unavailable" }), + ) + .mockResolvedValueOnce(new Response(JSON.stringify({ status: "completed" }))); + + const result = pollProviderOperationJson<{ status?: string }>({ + url: "https://api.example.com/v1/videos/task-1", + headers: new Headers({ authorization: "Bearer test" }), + deadline: createProviderOperationDeadline({ + label: "video generation task task-1", + timeoutMs: 10_000, + }), + defaultTimeoutMs: 5_000, + fetchFn, + maxAttempts: 3, + pollIntervalMs: 1_000, + requestFailedMessage: "status failed", + timeoutMessage: "task timed out", + isComplete: (payload) => payload.status === "completed", + }); + + await vi.advanceTimersByTimeAsync(250); + + await expect(result).resolves.toEqual({ status: "completed" }); + expect(fetchFn).toHaveBeenCalledTimes(2); + }); + + it("retries transient generated asset downloads", async () => { + const sleep = vi.fn(async () => undefined); + const fetchFn = vi + .fn() + .mockRejectedValueOnce(Object.assign(new Error("socket hang up"), { code: "ECONNRESET" })) + .mockResolvedValueOnce(new Response("video-bytes", { status: 200 })); + + const response = await fetchProviderDownloadResponse({ + url: "https://cdn.example.com/video.mp4", + init: { method: "GET" }, + timeoutMs: 5_000, + fetchFn, + provider: "test-video", + requestFailedMessage: "download failed", + retry: { attempts: 2, baseDelayMs: 0, maxDelayMs: 0, sleep }, + }); + + expect(await response.text()).toBe("video-bytes"); + expect(fetchFn).toHaveBeenCalledTimes(2); + expect(sleep).toHaveBeenCalledWith(0, undefined); + }); }); describe("resolveProviderHttpRequestConfig", () => { @@ -410,6 +465,54 @@ describe("fetchWithTimeoutGuarded", () => { expect(getFirstGuardedFetchCall().pinDns).toBe(false); }); + it("does not retry JSON POST requests by default", async () => { + fetchWithSsrFGuardMock.mockReset(); + fetchWithSsrFGuardMock + .mockRejectedValueOnce(Object.assign(new Error("socket hang up"), { code: "ECONNRESET" })) + .mockResolvedValueOnce({ + response: new Response(null, { status: 200 }), + finalUrl: "https://api.example.com", + release: async () => {}, + }); + + await expect( + postJsonRequest({ + url: "https://api.example.com/v1/create", + headers: new Headers(), + body: { prompt: "make a video" }, + fetchFn: fetch, + }), + ).rejects.toThrow("socket hang up"); + + expect(fetchWithSsrFGuardMock).toHaveBeenCalledTimes(1); + }); + + it("retries JSON POST requests only when marked as read operations", async () => { + fetchWithSsrFGuardMock.mockReset(); + const sleep = vi.fn(async () => undefined); + fetchWithSsrFGuardMock + .mockRejectedValueOnce(Object.assign(new Error("socket hang up"), { code: "ECONNRESET" })) + .mockResolvedValueOnce({ + response: new Response(null, { status: 200 }), + finalUrl: "https://api.example.com", + release: async () => {}, + }); + + await expect( + postJsonRequest({ + url: "https://api.example.com/v1/analyze", + headers: new Headers(), + body: { media: "base64" }, + fetchFn: fetch, + retryStage: "read", + retry: { attempts: 2, baseDelayMs: 0, maxDelayMs: 0, sleep }, + }), + ).resolves.toEqual(expect.objectContaining({ finalUrl: "https://api.example.com" })); + + expect(fetchWithSsrFGuardMock).toHaveBeenCalledTimes(2); + expect(sleep).toHaveBeenCalledWith(0, undefined); + }); + it("forwards explicit pinDns overrides to transcription requests", async () => { fetchWithSsrFGuardMock.mockResolvedValue({ response: new Response(null, { status: 200 }), diff --git a/src/media-understanding/shared.ts b/src/media-understanding/shared.ts index e8a58c03204..d3b76f9fe2e 100644 --- a/src/media-understanding/shared.ts +++ b/src/media-understanding/shared.ts @@ -15,6 +15,11 @@ import type { GuardedFetchMode, GuardedFetchResult } from "../infra/net/fetch-gu import { fetchWithSsrFGuard, GUARDED_FETCH_MODE } from "../infra/net/fetch-guard.js"; import { shouldUseEnvHttpProxyForUrl } from "../infra/net/proxy-env.js"; import type { LookupFn, PinnedDispatcherPolicy, SsrFPolicy } from "../infra/net/ssrf.js"; +import { + executeProviderOperationWithRetry, + type ProviderOperationRetryStage, + type TransientProviderRetryConfig, +} from "../provider-runtime/operation-retry.js"; import { fetchWithTimeout } from "../utils/fetch-timeout.js"; export { fetchWithTimeout }; export { normalizeBaseUrl } from "../agents/provider-request-config.js"; @@ -130,19 +135,20 @@ export async function pollProviderOperationJson(params: { getFailureMessage?: (payload: TPayload) => string | undefined; }): Promise { for (let attempt = 0; attempt < params.maxAttempts; attempt += 1) { - const response = await fetchWithTimeout( - params.url, - { + const response = await fetchProviderOperationResponse({ + stage: "poll", + url: params.url, + init: { method: "GET", headers: params.headers, }, - resolveProviderOperationTimeoutMs({ + timeoutMs: resolveProviderOperationTimeoutMs({ deadline: params.deadline, defaultTimeoutMs: params.defaultTimeoutMs, }), - params.fetchFn, - ); - await assertOkOrThrowHttpError(response, params.requestFailedMessage); + fetchFn: params.fetchFn, + requestFailedMessage: params.requestFailedMessage, + }); const payload = (await response.json()) as TPayload; if (params.isComplete(payload)) { return payload; @@ -159,6 +165,56 @@ export async function pollProviderOperationJson(params: { throw new Error(params.timeoutMessage); } +export async function fetchProviderOperationResponse(params: { + stage: ProviderOperationRetryStage; + url: string; + init?: RequestInit; + timeoutMs?: number; + fetchFn: typeof fetch; + provider?: string; + requestFailedMessage?: string; + retry?: TransientProviderRetryConfig; +}): Promise { + return await executeProviderOperationWithRetry({ + provider: params.provider ?? "provider-http", + stage: params.stage, + retry: params.retry, + operation: async () => { + const response = await fetchWithTimeout( + params.url, + params.init ?? {}, + params.timeoutMs ?? DEFAULT_GUARDED_HTTP_TIMEOUT_MS, + params.fetchFn, + ); + if (params.requestFailedMessage) { + await assertOkOrThrowHttpError(response, params.requestFailedMessage); + } + return response; + }, + }); +} + +export async function fetchProviderDownloadResponse(params: { + url: string; + init?: RequestInit; + timeoutMs?: number; + fetchFn: typeof fetch; + provider?: string; + requestFailedMessage: string; + retry?: TransientProviderRetryConfig; +}): Promise { + return await fetchProviderOperationResponse({ + stage: "download", + url: params.url, + init: params.init, + timeoutMs: params.timeoutMs, + fetchFn: params.fetchFn, + provider: params.provider, + requestFailedMessage: params.requestFailedMessage, + retry: params.retry, + }); +} + function resolveGuardedHttpTimeoutMs(timeoutMs: number | undefined): number { if (typeof timeoutMs !== "number" || !Number.isFinite(timeoutMs) || timeoutMs <= 0) { return DEFAULT_GUARDED_HTTP_TIMEOUT_MS; @@ -360,97 +416,146 @@ function resolveGuardedPostRequestOptions(params: { }; } -export async function postTranscriptionRequest(params: { - url: string; - headers: Headers; - body: BodyInit; - timeoutMs?: number; - fetchFn: typeof fetch; - pinDns?: boolean; - allowPrivateNetwork?: boolean; - ssrfPolicy?: SsrFPolicy; - dispatcherPolicy?: PinnedDispatcherPolicy; - auditContext?: string; +type GuardedPostRequestRetryOptions = { /** - * Override the guarded-fetch mode. Defaults to an auto-upgrade to - * `TRUSTED_ENV_PROXY` when `HTTP_PROXY`/`HTTPS_PROXY` is configured in the - * environment; pass `"strict"` to force pinned-DNS even inside a proxy. + * POST requests default to no retry because many provider endpoints create + * billable jobs. Pass "read" only for read/analysis POST endpoints. */ - mode?: GuardedFetchMode; -}) { - return fetchWithTimeoutGuarded( - params.url, - { + retryStage?: ProviderOperationRetryStage; + retry?: TransientProviderRetryConfig; +}; + +export async function postTranscriptionRequest( + params: { + url: string; + headers: Headers; + body: BodyInit; + timeoutMs?: number; + fetchFn: typeof fetch; + pinDns?: boolean; + allowPrivateNetwork?: boolean; + ssrfPolicy?: SsrFPolicy; + dispatcherPolicy?: PinnedDispatcherPolicy; + auditContext?: string; + /** + * Override the guarded-fetch mode. Defaults to an auto-upgrade to + * `TRUSTED_ENV_PROXY` when `HTTP_PROXY`/`HTTPS_PROXY` is configured in the + * environment; pass `"strict"` to force pinned-DNS even inside a proxy. + */ + mode?: GuardedFetchMode; + } & GuardedPostRequestRetryOptions, +) { + return await postGuardedRequest({ + url: params.url, + init: { method: "POST", headers: params.headers, body: params.body, }, - params.timeoutMs, - params.fetchFn, - resolveGuardedPostRequestOptions(params), - ); + timeoutMs: params.timeoutMs, + fetchFn: params.fetchFn, + guardedOptions: resolveGuardedPostRequestOptions(params), + retryStage: params.retryStage, + retry: params.retry, + }); } -export async function postJsonRequest(params: { +async function postGuardedRequest(params: { url: string; - headers: Headers; - body: unknown; + init: RequestInit; timeoutMs?: number; fetchFn: typeof fetch; - pinDns?: boolean; - allowPrivateNetwork?: boolean; - ssrfPolicy?: SsrFPolicy; - dispatcherPolicy?: PinnedDispatcherPolicy; - auditContext?: string; - /** - * Override the guarded-fetch mode. Defaults to an auto-upgrade to - * `TRUSTED_ENV_PROXY` when `HTTP_PROXY`/`HTTPS_PROXY` is configured in the - * environment; pass `"strict"` to force pinned-DNS even inside a proxy. - */ - mode?: GuardedFetchMode; + guardedOptions?: GuardedPostRequestOptions; + retryStage?: ProviderOperationRetryStage; + retry?: TransientProviderRetryConfig; }) { - return fetchWithTimeoutGuarded( - params.url, - { + const operation = () => + fetchWithTimeoutGuarded( + params.url, + params.init, + params.timeoutMs, + params.fetchFn, + params.guardedOptions, + ); + if (!params.retryStage) { + return await operation(); + } + return await executeProviderOperationWithRetry({ + provider: "provider-http", + stage: params.retryStage, + retry: params.retry, + operation, + }); +} + +export async function postJsonRequest( + params: { + url: string; + headers: Headers; + body: unknown; + timeoutMs?: number; + fetchFn: typeof fetch; + pinDns?: boolean; + allowPrivateNetwork?: boolean; + ssrfPolicy?: SsrFPolicy; + dispatcherPolicy?: PinnedDispatcherPolicy; + auditContext?: string; + /** + * Override the guarded-fetch mode. Defaults to an auto-upgrade to + * `TRUSTED_ENV_PROXY` when `HTTP_PROXY`/`HTTPS_PROXY` is configured in the + * environment; pass `"strict"` to force pinned-DNS even inside a proxy. + */ + mode?: GuardedFetchMode; + } & GuardedPostRequestRetryOptions, +) { + return await postGuardedRequest({ + url: params.url, + init: { method: "POST", headers: params.headers, body: JSON.stringify(params.body), }, - params.timeoutMs, - params.fetchFn, - resolveGuardedPostRequestOptions(params), - ); + timeoutMs: params.timeoutMs, + fetchFn: params.fetchFn, + guardedOptions: resolveGuardedPostRequestOptions(params), + retryStage: params.retryStage, + retry: params.retry, + }); } -export async function postMultipartRequest(params: { - url: string; - headers: Headers; - body: BodyInit; - timeoutMs?: number; - fetchFn: typeof fetch; - pinDns?: boolean; - allowPrivateNetwork?: boolean; - ssrfPolicy?: SsrFPolicy; - dispatcherPolicy?: PinnedDispatcherPolicy; - auditContext?: string; - /** - * Override the guarded-fetch mode. Defaults to an auto-upgrade to - * `TRUSTED_ENV_PROXY` when `HTTP_PROXY`/`HTTPS_PROXY` is configured in the - * environment; pass `"strict"` to force pinned-DNS even inside a proxy. - */ - mode?: GuardedFetchMode; -}) { - return fetchWithTimeoutGuarded( - params.url, - { +export async function postMultipartRequest( + params: { + url: string; + headers: Headers; + body: BodyInit; + timeoutMs?: number; + fetchFn: typeof fetch; + pinDns?: boolean; + allowPrivateNetwork?: boolean; + ssrfPolicy?: SsrFPolicy; + dispatcherPolicy?: PinnedDispatcherPolicy; + auditContext?: string; + /** + * Override the guarded-fetch mode. Defaults to an auto-upgrade to + * `TRUSTED_ENV_PROXY` when `HTTP_PROXY`/`HTTPS_PROXY` is configured in the + * environment; pass `"strict"` to force pinned-DNS even inside a proxy. + */ + mode?: GuardedFetchMode; + } & GuardedPostRequestRetryOptions, +) { + return await postGuardedRequest({ + url: params.url, + init: { method: "POST", headers: params.headers, body: params.body, }, - params.timeoutMs, - params.fetchFn, - resolveGuardedPostRequestOptions(params), - ); + timeoutMs: params.timeoutMs, + fetchFn: params.fetchFn, + guardedOptions: resolveGuardedPostRequestOptions(params), + retryStage: params.retryStage, + retry: params.retry, + }); } export async function readErrorResponse(res: Response): Promise { diff --git a/src/plugin-sdk/provider-http.ts b/src/plugin-sdk/provider-http.ts index b7a1cca8b56..b3abcd30290 100644 --- a/src/plugin-sdk/provider-http.ts +++ b/src/plugin-sdk/provider-http.ts @@ -15,6 +15,8 @@ export { export { buildAudioTranscriptionFormData, createProviderOperationDeadline, + fetchProviderDownloadResponse, + fetchProviderOperationResponse, fetchWithTimeout, fetchWithTimeoutGuarded, normalizeBaseUrl, @@ -30,6 +32,16 @@ export { waitProviderOperationPollInterval, } from "../media-understanding/shared.js"; export type { ProviderOperationDeadline } from "../media-understanding/shared.js"; +export { + executeProviderOperationWithRetry, + providerOperationRetryConfig, +} from "../provider-runtime/operation-retry.js"; +export type { + ProviderOperationRetryStage, + TransientProviderRetryConfig, + TransientProviderRetryOptions, + TransientProviderRetryParams, +} from "../provider-runtime/operation-retry.js"; export type { ProviderAttributionPolicy, ProviderRequestCapabilities, diff --git a/src/plugin-sdk/test-helpers/provider-http-mocks.ts b/src/plugin-sdk/test-helpers/provider-http-mocks.ts index 73a8dcd9204..5526aba5693 100644 --- a/src/plugin-sdk/test-helpers/provider-http-mocks.ts +++ b/src/plugin-sdk/test-helpers/provider-http-mocks.ts @@ -1,5 +1,7 @@ import { afterEach, vi, type Mock } from "vitest"; import type { + fetchProviderDownloadResponse, + fetchProviderOperationResponse, pollProviderOperationJson, resolveProviderHttpRequestConfig, sanitizeConfiguredModelProviderRequest, @@ -9,6 +11,8 @@ type ResolveProviderHttpRequestConfigParams = Parameters< typeof resolveProviderHttpRequestConfig >[0]; type PollProviderOperationJsonParams = Parameters[0]; +type FetchProviderOperationResponseParams = Parameters[0]; +type FetchProviderDownloadResponseParams = Parameters[0]; type SanitizeConfiguredModelProviderRequestParams = Parameters< typeof sanitizeConfiguredModelProviderRequest >[0]; @@ -43,6 +47,8 @@ const providerHttpMocks = vi.hoisted(() => ({ resolveApiKeyForProviderMock: vi.fn(async () => ({ apiKey: "provider-key" })), postJsonRequestMock: vi.fn(), fetchWithTimeoutMock: vi.fn(), + fetchProviderOperationResponseMock: vi.fn(), + fetchProviderDownloadResponseMock: vi.fn(), pollProviderOperationJsonMock: vi.fn(), assertOkOrThrowHttpErrorMock: vi.fn(async (_response: Response, _label: string) => {}), assertOkOrThrowProviderErrorMock: vi.fn(async (_response: Response, _label: string) => {}), @@ -57,6 +63,34 @@ const providerHttpMocks = vi.hoisted(() => ({ })), })); +providerHttpMocks.fetchProviderOperationResponseMock.mockImplementation( + async (params: FetchProviderOperationResponseParams) => { + const response = await providerHttpMocks.fetchWithTimeoutMock( + params.url, + params.init ?? {}, + params.timeoutMs ?? 60_000, + params.fetchFn, + ); + if (params.requestFailedMessage) { + await providerHttpMocks.assertOkOrThrowHttpErrorMock(response, params.requestFailedMessage); + } + return response; + }, +); + +providerHttpMocks.fetchProviderDownloadResponseMock.mockImplementation( + async (params: FetchProviderDownloadResponseParams) => { + const response = await providerHttpMocks.fetchWithTimeoutMock( + params.url, + params.init ?? {}, + params.timeoutMs ?? 60_000, + params.fetchFn, + ); + await providerHttpMocks.assertOkOrThrowHttpErrorMock(response, params.requestFailedMessage); + return response; + }, +); + providerHttpMocks.pollProviderOperationJsonMock.mockImplementation( async (params: PollProviderOperationJsonParams) => { for (let attempt = 0; attempt < params.maxAttempts; attempt += 1) { @@ -100,9 +134,14 @@ vi.mock("openclaw/plugin-sdk/provider-http", () => ({ label, timeoutMs, }), + executeProviderOperationWithRetry: async ({ operation }: { operation: () => Promise }) => + await operation(), + fetchProviderDownloadResponse: providerHttpMocks.fetchProviderDownloadResponseMock, + fetchProviderOperationResponse: providerHttpMocks.fetchProviderOperationResponseMock, fetchWithTimeout: providerHttpMocks.fetchWithTimeoutMock, pollProviderOperationJson: providerHttpMocks.pollProviderOperationJsonMock, postJsonRequest: providerHttpMocks.postJsonRequestMock, + providerOperationRetryConfig: (_stage: string) => true, resolveProviderOperationTimeoutMs: ({ defaultTimeoutMs }: { defaultTimeoutMs: number }) => defaultTimeoutMs, resolveProviderHttpRequestConfig: providerHttpMocks.resolveProviderHttpRequestConfigMock, @@ -120,6 +159,8 @@ export function installProviderHttpMockCleanup(): void { providerHttpMocks.resolveApiKeyForProviderMock.mockClear(); providerHttpMocks.postJsonRequestMock.mockReset(); providerHttpMocks.fetchWithTimeoutMock.mockReset(); + providerHttpMocks.fetchProviderOperationResponseMock.mockClear(); + providerHttpMocks.fetchProviderDownloadResponseMock.mockClear(); providerHttpMocks.pollProviderOperationJsonMock.mockClear(); providerHttpMocks.assertOkOrThrowHttpErrorMock.mockClear(); providerHttpMocks.assertOkOrThrowProviderErrorMock.mockClear(); diff --git a/src/provider-runtime/operation-retry.ts b/src/provider-runtime/operation-retry.ts new file mode 100644 index 00000000000..161fec1b99e --- /dev/null +++ b/src/provider-runtime/operation-retry.ts @@ -0,0 +1,266 @@ +import { sleepWithAbort } from "../infra/backoff.js"; +import { formatErrorMessage } from "../infra/errors.js"; + +export type ProviderOperationRetryStage = "read" | "poll" | "download" | "create"; + +export type TransientProviderRetryParams = { + error: unknown; + message: string; + provider: string; + apiKeyIndex: number; + attemptNumber: number; + stage?: ProviderOperationRetryStage; +}; + +export type TransientProviderRetryOptions = { + /** + * Total executions, including the first call. + * attempts: 2 means one initial call plus one retry. + */ + attempts: number; + baseDelayMs?: number; + maxDelayMs?: number; + signal?: AbortSignal; + shouldRetry?: (params: TransientProviderRetryParams) => boolean; + sleep?: (ms: number, signal?: AbortSignal) => Promise; +}; + +export type TransientProviderRetryConfig = boolean | TransientProviderRetryOptions; + +export const DEFAULT_TRANSIENT_PROVIDER_RETRY_OPTIONS = { + attempts: 2, + baseDelayMs: 250, + maxDelayMs: 1_000, +} as const satisfies TransientProviderRetryOptions; + +export function resolveTransientProviderRetryOptions( + options?: TransientProviderRetryConfig, +): TransientProviderRetryOptions | undefined { + if (!options) { + return undefined; + } + if (options === true) { + return DEFAULT_TRANSIENT_PROVIDER_RETRY_OPTIONS; + } + return options; +} + +export function defaultTransientProviderRetryForStage( + stage: ProviderOperationRetryStage, +): TransientProviderRetryConfig | undefined { + return stage === "create" ? undefined : true; +} + +export function providerOperationRetryConfig( + stage: ProviderOperationRetryStage, + options?: TransientProviderRetryConfig, +): TransientProviderRetryConfig | undefined { + return options ?? defaultTransientProviderRetryForStage(stage); +} + +function readErrorName(error: unknown): string | undefined { + if (typeof error !== "object" || error === null) { + return undefined; + } + const name = (error as { name?: unknown }).name; + return typeof name === "string" ? name : undefined; +} + +function isTimeoutNamedError(error: unknown): boolean { + const name = readErrorName(error); + return name === "TimeoutError" || name === "RequestTimeoutError"; +} + +function readErrorStatus(error: unknown): number | undefined { + if (typeof error !== "object" || error === null) { + return undefined; + } + const record = error as { status?: unknown; statusCode?: unknown; code?: unknown }; + for (const value of [record.status, record.statusCode, record.code]) { + if (typeof value === "number" && Number.isInteger(value)) { + return value; + } + if (typeof value === "string" && /^\d{3}$/.test(value.trim())) { + return Number(value.trim()); + } + } + return undefined; +} + +function readErrorCode(error: unknown): string | undefined { + if (typeof error !== "object" || error === null) { + return undefined; + } + const code = (error as { code?: unknown }).code; + return typeof code === "string" ? code : undefined; +} + +function readErrorCause(error: unknown): unknown { + if (typeof error !== "object" || error === null) { + return undefined; + } + return (error as { cause?: unknown }).cause; +} + +function hasTransientNetworkSignal(error: unknown, message: string): boolean { + const transientCodes = /\b(?:ECONNRESET|ECONNREFUSED|ETIMEDOUT|EAI_AGAIN)\b/i; + if (transientCodes.test(message)) { + return true; + } + const code = readErrorCode(error); + if (code && transientCodes.test(code)) { + return true; + } + const cause = readErrorCause(error); + if (!cause || cause === error) { + return false; + } + const causeCode = readErrorCode(cause); + if (causeCode && transientCodes.test(causeCode)) { + return true; + } + const causeMessage = formatErrorMessage(cause); + return transientCodes.test(causeMessage); +} + +function hasTimeoutSignal(error: unknown, message: string): boolean { + if (isTimeoutNamedError(error)) { + return true; + } + if (/\b(?:request timeout|provider timeout|timed out|timeout)\b/i.test(message)) { + return true; + } + const cause = readErrorCause(error); + if (!cause || cause === error) { + return false; + } + if (isTimeoutNamedError(cause)) { + return true; + } + return /\b(?:request timeout|provider timeout|timed out|timeout)\b/i.test( + formatErrorMessage(cause), + ); +} + +export function isTransientProviderOperationError(error: unknown, message: string): boolean { + const status = readErrorStatus(error); + if (status !== undefined) { + return status === 500 || status === 502 || status === 503 || status === 504; + } + if ( + /\b(?:HTTP\s*)?(?:400|401|403|404)\b/i.test(message) || + /\b(?:invalid api key|permission denied|model not found|validation|unsupported model)\b/i.test( + message, + ) + ) { + return false; + } + if (/\b(?:HTTP\s*)?(?:500|502|503|504)\b/i.test(message)) { + return true; + } + if (hasTransientNetworkSignal(error, message)) { + return true; + } + if (hasTimeoutSignal(error, message)) { + return true; + } + if (/\bfetch failed\b/i.test(message)) { + return hasTransientNetworkSignal(error, message); + } + return false; +} + +export function resolveTransientProviderAttempts(options?: TransientProviderRetryOptions): number { + if (!options) { + return 1; + } + return Math.max(1, Math.round(Number.isFinite(options.attempts) ? options.attempts : 1)); +} + +export function resolveTransientProviderDelayMs( + options: TransientProviderRetryOptions, + attemptNumber: number, +): number { + const rawBaseDelayMs = options.baseDelayMs ?? 250; + const baseDelayMs = Math.max( + 0, + Math.round(Number.isFinite(rawBaseDelayMs) ? rawBaseDelayMs : 250), + ); + const rawMaxDelayMs = options.maxDelayMs ?? 1_000; + const maxDelayMs = Math.max( + baseDelayMs, + Math.round(Number.isFinite(rawMaxDelayMs) ? rawMaxDelayMs : 1_000), + ); + return Math.min(maxDelayMs, baseDelayMs * 2 ** Math.max(attemptNumber - 1, 0)); +} + +export function shouldRetrySameKeyProviderOperation(params: { + options: TransientProviderRetryOptions; + error: unknown; + message: string; + provider: string; + apiKeyIndex: number; + attemptNumber: number; + maxAttempts: number; + stage?: ProviderOperationRetryStage; +}): boolean { + if (params.attemptNumber >= params.maxAttempts) { + return false; + } + if (params.options.signal?.aborted) { + return false; + } + const retryParams: TransientProviderRetryParams = { + error: params.error, + message: params.message, + provider: params.provider, + apiKeyIndex: params.apiKeyIndex, + attemptNumber: params.attemptNumber, + ...(params.stage ? { stage: params.stage } : {}), + }; + return params.options.shouldRetry + ? params.options.shouldRetry(retryParams) + : isTransientProviderOperationError(params.error, params.message); +} + +export async function executeProviderOperationWithRetry(params: { + provider: string; + stage: ProviderOperationRetryStage; + operation: () => Promise; + retry?: TransientProviderRetryConfig; +}): Promise { + const retryConfig = providerOperationRetryConfig(params.stage, params.retry); + const retryOptions = resolveTransientProviderRetryOptions(retryConfig); + const maxAttempts = resolveTransientProviderAttempts(retryOptions); + let lastError: unknown; + + for (let attemptNumber = 1; attemptNumber <= maxAttempts; attemptNumber += 1) { + try { + return await params.operation(); + } catch (error) { + lastError = error; + const message = formatErrorMessage(error); + if ( + !retryOptions || + !shouldRetrySameKeyProviderOperation({ + options: retryOptions, + error, + message, + provider: params.provider, + apiKeyIndex: 0, + attemptNumber, + maxAttempts, + stage: params.stage, + }) + ) { + throw error; + } + + const delayMs = resolveTransientProviderDelayMs(retryOptions, attemptNumber); + const sleep = retryOptions.sleep ?? sleepWithAbort; + await sleep(delayMs, retryOptions.signal); + } + } + + throw lastError; +} diff --git a/src/video-generation/dashscope-compatible.ts b/src/video-generation/dashscope-compatible.ts index ea847fe4f04..78bcafa71a3 100644 --- a/src/video-generation/dashscope-compatible.ts +++ b/src/video-generation/dashscope-compatible.ts @@ -1,7 +1,8 @@ import { assertOkOrThrowHttpError, createProviderOperationDeadline, - fetchWithTimeout, + fetchProviderDownloadResponse, + fetchProviderOperationResponse, postJsonRequest, resolveProviderOperationTimeoutMs, waitProviderOperationPollInterval, @@ -174,19 +175,18 @@ export async function pollDashscopeVideoTaskUntilComplete(params: { label: `${params.providerLabel} video generation task ${params.taskId}`, }); for (let attempt = 0; attempt < DEFAULT_VIDEO_GENERATION_MAX_POLL_ATTEMPTS; attempt += 1) { - const response = await fetchWithTimeout( - `${params.baseUrl}/api/v1/tasks/${params.taskId}`, - { + const response = await fetchProviderOperationResponse({ + stage: "poll", + url: `${params.baseUrl}/api/v1/tasks/${params.taskId}`, + init: { method: "GET", headers: params.headers, }, - resolveProviderOperationTimeoutMs({ deadline, defaultTimeoutMs }), - params.fetchFn, - ); - await assertOkOrThrowHttpError( - response, - `${params.providerLabel} video-generation task poll failed`, - ); + timeoutMs: resolveProviderOperationTimeoutMs({ deadline, defaultTimeoutMs }), + fetchFn: params.fetchFn, + provider: params.providerLabel, + requestFailedMessage: `${params.providerLabel} video-generation task poll failed`, + }); const payload = (await response.json()) as DashscopeVideoGenerationResponse; const status = payload.output?.task_status?.trim().toUpperCase(); if (status === "SUCCEEDED") { @@ -302,16 +302,14 @@ export async function downloadDashscopeGeneratedVideos(params: { }): Promise { const videos: GeneratedVideoAsset[] = []; for (const [index, url] of params.urls.entries()) { - const response = await fetchWithTimeout( + const response = await fetchProviderDownloadResponse({ url, - { method: "GET" }, - params.timeoutMs ?? params.defaultTimeoutMs ?? DEFAULT_VIDEO_GENERATION_TIMEOUT_MS, - params.fetchFn, - ); - await assertOkOrThrowHttpError( - response, - `${params.providerLabel} generated video download failed`, - ); + init: { method: "GET" }, + timeoutMs: params.timeoutMs ?? params.defaultTimeoutMs ?? DEFAULT_VIDEO_GENERATION_TIMEOUT_MS, + fetchFn: params.fetchFn, + provider: params.providerLabel, + requestFailedMessage: `${params.providerLabel} generated video download failed`, + }); const arrayBuffer = await response.arrayBuffer(); videos.push({ buffer: Buffer.from(arrayBuffer),