diff --git a/AGENTS.md b/AGENTS.md index 8ab8793c..89213e7d 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -37,6 +37,22 @@ Run the Flask API (if needed): flask --app application/app.py run --host=0.0.0.0 --port=7091 ``` +That's the fast inner-loop option — quick startup, the Werkzeug interactive +debugger still works, and it hot-reloads on source changes. It serves the +Flask routes only (`/api/*`, `/stream`, etc.). + +If you need to exercise the full ASGI stack — the `/mcp` FastMCP endpoint, +or to match the production runtime exactly — run the ASGI composition under +uvicorn instead: + +```bash +uvicorn application.asgi:asgi_app --host 0.0.0.0 --port 7091 --reload +``` + +Production uses `gunicorn -k uvicorn_worker.UvicornWorker` against the same +`application.asgi:asgi_app` target; see `application/Dockerfile` for the +full flag set. + Run the Celery worker in a separate terminal (if needed): ```bash diff --git a/application/Dockerfile b/application/Dockerfile index 48d29e57..849bcc79 100644 --- a/application/Dockerfile +++ b/application/Dockerfile @@ -88,5 +88,15 @@ EXPOSE 7091 # Switch to non-root user USER appuser -# Start Gunicorn -CMD ["gunicorn", "-w", "1", "--timeout", "120", "--bind", "0.0.0.0:7091", "--preload", "application.wsgi:app"] +CMD ["gunicorn", \ + "-w", "1", \ + "-k", "uvicorn_worker.UvicornWorker", \ + "--bind", "0.0.0.0:7091", \ + "--timeout", "180", \ + "--graceful-timeout", "120", \ + "--keep-alive", "5", \ + "--worker-tmp-dir", "/dev/shm", \ + "--max-requests", "1000", \ + "--max-requests-jitter", "100", \ + "--config", "application/gunicorn_conf.py", \ + "application.asgi:asgi_app"] diff --git a/application/api/answer/routes/search.py b/application/api/answer/routes/search.py index c5cdfaf8..d7aa2377 100644 --- a/application/api/answer/routes/search.py +++ b/application/api/answer/routes/search.py @@ -1,21 +1,21 @@ import logging -from typing import Any, Dict, List from flask import make_response, request from flask_restx import fields, Resource from application.api.answer.routes.base import answer_ns -from application.core.settings import settings -from application.storage.db.repositories.agents import AgentsRepository -from application.storage.db.session import db_readonly -from application.vectorstore.vector_creator import VectorCreator +from application.services.search_service import ( + InvalidAPIKey, + SearchFailed, + search, +) logger = logging.getLogger(__name__) @answer_ns.route("/api/search") class SearchResource(Resource): - """Fast search endpoint for retrieving relevant documents""" + """Fast search endpoint for retrieving relevant documents.""" search_model = answer_ns.model( "SearchModel", @@ -32,102 +32,10 @@ class SearchResource(Resource): }, ) - def _get_sources_from_api_key(self, api_key: str) -> List[str]: - """Get source IDs connected to the API key/agent.""" - with db_readonly() as conn: - agent_data = AgentsRepository(conn).find_by_key(api_key) - if not agent_data: - return [] - - source_ids: List[str] = [] - # extra_source_ids is a PG ARRAY(UUID) of source UUIDs. - extra = agent_data.get("extra_source_ids") or [] - for src in extra: - if src: - source_ids.append(str(src)) - - if not source_ids: - single = agent_data.get("source_id") - if single: - source_ids.append(str(single)) - - return source_ids - - def _search_vectorstores( - self, query: str, source_ids: List[str], chunks: int - ) -> List[Dict[str, Any]]: - """Search across vectorstores and return results""" - if not source_ids: - return [] - - results = [] - chunks_per_source = max(1, chunks // len(source_ids)) - seen_texts = set() - - for source_id in source_ids: - if not source_id or not source_id.strip(): - continue - - try: - docsearch = VectorCreator.create_vectorstore( - settings.VECTOR_STORE, source_id, settings.EMBEDDINGS_KEY - ) - docs = docsearch.search(query, k=chunks_per_source * 2) - - for doc in docs: - if len(results) >= chunks: - break - - if hasattr(doc, "page_content") and hasattr(doc, "metadata"): - page_content = doc.page_content - metadata = doc.metadata - else: - page_content = doc.get("text", doc.get("page_content", "")) - metadata = doc.get("metadata", {}) - - # Skip duplicates - text_hash = hash(page_content[:200]) - if text_hash in seen_texts: - continue - seen_texts.add(text_hash) - - title = metadata.get( - "title", metadata.get("post_title", "") - ) - if not isinstance(title, str): - title = str(title) if title else "" - - # Clean up title - if title: - title = title.split("/")[-1] - else: - # Use filename or first part of content as title - title = metadata.get("filename", page_content[:50] + "...") - - source = metadata.get("source", source_id) - - results.append({ - "text": page_content, - "title": title, - "source": source, - }) - - if len(results) >= chunks: - break - - except Exception as e: - logger.error( - f"Error searching vectorstore {source_id}: {e}", - exc_info=True, - ) - continue - - return results[:chunks] - @answer_ns.expect(search_model) @answer_ns.doc(description="Search for relevant documents based on query") def post(self): - data = request.get_json() + data = request.get_json() or {} question = data.get("question") api_key = data.get("api_key") @@ -135,32 +43,13 @@ class SearchResource(Resource): if not question: return make_response({"error": "question is required"}, 400) - if not api_key: return make_response({"error": "api_key is required"}, 400) - # Validate API key - with db_readonly() as conn: - agent = AgentsRepository(conn).find_by_key(api_key) - if not agent: - return make_response({"error": "Invalid API key"}, 401) - try: - # Get sources connected to this API key - source_ids = self._get_sources_from_api_key(api_key) - - if not source_ids: - return make_response([], 200) - - # Perform search - results = self._search_vectorstores(question, source_ids, chunks) - - return make_response(results, 200) - - except Exception as e: - logger.error( - f"/api/search - error: {str(e)}", - extra={"error": str(e)}, - exc_info=True, - ) + return make_response(search(api_key, question, chunks), 200) + except InvalidAPIKey: + return make_response({"error": "Invalid API key"}, 401) + except SearchFailed: + logger.exception("/api/search failed") return make_response({"error": "Search failed"}, 500) diff --git a/application/app.py b/application/app.py index adc55ba5..31aef718 100644 --- a/application/app.py +++ b/application/app.py @@ -4,7 +4,7 @@ import platform import uuid import dotenv -from flask import Flask, jsonify, redirect, request +from flask import Flask, Response, jsonify, redirect, request from jose import jwt from application.auth import handle_auth @@ -149,12 +149,11 @@ def authenticate_request(): @app.after_request -def after_request(response): - response.headers.add("Access-Control-Allow-Origin", "*") - response.headers.add("Access-Control-Allow-Headers", "Content-Type, Authorization") - response.headers.add( - "Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS" - ) +def after_request(response: Response) -> Response: + """Add CORS headers for the pure Flask development entrypoint.""" + response.headers["Access-Control-Allow-Origin"] = "*" + response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization" + response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS" return response diff --git a/application/asgi.py b/application/asgi.py new file mode 100644 index 00000000..19f8cbe7 --- /dev/null +++ b/application/asgi.py @@ -0,0 +1,33 @@ +"""ASGI entrypoint: Flask (WSGI) + FastMCP on the same process.""" + +from __future__ import annotations + +from a2wsgi import WSGIMiddleware +from starlette.applications import Starlette +from starlette.middleware import Middleware +from starlette.middleware.cors import CORSMiddleware +from starlette.routing import Mount + +from application.app import app as flask_app +from application.mcp_server import mcp + +_WSGI_THREADPOOL = 32 + +mcp_app = mcp.http_app(path="/") + +asgi_app = Starlette( + routes=[ + Mount("/mcp", app=mcp_app), + Mount("/", app=WSGIMiddleware(flask_app, workers=_WSGI_THREADPOOL)), + ], + middleware=[ + Middleware( + CORSMiddleware, + allow_origins=["*"], + allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], + allow_headers=["Content-Type", "Authorization", "Mcp-Session-Id"], + expose_headers=["Mcp-Session-Id"], + ), + ], + lifespan=mcp_app.lifespan, +) diff --git a/application/gunicorn_conf.py b/application/gunicorn_conf.py new file mode 100644 index 00000000..4b56dc8e --- /dev/null +++ b/application/gunicorn_conf.py @@ -0,0 +1,72 @@ +"""Gunicorn config — keeps uvicorn's access log in NCSA format.""" + +from __future__ import annotations + +import logging +import logging.config + +# NCSA common log format: +# %(h)s %(l)s %(u)s %(t)s "%(r)s" %(s)s %(b)s "%(f)s" "%(a)s" +# Uvicorn's access formatter exposes a ``client_addr``/``request_line``/ +# ``status_code`` trio but not the full NCSA field set, so we re-derive +# what we can. +_NCSA_FMT = ( + '%(client_addr)s - - [%(asctime)s] "%(request_line)s" %(status_code)s' +) + +logconfig_dict = { + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "ncsa_access": { + "()": "uvicorn.logging.AccessFormatter", + "fmt": _NCSA_FMT, + "datefmt": "%d/%b/%Y:%H:%M:%S %z", + "use_colors": False, + }, + "default": { + "format": "[%(asctime)s] [%(process)d] [%(levelname)s] %(name)s: %(message)s", + }, + }, + "handlers": { + "access": { + "class": "logging.StreamHandler", + "formatter": "ncsa_access", + "stream": "ext://sys.stdout", + }, + "default": { + "class": "logging.StreamHandler", + "formatter": "default", + "stream": "ext://sys.stderr", + }, + }, + "loggers": { + "uvicorn": {"handlers": ["default"], "level": "INFO", "propagate": False}, + "uvicorn.error": { + "handlers": ["default"], + "level": "INFO", + "propagate": False, + }, + "uvicorn.access": { + "handlers": ["access"], + "level": "INFO", + "propagate": False, + }, + "gunicorn.error": { + "handlers": ["default"], + "level": "INFO", + "propagate": False, + }, + "gunicorn.access": { + "handlers": ["access"], + "level": "INFO", + "propagate": False, + }, + }, + "root": {"handlers": ["default"], "level": "INFO"}, +} + + +def on_starting(server): # pragma: no cover — gunicorn hook + """Ensure gunicorn's own loggers use the configured handlers.""" + logging.config.dictConfig(logconfig_dict) diff --git a/application/mcp_server.py b/application/mcp_server.py new file mode 100644 index 00000000..23b074f1 --- /dev/null +++ b/application/mcp_server.py @@ -0,0 +1,59 @@ +"""FastMCP server exposing DocsGPT retrieval over streamable HTTP. + +Mounted at ``/mcp`` by ``application/asgi.py``. Bearer tokens are the +existing DocsGPT agent API keys — no new credential surface. + +The tool reads the ``Authorization`` header directly via +``get_http_headers(include={"authorization"})``. The ``include`` kwarg +is required: by default ``get_http_headers`` strips ``authorization`` +(and a handful of other hop-by-hop headers) so they aren't forwarded +to downstream services — since we deliberately want the caller's +token, we opt it back in. +""" + +from __future__ import annotations + +import asyncio +import logging + +from fastmcp import FastMCP +from fastmcp.server.dependencies import get_http_headers + +from application.services.search_service import ( + InvalidAPIKey, + SearchFailed, + search, +) + +logger = logging.getLogger(__name__) + +mcp = FastMCP("docsgpt") + + +def _extract_bearer_token() -> str | None: + auth = get_http_headers(include={"authorization"}).get("authorization", "") + parts = auth.split(None, 1) + if len(parts) != 2 or parts[0].lower() != "bearer" or not parts[1]: + return None + return parts[1] + + +@mcp.tool +async def search_docs(query: str, chunks: int = 5) -> list[dict]: + """Search the caller's DocsGPT knowledge base. + + Authentication is via ``Authorization: Bearer `` on + the MCP request — the same opaque key that ``/api/search`` accepts + in its JSON body. Returns at most ``chunks`` hits, each a dict with + ``text``, ``title``, ``source`` keys. + """ + api_key = _extract_bearer_token() + if not api_key: + raise PermissionError("Missing Bearer token") + try: + return await asyncio.to_thread(search, api_key, query, chunks) + except InvalidAPIKey as exc: + raise PermissionError("Invalid API key") from exc + except SearchFailed: + logger.exception("search_docs failed") + raise diff --git a/application/requirements.txt b/application/requirements.txt index 2ff38b00..57597811 100644 --- a/application/requirements.txt +++ b/application/requirements.txt @@ -1,5 +1,7 @@ +a2wsgi==1.10.10 alembic>=1.13,<2 anthropic==0.88.0 +asgiref>=3.11.1 boto3==1.42.83 beautifulsoup4==4.14.3 cel-python==0.5.0 @@ -14,7 +16,7 @@ docx2txt==0.9 ddgs>=8.0.0 fast-ebook elevenlabs==2.43.0 -Flask==3.1.3 +Flask==3.1.1 faiss-cpu==1.13.2 fastmcp==3.2.4 flask-restx==1.3.2 @@ -76,6 +78,7 @@ requests==2.33.1 retry==0.9.2 sentence-transformers==5.3.0 sqlalchemy>=2.0,<3 +starlette>=1.0,<2 tiktoken==0.12.0 tokenizers==0.22.2 torch==2.11.0 @@ -85,6 +88,8 @@ typing-extensions==4.15.0 typing-inspect==0.9.0 tzdata==2026.1 urllib3==2.6.3 +uvicorn[standard]>=0.30,<1 +uvicorn-worker>=0.4,<1 vine==5.1.0 wcwidth==0.6.0 werkzeug>=3.1.0 diff --git a/application/services/__init__.py b/application/services/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/application/services/search_service.py b/application/services/search_service.py new file mode 100644 index 00000000..8febd79b --- /dev/null +++ b/application/services/search_service.py @@ -0,0 +1,153 @@ +"""Shared retrieval service used by the HTTP search route and the MCP tool. + +Flask-free. Raises domain exceptions (``InvalidAPIKey``, ``SearchFailed``) +that callers translate into their own wire protocol (HTTP status codes, +MCP error responses, etc.). +""" + +from __future__ import annotations + +import logging +from typing import Any, Dict, List + +from application.core.settings import settings +from application.storage.db.repositories.agents import AgentsRepository +from application.storage.db.session import db_readonly +from application.vectorstore.vector_creator import VectorCreator + +logger = logging.getLogger(__name__) + + +class InvalidAPIKey(Exception): + """The supplied ``api_key`` does not resolve to an agent.""" + + +class SearchFailed(Exception): + """Unexpected error during retrieval (e.g. DB outage). Caller maps to 5xx.""" + + +def _collect_source_ids(agent: Dict[str, Any]) -> List[str]: + """Extract the ordered list of source UUIDs to search. + + Prefers ``extra_source_ids`` (PG ARRAY(UUID) of multi-source agents); + falls back to the legacy single ``source_id`` field. + """ + source_ids: List[str] = [] + extra = agent.get("extra_source_ids") or [] + for src in extra: + if src: + source_ids.append(str(src)) + if not source_ids: + single = agent.get("source_id") + if single: + source_ids.append(str(single)) + return source_ids + + +def _search_sources( + query: str, source_ids: List[str], chunks: int +) -> List[Dict[str, Any]]: + """Search across each source's vectorstore and return up to ``chunks`` hits. + + Per-source errors are logged and skipped so one broken index doesn't + take down the whole search. Results are de-duplicated by content hash. + """ + if chunks <= 0 or not source_ids: + return [] + + results: List[Dict[str, Any]] = [] + chunks_per_source = max(1, chunks // len(source_ids)) + seen_texts: set[int] = set() + + for source_id in source_ids: + if not source_id or not source_id.strip(): + continue + + try: + docsearch = VectorCreator.create_vectorstore( + settings.VECTOR_STORE, source_id, settings.EMBEDDINGS_KEY + ) + docs = docsearch.search(query, k=chunks_per_source * 2) + + for doc in docs: + if len(results) >= chunks: + break + + if hasattr(doc, "page_content") and hasattr(doc, "metadata"): + page_content = doc.page_content + metadata = doc.metadata + else: + page_content = doc.get("text", doc.get("page_content", "")) + metadata = doc.get("metadata", {}) + + text_hash = hash(page_content[:200]) + if text_hash in seen_texts: + continue + seen_texts.add(text_hash) + + title = metadata.get("title", metadata.get("post_title", "")) + if not isinstance(title, str): + title = str(title) if title else "" + + if title: + title = title.split("/")[-1] + else: + title = metadata.get("filename", page_content[:50] + "...") + + source = metadata.get("source", source_id) + + results.append( + { + "text": page_content, + "title": title, + "source": source, + } + ) + + if len(results) >= chunks: + break + + except Exception as e: + logger.error( + f"Error searching vectorstore {source_id}: {e}", + exc_info=True, + ) + continue + + return results[:chunks] + + +def search(api_key: str, query: str, chunks: int = 5) -> List[Dict[str, Any]]: + """Resolve an agent by API key and search its sources. + + Args: + api_key: Agent API key (the opaque string stored on + ``agents.key`` in Postgres). + query: Free-text search query. + chunks: Max number of hits to return. + + Returns: + List of hit dicts with ``text``, ``title``, ``source`` keys. + Empty list if the agent has no sources configured. + + Raises: + InvalidAPIKey: if ``api_key`` does not resolve to an agent. + SearchFailed: on unexpected DB / infrastructure errors. + """ + if chunks <= 0: + return [] + + try: + with db_readonly() as conn: + agent = AgentsRepository(conn).find_by_key(api_key) + except Exception as e: + raise SearchFailed("agent lookup failed") from e + + if not agent: + raise InvalidAPIKey() + + source_ids = _collect_source_ids(agent) + if not source_ids: + return [] + + return _search_sources(query, source_ids, chunks) diff --git a/docs/content/Deploying/Development-Environment.mdx b/docs/content/Deploying/Development-Environment.mdx index 057ed4f1..bbe1ebc6 100644 --- a/docs/content/Deploying/Development-Environment.mdx +++ b/docs/content/Deploying/Development-Environment.mdx @@ -104,7 +104,15 @@ To run the DocsGPT backend locally, you'll need to set up a Python environment a flask --app application/app.py run --host=0.0.0.0 --port=7091 ``` - This command will launch the backend server, making it accessible on `http://localhost:7091`. + This command will launch the backend server, making it accessible on `http://localhost:7091`. It's the fastest inner-loop option for day-to-day development — the Werkzeug interactive debugger still works and it hot-reloads on source changes. It serves the Flask routes only. + + If you need to exercise the full ASGI stack — the `/mcp` endpoint (FastMCP server), or to match the production runtime — run the ASGI composition under uvicorn instead: + + ```bash + uvicorn application.asgi:asgi_app --host 0.0.0.0 --port 7091 --reload + ``` + + Production uses `gunicorn -k uvicorn_worker.UvicornWorker` against the same `application.asgi:asgi_app` target. 6. **Start the Celery Worker:** diff --git a/pytest.ini b/pytest.ini index bfe52165..860b86a7 100644 --- a/pytest.ini +++ b/pytest.ini @@ -16,6 +16,7 @@ markers = unit: Unit tests integration: Integration tests slow: Slow running tests +asyncio_mode = strict filterwarnings = ignore::DeprecationWarning ignore::PendingDeprecationWarning diff --git a/scripts/mock_llm.py b/scripts/mock_llm.py new file mode 100644 index 00000000..8ce556ee --- /dev/null +++ b/scripts/mock_llm.py @@ -0,0 +1,137 @@ +"""Mock OpenAI-compatible LLM server for benchmarking. + +Fixed 5-second generation (100 tokens × 50 ms/token). No auth. Emits SSE +chunks in OpenAI's chat.completions streaming format, or a single response +when stream=false. Run on 127.0.0.1:8090 — point DocsGPT at it via +OPENAI_BASE_URL=http://127.0.0.1:8090/v1. +""" + +import asyncio +import json +import logging +import time +import uuid + +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse, StreamingResponse + +TOKEN_COUNT = 100 +TOKEN_DELAY_S = 0.05 # 100 * 0.05 = 5.0 s + +logger = logging.getLogger("mock_llm") +logging.basicConfig(level=logging.INFO, format="%(asctime)s mock: %(message)s") + +FILLER_TOKENS = [ + "Lorem", " ipsum", " dolor", " sit", " amet", ",", " consectetur", + " adipiscing", " elit", ".", " Sed", " do", " eiusmod", " tempor", + " incididunt", " ut", " labore", " et", " dolore", " magna", " aliqua", + ".", " Ut", " enim", " ad", " minim", " veniam", ",", " quis", " nostrud", + " exercitation", " ullamco", " laboris", " nisi", " ut", " aliquip", + " ex", " ea", " commodo", " consequat", ".", " Duis", " aute", " irure", + " dolor", " in", " reprehenderit", " in", " voluptate", " velit", + " esse", " cillum", " dolore", " eu", " fugiat", " nulla", " pariatur", + ".", " Excepteur", " sint", " occaecat", " cupidatat", " non", " proident", + ",", " sunt", " in", " culpa", " qui", " officia", " deserunt", + " mollit", " anim", " id", " est", " laborum", ".", " Curabitur", + " pretium", " tincidunt", " lacus", ".", " Nulla", " gravida", " orci", + " a", " odio", ".", " Nullam", " varius", ",", " turpis", " et", + " commodo", " pharetra", ",", " est", " eros", " bibendum", " elit", + ".", +] + +app = FastAPI() + + +def _token_stream_id() -> str: + return f"chatcmpl-mock-{uuid.uuid4().hex[:12]}" + + +def _sse_chunk(completion_id: str, model: str, delta: dict, finish_reason=None) -> str: + payload = { + "id": completion_id, + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model, + "choices": [ + { + "index": 0, + "delta": delta, + "finish_reason": finish_reason, + } + ], + } + return f"data: {json.dumps(payload)}\n\n" + + +async def _stream_response(model: str, req_id: str): + completion_id = _token_stream_id() + yield _sse_chunk(completion_id, model, {"role": "assistant", "content": ""}) + for i, tok in enumerate(FILLER_TOKENS[:TOKEN_COUNT]): + await asyncio.sleep(TOKEN_DELAY_S) + yield _sse_chunk(completion_id, model, {"content": tok}) + yield _sse_chunk(completion_id, model, {}, finish_reason="stop") + yield "data: [DONE]\n\n" + logger.info("[%s] stream done", req_id) + + +@app.post("/v1/chat/completions") +async def chat_completions(request: Request): + body = await request.json() + model = body.get("model", "mock") + stream = bool(body.get("stream", False)) + req_id = uuid.uuid4().hex[:8] + logger.info("[%s] /chat/completions stream=%s model=%s max_tokens=%s", req_id, stream, model, body.get("max_tokens")) + + if stream: + return StreamingResponse( + _stream_response(model, req_id), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache, no-transform", + "X-Accel-Buffering": "no", + }, + ) + + await asyncio.sleep(TOKEN_COUNT * TOKEN_DELAY_S) + logger.info("[%s] non-stream done", req_id) + text = "".join(FILLER_TOKENS[:TOKEN_COUNT]) + completion_id = _token_stream_id() + return JSONResponse( + { + "id": completion_id, + "object": "chat.completion", + "created": int(time.time()), + "model": model, + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": text}, + "finish_reason": "stop", + } + ], + "usage": { + "prompt_tokens": 10, + "completion_tokens": TOKEN_COUNT, + "total_tokens": 10 + TOKEN_COUNT, + }, + } + ) + + +@app.get("/v1/models") +async def list_models(): + return { + "object": "list", + "data": [{"id": "mock", "object": "model", "owned_by": "mock"}], + } + + +@app.get("/health") +async def health(): + return {"status": "ok"} + + +if __name__ == "__main__": + import uvicorn + + uvicorn.run(app, host="127.0.0.1", port=8090, log_level="info") diff --git a/tests/api/answer/routes/test_search.py b/tests/api/answer/routes/test_search.py index 6f38934c..73bda7dc 100644 --- a/tests/api/answer/routes/test_search.py +++ b/tests/api/answer/routes/test_search.py @@ -1,3 +1,17 @@ +"""Tests for /api/search route (application/api/answer/routes/search.py). + +Retrieval logic lives in ``application/services/search_service.py`` and +has its own unit tests in ``tests/services/test_search_service.py``. The +tests below focus on what the route specifically owns: + +* Request validation (400 for missing fields). +* Translation of the service's ``InvalidAPIKey`` / ``SearchFailed`` + exceptions to HTTP status codes (401 / 500). +* End-to-end happy path against a real ephemeral Postgres via + ``pg_conn``, to catch regressions in the route's wiring to the + service and repositories. +""" + from contextlib import contextmanager from unittest.mock import MagicMock, patch @@ -6,254 +20,97 @@ import pytest @pytest.mark.unit class TestSearchResourceValidation: - pass - - def test_returns_error_when_question_missing(self, mock_mongo_db, flask_app): + def test_returns_400_when_question_missing(self, flask_app): from application.api.answer.routes.search import SearchResource with flask_app.app_context(): - with flask_app.test_request_context( - json={"api_key": "test_key"} - ): - resource = SearchResource() - result = resource.post() - + with flask_app.test_request_context(json={"api_key": "test_key"}): + result = SearchResource().post() assert result.status_code == 400 assert "question" in result.json["error"] - def test_returns_error_when_api_key_missing(self, mock_mongo_db, flask_app): + def test_returns_400_when_api_key_missing(self, flask_app): from application.api.answer.routes.search import SearchResource with flask_app.app_context(): - with flask_app.test_request_context( - json={"question": "test query"} - ): - resource = SearchResource() - result = resource.post() - + with flask_app.test_request_context(json={"question": "test query"}): + result = SearchResource().post() assert result.status_code == 400 assert "api_key" in result.json["error"] - @pytest.mark.unit -class TestGetSourcesFromApiKey: - pass +class TestSearchResourceExceptionMapping: + """Verify the route maps service exceptions to HTTP status codes. - def test_returns_source_id_via_patched_method(self, mock_mongo_db, flask_app): - """Test that _get_sources_from_api_key can return multiple sources via patch.""" + The service function itself is patched; these tests do not care about + the search logic — only that 401/500/200 are produced correctly from + the three possible service outcomes. + """ + + def test_invalid_api_key_returns_401(self, flask_app): + from application.api.answer.routes.search import SearchResource + from application.services.search_service import InvalidAPIKey + + with flask_app.app_context(), flask_app.test_request_context( + json={"question": "q", "api_key": "bad"} + ), patch( + "application.api.answer.routes.search.search", + side_effect=InvalidAPIKey(), + ): + result = SearchResource().post() + assert result.status_code == 401 + assert result.json == {"error": "Invalid API key"} + + def test_search_failed_returns_500(self, flask_app): + from application.api.answer.routes.search import SearchResource + from application.services.search_service import SearchFailed + + with flask_app.app_context(), flask_app.test_request_context( + json={"question": "q", "api_key": "k"} + ), patch( + "application.api.answer.routes.search.search", + side_effect=SearchFailed("boom"), + ): + result = SearchResource().post() + assert result.status_code == 500 + assert result.json == {"error": "Search failed"} + + def test_happy_path_passes_service_result_through(self, flask_app): from application.api.answer.routes.search import SearchResource - with flask_app.app_context(): - resource = SearchResource() + hits = [{"text": "t", "title": "T", "source": "s"}] + with flask_app.app_context(), flask_app.test_request_context( + json={"question": "q", "api_key": "k", "chunks": 7} + ), patch( + "application.api.answer.routes.search.search", + return_value=hits, + ) as mock_search: + result = SearchResource().post() + assert result.status_code == 200 + assert result.json == hits + mock_search.assert_called_once_with("k", "q", 7) - with patch.object(resource, "_get_sources_from_api_key", return_value=["src1", "src2"]): - result = resource._get_sources_from_api_key("any_key") - - assert len(result) == 2 - assert "src1" in result - assert "src2" in result - - - -@pytest.mark.unit -class TestSearchVectorstores: - pass - - def test_returns_empty_when_no_source_ids(self, mock_mongo_db, flask_app): + def test_default_chunks_is_5(self, flask_app): from application.api.answer.routes.search import SearchResource - with flask_app.app_context(): - resource = SearchResource() - - result = resource._search_vectorstores("test query", [], 5) - - assert result == [] - - def test_skips_empty_source_ids(self, mock_mongo_db, flask_app): - from application.api.answer.routes.search import SearchResource - - with flask_app.app_context(): - resource = SearchResource() - - with patch( - "application.api.answer.routes.search.VectorCreator.create_vectorstore" - ) as mock_create: - mock_vectorstore = MagicMock() - mock_vectorstore.search.return_value = [] - mock_create.return_value = mock_vectorstore - - result = resource._search_vectorstores("test query", ["", " "], 5) - - mock_create.assert_not_called() - assert result == [] - - def test_returns_search_results(self, mock_mongo_db, flask_app): - from application.api.answer.routes.search import SearchResource - - with flask_app.app_context(): - resource = SearchResource() - - mock_doc = { - "text": "Test content", - "page_content": "Test content", - "metadata": { - "title": "Test Title", - "source": "/path/to/doc", - }, - } - - with patch( - "application.api.answer.routes.search.VectorCreator.create_vectorstore" - ) as mock_create: - mock_vectorstore = MagicMock() - mock_vectorstore.search.return_value = [mock_doc] - mock_create.return_value = mock_vectorstore - - result = resource._search_vectorstores("test query", ["source_id"], 5) - - assert len(result) == 1 - assert result[0]["text"] == "Test content" - assert result[0]["title"] == "Test Title" - assert result[0]["source"] == "/path/to/doc" - - def test_handles_langchain_document_format(self, mock_mongo_db, flask_app): - from application.api.answer.routes.search import SearchResource - - with flask_app.app_context(): - resource = SearchResource() - - mock_doc = MagicMock() - mock_doc.page_content = "Langchain content" - mock_doc.metadata = {"title": "LC Title", "source": "/lc/path"} - - with patch( - "application.api.answer.routes.search.VectorCreator.create_vectorstore" - ) as mock_create: - mock_vectorstore = MagicMock() - mock_vectorstore.search.return_value = [mock_doc] - mock_create.return_value = mock_vectorstore - - result = resource._search_vectorstores("test query", ["source_id"], 5) - - assert len(result) == 1 - assert result[0]["text"] == "Langchain content" - assert result[0]["title"] == "LC Title" - - def test_respects_chunks_limit(self, mock_mongo_db, flask_app): - from application.api.answer.routes.search import SearchResource - - with flask_app.app_context(): - resource = SearchResource() - - mock_docs = [ - {"text": f"Content {i}", "metadata": {"title": f"Title {i}"}} - for i in range(10) - ] - - with patch( - "application.api.answer.routes.search.VectorCreator.create_vectorstore" - ) as mock_create: - mock_vectorstore = MagicMock() - mock_vectorstore.search.return_value = mock_docs - mock_create.return_value = mock_vectorstore - - result = resource._search_vectorstores("test query", ["source_id"], 3) - - assert len(result) == 3 - - def test_deduplicates_results(self, mock_mongo_db, flask_app): - from application.api.answer.routes.search import SearchResource - - with flask_app.app_context(): - resource = SearchResource() - - duplicate_text = "Duplicate content " * 20 - mock_docs = [ - {"text": duplicate_text, "metadata": {"title": "Title 1"}}, - {"text": duplicate_text, "metadata": {"title": "Title 2"}}, - {"text": "Unique content", "metadata": {"title": "Title 3"}}, - ] - - with patch( - "application.api.answer.routes.search.VectorCreator.create_vectorstore" - ) as mock_create: - mock_vectorstore = MagicMock() - mock_vectorstore.search.return_value = mock_docs - mock_create.return_value = mock_vectorstore - - result = resource._search_vectorstores("test query", ["source_id"], 5) - - assert len(result) == 2 - - def test_handles_vectorstore_error_gracefully(self, mock_mongo_db, flask_app): - from application.api.answer.routes.search import SearchResource - - with flask_app.app_context(): - resource = SearchResource() - - with patch( - "application.api.answer.routes.search.VectorCreator.create_vectorstore" - ) as mock_create: - mock_create.side_effect = Exception("Vectorstore error") - - result = resource._search_vectorstores("test query", ["source_id"], 5) - - assert result == [] - - def test_uses_filename_as_title_fallback(self, mock_mongo_db, flask_app): - from application.api.answer.routes.search import SearchResource - - with flask_app.app_context(): - resource = SearchResource() - - mock_doc = { - "text": "Content without title", - "metadata": {"filename": "document.pdf"}, - } - - with patch( - "application.api.answer.routes.search.VectorCreator.create_vectorstore" - ) as mock_create: - mock_vectorstore = MagicMock() - mock_vectorstore.search.return_value = [mock_doc] - mock_create.return_value = mock_vectorstore - - result = resource._search_vectorstores("test query", ["source_id"], 5) - - assert result[0]["title"] == "document.pdf" - - def test_uses_content_snippet_as_title_last_resort(self, mock_mongo_db, flask_app): - from application.api.answer.routes.search import SearchResource - - with flask_app.app_context(): - resource = SearchResource() - - mock_doc = { - "text": "Content without any title metadata at all", - "metadata": {}, - } - - with patch( - "application.api.answer.routes.search.VectorCreator.create_vectorstore" - ) as mock_create: - mock_vectorstore = MagicMock() - mock_vectorstore.search.return_value = [mock_doc] - mock_create.return_value = mock_vectorstore - - result = resource._search_vectorstores("test query", ["source_id"], 5) - - assert "Content without any title" in result[0]["title"] - assert result[0]["title"].endswith("...") - - -@pytest.mark.unit -class TestSearchEndpoint: - pass + with flask_app.app_context(), flask_app.test_request_context( + json={"question": "q", "api_key": "k"} # no chunks field + ), patch( + "application.api.answer.routes.search.search", + return_value=[], + ) as mock_search: + SearchResource().post() + mock_search.assert_called_once_with("k", "q", 5) # --------------------------------------------------------------------------- -# Real-PG tests for SearchResource. +# End-to-end against a real ephemeral Postgres. +# +# These exercise the full route → service → repository → DB path, patching +# only ``VectorCreator.create_vectorstore`` (so we don't need real embeddings +# or a vector index). ``db_readonly`` is redirected at the *service* module +# since that's where the import now lives. # --------------------------------------------------------------------------- @@ -264,7 +121,7 @@ def _patch_search_db(conn): yield conn with patch( - "application.api.answer.routes.search.db_readonly", _yield + "application.services.search_service.db_readonly", _yield ): yield @@ -298,9 +155,7 @@ class TestSearchResourcePgConn: def test_search_returns_results(self, pg_conn, flask_app): from application.api.answer.routes.search import SearchResource from application.storage.db.repositories.agents import AgentsRepository - from application.storage.db.repositories.sources import ( - SourcesRepository, - ) + from application.storage.db.repositories.sources import SourcesRepository src = SourcesRepository(pg_conn).create("src", user_id="u") AgentsRepository(pg_conn).create( @@ -315,7 +170,7 @@ class TestSearchResourcePgConn: ] with _patch_search_db(pg_conn), patch( - "application.api.answer.routes.search.VectorCreator.create_vectorstore", + "application.services.search_service.VectorCreator.create_vectorstore", return_value=fake_vs, ), flask_app.app_context(): with flask_app.test_request_context( @@ -328,9 +183,7 @@ class TestSearchResourcePgConn: def test_search_uses_extra_source_ids(self, pg_conn, flask_app): from application.api.answer.routes.search import SearchResource from application.storage.db.repositories.agents import AgentsRepository - from application.storage.db.repositories.sources import ( - SourcesRepository, - ) + from application.storage.db.repositories.sources import SourcesRepository src1 = SourcesRepository(pg_conn).create("s1", user_id="u") src2 = SourcesRepository(pg_conn).create("s2", user_id="u") @@ -345,7 +198,7 @@ class TestSearchResourcePgConn: {"text": "one", "metadata": {"title": "A"}}, ] with _patch_search_db(pg_conn), patch( - "application.api.answer.routes.search.VectorCreator.create_vectorstore", + "application.services.search_service.VectorCreator.create_vectorstore", return_value=fake_vs, ), flask_app.app_context(): with flask_app.test_request_context( @@ -353,71 +206,3 @@ class TestSearchResourcePgConn: ): result = SearchResource().post() assert result.status_code == 200 - - def test_search_exception_returns_500(self, pg_conn, flask_app): - from application.api.answer.routes.search import SearchResource - from application.storage.db.repositories.agents import AgentsRepository - from application.storage.db.repositories.sources import ( - SourcesRepository, - ) - - src = SourcesRepository(pg_conn).create("src", user_id="u") - AgentsRepository(pg_conn).create( - "u", "a", "published", - key="err-key", - source_id=str(src["id"]), - ) - - with _patch_search_db(pg_conn), patch( - "application.api.answer.routes.search.SearchResource._get_sources_from_api_key", - side_effect=RuntimeError("boom"), - ), flask_app.app_context(): - with flask_app.test_request_context( - json={"question": "q", "api_key": "err-key"}, - ): - result = SearchResource().post() - assert result.status_code == 500 - - -class TestGetSourcesFromApiKeyPg: - def test_empty_for_unknown_key(self, pg_conn, flask_app): - from application.api.answer.routes.search import SearchResource - - with _patch_search_db(pg_conn), flask_app.app_context(): - got = SearchResource()._get_sources_from_api_key("nope") - assert got == [] - - def test_returns_extra_source_ids(self, pg_conn, flask_app): - from application.api.answer.routes.search import SearchResource - from application.storage.db.repositories.agents import AgentsRepository - from application.storage.db.repositories.sources import ( - SourcesRepository, - ) - - src = SourcesRepository(pg_conn).create("s", user_id="u") - AgentsRepository(pg_conn).create( - "u", "a", "published", - key="sources-key", - extra_source_ids=[str(src["id"])], - ) - with _patch_search_db(pg_conn), flask_app.app_context(): - got = SearchResource()._get_sources_from_api_key("sources-key") - assert got == [str(src["id"])] - - def test_falls_back_to_single_source(self, pg_conn, flask_app): - from application.api.answer.routes.search import SearchResource - from application.storage.db.repositories.agents import AgentsRepository - from application.storage.db.repositories.sources import ( - SourcesRepository, - ) - - src = SourcesRepository(pg_conn).create("s", user_id="u") - AgentsRepository(pg_conn).create( - "u", "a", "published", - key="single-key", - source_id=str(src["id"]), - ) - with _patch_search_db(pg_conn), flask_app.app_context(): - got = SearchResource()._get_sources_from_api_key("single-key") - assert got == [str(src["id"])] - diff --git a/tests/requirements.txt b/tests/requirements.txt index b555cbf8..2af47c93 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -1,4 +1,5 @@ pytest>=8.0.0 +pytest-asyncio>=0.23 pytest-cov>=4.1.0 coverage>=7.4.0 pytest-postgresql>=6.0.0 diff --git a/tests/services/__init__.py b/tests/services/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/services/test_mcp_server.py b/tests/services/test_mcp_server.py new file mode 100644 index 00000000..c1da1cbd --- /dev/null +++ b/tests/services/test_mcp_server.py @@ -0,0 +1,134 @@ +"""Tests for application/mcp_server.py. + +The server module exposes one FastMCP tool, ``search_docs``, that reads +the caller's ``Authorization: Bearer `` header via +``get_http_headers()`` and delegates to +``application.services.search_service.search``. These tests exercise +the tool directly by patching ``get_http_headers`` and ``search``; the +full HTTP-layer plumbing (mount, lifespan, session handshake) is +covered by ``tests/test_asgi.py``. +""" + +from unittest.mock import patch + +import pytest + + +@pytest.mark.unit +class TestSearchDocsTool: + @pytest.mark.asyncio + async def test_missing_bearer_raises_permission_error(self): + from application.mcp_server import search_docs + + with patch( + "application.mcp_server.get_http_headers", return_value={} + ): + with pytest.raises(PermissionError): + await search_docs(query="hi") + + @pytest.mark.asyncio + async def test_non_bearer_header_raises_permission_error(self): + from application.mcp_server import search_docs + + with patch( + "application.mcp_server.get_http_headers", + return_value={"authorization": "Basic dXNlcjpwYXNz"}, + ): + with pytest.raises(PermissionError): + await search_docs(query="hi") + + @pytest.mark.asyncio + async def test_blank_bearer_token_raises_permission_error(self): + from application.mcp_server import search_docs + + with patch( + "application.mcp_server.get_http_headers", + return_value={"authorization": "Bearer "}, + ): + with pytest.raises(PermissionError): + await search_docs(query="hi") + + @pytest.mark.asyncio + async def test_invalid_api_key_raises_permission_error(self): + from application.mcp_server import search_docs + from application.services.search_service import InvalidAPIKey + + with ( + patch( + "application.mcp_server.get_http_headers", + return_value={"authorization": "Bearer bogus"}, + ), + patch( + "application.mcp_server.search", side_effect=InvalidAPIKey() + ), + ): + with pytest.raises(PermissionError): + await search_docs(query="hi") + + @pytest.mark.asyncio + async def test_search_failed_bubbles_up(self): + from application.mcp_server import search_docs + from application.services.search_service import SearchFailed + + with ( + patch( + "application.mcp_server.get_http_headers", + return_value={"authorization": "Bearer k"}, + ), + patch( + "application.mcp_server.search", + side_effect=SearchFailed("boom"), + ), + ): + with pytest.raises(SearchFailed): + await search_docs(query="hi") + + @pytest.mark.asyncio + async def test_happy_path_passes_args_and_returns_hits(self): + from application.mcp_server import search_docs + + hits = [{"text": "t", "title": "T", "source": "s"}] + with ( + patch( + "application.mcp_server.get_http_headers", + return_value={"authorization": "Bearer the-key"}, + ), + patch( + "application.mcp_server.search", return_value=hits + ) as mock_search, + ): + out = await search_docs(query="q", chunks=7) + assert out == hits + mock_search.assert_called_once_with("the-key", "q", 7) + + @pytest.mark.asyncio + async def test_default_chunks_is_5(self): + from application.mcp_server import search_docs + + with ( + patch( + "application.mcp_server.get_http_headers", + return_value={"authorization": "Bearer k"}, + ), + patch( + "application.mcp_server.search", return_value=[] + ) as mock_search, + ): + await search_docs(query="q") + mock_search.assert_called_once_with("k", "q", 5) + + @pytest.mark.asyncio + async def test_bearer_scheme_case_insensitive(self): + from application.mcp_server import search_docs + + with ( + patch( + "application.mcp_server.get_http_headers", + return_value={"authorization": "bearer lowercase-scheme"}, + ), + patch( + "application.mcp_server.search", return_value=[] + ) as mock_search, + ): + await search_docs(query="q") + mock_search.assert_called_once_with("lowercase-scheme", "q", 5) diff --git a/tests/services/test_search_service.py b/tests/services/test_search_service.py new file mode 100644 index 00000000..2b725404 --- /dev/null +++ b/tests/services/test_search_service.py @@ -0,0 +1,230 @@ +"""Unit tests for application/services/search_service.py. + +Tests exercise the service function in isolation — AgentsRepository is +stubbed via a patched ``db_readonly`` context manager, and +``VectorCreator.create_vectorstore`` is patched to return a fake +vectorstore. No Flask app context, no real DB, no real embeddings. +""" + +from contextlib import contextmanager +from unittest.mock import MagicMock, patch + +import pytest + +from application.services.search_service import ( + InvalidAPIKey, + SearchFailed, + _collect_source_ids, + search, +) + + +@contextmanager +def _fake_db_readonly(agent_data): + """Patch ``db_readonly`` so ``AgentsRepository.find_by_key`` returns ``agent_data``.""" + agents_repo = MagicMock() + agents_repo.find_by_key.return_value = agent_data + + @contextmanager + def _yield_conn(): + yield MagicMock() + + with patch( + "application.services.search_service.db_readonly", _yield_conn + ), patch( + "application.services.search_service.AgentsRepository", + return_value=agents_repo, + ): + yield + + +@pytest.mark.unit +class TestCollectSourceIds: + def test_empty_when_no_sources(self): + assert _collect_source_ids({}) == [] + + def test_returns_extra_source_ids(self): + agent = {"extra_source_ids": ["s1", "s2"], "source_id": "legacy"} + assert _collect_source_ids(agent) == ["s1", "s2"] + + def test_falls_back_to_single_source_id(self): + agent = {"extra_source_ids": [], "source_id": "s1"} + assert _collect_source_ids(agent) == ["s1"] + + def test_skips_empty_entries_in_extra(self): + agent = {"extra_source_ids": ["", None, "s1"], "source_id": "fallback"} + assert _collect_source_ids(agent) == ["s1"] + + +@pytest.mark.unit +class TestSearchInvalidAPIKey: + def test_raises_when_key_unknown(self): + with _fake_db_readonly(None): + with pytest.raises(InvalidAPIKey): + search("does-not-exist", "hello", 5) + + def test_raises_search_failed_on_db_error(self): + @contextmanager + def _yield_conn(): + yield MagicMock() + + agents_repo = MagicMock() + agents_repo.find_by_key.side_effect = RuntimeError("db down") + + with patch( + "application.services.search_service.db_readonly", _yield_conn + ), patch( + "application.services.search_service.AgentsRepository", + return_value=agents_repo, + ): + with pytest.raises(SearchFailed): + search("any-key", "hello", 5) + + +@pytest.mark.unit +class TestSearchEmptyWhenNoSources: + def test_returns_empty_when_agent_has_no_sources(self): + with _fake_db_readonly({"extra_source_ids": [], "source_id": None}): + assert search("k", "q", 5) == [] + + def test_returns_empty_for_zero_chunks_without_db_lookup(self): + with patch("application.services.search_service.db_readonly") as mock_db: + assert search("k", "q", 0) == [] + mock_db.assert_not_called() + + def test_returns_empty_for_negative_chunks_without_db_lookup(self): + with patch("application.services.search_service.db_readonly") as mock_db: + assert search("k", "q", -1) == [] + mock_db.assert_not_called() + + +@pytest.mark.unit +class TestSearchResults: + def test_returns_hit_shape(self): + agent = {"source_id": "src-1", "extra_source_ids": []} + fake_vs = MagicMock() + fake_vs.search.return_value = [ + { + "text": "Test content", + "metadata": {"title": "Test Title", "source": "/path/to/doc"}, + } + ] + with _fake_db_readonly(agent), patch( + "application.services.search_service.VectorCreator.create_vectorstore", + return_value=fake_vs, + ): + results = search("k", "q", 5) + assert results == [ + {"text": "Test content", "title": "Test Title", "source": "/path/to/doc"} + ] + + def test_handles_langchain_document_format(self): + agent = {"source_id": "src-1", "extra_source_ids": []} + lc_doc = MagicMock() + lc_doc.page_content = "Langchain content" + lc_doc.metadata = {"title": "LC Title", "source": "/lc/path"} + + fake_vs = MagicMock() + fake_vs.search.return_value = [lc_doc] + + with _fake_db_readonly(agent), patch( + "application.services.search_service.VectorCreator.create_vectorstore", + return_value=fake_vs, + ): + results = search("k", "q", 5) + assert len(results) == 1 + assert results[0]["text"] == "Langchain content" + assert results[0]["title"] == "LC Title" + + def test_respects_chunks_cap(self): + agent = {"source_id": "src-1", "extra_source_ids": []} + docs = [ + {"text": f"Content {i}", "metadata": {"title": f"T{i}"}} + for i in range(10) + ] + fake_vs = MagicMock() + fake_vs.search.return_value = docs + + with _fake_db_readonly(agent), patch( + "application.services.search_service.VectorCreator.create_vectorstore", + return_value=fake_vs, + ): + results = search("k", "q", 3) + assert len(results) == 3 + + def test_deduplicates_results_by_content_prefix(self): + agent = {"source_id": "src-1", "extra_source_ids": []} + dup_text = "Duplicate content " * 20 + docs = [ + {"text": dup_text, "metadata": {"title": "T1"}}, + {"text": dup_text, "metadata": {"title": "T2"}}, + {"text": "Unique content", "metadata": {"title": "T3"}}, + ] + fake_vs = MagicMock() + fake_vs.search.return_value = docs + + with _fake_db_readonly(agent), patch( + "application.services.search_service.VectorCreator.create_vectorstore", + return_value=fake_vs, + ): + results = search("k", "q", 5) + assert len(results) == 2 + + def test_skips_broken_source_and_returns_from_healthy_ones(self): + # Two sources — the first raises, the second returns a doc. The + # caller should still get the healthy source's result. + agent = {"extra_source_ids": ["broken", "ok"], "source_id": None} + healthy_vs = MagicMock() + healthy_vs.search.return_value = [ + {"text": "ok content", "metadata": {"title": "Ok"}} + ] + + def create_vs(store, source_id, key): + if source_id == "broken": + raise RuntimeError("vector index missing") + return healthy_vs + + with _fake_db_readonly(agent), patch( + "application.services.search_service.VectorCreator.create_vectorstore", + side_effect=create_vs, + ): + results = search("k", "q", 5) + assert len(results) == 1 + assert results[0]["text"] == "ok content" + + def test_uses_filename_when_title_missing(self): + agent = {"source_id": "src-1", "extra_source_ids": []} + fake_vs = MagicMock() + fake_vs.search.return_value = [ + {"text": "body", "metadata": {"filename": "document.pdf"}} + ] + with _fake_db_readonly(agent), patch( + "application.services.search_service.VectorCreator.create_vectorstore", + return_value=fake_vs, + ): + results = search("k", "q", 5) + assert results[0]["title"] == "document.pdf" + + def test_uses_content_snippet_as_title_last_resort(self): + agent = {"source_id": "src-1", "extra_source_ids": []} + fake_vs = MagicMock() + fake_vs.search.return_value = [ + {"text": "Content without any title metadata at all", "metadata": {}} + ] + with _fake_db_readonly(agent), patch( + "application.services.search_service.VectorCreator.create_vectorstore", + return_value=fake_vs, + ): + results = search("k", "q", 5) + assert results[0]["title"].endswith("...") + assert "Content without any title" in results[0]["title"] + + def test_skips_empty_source_ids(self): + # ``source_id=" "`` only — after strip() this leaves no real source. + agent = {"extra_source_ids": [" ", ""], "source_id": None} + with _fake_db_readonly(agent), patch( + "application.services.search_service.VectorCreator.create_vectorstore" + ) as mock_create: + results = search("k", "q", 5) + mock_create.assert_not_called() + assert results == [] diff --git a/tests/test_app_routes.py b/tests/test_app_routes.py index eb5fe26f..656b550b 100644 --- a/tests/test_app_routes.py +++ b/tests/test_app_routes.py @@ -105,11 +105,26 @@ class TestAuthenticateRequest: assert response.status_code == 200 -class TestAfterRequest: +class TestFlaskCors: @pytest.mark.unit - def test_cors_headers(self, client): - response = client.get("/api/health") - assert response.headers.get("Access-Control-Allow-Origin") == "*" - assert "Content-Type" in response.headers.get("Access-Control-Allow-Headers", "") - assert "GET" in response.headers.get("Access-Control-Allow-Methods", "") + def test_cors_headers_on_flask_route(self, client): + response = client.get("/api/health", headers={"Origin": "http://localhost:5173"}) + assert response.headers["Access-Control-Allow-Origin"] == "*" + assert response.headers["Access-Control-Allow-Headers"] == "Content-Type, Authorization" + assert response.headers["Access-Control-Allow-Methods"] == "GET, POST, PUT, DELETE, OPTIONS" + + @pytest.mark.unit + def test_cors_headers_on_flask_preflight(self, client): + response = client.options( + "/api/health", + headers={ + "Origin": "http://localhost:5173", + "Access-Control-Request-Method": "GET", + "Access-Control-Request-Headers": "Content-Type", + }, + ) + assert response.status_code == 200 + assert response.headers["Access-Control-Allow-Origin"] == "*" + assert response.headers["Access-Control-Allow-Headers"] == "Content-Type, Authorization" + assert response.headers["Access-Control-Allow-Methods"] == "GET, POST, PUT, DELETE, OPTIONS" diff --git a/tests/test_asgi.py b/tests/test_asgi.py new file mode 100644 index 00000000..57fd4fc9 --- /dev/null +++ b/tests/test_asgi.py @@ -0,0 +1,136 @@ +"""Smoke tests for application/asgi.py. + +The goal isn't to re-test Flask or FastMCP internals — it's to catch +regressions in the wiring: mounts resolve, CORS headers emit, lifespan +runs (without it, the /mcp session manager raises "Task group is not +initialized"), routing to ``/`` vs ``/mcp`` doesn't cross paths. + +Uses ``starlette.testclient.TestClient`` because it boots the ASGI app +end-to-end and handles the lifespan protocol automatically — ``httpx`` +alone does not run lifespan events, which would mask the exact kind of +misconfiguration this test suite exists to catch. +""" + +import pytest + + +@pytest.mark.unit +def test_asgi_app_imports(): + from application.asgi import asgi_app + + assert asgi_app is not None + + +@pytest.mark.unit +def test_flask_route_served_through_starlette_mount(): + """GET /api/health should reach the Flask app via a2wsgi and return 200.""" + from starlette.testclient import TestClient + + from application.asgi import asgi_app + + with TestClient(asgi_app) as client: + r = client.get("/api/health") + assert r.status_code == 200 + assert r.json() == {"status": "ok"} + + +@pytest.mark.unit +def test_mcp_endpoint_mounted_and_lifespan_runs(): + """/mcp must be reachable AND the FastMCP session manager must start. + + Without ``lifespan=mcp_app.lifespan`` on the outer Starlette app, + every /mcp request raises ``RuntimeError: Task group is not + initialized``. Hitting the endpoint under a real lifespan-aware + client catches that. + """ + from starlette.testclient import TestClient + + from application.asgi import asgi_app + + with TestClient(asgi_app) as client: + # Minimal MCP initialize request. Doesn't need to succeed — we + # just need a non-404, non-500-with-RuntimeError response to + # confirm the mount + lifespan are both wired. + r = client.post( + "/mcp/", + headers={ + "Origin": "http://example.com", + "Content-Type": "application/json", + "Accept": "application/json, text/event-stream", + }, + json={ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2025-03-26", + "capabilities": {}, + "clientInfo": {"name": "pytest", "version": "0"}, + }, + }, + ) + assert r.status_code != 404, f"/mcp mount unreachable: {r.status_code}" + # A successful initialize returns 200 with a Mcp-Session-Id header. + assert r.status_code == 200 + assert "mcp-session-id" in {k.lower() for k in r.headers.keys()} + assert r.headers.get("access-control-expose-headers") == "Mcp-Session-Id" + + +@pytest.mark.unit +def test_cors_headers_on_flask_route(): + """CORS middleware should emit allow-origin on actual (non-preflight) requests. + + ``allow_origins=["*"]`` → header value is literal ``*`` (not an echo). + """ + from starlette.testclient import TestClient + + from application.asgi import asgi_app + + with TestClient(asgi_app) as client: + r = client.get("/api/health", headers={"Origin": "http://example.com"}) + assert r.status_code == 200 + assert r.headers.get("access-control-allow-origin") == "*" + + +@pytest.mark.unit +def test_cors_preflight_on_flask_route(): + """OPTIONS preflight on a Flask route should be handled by Starlette CORSMiddleware.""" + from starlette.testclient import TestClient + + from application.asgi import asgi_app + + with TestClient(asgi_app) as client: + r = client.options( + "/api/health", + headers={ + "Origin": "http://example.com", + "Access-Control-Request-Method": "GET", + "Access-Control-Request-Headers": "Content-Type", + }, + ) + assert r.status_code in (200, 204) + assert r.headers.get("access-control-allow-origin") == "*" + assert "GET" in r.headers.get("access-control-allow-methods", "") + + +@pytest.mark.unit +def test_cors_preflight_on_mcp_route(): + """Browser clients hitting /mcp should be allowed to send session headers.""" + from starlette.testclient import TestClient + + from application.asgi import asgi_app + + with TestClient(asgi_app) as client: + r = client.options( + "/mcp/", + headers={ + "Origin": "http://example.com", + "Access-Control-Request-Method": "POST", + "Access-Control-Request-Headers": ( + "Authorization, Content-Type, Mcp-Session-Id" + ), + }, + ) + assert r.status_code in (200, 204) + assert r.headers.get("access-control-allow-origin") == "*" + assert "Mcp-Session-Id" in r.headers.get("access-control-allow-headers", "")