refactor: shortcut bundled provider contract fixtures

This commit is contained in:
Peter Steinberger
2026-03-27 03:55:54 +00:00
parent 17203d0af9
commit a4b77ad33f
16 changed files with 354 additions and 107 deletions

View File

@@ -0,0 +1 @@
export { __testing, createBraveWebSearchProvider } from "./src/brave-web-search-provider.js";

View File

@@ -0,0 +1 @@
export { createDuckDuckGoWebSearchProvider } from "./src/ddg-search-provider.js";

View File

@@ -0,0 +1 @@
export { __testing, createExaWebSearchProvider } from "./src/exa-web-search-provider.js";

View File

@@ -0,0 +1 @@
export { createFirecrawlWebSearchProvider } from "./src/firecrawl-search-provider.js";

View File

@@ -0,0 +1 @@
export { createGeminiWebSearchProvider } from "./src/gemini-web-search-provider.js";

View File

@@ -0,0 +1 @@
export { createKimiWebSearchProvider } from "./src/kimi-web-search-provider.js";

View File

@@ -0,0 +1 @@
export { createTavilyWebSearchProvider } from "./src/tavily-search-provider.js";

View File

@@ -0,0 +1,23 @@
import { buildFalImageGenerationProvider } from "../extensions/fal/image-generation-provider.js";
import { buildGoogleImageGenerationProvider } from "../extensions/google/image-generation-provider.js";
import {
buildMinimaxImageGenerationProvider,
buildMinimaxPortalImageGenerationProvider,
} from "../extensions/minimax/image-generation-provider.js";
import { buildOpenAIImageGenerationProvider } from "../extensions/openai/image-generation-provider.js";
import type { ImageGenerationProviderPlugin } from "./plugins/types.js";
type BundledImageGenerationProviderEntry = {
pluginId: string;
provider: ImageGenerationProviderPlugin;
};
export function listBundledImageGenerationProviderEntries(): BundledImageGenerationProviderEntry[] {
return [
{ pluginId: "fal", provider: buildFalImageGenerationProvider() },
{ pluginId: "google", provider: buildGoogleImageGenerationProvider() },
{ pluginId: "minimax", provider: buildMinimaxImageGenerationProvider() },
{ pluginId: "minimax", provider: buildMinimaxPortalImageGenerationProvider() },
{ pluginId: "openai", provider: buildOpenAIImageGenerationProvider() },
];
}

View File

@@ -0,0 +1,24 @@
import { createBraveWebSearchProvider } from "../extensions/brave/web-search-provider.js";
import { createDuckDuckGoWebSearchProvider } from "../extensions/duckduckgo/web-search-provider.js";
import { createExaWebSearchProvider } from "../extensions/exa/web-search-provider.js";
import { createFirecrawlWebSearchProvider } from "../extensions/firecrawl/web-search-provider.js";
import { createGeminiWebSearchProvider } from "../extensions/google/web-search-provider.js";
import { createKimiWebSearchProvider } from "../extensions/moonshot/web-search-provider.js";
import { createPerplexityWebSearchProvider } from "../extensions/perplexity/web-search-provider.js";
import { createTavilyWebSearchProvider } from "../extensions/tavily/web-search-provider.js";
import { createXaiWebSearchProvider } from "../extensions/xai/web-search.js";
import type { PluginWebSearchProviderEntry } from "./plugins/types.js";
export function listBundledWebSearchProviderEntries(): PluginWebSearchProviderEntry[] {
return [
{ pluginId: "brave", ...createBraveWebSearchProvider() },
{ pluginId: "duckduckgo", ...createDuckDuckGoWebSearchProvider() },
{ pluginId: "exa", ...createExaWebSearchProvider() },
{ pluginId: "firecrawl", ...createFirecrawlWebSearchProvider() },
{ pluginId: "google", ...createGeminiWebSearchProvider() },
{ pluginId: "moonshot", ...createKimiWebSearchProvider() },
{ pluginId: "perplexity", ...createPerplexityWebSearchProvider() },
{ pluginId: "tavily", ...createTavilyWebSearchProvider() },
{ pluginId: "xai", ...createXaiWebSearchProvider() },
];
}

View File

@@ -1,4 +1,5 @@
import { describe, expect, it } from "vitest";
import { listBundledWebSearchProviderEntries } from "../bundled-web-search.entries.js";
import type { OpenClawConfig } from "../config/config.js";
import { BUNDLED_WEB_SEARCH_PLUGIN_IDS } from "./bundled-web-search-ids.js";
import { resolveBundledWebSearchPluginId } from "./bundled-web-search-provider-ids.js";
@@ -6,7 +7,6 @@ import {
listBundledWebSearchProviders,
resolveBundledWebSearchPluginIds,
} from "./bundled-web-search.js";
import { webSearchProviderContractRegistry } from "./contracts/registry.js";
describe("bundled web search metadata", () => {
function toComparableEntry(params: {
@@ -105,6 +105,7 @@ describe("bundled web search metadata", () => {
it("keeps bundled provider metadata aligned with bundled plugin contracts", async () => {
const fastPathProviders = listBundledWebSearchProviders();
const bundledProviderEntries = listBundledWebSearchProviderEntries();
expect(
sortComparableEntries(
@@ -117,7 +118,7 @@ describe("bundled web search metadata", () => {
),
).toEqual(
sortComparableEntries(
webSearchProviderContractRegistry.map(({ pluginId, provider }) =>
bundledProviderEntries.map(({ pluginId, ...provider }) =>
toComparableEntry({
pluginId,
provider,
@@ -127,12 +128,11 @@ describe("bundled web search metadata", () => {
);
for (const fastPathProvider of fastPathProviders) {
const contractEntry = webSearchProviderContractRegistry.find(
(entry) =>
entry.pluginId === fastPathProvider.pluginId && entry.provider.id === fastPathProvider.id,
const bundledEntry = bundledProviderEntries.find(
(entry) => entry.pluginId === fastPathProvider.pluginId && entry.id === fastPathProvider.id,
);
expect(contractEntry).toBeDefined();
const contractProvider = contractEntry!.provider;
expect(bundledEntry).toBeDefined();
const contractProvider = bundledEntry!;
const fastSearchConfig: Record<string, unknown> = {};
const contractSearchConfig: Record<string, unknown> = {};

View File

@@ -1,6 +1,6 @@
import { listBundledWebSearchProviderEntries } from "../bundled-web-search.entries.js";
import { BUNDLED_WEB_SEARCH_PLUGIN_IDS } from "./bundled-capability-metadata.js";
import { resolveBundledWebSearchPluginId as resolveBundledWebSearchPluginIdFromMap } from "./bundled-web-search-provider-ids.js";
import { webSearchProviderContractRegistry } from "./contracts/registry.js";
import type { PluginLoadOptions } from "./loader.js";
import { loadPluginManifestRegistry } from "./manifest-registry.js";
import type { PluginWebSearchProviderEntry } from "./types.js";
@@ -11,10 +11,7 @@ let bundledWebSearchProvidersCache: BundledWebSearchProviderEntry[] | null = nul
function loadBundledWebSearchProviders(): BundledWebSearchProviderEntry[] {
if (!bundledWebSearchProvidersCache) {
bundledWebSearchProvidersCache = webSearchProviderContractRegistry.map((entry) => ({
...entry.provider,
pluginId: entry.pluginId,
}));
bundledWebSearchProvidersCache = listBundledWebSearchProviderEntries();
}
return bundledWebSearchProvidersCache;
}

View File

@@ -2,7 +2,7 @@ import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
import { clearRuntimeAuthProfileStoreSnapshots } from "../../agents/auth-profiles/store.js";
import { resolvePreferredProviderForAuthChoice } from "../../plugins/provider-auth-choice-preference.js";
import { buildProviderPluginMethodChoice } from "../provider-wizard.js";
import { requireProviderContractProvider, uniqueProviderContractProviders } from "./registry.js";
import type { ProviderPlugin } from "../types.js";
type ResolvePluginProviders =
typeof import("../../plugins/provider-auth-choice.runtime.js").resolvePluginProviders;
@@ -15,6 +15,7 @@ const resolveProviderPluginChoiceMock = vi.hoisted(() => vi.fn<ResolveProviderPl
const runProviderModelSelectedHookMock = vi.hoisted(() =>
vi.fn<RunProviderModelSelectedHook>(async () => {}),
);
const runAuthMethodMock = vi.hoisted(() => vi.fn(async () => ({ profiles: [] })));
vi.mock("../../plugins/provider-auth-choice.runtime.js", () => ({
resolvePluginProviders: resolvePluginProvidersMock,
@@ -25,7 +26,7 @@ vi.mock("../../plugins/provider-auth-choice.runtime.js", () => ({
describe("provider auth-choice contract", () => {
beforeEach(() => {
resolvePluginProvidersMock.mockReset();
resolvePluginProvidersMock.mockReturnValue(uniqueProviderContractProviders);
resolvePluginProvidersMock.mockReturnValue([]);
resolveProviderPluginChoiceMock.mockReset();
resolveProviderPluginChoiceMock.mockImplementation(({ providers, choice }) => {
const provider = providers.find((entry) =>
@@ -55,24 +56,69 @@ describe("provider auth-choice contract", () => {
});
it("maps provider-plugin choices through the shared preferred-provider fallback resolver", async () => {
const pluginFallbackScenarios = [
"github-copilot",
"minimax-portal",
"modelstudio",
"ollama",
].map((providerId) => {
const provider = requireProviderContractProvider(providerId);
return {
authChoice: buildProviderPluginMethodChoice(provider.id, provider.auth[0]?.id ?? "default"),
expectedProvider: provider.id,
};
});
const pluginFallbackScenarios: ProviderPlugin[] = [
{
id: "github-copilot",
label: "GitHub Copilot",
auth: [
{
id: "oauth",
label: "OAuth",
hint: "Browser sign-in",
kind: "oauth",
run: runAuthMethodMock,
},
],
},
{
id: "minimax-portal",
label: "MiniMax Portal",
auth: [
{
id: "portal",
label: "Portal",
hint: "Browser sign-in",
kind: "oauth",
run: runAuthMethodMock,
},
],
},
{
id: "modelstudio",
label: "ModelStudio",
auth: [
{
id: "api-key",
label: "API key",
hint: "Paste key",
kind: "api_key",
run: runAuthMethodMock,
},
],
},
{
id: "ollama",
label: "Ollama",
auth: [
{
id: "local",
label: "Local",
hint: "No auth",
kind: "custom",
run: runAuthMethodMock,
},
],
},
];
for (const scenario of pluginFallbackScenarios) {
for (const provider of pluginFallbackScenarios) {
resolvePluginProvidersMock.mockClear();
resolvePluginProvidersMock.mockReturnValue([provider]);
await expect(
resolvePreferredProviderForAuthChoice({ choice: scenario.authChoice }),
).resolves.toBe(scenario.expectedProvider);
resolvePreferredProviderForAuthChoice({
choice: buildProviderPluginMethodChoice(provider.id, provider.auth[0]?.id ?? "default"),
}),
).resolves.toBe(provider.id);
expect(resolvePluginProvidersMock).toHaveBeenCalled();
}

View File

@@ -1,10 +1,14 @@
import { beforeAll, beforeEach, describe, it, vi } from "vitest";
import {
registerProviderPlugin,
requireRegisteredProvider,
} from "../../../test/helpers/extensions/provider-registration.js";
import {
expectAugmentedCodexCatalog,
expectCodexBuiltInSuppression,
expectCodexMissingAuthHint,
} from "../provider-runtime.test-support.js";
import { requireProviderContractProvider } from "./registry.js";
import type { ProviderPlugin } from "../types.js";
type ResolvePluginProviders = typeof import("../providers.runtime.js").resolvePluginProviders;
type ResolveOwningPluginIdsForProvider =
@@ -14,9 +18,7 @@ type ResolveNonBundledProviderPluginIds =
const resolvePluginProvidersMock = vi.hoisted(() => vi.fn<ResolvePluginProviders>(() => []));
const resolveOwningPluginIdsForProviderMock = vi.hoisted(() =>
vi.fn<ResolveOwningPluginIdsForProvider>((params) =>
resolveProviderContractPluginIdsForProvider(params.provider),
),
vi.fn<ResolveOwningPluginIdsForProvider>(() => undefined),
);
const resolveNonBundledProviderPluginIdsMock = vi.hoisted(() =>
vi.fn<ResolveNonBundledProviderPluginIds>((_) => [] as string[]),
@@ -36,17 +38,18 @@ vi.mock("../providers.runtime.js", () => ({
let augmentModelCatalogWithProviderPlugins: typeof import("../provider-runtime.js").augmentModelCatalogWithProviderPlugins;
let resetProviderRuntimeHookCacheForTest: typeof import("../provider-runtime.js").resetProviderRuntimeHookCacheForTest;
let resolveProviderBuiltInModelSuppression: typeof import("../provider-runtime.js").resolveProviderBuiltInModelSuppression;
let resolveProviderContractPluginIdsForProvider: typeof import("./registry.js").resolveProviderContractPluginIdsForProvider;
let resolveProviderContractProvidersForPluginIds: typeof import("./registry.js").resolveProviderContractProvidersForPluginIds;
let uniqueProviderContractProviders: typeof import("./registry.js").uniqueProviderContractProviders;
let openaiProviders: ProviderPlugin[];
let openaiProvider: ProviderPlugin;
describe("provider catalog contract", () => {
beforeAll(async () => {
({
resolveProviderContractPluginIdsForProvider,
resolveProviderContractProvidersForPluginIds,
uniqueProviderContractProviders,
} = await import("./registry.js"));
const openaiPlugin = await import("../../../extensions/openai/index.ts");
openaiProviders = registerProviderPlugin({
plugin: openaiPlugin.default,
id: "openai",
name: "OpenAI",
}).providers;
openaiProvider = requireRegisteredProvider(openaiProviders, "openai", "provider");
({
augmentModelCatalogWithProviderPlugins,
resetProviderRuntimeHookCacheForTest,
@@ -61,22 +64,28 @@ describe("provider catalog contract", () => {
resolvePluginProvidersMock.mockImplementation((params?: { onlyPluginIds?: string[] }) => {
const onlyPluginIds = params?.onlyPluginIds;
if (!onlyPluginIds || onlyPluginIds.length === 0) {
return uniqueProviderContractProviders;
return openaiProviders;
}
return resolveProviderContractProvidersForPluginIds(onlyPluginIds);
return onlyPluginIds.includes("openai") ? openaiProviders : [];
});
resolveOwningPluginIdsForProviderMock.mockReset();
resolveOwningPluginIdsForProviderMock.mockImplementation((params) =>
resolveProviderContractPluginIdsForProvider(params.provider),
);
resolveOwningPluginIdsForProviderMock.mockImplementation((params) => {
switch (params.provider) {
case "azure-openai-responses":
case "openai":
case "openai-codex":
return ["openai"];
default:
return undefined;
}
});
resolveNonBundledProviderPluginIdsMock.mockReset();
resolveNonBundledProviderPluginIdsMock.mockReturnValue([]);
});
it("keeps codex-only missing-auth hints wired through the provider runtime", () => {
const openaiProvider = requireProviderContractProvider("openai");
expectCodexMissingAuthHint(
(params) => openaiProvider.buildMissingAuthMessage?.(params.context) ?? undefined,
);

View File

@@ -3,7 +3,8 @@ import { withBundledPluginAllowlistCompat } from "../bundled-compat.js";
import { resolveBundledWebSearchPluginIds } from "../bundled-web-search.js";
import { loadPluginManifestRegistry } from "../manifest-registry.js";
import { __testing as providerTesting } from "../providers.js";
import { providerContractCompatPluginIds, webSearchProviderContractRegistry } from "./registry.js";
import { resolveBundledPluginWebSearchProviders } from "../web-search-providers.js";
import { providerContractCompatPluginIds } from "./registry.js";
import { uniqueSortedStrings } from "./testkit.js";
function resolveBundledManifestProviderPluginIds() {
@@ -48,7 +49,7 @@ describe("plugin loader contract", () => {
env: { VITEST: "1" } as NodeJS.ProcessEnv,
});
webSearchPluginIds = uniqueSortedStrings(
webSearchProviderContractRegistry.map((entry) => entry.pluginId),
resolveBundledPluginWebSearchProviders({}).map((entry) => entry.pluginId),
);
bundledWebSearchPluginIds = uniqueSortedStrings(resolveBundledWebSearchPluginIds({}));
webSearchAllowlistCompatConfig = withBundledPluginAllowlistCompat({

View File

@@ -1,9 +1,8 @@
import { listBundledImageGenerationProviderEntries } from "../../bundled-image-generation-providers.js";
import {
BUNDLED_IMAGE_GENERATION_PLUGIN_IDS,
BUNDLED_MEDIA_UNDERSTANDING_PLUGIN_IDS,
BUNDLED_PLUGIN_CONTRACT_SNAPSHOTS,
BUNDLED_PROVIDER_PLUGIN_IDS,
BUNDLED_SPEECH_PLUGIN_IDS,
BUNDLED_WEB_SEARCH_PLUGIN_IDS,
} from "../bundled-capability-metadata.js";
import { loadBundledCapabilityRuntimeRegistry } from "../bundled-capability-runtime.js";
@@ -43,6 +42,47 @@ type PluginRegistrationContractEntry = {
toolNames: string[];
};
function createProviderContractPluginIdsByProviderId(): Map<string, string[]> {
const result = new Map<string, string[]>();
for (const entry of BUNDLED_PLUGIN_CONTRACT_SNAPSHOTS) {
for (const providerId of entry.providerIds) {
const existing = result.get(providerId) ?? [];
if (!existing.includes(entry.pluginId)) {
existing.push(entry.pluginId);
}
result.set(providerId, existing);
}
}
return result;
}
function createContractSpeechProvider(providerId: string): SpeechProviderPlugin {
return {
id: providerId,
label: providerId,
isConfigured: () => true,
synthesize: async () => ({
audioBuffer: Buffer.alloc(0),
outputFormat: "mp3",
fileExtension: "mp3",
voiceCompatible: true,
}),
listVoices: async () => [],
};
}
function createContractMediaUnderstandingProvider(
providerId: string,
): MediaUnderstandingProviderPlugin {
return {
id: providerId,
capabilities: ["image"],
describeImages: async () => {
throw new Error(`media-understanding contract stub invoked for ${providerId}`);
},
};
}
function uniqueStrings(values: readonly string[]): string[] {
const result: string[] = [];
const seen = new Set<string>();
@@ -64,27 +104,48 @@ let mediaUnderstandingProviderContractRegistryCache:
| null = null;
let imageGenerationProviderContractRegistryCache: ImageGenerationProviderContractEntry[] | null =
null;
const providerContractPluginIdsByProviderId = createProviderContractPluginIdsByProviderId();
const providerContractEntriesByPluginId = new Map<string, ProviderContractEntry[]>();
export let providerContractLoadError: Error | undefined;
function loadProviderContractEntriesForPluginId(pluginId: string): ProviderContractEntry[] {
const cached = providerContractEntriesByPluginId.get(pluginId);
if (cached) {
return cached;
}
try {
providerContractLoadError = undefined;
const entries = resolvePluginProviders({
bundledProviderAllowlistCompat: true,
bundledProviderVitestCompat: true,
onlyPluginIds: [pluginId],
cache: false,
activate: false,
}).map((provider) => ({
pluginId: provider.pluginId ?? pluginId,
provider,
}));
providerContractEntriesByPluginId.set(pluginId, entries);
return entries;
} catch (error) {
providerContractLoadError = error instanceof Error ? error : new Error(String(error));
providerContractEntriesByPluginId.set(pluginId, []);
return [];
}
}
function loadProviderContractEntriesForPluginIds(
pluginIds: readonly string[],
): ProviderContractEntry[] {
return pluginIds.flatMap((pluginId) => loadProviderContractEntriesForPluginId(pluginId));
}
function loadProviderContractRegistry(): ProviderContractEntry[] {
if (!providerContractRegistryCache) {
try {
providerContractLoadError = undefined;
providerContractRegistryCache = resolvePluginProviders({
bundledProviderAllowlistCompat: true,
bundledProviderVitestCompat: true,
onlyPluginIds: [...BUNDLED_PROVIDER_PLUGIN_IDS],
cache: false,
activate: false,
}).map((provider) => ({
pluginId: provider.pluginId ?? "",
provider,
}));
} catch (error) {
providerContractLoadError = error instanceof Error ? error : new Error(String(error));
providerContractRegistryCache = [];
}
providerContractRegistryCache = loadProviderContractEntriesForPluginIds(
BUNDLED_PROVIDER_PLUGIN_IDS,
);
}
return providerContractRegistryCache;
}
@@ -137,27 +198,25 @@ function loadWebSearchProviderContractRegistry(): WebSearchProviderContractEntry
function loadSpeechProviderContractRegistry(): SpeechProviderContractEntry[] {
if (!speechProviderContractRegistryCache) {
const registry = loadBundledCapabilityRuntimeRegistry({
pluginIds: BUNDLED_SPEECH_PLUGIN_IDS,
});
speechProviderContractRegistryCache = registry.speechProviders.map((entry) => ({
pluginId: entry.pluginId,
provider: entry.provider,
}));
// Contract tests only need bundled ownership and public speech surface shape.
speechProviderContractRegistryCache = BUNDLED_PLUGIN_CONTRACT_SNAPSHOTS.flatMap((entry) =>
entry.speechProviderIds.map((providerId) => ({
pluginId: entry.pluginId,
provider: createContractSpeechProvider(providerId),
})),
);
}
return speechProviderContractRegistryCache;
}
function loadMediaUnderstandingProviderContractRegistry(): MediaUnderstandingProviderContractEntry[] {
if (!mediaUnderstandingProviderContractRegistryCache) {
const registry = loadBundledCapabilityRuntimeRegistry({
pluginIds: BUNDLED_MEDIA_UNDERSTANDING_PLUGIN_IDS,
});
mediaUnderstandingProviderContractRegistryCache = registry.mediaUnderstandingProviders.map(
(entry) => ({
pluginId: entry.pluginId,
provider: entry.provider,
}),
mediaUnderstandingProviderContractRegistryCache = BUNDLED_PLUGIN_CONTRACT_SNAPSHOTS.flatMap(
(entry) =>
entry.mediaUnderstandingProviderIds.map((providerId) => ({
pluginId: entry.pluginId,
provider: createContractMediaUnderstandingProvider(providerId),
})),
);
}
return mediaUnderstandingProviderContractRegistryCache;
@@ -165,15 +224,10 @@ function loadMediaUnderstandingProviderContractRegistry(): MediaUnderstandingPro
function loadImageGenerationProviderContractRegistry(): ImageGenerationProviderContractEntry[] {
if (!imageGenerationProviderContractRegistryCache) {
const registry = loadBundledCapabilityRuntimeRegistry({
pluginIds: BUNDLED_IMAGE_GENERATION_PLUGIN_IDS,
});
imageGenerationProviderContractRegistryCache = registry.imageGenerationProviders.map(
(entry) => ({
pluginId: entry.pluginId,
provider: entry.provider,
}),
);
imageGenerationProviderContractRegistryCache =
listBundledImageGenerationProviderEntries().filter((entry) =>
BUNDLED_IMAGE_GENERATION_PLUGIN_IDS.includes(entry.pluginId),
);
}
return imageGenerationProviderContractRegistryCache;
}
@@ -227,7 +281,9 @@ export const providerContractCompatPluginIds: string[] = createLazyArrayView(
);
export function requireProviderContractProvider(providerId: string): ProviderPlugin {
const provider = uniqueProviderContractProviders.find((entry) => entry.id === providerId);
const provider = loadProviderContractEntriesForPluginIds(
providerContractPluginIdsByProviderId.get(providerId) ?? [],
).find((entry) => entry.provider.id === providerId)?.provider;
if (!provider) {
if (providerContractLoadError) {
throw new Error(
@@ -242,13 +298,7 @@ export function requireProviderContractProvider(providerId: string): ProviderPlu
export function resolveProviderContractPluginIdsForProvider(
providerId: string,
): string[] | undefined {
const pluginIds = [
...new Set(
providerContractRegistry
.filter((entry) => entry.provider.id === providerId)
.map((entry) => entry.pluginId),
),
];
const pluginIds = providerContractPluginIdsByProviderId.get(providerId) ?? [];
return pluginIds.length > 0 ? pluginIds : undefined;
}
@@ -258,7 +308,7 @@ export function resolveProviderContractProvidersForPluginIds(
const allowed = new Set(pluginIds);
return [
...new Map(
providerContractRegistry
loadProviderContractEntriesForPluginIds([...allowed])
.filter((entry) => allowed.has(entry.pluginId))
.map((entry) => [entry.provider.id, entry.provider]),
).values(),

View File

@@ -1,14 +1,16 @@
import fs from "node:fs/promises";
import os from "node:os";
import path from "node:path";
import { beforeEach, describe, expect, it, vi } from "vitest";
import { beforeAll, beforeEach, describe, expect, it, vi } from "vitest";
import {
registerProviderPlugin,
requireRegisteredProvider,
} from "../../../test/helpers/extensions/provider-registration.js";
import { createProviderUsageFetch, makeResponse } from "../../test-utils/provider-usage-fetch.js";
import type { ProviderPlugin, ProviderRuntimeModel } from "../types.js";
import { requireProviderContractProvider as requireBundledProviderContractProvider } from "./registry.js";
const CONTRACT_SETUP_TIMEOUT_MS = 300_000;
const getOAuthApiKeyMock = vi.hoisted(() => vi.fn());
const refreshOpenAICodexTokenMock = vi.hoisted(() => vi.fn());
const getOAuthProvidersMock = vi.hoisted(() =>
vi.fn(() => [
@@ -24,9 +26,8 @@ vi.mock("@mariozechner/pi-ai/oauth", async () => {
);
return {
...actual,
getOAuthApiKey: getOAuthApiKeyMock,
getOAuthProviders: getOAuthProvidersMock,
refreshOpenAICodexToken: refreshOpenAICodexTokenMock,
getOAuthProviders: getOAuthProvidersMock,
};
});
@@ -49,13 +50,102 @@ function createModel(overrides: Partial<ProviderRuntimeModel> & Pick<ProviderRun
} satisfies ProviderRuntimeModel;
}
type ProviderRuntimeContractFixture = {
providerIds: string[];
pluginId: string;
name: string;
load: () => Promise<{ default: Parameters<typeof registerProviderPlugin>[0]["plugin"] }>;
};
const PROVIDER_RUNTIME_CONTRACT_FIXTURES: readonly ProviderRuntimeContractFixture[] = [
{
providerIds: ["anthropic"],
pluginId: "anthropic",
name: "Anthropic",
load: async () => await import("../../../extensions/anthropic/index.ts"),
},
{
providerIds: ["github-copilot"],
pluginId: "github-copilot",
name: "GitHub Copilot",
load: async () => await import("../../../extensions/github-copilot/index.ts"),
},
{
providerIds: ["google", "google-gemini-cli"],
pluginId: "google",
name: "Google",
load: async () => await import("../../../extensions/google/index.ts"),
},
{
providerIds: ["openai", "openai-codex"],
pluginId: "openai",
name: "OpenAI",
load: async () => await import("../../../extensions/openai/index.ts"),
},
{
providerIds: ["openrouter"],
pluginId: "openrouter",
name: "OpenRouter",
load: async () => await import("../../../extensions/openrouter/index.ts"),
},
{
providerIds: ["venice"],
pluginId: "venice",
name: "Venice",
load: async () => await import("../../../extensions/venice/index.ts"),
},
{
providerIds: ["xai"],
pluginId: "xai",
name: "xAI",
load: async () => await import("../../../extensions/xai/index.ts"),
},
{
providerIds: ["zai"],
pluginId: "zai",
name: "Z.AI",
load: async () => await import("../../../extensions/zai/index.ts"),
},
] as const;
const providerRuntimeContractProviders = new Map<string, ProviderPlugin>();
function requireProviderContractProvider(providerId: string): ProviderPlugin {
return requireBundledProviderContractProvider(providerId);
const provider = providerRuntimeContractProviders.get(providerId);
if (!provider) {
throw new Error(`provider runtime contract fixture missing for ${providerId}`);
}
return provider;
}
describe("provider runtime contract", () => {
beforeAll(async () => {
providerRuntimeContractProviders.clear();
const registeredFixtures = await Promise.all(
PROVIDER_RUNTIME_CONTRACT_FIXTURES.map(async (fixture) => {
const plugin = await fixture.load();
return {
fixture,
providers: registerProviderPlugin({
plugin: plugin.default,
id: fixture.pluginId,
name: fixture.name,
}).providers,
};
}),
);
for (const { fixture, providers } of registeredFixtures) {
for (const providerId of fixture.providerIds) {
providerRuntimeContractProviders.set(
providerId,
requireRegisteredProvider(providers, providerId, "provider"),
);
}
}
}, CONTRACT_SETUP_TIMEOUT_MS);
beforeEach(() => {
getOAuthApiKeyMock.mockReset();
refreshOpenAICodexTokenMock.mockReset();
getOAuthProvidersMock.mockClear();
refreshOpenAICodexTokenMock.mockReset();
}, CONTRACT_SETUP_TIMEOUT_MS);