mirror of
https://github.com/moltbot/moltbot.git
synced 2026-05-13 23:56:07 +00:00
refactor: shortcut bundled provider contract fixtures
This commit is contained in:
1
extensions/brave/web-search-provider.ts
Normal file
1
extensions/brave/web-search-provider.ts
Normal file
@@ -0,0 +1 @@
|
||||
export { __testing, createBraveWebSearchProvider } from "./src/brave-web-search-provider.js";
|
||||
1
extensions/duckduckgo/web-search-provider.ts
Normal file
1
extensions/duckduckgo/web-search-provider.ts
Normal file
@@ -0,0 +1 @@
|
||||
export { createDuckDuckGoWebSearchProvider } from "./src/ddg-search-provider.js";
|
||||
1
extensions/exa/web-search-provider.ts
Normal file
1
extensions/exa/web-search-provider.ts
Normal file
@@ -0,0 +1 @@
|
||||
export { __testing, createExaWebSearchProvider } from "./src/exa-web-search-provider.js";
|
||||
1
extensions/firecrawl/web-search-provider.ts
Normal file
1
extensions/firecrawl/web-search-provider.ts
Normal file
@@ -0,0 +1 @@
|
||||
export { createFirecrawlWebSearchProvider } from "./src/firecrawl-search-provider.js";
|
||||
1
extensions/google/web-search-provider.ts
Normal file
1
extensions/google/web-search-provider.ts
Normal file
@@ -0,0 +1 @@
|
||||
export { createGeminiWebSearchProvider } from "./src/gemini-web-search-provider.js";
|
||||
1
extensions/moonshot/web-search-provider.ts
Normal file
1
extensions/moonshot/web-search-provider.ts
Normal file
@@ -0,0 +1 @@
|
||||
export { createKimiWebSearchProvider } from "./src/kimi-web-search-provider.js";
|
||||
1
extensions/tavily/web-search-provider.ts
Normal file
1
extensions/tavily/web-search-provider.ts
Normal file
@@ -0,0 +1 @@
|
||||
export { createTavilyWebSearchProvider } from "./src/tavily-search-provider.js";
|
||||
23
src/bundled-image-generation-providers.ts
Normal file
23
src/bundled-image-generation-providers.ts
Normal file
@@ -0,0 +1,23 @@
|
||||
import { buildFalImageGenerationProvider } from "../extensions/fal/image-generation-provider.js";
|
||||
import { buildGoogleImageGenerationProvider } from "../extensions/google/image-generation-provider.js";
|
||||
import {
|
||||
buildMinimaxImageGenerationProvider,
|
||||
buildMinimaxPortalImageGenerationProvider,
|
||||
} from "../extensions/minimax/image-generation-provider.js";
|
||||
import { buildOpenAIImageGenerationProvider } from "../extensions/openai/image-generation-provider.js";
|
||||
import type { ImageGenerationProviderPlugin } from "./plugins/types.js";
|
||||
|
||||
type BundledImageGenerationProviderEntry = {
|
||||
pluginId: string;
|
||||
provider: ImageGenerationProviderPlugin;
|
||||
};
|
||||
|
||||
export function listBundledImageGenerationProviderEntries(): BundledImageGenerationProviderEntry[] {
|
||||
return [
|
||||
{ pluginId: "fal", provider: buildFalImageGenerationProvider() },
|
||||
{ pluginId: "google", provider: buildGoogleImageGenerationProvider() },
|
||||
{ pluginId: "minimax", provider: buildMinimaxImageGenerationProvider() },
|
||||
{ pluginId: "minimax", provider: buildMinimaxPortalImageGenerationProvider() },
|
||||
{ pluginId: "openai", provider: buildOpenAIImageGenerationProvider() },
|
||||
];
|
||||
}
|
||||
24
src/bundled-web-search.entries.ts
Normal file
24
src/bundled-web-search.entries.ts
Normal file
@@ -0,0 +1,24 @@
|
||||
import { createBraveWebSearchProvider } from "../extensions/brave/web-search-provider.js";
|
||||
import { createDuckDuckGoWebSearchProvider } from "../extensions/duckduckgo/web-search-provider.js";
|
||||
import { createExaWebSearchProvider } from "../extensions/exa/web-search-provider.js";
|
||||
import { createFirecrawlWebSearchProvider } from "../extensions/firecrawl/web-search-provider.js";
|
||||
import { createGeminiWebSearchProvider } from "../extensions/google/web-search-provider.js";
|
||||
import { createKimiWebSearchProvider } from "../extensions/moonshot/web-search-provider.js";
|
||||
import { createPerplexityWebSearchProvider } from "../extensions/perplexity/web-search-provider.js";
|
||||
import { createTavilyWebSearchProvider } from "../extensions/tavily/web-search-provider.js";
|
||||
import { createXaiWebSearchProvider } from "../extensions/xai/web-search.js";
|
||||
import type { PluginWebSearchProviderEntry } from "./plugins/types.js";
|
||||
|
||||
export function listBundledWebSearchProviderEntries(): PluginWebSearchProviderEntry[] {
|
||||
return [
|
||||
{ pluginId: "brave", ...createBraveWebSearchProvider() },
|
||||
{ pluginId: "duckduckgo", ...createDuckDuckGoWebSearchProvider() },
|
||||
{ pluginId: "exa", ...createExaWebSearchProvider() },
|
||||
{ pluginId: "firecrawl", ...createFirecrawlWebSearchProvider() },
|
||||
{ pluginId: "google", ...createGeminiWebSearchProvider() },
|
||||
{ pluginId: "moonshot", ...createKimiWebSearchProvider() },
|
||||
{ pluginId: "perplexity", ...createPerplexityWebSearchProvider() },
|
||||
{ pluginId: "tavily", ...createTavilyWebSearchProvider() },
|
||||
{ pluginId: "xai", ...createXaiWebSearchProvider() },
|
||||
];
|
||||
}
|
||||
@@ -1,4 +1,5 @@
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { listBundledWebSearchProviderEntries } from "../bundled-web-search.entries.js";
|
||||
import type { OpenClawConfig } from "../config/config.js";
|
||||
import { BUNDLED_WEB_SEARCH_PLUGIN_IDS } from "./bundled-web-search-ids.js";
|
||||
import { resolveBundledWebSearchPluginId } from "./bundled-web-search-provider-ids.js";
|
||||
@@ -6,7 +7,6 @@ import {
|
||||
listBundledWebSearchProviders,
|
||||
resolveBundledWebSearchPluginIds,
|
||||
} from "./bundled-web-search.js";
|
||||
import { webSearchProviderContractRegistry } from "./contracts/registry.js";
|
||||
|
||||
describe("bundled web search metadata", () => {
|
||||
function toComparableEntry(params: {
|
||||
@@ -105,6 +105,7 @@ describe("bundled web search metadata", () => {
|
||||
|
||||
it("keeps bundled provider metadata aligned with bundled plugin contracts", async () => {
|
||||
const fastPathProviders = listBundledWebSearchProviders();
|
||||
const bundledProviderEntries = listBundledWebSearchProviderEntries();
|
||||
|
||||
expect(
|
||||
sortComparableEntries(
|
||||
@@ -117,7 +118,7 @@ describe("bundled web search metadata", () => {
|
||||
),
|
||||
).toEqual(
|
||||
sortComparableEntries(
|
||||
webSearchProviderContractRegistry.map(({ pluginId, provider }) =>
|
||||
bundledProviderEntries.map(({ pluginId, ...provider }) =>
|
||||
toComparableEntry({
|
||||
pluginId,
|
||||
provider,
|
||||
@@ -127,12 +128,11 @@ describe("bundled web search metadata", () => {
|
||||
);
|
||||
|
||||
for (const fastPathProvider of fastPathProviders) {
|
||||
const contractEntry = webSearchProviderContractRegistry.find(
|
||||
(entry) =>
|
||||
entry.pluginId === fastPathProvider.pluginId && entry.provider.id === fastPathProvider.id,
|
||||
const bundledEntry = bundledProviderEntries.find(
|
||||
(entry) => entry.pluginId === fastPathProvider.pluginId && entry.id === fastPathProvider.id,
|
||||
);
|
||||
expect(contractEntry).toBeDefined();
|
||||
const contractProvider = contractEntry!.provider;
|
||||
expect(bundledEntry).toBeDefined();
|
||||
const contractProvider = bundledEntry!;
|
||||
|
||||
const fastSearchConfig: Record<string, unknown> = {};
|
||||
const contractSearchConfig: Record<string, unknown> = {};
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { listBundledWebSearchProviderEntries } from "../bundled-web-search.entries.js";
|
||||
import { BUNDLED_WEB_SEARCH_PLUGIN_IDS } from "./bundled-capability-metadata.js";
|
||||
import { resolveBundledWebSearchPluginId as resolveBundledWebSearchPluginIdFromMap } from "./bundled-web-search-provider-ids.js";
|
||||
import { webSearchProviderContractRegistry } from "./contracts/registry.js";
|
||||
import type { PluginLoadOptions } from "./loader.js";
|
||||
import { loadPluginManifestRegistry } from "./manifest-registry.js";
|
||||
import type { PluginWebSearchProviderEntry } from "./types.js";
|
||||
@@ -11,10 +11,7 @@ let bundledWebSearchProvidersCache: BundledWebSearchProviderEntry[] | null = nul
|
||||
|
||||
function loadBundledWebSearchProviders(): BundledWebSearchProviderEntry[] {
|
||||
if (!bundledWebSearchProvidersCache) {
|
||||
bundledWebSearchProvidersCache = webSearchProviderContractRegistry.map((entry) => ({
|
||||
...entry.provider,
|
||||
pluginId: entry.pluginId,
|
||||
}));
|
||||
bundledWebSearchProvidersCache = listBundledWebSearchProviderEntries();
|
||||
}
|
||||
return bundledWebSearchProvidersCache;
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { clearRuntimeAuthProfileStoreSnapshots } from "../../agents/auth-profiles/store.js";
|
||||
import { resolvePreferredProviderForAuthChoice } from "../../plugins/provider-auth-choice-preference.js";
|
||||
import { buildProviderPluginMethodChoice } from "../provider-wizard.js";
|
||||
import { requireProviderContractProvider, uniqueProviderContractProviders } from "./registry.js";
|
||||
import type { ProviderPlugin } from "../types.js";
|
||||
|
||||
type ResolvePluginProviders =
|
||||
typeof import("../../plugins/provider-auth-choice.runtime.js").resolvePluginProviders;
|
||||
@@ -15,6 +15,7 @@ const resolveProviderPluginChoiceMock = vi.hoisted(() => vi.fn<ResolveProviderPl
|
||||
const runProviderModelSelectedHookMock = vi.hoisted(() =>
|
||||
vi.fn<RunProviderModelSelectedHook>(async () => {}),
|
||||
);
|
||||
const runAuthMethodMock = vi.hoisted(() => vi.fn(async () => ({ profiles: [] })));
|
||||
|
||||
vi.mock("../../plugins/provider-auth-choice.runtime.js", () => ({
|
||||
resolvePluginProviders: resolvePluginProvidersMock,
|
||||
@@ -25,7 +26,7 @@ vi.mock("../../plugins/provider-auth-choice.runtime.js", () => ({
|
||||
describe("provider auth-choice contract", () => {
|
||||
beforeEach(() => {
|
||||
resolvePluginProvidersMock.mockReset();
|
||||
resolvePluginProvidersMock.mockReturnValue(uniqueProviderContractProviders);
|
||||
resolvePluginProvidersMock.mockReturnValue([]);
|
||||
resolveProviderPluginChoiceMock.mockReset();
|
||||
resolveProviderPluginChoiceMock.mockImplementation(({ providers, choice }) => {
|
||||
const provider = providers.find((entry) =>
|
||||
@@ -55,24 +56,69 @@ describe("provider auth-choice contract", () => {
|
||||
});
|
||||
|
||||
it("maps provider-plugin choices through the shared preferred-provider fallback resolver", async () => {
|
||||
const pluginFallbackScenarios = [
|
||||
"github-copilot",
|
||||
"minimax-portal",
|
||||
"modelstudio",
|
||||
"ollama",
|
||||
].map((providerId) => {
|
||||
const provider = requireProviderContractProvider(providerId);
|
||||
return {
|
||||
authChoice: buildProviderPluginMethodChoice(provider.id, provider.auth[0]?.id ?? "default"),
|
||||
expectedProvider: provider.id,
|
||||
};
|
||||
});
|
||||
const pluginFallbackScenarios: ProviderPlugin[] = [
|
||||
{
|
||||
id: "github-copilot",
|
||||
label: "GitHub Copilot",
|
||||
auth: [
|
||||
{
|
||||
id: "oauth",
|
||||
label: "OAuth",
|
||||
hint: "Browser sign-in",
|
||||
kind: "oauth",
|
||||
run: runAuthMethodMock,
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
id: "minimax-portal",
|
||||
label: "MiniMax Portal",
|
||||
auth: [
|
||||
{
|
||||
id: "portal",
|
||||
label: "Portal",
|
||||
hint: "Browser sign-in",
|
||||
kind: "oauth",
|
||||
run: runAuthMethodMock,
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
id: "modelstudio",
|
||||
label: "ModelStudio",
|
||||
auth: [
|
||||
{
|
||||
id: "api-key",
|
||||
label: "API key",
|
||||
hint: "Paste key",
|
||||
kind: "api_key",
|
||||
run: runAuthMethodMock,
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
id: "ollama",
|
||||
label: "Ollama",
|
||||
auth: [
|
||||
{
|
||||
id: "local",
|
||||
label: "Local",
|
||||
hint: "No auth",
|
||||
kind: "custom",
|
||||
run: runAuthMethodMock,
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
for (const scenario of pluginFallbackScenarios) {
|
||||
for (const provider of pluginFallbackScenarios) {
|
||||
resolvePluginProvidersMock.mockClear();
|
||||
resolvePluginProvidersMock.mockReturnValue([provider]);
|
||||
await expect(
|
||||
resolvePreferredProviderForAuthChoice({ choice: scenario.authChoice }),
|
||||
).resolves.toBe(scenario.expectedProvider);
|
||||
resolvePreferredProviderForAuthChoice({
|
||||
choice: buildProviderPluginMethodChoice(provider.id, provider.auth[0]?.id ?? "default"),
|
||||
}),
|
||||
).resolves.toBe(provider.id);
|
||||
expect(resolvePluginProvidersMock).toHaveBeenCalled();
|
||||
}
|
||||
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
import { beforeAll, beforeEach, describe, it, vi } from "vitest";
|
||||
import {
|
||||
registerProviderPlugin,
|
||||
requireRegisteredProvider,
|
||||
} from "../../../test/helpers/extensions/provider-registration.js";
|
||||
import {
|
||||
expectAugmentedCodexCatalog,
|
||||
expectCodexBuiltInSuppression,
|
||||
expectCodexMissingAuthHint,
|
||||
} from "../provider-runtime.test-support.js";
|
||||
import { requireProviderContractProvider } from "./registry.js";
|
||||
import type { ProviderPlugin } from "../types.js";
|
||||
|
||||
type ResolvePluginProviders = typeof import("../providers.runtime.js").resolvePluginProviders;
|
||||
type ResolveOwningPluginIdsForProvider =
|
||||
@@ -14,9 +18,7 @@ type ResolveNonBundledProviderPluginIds =
|
||||
|
||||
const resolvePluginProvidersMock = vi.hoisted(() => vi.fn<ResolvePluginProviders>(() => []));
|
||||
const resolveOwningPluginIdsForProviderMock = vi.hoisted(() =>
|
||||
vi.fn<ResolveOwningPluginIdsForProvider>((params) =>
|
||||
resolveProviderContractPluginIdsForProvider(params.provider),
|
||||
),
|
||||
vi.fn<ResolveOwningPluginIdsForProvider>(() => undefined),
|
||||
);
|
||||
const resolveNonBundledProviderPluginIdsMock = vi.hoisted(() =>
|
||||
vi.fn<ResolveNonBundledProviderPluginIds>((_) => [] as string[]),
|
||||
@@ -36,17 +38,18 @@ vi.mock("../providers.runtime.js", () => ({
|
||||
let augmentModelCatalogWithProviderPlugins: typeof import("../provider-runtime.js").augmentModelCatalogWithProviderPlugins;
|
||||
let resetProviderRuntimeHookCacheForTest: typeof import("../provider-runtime.js").resetProviderRuntimeHookCacheForTest;
|
||||
let resolveProviderBuiltInModelSuppression: typeof import("../provider-runtime.js").resolveProviderBuiltInModelSuppression;
|
||||
let resolveProviderContractPluginIdsForProvider: typeof import("./registry.js").resolveProviderContractPluginIdsForProvider;
|
||||
let resolveProviderContractProvidersForPluginIds: typeof import("./registry.js").resolveProviderContractProvidersForPluginIds;
|
||||
let uniqueProviderContractProviders: typeof import("./registry.js").uniqueProviderContractProviders;
|
||||
let openaiProviders: ProviderPlugin[];
|
||||
let openaiProvider: ProviderPlugin;
|
||||
|
||||
describe("provider catalog contract", () => {
|
||||
beforeAll(async () => {
|
||||
({
|
||||
resolveProviderContractPluginIdsForProvider,
|
||||
resolveProviderContractProvidersForPluginIds,
|
||||
uniqueProviderContractProviders,
|
||||
} = await import("./registry.js"));
|
||||
const openaiPlugin = await import("../../../extensions/openai/index.ts");
|
||||
openaiProviders = registerProviderPlugin({
|
||||
plugin: openaiPlugin.default,
|
||||
id: "openai",
|
||||
name: "OpenAI",
|
||||
}).providers;
|
||||
openaiProvider = requireRegisteredProvider(openaiProviders, "openai", "provider");
|
||||
({
|
||||
augmentModelCatalogWithProviderPlugins,
|
||||
resetProviderRuntimeHookCacheForTest,
|
||||
@@ -61,22 +64,28 @@ describe("provider catalog contract", () => {
|
||||
resolvePluginProvidersMock.mockImplementation((params?: { onlyPluginIds?: string[] }) => {
|
||||
const onlyPluginIds = params?.onlyPluginIds;
|
||||
if (!onlyPluginIds || onlyPluginIds.length === 0) {
|
||||
return uniqueProviderContractProviders;
|
||||
return openaiProviders;
|
||||
}
|
||||
return resolveProviderContractProvidersForPluginIds(onlyPluginIds);
|
||||
return onlyPluginIds.includes("openai") ? openaiProviders : [];
|
||||
});
|
||||
|
||||
resolveOwningPluginIdsForProviderMock.mockReset();
|
||||
resolveOwningPluginIdsForProviderMock.mockImplementation((params) =>
|
||||
resolveProviderContractPluginIdsForProvider(params.provider),
|
||||
);
|
||||
resolveOwningPluginIdsForProviderMock.mockImplementation((params) => {
|
||||
switch (params.provider) {
|
||||
case "azure-openai-responses":
|
||||
case "openai":
|
||||
case "openai-codex":
|
||||
return ["openai"];
|
||||
default:
|
||||
return undefined;
|
||||
}
|
||||
});
|
||||
|
||||
resolveNonBundledProviderPluginIdsMock.mockReset();
|
||||
resolveNonBundledProviderPluginIdsMock.mockReturnValue([]);
|
||||
});
|
||||
|
||||
it("keeps codex-only missing-auth hints wired through the provider runtime", () => {
|
||||
const openaiProvider = requireProviderContractProvider("openai");
|
||||
expectCodexMissingAuthHint(
|
||||
(params) => openaiProvider.buildMissingAuthMessage?.(params.context) ?? undefined,
|
||||
);
|
||||
|
||||
@@ -3,7 +3,8 @@ import { withBundledPluginAllowlistCompat } from "../bundled-compat.js";
|
||||
import { resolveBundledWebSearchPluginIds } from "../bundled-web-search.js";
|
||||
import { loadPluginManifestRegistry } from "../manifest-registry.js";
|
||||
import { __testing as providerTesting } from "../providers.js";
|
||||
import { providerContractCompatPluginIds, webSearchProviderContractRegistry } from "./registry.js";
|
||||
import { resolveBundledPluginWebSearchProviders } from "../web-search-providers.js";
|
||||
import { providerContractCompatPluginIds } from "./registry.js";
|
||||
import { uniqueSortedStrings } from "./testkit.js";
|
||||
|
||||
function resolveBundledManifestProviderPluginIds() {
|
||||
@@ -48,7 +49,7 @@ describe("plugin loader contract", () => {
|
||||
env: { VITEST: "1" } as NodeJS.ProcessEnv,
|
||||
});
|
||||
webSearchPluginIds = uniqueSortedStrings(
|
||||
webSearchProviderContractRegistry.map((entry) => entry.pluginId),
|
||||
resolveBundledPluginWebSearchProviders({}).map((entry) => entry.pluginId),
|
||||
);
|
||||
bundledWebSearchPluginIds = uniqueSortedStrings(resolveBundledWebSearchPluginIds({}));
|
||||
webSearchAllowlistCompatConfig = withBundledPluginAllowlistCompat({
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
import { listBundledImageGenerationProviderEntries } from "../../bundled-image-generation-providers.js";
|
||||
import {
|
||||
BUNDLED_IMAGE_GENERATION_PLUGIN_IDS,
|
||||
BUNDLED_MEDIA_UNDERSTANDING_PLUGIN_IDS,
|
||||
BUNDLED_PLUGIN_CONTRACT_SNAPSHOTS,
|
||||
BUNDLED_PROVIDER_PLUGIN_IDS,
|
||||
BUNDLED_SPEECH_PLUGIN_IDS,
|
||||
BUNDLED_WEB_SEARCH_PLUGIN_IDS,
|
||||
} from "../bundled-capability-metadata.js";
|
||||
import { loadBundledCapabilityRuntimeRegistry } from "../bundled-capability-runtime.js";
|
||||
@@ -43,6 +42,47 @@ type PluginRegistrationContractEntry = {
|
||||
toolNames: string[];
|
||||
};
|
||||
|
||||
function createProviderContractPluginIdsByProviderId(): Map<string, string[]> {
|
||||
const result = new Map<string, string[]>();
|
||||
for (const entry of BUNDLED_PLUGIN_CONTRACT_SNAPSHOTS) {
|
||||
for (const providerId of entry.providerIds) {
|
||||
const existing = result.get(providerId) ?? [];
|
||||
if (!existing.includes(entry.pluginId)) {
|
||||
existing.push(entry.pluginId);
|
||||
}
|
||||
result.set(providerId, existing);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
function createContractSpeechProvider(providerId: string): SpeechProviderPlugin {
|
||||
return {
|
||||
id: providerId,
|
||||
label: providerId,
|
||||
isConfigured: () => true,
|
||||
synthesize: async () => ({
|
||||
audioBuffer: Buffer.alloc(0),
|
||||
outputFormat: "mp3",
|
||||
fileExtension: "mp3",
|
||||
voiceCompatible: true,
|
||||
}),
|
||||
listVoices: async () => [],
|
||||
};
|
||||
}
|
||||
|
||||
function createContractMediaUnderstandingProvider(
|
||||
providerId: string,
|
||||
): MediaUnderstandingProviderPlugin {
|
||||
return {
|
||||
id: providerId,
|
||||
capabilities: ["image"],
|
||||
describeImages: async () => {
|
||||
throw new Error(`media-understanding contract stub invoked for ${providerId}`);
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
function uniqueStrings(values: readonly string[]): string[] {
|
||||
const result: string[] = [];
|
||||
const seen = new Set<string>();
|
||||
@@ -64,27 +104,48 @@ let mediaUnderstandingProviderContractRegistryCache:
|
||||
| null = null;
|
||||
let imageGenerationProviderContractRegistryCache: ImageGenerationProviderContractEntry[] | null =
|
||||
null;
|
||||
const providerContractPluginIdsByProviderId = createProviderContractPluginIdsByProviderId();
|
||||
const providerContractEntriesByPluginId = new Map<string, ProviderContractEntry[]>();
|
||||
|
||||
export let providerContractLoadError: Error | undefined;
|
||||
|
||||
function loadProviderContractEntriesForPluginId(pluginId: string): ProviderContractEntry[] {
|
||||
const cached = providerContractEntriesByPluginId.get(pluginId);
|
||||
if (cached) {
|
||||
return cached;
|
||||
}
|
||||
try {
|
||||
providerContractLoadError = undefined;
|
||||
const entries = resolvePluginProviders({
|
||||
bundledProviderAllowlistCompat: true,
|
||||
bundledProviderVitestCompat: true,
|
||||
onlyPluginIds: [pluginId],
|
||||
cache: false,
|
||||
activate: false,
|
||||
}).map((provider) => ({
|
||||
pluginId: provider.pluginId ?? pluginId,
|
||||
provider,
|
||||
}));
|
||||
providerContractEntriesByPluginId.set(pluginId, entries);
|
||||
return entries;
|
||||
} catch (error) {
|
||||
providerContractLoadError = error instanceof Error ? error : new Error(String(error));
|
||||
providerContractEntriesByPluginId.set(pluginId, []);
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
function loadProviderContractEntriesForPluginIds(
|
||||
pluginIds: readonly string[],
|
||||
): ProviderContractEntry[] {
|
||||
return pluginIds.flatMap((pluginId) => loadProviderContractEntriesForPluginId(pluginId));
|
||||
}
|
||||
|
||||
function loadProviderContractRegistry(): ProviderContractEntry[] {
|
||||
if (!providerContractRegistryCache) {
|
||||
try {
|
||||
providerContractLoadError = undefined;
|
||||
providerContractRegistryCache = resolvePluginProviders({
|
||||
bundledProviderAllowlistCompat: true,
|
||||
bundledProviderVitestCompat: true,
|
||||
onlyPluginIds: [...BUNDLED_PROVIDER_PLUGIN_IDS],
|
||||
cache: false,
|
||||
activate: false,
|
||||
}).map((provider) => ({
|
||||
pluginId: provider.pluginId ?? "",
|
||||
provider,
|
||||
}));
|
||||
} catch (error) {
|
||||
providerContractLoadError = error instanceof Error ? error : new Error(String(error));
|
||||
providerContractRegistryCache = [];
|
||||
}
|
||||
providerContractRegistryCache = loadProviderContractEntriesForPluginIds(
|
||||
BUNDLED_PROVIDER_PLUGIN_IDS,
|
||||
);
|
||||
}
|
||||
return providerContractRegistryCache;
|
||||
}
|
||||
@@ -137,27 +198,25 @@ function loadWebSearchProviderContractRegistry(): WebSearchProviderContractEntry
|
||||
|
||||
function loadSpeechProviderContractRegistry(): SpeechProviderContractEntry[] {
|
||||
if (!speechProviderContractRegistryCache) {
|
||||
const registry = loadBundledCapabilityRuntimeRegistry({
|
||||
pluginIds: BUNDLED_SPEECH_PLUGIN_IDS,
|
||||
});
|
||||
speechProviderContractRegistryCache = registry.speechProviders.map((entry) => ({
|
||||
pluginId: entry.pluginId,
|
||||
provider: entry.provider,
|
||||
}));
|
||||
// Contract tests only need bundled ownership and public speech surface shape.
|
||||
speechProviderContractRegistryCache = BUNDLED_PLUGIN_CONTRACT_SNAPSHOTS.flatMap((entry) =>
|
||||
entry.speechProviderIds.map((providerId) => ({
|
||||
pluginId: entry.pluginId,
|
||||
provider: createContractSpeechProvider(providerId),
|
||||
})),
|
||||
);
|
||||
}
|
||||
return speechProviderContractRegistryCache;
|
||||
}
|
||||
|
||||
function loadMediaUnderstandingProviderContractRegistry(): MediaUnderstandingProviderContractEntry[] {
|
||||
if (!mediaUnderstandingProviderContractRegistryCache) {
|
||||
const registry = loadBundledCapabilityRuntimeRegistry({
|
||||
pluginIds: BUNDLED_MEDIA_UNDERSTANDING_PLUGIN_IDS,
|
||||
});
|
||||
mediaUnderstandingProviderContractRegistryCache = registry.mediaUnderstandingProviders.map(
|
||||
(entry) => ({
|
||||
pluginId: entry.pluginId,
|
||||
provider: entry.provider,
|
||||
}),
|
||||
mediaUnderstandingProviderContractRegistryCache = BUNDLED_PLUGIN_CONTRACT_SNAPSHOTS.flatMap(
|
||||
(entry) =>
|
||||
entry.mediaUnderstandingProviderIds.map((providerId) => ({
|
||||
pluginId: entry.pluginId,
|
||||
provider: createContractMediaUnderstandingProvider(providerId),
|
||||
})),
|
||||
);
|
||||
}
|
||||
return mediaUnderstandingProviderContractRegistryCache;
|
||||
@@ -165,15 +224,10 @@ function loadMediaUnderstandingProviderContractRegistry(): MediaUnderstandingPro
|
||||
|
||||
function loadImageGenerationProviderContractRegistry(): ImageGenerationProviderContractEntry[] {
|
||||
if (!imageGenerationProviderContractRegistryCache) {
|
||||
const registry = loadBundledCapabilityRuntimeRegistry({
|
||||
pluginIds: BUNDLED_IMAGE_GENERATION_PLUGIN_IDS,
|
||||
});
|
||||
imageGenerationProviderContractRegistryCache = registry.imageGenerationProviders.map(
|
||||
(entry) => ({
|
||||
pluginId: entry.pluginId,
|
||||
provider: entry.provider,
|
||||
}),
|
||||
);
|
||||
imageGenerationProviderContractRegistryCache =
|
||||
listBundledImageGenerationProviderEntries().filter((entry) =>
|
||||
BUNDLED_IMAGE_GENERATION_PLUGIN_IDS.includes(entry.pluginId),
|
||||
);
|
||||
}
|
||||
return imageGenerationProviderContractRegistryCache;
|
||||
}
|
||||
@@ -227,7 +281,9 @@ export const providerContractCompatPluginIds: string[] = createLazyArrayView(
|
||||
);
|
||||
|
||||
export function requireProviderContractProvider(providerId: string): ProviderPlugin {
|
||||
const provider = uniqueProviderContractProviders.find((entry) => entry.id === providerId);
|
||||
const provider = loadProviderContractEntriesForPluginIds(
|
||||
providerContractPluginIdsByProviderId.get(providerId) ?? [],
|
||||
).find((entry) => entry.provider.id === providerId)?.provider;
|
||||
if (!provider) {
|
||||
if (providerContractLoadError) {
|
||||
throw new Error(
|
||||
@@ -242,13 +298,7 @@ export function requireProviderContractProvider(providerId: string): ProviderPlu
|
||||
export function resolveProviderContractPluginIdsForProvider(
|
||||
providerId: string,
|
||||
): string[] | undefined {
|
||||
const pluginIds = [
|
||||
...new Set(
|
||||
providerContractRegistry
|
||||
.filter((entry) => entry.provider.id === providerId)
|
||||
.map((entry) => entry.pluginId),
|
||||
),
|
||||
];
|
||||
const pluginIds = providerContractPluginIdsByProviderId.get(providerId) ?? [];
|
||||
return pluginIds.length > 0 ? pluginIds : undefined;
|
||||
}
|
||||
|
||||
@@ -258,7 +308,7 @@ export function resolveProviderContractProvidersForPluginIds(
|
||||
const allowed = new Set(pluginIds);
|
||||
return [
|
||||
...new Map(
|
||||
providerContractRegistry
|
||||
loadProviderContractEntriesForPluginIds([...allowed])
|
||||
.filter((entry) => allowed.has(entry.pluginId))
|
||||
.map((entry) => [entry.provider.id, entry.provider]),
|
||||
).values(),
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
import fs from "node:fs/promises";
|
||||
import os from "node:os";
|
||||
import path from "node:path";
|
||||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { beforeAll, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import {
|
||||
registerProviderPlugin,
|
||||
requireRegisteredProvider,
|
||||
} from "../../../test/helpers/extensions/provider-registration.js";
|
||||
import { createProviderUsageFetch, makeResponse } from "../../test-utils/provider-usage-fetch.js";
|
||||
import type { ProviderPlugin, ProviderRuntimeModel } from "../types.js";
|
||||
import { requireProviderContractProvider as requireBundledProviderContractProvider } from "./registry.js";
|
||||
|
||||
const CONTRACT_SETUP_TIMEOUT_MS = 300_000;
|
||||
|
||||
const getOAuthApiKeyMock = vi.hoisted(() => vi.fn());
|
||||
const refreshOpenAICodexTokenMock = vi.hoisted(() => vi.fn());
|
||||
const getOAuthProvidersMock = vi.hoisted(() =>
|
||||
vi.fn(() => [
|
||||
@@ -24,9 +26,8 @@ vi.mock("@mariozechner/pi-ai/oauth", async () => {
|
||||
);
|
||||
return {
|
||||
...actual,
|
||||
getOAuthApiKey: getOAuthApiKeyMock,
|
||||
getOAuthProviders: getOAuthProvidersMock,
|
||||
refreshOpenAICodexToken: refreshOpenAICodexTokenMock,
|
||||
getOAuthProviders: getOAuthProvidersMock,
|
||||
};
|
||||
});
|
||||
|
||||
@@ -49,13 +50,102 @@ function createModel(overrides: Partial<ProviderRuntimeModel> & Pick<ProviderRun
|
||||
} satisfies ProviderRuntimeModel;
|
||||
}
|
||||
|
||||
type ProviderRuntimeContractFixture = {
|
||||
providerIds: string[];
|
||||
pluginId: string;
|
||||
name: string;
|
||||
load: () => Promise<{ default: Parameters<typeof registerProviderPlugin>[0]["plugin"] }>;
|
||||
};
|
||||
|
||||
const PROVIDER_RUNTIME_CONTRACT_FIXTURES: readonly ProviderRuntimeContractFixture[] = [
|
||||
{
|
||||
providerIds: ["anthropic"],
|
||||
pluginId: "anthropic",
|
||||
name: "Anthropic",
|
||||
load: async () => await import("../../../extensions/anthropic/index.ts"),
|
||||
},
|
||||
{
|
||||
providerIds: ["github-copilot"],
|
||||
pluginId: "github-copilot",
|
||||
name: "GitHub Copilot",
|
||||
load: async () => await import("../../../extensions/github-copilot/index.ts"),
|
||||
},
|
||||
{
|
||||
providerIds: ["google", "google-gemini-cli"],
|
||||
pluginId: "google",
|
||||
name: "Google",
|
||||
load: async () => await import("../../../extensions/google/index.ts"),
|
||||
},
|
||||
{
|
||||
providerIds: ["openai", "openai-codex"],
|
||||
pluginId: "openai",
|
||||
name: "OpenAI",
|
||||
load: async () => await import("../../../extensions/openai/index.ts"),
|
||||
},
|
||||
{
|
||||
providerIds: ["openrouter"],
|
||||
pluginId: "openrouter",
|
||||
name: "OpenRouter",
|
||||
load: async () => await import("../../../extensions/openrouter/index.ts"),
|
||||
},
|
||||
{
|
||||
providerIds: ["venice"],
|
||||
pluginId: "venice",
|
||||
name: "Venice",
|
||||
load: async () => await import("../../../extensions/venice/index.ts"),
|
||||
},
|
||||
{
|
||||
providerIds: ["xai"],
|
||||
pluginId: "xai",
|
||||
name: "xAI",
|
||||
load: async () => await import("../../../extensions/xai/index.ts"),
|
||||
},
|
||||
{
|
||||
providerIds: ["zai"],
|
||||
pluginId: "zai",
|
||||
name: "Z.AI",
|
||||
load: async () => await import("../../../extensions/zai/index.ts"),
|
||||
},
|
||||
] as const;
|
||||
|
||||
const providerRuntimeContractProviders = new Map<string, ProviderPlugin>();
|
||||
|
||||
function requireProviderContractProvider(providerId: string): ProviderPlugin {
|
||||
return requireBundledProviderContractProvider(providerId);
|
||||
const provider = providerRuntimeContractProviders.get(providerId);
|
||||
if (!provider) {
|
||||
throw new Error(`provider runtime contract fixture missing for ${providerId}`);
|
||||
}
|
||||
return provider;
|
||||
}
|
||||
|
||||
describe("provider runtime contract", () => {
|
||||
beforeAll(async () => {
|
||||
providerRuntimeContractProviders.clear();
|
||||
const registeredFixtures = await Promise.all(
|
||||
PROVIDER_RUNTIME_CONTRACT_FIXTURES.map(async (fixture) => {
|
||||
const plugin = await fixture.load();
|
||||
return {
|
||||
fixture,
|
||||
providers: registerProviderPlugin({
|
||||
plugin: plugin.default,
|
||||
id: fixture.pluginId,
|
||||
name: fixture.name,
|
||||
}).providers,
|
||||
};
|
||||
}),
|
||||
);
|
||||
for (const { fixture, providers } of registeredFixtures) {
|
||||
for (const providerId of fixture.providerIds) {
|
||||
providerRuntimeContractProviders.set(
|
||||
providerId,
|
||||
requireRegisteredProvider(providers, providerId, "provider"),
|
||||
);
|
||||
}
|
||||
}
|
||||
}, CONTRACT_SETUP_TIMEOUT_MS);
|
||||
|
||||
beforeEach(() => {
|
||||
getOAuthApiKeyMock.mockReset();
|
||||
refreshOpenAICodexTokenMock.mockReset();
|
||||
getOAuthProvidersMock.mockClear();
|
||||
refreshOpenAICodexTokenMock.mockReset();
|
||||
}, CONTRACT_SETUP_TIMEOUT_MS);
|
||||
|
||||
Reference in New Issue
Block a user