feat(mcp): add OAuth support, remove registry tab, improve error handling

- Add MCP OAuth flow: oauth_store.py for token persistence, callback
  coordination via asyncio Futures, dashboard callback endpoint, and
  WebSocket redirect broadcast to frontend
- Add `oauth` flag to MCPServerConfig and MCPPreset for OAuth-aware
  install flow (browser popup for auth)
- Remove MCP Registry tab and related dashboard/JS code (replaced by
  curated preset catalog)
- Improve MCP manager error handling: unwrap ExceptionGroup to surface
  root-cause errors instead of unhelpful anyio cancel-scope messages
- Add mcp_client_metadata_url config field for OIDC client metadata
- Frontend: streamline mcp.js (OAuth popup handling, cancel flow)
- Tests: 24 new OAuth tests, updated manager tests

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Rohit Kushwaha
2026-02-17 01:45:18 +05:30
parent cf98082e40
commit 38c0aac72c
12 changed files with 901 additions and 3066 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,13 @@
{
"client_name": "PocketPaw",
"redirect_uris": [
"http://localhost:8888/api/mcp/oauth/callback",
"http://127.0.0.1:8888/api/mcp/oauth/callback"
],
"token_endpoint_auth_method": "none",
"grant_types": ["authorization_code", "refresh_token"],
"response_types": ["code"],
"client_uri": "https://github.com/pocketpaw/pocketpaw",
"software_id": "pocketpaw",
"software_version": "0.4.1"
}

View File

@@ -411,6 +411,12 @@ class Settings(BaseSettings):
web_host: str = Field(default="127.0.0.1", description="Web server host")
web_port: int = Field(default=8888, description="Web server port")
# MCP OAuth
mcp_client_metadata_url: str = Field(
default="",
description="CIMD URL for MCP OAuth (optional, for servers without dynamic registration)",
)
# Identity / Multi-user
owner_id: str = Field(
default="",
@@ -553,6 +559,8 @@ class Settings(BaseSettings):
"google_oauth_client_secret": (
self.google_oauth_client_secret or existing.get("google_oauth_client_secret")
),
# MCP OAuth
"mcp_client_metadata_url": self.mcp_client_metadata_url,
# Voice/TTS
"tts_provider": self.tts_provider,
"elevenlabs_api_key": (self.elevenlabs_api_key or existing.get("elevenlabs_api_key")),

View File

@@ -399,9 +399,19 @@ async def startup_event():
except Exception as e:
logger.warning("Failed to recover interrupted projects: %s", e)
# Auto-start enabled MCP servers
# Wire MCP OAuth broadcast + auto-start enabled MCP servers
try:
from pocketpaw.mcp.manager import get_mcp_manager
from pocketpaw.mcp.manager import get_mcp_manager, set_ws_broadcast
async def _mcp_ws_broadcast(message: dict) -> None:
"""Broadcast an MCP message to all connected WebSocket clients."""
for ws in active_connections[:]:
try:
await ws.send_json(message)
except Exception:
pass
set_ws_broadcast(_mcp_ws_broadcast)
mcp = get_mcp_manager()
await mcp.start_enabled_servers()
@@ -616,6 +626,7 @@ async def list_mcp_presets():
"url": p.url,
"docs_url": p.docs_url,
"needs_args": p.needs_args,
"oauth": p.oauth,
"installed": p.id in installed_names,
"env_keys": [
{
@@ -670,244 +681,35 @@ async def install_mcp_preset(request: Request):
}
# ==================== MCP Registry API ====================
@app.get("/api/mcp/oauth/callback")
async def mcp_oauth_callback(code: str = "", state: str = ""):
"""OAuth callback endpoint — receives authorization code from OAuth provider.
_MCP_REGISTRY_BASE = "https://registry.modelcontextprotocol.io"
# Server name parts that are too generic to use alone as a config name.
_GENERIC_SERVER_PARTS = {"mcp", "server", "mcp-server", "main", "app", "api"}
def _derive_registry_short_name(raw_name: str, title: str | None = None) -> str:
"""Derive a short, readable config name from a registry server name.
Examples:
"com.zomato/mcp" -> "zomato-mcp"
"acme/weather-server" -> "weather-server"
"@anthropic/claude" -> "claude"
"simple-tool" -> "simple-tool"
This is the redirect target after user authenticates with GitHub, Notion, etc.
Auth-exempt because the OAuth provider redirects the user's browser here.
"""
if not raw_name:
return ""
from fastapi.responses import HTMLResponse
if "/" not in raw_name:
return raw_name
from pocketpaw.mcp.manager import set_oauth_callback_result
parts = raw_name.split("/")
org = parts[0]
server_part = parts[-1]
# Clean up org: "com.zomato" -> "zomato", "@anthropic" -> "anthropic"
if "." in org:
org = org.rsplit(".", 1)[-1]
org = org.lstrip("@")
# If the server part is too generic, combine with org for disambiguation
if server_part.lower() in _GENERIC_SERVER_PARTS:
return f"{org}-{server_part}"
return server_part
@app.get("/api/mcp/registry/search")
async def search_mcp_registry(
q: str = "",
limit: int = 30,
cursor: str = "",
):
"""Proxy search to the official MCP Registry (avoids CORS)."""
import httpx
params: dict[str, str | int] = {"limit": min(limit, 100)}
if q:
params["search"] = q
if cursor:
params["cursor"] = cursor
try:
async with httpx.AsyncClient(timeout=15.0) as client:
resp = await client.get(
f"{_MCP_REGISTRY_BASE}/v0/servers",
params=params,
)
resp.raise_for_status()
data = resp.json()
# Registry wraps each entry as {server: {...}, _meta: {...}}.
# Unwrap so the frontend gets flat server objects.
# Also: lift environmentVariables from packages[0] to server
# level, remove $schema ($ prefix can confuse Alpine.js proxies),
# and ensure expected fields have defaults.
servers = []
raw_entries = data.get("servers", [])
if not isinstance(raw_entries, list):
raw_entries = []
for entry in raw_entries:
if not isinstance(entry, dict):
continue
raw_server = entry.get("server", entry)
if not isinstance(raw_server, dict):
continue
srv = dict(raw_server)
meta = entry.get("_meta", srv.get("_meta", {}))
srv["_meta"] = meta if isinstance(meta, dict) else {}
srv.pop("$schema", None)
name = srv.get("name")
description = srv.get("description")
packages = srv.get("packages")
remotes = srv.get("remotes")
env_vars = srv.get("environmentVariables")
srv["name"] = name if isinstance(name, str) else ""
srv["description"] = description if isinstance(description, str) else ""
srv["packages"] = packages if isinstance(packages, list) else []
srv["remotes"] = remotes if isinstance(remotes, list) else []
srv["environmentVariables"] = env_vars if isinstance(env_vars, list) else []
# Lift env vars from the first package to the server level.
if not srv["environmentVariables"]:
for pkg in srv["packages"]:
if not isinstance(pkg, dict):
continue
pkg_env = pkg.get("environmentVariables")
if isinstance(pkg_env, list) and pkg_env:
srv["environmentVariables"] = pkg_env
break
# Skip entries without a valid name.
if srv["name"]:
servers.append(srv)
metadata = data.get("metadata", {})
if not isinstance(metadata, dict):
metadata = {}
if "nextCursor" not in metadata and "next_cursor" in metadata:
metadata["nextCursor"] = metadata["next_cursor"]
metadata.setdefault("count", len(servers))
return {"servers": servers, "metadata": metadata}
except Exception as exc:
logger.warning("MCP registry search failed: %s", exc)
return {"servers": [], "metadata": {"count": 0}, "error": str(exc)}
@app.post("/api/mcp/registry/install")
async def install_from_registry(request: Request):
"""Install an MCP server from registry metadata.
Expects a JSON body with the server's registry data (name, packages/remotes,
environmentVariables) and user-supplied env values.
"""
from fastapi.responses import JSONResponse
from pocketpaw.mcp.config import MCPServerConfig
from pocketpaw.mcp.manager import get_mcp_manager
data = await request.json()
server = data.get("server", {})
user_env = data.get("env", {})
# Derive a short, readable name from the registry name.
# e.g. "com.zomato/mcp" -> "zomato-mcp", "acme/weather-server" -> "weather-server"
raw_name = server.get("name", "")
short_name = _derive_registry_short_name(raw_name, server.get("title"))
if not short_name:
return JSONResponse({"error": "Missing server name"}, status_code=400)
# Try remotes first (HTTP transport — simplest, no npm needed)
remotes = server.get("remotes", [])
packages = server.get("packages", [])
config = None
if remotes:
remote = remotes[0]
# Registry API uses "type" (e.g. "streamable-http"), legacy uses "transportType"
transport = remote.get("type", remote.get("transportType", "http"))
# Normalize SSE to "http" but keep "streamable-http" distinct — they need
# different MCP SDK clients.
if transport == "sse":
transport = "http"
elif transport not in ("http", "streamable-http"):
transport = "http" # safe fallback
config = MCPServerConfig(
name=short_name,
transport=transport,
url=remote.get("url", ""),
env=user_env,
enabled=True,
)
elif packages:
pkg = packages[0]
registry_type = pkg.get("registryType", "")
pkg_name = pkg.get("name", "") or pkg.get("identifier", "")
runtime = pkg.get("runtime", "node")
if registry_type == "docker":
args = ["run", "-i", "--rm"]
for ra in pkg.get("runtimeArguments", []):
if ra.get("isFixed"):
args.append(ra.get("value", ""))
args.append(pkg_name)
config = MCPServerConfig(
name=short_name,
transport="stdio",
command="docker",
args=args,
env=user_env,
enabled=True,
)
elif registry_type == "pypi":
config = MCPServerConfig(
name=short_name,
transport="stdio",
command="uvx",
args=[pkg_name],
env=user_env,
enabled=True,
)
elif registry_type == "npm" or runtime == "node":
args = ["-y", pkg_name]
for pa in pkg.get("packageArguments", []):
if pa.get("isFixed"):
args.append(pa.get("value", ""))
config = MCPServerConfig(
name=short_name,
transport="stdio",
command="npx",
args=args,
env=user_env,
enabled=True,
)
if config is None:
return JSONResponse(
{"error": "Could not determine install method from registry data"},
if not code or not state:
return HTMLResponse(
"<html><body><h3>Missing code or state parameter.</h3></body></html>",
status_code=400,
)
mgr = get_mcp_manager()
mgr.add_server_config(config)
connected = await mgr.start_server(config)
tools = mgr.discover_tools(config.name) if connected else []
result: dict = {
"status": "ok",
"name": config.name,
"connected": connected,
"tools": [{"name": t.name, "description": t.description} for t in tools],
}
# Surface connection error so the frontend can display it
if not connected:
status = mgr.get_server_status()
srv = status.get(config.name, {})
if srv.get("error"):
result["error"] = srv["error"]
return result
resolved = set_oauth_callback_result(state, code)
if resolved:
return HTMLResponse(
"<html><body>"
"<h3>Authenticated! You can close this tab.</h3>"
"<script>window.close()</script>"
"</body></html>"
)
return HTMLResponse(
"<html><body><h3>OAuth flow expired or not found.</h3></body></html>",
status_code=400,
)
# ==================== Skills Library API ====================
@@ -1704,6 +1506,7 @@ async def auth_middleware(request: Request, call_next):
"/webhook/inbound",
"/api/whatsapp/qr",
"/oauth/callback",
"/api/mcp/oauth/callback",
]
for path in exempt_paths:

View File

@@ -2,14 +2,14 @@
* PocketPaw - MCP Servers Feature Module
*
* Created: 2026-02-07
* Updated: 2026-02-12 — Registry tab (browse official MCP registry), dynamic categories, needs_args.
* Updated: 2026-02-17 Removed Registry tab, added paste-command input.
*
* Manages MCP (Model Context Protocol) server connections:
* - List/add/remove servers
* - Enable/disable servers
* - View tool inventory
* - Browse & install presets from the catalog
* - Search & install from the official MCP Registry (16K+ servers)
* - Browse & install presets from the curated catalog
* - Paste a full command to auto-fill Add Server form
*/
window.PocketPaw = window.PocketPaw || {};
@@ -28,7 +28,8 @@ window.PocketPaw.MCP = {
transport: 'stdio',
command: '',
args: '',
url: ''
url: '',
fullCommand: ''
},
mcpLoading: false,
mcpShowAddForm: false,
@@ -38,17 +39,8 @@ window.PocketPaw.MCP = {
mcpInstallEnv: {},
mcpInstallArgs: '',
mcpInstalling: false,
mcpCategoryFilter: 'all',
// Registry state
mcpRegistryQuery: '',
mcpRegistryResults: [],
mcpRegistryFeatured: [],
mcpRegistryLoading: false,
mcpRegistryFeaturedError: false,
mcpRegistryCursor: null,
mcpRegistryLoadingMore: false,
mcpRegistryInstalling: null,
mcpRegistryInstallEnv: {}
mcpInstallAbort: null,
mcpCategoryFilter: 'all'
};
},
@@ -64,6 +56,31 @@ window.PocketPaw.MCP = {
this.showMCP = true;
await this.getMCPStatus();
await this.loadPresets();
// Register WS handler for OAuth redirect (once)
if (!window.PocketPaw._mcpOAuthRegistered && window.socket) {
window.socket.on('mcp_oauth_redirect', (data) => {
if (!data.url) return;
// Navigate pre-opened popup or show fallback
const popup = window.PocketPaw._oauthPopup;
if (popup && !popup.closed) {
popup.location = data.url;
} else {
// Popup was blocked — show clickable link
const name = data.server || 'server';
if (this.showToast) {
this.showToast(
`Open auth link for ${name}: ` +
data.url.substring(0, 60) + '...',
'info'
);
}
window.open(data.url, '_blank');
}
});
window.PocketPaw._mcpOAuthRegistered = true;
}
this.$nextTick(() => {
if (window.refreshIcons) window.refreshIcons();
});
@@ -108,7 +125,7 @@ window.PocketPaw.MCP = {
const data = await res.json();
if (data.status === 'ok') {
this.showToast(`MCP server "${this.mcpForm.name}" added`, 'success');
this.mcpForm = { name: '', transport: 'stdio', command: '', args: '', url: '' };
this.mcpForm = { name: '', transport: 'stdio', command: '', args: '', url: '', fullCommand: '' };
await this.getMCPStatus();
} else {
this.showToast(data.error || 'Failed to add server', 'error');
@@ -199,12 +216,29 @@ window.PocketPaw.MCP = {
}
},
/**
* Cancel an in-progress install (abort fetch, reset state, close popup)
*/
cancelInstall() {
if (this.mcpInstallAbort) {
this.mcpInstallAbort.abort();
this.mcpInstallAbort = null;
}
this.mcpInstalling = false;
this.mcpInstallId = null;
const popup = window.PocketPaw._oauthPopup;
if (popup && !popup.closed) {
try { popup.close(); } catch (_) { /* cross-origin */ }
}
window.PocketPaw._oauthPopup = null;
},
/**
* Show install form for a preset
*/
showInstallForm(presetId) {
if (this.mcpInstallId === presetId) {
this.mcpInstallId = null;
this.cancelInstall();
return;
}
this.mcpInstallId = presetId;
@@ -228,6 +262,21 @@ window.PocketPaw.MCP = {
async installPreset() {
if (!this.mcpInstallId) return;
this.mcpInstalling = true;
// AbortController so Cancel can kill the pending fetch
const abort = new AbortController();
this.mcpInstallAbort = abort;
// For OAuth presets: open a blank popup NOW (in user click context)
// so the browser allows it. The WS handler will navigate it later.
const isOAuth = this.presetIsOAuth(this.mcpInstallId);
if (isOAuth) {
window.PocketPaw._oauthPopup = window.open(
'about:blank', 'pocketpaw_oauth',
'width=600,height=700,scrollbars=yes'
);
}
try {
const body = {
preset_id: this.mcpInstallId,
@@ -240,7 +289,8 @@ window.PocketPaw.MCP = {
const res = await fetch('/api/mcp/presets/install', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify(body)
body: JSON.stringify(body),
signal: abort.signal
});
const data = await res.json();
if (res.ok && data.status === 'ok') {
@@ -256,9 +306,21 @@ window.PocketPaw.MCP = {
this.showToast(data.error || 'Install failed', 'error');
}
} catch (e) {
if (e.name === 'AbortError') return; // User cancelled
this.showToast('Install failed: ' + e.message, 'error');
} finally {
this.mcpInstalling = false;
this.mcpInstallAbort = null;
// Close leftover OAuth popup if still blank
const popup = window.PocketPaw._oauthPopup;
if (popup && !popup.closed) {
try {
if (popup.location.href === 'about:blank') {
popup.close();
}
} catch (_) { /* cross-origin — popup navigated, leave it */ }
}
window.PocketPaw._oauthPopup = null;
this.$nextTick(() => {
if (window.refreshIcons) window.refreshIcons();
});
@@ -290,294 +352,62 @@ window.PocketPaw.MCP = {
},
/**
* Normalize one registry entry to a flat server object.
* Accepts both wrapped shape ({server, _meta}) and flat shape.
* Check if a preset uses OAuth authentication
*/
normalizeRegistryServer(entry) {
if (!entry || typeof entry !== 'object') return null;
presetIsOAuth(presetId) {
const preset = this.mcpPresets.find(p => p.id === presetId);
return preset ? !!preset.oauth : false;
},
const raw = (entry.server && typeof entry.server === 'object')
? entry.server
: entry;
if (!raw || typeof raw !== 'object') return null;
/**
* Get button text for a preset based on OAuth status and install state
*/
presetButtonText(presetId) {
if (this.mcpInstallId === presetId) return 'Cancel';
return this.presetIsOAuth(presetId) ? 'Authenticate' : 'Install';
},
const server = { ...raw };
const meta = (entry._meta && typeof entry._meta === 'object')
? entry._meta
: ((server._meta && typeof server._meta === 'object') ? server._meta : {});
/**
* Parse a pasted full command string and auto-populate the Add Server form.
* Handles patterns like:
* "npx -y @some/package"
* "uvx mcp-server-git"
* "docker run -i --rm ghcr.io/org/img"
*/
parseFullCommand() {
const raw = (this.mcpForm.fullCommand || '').trim();
if (!raw) return;
server._meta = meta;
delete server.$schema;
const parts = raw.split(/\s+/);
if (parts.length === 0) return;
server.name = typeof server.name === 'string' ? server.name : '';
server.title = typeof server.title === 'string' ? server.title : '';
server.description = typeof server.description === 'string' ? server.description : '';
server.version = typeof server.version === 'string' ? server.version : '';
server.packages = Array.isArray(server.packages) ? server.packages : [];
server.remotes = Array.isArray(server.remotes) ? server.remotes : [];
server.environmentVariables = Array.isArray(server.environmentVariables)
? server.environmentVariables
: [];
const command = parts[0];
const args = parts.slice(1);
// Lift env vars from package metadata when server-level list is missing.
if (server.environmentVariables.length === 0) {
for (const pkg of server.packages) {
if (!pkg || typeof pkg !== 'object') continue;
if (Array.isArray(pkg.environmentVariables) && pkg.environmentVariables.length > 0) {
server.environmentVariables = pkg.environmentVariables;
break;
}
this.mcpForm.command = command;
this.mcpForm.args = args.join(', ');
// Auto-derive a name from the last arg that looks like a package
let name = '';
for (let i = args.length - 1; i >= 0; i--) {
const a = args[i];
// Skip flags and version suffixes
if (a.startsWith('-')) continue;
// Use the package-like arg
name = a
.replace(/@latest$/, '')
.replace(/@[\d.]+.*$/, '');
// Extract short name: "@scope/pkg" -> "pkg", "mcp-server-git" -> "mcp-server-git"
if (name.includes('/')) {
name = name.split('/').pop() || name;
}
name = name.replace(/^@/, '');
break;
}
return server.name ? server : null;
},
/**
* Normalize a registry server list into safe flat entries.
*/
normalizeRegistryServers(entries) {
if (!Array.isArray(entries)) return [];
const normalized = [];
for (const entry of entries) {
const server = this.normalizeRegistryServer(entry);
if (server) normalized.push(server);
if (name && !this.mcpForm.name) {
this.mcpForm.name = name;
}
return normalized;
},
/**
* Normalize next-cursor variants returned by registry metadata.
*/
registryNextCursor(metadata) {
if (!metadata || typeof metadata !== 'object') return null;
return metadata.nextCursor || metadata.next_cursor || null;
},
// ==================== Registry Methods ====================
/**
* Search the official MCP Registry (debounced via Alpine @input.debounce)
*/
async searchRegistry() {
const q = this.mcpRegistryQuery.trim();
if (!q) {
this.mcpRegistryResults = [];
this.mcpRegistryCursor = null;
return;
}
this.mcpRegistryLoading = true;
try {
const url = `/api/mcp/registry/search?q=${encodeURIComponent(q)}&limit=30`;
const res = await fetch(url);
if (res.ok) {
const data = await res.json();
this.mcpRegistryResults = this.normalizeRegistryServers(data.servers);
this.mcpRegistryCursor = this.registryNextCursor(data.metadata);
if (data.error) {
this.showToast(`Registry search failed: ${data.error}`, 'error');
}
} else {
this.mcpRegistryResults = [];
this.mcpRegistryCursor = null;
}
} catch (e) {
console.error('Registry search failed', e);
} finally {
this.mcpRegistryLoading = false;
this.$nextTick(() => {
if (window.refreshIcons) window.refreshIcons();
});
}
},
/**
* Load featured/popular registry servers for initial view
*/
async loadRegistryFeatured() {
if (this.mcpRegistryFeatured.length > 0) return;
this.mcpRegistryLoading = true;
this.mcpRegistryFeaturedError = false;
try {
const res = await fetch('/api/mcp/registry/search?limit=30');
if (res.ok) {
const data = await res.json();
this.mcpRegistryFeatured = this.normalizeRegistryServers(data.servers);
if (data.error) {
this.mcpRegistryFeaturedError = true;
}
} else {
this.mcpRegistryFeaturedError = true;
}
} catch (e) {
console.error('Failed to load registry featured', e);
this.mcpRegistryFeaturedError = true;
} finally {
this.mcpRegistryLoading = false;
this.$nextTick(() => {
if (window.refreshIcons) window.refreshIcons();
});
}
},
/**
* Retry loading featured servers (clears cache first)
*/
async retryRegistryFeatured() {
this.mcpRegistryFeatured = [];
await this.loadRegistryFeatured();
},
/**
* Load more registry results (pagination)
*/
async loadMoreRegistry() {
if (!this.mcpRegistryCursor || this.mcpRegistryLoadingMore) return;
this.mcpRegistryLoadingMore = true;
try {
const q = this.mcpRegistryQuery.trim();
let url = `/api/mcp/registry/search?limit=30&cursor=${encodeURIComponent(this.mcpRegistryCursor)}`;
if (q) url += `&q=${encodeURIComponent(q)}`;
const res = await fetch(url);
if (res.ok) {
const data = await res.json();
const newServers = this.normalizeRegistryServers(data.servers);
this.mcpRegistryResults = [...this.mcpRegistryResults, ...newServers];
this.mcpRegistryCursor = this.registryNextCursor(data.metadata);
}
} catch (e) {
console.error('Registry load more failed', e);
} finally {
this.mcpRegistryLoadingMore = false;
this.$nextTick(() => {
if (window.refreshIcons) window.refreshIcons();
});
}
},
/**
* Get the list to display in registry view
*/
registryDisplayResults() {
return this.mcpRegistryQuery.trim()
? this.mcpRegistryResults
: this.mcpRegistryFeatured;
},
/**
* Extract a short display name from a registry server
*/
registryServerName(server) {
if (server.title) return server.title;
const name = server.name || '';
return name.includes('/') ? name.split('/').pop() : name;
},
/**
* Extract a source label (e.g. "npm: @mcp/server" or "HTTP")
*/
registryServerSource(server) {
const remotes = server.remotes || [];
const packages = server.packages || [];
if (remotes.length > 0) return 'HTTP';
if (packages.length > 0) {
const pkg = packages[0];
const type = pkg.registryType || 'npm';
return `${type}: ${pkg.name || ''}`;
}
return server.name || '';
},
/**
* Check if a registry server is already installed locally
*/
isRegistryServerInstalled(server) {
const rawName = server.name || '';
const installed = Object.keys(this.mcpServers).map(n => n.toLowerCase());
// Check both the full derived name and the simple last-segment name
const parts = rawName.split('/');
const lastPart = (parts.pop() || '').toLowerCase();
const orgPart = parts.length > 0
? (parts[0].includes('.') ? parts[0].split('.').pop() : parts[0]).replace(/^@/, '').toLowerCase()
: '';
const generic = ['mcp', 'server', 'mcp-server', 'main', 'app', 'api'];
const derivedName = generic.includes(lastPart) && orgPart
? `${orgPart}-${lastPart}`
: lastPart;
return installed.includes(derivedName) || installed.includes(lastPart);
},
/**
* Get env vars required by a registry server
*/
registryServerEnvVars(server) {
return (server.environmentVariables || []).filter(ev => ev.required !== false);
},
/**
* Show install form for a registry server
*/
showRegistryInstallForm(serverName) {
if (this.mcpRegistryInstalling === serverName) {
this.mcpRegistryInstalling = null;
return;
}
this.mcpRegistryInstalling = serverName;
// Pre-fill env
const results = this.registryDisplayResults();
const server = results.find(s => s.name === serverName);
const env = {};
if (server) {
for (const ev of (server.environmentVariables || [])) {
env[ev.name] = '';
}
}
this.mcpRegistryInstallEnv = env;
this.$nextTick(() => {
if (window.refreshIcons) window.refreshIcons();
});
},
/**
* Install a server from the registry
*/
async installFromRegistry(server) {
const serverName = server.name;
this.mcpRegistryInstalling = serverName;
try {
const res = await fetch('/api/mcp/registry/install', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
server: server,
env: this.mcpRegistryInstallEnv
})
});
const data = await res.json();
if (res.ok && data.status === 'ok') {
const toolCount = data.tools ? data.tools.length : 0;
let msg;
if (data.connected) {
msg = `Installed "${data.name}" — ${toolCount} tools`;
} else {
msg = `Installed "${data.name}" (not yet connected)`;
if (data.error) msg += `: ${data.error}`;
}
this.showToast(msg, data.connected ? 'success' : 'warning');
this.mcpRegistryInstalling = null;
await this.getMCPStatus();
} else {
this.showToast(data.error || 'Install failed', 'error');
this.mcpRegistryInstalling = null;
}
} catch (e) {
this.showToast('Install failed: ' + e.message, 'error');
this.mcpRegistryInstalling = null;
}
this.$nextTick(() => {
if (window.refreshIcons) window.refreshIcons();
});
}
};
}

View File

@@ -29,9 +29,14 @@ class MCPServerConfig:
env: dict[str, str] = field(default_factory=dict)
enabled: bool = True
timeout: int = 30 # Connection timeout in seconds
# Legacy: original registry identifier (e.g. "@cmd8/excalidraw-mcp@0.1.4").
# Kept for backward compatibility with servers installed from the now-removed
# MCP Registry tab. Not used by new installations.
registry_ref: str = ""
oauth: bool = False # True if server uses OAuth authentication
def to_dict(self) -> dict:
return {
d = {
"name": self.name,
"transport": self.transport,
"command": self.command,
@@ -41,6 +46,11 @@ class MCPServerConfig:
"enabled": self.enabled,
"timeout": self.timeout,
}
if self.registry_ref:
d["registry_ref"] = self.registry_ref
if self.oauth:
d["oauth"] = True
return d
@classmethod
def from_dict(cls, data: dict) -> MCPServerConfig:
@@ -53,6 +63,8 @@ class MCPServerConfig:
env=data.get("env", {}),
enabled=data.get("enabled", True),
timeout=data.get("timeout", 30),
registry_ref=data.get("registry_ref", ""),
oauth=data.get("oauth", False),
)

View File

@@ -14,13 +14,43 @@ from __future__ import annotations
import asyncio
import logging
import os
from collections.abc import Callable
from dataclasses import dataclass, field
from typing import Any
from urllib.parse import parse_qs, urlparse
from pocketpaw.mcp.config import MCPServerConfig, load_mcp_config, save_mcp_config
logger = logging.getLogger(__name__)
# OAuth callback coordination: state -> Future[(code, state)]
_oauth_pending: dict[str, asyncio.Future] = {}
# WebSocket broadcast function injected by dashboard at startup
_ws_broadcast: Callable | None = None
def set_ws_broadcast(fn: Callable) -> None:
"""Set the WebSocket broadcast function (called by dashboard at startup)."""
global _ws_broadcast
_ws_broadcast = fn
def set_oauth_callback_result(state: str, code: str) -> bool:
"""Resolve a pending OAuth Future with the authorization code.
Called by the dashboard callback endpoint when the OAuth provider
redirects back with code + state params.
Returns True if a pending Future was found and resolved.
"""
future = _oauth_pending.get(state)
if future and not future.done():
future.set_result((code, state))
return True
logger.warning("No pending OAuth flow for state=%s", state[:16])
return False
@dataclass
class MCPToolInfo:
@@ -46,6 +76,46 @@ class _ServerState:
connected: bool = False
_UNHELPFUL_ERRORS = {
"Attempted to exit a cancel scope that isn't the current tasks's current cancel scope",
}
def _extract_root_error(exc: BaseException) -> str:
"""Unwrap ExceptionGroup / BaseExceptionGroup to find the real error.
MCP stdio transport wraps subprocess crashes in anyio cancel-scope errors
that are uninformative (e.g. "Attempted to exit a cancel scope…").
Walk the exception tree to find a concrete root-cause message.
"""
# Collect all leaf exceptions from exception groups
leaves: list[BaseException] = []
def _collect(e: BaseException) -> None:
if isinstance(e, (ExceptionGroup, BaseExceptionGroup)):
for sub in e.exceptions:
_collect(sub)
elif e.__cause__:
_collect(e.__cause__)
else:
leaves.append(e)
_collect(exc)
# Pick the most useful message (skip unhelpful anyio internals)
for leaf in leaves:
msg = str(leaf).strip()
if msg and msg not in _UNHELPFUL_ERRORS:
return msg
# Fallback: if everything is unhelpful, use the top-level message and
# hint that the server process crashed.
top = str(exc).strip()
if top in _UNHELPFUL_ERRORS:
return "Server process crashed during startup (check terminal for details)"
return top
class MCPManager:
"""Manages MCP server connections and tool invocations."""
@@ -104,6 +174,88 @@ class MCPManager:
env.update(config_env)
return env
@staticmethod
def _make_oauth_auth(config: MCPServerConfig):
"""Create an httpx.Auth for OAuth-based MCP servers.
Uses the MCP SDK's OAuthClientProvider which handles:
- OAuth 2.1 metadata discovery
- CIMD (Client ID Metadata Document) for servers that support it
- Dynamic client registration as fallback
- PKCE authorization code flow
- Token refresh
"""
from mcp.client.auth import OAuthClientProvider
from mcp.shared.auth import OAuthClientMetadata
from pocketpaw.config import Settings
from pocketpaw.mcp.oauth_store import MCPTokenStorage
settings = Settings.load()
port = settings.web_port or 8888
storage = MCPTokenStorage(config.name)
client_metadata = OAuthClientMetadata(
client_name="PocketPaw",
redirect_uris=[f"http://localhost:{port}/api/mcp/oauth/callback"],
token_endpoint_auth_method="none",
grant_types=["authorization_code", "refresh_token"],
response_types=["code"],
)
# CIMD URL — servers that support client_id_metadata_document_supported
# will use this URL as the client_id instead of dynamic registration.
cimd_url = (settings.mcp_client_metadata_url or "").strip() or None
# Shared mutable state between the two closures
_flow_state: dict[str, str] = {}
async def redirect_handler(auth_url: str) -> None:
"""Called by SDK with the authorization URL — broadcast to frontend."""
parsed = urlparse(auth_url)
params = parse_qs(parsed.query)
state_values = params.get("state", [])
state = state_values[0] if state_values else ""
if state:
loop = asyncio.get_running_loop()
future = loop.create_future()
_oauth_pending[state] = future
_flow_state["state"] = state
if _ws_broadcast:
await _ws_broadcast(
{
"type": "mcp_oauth_redirect",
"url": auth_url,
"server": config.name,
}
)
logger.info("OAuth redirect for MCP server '%s' — waiting for callback", config.name)
async def callback_handler() -> tuple[str, str | None]:
"""Called by SDK to wait for the OAuth callback result."""
state = _flow_state.get("state")
if not state or state not in _oauth_pending:
raise RuntimeError(f"No pending OAuth flow for server '{config.name}'")
future = _oauth_pending[state]
try:
code, returned_state = await asyncio.wait_for(future, timeout=300)
return (code, returned_state)
finally:
_oauth_pending.pop(state, None)
return OAuthClientProvider(
server_url=config.url,
client_metadata=client_metadata,
storage=storage,
redirect_handler=redirect_handler,
callback_handler=callback_handler,
client_metadata_url=cimd_url,
)
async def start_server(self, config: MCPServerConfig) -> bool:
"""Start an MCP server and initialize its session.
@@ -117,23 +269,62 @@ class MCPManager:
state = _ServerState(config=config)
self._servers[config.name] = state
# Build OAuth auth if needed
auth = None
if config.oauth:
try:
auth = self._make_oauth_auth(config)
except Exception as e:
state.error = f"OAuth setup failed: {e}"
logger.error("OAuth setup failed for '%s': %s", config.name, e)
return False
try:
timeout = config.timeout or 30
# OAuth flows need more time for user interaction
connect_timeout = 300 if config.oauth else timeout
if config.transport == "stdio":
await asyncio.wait_for(self._connect_stdio(state), timeout=timeout)
elif config.transport == "streamable-http":
# Streamable HTTP uses a different MCP SDK client than SSE.
await self._connect_remote_with_timeout(
state, timeout, self._connect_streamable_http
state,
connect_timeout,
lambda s: self._connect_streamable_http(s, auth=auth),
)
elif config.transport == "sse":
await self._connect_remote_with_timeout(
state,
connect_timeout,
lambda s: self._connect_sse(s, auth=auth),
)
elif config.transport == "http":
# HTTP/SSE connections use anyio cancel scopes internally.
# asyncio.wait_for cancels the task on timeout, which disrupts
# anyio's cancel scope cleanup and causes TaskGroup errors.
# Instead, run with a manual timeout that doesn't cancel the task.
await self._connect_remote_with_timeout(
state, timeout, self._connect_sse
)
# Auto-detect: try Streamable HTTP first, fall back to SSE.
# Modern MCP servers use Streamable HTTP (POST-based);
# older ones use SSE (GET-based).
try:
await self._connect_remote_with_timeout(
state,
connect_timeout,
lambda s: self._connect_streamable_http(
s, auth=auth
),
)
except TimeoutError:
raise # Don't waste time retrying on timeout
except BaseException:
await self._cleanup_state(state)
state = _ServerState(config=config)
self._servers[config.name] = state
logger.debug(
"Streamable HTTP failed for '%s', trying SSE",
config.name,
)
await self._connect_remote_with_timeout(
state,
connect_timeout,
lambda s: self._connect_sse(s, auth=auth),
)
else:
state.error = f"Unknown transport: {config.transport}"
logger.error(state.error)
@@ -150,18 +341,31 @@ class MCPManager:
return True
except TimeoutError:
state.error = f"Connection timed out after {timeout}s"
effective_timeout = 300 if config.oauth else (config.timeout or 30)
state.error = f"Connection timed out after {effective_timeout}s"
state.connected = False
await self._cleanup_state(state)
logger.error("MCP server '%s' timed out after %ds", config.name, timeout)
logger.error("MCP server '%s' timed out after %ds", config.name, effective_timeout)
return False
except BaseException as e:
# Catch BaseException to handle ExceptionGroup / BaseExceptionGroup
# from anyio TaskGroup failures in the MCP library.
state.error = str(e)
root_msg = _extract_root_error(e)
# Provide actionable hint for OAuth registration failures
if config.oauth and "Registration failed" in root_msg:
root_msg = (
f"{root_msg}. "
"This server doesn't support dynamic client registration. "
"You can set mcp_client_metadata_url in Settings to a "
"publicly-hosted CIMD JSON file, or configure the server "
"with an API token instead of OAuth."
)
state.error = root_msg
state.connected = False
await self._cleanup_state(state)
logger.error("Failed to start MCP server '%s': %s", config.name, e)
logger.error("Failed to start MCP server '%s': %s", config.name, root_msg)
return False
async def _connect_stdio(self, state: _ServerState) -> None:
@@ -222,12 +426,15 @@ class MCPManager:
await self._cleanup_state(state)
raise
async def _connect_sse(self, state: _ServerState) -> None:
async def _connect_sse(self, state: _ServerState, auth=None) -> None:
"""Connect to an MCP server via SSE (Server-Sent Events)."""
from mcp import ClientSession
from mcp.client.sse import sse_client
ctx = sse_client(url=state.config.url)
kwargs: dict[str, Any] = {"url": state.config.url}
if auth is not None:
kwargs["auth"] = auth
ctx = sse_client(**kwargs)
streams = await ctx.__aenter__()
state.client = ctx
state.read_stream = streams[0]
@@ -243,12 +450,15 @@ class MCPManager:
state.client = None
raise
async def _connect_streamable_http(self, state: _ServerState) -> None:
async def _connect_streamable_http(self, state: _ServerState, auth=None) -> None:
"""Connect to an MCP server via Streamable HTTP transport."""
from mcp import ClientSession
from mcp.client.streamable_http import streamablehttp_client
ctx = streamablehttp_client(url=state.config.url)
kwargs: dict[str, Any] = {"url": state.config.url}
if auth is not None:
kwargs["auth"] = auth
ctx = streamablehttp_client(**kwargs)
streams = await ctx.__aenter__()
state.client = ctx
# streamablehttp_client yields (read, write, get_session_id)
@@ -356,22 +566,28 @@ class MCPManager:
result = {}
# First, include all servers from the config file
for cfg in load_mcp_config():
result[cfg.name] = {
info: dict = {
"connected": False,
"tool_count": 0,
"error": "",
"transport": cfg.transport,
"enabled": cfg.enabled,
}
if cfg.registry_ref:
info["registry_ref"] = cfg.registry_ref
result[cfg.name] = info
# Overlay runtime state for servers that have been started
for name, state in self._servers.items():
result[name] = {
info = {
"connected": state.connected,
"tool_count": len(state.tools),
"error": state.error,
"transport": state.config.transport,
"enabled": state.config.enabled,
}
if state.config.registry_ref:
info["registry_ref"] = state.config.registry_ref
result[name] = info
return result
async def start_enabled_servers(self) -> None:

View File

@@ -0,0 +1,96 @@
"""MCP OAuth Token Storage — file-based persistence for MCP OAuth tokens.
Implements the MCP SDK's ``TokenStorage`` protocol for persisting OAuth tokens
and client registration info to ``~/.pocketpaw/mcp_oauth/{server_name}.json``.
Created: 2026-02-17
"""
from __future__ import annotations
import json
import logging
import os
import stat
from pathlib import Path
from pocketpaw.config import get_config_dir
logger = logging.getLogger(__name__)
def _get_oauth_dir() -> Path:
"""Get/create the MCP OAuth token directory."""
d = get_config_dir() / "mcp_oauth"
d.mkdir(exist_ok=True)
return d
class MCPTokenStorage:
"""File-based token storage for MCP OAuth at ~/.pocketpaw/mcp_oauth/{name}.json.
Stores both OAuth tokens and dynamic client registration info.
Files are chmod 0600 (owner-only read/write).
"""
def __init__(self, server_name: str) -> None:
self._server_name = server_name
self._path = _get_oauth_dir() / f"{server_name}.json"
def _load(self) -> dict:
if not self._path.exists():
return {}
try:
return json.loads(self._path.read_text())
except (json.JSONDecodeError, OSError) as e:
logger.warning("Failed to load MCP OAuth data for %s: %s", self._server_name, e)
return {}
def _save(self, data: dict) -> None:
self._path.write_text(json.dumps(data, indent=2))
try:
os.chmod(self._path, stat.S_IRUSR | stat.S_IWUSR)
except OSError:
pass
async def get_tokens(self):
"""Get stored OAuth tokens."""
from mcp.shared.auth import OAuthToken
data = self._load()
tokens_data = data.get("tokens")
if not tokens_data:
return None
try:
return OAuthToken.model_validate(tokens_data)
except Exception as e:
logger.warning("Failed to parse MCP OAuth tokens for %s: %s", self._server_name, e)
return None
async def set_tokens(self, tokens) -> None:
"""Store OAuth tokens."""
data = self._load()
data["tokens"] = tokens.model_dump()
self._save(data)
logger.debug("Saved MCP OAuth tokens for %s", self._server_name)
async def get_client_info(self):
"""Get stored client registration info."""
from mcp.shared.auth import OAuthClientInformationFull
data = self._load()
client_data = data.get("client_info")
if not client_data:
return None
try:
return OAuthClientInformationFull.model_validate(client_data)
except Exception as e:
logger.warning("Failed to parse MCP OAuth client info for %s: %s", self._server_name, e)
return None
async def set_client_info(self, client_info) -> None:
"""Store client registration info."""
data = self._load()
data["client_info"] = client_info.model_dump(mode="json")
self._save(data)
logger.debug("Saved MCP OAuth client info for %s", self._server_name)

View File

@@ -402,6 +402,68 @@ class TestMCPManagerConfigMethods:
# ======================================================================
class TestHTTPAutoDetect:
"""Tests for transport='http' auto-detect (Streamable HTTP → SSE fallback)."""
async def test_http_transport_tries_streamable_first(self):
"""transport='http' should try Streamable HTTP and succeed."""
mgr = MCPManager()
cfg = MCPServerConfig(name="modern", transport="http", url="https://example.com/mcp")
with (
patch.object(mgr, "_connect_streamable_http", new_callable=AsyncMock),
patch.object(mgr, "_connect_sse", new_callable=AsyncMock) as mock_sse,
patch.object(
mgr,
"_discover_tools",
new_callable=AsyncMock,
),
):
result = await mgr.start_server(cfg)
assert result is True
# SSE should NOT have been called
mock_sse.assert_not_called()
async def test_http_transport_falls_back_to_sse(self):
"""transport='http' should fall back to SSE when Streamable HTTP fails."""
mgr = MCPManager()
cfg = MCPServerConfig(name="legacy", transport="http", url="https://example.com/mcp")
with (
patch.object(
mgr,
"_connect_streamable_http",
new_callable=AsyncMock,
side_effect=RuntimeError("405 Method Not Allowed"),
),
patch.object(mgr, "_connect_sse", new_callable=AsyncMock),
patch.object(mgr, "_discover_tools", new_callable=AsyncMock),
patch.object(mgr, "_cleanup_state", new_callable=AsyncMock),
):
result = await mgr.start_server(cfg)
assert result is True
async def test_http_transport_no_fallback_on_timeout(self):
"""transport='http' should NOT fall back to SSE on timeout."""
mgr = MCPManager()
cfg = MCPServerConfig(
name="slow", transport="http", url="https://example.com/mcp", timeout=1
)
with (
patch.object(
mgr,
"_connect_remote_with_timeout",
new_callable=AsyncMock,
side_effect=TimeoutError("Connection timed out"),
),
):
result = await mgr.start_server(cfg)
assert result is False
status = mgr._servers["slow"]
assert "timed out" in status.error
class TestGetMCPManager:
def test_returns_same_instance(self):
import pocketpaw.mcp.manager as mod

308
tests/test_mcp_oauth.py Normal file
View File

@@ -0,0 +1,308 @@
"""Tests for MCP OAuth flow — token storage, callback coordination,
preset flags, and dashboard endpoint.
Created: 2026-02-17
"""
import asyncio
from unittest.mock import MagicMock, patch
import pytest
from pocketpaw.mcp.config import MCPServerConfig
from pocketpaw.mcp.presets import get_all_presets, get_preset, preset_to_config
# ======================================================================
# MCPTokenStorage tests
# ======================================================================
class TestMCPTokenStorage:
@pytest.fixture
def storage(self, tmp_path):
with patch("pocketpaw.mcp.oauth_store.get_config_dir", return_value=tmp_path):
from pocketpaw.mcp.oauth_store import MCPTokenStorage
return MCPTokenStorage("test-server")
async def test_get_tokens_empty(self, storage):
result = await storage.get_tokens()
assert result is None
async def test_set_and_get_tokens(self, storage):
from mcp.shared.auth import OAuthToken
token = OAuthToken(access_token="access_123", token_type="Bearer", refresh_token="ref_456")
await storage.set_tokens(token)
loaded = await storage.get_tokens()
assert loaded is not None
assert loaded.access_token == "access_123"
assert loaded.refresh_token == "ref_456"
assert loaded.token_type == "Bearer"
async def test_get_client_info_empty(self, storage):
result = await storage.get_client_info()
assert result is None
async def test_set_and_get_client_info(self, storage):
from mcp.shared.auth import OAuthClientInformationFull
info = OAuthClientInformationFull(
client_id="cid_123",
client_secret="secret_456",
redirect_uris=["http://localhost:8888/callback"],
)
await storage.set_client_info(info)
loaded = await storage.get_client_info()
assert loaded is not None
assert loaded.client_id == "cid_123"
assert loaded.client_secret == "secret_456"
async def test_tokens_and_client_info_coexist(self, storage):
"""Both tokens and client info can be stored in the same file."""
from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
token = OAuthToken(access_token="tok")
info = OAuthClientInformationFull(client_id="cid", redirect_uris=["http://localhost/cb"])
await storage.set_tokens(token)
await storage.set_client_info(info)
loaded_tok = await storage.get_tokens()
loaded_info = await storage.get_client_info()
assert loaded_tok.access_token == "tok"
assert loaded_info.client_id == "cid"
async def test_file_permissions(self, storage, tmp_path):
"""Token file should be chmod 0600."""
import os
import stat
from mcp.shared.auth import OAuthToken
await storage.set_tokens(OAuthToken(access_token="x"))
path = tmp_path / "mcp_oauth" / "test-server.json"
mode = os.stat(path).st_mode
assert mode & stat.S_IRWXG == 0 # No group perms
assert mode & stat.S_IRWXO == 0 # No other perms
assert mode & stat.S_IRUSR # Owner read
async def test_corrupted_file_returns_none(self, storage, tmp_path):
"""Corrupted JSON should return None, not raise."""
path = tmp_path / "mcp_oauth" / "test-server.json"
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text("not json{{{")
result = await storage.get_tokens()
assert result is None
result = await storage.get_client_info()
assert result is None
# ======================================================================
# OAuth callback coordination tests
# ======================================================================
class TestOAuthCallbackCoordination:
def setup_method(self):
from pocketpaw.mcp import manager
manager._oauth_pending.clear()
def teardown_method(self):
from pocketpaw.mcp import manager
manager._oauth_pending.clear()
def test_set_oauth_callback_result_resolves_future(self):
from pocketpaw.mcp.manager import set_oauth_callback_result
loop = asyncio.new_event_loop()
future = loop.create_future()
from pocketpaw.mcp import manager
manager._oauth_pending["test_state_123"] = future
result = set_oauth_callback_result("test_state_123", "auth_code_456")
assert result is True
assert future.done()
assert future.result() == ("auth_code_456", "test_state_123")
loop.close()
def test_set_oauth_callback_result_unknown_state(self):
from pocketpaw.mcp.manager import set_oauth_callback_result
result = set_oauth_callback_result("nonexistent_state", "code")
assert result is False
def test_set_oauth_callback_result_already_resolved(self):
from pocketpaw.mcp.manager import set_oauth_callback_result
loop = asyncio.new_event_loop()
future = loop.create_future()
future.set_result(("old_code", "old_state"))
from pocketpaw.mcp import manager
manager._oauth_pending["done_state"] = future
result = set_oauth_callback_result("done_state", "new_code")
assert result is False
loop.close()
def test_set_ws_broadcast(self):
from pocketpaw.mcp import manager
from pocketpaw.mcp.manager import set_ws_broadcast
old = manager._ws_broadcast
try:
fn = MagicMock()
set_ws_broadcast(fn)
assert manager._ws_broadcast is fn
finally:
manager._ws_broadcast = old
# ======================================================================
# Preset OAuth flag tests
# ======================================================================
class TestPresetOAuthFlags:
_OAUTH_PRESET_IDS = {
"github",
"notion",
"atlassian",
"stripe",
"cloudflare",
"supabase",
"vercel",
"gitlab",
"figma",
}
def test_http_presets_have_oauth_true(self):
for preset_id in self._OAUTH_PRESET_IDS:
preset = get_preset(preset_id)
assert preset is not None, f"Preset {preset_id} not found"
assert preset.oauth is True, f"Preset {preset_id} should have oauth=True"
assert preset.transport == "http"
def test_stdio_presets_have_oauth_false(self):
for preset in get_all_presets():
if preset.transport == "stdio":
assert preset.oauth is False, f"Stdio preset {preset.id} should have oauth=False"
def test_preset_to_config_passes_oauth(self):
preset = get_preset("github")
config = preset_to_config(preset)
assert config.oauth is True
def test_preset_to_config_stdio_no_oauth(self):
preset = get_preset("sentry")
config = preset_to_config(preset, env={"SENTRY_ACCESS_TOKEN": "x"})
assert config.oauth is False
# ======================================================================
# MCPServerConfig oauth field tests
# ======================================================================
class TestConfigOAuthField:
def test_to_dict_includes_oauth_when_true(self):
config = MCPServerConfig(name="test", oauth=True)
d = config.to_dict()
assert d["oauth"] is True
def test_to_dict_excludes_oauth_when_false(self):
config = MCPServerConfig(name="test", oauth=False)
d = config.to_dict()
assert "oauth" not in d
def test_from_dict_reads_oauth(self):
config = MCPServerConfig.from_dict({"name": "test", "oauth": True})
assert config.oauth is True
def test_from_dict_defaults_oauth_false(self):
config = MCPServerConfig.from_dict({"name": "test"})
assert config.oauth is False
# ======================================================================
# Dashboard OAuth callback endpoint tests
# ======================================================================
_TEST_TOKEN = "test-oauth-token-12345"
def _auth(**extra):
h = {"Authorization": f"Bearer {_TEST_TOKEN}"}
h.update(extra)
return h
class TestDashboardOAuthCallback:
@pytest.fixture(autouse=True)
def _mock_token(self):
with patch("pocketpaw.dashboard.get_access_token", return_value=_TEST_TOKEN):
yield
@pytest.fixture
def client(self):
from fastapi.testclient import TestClient
from pocketpaw.dashboard import app
return TestClient(app, raise_server_exceptions=False)
@patch("pocketpaw.mcp.manager.set_oauth_callback_result")
def test_callback_success(self, mock_set_result, client):
mock_set_result.return_value = True
res = client.get("/api/mcp/oauth/callback?code=abc123&state=xyz789")
assert res.status_code == 200
assert "Authenticated" in res.text
assert "window.close" in res.text
mock_set_result.assert_called_once_with("xyz789", "abc123")
@patch("pocketpaw.mcp.manager.set_oauth_callback_result")
def test_callback_expired_flow(self, mock_set_result, client):
mock_set_result.return_value = False
res = client.get("/api/mcp/oauth/callback?code=abc&state=xyz")
assert res.status_code == 400
assert "expired" in res.text.lower() or "not found" in res.text.lower()
def test_callback_missing_params(self, client):
# Empty defaults → 400
res = client.get("/api/mcp/oauth/callback")
assert res.status_code == 400
# Explicit empty strings → also 400
res = client.get("/api/mcp/oauth/callback?code=&state=")
assert res.status_code == 400
def test_callback_is_auth_exempt(self, client):
"""OAuth callback should work without auth token."""
with patch("pocketpaw.mcp.manager.set_oauth_callback_result", return_value=True):
res = client.get(
"/api/mcp/oauth/callback?code=test&state=test",
# No auth header
)
# Should not get 401 — the endpoint is auth-exempt
assert res.status_code != 401
@patch("pocketpaw.mcp.config.load_mcp_config")
def test_presets_include_oauth_field(self, mock_load, client):
mock_load.return_value = []
res = client.get("/api/mcp/presets", headers=_auth())
assert res.status_code == 200
data = res.json()
github = next(p for p in data if p["id"] == "github")
assert github["oauth"] is True
sentry = next(p for p in data if p["id"] == "sentry")
assert sentry["oauth"] is False

View File

@@ -258,6 +258,7 @@ class TestSkillsRESTEndpoints:
# Make wait_for actually call the coroutine
async def passthrough(coro, timeout):
return await coro
mock_asyncio.wait_for = passthrough
# Prepare a fake cloned repo with a skill inside skills/ subdir
@@ -291,9 +292,7 @@ class TestSkillsRESTEndpoints:
request.json = AsyncMock(return_value={"source": "owner/bad-repo/skill"})
mock_proc = AsyncMock()
mock_proc.communicate = AsyncMock(
return_value=(b"", b"fatal: repository not found\n")
)
mock_proc.communicate = AsyncMock(return_value=(b"", b"fatal: repository not found\n"))
mock_proc.returncode = 128
with (
@@ -305,6 +304,7 @@ class TestSkillsRESTEndpoints:
async def passthrough(coro, timeout):
return await coro
mock_asyncio.wait_for = passthrough
from pocketpaw.dashboard import install_skill
@@ -415,364 +415,3 @@ class TestMCPPresetNeedsArgs:
for p in get_all_presets():
# Every preset should have a bool needs_args
assert isinstance(p.needs_args, bool), f"Preset {p.id} needs_args is not bool"
# ======================================================================
# MCP Registry endpoint tests
# ======================================================================
class TestMCPRegistryEndpoints:
async def test_search_registry_empty_query(self):
"""GET /api/mcp/registry/search with no query returns featured servers."""
mock_response = MagicMock()
# Registry API wraps each entry as {server: {...}, _meta: {...}}
mock_response.json.return_value = {
"servers": [
{
"server": {"name": "org/server", "description": "A server"},
"_meta": {"score": 0.9},
}
],
"metadata": {"count": 1},
}
mock_response.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
mock_client.get = AsyncMock(return_value=mock_response)
with patch("httpx.AsyncClient", return_value=mock_client):
from pocketpaw.dashboard import search_mcp_registry
result = await search_mcp_registry(q="", limit=30, cursor="")
assert "servers" in result
# Backend should unwrap nested structure
assert result["servers"][0]["name"] == "org/server"
assert result["servers"][0]["_meta"]["score"] == 0.9
# Should call registry with just limit (no search param)
mock_client.get.assert_called_once()
call_kwargs = mock_client.get.call_args
params = call_kwargs.kwargs.get("params") or call_kwargs[1].get("params")
assert "search" not in params
async def test_search_registry_with_query(self):
"""GET /api/mcp/registry/search proxies search param to registry."""
mock_response = MagicMock()
mock_response.json.return_value = {
"servers": [{"server": {"name": "org/sql-server", "description": "SQL"}, "_meta": {}}],
"metadata": {"count": 1, "nextCursor": "abc123"},
}
mock_response.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
mock_client.get = AsyncMock(return_value=mock_response)
with patch("httpx.AsyncClient", return_value=mock_client):
from pocketpaw.dashboard import search_mcp_registry
result = await search_mcp_registry(q="sql", limit=10, cursor="")
# Unwrapped: flat server object
assert result["servers"][0]["name"] == "org/sql-server"
call_kwargs = mock_client.get.call_args
params = call_kwargs.kwargs.get("params") or call_kwargs[1].get("params")
assert params["search"] == "sql"
assert params["limit"] == 10
async def test_search_registry_with_cursor(self):
"""GET /api/mcp/registry/search passes cursor for pagination."""
mock_response = MagicMock()
mock_response.json.return_value = {
"servers": [],
"metadata": {"count": 0},
}
mock_response.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
mock_client.get = AsyncMock(return_value=mock_response)
with patch("httpx.AsyncClient", return_value=mock_client):
from pocketpaw.dashboard import search_mcp_registry
await search_mcp_registry(q="test", limit=30, cursor="page2")
call_kwargs = mock_client.get.call_args
params = call_kwargs.kwargs.get("params") or call_kwargs[1].get("params")
assert params["cursor"] == "page2"
async def test_search_registry_limit_capped(self):
"""Limit is capped at 100."""
mock_response = MagicMock()
mock_response.json.return_value = {"servers": [], "metadata": {"count": 0}}
mock_response.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
mock_client.get = AsyncMock(return_value=mock_response)
with patch("httpx.AsyncClient", return_value=mock_client):
from pocketpaw.dashboard import search_mcp_registry
await search_mcp_registry(q="x", limit=999, cursor="")
call_kwargs = mock_client.get.call_args
params = call_kwargs.kwargs.get("params") or call_kwargs[1].get("params")
assert params["limit"] == 100
async def test_search_registry_accepts_flat_entries(self):
"""Flat server rows (no nested 'server' key) are normalized correctly."""
mock_response = MagicMock()
mock_response.json.return_value = {
"servers": [
{
"name": "org/flat-server",
"description": "Flat format row",
"packages": [],
"remotes": [],
}
],
"metadata": {"count": 1, "next_cursor": "cursor123"},
}
mock_response.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
mock_client.get = AsyncMock(return_value=mock_response)
with patch("httpx.AsyncClient", return_value=mock_client):
from pocketclaw.dashboard import search_mcp_registry
result = await search_mcp_registry(q="flat", limit=10, cursor="")
assert result["servers"][0]["name"] == "org/flat-server"
# next_cursor is normalized for frontend compatibility
assert result["metadata"]["nextCursor"] == "cursor123"
async def test_search_registry_skips_malformed_rows(self):
"""Malformed rows are skipped instead of crashing the whole result set."""
mock_response = MagicMock()
mock_response.json.return_value = {
"servers": [
None,
{"server": None},
{
"server": {
"name": "org/good-server",
"description": "Good row",
"packages": [{"environmentVariables": [{"name": "API_KEY"}]}],
"remotes": [],
}
},
],
"metadata": "invalid-metadata-type",
}
mock_response.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
mock_client.get = AsyncMock(return_value=mock_response)
with patch("httpx.AsyncClient", return_value=mock_client):
from pocketclaw.dashboard import search_mcp_registry
result = await search_mcp_registry(q="good", limit=10, cursor="")
assert len(result["servers"]) == 1
assert result["servers"][0]["name"] == "org/good-server"
# env vars should be lifted from package metadata
assert result["servers"][0]["environmentVariables"][0]["name"] == "API_KEY"
# invalid metadata falls back to normalized dict with count
assert result["metadata"]["count"] == 1
async def test_install_from_registry_missing_name(self):
"""POST /api/mcp/registry/install with empty server name returns 400."""
from pocketpaw.dashboard import install_from_registry
request = MagicMock()
request.json = AsyncMock(return_value={"server": {}, "env": {}})
result = await install_from_registry(request)
assert result.status_code == 400
async def test_install_from_registry_http_remote(self):
"""Install from registry with HTTP remote transport (legacy transportType key)."""
mock_mgr = MagicMock()
mock_mgr.add_server_config = MagicMock()
mock_mgr.start_server = AsyncMock(return_value=True)
mock_mgr.discover_tools = MagicMock(return_value=[])
request = MagicMock()
request.json = AsyncMock(
return_value={
"server": {
"name": "org/my-server",
"remotes": [{"url": "https://api.example.com/mcp", "transportType": "http"}],
},
"env": {},
}
)
with patch("pocketpaw.mcp.manager.get_mcp_manager", return_value=mock_mgr):
from pocketpaw.dashboard import install_from_registry
result = await install_from_registry(request)
assert result["status"] == "ok"
assert result["name"] == "my-server"
assert result["connected"] is True
# Verify the config was created with HTTP transport
config = mock_mgr.add_server_config.call_args[0][0]
assert config.transport == "http"
assert config.url == "https://api.example.com/mcp"
async def test_install_from_registry_streamable_http_remote(self):
"""Install from registry with streamable-http transport (actual registry API format)."""
mock_mgr = MagicMock()
mock_mgr.add_server_config = MagicMock()
mock_mgr.start_server = AsyncMock(return_value=True)
mock_mgr.discover_tools = MagicMock(return_value=[])
request = MagicMock()
request.json = AsyncMock(
return_value={
"server": {
"name": "org/stream-server",
"remotes": [{"url": "https://api.example.com/mcp", "type": "streamable-http"}],
},
"env": {},
}
)
with patch("pocketpaw.mcp.manager.get_mcp_manager", return_value=mock_mgr):
from pocketpaw.dashboard import install_from_registry
result = await install_from_registry(request)
assert result["status"] == "ok"
assert result["name"] == "stream-server"
# streamable-http is preserved (needs different MCP SDK client)
config = mock_mgr.add_server_config.call_args[0][0]
assert config.transport == "streamable-http"
assert config.url == "https://api.example.com/mcp"
async def test_install_from_registry_npm_package(self):
"""Install from registry with npm package (stdio)."""
mock_mgr = MagicMock()
mock_mgr.add_server_config = MagicMock()
mock_mgr.start_server = AsyncMock(return_value=False)
mock_mgr.discover_tools = MagicMock(return_value=[])
request = MagicMock()
request.json = AsyncMock(
return_value={
"server": {
"name": "org/cool-mcp",
"packages": [
{
"registryType": "npm",
"name": "@cool/mcp-server",
"runtime": "node",
"packageArguments": [],
}
],
},
"env": {"API_KEY": "test123"},
}
)
with patch("pocketpaw.mcp.manager.get_mcp_manager", return_value=mock_mgr):
from pocketpaw.dashboard import install_from_registry
result = await install_from_registry(request)
assert result["status"] == "ok"
assert result["name"] == "cool-mcp"
config = mock_mgr.add_server_config.call_args[0][0]
assert config.transport == "stdio"
assert config.command == "npx"
assert "-y" in config.args
assert "@cool/mcp-server" in config.args
assert config.env == {"API_KEY": "test123"}
async def test_install_from_registry_pypi_package(self):
"""Install from registry with pypi package (uvx)."""
mock_mgr = MagicMock()
mock_mgr.add_server_config = MagicMock()
mock_mgr.start_server = AsyncMock(return_value=True)
mock_mgr.discover_tools = MagicMock(return_value=[])
request = MagicMock()
request.json = AsyncMock(
return_value={
"server": {
"name": "org/py-server",
"packages": [
{"registryType": "pypi", "name": "mcp-py-server", "runtime": "python"}
],
},
"env": {},
}
)
with patch("pocketpaw.mcp.manager.get_mcp_manager", return_value=mock_mgr):
from pocketpaw.dashboard import install_from_registry
result = await install_from_registry(request)
assert result["status"] == "ok"
config = mock_mgr.add_server_config.call_args[0][0]
assert config.command == "uvx"
assert "mcp-py-server" in config.args
async def test_install_from_registry_docker_package(self):
"""Install from registry with docker package."""
mock_mgr = MagicMock()
mock_mgr.add_server_config = MagicMock()
mock_mgr.start_server = AsyncMock(return_value=True)
mock_mgr.discover_tools = MagicMock(return_value=[])
request = MagicMock()
request.json = AsyncMock(
return_value={
"server": {
"name": "org/docker-srv",
"packages": [
{
"registryType": "docker",
"name": "ghcr.io/org/mcp-docker",
"runtimeArguments": [
{"isFixed": True, "value": "-p"},
{"isFixed": True, "value": "3000:3000"},
],
}
],
},
"env": {},
}
)
with patch("pocketpaw.mcp.manager.get_mcp_manager", return_value=mock_mgr):
from pocketpaw.dashboard import install_from_registry
result = await install_from_registry(request)
assert result["status"] == "ok"
config = mock_mgr.add_server_config.call_args[0][0]
assert config.command == "docker"
assert "run" in config.args
assert "-p" in config.args
assert "ghcr.io/org/mcp-docker" in config.args
async def test_install_from_registry_no_install_method(self):
"""POST /api/mcp/registry/install with no packages or remotes returns 400."""
from pocketpaw.dashboard import install_from_registry
request = MagicMock()
request.json = AsyncMock(
return_value={
"server": {"name": "org/empty-server"},
"env": {},
}
)
result = await install_from_registry(request)
assert result.status_code == 400

2
uv.lock generated
View File

@@ -2820,7 +2820,7 @@ wheels = [
[[package]]
name = "pocketpaw"
version = "0.4.0"
version = "0.4.1"
source = { editable = "." }
dependencies = [
{ name = "anthropic" },