mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-05-21 04:45:06 +00:00
609 lines
20 KiB
Python
609 lines
20 KiB
Python
from contextlib import contextmanager
|
|
from datetime import timedelta
|
|
from unittest.mock import ANY, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
|
|
@contextmanager
|
|
def _patch_decorator_db(conn):
|
|
"""Route the decorator's own ``db_session`` / ``db_readonly`` at ``conn``."""
|
|
|
|
@contextmanager
|
|
def _yield():
|
|
yield conn
|
|
|
|
with patch(
|
|
"application.api.user.idempotency.db_session", _yield
|
|
), patch(
|
|
"application.api.user.idempotency.db_readonly", _yield
|
|
):
|
|
yield
|
|
|
|
|
|
class TestIngestTask:
|
|
@pytest.mark.unit
|
|
@patch("application.api.user.tasks.ingest_worker")
|
|
def test_calls_ingest_worker(self, mock_worker):
|
|
from application.api.user.tasks import ingest
|
|
|
|
mock_worker.return_value = {"status": "ok"}
|
|
|
|
result = ingest("dir", ["pdf"], "job1", "user1", "/path", "file.pdf")
|
|
|
|
mock_worker.assert_called_once_with(
|
|
ANY, "dir", ["pdf"], "job1", "/path", "file.pdf", "user1",
|
|
file_name_map=None, idempotency_key=None, source_id=None,
|
|
)
|
|
assert result == {"status": "ok"}
|
|
|
|
@pytest.mark.unit
|
|
@patch("application.api.user.tasks.ingest_worker")
|
|
def test_passes_file_name_map(self, mock_worker):
|
|
from application.api.user.tasks import ingest
|
|
|
|
mock_worker.return_value = {"status": "ok"}
|
|
name_map = {"a.pdf": "b.pdf"}
|
|
|
|
ingest("dir", ["pdf"], "job1", "user1", "/path", "file.pdf",
|
|
file_name_map=name_map)
|
|
|
|
mock_worker.assert_called_once_with(
|
|
ANY, "dir", ["pdf"], "job1", "/path", "file.pdf", "user1",
|
|
file_name_map=name_map, idempotency_key=None, source_id=None,
|
|
)
|
|
|
|
|
|
class TestIngestRemoteTask:
|
|
@pytest.mark.unit
|
|
@patch("application.api.user.tasks.remote_worker")
|
|
def test_calls_remote_worker(self, mock_worker):
|
|
from application.api.user.tasks import ingest_remote
|
|
|
|
mock_worker.return_value = {"status": "ok"}
|
|
|
|
result = ingest_remote({"url": "http://x"}, "job1", "user1", "web")
|
|
|
|
mock_worker.assert_called_once_with(
|
|
ANY, {"url": "http://x"}, "job1", "user1", "web",
|
|
idempotency_key=None, source_id=None,
|
|
)
|
|
assert result == {"status": "ok"}
|
|
|
|
|
|
class TestReingestSourceTask:
|
|
@pytest.mark.unit
|
|
@patch("application.worker.reingest_source_worker")
|
|
def test_calls_reingest_worker(self, mock_worker):
|
|
from application.api.user.tasks import reingest_source_task
|
|
|
|
mock_worker.return_value = {"status": "ok"}
|
|
|
|
result = reingest_source_task("source123", "user1")
|
|
|
|
mock_worker.assert_called_once_with(ANY, "source123", "user1")
|
|
assert result == {"status": "ok"}
|
|
|
|
|
|
class TestScheduleSyncsTask:
|
|
@pytest.mark.unit
|
|
@patch("application.api.user.tasks.sync_worker")
|
|
def test_calls_sync_worker(self, mock_worker):
|
|
from application.api.user.tasks import schedule_syncs
|
|
|
|
mock_worker.return_value = {"status": "ok"}
|
|
|
|
result = schedule_syncs("daily")
|
|
|
|
mock_worker.assert_called_once_with(ANY, "daily")
|
|
assert result == {"status": "ok"}
|
|
|
|
|
|
class TestSyncSourceTask:
|
|
@pytest.mark.unit
|
|
@patch("application.api.user.tasks.sync")
|
|
def test_calls_sync(self, mock_sync):
|
|
from application.api.user.tasks import sync_source
|
|
|
|
mock_sync.return_value = {"status": "ok"}
|
|
|
|
result = sync_source(
|
|
{"data": 1}, "job1", "user1", "web", "daily", "classic", "doc1"
|
|
)
|
|
|
|
mock_sync.assert_called_once_with(
|
|
ANY, {"data": 1}, "job1", "user1", "web", "daily", "classic", "doc1"
|
|
)
|
|
assert result == {"status": "ok"}
|
|
|
|
|
|
class TestStoreAttachmentTask:
|
|
@pytest.mark.unit
|
|
@patch("application.api.user.tasks.attachment_worker")
|
|
def test_calls_attachment_worker(self, mock_worker):
|
|
from application.api.user.tasks import store_attachment
|
|
|
|
mock_worker.return_value = {"status": "ok"}
|
|
|
|
result = store_attachment({"file": "info"}, "user1")
|
|
|
|
mock_worker.assert_called_once_with(ANY, {"file": "info"}, "user1")
|
|
assert result == {"status": "ok"}
|
|
|
|
|
|
class TestProcessAgentWebhookTask:
|
|
@pytest.mark.unit
|
|
@patch("application.api.user.tasks.agent_webhook_worker")
|
|
def test_calls_agent_webhook_worker(self, mock_worker):
|
|
from application.api.user.tasks import process_agent_webhook
|
|
|
|
mock_worker.return_value = {"status": "ok"}
|
|
|
|
result = process_agent_webhook("agent123", {"event": "test"})
|
|
|
|
mock_worker.assert_called_once_with(ANY, "agent123", {"event": "test"})
|
|
assert result == {"status": "ok"}
|
|
|
|
|
|
class TestIngestConnectorTask:
|
|
@pytest.mark.unit
|
|
@patch("application.worker.ingest_connector")
|
|
def test_calls_ingest_connector_defaults(self, mock_worker):
|
|
from application.api.user.tasks import ingest_connector_task
|
|
|
|
mock_worker.return_value = {"status": "ok"}
|
|
|
|
result = ingest_connector_task("job1", "user1", "gdrive")
|
|
|
|
mock_worker.assert_called_once_with(
|
|
ANY,
|
|
"job1",
|
|
"user1",
|
|
"gdrive",
|
|
session_token=None,
|
|
file_ids=None,
|
|
folder_ids=None,
|
|
recursive=True,
|
|
retriever="classic",
|
|
operation_mode="upload",
|
|
doc_id=None,
|
|
sync_frequency="never",
|
|
idempotency_key=None,
|
|
source_id=None,
|
|
)
|
|
assert result == {"status": "ok"}
|
|
|
|
@pytest.mark.unit
|
|
@patch("application.worker.ingest_connector")
|
|
def test_calls_ingest_connector_custom(self, mock_worker):
|
|
from application.api.user.tasks import ingest_connector_task
|
|
|
|
mock_worker.return_value = {"status": "ok"}
|
|
|
|
result = ingest_connector_task(
|
|
"job1",
|
|
"user1",
|
|
"sharepoint",
|
|
session_token="tok",
|
|
file_ids=["f1"],
|
|
folder_ids=["d1"],
|
|
recursive=False,
|
|
retriever="duckdb",
|
|
operation_mode="sync",
|
|
doc_id="doc1",
|
|
sync_frequency="daily",
|
|
)
|
|
|
|
mock_worker.assert_called_once_with(
|
|
ANY,
|
|
"job1",
|
|
"user1",
|
|
"sharepoint",
|
|
session_token="tok",
|
|
file_ids=["f1"],
|
|
folder_ids=["d1"],
|
|
recursive=False,
|
|
retriever="duckdb",
|
|
operation_mode="sync",
|
|
doc_id="doc1",
|
|
sync_frequency="daily",
|
|
idempotency_key=None,
|
|
source_id=None,
|
|
)
|
|
assert result == {"status": "ok"}
|
|
|
|
|
|
class TestSetupPeriodicTasks:
|
|
@pytest.mark.unit
|
|
def test_registers_periodic_tasks(self):
|
|
from application.api.user.tasks import setup_periodic_tasks
|
|
|
|
sender = MagicMock()
|
|
|
|
setup_periodic_tasks(sender)
|
|
|
|
assert sender.add_periodic_task.call_count == 8
|
|
|
|
calls = sender.add_periodic_task.call_args_list
|
|
|
|
# daily
|
|
assert calls[0][0][0] == timedelta(days=1)
|
|
# weekly
|
|
assert calls[1][0][0] == timedelta(weeks=1)
|
|
# monthly
|
|
assert calls[2][0][0] == timedelta(days=30)
|
|
# pending_tool_state TTL cleanup (60s)
|
|
assert calls[3][0][0] == timedelta(seconds=60)
|
|
assert calls[3][1].get("name") == "cleanup-pending-tool-state"
|
|
# idempotency dedup TTL cleanup (1h)
|
|
assert calls[4][0][0] == timedelta(hours=1)
|
|
assert calls[4][1].get("name") == "cleanup-idempotency-dedup"
|
|
# reconciliation sweep (30s)
|
|
assert calls[5][0][0] == timedelta(seconds=30)
|
|
assert calls[5][1].get("name") == "reconciliation"
|
|
# version-check (every 7h)
|
|
assert calls[6][0][0] == timedelta(hours=7)
|
|
# message_events retention sweep (24h)
|
|
assert calls[7][0][0] == timedelta(hours=24)
|
|
assert calls[7][1].get("name") == "cleanup-message-events"
|
|
|
|
|
|
class TestMcpOauthTask:
|
|
@pytest.mark.unit
|
|
@patch("application.api.user.tasks.mcp_oauth")
|
|
def test_calls_mcp_oauth(self, mock_worker):
|
|
from application.api.user.tasks import mcp_oauth_task
|
|
|
|
mock_worker.return_value = {"url": "http://auth"}
|
|
|
|
result = mcp_oauth_task({"server": "mcp"}, "user1")
|
|
|
|
mock_worker.assert_called_once_with(ANY, {"server": "mcp"}, "user1")
|
|
assert result == {"url": "http://auth"}
|
|
|
|
|
|
class TestDurableTaskRetryPolicy:
|
|
"""The long-running tasks share a uniform retry policy."""
|
|
|
|
@pytest.mark.unit
|
|
@pytest.mark.parametrize(
|
|
"task_name",
|
|
[
|
|
"ingest",
|
|
"ingest_remote",
|
|
"reingest_source_task",
|
|
"store_attachment",
|
|
"process_agent_webhook",
|
|
"ingest_connector_task",
|
|
],
|
|
)
|
|
def test_task_has_retry_config(self, task_name):
|
|
import application.api.user.tasks as tasks_module
|
|
|
|
task = getattr(tasks_module, task_name)
|
|
assert task.acks_late is True
|
|
assert Exception in task.autoretry_for
|
|
assert task.retry_backoff is True
|
|
assert task.retry_kwargs == {"max_retries": 3, "countdown": 60}
|
|
|
|
@pytest.mark.unit
|
|
@pytest.mark.parametrize(
|
|
"task_name",
|
|
[
|
|
"schedule_syncs",
|
|
"sync_source",
|
|
"mcp_oauth_task",
|
|
"cleanup_pending_tool_state",
|
|
"reconciliation_task",
|
|
"version_check_task",
|
|
],
|
|
)
|
|
def test_short_periodic_tasks_have_no_retry_config(self, task_name):
|
|
import application.api.user.tasks as tasks_module
|
|
|
|
task = getattr(tasks_module, task_name)
|
|
assert not getattr(task, "autoretry_for", None)
|
|
|
|
|
|
class TestProcessAgentWebhookIdempotency:
|
|
"""Wrapper short-circuits a second call with the same key on the durable webhook task."""
|
|
|
|
@pytest.mark.unit
|
|
def test_repeat_with_same_key_short_circuits(self, pg_conn):
|
|
from application.api.user.tasks import process_agent_webhook
|
|
|
|
worker_calls = []
|
|
|
|
def _fake_worker(self, agent_id, payload):
|
|
worker_calls.append((agent_id, payload))
|
|
return {"status": "success", "result": {"answer": "ok"}}
|
|
|
|
with _patch_decorator_db(pg_conn), patch(
|
|
"application.api.user.tasks.agent_webhook_worker",
|
|
side_effect=_fake_worker,
|
|
):
|
|
first = process_agent_webhook(
|
|
"agent", {"event": "x"}, idempotency_key="dur-k1",
|
|
)
|
|
second = process_agent_webhook(
|
|
"agent", {"event": "x"}, idempotency_key="dur-k1",
|
|
)
|
|
|
|
assert first == {"status": "success", "result": {"answer": "ok"}}
|
|
assert second == first
|
|
assert len(worker_calls) == 1
|
|
|
|
|
|
class TestCleanupPendingToolState:
|
|
"""Janitor reverts stale 'resuming' rows and deletes TTL-expired rows."""
|
|
|
|
@pytest.mark.unit
|
|
def test_reverts_stale_and_deletes_expired(self, pg_conn):
|
|
from sqlalchemy import text as _text
|
|
|
|
from application.api.user.tasks import cleanup_pending_tool_state
|
|
from application.storage.db.repositories.conversations import (
|
|
ConversationsRepository,
|
|
)
|
|
from application.storage.db.repositories.pending_tool_state import (
|
|
PendingToolStateRepository,
|
|
)
|
|
|
|
repo = PendingToolStateRepository(pg_conn)
|
|
|
|
def _sample() -> dict:
|
|
return {
|
|
"messages": [],
|
|
"pending_tool_calls": [],
|
|
"tools_dict": {},
|
|
"tool_schemas": [],
|
|
"agent_config": {},
|
|
}
|
|
|
|
# Pending and fresh — should be left alone.
|
|
c1 = ConversationsRepository(pg_conn).create("u", "fresh-pending")
|
|
repo.save_state(c1["id"], "u", **_sample())
|
|
|
|
# Pending but already expired — should be deleted.
|
|
c2 = ConversationsRepository(pg_conn).create("u", "expired-pending")
|
|
repo.save_state(c2["id"], "u", **_sample(), ttl_seconds=0)
|
|
|
|
# Resuming within grace — should stay 'resuming'.
|
|
c3 = ConversationsRepository(pg_conn).create("u", "fresh-resuming")
|
|
repo.save_state(c3["id"], "u", **_sample())
|
|
repo.mark_resuming(c3["id"], "u")
|
|
|
|
# Resuming past grace — should revert to 'pending'.
|
|
c4 = ConversationsRepository(pg_conn).create("u", "stale-resuming")
|
|
repo.save_state(c4["id"], "u", **_sample())
|
|
repo.mark_resuming(c4["id"], "u")
|
|
pg_conn.execute(
|
|
_text(
|
|
"UPDATE pending_tool_state "
|
|
"SET resumed_at = clock_timestamp() "
|
|
" - make_interval(secs => 660) "
|
|
"WHERE conversation_id = CAST(:conv_id AS uuid)"
|
|
),
|
|
{"conv_id": c4["id"]},
|
|
)
|
|
|
|
from contextlib import contextmanager
|
|
|
|
@contextmanager
|
|
def _fake_begin():
|
|
yield pg_conn
|
|
|
|
fake_engine = MagicMock()
|
|
fake_engine.begin = _fake_begin
|
|
|
|
with patch(
|
|
"application.storage.db.engine.get_engine",
|
|
return_value=fake_engine,
|
|
):
|
|
result = cleanup_pending_tool_state.run()
|
|
|
|
assert result["reverted"] == 1
|
|
assert result["deleted"] == 1
|
|
|
|
# Final state assertions.
|
|
assert repo.load_state(c1["id"], "u")["status"] == "pending"
|
|
assert repo.load_state(c2["id"], "u") is None
|
|
assert repo.load_state(c3["id"], "u")["status"] == "resuming"
|
|
c4_row = repo.load_state(c4["id"], "u")
|
|
assert c4_row["status"] == "pending"
|
|
assert c4_row["resumed_at"] is None
|
|
|
|
@pytest.mark.unit
|
|
def test_skips_when_postgres_uri_missing(self, monkeypatch):
|
|
from application.api.user.tasks import cleanup_pending_tool_state
|
|
from application.core.settings import settings
|
|
|
|
monkeypatch.setattr(settings, "POSTGRES_URI", None, raising=False)
|
|
|
|
result = cleanup_pending_tool_state.run()
|
|
assert result == {
|
|
"deleted": 0,
|
|
"reverted": 0,
|
|
"skipped": "POSTGRES_URI not set",
|
|
}
|
|
|
|
|
|
class TestCleanupMessageEventsTask:
|
|
"""Retention janitor delegates to MessageEventsRepository.cleanup_older_than."""
|
|
|
|
@pytest.mark.unit
|
|
def test_skips_when_postgres_uri_missing(self, monkeypatch):
|
|
from application.api.user.tasks import cleanup_message_events
|
|
from application.core.settings import settings
|
|
|
|
monkeypatch.setattr(settings, "POSTGRES_URI", None, raising=False)
|
|
|
|
result = cleanup_message_events.run()
|
|
assert result == {"deleted": 0, "skipped": "POSTGRES_URI not set"}
|
|
|
|
@pytest.mark.unit
|
|
def test_deletes_rows_past_retention_window(self, pg_conn, monkeypatch):
|
|
import uuid
|
|
|
|
from sqlalchemy import text as _text
|
|
|
|
from application.api.user.tasks import cleanup_message_events
|
|
from application.core.settings import settings
|
|
from application.storage.db.repositories.message_events import (
|
|
MessageEventsRepository,
|
|
)
|
|
|
|
# Seed parent rows so the FK on message_events holds.
|
|
user_id = f"user-{uuid.uuid4().hex[:8]}"
|
|
conv_id = uuid.uuid4()
|
|
msg_id = uuid.uuid4()
|
|
pg_conn.execute(
|
|
_text("INSERT INTO users (user_id) VALUES (:u)"),
|
|
{"u": user_id},
|
|
)
|
|
pg_conn.execute(
|
|
_text(
|
|
"INSERT INTO conversations (id, user_id, name) "
|
|
"VALUES (:id, :u, 'test')"
|
|
),
|
|
{"id": conv_id, "u": user_id},
|
|
)
|
|
pg_conn.execute(
|
|
_text(
|
|
"INSERT INTO conversation_messages (id, conversation_id, "
|
|
"user_id, position) VALUES (:id, :c, :u, 0)"
|
|
),
|
|
{"id": msg_id, "c": conv_id, "u": user_id},
|
|
)
|
|
|
|
repo = MessageEventsRepository(pg_conn)
|
|
repo.record(str(msg_id), 0, "answer", {"chunk": "stale"})
|
|
repo.record(str(msg_id), 1, "answer", {"chunk": "fresh"})
|
|
# Backdate seq=0 past the default 14-day retention so the
|
|
# janitor catches it; seq=1 stays at "now" and must survive.
|
|
pg_conn.execute(
|
|
_text(
|
|
"UPDATE message_events SET created_at = now() - interval '20 days' "
|
|
"WHERE message_id = CAST(:id AS uuid) AND sequence_no = 0"
|
|
),
|
|
{"id": str(msg_id)},
|
|
)
|
|
|
|
monkeypatch.setattr(
|
|
settings, "POSTGRES_URI", "postgresql://stub", raising=False
|
|
)
|
|
|
|
@contextmanager
|
|
def _fake_begin():
|
|
yield pg_conn
|
|
|
|
fake_engine = MagicMock()
|
|
fake_engine.begin = _fake_begin
|
|
|
|
with patch(
|
|
"application.storage.db.engine.get_engine",
|
|
return_value=fake_engine,
|
|
):
|
|
result = cleanup_message_events.run()
|
|
|
|
assert result == {
|
|
"deleted": 1,
|
|
"ttl_days": settings.MESSAGE_EVENTS_RETENTION_DAYS,
|
|
}
|
|
# Only the fresh row survives.
|
|
rows = repo.read_after(str(msg_id))
|
|
assert [r["sequence_no"] for r in rows] == [1]
|
|
|
|
|
|
class TestIngestIdempotency:
|
|
"""Same short-circuit applies to the ingest task path."""
|
|
|
|
@pytest.mark.unit
|
|
def test_repeat_with_same_key_short_circuits(self, pg_conn):
|
|
from application.api.user.tasks import ingest
|
|
|
|
worker_calls = []
|
|
|
|
def _fake_worker(self, directory, formats, job_name, file_path,
|
|
filename, user, file_name_map=None,
|
|
idempotency_key=None, source_id=None):
|
|
worker_calls.append(filename)
|
|
return {"status": "ok", "directory": directory}
|
|
|
|
with _patch_decorator_db(pg_conn), patch(
|
|
"application.api.user.tasks.ingest_worker",
|
|
side_effect=_fake_worker,
|
|
):
|
|
first = ingest(
|
|
"dir", ["pdf"], "job1", "user1", "/path", "file.pdf",
|
|
idempotency_key="dur-ing-1",
|
|
)
|
|
second = ingest(
|
|
"dir", ["pdf"], "job1", "user1", "/path", "file.pdf",
|
|
idempotency_key="dur-ing-1",
|
|
)
|
|
|
|
assert first == second
|
|
assert first == {"status": "ok", "directory": "dir"}
|
|
assert len(worker_calls) == 1
|
|
|
|
|
|
class TestIngestPoisonEvent:
|
|
"""The poison hook publishes a terminal source.ingest.failed so the
|
|
upload toast resolves instead of hanging on "training".
|
|
"""
|
|
|
|
@pytest.mark.unit
|
|
def test_publishes_failed_event(self):
|
|
from application.api.user.tasks import _emit_ingest_poison_event
|
|
|
|
published = []
|
|
|
|
def _fake_publish(user, event_type, payload, *, scope=None):
|
|
published.append((user, event_type, payload, scope))
|
|
|
|
with patch(
|
|
"application.events.publisher.publish_user_event",
|
|
side_effect=_fake_publish,
|
|
):
|
|
_emit_ingest_poison_event(
|
|
"ingest",
|
|
{"user": "u1", "source_id": "src-9", "filename": "doc.pdf"},
|
|
)
|
|
|
|
assert len(published) == 1
|
|
user, event_type, payload, scope = published[0]
|
|
assert user == "u1"
|
|
assert event_type == "source.ingest.failed"
|
|
assert payload["source_id"] == "src-9"
|
|
assert payload["filename"] == "doc.pdf"
|
|
assert payload["operation"] == "upload"
|
|
assert scope == {"kind": "source", "id": "src-9"}
|
|
|
|
@pytest.mark.unit
|
|
def test_skips_when_source_id_missing(self):
|
|
from application.api.user.tasks import _emit_ingest_poison_event
|
|
|
|
with patch(
|
|
"application.events.publisher.publish_user_event",
|
|
) as mock_publish:
|
|
_emit_ingest_poison_event("ingest", {"user": "u1"})
|
|
|
|
mock_publish.assert_not_called()
|
|
|
|
@pytest.mark.unit
|
|
def test_reingest_uses_reingest_operation(self):
|
|
from application.api.user.tasks import _emit_ingest_poison_event
|
|
|
|
published = []
|
|
with patch(
|
|
"application.events.publisher.publish_user_event",
|
|
side_effect=lambda *a, **k: published.append((a, k)),
|
|
):
|
|
_emit_ingest_poison_event(
|
|
"reingest_source_task",
|
|
{"user": "u1", "source_id": "src-r"},
|
|
)
|
|
|
|
assert published[0][0][2]["operation"] == "reingest"
|