mirror of
https://github.com/pocketpaw/pocketpaw.git
synced 2026-05-13 21:21:53 +00:00
feat(mcp): add OAuth support, remove registry tab, improve error handling
- Add MCP OAuth flow: oauth_store.py for token persistence, callback coordination via asyncio Futures, dashboard callback endpoint, and WebSocket redirect broadcast to frontend - Add `oauth` flag to MCPServerConfig and MCPPreset for OAuth-aware install flow (browser popup for auth) - Remove MCP Registry tab and related dashboard/JS code (replaced by curated preset catalog) - Improve MCP manager error handling: unwrap ExceptionGroup to surface root-cause errors instead of unhelpful anyio cancel-scope messages - Add mcp_client_metadata_url config field for OIDC client metadata - Frontend: streamline mcp.js (OAuth popup handling, cancel flow) - Tests: 24 new OAuth tests, updated manager tests Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
File diff suppressed because it is too large
Load Diff
13
docs/public/mcp-client.json
Normal file
13
docs/public/mcp-client.json
Normal file
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"client_name": "PocketPaw",
|
||||
"redirect_uris": [
|
||||
"http://localhost:8888/api/mcp/oauth/callback",
|
||||
"http://127.0.0.1:8888/api/mcp/oauth/callback"
|
||||
],
|
||||
"token_endpoint_auth_method": "none",
|
||||
"grant_types": ["authorization_code", "refresh_token"],
|
||||
"response_types": ["code"],
|
||||
"client_uri": "https://github.com/pocketpaw/pocketpaw",
|
||||
"software_id": "pocketpaw",
|
||||
"software_version": "0.4.1"
|
||||
}
|
||||
@@ -411,6 +411,12 @@ class Settings(BaseSettings):
|
||||
web_host: str = Field(default="127.0.0.1", description="Web server host")
|
||||
web_port: int = Field(default=8888, description="Web server port")
|
||||
|
||||
# MCP OAuth
|
||||
mcp_client_metadata_url: str = Field(
|
||||
default="",
|
||||
description="CIMD URL for MCP OAuth (optional, for servers without dynamic registration)",
|
||||
)
|
||||
|
||||
# Identity / Multi-user
|
||||
owner_id: str = Field(
|
||||
default="",
|
||||
@@ -553,6 +559,8 @@ class Settings(BaseSettings):
|
||||
"google_oauth_client_secret": (
|
||||
self.google_oauth_client_secret or existing.get("google_oauth_client_secret")
|
||||
),
|
||||
# MCP OAuth
|
||||
"mcp_client_metadata_url": self.mcp_client_metadata_url,
|
||||
# Voice/TTS
|
||||
"tts_provider": self.tts_provider,
|
||||
"elevenlabs_api_key": (self.elevenlabs_api_key or existing.get("elevenlabs_api_key")),
|
||||
|
||||
@@ -399,9 +399,19 @@ async def startup_event():
|
||||
except Exception as e:
|
||||
logger.warning("Failed to recover interrupted projects: %s", e)
|
||||
|
||||
# Auto-start enabled MCP servers
|
||||
# Wire MCP OAuth broadcast + auto-start enabled MCP servers
|
||||
try:
|
||||
from pocketpaw.mcp.manager import get_mcp_manager
|
||||
from pocketpaw.mcp.manager import get_mcp_manager, set_ws_broadcast
|
||||
|
||||
async def _mcp_ws_broadcast(message: dict) -> None:
|
||||
"""Broadcast an MCP message to all connected WebSocket clients."""
|
||||
for ws in active_connections[:]:
|
||||
try:
|
||||
await ws.send_json(message)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
set_ws_broadcast(_mcp_ws_broadcast)
|
||||
|
||||
mcp = get_mcp_manager()
|
||||
await mcp.start_enabled_servers()
|
||||
@@ -616,6 +626,7 @@ async def list_mcp_presets():
|
||||
"url": p.url,
|
||||
"docs_url": p.docs_url,
|
||||
"needs_args": p.needs_args,
|
||||
"oauth": p.oauth,
|
||||
"installed": p.id in installed_names,
|
||||
"env_keys": [
|
||||
{
|
||||
@@ -670,244 +681,35 @@ async def install_mcp_preset(request: Request):
|
||||
}
|
||||
|
||||
|
||||
# ==================== MCP Registry API ====================
|
||||
@app.get("/api/mcp/oauth/callback")
|
||||
async def mcp_oauth_callback(code: str = "", state: str = ""):
|
||||
"""OAuth callback endpoint — receives authorization code from OAuth provider.
|
||||
|
||||
_MCP_REGISTRY_BASE = "https://registry.modelcontextprotocol.io"
|
||||
|
||||
# Server name parts that are too generic to use alone as a config name.
|
||||
_GENERIC_SERVER_PARTS = {"mcp", "server", "mcp-server", "main", "app", "api"}
|
||||
|
||||
|
||||
def _derive_registry_short_name(raw_name: str, title: str | None = None) -> str:
|
||||
"""Derive a short, readable config name from a registry server name.
|
||||
|
||||
Examples:
|
||||
"com.zomato/mcp" -> "zomato-mcp"
|
||||
"acme/weather-server" -> "weather-server"
|
||||
"@anthropic/claude" -> "claude"
|
||||
"simple-tool" -> "simple-tool"
|
||||
This is the redirect target after user authenticates with GitHub, Notion, etc.
|
||||
Auth-exempt because the OAuth provider redirects the user's browser here.
|
||||
"""
|
||||
if not raw_name:
|
||||
return ""
|
||||
from fastapi.responses import HTMLResponse
|
||||
|
||||
if "/" not in raw_name:
|
||||
return raw_name
|
||||
from pocketpaw.mcp.manager import set_oauth_callback_result
|
||||
|
||||
parts = raw_name.split("/")
|
||||
org = parts[0]
|
||||
server_part = parts[-1]
|
||||
|
||||
# Clean up org: "com.zomato" -> "zomato", "@anthropic" -> "anthropic"
|
||||
if "." in org:
|
||||
org = org.rsplit(".", 1)[-1]
|
||||
org = org.lstrip("@")
|
||||
|
||||
# If the server part is too generic, combine with org for disambiguation
|
||||
if server_part.lower() in _GENERIC_SERVER_PARTS:
|
||||
return f"{org}-{server_part}"
|
||||
|
||||
return server_part
|
||||
|
||||
|
||||
@app.get("/api/mcp/registry/search")
|
||||
async def search_mcp_registry(
|
||||
q: str = "",
|
||||
limit: int = 30,
|
||||
cursor: str = "",
|
||||
):
|
||||
"""Proxy search to the official MCP Registry (avoids CORS)."""
|
||||
import httpx
|
||||
|
||||
params: dict[str, str | int] = {"limit": min(limit, 100)}
|
||||
if q:
|
||||
params["search"] = q
|
||||
if cursor:
|
||||
params["cursor"] = cursor
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=15.0) as client:
|
||||
resp = await client.get(
|
||||
f"{_MCP_REGISTRY_BASE}/v0/servers",
|
||||
params=params,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
# Registry wraps each entry as {server: {...}, _meta: {...}}.
|
||||
# Unwrap so the frontend gets flat server objects.
|
||||
# Also: lift environmentVariables from packages[0] to server
|
||||
# level, remove $schema ($ prefix can confuse Alpine.js proxies),
|
||||
# and ensure expected fields have defaults.
|
||||
servers = []
|
||||
raw_entries = data.get("servers", [])
|
||||
if not isinstance(raw_entries, list):
|
||||
raw_entries = []
|
||||
|
||||
for entry in raw_entries:
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
|
||||
raw_server = entry.get("server", entry)
|
||||
if not isinstance(raw_server, dict):
|
||||
continue
|
||||
|
||||
srv = dict(raw_server)
|
||||
meta = entry.get("_meta", srv.get("_meta", {}))
|
||||
srv["_meta"] = meta if isinstance(meta, dict) else {}
|
||||
srv.pop("$schema", None)
|
||||
|
||||
name = srv.get("name")
|
||||
description = srv.get("description")
|
||||
packages = srv.get("packages")
|
||||
remotes = srv.get("remotes")
|
||||
env_vars = srv.get("environmentVariables")
|
||||
|
||||
srv["name"] = name if isinstance(name, str) else ""
|
||||
srv["description"] = description if isinstance(description, str) else ""
|
||||
srv["packages"] = packages if isinstance(packages, list) else []
|
||||
srv["remotes"] = remotes if isinstance(remotes, list) else []
|
||||
srv["environmentVariables"] = env_vars if isinstance(env_vars, list) else []
|
||||
|
||||
# Lift env vars from the first package to the server level.
|
||||
if not srv["environmentVariables"]:
|
||||
for pkg in srv["packages"]:
|
||||
if not isinstance(pkg, dict):
|
||||
continue
|
||||
pkg_env = pkg.get("environmentVariables")
|
||||
if isinstance(pkg_env, list) and pkg_env:
|
||||
srv["environmentVariables"] = pkg_env
|
||||
break
|
||||
|
||||
# Skip entries without a valid name.
|
||||
if srv["name"]:
|
||||
servers.append(srv)
|
||||
|
||||
metadata = data.get("metadata", {})
|
||||
if not isinstance(metadata, dict):
|
||||
metadata = {}
|
||||
if "nextCursor" not in metadata and "next_cursor" in metadata:
|
||||
metadata["nextCursor"] = metadata["next_cursor"]
|
||||
metadata.setdefault("count", len(servers))
|
||||
|
||||
return {"servers": servers, "metadata": metadata}
|
||||
except Exception as exc:
|
||||
logger.warning("MCP registry search failed: %s", exc)
|
||||
return {"servers": [], "metadata": {"count": 0}, "error": str(exc)}
|
||||
|
||||
|
||||
@app.post("/api/mcp/registry/install")
|
||||
async def install_from_registry(request: Request):
|
||||
"""Install an MCP server from registry metadata.
|
||||
|
||||
Expects a JSON body with the server's registry data (name, packages/remotes,
|
||||
environmentVariables) and user-supplied env values.
|
||||
"""
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from pocketpaw.mcp.config import MCPServerConfig
|
||||
from pocketpaw.mcp.manager import get_mcp_manager
|
||||
|
||||
data = await request.json()
|
||||
server = data.get("server", {})
|
||||
user_env = data.get("env", {})
|
||||
|
||||
# Derive a short, readable name from the registry name.
|
||||
# e.g. "com.zomato/mcp" -> "zomato-mcp", "acme/weather-server" -> "weather-server"
|
||||
raw_name = server.get("name", "")
|
||||
short_name = _derive_registry_short_name(raw_name, server.get("title"))
|
||||
if not short_name:
|
||||
return JSONResponse({"error": "Missing server name"}, status_code=400)
|
||||
|
||||
# Try remotes first (HTTP transport — simplest, no npm needed)
|
||||
remotes = server.get("remotes", [])
|
||||
packages = server.get("packages", [])
|
||||
|
||||
config = None
|
||||
|
||||
if remotes:
|
||||
remote = remotes[0]
|
||||
# Registry API uses "type" (e.g. "streamable-http"), legacy uses "transportType"
|
||||
transport = remote.get("type", remote.get("transportType", "http"))
|
||||
# Normalize SSE to "http" but keep "streamable-http" distinct — they need
|
||||
# different MCP SDK clients.
|
||||
if transport == "sse":
|
||||
transport = "http"
|
||||
elif transport not in ("http", "streamable-http"):
|
||||
transport = "http" # safe fallback
|
||||
config = MCPServerConfig(
|
||||
name=short_name,
|
||||
transport=transport,
|
||||
url=remote.get("url", ""),
|
||||
env=user_env,
|
||||
enabled=True,
|
||||
)
|
||||
elif packages:
|
||||
pkg = packages[0]
|
||||
registry_type = pkg.get("registryType", "")
|
||||
pkg_name = pkg.get("name", "") or pkg.get("identifier", "")
|
||||
runtime = pkg.get("runtime", "node")
|
||||
|
||||
if registry_type == "docker":
|
||||
args = ["run", "-i", "--rm"]
|
||||
for ra in pkg.get("runtimeArguments", []):
|
||||
if ra.get("isFixed"):
|
||||
args.append(ra.get("value", ""))
|
||||
args.append(pkg_name)
|
||||
config = MCPServerConfig(
|
||||
name=short_name,
|
||||
transport="stdio",
|
||||
command="docker",
|
||||
args=args,
|
||||
env=user_env,
|
||||
enabled=True,
|
||||
)
|
||||
elif registry_type == "pypi":
|
||||
config = MCPServerConfig(
|
||||
name=short_name,
|
||||
transport="stdio",
|
||||
command="uvx",
|
||||
args=[pkg_name],
|
||||
env=user_env,
|
||||
enabled=True,
|
||||
)
|
||||
elif registry_type == "npm" or runtime == "node":
|
||||
args = ["-y", pkg_name]
|
||||
for pa in pkg.get("packageArguments", []):
|
||||
if pa.get("isFixed"):
|
||||
args.append(pa.get("value", ""))
|
||||
config = MCPServerConfig(
|
||||
name=short_name,
|
||||
transport="stdio",
|
||||
command="npx",
|
||||
args=args,
|
||||
env=user_env,
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
if config is None:
|
||||
return JSONResponse(
|
||||
{"error": "Could not determine install method from registry data"},
|
||||
if not code or not state:
|
||||
return HTMLResponse(
|
||||
"<html><body><h3>Missing code or state parameter.</h3></body></html>",
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
mgr = get_mcp_manager()
|
||||
mgr.add_server_config(config)
|
||||
connected = await mgr.start_server(config)
|
||||
tools = mgr.discover_tools(config.name) if connected else []
|
||||
|
||||
result: dict = {
|
||||
"status": "ok",
|
||||
"name": config.name,
|
||||
"connected": connected,
|
||||
"tools": [{"name": t.name, "description": t.description} for t in tools],
|
||||
}
|
||||
# Surface connection error so the frontend can display it
|
||||
if not connected:
|
||||
status = mgr.get_server_status()
|
||||
srv = status.get(config.name, {})
|
||||
if srv.get("error"):
|
||||
result["error"] = srv["error"]
|
||||
return result
|
||||
resolved = set_oauth_callback_result(state, code)
|
||||
if resolved:
|
||||
return HTMLResponse(
|
||||
"<html><body>"
|
||||
"<h3>Authenticated! You can close this tab.</h3>"
|
||||
"<script>window.close()</script>"
|
||||
"</body></html>"
|
||||
)
|
||||
return HTMLResponse(
|
||||
"<html><body><h3>OAuth flow expired or not found.</h3></body></html>",
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
|
||||
# ==================== Skills Library API ====================
|
||||
@@ -1704,6 +1506,7 @@ async def auth_middleware(request: Request, call_next):
|
||||
"/webhook/inbound",
|
||||
"/api/whatsapp/qr",
|
||||
"/oauth/callback",
|
||||
"/api/mcp/oauth/callback",
|
||||
]
|
||||
|
||||
for path in exempt_paths:
|
||||
|
||||
@@ -2,14 +2,14 @@
|
||||
* PocketPaw - MCP Servers Feature Module
|
||||
*
|
||||
* Created: 2026-02-07
|
||||
* Updated: 2026-02-12 — Registry tab (browse official MCP registry), dynamic categories, needs_args.
|
||||
* Updated: 2026-02-17 — Removed Registry tab, added paste-command input.
|
||||
*
|
||||
* Manages MCP (Model Context Protocol) server connections:
|
||||
* - List/add/remove servers
|
||||
* - Enable/disable servers
|
||||
* - View tool inventory
|
||||
* - Browse & install presets from the catalog
|
||||
* - Search & install from the official MCP Registry (16K+ servers)
|
||||
* - Browse & install presets from the curated catalog
|
||||
* - Paste a full command to auto-fill Add Server form
|
||||
*/
|
||||
|
||||
window.PocketPaw = window.PocketPaw || {};
|
||||
@@ -28,7 +28,8 @@ window.PocketPaw.MCP = {
|
||||
transport: 'stdio',
|
||||
command: '',
|
||||
args: '',
|
||||
url: ''
|
||||
url: '',
|
||||
fullCommand: ''
|
||||
},
|
||||
mcpLoading: false,
|
||||
mcpShowAddForm: false,
|
||||
@@ -38,17 +39,8 @@ window.PocketPaw.MCP = {
|
||||
mcpInstallEnv: {},
|
||||
mcpInstallArgs: '',
|
||||
mcpInstalling: false,
|
||||
mcpCategoryFilter: 'all',
|
||||
// Registry state
|
||||
mcpRegistryQuery: '',
|
||||
mcpRegistryResults: [],
|
||||
mcpRegistryFeatured: [],
|
||||
mcpRegistryLoading: false,
|
||||
mcpRegistryFeaturedError: false,
|
||||
mcpRegistryCursor: null,
|
||||
mcpRegistryLoadingMore: false,
|
||||
mcpRegistryInstalling: null,
|
||||
mcpRegistryInstallEnv: {}
|
||||
mcpInstallAbort: null,
|
||||
mcpCategoryFilter: 'all'
|
||||
};
|
||||
},
|
||||
|
||||
@@ -64,6 +56,31 @@ window.PocketPaw.MCP = {
|
||||
this.showMCP = true;
|
||||
await this.getMCPStatus();
|
||||
await this.loadPresets();
|
||||
|
||||
// Register WS handler for OAuth redirect (once)
|
||||
if (!window.PocketPaw._mcpOAuthRegistered && window.socket) {
|
||||
window.socket.on('mcp_oauth_redirect', (data) => {
|
||||
if (!data.url) return;
|
||||
// Navigate pre-opened popup or show fallback
|
||||
const popup = window.PocketPaw._oauthPopup;
|
||||
if (popup && !popup.closed) {
|
||||
popup.location = data.url;
|
||||
} else {
|
||||
// Popup was blocked — show clickable link
|
||||
const name = data.server || 'server';
|
||||
if (this.showToast) {
|
||||
this.showToast(
|
||||
`Open auth link for ${name}: ` +
|
||||
data.url.substring(0, 60) + '...',
|
||||
'info'
|
||||
);
|
||||
}
|
||||
window.open(data.url, '_blank');
|
||||
}
|
||||
});
|
||||
window.PocketPaw._mcpOAuthRegistered = true;
|
||||
}
|
||||
|
||||
this.$nextTick(() => {
|
||||
if (window.refreshIcons) window.refreshIcons();
|
||||
});
|
||||
@@ -108,7 +125,7 @@ window.PocketPaw.MCP = {
|
||||
const data = await res.json();
|
||||
if (data.status === 'ok') {
|
||||
this.showToast(`MCP server "${this.mcpForm.name}" added`, 'success');
|
||||
this.mcpForm = { name: '', transport: 'stdio', command: '', args: '', url: '' };
|
||||
this.mcpForm = { name: '', transport: 'stdio', command: '', args: '', url: '', fullCommand: '' };
|
||||
await this.getMCPStatus();
|
||||
} else {
|
||||
this.showToast(data.error || 'Failed to add server', 'error');
|
||||
@@ -199,12 +216,29 @@ window.PocketPaw.MCP = {
|
||||
}
|
||||
},
|
||||
|
||||
/**
|
||||
* Cancel an in-progress install (abort fetch, reset state, close popup)
|
||||
*/
|
||||
cancelInstall() {
|
||||
if (this.mcpInstallAbort) {
|
||||
this.mcpInstallAbort.abort();
|
||||
this.mcpInstallAbort = null;
|
||||
}
|
||||
this.mcpInstalling = false;
|
||||
this.mcpInstallId = null;
|
||||
const popup = window.PocketPaw._oauthPopup;
|
||||
if (popup && !popup.closed) {
|
||||
try { popup.close(); } catch (_) { /* cross-origin */ }
|
||||
}
|
||||
window.PocketPaw._oauthPopup = null;
|
||||
},
|
||||
|
||||
/**
|
||||
* Show install form for a preset
|
||||
*/
|
||||
showInstallForm(presetId) {
|
||||
if (this.mcpInstallId === presetId) {
|
||||
this.mcpInstallId = null;
|
||||
this.cancelInstall();
|
||||
return;
|
||||
}
|
||||
this.mcpInstallId = presetId;
|
||||
@@ -228,6 +262,21 @@ window.PocketPaw.MCP = {
|
||||
async installPreset() {
|
||||
if (!this.mcpInstallId) return;
|
||||
this.mcpInstalling = true;
|
||||
|
||||
// AbortController so Cancel can kill the pending fetch
|
||||
const abort = new AbortController();
|
||||
this.mcpInstallAbort = abort;
|
||||
|
||||
// For OAuth presets: open a blank popup NOW (in user click context)
|
||||
// so the browser allows it. The WS handler will navigate it later.
|
||||
const isOAuth = this.presetIsOAuth(this.mcpInstallId);
|
||||
if (isOAuth) {
|
||||
window.PocketPaw._oauthPopup = window.open(
|
||||
'about:blank', 'pocketpaw_oauth',
|
||||
'width=600,height=700,scrollbars=yes'
|
||||
);
|
||||
}
|
||||
|
||||
try {
|
||||
const body = {
|
||||
preset_id: this.mcpInstallId,
|
||||
@@ -240,7 +289,8 @@ window.PocketPaw.MCP = {
|
||||
const res = await fetch('/api/mcp/presets/install', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify(body)
|
||||
body: JSON.stringify(body),
|
||||
signal: abort.signal
|
||||
});
|
||||
const data = await res.json();
|
||||
if (res.ok && data.status === 'ok') {
|
||||
@@ -256,9 +306,21 @@ window.PocketPaw.MCP = {
|
||||
this.showToast(data.error || 'Install failed', 'error');
|
||||
}
|
||||
} catch (e) {
|
||||
if (e.name === 'AbortError') return; // User cancelled
|
||||
this.showToast('Install failed: ' + e.message, 'error');
|
||||
} finally {
|
||||
this.mcpInstalling = false;
|
||||
this.mcpInstallAbort = null;
|
||||
// Close leftover OAuth popup if still blank
|
||||
const popup = window.PocketPaw._oauthPopup;
|
||||
if (popup && !popup.closed) {
|
||||
try {
|
||||
if (popup.location.href === 'about:blank') {
|
||||
popup.close();
|
||||
}
|
||||
} catch (_) { /* cross-origin — popup navigated, leave it */ }
|
||||
}
|
||||
window.PocketPaw._oauthPopup = null;
|
||||
this.$nextTick(() => {
|
||||
if (window.refreshIcons) window.refreshIcons();
|
||||
});
|
||||
@@ -290,294 +352,62 @@ window.PocketPaw.MCP = {
|
||||
},
|
||||
|
||||
/**
|
||||
* Normalize one registry entry to a flat server object.
|
||||
* Accepts both wrapped shape ({server, _meta}) and flat shape.
|
||||
* Check if a preset uses OAuth authentication
|
||||
*/
|
||||
normalizeRegistryServer(entry) {
|
||||
if (!entry || typeof entry !== 'object') return null;
|
||||
presetIsOAuth(presetId) {
|
||||
const preset = this.mcpPresets.find(p => p.id === presetId);
|
||||
return preset ? !!preset.oauth : false;
|
||||
},
|
||||
|
||||
const raw = (entry.server && typeof entry.server === 'object')
|
||||
? entry.server
|
||||
: entry;
|
||||
if (!raw || typeof raw !== 'object') return null;
|
||||
/**
|
||||
* Get button text for a preset based on OAuth status and install state
|
||||
*/
|
||||
presetButtonText(presetId) {
|
||||
if (this.mcpInstallId === presetId) return 'Cancel';
|
||||
return this.presetIsOAuth(presetId) ? 'Authenticate' : 'Install';
|
||||
},
|
||||
|
||||
const server = { ...raw };
|
||||
const meta = (entry._meta && typeof entry._meta === 'object')
|
||||
? entry._meta
|
||||
: ((server._meta && typeof server._meta === 'object') ? server._meta : {});
|
||||
/**
|
||||
* Parse a pasted full command string and auto-populate the Add Server form.
|
||||
* Handles patterns like:
|
||||
* "npx -y @some/package"
|
||||
* "uvx mcp-server-git"
|
||||
* "docker run -i --rm ghcr.io/org/img"
|
||||
*/
|
||||
parseFullCommand() {
|
||||
const raw = (this.mcpForm.fullCommand || '').trim();
|
||||
if (!raw) return;
|
||||
|
||||
server._meta = meta;
|
||||
delete server.$schema;
|
||||
const parts = raw.split(/\s+/);
|
||||
if (parts.length === 0) return;
|
||||
|
||||
server.name = typeof server.name === 'string' ? server.name : '';
|
||||
server.title = typeof server.title === 'string' ? server.title : '';
|
||||
server.description = typeof server.description === 'string' ? server.description : '';
|
||||
server.version = typeof server.version === 'string' ? server.version : '';
|
||||
server.packages = Array.isArray(server.packages) ? server.packages : [];
|
||||
server.remotes = Array.isArray(server.remotes) ? server.remotes : [];
|
||||
server.environmentVariables = Array.isArray(server.environmentVariables)
|
||||
? server.environmentVariables
|
||||
: [];
|
||||
const command = parts[0];
|
||||
const args = parts.slice(1);
|
||||
|
||||
// Lift env vars from package metadata when server-level list is missing.
|
||||
if (server.environmentVariables.length === 0) {
|
||||
for (const pkg of server.packages) {
|
||||
if (!pkg || typeof pkg !== 'object') continue;
|
||||
if (Array.isArray(pkg.environmentVariables) && pkg.environmentVariables.length > 0) {
|
||||
server.environmentVariables = pkg.environmentVariables;
|
||||
break;
|
||||
}
|
||||
this.mcpForm.command = command;
|
||||
this.mcpForm.args = args.join(', ');
|
||||
|
||||
// Auto-derive a name from the last arg that looks like a package
|
||||
let name = '';
|
||||
for (let i = args.length - 1; i >= 0; i--) {
|
||||
const a = args[i];
|
||||
// Skip flags and version suffixes
|
||||
if (a.startsWith('-')) continue;
|
||||
// Use the package-like arg
|
||||
name = a
|
||||
.replace(/@latest$/, '')
|
||||
.replace(/@[\d.]+.*$/, '');
|
||||
// Extract short name: "@scope/pkg" -> "pkg", "mcp-server-git" -> "mcp-server-git"
|
||||
if (name.includes('/')) {
|
||||
name = name.split('/').pop() || name;
|
||||
}
|
||||
name = name.replace(/^@/, '');
|
||||
break;
|
||||
}
|
||||
|
||||
return server.name ? server : null;
|
||||
},
|
||||
|
||||
/**
|
||||
* Normalize a registry server list into safe flat entries.
|
||||
*/
|
||||
normalizeRegistryServers(entries) {
|
||||
if (!Array.isArray(entries)) return [];
|
||||
const normalized = [];
|
||||
for (const entry of entries) {
|
||||
const server = this.normalizeRegistryServer(entry);
|
||||
if (server) normalized.push(server);
|
||||
if (name && !this.mcpForm.name) {
|
||||
this.mcpForm.name = name;
|
||||
}
|
||||
return normalized;
|
||||
},
|
||||
|
||||
/**
|
||||
* Normalize next-cursor variants returned by registry metadata.
|
||||
*/
|
||||
registryNextCursor(metadata) {
|
||||
if (!metadata || typeof metadata !== 'object') return null;
|
||||
return metadata.nextCursor || metadata.next_cursor || null;
|
||||
},
|
||||
|
||||
// ==================== Registry Methods ====================
|
||||
|
||||
/**
|
||||
* Search the official MCP Registry (debounced via Alpine @input.debounce)
|
||||
*/
|
||||
async searchRegistry() {
|
||||
const q = this.mcpRegistryQuery.trim();
|
||||
if (!q) {
|
||||
this.mcpRegistryResults = [];
|
||||
this.mcpRegistryCursor = null;
|
||||
return;
|
||||
}
|
||||
|
||||
this.mcpRegistryLoading = true;
|
||||
try {
|
||||
const url = `/api/mcp/registry/search?q=${encodeURIComponent(q)}&limit=30`;
|
||||
const res = await fetch(url);
|
||||
if (res.ok) {
|
||||
const data = await res.json();
|
||||
this.mcpRegistryResults = this.normalizeRegistryServers(data.servers);
|
||||
this.mcpRegistryCursor = this.registryNextCursor(data.metadata);
|
||||
if (data.error) {
|
||||
this.showToast(`Registry search failed: ${data.error}`, 'error');
|
||||
}
|
||||
} else {
|
||||
this.mcpRegistryResults = [];
|
||||
this.mcpRegistryCursor = null;
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('Registry search failed', e);
|
||||
} finally {
|
||||
this.mcpRegistryLoading = false;
|
||||
this.$nextTick(() => {
|
||||
if (window.refreshIcons) window.refreshIcons();
|
||||
});
|
||||
}
|
||||
},
|
||||
|
||||
/**
|
||||
* Load featured/popular registry servers for initial view
|
||||
*/
|
||||
async loadRegistryFeatured() {
|
||||
if (this.mcpRegistryFeatured.length > 0) return;
|
||||
this.mcpRegistryLoading = true;
|
||||
this.mcpRegistryFeaturedError = false;
|
||||
try {
|
||||
const res = await fetch('/api/mcp/registry/search?limit=30');
|
||||
if (res.ok) {
|
||||
const data = await res.json();
|
||||
this.mcpRegistryFeatured = this.normalizeRegistryServers(data.servers);
|
||||
if (data.error) {
|
||||
this.mcpRegistryFeaturedError = true;
|
||||
}
|
||||
} else {
|
||||
this.mcpRegistryFeaturedError = true;
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('Failed to load registry featured', e);
|
||||
this.mcpRegistryFeaturedError = true;
|
||||
} finally {
|
||||
this.mcpRegistryLoading = false;
|
||||
this.$nextTick(() => {
|
||||
if (window.refreshIcons) window.refreshIcons();
|
||||
});
|
||||
}
|
||||
},
|
||||
|
||||
/**
|
||||
* Retry loading featured servers (clears cache first)
|
||||
*/
|
||||
async retryRegistryFeatured() {
|
||||
this.mcpRegistryFeatured = [];
|
||||
await this.loadRegistryFeatured();
|
||||
},
|
||||
|
||||
/**
|
||||
* Load more registry results (pagination)
|
||||
*/
|
||||
async loadMoreRegistry() {
|
||||
if (!this.mcpRegistryCursor || this.mcpRegistryLoadingMore) return;
|
||||
this.mcpRegistryLoadingMore = true;
|
||||
try {
|
||||
const q = this.mcpRegistryQuery.trim();
|
||||
let url = `/api/mcp/registry/search?limit=30&cursor=${encodeURIComponent(this.mcpRegistryCursor)}`;
|
||||
if (q) url += `&q=${encodeURIComponent(q)}`;
|
||||
const res = await fetch(url);
|
||||
if (res.ok) {
|
||||
const data = await res.json();
|
||||
const newServers = this.normalizeRegistryServers(data.servers);
|
||||
this.mcpRegistryResults = [...this.mcpRegistryResults, ...newServers];
|
||||
this.mcpRegistryCursor = this.registryNextCursor(data.metadata);
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('Registry load more failed', e);
|
||||
} finally {
|
||||
this.mcpRegistryLoadingMore = false;
|
||||
this.$nextTick(() => {
|
||||
if (window.refreshIcons) window.refreshIcons();
|
||||
});
|
||||
}
|
||||
},
|
||||
|
||||
/**
|
||||
* Get the list to display in registry view
|
||||
*/
|
||||
registryDisplayResults() {
|
||||
return this.mcpRegistryQuery.trim()
|
||||
? this.mcpRegistryResults
|
||||
: this.mcpRegistryFeatured;
|
||||
},
|
||||
|
||||
/**
|
||||
* Extract a short display name from a registry server
|
||||
*/
|
||||
registryServerName(server) {
|
||||
if (server.title) return server.title;
|
||||
const name = server.name || '';
|
||||
return name.includes('/') ? name.split('/').pop() : name;
|
||||
},
|
||||
|
||||
/**
|
||||
* Extract a source label (e.g. "npm: @mcp/server" or "HTTP")
|
||||
*/
|
||||
registryServerSource(server) {
|
||||
const remotes = server.remotes || [];
|
||||
const packages = server.packages || [];
|
||||
if (remotes.length > 0) return 'HTTP';
|
||||
if (packages.length > 0) {
|
||||
const pkg = packages[0];
|
||||
const type = pkg.registryType || 'npm';
|
||||
return `${type}: ${pkg.name || ''}`;
|
||||
}
|
||||
return server.name || '';
|
||||
},
|
||||
|
||||
/**
|
||||
* Check if a registry server is already installed locally
|
||||
*/
|
||||
isRegistryServerInstalled(server) {
|
||||
const rawName = server.name || '';
|
||||
const installed = Object.keys(this.mcpServers).map(n => n.toLowerCase());
|
||||
// Check both the full derived name and the simple last-segment name
|
||||
const parts = rawName.split('/');
|
||||
const lastPart = (parts.pop() || '').toLowerCase();
|
||||
const orgPart = parts.length > 0
|
||||
? (parts[0].includes('.') ? parts[0].split('.').pop() : parts[0]).replace(/^@/, '').toLowerCase()
|
||||
: '';
|
||||
const generic = ['mcp', 'server', 'mcp-server', 'main', 'app', 'api'];
|
||||
const derivedName = generic.includes(lastPart) && orgPart
|
||||
? `${orgPart}-${lastPart}`
|
||||
: lastPart;
|
||||
return installed.includes(derivedName) || installed.includes(lastPart);
|
||||
},
|
||||
|
||||
/**
|
||||
* Get env vars required by a registry server
|
||||
*/
|
||||
registryServerEnvVars(server) {
|
||||
return (server.environmentVariables || []).filter(ev => ev.required !== false);
|
||||
},
|
||||
|
||||
/**
|
||||
* Show install form for a registry server
|
||||
*/
|
||||
showRegistryInstallForm(serverName) {
|
||||
if (this.mcpRegistryInstalling === serverName) {
|
||||
this.mcpRegistryInstalling = null;
|
||||
return;
|
||||
}
|
||||
this.mcpRegistryInstalling = serverName;
|
||||
// Pre-fill env
|
||||
const results = this.registryDisplayResults();
|
||||
const server = results.find(s => s.name === serverName);
|
||||
const env = {};
|
||||
if (server) {
|
||||
for (const ev of (server.environmentVariables || [])) {
|
||||
env[ev.name] = '';
|
||||
}
|
||||
}
|
||||
this.mcpRegistryInstallEnv = env;
|
||||
this.$nextTick(() => {
|
||||
if (window.refreshIcons) window.refreshIcons();
|
||||
});
|
||||
},
|
||||
|
||||
/**
|
||||
* Install a server from the registry
|
||||
*/
|
||||
async installFromRegistry(server) {
|
||||
const serverName = server.name;
|
||||
this.mcpRegistryInstalling = serverName;
|
||||
try {
|
||||
const res = await fetch('/api/mcp/registry/install', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
server: server,
|
||||
env: this.mcpRegistryInstallEnv
|
||||
})
|
||||
});
|
||||
const data = await res.json();
|
||||
if (res.ok && data.status === 'ok') {
|
||||
const toolCount = data.tools ? data.tools.length : 0;
|
||||
let msg;
|
||||
if (data.connected) {
|
||||
msg = `Installed "${data.name}" — ${toolCount} tools`;
|
||||
} else {
|
||||
msg = `Installed "${data.name}" (not yet connected)`;
|
||||
if (data.error) msg += `: ${data.error}`;
|
||||
}
|
||||
this.showToast(msg, data.connected ? 'success' : 'warning');
|
||||
this.mcpRegistryInstalling = null;
|
||||
await this.getMCPStatus();
|
||||
} else {
|
||||
this.showToast(data.error || 'Install failed', 'error');
|
||||
this.mcpRegistryInstalling = null;
|
||||
}
|
||||
} catch (e) {
|
||||
this.showToast('Install failed: ' + e.message, 'error');
|
||||
this.mcpRegistryInstalling = null;
|
||||
}
|
||||
this.$nextTick(() => {
|
||||
if (window.refreshIcons) window.refreshIcons();
|
||||
});
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@@ -29,9 +29,14 @@ class MCPServerConfig:
|
||||
env: dict[str, str] = field(default_factory=dict)
|
||||
enabled: bool = True
|
||||
timeout: int = 30 # Connection timeout in seconds
|
||||
# Legacy: original registry identifier (e.g. "@cmd8/excalidraw-mcp@0.1.4").
|
||||
# Kept for backward compatibility with servers installed from the now-removed
|
||||
# MCP Registry tab. Not used by new installations.
|
||||
registry_ref: str = ""
|
||||
oauth: bool = False # True if server uses OAuth authentication
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
d = {
|
||||
"name": self.name,
|
||||
"transport": self.transport,
|
||||
"command": self.command,
|
||||
@@ -41,6 +46,11 @@ class MCPServerConfig:
|
||||
"enabled": self.enabled,
|
||||
"timeout": self.timeout,
|
||||
}
|
||||
if self.registry_ref:
|
||||
d["registry_ref"] = self.registry_ref
|
||||
if self.oauth:
|
||||
d["oauth"] = True
|
||||
return d
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> MCPServerConfig:
|
||||
@@ -53,6 +63,8 @@ class MCPServerConfig:
|
||||
env=data.get("env", {}),
|
||||
enabled=data.get("enabled", True),
|
||||
timeout=data.get("timeout", 30),
|
||||
registry_ref=data.get("registry_ref", ""),
|
||||
oauth=data.get("oauth", False),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -14,13 +14,43 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
from pocketpaw.mcp.config import MCPServerConfig, load_mcp_config, save_mcp_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# OAuth callback coordination: state -> Future[(code, state)]
|
||||
_oauth_pending: dict[str, asyncio.Future] = {}
|
||||
|
||||
# WebSocket broadcast function injected by dashboard at startup
|
||||
_ws_broadcast: Callable | None = None
|
||||
|
||||
|
||||
def set_ws_broadcast(fn: Callable) -> None:
|
||||
"""Set the WebSocket broadcast function (called by dashboard at startup)."""
|
||||
global _ws_broadcast
|
||||
_ws_broadcast = fn
|
||||
|
||||
|
||||
def set_oauth_callback_result(state: str, code: str) -> bool:
|
||||
"""Resolve a pending OAuth Future with the authorization code.
|
||||
|
||||
Called by the dashboard callback endpoint when the OAuth provider
|
||||
redirects back with code + state params.
|
||||
|
||||
Returns True if a pending Future was found and resolved.
|
||||
"""
|
||||
future = _oauth_pending.get(state)
|
||||
if future and not future.done():
|
||||
future.set_result((code, state))
|
||||
return True
|
||||
logger.warning("No pending OAuth flow for state=%s", state[:16])
|
||||
return False
|
||||
|
||||
|
||||
@dataclass
|
||||
class MCPToolInfo:
|
||||
@@ -46,6 +76,46 @@ class _ServerState:
|
||||
connected: bool = False
|
||||
|
||||
|
||||
_UNHELPFUL_ERRORS = {
|
||||
"Attempted to exit a cancel scope that isn't the current tasks's current cancel scope",
|
||||
}
|
||||
|
||||
|
||||
def _extract_root_error(exc: BaseException) -> str:
|
||||
"""Unwrap ExceptionGroup / BaseExceptionGroup to find the real error.
|
||||
|
||||
MCP stdio transport wraps subprocess crashes in anyio cancel-scope errors
|
||||
that are uninformative (e.g. "Attempted to exit a cancel scope…").
|
||||
Walk the exception tree to find a concrete root-cause message.
|
||||
"""
|
||||
# Collect all leaf exceptions from exception groups
|
||||
leaves: list[BaseException] = []
|
||||
|
||||
def _collect(e: BaseException) -> None:
|
||||
if isinstance(e, (ExceptionGroup, BaseExceptionGroup)):
|
||||
for sub in e.exceptions:
|
||||
_collect(sub)
|
||||
elif e.__cause__:
|
||||
_collect(e.__cause__)
|
||||
else:
|
||||
leaves.append(e)
|
||||
|
||||
_collect(exc)
|
||||
|
||||
# Pick the most useful message (skip unhelpful anyio internals)
|
||||
for leaf in leaves:
|
||||
msg = str(leaf).strip()
|
||||
if msg and msg not in _UNHELPFUL_ERRORS:
|
||||
return msg
|
||||
|
||||
# Fallback: if everything is unhelpful, use the top-level message and
|
||||
# hint that the server process crashed.
|
||||
top = str(exc).strip()
|
||||
if top in _UNHELPFUL_ERRORS:
|
||||
return "Server process crashed during startup (check terminal for details)"
|
||||
return top
|
||||
|
||||
|
||||
class MCPManager:
|
||||
"""Manages MCP server connections and tool invocations."""
|
||||
|
||||
@@ -104,6 +174,88 @@ class MCPManager:
|
||||
env.update(config_env)
|
||||
return env
|
||||
|
||||
@staticmethod
|
||||
def _make_oauth_auth(config: MCPServerConfig):
|
||||
"""Create an httpx.Auth for OAuth-based MCP servers.
|
||||
|
||||
Uses the MCP SDK's OAuthClientProvider which handles:
|
||||
- OAuth 2.1 metadata discovery
|
||||
- CIMD (Client ID Metadata Document) for servers that support it
|
||||
- Dynamic client registration as fallback
|
||||
- PKCE authorization code flow
|
||||
- Token refresh
|
||||
"""
|
||||
from mcp.client.auth import OAuthClientProvider
|
||||
from mcp.shared.auth import OAuthClientMetadata
|
||||
|
||||
from pocketpaw.config import Settings
|
||||
from pocketpaw.mcp.oauth_store import MCPTokenStorage
|
||||
|
||||
settings = Settings.load()
|
||||
port = settings.web_port or 8888
|
||||
|
||||
storage = MCPTokenStorage(config.name)
|
||||
|
||||
client_metadata = OAuthClientMetadata(
|
||||
client_name="PocketPaw",
|
||||
redirect_uris=[f"http://localhost:{port}/api/mcp/oauth/callback"],
|
||||
token_endpoint_auth_method="none",
|
||||
grant_types=["authorization_code", "refresh_token"],
|
||||
response_types=["code"],
|
||||
)
|
||||
|
||||
# CIMD URL — servers that support client_id_metadata_document_supported
|
||||
# will use this URL as the client_id instead of dynamic registration.
|
||||
cimd_url = (settings.mcp_client_metadata_url or "").strip() or None
|
||||
|
||||
# Shared mutable state between the two closures
|
||||
_flow_state: dict[str, str] = {}
|
||||
|
||||
async def redirect_handler(auth_url: str) -> None:
|
||||
"""Called by SDK with the authorization URL — broadcast to frontend."""
|
||||
parsed = urlparse(auth_url)
|
||||
params = parse_qs(parsed.query)
|
||||
state_values = params.get("state", [])
|
||||
state = state_values[0] if state_values else ""
|
||||
|
||||
if state:
|
||||
loop = asyncio.get_running_loop()
|
||||
future = loop.create_future()
|
||||
_oauth_pending[state] = future
|
||||
_flow_state["state"] = state
|
||||
|
||||
if _ws_broadcast:
|
||||
await _ws_broadcast(
|
||||
{
|
||||
"type": "mcp_oauth_redirect",
|
||||
"url": auth_url,
|
||||
"server": config.name,
|
||||
}
|
||||
)
|
||||
logger.info("OAuth redirect for MCP server '%s' — waiting for callback", config.name)
|
||||
|
||||
async def callback_handler() -> tuple[str, str | None]:
|
||||
"""Called by SDK to wait for the OAuth callback result."""
|
||||
state = _flow_state.get("state")
|
||||
if not state or state not in _oauth_pending:
|
||||
raise RuntimeError(f"No pending OAuth flow for server '{config.name}'")
|
||||
|
||||
future = _oauth_pending[state]
|
||||
try:
|
||||
code, returned_state = await asyncio.wait_for(future, timeout=300)
|
||||
return (code, returned_state)
|
||||
finally:
|
||||
_oauth_pending.pop(state, None)
|
||||
|
||||
return OAuthClientProvider(
|
||||
server_url=config.url,
|
||||
client_metadata=client_metadata,
|
||||
storage=storage,
|
||||
redirect_handler=redirect_handler,
|
||||
callback_handler=callback_handler,
|
||||
client_metadata_url=cimd_url,
|
||||
)
|
||||
|
||||
async def start_server(self, config: MCPServerConfig) -> bool:
|
||||
"""Start an MCP server and initialize its session.
|
||||
|
||||
@@ -117,23 +269,62 @@ class MCPManager:
|
||||
state = _ServerState(config=config)
|
||||
self._servers[config.name] = state
|
||||
|
||||
# Build OAuth auth if needed
|
||||
auth = None
|
||||
if config.oauth:
|
||||
try:
|
||||
auth = self._make_oauth_auth(config)
|
||||
except Exception as e:
|
||||
state.error = f"OAuth setup failed: {e}"
|
||||
logger.error("OAuth setup failed for '%s': %s", config.name, e)
|
||||
return False
|
||||
|
||||
try:
|
||||
timeout = config.timeout or 30
|
||||
# OAuth flows need more time for user interaction
|
||||
connect_timeout = 300 if config.oauth else timeout
|
||||
|
||||
if config.transport == "stdio":
|
||||
await asyncio.wait_for(self._connect_stdio(state), timeout=timeout)
|
||||
elif config.transport == "streamable-http":
|
||||
# Streamable HTTP uses a different MCP SDK client than SSE.
|
||||
await self._connect_remote_with_timeout(
|
||||
state, timeout, self._connect_streamable_http
|
||||
state,
|
||||
connect_timeout,
|
||||
lambda s: self._connect_streamable_http(s, auth=auth),
|
||||
)
|
||||
elif config.transport == "sse":
|
||||
await self._connect_remote_with_timeout(
|
||||
state,
|
||||
connect_timeout,
|
||||
lambda s: self._connect_sse(s, auth=auth),
|
||||
)
|
||||
elif config.transport == "http":
|
||||
# HTTP/SSE connections use anyio cancel scopes internally.
|
||||
# asyncio.wait_for cancels the task on timeout, which disrupts
|
||||
# anyio's cancel scope cleanup and causes TaskGroup errors.
|
||||
# Instead, run with a manual timeout that doesn't cancel the task.
|
||||
await self._connect_remote_with_timeout(
|
||||
state, timeout, self._connect_sse
|
||||
)
|
||||
# Auto-detect: try Streamable HTTP first, fall back to SSE.
|
||||
# Modern MCP servers use Streamable HTTP (POST-based);
|
||||
# older ones use SSE (GET-based).
|
||||
try:
|
||||
await self._connect_remote_with_timeout(
|
||||
state,
|
||||
connect_timeout,
|
||||
lambda s: self._connect_streamable_http(
|
||||
s, auth=auth
|
||||
),
|
||||
)
|
||||
except TimeoutError:
|
||||
raise # Don't waste time retrying on timeout
|
||||
except BaseException:
|
||||
await self._cleanup_state(state)
|
||||
state = _ServerState(config=config)
|
||||
self._servers[config.name] = state
|
||||
logger.debug(
|
||||
"Streamable HTTP failed for '%s', trying SSE",
|
||||
config.name,
|
||||
)
|
||||
await self._connect_remote_with_timeout(
|
||||
state,
|
||||
connect_timeout,
|
||||
lambda s: self._connect_sse(s, auth=auth),
|
||||
)
|
||||
else:
|
||||
state.error = f"Unknown transport: {config.transport}"
|
||||
logger.error(state.error)
|
||||
@@ -150,18 +341,31 @@ class MCPManager:
|
||||
return True
|
||||
|
||||
except TimeoutError:
|
||||
state.error = f"Connection timed out after {timeout}s"
|
||||
effective_timeout = 300 if config.oauth else (config.timeout or 30)
|
||||
state.error = f"Connection timed out after {effective_timeout}s"
|
||||
state.connected = False
|
||||
await self._cleanup_state(state)
|
||||
logger.error("MCP server '%s' timed out after %ds", config.name, timeout)
|
||||
logger.error("MCP server '%s' timed out after %ds", config.name, effective_timeout)
|
||||
return False
|
||||
except BaseException as e:
|
||||
# Catch BaseException to handle ExceptionGroup / BaseExceptionGroup
|
||||
# from anyio TaskGroup failures in the MCP library.
|
||||
state.error = str(e)
|
||||
root_msg = _extract_root_error(e)
|
||||
|
||||
# Provide actionable hint for OAuth registration failures
|
||||
if config.oauth and "Registration failed" in root_msg:
|
||||
root_msg = (
|
||||
f"{root_msg}. "
|
||||
"This server doesn't support dynamic client registration. "
|
||||
"You can set mcp_client_metadata_url in Settings to a "
|
||||
"publicly-hosted CIMD JSON file, or configure the server "
|
||||
"with an API token instead of OAuth."
|
||||
)
|
||||
|
||||
state.error = root_msg
|
||||
state.connected = False
|
||||
await self._cleanup_state(state)
|
||||
logger.error("Failed to start MCP server '%s': %s", config.name, e)
|
||||
logger.error("Failed to start MCP server '%s': %s", config.name, root_msg)
|
||||
return False
|
||||
|
||||
async def _connect_stdio(self, state: _ServerState) -> None:
|
||||
@@ -222,12 +426,15 @@ class MCPManager:
|
||||
await self._cleanup_state(state)
|
||||
raise
|
||||
|
||||
async def _connect_sse(self, state: _ServerState) -> None:
|
||||
async def _connect_sse(self, state: _ServerState, auth=None) -> None:
|
||||
"""Connect to an MCP server via SSE (Server-Sent Events)."""
|
||||
from mcp import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
|
||||
ctx = sse_client(url=state.config.url)
|
||||
kwargs: dict[str, Any] = {"url": state.config.url}
|
||||
if auth is not None:
|
||||
kwargs["auth"] = auth
|
||||
ctx = sse_client(**kwargs)
|
||||
streams = await ctx.__aenter__()
|
||||
state.client = ctx
|
||||
state.read_stream = streams[0]
|
||||
@@ -243,12 +450,15 @@ class MCPManager:
|
||||
state.client = None
|
||||
raise
|
||||
|
||||
async def _connect_streamable_http(self, state: _ServerState) -> None:
|
||||
async def _connect_streamable_http(self, state: _ServerState, auth=None) -> None:
|
||||
"""Connect to an MCP server via Streamable HTTP transport."""
|
||||
from mcp import ClientSession
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
|
||||
ctx = streamablehttp_client(url=state.config.url)
|
||||
kwargs: dict[str, Any] = {"url": state.config.url}
|
||||
if auth is not None:
|
||||
kwargs["auth"] = auth
|
||||
ctx = streamablehttp_client(**kwargs)
|
||||
streams = await ctx.__aenter__()
|
||||
state.client = ctx
|
||||
# streamablehttp_client yields (read, write, get_session_id)
|
||||
@@ -356,22 +566,28 @@ class MCPManager:
|
||||
result = {}
|
||||
# First, include all servers from the config file
|
||||
for cfg in load_mcp_config():
|
||||
result[cfg.name] = {
|
||||
info: dict = {
|
||||
"connected": False,
|
||||
"tool_count": 0,
|
||||
"error": "",
|
||||
"transport": cfg.transport,
|
||||
"enabled": cfg.enabled,
|
||||
}
|
||||
if cfg.registry_ref:
|
||||
info["registry_ref"] = cfg.registry_ref
|
||||
result[cfg.name] = info
|
||||
# Overlay runtime state for servers that have been started
|
||||
for name, state in self._servers.items():
|
||||
result[name] = {
|
||||
info = {
|
||||
"connected": state.connected,
|
||||
"tool_count": len(state.tools),
|
||||
"error": state.error,
|
||||
"transport": state.config.transport,
|
||||
"enabled": state.config.enabled,
|
||||
}
|
||||
if state.config.registry_ref:
|
||||
info["registry_ref"] = state.config.registry_ref
|
||||
result[name] = info
|
||||
return result
|
||||
|
||||
async def start_enabled_servers(self) -> None:
|
||||
|
||||
96
src/pocketpaw/mcp/oauth_store.py
Normal file
96
src/pocketpaw/mcp/oauth_store.py
Normal file
@@ -0,0 +1,96 @@
|
||||
"""MCP OAuth Token Storage — file-based persistence for MCP OAuth tokens.
|
||||
|
||||
Implements the MCP SDK's ``TokenStorage`` protocol for persisting OAuth tokens
|
||||
and client registration info to ``~/.pocketpaw/mcp_oauth/{server_name}.json``.
|
||||
|
||||
Created: 2026-02-17
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import stat
|
||||
from pathlib import Path
|
||||
|
||||
from pocketpaw.config import get_config_dir
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_oauth_dir() -> Path:
|
||||
"""Get/create the MCP OAuth token directory."""
|
||||
d = get_config_dir() / "mcp_oauth"
|
||||
d.mkdir(exist_ok=True)
|
||||
return d
|
||||
|
||||
|
||||
class MCPTokenStorage:
|
||||
"""File-based token storage for MCP OAuth at ~/.pocketpaw/mcp_oauth/{name}.json.
|
||||
|
||||
Stores both OAuth tokens and dynamic client registration info.
|
||||
Files are chmod 0600 (owner-only read/write).
|
||||
"""
|
||||
|
||||
def __init__(self, server_name: str) -> None:
|
||||
self._server_name = server_name
|
||||
self._path = _get_oauth_dir() / f"{server_name}.json"
|
||||
|
||||
def _load(self) -> dict:
|
||||
if not self._path.exists():
|
||||
return {}
|
||||
try:
|
||||
return json.loads(self._path.read_text())
|
||||
except (json.JSONDecodeError, OSError) as e:
|
||||
logger.warning("Failed to load MCP OAuth data for %s: %s", self._server_name, e)
|
||||
return {}
|
||||
|
||||
def _save(self, data: dict) -> None:
|
||||
self._path.write_text(json.dumps(data, indent=2))
|
||||
try:
|
||||
os.chmod(self._path, stat.S_IRUSR | stat.S_IWUSR)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
async def get_tokens(self):
|
||||
"""Get stored OAuth tokens."""
|
||||
from mcp.shared.auth import OAuthToken
|
||||
|
||||
data = self._load()
|
||||
tokens_data = data.get("tokens")
|
||||
if not tokens_data:
|
||||
return None
|
||||
try:
|
||||
return OAuthToken.model_validate(tokens_data)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to parse MCP OAuth tokens for %s: %s", self._server_name, e)
|
||||
return None
|
||||
|
||||
async def set_tokens(self, tokens) -> None:
|
||||
"""Store OAuth tokens."""
|
||||
data = self._load()
|
||||
data["tokens"] = tokens.model_dump()
|
||||
self._save(data)
|
||||
logger.debug("Saved MCP OAuth tokens for %s", self._server_name)
|
||||
|
||||
async def get_client_info(self):
|
||||
"""Get stored client registration info."""
|
||||
from mcp.shared.auth import OAuthClientInformationFull
|
||||
|
||||
data = self._load()
|
||||
client_data = data.get("client_info")
|
||||
if not client_data:
|
||||
return None
|
||||
try:
|
||||
return OAuthClientInformationFull.model_validate(client_data)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to parse MCP OAuth client info for %s: %s", self._server_name, e)
|
||||
return None
|
||||
|
||||
async def set_client_info(self, client_info) -> None:
|
||||
"""Store client registration info."""
|
||||
data = self._load()
|
||||
data["client_info"] = client_info.model_dump(mode="json")
|
||||
self._save(data)
|
||||
logger.debug("Saved MCP OAuth client info for %s", self._server_name)
|
||||
@@ -402,6 +402,68 @@ class TestMCPManagerConfigMethods:
|
||||
# ======================================================================
|
||||
|
||||
|
||||
class TestHTTPAutoDetect:
|
||||
"""Tests for transport='http' auto-detect (Streamable HTTP → SSE fallback)."""
|
||||
|
||||
async def test_http_transport_tries_streamable_first(self):
|
||||
"""transport='http' should try Streamable HTTP and succeed."""
|
||||
mgr = MCPManager()
|
||||
cfg = MCPServerConfig(name="modern", transport="http", url="https://example.com/mcp")
|
||||
|
||||
with (
|
||||
patch.object(mgr, "_connect_streamable_http", new_callable=AsyncMock),
|
||||
patch.object(mgr, "_connect_sse", new_callable=AsyncMock) as mock_sse,
|
||||
patch.object(
|
||||
mgr,
|
||||
"_discover_tools",
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
):
|
||||
result = await mgr.start_server(cfg)
|
||||
assert result is True
|
||||
# SSE should NOT have been called
|
||||
mock_sse.assert_not_called()
|
||||
|
||||
async def test_http_transport_falls_back_to_sse(self):
|
||||
"""transport='http' should fall back to SSE when Streamable HTTP fails."""
|
||||
mgr = MCPManager()
|
||||
cfg = MCPServerConfig(name="legacy", transport="http", url="https://example.com/mcp")
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
mgr,
|
||||
"_connect_streamable_http",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=RuntimeError("405 Method Not Allowed"),
|
||||
),
|
||||
patch.object(mgr, "_connect_sse", new_callable=AsyncMock),
|
||||
patch.object(mgr, "_discover_tools", new_callable=AsyncMock),
|
||||
patch.object(mgr, "_cleanup_state", new_callable=AsyncMock),
|
||||
):
|
||||
result = await mgr.start_server(cfg)
|
||||
assert result is True
|
||||
|
||||
async def test_http_transport_no_fallback_on_timeout(self):
|
||||
"""transport='http' should NOT fall back to SSE on timeout."""
|
||||
mgr = MCPManager()
|
||||
cfg = MCPServerConfig(
|
||||
name="slow", transport="http", url="https://example.com/mcp", timeout=1
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
mgr,
|
||||
"_connect_remote_with_timeout",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=TimeoutError("Connection timed out"),
|
||||
),
|
||||
):
|
||||
result = await mgr.start_server(cfg)
|
||||
assert result is False
|
||||
status = mgr._servers["slow"]
|
||||
assert "timed out" in status.error
|
||||
|
||||
|
||||
class TestGetMCPManager:
|
||||
def test_returns_same_instance(self):
|
||||
import pocketpaw.mcp.manager as mod
|
||||
|
||||
308
tests/test_mcp_oauth.py
Normal file
308
tests/test_mcp_oauth.py
Normal file
@@ -0,0 +1,308 @@
|
||||
"""Tests for MCP OAuth flow — token storage, callback coordination,
|
||||
preset flags, and dashboard endpoint.
|
||||
|
||||
Created: 2026-02-17
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from pocketpaw.mcp.config import MCPServerConfig
|
||||
from pocketpaw.mcp.presets import get_all_presets, get_preset, preset_to_config
|
||||
|
||||
# ======================================================================
|
||||
# MCPTokenStorage tests
|
||||
# ======================================================================
|
||||
|
||||
|
||||
class TestMCPTokenStorage:
|
||||
@pytest.fixture
|
||||
def storage(self, tmp_path):
|
||||
with patch("pocketpaw.mcp.oauth_store.get_config_dir", return_value=tmp_path):
|
||||
from pocketpaw.mcp.oauth_store import MCPTokenStorage
|
||||
|
||||
return MCPTokenStorage("test-server")
|
||||
|
||||
async def test_get_tokens_empty(self, storage):
|
||||
result = await storage.get_tokens()
|
||||
assert result is None
|
||||
|
||||
async def test_set_and_get_tokens(self, storage):
|
||||
from mcp.shared.auth import OAuthToken
|
||||
|
||||
token = OAuthToken(access_token="access_123", token_type="Bearer", refresh_token="ref_456")
|
||||
await storage.set_tokens(token)
|
||||
|
||||
loaded = await storage.get_tokens()
|
||||
assert loaded is not None
|
||||
assert loaded.access_token == "access_123"
|
||||
assert loaded.refresh_token == "ref_456"
|
||||
assert loaded.token_type == "Bearer"
|
||||
|
||||
async def test_get_client_info_empty(self, storage):
|
||||
result = await storage.get_client_info()
|
||||
assert result is None
|
||||
|
||||
async def test_set_and_get_client_info(self, storage):
|
||||
from mcp.shared.auth import OAuthClientInformationFull
|
||||
|
||||
info = OAuthClientInformationFull(
|
||||
client_id="cid_123",
|
||||
client_secret="secret_456",
|
||||
redirect_uris=["http://localhost:8888/callback"],
|
||||
)
|
||||
await storage.set_client_info(info)
|
||||
|
||||
loaded = await storage.get_client_info()
|
||||
assert loaded is not None
|
||||
assert loaded.client_id == "cid_123"
|
||||
assert loaded.client_secret == "secret_456"
|
||||
|
||||
async def test_tokens_and_client_info_coexist(self, storage):
|
||||
"""Both tokens and client info can be stored in the same file."""
|
||||
from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
|
||||
|
||||
token = OAuthToken(access_token="tok")
|
||||
info = OAuthClientInformationFull(client_id="cid", redirect_uris=["http://localhost/cb"])
|
||||
|
||||
await storage.set_tokens(token)
|
||||
await storage.set_client_info(info)
|
||||
|
||||
loaded_tok = await storage.get_tokens()
|
||||
loaded_info = await storage.get_client_info()
|
||||
assert loaded_tok.access_token == "tok"
|
||||
assert loaded_info.client_id == "cid"
|
||||
|
||||
async def test_file_permissions(self, storage, tmp_path):
|
||||
"""Token file should be chmod 0600."""
|
||||
import os
|
||||
import stat
|
||||
|
||||
from mcp.shared.auth import OAuthToken
|
||||
|
||||
await storage.set_tokens(OAuthToken(access_token="x"))
|
||||
|
||||
path = tmp_path / "mcp_oauth" / "test-server.json"
|
||||
mode = os.stat(path).st_mode
|
||||
assert mode & stat.S_IRWXG == 0 # No group perms
|
||||
assert mode & stat.S_IRWXO == 0 # No other perms
|
||||
assert mode & stat.S_IRUSR # Owner read
|
||||
|
||||
async def test_corrupted_file_returns_none(self, storage, tmp_path):
|
||||
"""Corrupted JSON should return None, not raise."""
|
||||
path = tmp_path / "mcp_oauth" / "test-server.json"
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text("not json{{{")
|
||||
|
||||
result = await storage.get_tokens()
|
||||
assert result is None
|
||||
|
||||
result = await storage.get_client_info()
|
||||
assert result is None
|
||||
|
||||
|
||||
# ======================================================================
|
||||
# OAuth callback coordination tests
|
||||
# ======================================================================
|
||||
|
||||
|
||||
class TestOAuthCallbackCoordination:
|
||||
def setup_method(self):
|
||||
from pocketpaw.mcp import manager
|
||||
|
||||
manager._oauth_pending.clear()
|
||||
|
||||
def teardown_method(self):
|
||||
from pocketpaw.mcp import manager
|
||||
|
||||
manager._oauth_pending.clear()
|
||||
|
||||
def test_set_oauth_callback_result_resolves_future(self):
|
||||
from pocketpaw.mcp.manager import set_oauth_callback_result
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
future = loop.create_future()
|
||||
|
||||
from pocketpaw.mcp import manager
|
||||
|
||||
manager._oauth_pending["test_state_123"] = future
|
||||
|
||||
result = set_oauth_callback_result("test_state_123", "auth_code_456")
|
||||
assert result is True
|
||||
assert future.done()
|
||||
assert future.result() == ("auth_code_456", "test_state_123")
|
||||
loop.close()
|
||||
|
||||
def test_set_oauth_callback_result_unknown_state(self):
|
||||
from pocketpaw.mcp.manager import set_oauth_callback_result
|
||||
|
||||
result = set_oauth_callback_result("nonexistent_state", "code")
|
||||
assert result is False
|
||||
|
||||
def test_set_oauth_callback_result_already_resolved(self):
|
||||
from pocketpaw.mcp.manager import set_oauth_callback_result
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
future = loop.create_future()
|
||||
future.set_result(("old_code", "old_state"))
|
||||
|
||||
from pocketpaw.mcp import manager
|
||||
|
||||
manager._oauth_pending["done_state"] = future
|
||||
|
||||
result = set_oauth_callback_result("done_state", "new_code")
|
||||
assert result is False
|
||||
loop.close()
|
||||
|
||||
def test_set_ws_broadcast(self):
|
||||
from pocketpaw.mcp import manager
|
||||
from pocketpaw.mcp.manager import set_ws_broadcast
|
||||
|
||||
old = manager._ws_broadcast
|
||||
try:
|
||||
fn = MagicMock()
|
||||
set_ws_broadcast(fn)
|
||||
assert manager._ws_broadcast is fn
|
||||
finally:
|
||||
manager._ws_broadcast = old
|
||||
|
||||
|
||||
# ======================================================================
|
||||
# Preset OAuth flag tests
|
||||
# ======================================================================
|
||||
|
||||
|
||||
class TestPresetOAuthFlags:
|
||||
_OAUTH_PRESET_IDS = {
|
||||
"github",
|
||||
"notion",
|
||||
"atlassian",
|
||||
"stripe",
|
||||
"cloudflare",
|
||||
"supabase",
|
||||
"vercel",
|
||||
"gitlab",
|
||||
"figma",
|
||||
}
|
||||
|
||||
def test_http_presets_have_oauth_true(self):
|
||||
for preset_id in self._OAUTH_PRESET_IDS:
|
||||
preset = get_preset(preset_id)
|
||||
assert preset is not None, f"Preset {preset_id} not found"
|
||||
assert preset.oauth is True, f"Preset {preset_id} should have oauth=True"
|
||||
assert preset.transport == "http"
|
||||
|
||||
def test_stdio_presets_have_oauth_false(self):
|
||||
for preset in get_all_presets():
|
||||
if preset.transport == "stdio":
|
||||
assert preset.oauth is False, f"Stdio preset {preset.id} should have oauth=False"
|
||||
|
||||
def test_preset_to_config_passes_oauth(self):
|
||||
preset = get_preset("github")
|
||||
config = preset_to_config(preset)
|
||||
assert config.oauth is True
|
||||
|
||||
def test_preset_to_config_stdio_no_oauth(self):
|
||||
preset = get_preset("sentry")
|
||||
config = preset_to_config(preset, env={"SENTRY_ACCESS_TOKEN": "x"})
|
||||
assert config.oauth is False
|
||||
|
||||
|
||||
# ======================================================================
|
||||
# MCPServerConfig oauth field tests
|
||||
# ======================================================================
|
||||
|
||||
|
||||
class TestConfigOAuthField:
|
||||
def test_to_dict_includes_oauth_when_true(self):
|
||||
config = MCPServerConfig(name="test", oauth=True)
|
||||
d = config.to_dict()
|
||||
assert d["oauth"] is True
|
||||
|
||||
def test_to_dict_excludes_oauth_when_false(self):
|
||||
config = MCPServerConfig(name="test", oauth=False)
|
||||
d = config.to_dict()
|
||||
assert "oauth" not in d
|
||||
|
||||
def test_from_dict_reads_oauth(self):
|
||||
config = MCPServerConfig.from_dict({"name": "test", "oauth": True})
|
||||
assert config.oauth is True
|
||||
|
||||
def test_from_dict_defaults_oauth_false(self):
|
||||
config = MCPServerConfig.from_dict({"name": "test"})
|
||||
assert config.oauth is False
|
||||
|
||||
|
||||
# ======================================================================
|
||||
# Dashboard OAuth callback endpoint tests
|
||||
# ======================================================================
|
||||
|
||||
_TEST_TOKEN = "test-oauth-token-12345"
|
||||
|
||||
|
||||
def _auth(**extra):
|
||||
h = {"Authorization": f"Bearer {_TEST_TOKEN}"}
|
||||
h.update(extra)
|
||||
return h
|
||||
|
||||
|
||||
class TestDashboardOAuthCallback:
|
||||
@pytest.fixture(autouse=True)
|
||||
def _mock_token(self):
|
||||
with patch("pocketpaw.dashboard.get_access_token", return_value=_TEST_TOKEN):
|
||||
yield
|
||||
|
||||
@pytest.fixture
|
||||
def client(self):
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from pocketpaw.dashboard import app
|
||||
|
||||
return TestClient(app, raise_server_exceptions=False)
|
||||
|
||||
@patch("pocketpaw.mcp.manager.set_oauth_callback_result")
|
||||
def test_callback_success(self, mock_set_result, client):
|
||||
mock_set_result.return_value = True
|
||||
res = client.get("/api/mcp/oauth/callback?code=abc123&state=xyz789")
|
||||
assert res.status_code == 200
|
||||
assert "Authenticated" in res.text
|
||||
assert "window.close" in res.text
|
||||
mock_set_result.assert_called_once_with("xyz789", "abc123")
|
||||
|
||||
@patch("pocketpaw.mcp.manager.set_oauth_callback_result")
|
||||
def test_callback_expired_flow(self, mock_set_result, client):
|
||||
mock_set_result.return_value = False
|
||||
res = client.get("/api/mcp/oauth/callback?code=abc&state=xyz")
|
||||
assert res.status_code == 400
|
||||
assert "expired" in res.text.lower() or "not found" in res.text.lower()
|
||||
|
||||
def test_callback_missing_params(self, client):
|
||||
# Empty defaults → 400
|
||||
res = client.get("/api/mcp/oauth/callback")
|
||||
assert res.status_code == 400
|
||||
# Explicit empty strings → also 400
|
||||
res = client.get("/api/mcp/oauth/callback?code=&state=")
|
||||
assert res.status_code == 400
|
||||
|
||||
def test_callback_is_auth_exempt(self, client):
|
||||
"""OAuth callback should work without auth token."""
|
||||
with patch("pocketpaw.mcp.manager.set_oauth_callback_result", return_value=True):
|
||||
res = client.get(
|
||||
"/api/mcp/oauth/callback?code=test&state=test",
|
||||
# No auth header
|
||||
)
|
||||
# Should not get 401 — the endpoint is auth-exempt
|
||||
assert res.status_code != 401
|
||||
|
||||
@patch("pocketpaw.mcp.config.load_mcp_config")
|
||||
def test_presets_include_oauth_field(self, mock_load, client):
|
||||
mock_load.return_value = []
|
||||
res = client.get("/api/mcp/presets", headers=_auth())
|
||||
assert res.status_code == 200
|
||||
data = res.json()
|
||||
github = next(p for p in data if p["id"] == "github")
|
||||
assert github["oauth"] is True
|
||||
sentry = next(p for p in data if p["id"] == "sentry")
|
||||
assert sentry["oauth"] is False
|
||||
@@ -258,6 +258,7 @@ class TestSkillsRESTEndpoints:
|
||||
# Make wait_for actually call the coroutine
|
||||
async def passthrough(coro, timeout):
|
||||
return await coro
|
||||
|
||||
mock_asyncio.wait_for = passthrough
|
||||
|
||||
# Prepare a fake cloned repo with a skill inside skills/ subdir
|
||||
@@ -291,9 +292,7 @@ class TestSkillsRESTEndpoints:
|
||||
request.json = AsyncMock(return_value={"source": "owner/bad-repo/skill"})
|
||||
|
||||
mock_proc = AsyncMock()
|
||||
mock_proc.communicate = AsyncMock(
|
||||
return_value=(b"", b"fatal: repository not found\n")
|
||||
)
|
||||
mock_proc.communicate = AsyncMock(return_value=(b"", b"fatal: repository not found\n"))
|
||||
mock_proc.returncode = 128
|
||||
|
||||
with (
|
||||
@@ -305,6 +304,7 @@ class TestSkillsRESTEndpoints:
|
||||
|
||||
async def passthrough(coro, timeout):
|
||||
return await coro
|
||||
|
||||
mock_asyncio.wait_for = passthrough
|
||||
|
||||
from pocketpaw.dashboard import install_skill
|
||||
@@ -415,364 +415,3 @@ class TestMCPPresetNeedsArgs:
|
||||
for p in get_all_presets():
|
||||
# Every preset should have a bool needs_args
|
||||
assert isinstance(p.needs_args, bool), f"Preset {p.id} needs_args is not bool"
|
||||
|
||||
|
||||
# ======================================================================
|
||||
# MCP Registry endpoint tests
|
||||
# ======================================================================
|
||||
|
||||
|
||||
class TestMCPRegistryEndpoints:
|
||||
async def test_search_registry_empty_query(self):
|
||||
"""GET /api/mcp/registry/search with no query returns featured servers."""
|
||||
mock_response = MagicMock()
|
||||
# Registry API wraps each entry as {server: {...}, _meta: {...}}
|
||||
mock_response.json.return_value = {
|
||||
"servers": [
|
||||
{
|
||||
"server": {"name": "org/server", "description": "A server"},
|
||||
"_meta": {"score": 0.9},
|
||||
}
|
||||
],
|
||||
"metadata": {"count": 1},
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||
from pocketpaw.dashboard import search_mcp_registry
|
||||
|
||||
result = await search_mcp_registry(q="", limit=30, cursor="")
|
||||
assert "servers" in result
|
||||
# Backend should unwrap nested structure
|
||||
assert result["servers"][0]["name"] == "org/server"
|
||||
assert result["servers"][0]["_meta"]["score"] == 0.9
|
||||
# Should call registry with just limit (no search param)
|
||||
mock_client.get.assert_called_once()
|
||||
call_kwargs = mock_client.get.call_args
|
||||
params = call_kwargs.kwargs.get("params") or call_kwargs[1].get("params")
|
||||
assert "search" not in params
|
||||
|
||||
async def test_search_registry_with_query(self):
|
||||
"""GET /api/mcp/registry/search proxies search param to registry."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"servers": [{"server": {"name": "org/sql-server", "description": "SQL"}, "_meta": {}}],
|
||||
"metadata": {"count": 1, "nextCursor": "abc123"},
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||
from pocketpaw.dashboard import search_mcp_registry
|
||||
|
||||
result = await search_mcp_registry(q="sql", limit=10, cursor="")
|
||||
# Unwrapped: flat server object
|
||||
assert result["servers"][0]["name"] == "org/sql-server"
|
||||
call_kwargs = mock_client.get.call_args
|
||||
params = call_kwargs.kwargs.get("params") or call_kwargs[1].get("params")
|
||||
assert params["search"] == "sql"
|
||||
assert params["limit"] == 10
|
||||
|
||||
async def test_search_registry_with_cursor(self):
|
||||
"""GET /api/mcp/registry/search passes cursor for pagination."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"servers": [],
|
||||
"metadata": {"count": 0},
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||
from pocketpaw.dashboard import search_mcp_registry
|
||||
|
||||
await search_mcp_registry(q="test", limit=30, cursor="page2")
|
||||
call_kwargs = mock_client.get.call_args
|
||||
params = call_kwargs.kwargs.get("params") or call_kwargs[1].get("params")
|
||||
assert params["cursor"] == "page2"
|
||||
|
||||
async def test_search_registry_limit_capped(self):
|
||||
"""Limit is capped at 100."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"servers": [], "metadata": {"count": 0}}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||
from pocketpaw.dashboard import search_mcp_registry
|
||||
|
||||
await search_mcp_registry(q="x", limit=999, cursor="")
|
||||
call_kwargs = mock_client.get.call_args
|
||||
params = call_kwargs.kwargs.get("params") or call_kwargs[1].get("params")
|
||||
assert params["limit"] == 100
|
||||
|
||||
async def test_search_registry_accepts_flat_entries(self):
|
||||
"""Flat server rows (no nested 'server' key) are normalized correctly."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"servers": [
|
||||
{
|
||||
"name": "org/flat-server",
|
||||
"description": "Flat format row",
|
||||
"packages": [],
|
||||
"remotes": [],
|
||||
}
|
||||
],
|
||||
"metadata": {"count": 1, "next_cursor": "cursor123"},
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||
from pocketclaw.dashboard import search_mcp_registry
|
||||
|
||||
result = await search_mcp_registry(q="flat", limit=10, cursor="")
|
||||
assert result["servers"][0]["name"] == "org/flat-server"
|
||||
# next_cursor is normalized for frontend compatibility
|
||||
assert result["metadata"]["nextCursor"] == "cursor123"
|
||||
|
||||
async def test_search_registry_skips_malformed_rows(self):
|
||||
"""Malformed rows are skipped instead of crashing the whole result set."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"servers": [
|
||||
None,
|
||||
{"server": None},
|
||||
{
|
||||
"server": {
|
||||
"name": "org/good-server",
|
||||
"description": "Good row",
|
||||
"packages": [{"environmentVariables": [{"name": "API_KEY"}]}],
|
||||
"remotes": [],
|
||||
}
|
||||
},
|
||||
],
|
||||
"metadata": "invalid-metadata-type",
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||
from pocketclaw.dashboard import search_mcp_registry
|
||||
|
||||
result = await search_mcp_registry(q="good", limit=10, cursor="")
|
||||
assert len(result["servers"]) == 1
|
||||
assert result["servers"][0]["name"] == "org/good-server"
|
||||
# env vars should be lifted from package metadata
|
||||
assert result["servers"][0]["environmentVariables"][0]["name"] == "API_KEY"
|
||||
# invalid metadata falls back to normalized dict with count
|
||||
assert result["metadata"]["count"] == 1
|
||||
|
||||
async def test_install_from_registry_missing_name(self):
|
||||
"""POST /api/mcp/registry/install with empty server name returns 400."""
|
||||
from pocketpaw.dashboard import install_from_registry
|
||||
|
||||
request = MagicMock()
|
||||
request.json = AsyncMock(return_value={"server": {}, "env": {}})
|
||||
|
||||
result = await install_from_registry(request)
|
||||
assert result.status_code == 400
|
||||
|
||||
async def test_install_from_registry_http_remote(self):
|
||||
"""Install from registry with HTTP remote transport (legacy transportType key)."""
|
||||
mock_mgr = MagicMock()
|
||||
mock_mgr.add_server_config = MagicMock()
|
||||
mock_mgr.start_server = AsyncMock(return_value=True)
|
||||
mock_mgr.discover_tools = MagicMock(return_value=[])
|
||||
|
||||
request = MagicMock()
|
||||
request.json = AsyncMock(
|
||||
return_value={
|
||||
"server": {
|
||||
"name": "org/my-server",
|
||||
"remotes": [{"url": "https://api.example.com/mcp", "transportType": "http"}],
|
||||
},
|
||||
"env": {},
|
||||
}
|
||||
)
|
||||
|
||||
with patch("pocketpaw.mcp.manager.get_mcp_manager", return_value=mock_mgr):
|
||||
from pocketpaw.dashboard import install_from_registry
|
||||
|
||||
result = await install_from_registry(request)
|
||||
assert result["status"] == "ok"
|
||||
assert result["name"] == "my-server"
|
||||
assert result["connected"] is True
|
||||
# Verify the config was created with HTTP transport
|
||||
config = mock_mgr.add_server_config.call_args[0][0]
|
||||
assert config.transport == "http"
|
||||
assert config.url == "https://api.example.com/mcp"
|
||||
|
||||
async def test_install_from_registry_streamable_http_remote(self):
|
||||
"""Install from registry with streamable-http transport (actual registry API format)."""
|
||||
mock_mgr = MagicMock()
|
||||
mock_mgr.add_server_config = MagicMock()
|
||||
mock_mgr.start_server = AsyncMock(return_value=True)
|
||||
mock_mgr.discover_tools = MagicMock(return_value=[])
|
||||
|
||||
request = MagicMock()
|
||||
request.json = AsyncMock(
|
||||
return_value={
|
||||
"server": {
|
||||
"name": "org/stream-server",
|
||||
"remotes": [{"url": "https://api.example.com/mcp", "type": "streamable-http"}],
|
||||
},
|
||||
"env": {},
|
||||
}
|
||||
)
|
||||
|
||||
with patch("pocketpaw.mcp.manager.get_mcp_manager", return_value=mock_mgr):
|
||||
from pocketpaw.dashboard import install_from_registry
|
||||
|
||||
result = await install_from_registry(request)
|
||||
assert result["status"] == "ok"
|
||||
assert result["name"] == "stream-server"
|
||||
# streamable-http is preserved (needs different MCP SDK client)
|
||||
config = mock_mgr.add_server_config.call_args[0][0]
|
||||
assert config.transport == "streamable-http"
|
||||
assert config.url == "https://api.example.com/mcp"
|
||||
|
||||
async def test_install_from_registry_npm_package(self):
|
||||
"""Install from registry with npm package (stdio)."""
|
||||
mock_mgr = MagicMock()
|
||||
mock_mgr.add_server_config = MagicMock()
|
||||
mock_mgr.start_server = AsyncMock(return_value=False)
|
||||
mock_mgr.discover_tools = MagicMock(return_value=[])
|
||||
|
||||
request = MagicMock()
|
||||
request.json = AsyncMock(
|
||||
return_value={
|
||||
"server": {
|
||||
"name": "org/cool-mcp",
|
||||
"packages": [
|
||||
{
|
||||
"registryType": "npm",
|
||||
"name": "@cool/mcp-server",
|
||||
"runtime": "node",
|
||||
"packageArguments": [],
|
||||
}
|
||||
],
|
||||
},
|
||||
"env": {"API_KEY": "test123"},
|
||||
}
|
||||
)
|
||||
|
||||
with patch("pocketpaw.mcp.manager.get_mcp_manager", return_value=mock_mgr):
|
||||
from pocketpaw.dashboard import install_from_registry
|
||||
|
||||
result = await install_from_registry(request)
|
||||
assert result["status"] == "ok"
|
||||
assert result["name"] == "cool-mcp"
|
||||
config = mock_mgr.add_server_config.call_args[0][0]
|
||||
assert config.transport == "stdio"
|
||||
assert config.command == "npx"
|
||||
assert "-y" in config.args
|
||||
assert "@cool/mcp-server" in config.args
|
||||
assert config.env == {"API_KEY": "test123"}
|
||||
|
||||
async def test_install_from_registry_pypi_package(self):
|
||||
"""Install from registry with pypi package (uvx)."""
|
||||
mock_mgr = MagicMock()
|
||||
mock_mgr.add_server_config = MagicMock()
|
||||
mock_mgr.start_server = AsyncMock(return_value=True)
|
||||
mock_mgr.discover_tools = MagicMock(return_value=[])
|
||||
|
||||
request = MagicMock()
|
||||
request.json = AsyncMock(
|
||||
return_value={
|
||||
"server": {
|
||||
"name": "org/py-server",
|
||||
"packages": [
|
||||
{"registryType": "pypi", "name": "mcp-py-server", "runtime": "python"}
|
||||
],
|
||||
},
|
||||
"env": {},
|
||||
}
|
||||
)
|
||||
|
||||
with patch("pocketpaw.mcp.manager.get_mcp_manager", return_value=mock_mgr):
|
||||
from pocketpaw.dashboard import install_from_registry
|
||||
|
||||
result = await install_from_registry(request)
|
||||
assert result["status"] == "ok"
|
||||
config = mock_mgr.add_server_config.call_args[0][0]
|
||||
assert config.command == "uvx"
|
||||
assert "mcp-py-server" in config.args
|
||||
|
||||
async def test_install_from_registry_docker_package(self):
|
||||
"""Install from registry with docker package."""
|
||||
mock_mgr = MagicMock()
|
||||
mock_mgr.add_server_config = MagicMock()
|
||||
mock_mgr.start_server = AsyncMock(return_value=True)
|
||||
mock_mgr.discover_tools = MagicMock(return_value=[])
|
||||
|
||||
request = MagicMock()
|
||||
request.json = AsyncMock(
|
||||
return_value={
|
||||
"server": {
|
||||
"name": "org/docker-srv",
|
||||
"packages": [
|
||||
{
|
||||
"registryType": "docker",
|
||||
"name": "ghcr.io/org/mcp-docker",
|
||||
"runtimeArguments": [
|
||||
{"isFixed": True, "value": "-p"},
|
||||
{"isFixed": True, "value": "3000:3000"},
|
||||
],
|
||||
}
|
||||
],
|
||||
},
|
||||
"env": {},
|
||||
}
|
||||
)
|
||||
|
||||
with patch("pocketpaw.mcp.manager.get_mcp_manager", return_value=mock_mgr):
|
||||
from pocketpaw.dashboard import install_from_registry
|
||||
|
||||
result = await install_from_registry(request)
|
||||
assert result["status"] == "ok"
|
||||
config = mock_mgr.add_server_config.call_args[0][0]
|
||||
assert config.command == "docker"
|
||||
assert "run" in config.args
|
||||
assert "-p" in config.args
|
||||
assert "ghcr.io/org/mcp-docker" in config.args
|
||||
|
||||
async def test_install_from_registry_no_install_method(self):
|
||||
"""POST /api/mcp/registry/install with no packages or remotes returns 400."""
|
||||
from pocketpaw.dashboard import install_from_registry
|
||||
|
||||
request = MagicMock()
|
||||
request.json = AsyncMock(
|
||||
return_value={
|
||||
"server": {"name": "org/empty-server"},
|
||||
"env": {},
|
||||
}
|
||||
)
|
||||
|
||||
result = await install_from_registry(request)
|
||||
assert result.status_code == 400
|
||||
|
||||
Reference in New Issue
Block a user