mirror of
https://github.com/pocketpaw/pocketpaw.git
synced 2026-05-21 01:04:57 +00:00
- Mock GoalParser in deep_work_session tests (was hitting real LLM, 4m10s → 1s) - Add embedding_provider="hash" to file_memory_fixes tests (was timing out on ollama) - Fix mock_config fixture in test_remote_access to use tmp_path (race condition under -n auto) - Update assertions for WEBSOCKET channel now having format hints - Accept 401 for /api/v1/sessions when ee.cloud router overrides core router - Exclude tests/cloud and tests/e2e from default pytest run (pymongo hangs) - Reorganize ee/cloud tests into tests/cloud/, v1 tests into tests/v1/ Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1233 lines
43 KiB
Python
1233 lines
43 KiB
Python
# Tests for file memory fixes: UUID collision, fuzzy search, daily loading,
|
|
# dedup, persistent delete, ForgetTool, auto-learn, context limits.
|
|
# Created: 2026-02-09
|
|
|
|
import sqlite3
|
|
from datetime import UTC, date, datetime, timedelta
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from pocketpaw.memory.file_store import FileMemoryStore, _make_deterministic_id, _tokenize
|
|
from pocketpaw.memory.manager import MemoryManager
|
|
from pocketpaw.memory.protocol import MemoryEntry, MemoryType
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.fixture
|
|
def tmp_store(tmp_path):
|
|
"""Create a FileMemoryStore with a temp directory."""
|
|
return FileMemoryStore(base_path=tmp_path)
|
|
|
|
|
|
@pytest.fixture
|
|
def tmp_manager(tmp_store):
|
|
"""Create a MemoryManager wrapping a temp FileMemoryStore."""
|
|
return MemoryManager(store=tmp_store)
|
|
|
|
|
|
# ===========================================================================
|
|
# TestUUIDCollision
|
|
# ===========================================================================
|
|
|
|
|
|
class TestUUIDCollision:
|
|
"""Step 1: Entries with the same header but different body get unique IDs."""
|
|
|
|
async def test_multiple_memories_survive(self, tmp_store):
|
|
"""Two entries with header 'Memory' but different content get different IDs."""
|
|
e1 = MemoryEntry(
|
|
id="",
|
|
type=MemoryType.LONG_TERM,
|
|
content="User's name is Rohit",
|
|
metadata={"header": "Memory"},
|
|
)
|
|
e2 = MemoryEntry(
|
|
id="",
|
|
type=MemoryType.LONG_TERM,
|
|
content="User prefers dark mode",
|
|
metadata={"header": "Memory"},
|
|
)
|
|
id1 = await tmp_store.save(e1)
|
|
id2 = await tmp_store.save(e2)
|
|
|
|
assert id1 != id2
|
|
assert len(tmp_store._index) >= 2
|
|
|
|
# Both retrievable
|
|
assert (await tmp_store.get(id1)) is not None
|
|
assert (await tmp_store.get(id2)) is not None
|
|
|
|
async def test_survive_restart(self, tmp_path):
|
|
"""Memories survive creating a new FileMemoryStore (simulating restart)."""
|
|
store1 = FileMemoryStore(base_path=tmp_path)
|
|
await store1.save(
|
|
MemoryEntry(
|
|
id="",
|
|
type=MemoryType.LONG_TERM,
|
|
content="Fact A",
|
|
metadata={"header": "Memory"},
|
|
)
|
|
)
|
|
await store1.save(
|
|
MemoryEntry(
|
|
id="",
|
|
type=MemoryType.LONG_TERM,
|
|
content="Fact B",
|
|
metadata={"header": "Memory"},
|
|
)
|
|
)
|
|
await store1.save(
|
|
MemoryEntry(
|
|
id="",
|
|
type=MemoryType.LONG_TERM,
|
|
content="Fact C",
|
|
metadata={"header": "Memory"},
|
|
)
|
|
)
|
|
|
|
# Simulate restart
|
|
store2 = FileMemoryStore(base_path=tmp_path)
|
|
lt = await store2.get_by_type(MemoryType.LONG_TERM)
|
|
assert len(lt) == 3
|
|
contents = {e.content for e in lt}
|
|
assert contents == {"Fact A", "Fact B", "Fact C"}
|
|
|
|
async def test_custom_headers_work(self, tmp_store):
|
|
"""Entries with different headers also get unique IDs."""
|
|
id1 = await tmp_store.save(
|
|
MemoryEntry(
|
|
id="",
|
|
type=MemoryType.LONG_TERM,
|
|
content="Same body",
|
|
metadata={"header": "Header A"},
|
|
)
|
|
)
|
|
id2 = await tmp_store.save(
|
|
MemoryEntry(
|
|
id="",
|
|
type=MemoryType.LONG_TERM,
|
|
content="Same body",
|
|
metadata={"header": "Header B"},
|
|
)
|
|
)
|
|
assert id1 != id2
|
|
|
|
async def test_deterministic_id_includes_body(self, tmp_path):
|
|
"""_make_deterministic_id produces different IDs for different bodies."""
|
|
p = tmp_path / "test.md"
|
|
id1 = _make_deterministic_id(p, "Memory", "Fact one")
|
|
id2 = _make_deterministic_id(p, "Memory", "Fact two")
|
|
assert id1 != id2
|
|
|
|
async def test_deterministic_id_same_content(self, tmp_path):
|
|
"""Same path/header/body always yields the same ID."""
|
|
p = tmp_path / "test.md"
|
|
id1 = _make_deterministic_id(p, "Memory", "Fact one")
|
|
id2 = _make_deterministic_id(p, "Memory", "Fact one")
|
|
assert id1 == id2
|
|
|
|
|
|
class TestVectorSchemaMigrations:
|
|
"""Schema version checks for sqlite vector index."""
|
|
|
|
async def test_vector_schema_migrates_user_version_on_init(self, tmp_path):
|
|
db_path = tmp_path / "vector_index.sqlite3"
|
|
with sqlite3.connect(db_path) as conn:
|
|
conn.execute("PRAGMA user_version = 0")
|
|
|
|
FileMemoryStore(base_path=tmp_path, vector_enabled=True, embedding_provider="hash")
|
|
|
|
with sqlite3.connect(db_path) as conn:
|
|
version_row = conn.execute("PRAGMA user_version").fetchone()
|
|
assert version_row is not None
|
|
assert int(version_row[0]) == 1
|
|
|
|
columns = conn.execute("PRAGMA table_info(memory_vectors)").fetchall()
|
|
column_names = {str(row[1]) for row in columns}
|
|
assert "doc_id" in column_names
|
|
assert "embedding_json" in column_names
|
|
|
|
|
|
# ===========================================================================
|
|
# TestFuzzySearch
|
|
# ===========================================================================
|
|
|
|
|
|
class TestFuzzySearch:
|
|
"""Step 3: Word-overlap search replaces broken substring match."""
|
|
|
|
async def test_word_overlap_matches(self, tmp_store):
|
|
"""'name' matches 'User's name is Rohit'."""
|
|
await tmp_store.save(
|
|
MemoryEntry(
|
|
id="",
|
|
type=MemoryType.LONG_TERM,
|
|
content="User's name is Rohit",
|
|
metadata={"header": "Memory"},
|
|
)
|
|
)
|
|
results = await tmp_store.search("name")
|
|
assert len(results) == 1
|
|
assert "Rohit" in results[0].content
|
|
|
|
async def test_multi_word_query(self, tmp_store):
|
|
"""Multi-word query scores by overlap ratio."""
|
|
await tmp_store.save(
|
|
MemoryEntry(
|
|
id="",
|
|
type=MemoryType.LONG_TERM,
|
|
content="User's name is Rohit",
|
|
metadata={"header": "Memory"},
|
|
)
|
|
)
|
|
await tmp_store.save(
|
|
MemoryEntry(
|
|
id="",
|
|
type=MemoryType.LONG_TERM,
|
|
content="Rohit prefers dark mode",
|
|
metadata={"header": "Memory"},
|
|
)
|
|
)
|
|
|
|
# "Rohit name" has 2 query words; first entry matches both, second matches 1
|
|
results = await tmp_store.search("Rohit name")
|
|
assert len(results) == 2
|
|
assert "name is Rohit" in results[0].content # Higher score (2/2)
|
|
|
|
async def test_no_false_matches(self, tmp_store):
|
|
"""Query with no overlapping words returns nothing."""
|
|
await tmp_store.save(
|
|
MemoryEntry(
|
|
id="",
|
|
type=MemoryType.LONG_TERM,
|
|
content="User prefers dark mode",
|
|
metadata={"header": "Memory"},
|
|
)
|
|
)
|
|
results = await tmp_store.search("banana")
|
|
assert len(results) == 0
|
|
|
|
async def test_stop_words_excluded(self):
|
|
"""Tokenizer strips stop words."""
|
|
tokens = _tokenize("the user is a developer")
|
|
assert "the" not in tokens
|
|
assert "is" not in tokens
|
|
assert "a" not in tokens
|
|
assert "user" in tokens
|
|
assert "developer" in tokens
|
|
|
|
async def test_search_includes_header(self, tmp_store):
|
|
"""Search also matches against the header text."""
|
|
await tmp_store.save(
|
|
MemoryEntry(
|
|
id="",
|
|
type=MemoryType.LONG_TERM,
|
|
content="Likes coffee",
|
|
metadata={"header": "Preferences"},
|
|
)
|
|
)
|
|
results = await tmp_store.search("preferences")
|
|
assert len(results) == 1
|
|
|
|
async def test_ranking_order(self, tmp_store):
|
|
"""Results are sorted by descending overlap score."""
|
|
await tmp_store.save(
|
|
MemoryEntry(
|
|
id="",
|
|
type=MemoryType.LONG_TERM,
|
|
content="Python developer",
|
|
metadata={"header": "Memory"},
|
|
)
|
|
)
|
|
await tmp_store.save(
|
|
MemoryEntry(
|
|
id="",
|
|
type=MemoryType.LONG_TERM,
|
|
content="Python backend developer at Google",
|
|
metadata={"header": "Memory"},
|
|
)
|
|
)
|
|
|
|
results = await tmp_store.search("Python backend developer Google")
|
|
assert len(results) == 2
|
|
# Second entry matches more words → should be first
|
|
assert "Google" in results[0].content
|
|
|
|
|
|
# ===========================================================================
|
|
# TestDailyFileIndexing
|
|
# ===========================================================================
|
|
|
|
|
|
class TestDailyFileIndexing:
|
|
"""Step 2: Past daily files are loaded, not just today's."""
|
|
|
|
async def test_past_daily_files_loaded(self, tmp_path):
|
|
"""Memories in yesterday's daily file are available after restart."""
|
|
yesterday = date.today() - timedelta(days=1)
|
|
daily_file = tmp_path / f"{yesterday.isoformat()}.md"
|
|
daily_file.write_text("## 10:30\n\nHad meeting with Alice\n")
|
|
|
|
store = FileMemoryStore(base_path=tmp_path)
|
|
daily_entries = await store.get_by_type(MemoryType.DAILY)
|
|
assert len(daily_entries) == 1
|
|
assert "Alice" in daily_entries[0].content
|
|
|
|
async def test_multiple_daily_files(self, tmp_path):
|
|
"""Multiple past daily files are all loaded."""
|
|
for i in range(3):
|
|
d = date.today() - timedelta(days=i)
|
|
f = tmp_path / f"{d.isoformat()}.md"
|
|
f.write_text(f"## Note\n\nDay {i} note\n")
|
|
|
|
# Create sessions dir (needed by constructor)
|
|
(tmp_path / "sessions").mkdir(exist_ok=True)
|
|
store = FileMemoryStore(base_path=tmp_path)
|
|
daily_entries = await store.get_by_type(MemoryType.DAILY)
|
|
assert len(daily_entries) == 3
|
|
|
|
|
|
# ===========================================================================
|
|
# TestDeduplication
|
|
# ===========================================================================
|
|
|
|
|
|
class TestDeduplication:
|
|
"""Step 1 dedup: saving the same fact twice doesn't create a duplicate."""
|
|
|
|
async def test_same_fact_not_duplicated(self, tmp_store):
|
|
"""Saving identical content returns the same ID and doesn't grow index."""
|
|
id1 = await tmp_store.save(
|
|
MemoryEntry(
|
|
id="",
|
|
type=MemoryType.LONG_TERM,
|
|
content="User's name is Rohit",
|
|
metadata={"header": "Memory"},
|
|
)
|
|
)
|
|
count_after_first = len(tmp_store._index)
|
|
|
|
id2 = await tmp_store.save(
|
|
MemoryEntry(
|
|
id="",
|
|
type=MemoryType.LONG_TERM,
|
|
content="User's name is Rohit",
|
|
metadata={"header": "Memory"},
|
|
)
|
|
)
|
|
|
|
assert id1 == id2
|
|
assert len(tmp_store._index) == count_after_first
|
|
|
|
async def test_dedup_across_restart(self, tmp_path):
|
|
"""Dedup works after reloading from disk."""
|
|
store1 = FileMemoryStore(base_path=tmp_path)
|
|
await store1.save(
|
|
MemoryEntry(
|
|
id="",
|
|
type=MemoryType.LONG_TERM,
|
|
content="Fact X",
|
|
metadata={"header": "Memory"},
|
|
)
|
|
)
|
|
|
|
store2 = FileMemoryStore(base_path=tmp_path)
|
|
count_before = len(store2._index)
|
|
|
|
await store2.save(
|
|
MemoryEntry(
|
|
id="",
|
|
type=MemoryType.LONG_TERM,
|
|
content="Fact X",
|
|
metadata={"header": "Memory"},
|
|
)
|
|
)
|
|
assert len(store2._index) == count_before
|
|
|
|
|
|
# ===========================================================================
|
|
# TestPersistentDelete
|
|
# ===========================================================================
|
|
|
|
|
|
class TestPersistentDelete:
|
|
"""Step 4: Deletions persist across restarts."""
|
|
|
|
async def test_delete_persists_across_restart(self, tmp_path):
|
|
"""A deleted entry does not reappear after reload."""
|
|
store1 = FileMemoryStore(base_path=tmp_path)
|
|
id1 = await store1.save(
|
|
MemoryEntry(
|
|
id="",
|
|
type=MemoryType.LONG_TERM,
|
|
content="Keep this",
|
|
metadata={"header": "Memory"},
|
|
)
|
|
)
|
|
id2 = await store1.save(
|
|
MemoryEntry(
|
|
id="",
|
|
type=MemoryType.LONG_TERM,
|
|
content="Delete this",
|
|
metadata={"header": "Memory"},
|
|
)
|
|
)
|
|
|
|
deleted = await store1.delete(id2)
|
|
assert deleted is True
|
|
|
|
# Restart
|
|
store2 = FileMemoryStore(base_path=tmp_path)
|
|
assert await store2.get(id1) is not None
|
|
assert await store2.get(id2) is None
|
|
|
|
lt = await store2.get_by_type(MemoryType.LONG_TERM)
|
|
assert len(lt) == 1
|
|
assert lt[0].content == "Keep this"
|
|
|
|
async def test_delete_only_removes_target(self, tmp_path):
|
|
"""Deleting one entry doesn't affect others in the same file."""
|
|
store = FileMemoryStore(base_path=tmp_path)
|
|
ids = []
|
|
for i in range(5):
|
|
eid = await store.save(
|
|
MemoryEntry(
|
|
id="",
|
|
type=MemoryType.LONG_TERM,
|
|
content=f"Fact {i}",
|
|
metadata={"header": "Memory"},
|
|
)
|
|
)
|
|
ids.append(eid)
|
|
|
|
await store.delete(ids[2])
|
|
|
|
remaining = await store.get_by_type(MemoryType.LONG_TERM)
|
|
assert len(remaining) == 4
|
|
contents = {e.content for e in remaining}
|
|
assert "Fact 2" not in contents
|
|
for i in [0, 1, 3, 4]:
|
|
assert f"Fact {i}" in contents
|
|
|
|
async def test_delete_last_entry_removes_file(self, tmp_path):
|
|
"""Deleting the only entry in a file removes the file."""
|
|
store = FileMemoryStore(base_path=tmp_path)
|
|
eid = await store.save(
|
|
MemoryEntry(
|
|
id="",
|
|
type=MemoryType.LONG_TERM,
|
|
content="Only fact",
|
|
metadata={"header": "Memory"},
|
|
)
|
|
)
|
|
|
|
assert store.long_term_file.exists()
|
|
await store.delete(eid)
|
|
assert not store.long_term_file.exists()
|
|
|
|
|
|
# ===========================================================================
|
|
# TestForgetTool
|
|
# ===========================================================================
|
|
|
|
|
|
class TestForgetTool:
|
|
"""Step 5: ForgetTool searches and deletes memories."""
|
|
|
|
def test_forget_tool_definition(self):
|
|
"""ForgetTool has correct name and required params."""
|
|
from pocketpaw.tools.builtin.memory import ForgetTool
|
|
|
|
tool = ForgetTool()
|
|
assert tool.name == "forget"
|
|
assert "query" in tool.parameters["properties"]
|
|
assert "query" in tool.parameters["required"]
|
|
|
|
async def test_forget_removes_matching_memory(self, tmp_path):
|
|
"""ForgetTool deletes memories matching the query."""
|
|
from pocketpaw.tools.builtin.memory import ForgetTool
|
|
|
|
store = FileMemoryStore(base_path=tmp_path)
|
|
manager = MemoryManager(store=store)
|
|
|
|
await manager.remember("User's name is Rohit")
|
|
await manager.remember("User prefers dark mode")
|
|
|
|
# Verify both exist
|
|
all_lt = await store.get_by_type(MemoryType.LONG_TERM)
|
|
assert len(all_lt) == 2
|
|
|
|
tool = ForgetTool()
|
|
with patch("pocketpaw.tools.builtin.memory.get_memory_manager", return_value=manager):
|
|
result = await tool.execute(query="name Rohit")
|
|
|
|
assert "Forgot" in result
|
|
assert "1" in result
|
|
|
|
remaining = await store.get_by_type(MemoryType.LONG_TERM)
|
|
assert len(remaining) == 1
|
|
assert "dark mode" in remaining[0].content
|
|
|
|
def test_forget_in_policy_group(self):
|
|
"""'forget' is in the group:memory policy group."""
|
|
from pocketpaw.tools.policy import TOOL_GROUPS
|
|
|
|
assert "forget" in TOOL_GROUPS["group:memory"]
|
|
|
|
|
|
# ===========================================================================
|
|
# TestFileAutoLearn
|
|
# ===========================================================================
|
|
|
|
|
|
class TestFileAutoLearn:
|
|
"""Step 7: LLM-based auto-fact extraction for file backend."""
|
|
|
|
async def test_extracts_facts(self, tmp_path):
|
|
"""_file_auto_learn calls Haiku and saves extracted facts."""
|
|
store = FileMemoryStore(base_path=tmp_path)
|
|
manager = MemoryManager(store=store)
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.content = [MagicMock(text='["User name is Rohit", "Likes Python"]')]
|
|
|
|
mock_client = AsyncMock()
|
|
mock_client.messages.create = AsyncMock(return_value=mock_response)
|
|
|
|
with patch("anthropic.AsyncAnthropic", return_value=mock_client):
|
|
result = await manager._file_auto_learn(
|
|
[
|
|
{"role": "user", "content": "My name is Rohit and I like Python"},
|
|
{"role": "assistant", "content": "Nice to meet you, Rohit!"},
|
|
]
|
|
)
|
|
|
|
assert len(result.get("results", [])) == 2
|
|
|
|
lt = await store.get_by_type(MemoryType.LONG_TERM)
|
|
assert len(lt) == 2
|
|
contents = {e.content for e in lt}
|
|
assert "User name is Rohit" in contents
|
|
assert "Likes Python" in contents
|
|
|
|
async def test_graceful_without_api_key(self, tmp_path):
|
|
"""_file_auto_learn returns empty dict when API is unavailable."""
|
|
store = FileMemoryStore(base_path=tmp_path)
|
|
manager = MemoryManager(store=store)
|
|
|
|
with patch(
|
|
"anthropic.AsyncAnthropic",
|
|
side_effect=Exception("No API key"),
|
|
):
|
|
result = await manager._file_auto_learn(
|
|
[
|
|
{"role": "user", "content": "Hello"},
|
|
]
|
|
)
|
|
|
|
assert result == {}
|
|
|
|
async def test_deduplicates_auto_learned_facts(self, tmp_path):
|
|
"""Auto-learned facts are deduped by deterministic IDs."""
|
|
store = FileMemoryStore(base_path=tmp_path)
|
|
manager = MemoryManager(store=store)
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.content = [MagicMock(text='["User name is Rohit"]')]
|
|
|
|
mock_client = AsyncMock()
|
|
mock_client.messages.create = AsyncMock(return_value=mock_response)
|
|
|
|
with patch("anthropic.AsyncAnthropic", return_value=mock_client):
|
|
await manager._file_auto_learn(
|
|
[
|
|
{"role": "user", "content": "My name is Rohit"},
|
|
{"role": "assistant", "content": "Hi Rohit!"},
|
|
]
|
|
)
|
|
await manager._file_auto_learn(
|
|
[
|
|
{"role": "user", "content": "My name is Rohit"},
|
|
{"role": "assistant", "content": "Hi Rohit!"},
|
|
]
|
|
)
|
|
|
|
lt = await store.get_by_type(MemoryType.LONG_TERM)
|
|
assert len(lt) == 1
|
|
|
|
async def test_auto_learn_passes_flag(self, tmp_path):
|
|
"""auto_learn() dispatches to _file_auto_learn when file_auto_learn=True."""
|
|
store = FileMemoryStore(base_path=tmp_path)
|
|
manager = MemoryManager(store=store)
|
|
manager._file_auto_learn = AsyncMock(return_value={"results": []})
|
|
|
|
await manager.auto_learn(
|
|
[{"role": "user", "content": "hello"}],
|
|
file_auto_learn=True,
|
|
)
|
|
manager._file_auto_learn.assert_called_once()
|
|
|
|
|
|
# ===========================================================================
|
|
# TestContextLimits
|
|
# ===========================================================================
|
|
|
|
|
|
class TestContextLimits:
|
|
"""Step 6: Increased context injection limits."""
|
|
|
|
async def test_more_than_10_memories_in_context(self, tmp_path):
|
|
"""With default limits, >10 long-term memories appear in context."""
|
|
store = FileMemoryStore(base_path=tmp_path)
|
|
manager = MemoryManager(store=store)
|
|
|
|
for i in range(25):
|
|
await manager.remember(f"Fact number {i}")
|
|
|
|
context = await manager.get_context_for_agent()
|
|
# Count how many "Fact number" entries appear
|
|
count = context.count("Fact number")
|
|
assert count == 25
|
|
|
|
async def test_entry_truncation_at_500(self, tmp_path):
|
|
"""Entries are truncated at 500 chars (not 200)."""
|
|
store = FileMemoryStore(base_path=tmp_path)
|
|
manager = MemoryManager(store=store)
|
|
|
|
long_content = "x" * 600
|
|
await manager.remember(long_content)
|
|
|
|
context = await manager.get_context_for_agent()
|
|
# The truncated content should be 500 chars (plus the "- " prefix)
|
|
# It should NOT be truncated at 200
|
|
assert "x" * 500 in context
|
|
assert "x" * 600 not in context
|
|
|
|
async def test_custom_limits(self, tmp_path):
|
|
"""Custom limits can be passed as parameters."""
|
|
store = FileMemoryStore(base_path=tmp_path)
|
|
manager = MemoryManager(store=store)
|
|
|
|
for i in range(10):
|
|
await manager.remember(f"Fact {i}")
|
|
|
|
context = await manager.get_context_for_agent(long_term_limit=3)
|
|
count = context.count("Fact")
|
|
assert count == 3
|
|
|
|
async def test_max_chars_truncation(self, tmp_path):
|
|
"""Context is truncated when exceeding max_chars."""
|
|
store = FileMemoryStore(base_path=tmp_path)
|
|
manager = MemoryManager(store=store)
|
|
|
|
for i in range(100):
|
|
await manager.remember(f"Fact {i}: " + "a" * 100)
|
|
|
|
context = await manager.get_context_for_agent(max_chars=500)
|
|
assert len(context) <= 520 # 500 + "...(truncated)" suffix
|
|
assert "(truncated)" in context
|
|
|
|
|
|
# ===========================================================================
|
|
# TestFileVectorSemanticSearch
|
|
# ===========================================================================
|
|
|
|
|
|
class TestFileVectorSemanticSearch:
|
|
"""Phase 1: file backend is markdown + vector retrieval."""
|
|
|
|
async def test_semantic_search_returns_relevant_memory(self, tmp_path):
|
|
store = FileMemoryStore(
|
|
base_path=tmp_path,
|
|
vector_enabled=True,
|
|
embedding_provider="hash",
|
|
)
|
|
|
|
await store.save(
|
|
MemoryEntry(
|
|
id="",
|
|
type=MemoryType.LONG_TERM,
|
|
content="API refactor should preserve backward compatibility",
|
|
metadata={"header": "Architecture"},
|
|
)
|
|
)
|
|
await store.save(
|
|
MemoryEntry(
|
|
id="",
|
|
type=MemoryType.LONG_TERM,
|
|
content="Buy groceries tomorrow morning",
|
|
metadata={"header": "Personal"},
|
|
)
|
|
)
|
|
|
|
results = await store.semantic_search("api backward compatibility", user_id="default")
|
|
assert len(results) >= 1
|
|
assert "API refactor" in results[0]["memory"]
|
|
|
|
async def test_delete_removes_vector_record(self, tmp_path):
|
|
store = FileMemoryStore(
|
|
base_path=tmp_path,
|
|
vector_enabled=True,
|
|
embedding_provider="hash",
|
|
)
|
|
|
|
entry_id = await store.save(
|
|
MemoryEntry(
|
|
id="",
|
|
type=MemoryType.LONG_TERM,
|
|
content="Legacy deployment issue in staging",
|
|
metadata={"header": "Deploy"},
|
|
)
|
|
)
|
|
|
|
before = await store.semantic_search("deployment issue", user_id="default")
|
|
assert any(item["id"] == entry_id for item in before)
|
|
|
|
assert await store.delete(entry_id) is True
|
|
|
|
after = await store.semantic_search("deployment issue", user_id="default")
|
|
assert all(item["id"] != entry_id for item in after)
|
|
|
|
async def test_manager_uses_file_semantic_context(self, tmp_path):
|
|
store = FileMemoryStore(
|
|
base_path=tmp_path,
|
|
vector_enabled=True,
|
|
embedding_provider="hash",
|
|
)
|
|
manager = MemoryManager(store=store)
|
|
|
|
await manager.remember(
|
|
"Project Phoenix uses React and TypeScript",
|
|
header="Project",
|
|
)
|
|
|
|
context = await manager.get_semantic_context("what stack does project phoenix use")
|
|
assert "Relevant Memories" in context
|
|
assert "React" in context
|
|
|
|
|
|
class TestFileGraphAndManagement:
|
|
"""Phase 2/3: graph indexing, memory edits, and pruning."""
|
|
|
|
async def test_graph_db_not_created_when_graph_feature_disabled(self, tmp_path):
|
|
store = FileMemoryStore(base_path=tmp_path, vector_enabled=False)
|
|
|
|
assert store._graph_enabled is False
|
|
assert not store._graph_db_path.exists()
|
|
|
|
async def test_graph_snapshot_disabled_returns_empty_without_db_creation(self, tmp_path):
|
|
store = FileMemoryStore(base_path=tmp_path, vector_enabled=False)
|
|
|
|
snapshot = await store.get_graph_snapshot(user_id="default", limit=500)
|
|
|
|
assert snapshot == {"nodes": [], "edges": []}
|
|
assert not store._graph_db_path.exists()
|
|
|
|
async def test_graph_snapshot_limit_500_avoids_sql_variable_limit(self, tmp_path):
|
|
"""Graph snapshot with limit=500 should not exceed SQLite bind variable cap."""
|
|
store = FileMemoryStore(
|
|
base_path=tmp_path,
|
|
vector_enabled=True,
|
|
embedding_provider="hash",
|
|
)
|
|
|
|
now = datetime.now(UTC).isoformat()
|
|
entities = [
|
|
(f"entity_{i}", "default", f"entity-{i}", f"Entity {i}", 1, now, now)
|
|
for i in range(500)
|
|
]
|
|
relationships = [
|
|
(
|
|
f"rel_{i}",
|
|
"default",
|
|
f"entity_{i}",
|
|
f"entity_{i + 1}",
|
|
"related_to",
|
|
1,
|
|
now,
|
|
now,
|
|
)
|
|
for i in range(499)
|
|
]
|
|
|
|
with sqlite3.connect(store._graph_db_path) as conn:
|
|
conn.executemany(
|
|
"""INSERT INTO entities
|
|
(entity_id, user_scope, entity_key, display_name,
|
|
mention_count, first_seen, last_seen)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?)""",
|
|
entities,
|
|
)
|
|
conn.executemany(
|
|
"""INSERT INTO relationships
|
|
(relationship_id, user_scope, source_entity_id, target_entity_id,
|
|
relation_type, weight, first_seen, last_seen)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
|
|
relationships,
|
|
)
|
|
conn.commit()
|
|
|
|
snapshot = await store.get_graph_snapshot(user_id="default", limit=500)
|
|
|
|
assert len(snapshot["nodes"]) == 500
|
|
assert len(snapshot["edges"]) <= 500
|
|
|
|
async def test_graph_snapshot_contains_entities_and_edges(self, tmp_path):
|
|
store = FileMemoryStore(
|
|
base_path=tmp_path,
|
|
vector_enabled=True,
|
|
embedding_provider="hash",
|
|
)
|
|
|
|
await store.save(
|
|
MemoryEntry(
|
|
id="",
|
|
type=MemoryType.LONG_TERM,
|
|
content="Project Phoenix uses React",
|
|
metadata={"header": "Project"},
|
|
)
|
|
)
|
|
|
|
graph = await store.get_graph_snapshot(user_id="default")
|
|
assert len(graph["nodes"]) >= 2
|
|
assert any(edge["relation"] == "uses" for edge in graph["edges"])
|
|
|
|
async def test_update_entry_reindexes_content(self, tmp_path):
|
|
store = FileMemoryStore(
|
|
base_path=tmp_path,
|
|
vector_enabled=True,
|
|
embedding_provider="hash",
|
|
)
|
|
|
|
entry_id = await store.save(
|
|
MemoryEntry(
|
|
id="",
|
|
type=MemoryType.LONG_TERM,
|
|
content="Service runs on Flask",
|
|
metadata={"header": "Tech"},
|
|
)
|
|
)
|
|
|
|
updated = await store.update_entry(entry_id, content="Service runs on FastAPI")
|
|
assert updated is True
|
|
|
|
entry = await store.get(entry_id)
|
|
assert entry is not None
|
|
assert "FastAPI" in entry.content
|
|
|
|
async def test_prune_memories_deletes_old_daily_entries(self, tmp_path):
|
|
store = FileMemoryStore(
|
|
base_path=tmp_path,
|
|
vector_enabled=True,
|
|
embedding_provider="hash",
|
|
)
|
|
|
|
old_entry = MemoryEntry(
|
|
id="",
|
|
type=MemoryType.DAILY,
|
|
content="Old daily memory",
|
|
created_at=datetime.now(tz=UTC) - timedelta(days=45),
|
|
metadata={"header": "Daily"},
|
|
)
|
|
old_id = await store.save(old_entry)
|
|
|
|
result = await store.prune_memories(older_than_days=30)
|
|
assert result["ok"] is True
|
|
assert result["deleted_daily_memories"] >= 1
|
|
assert await store.get(old_id) is None
|
|
|
|
async def test_clear_session_handles_non_list_json(self, tmp_path):
|
|
store = FileMemoryStore(base_path=tmp_path)
|
|
session_key = "discord:test-session"
|
|
session_file = store._get_session_file(session_key)
|
|
session_file.write_text('{"unexpected": "shape"}', encoding="utf-8")
|
|
|
|
count = await store.clear_session(session_key)
|
|
assert count == 0
|
|
assert not session_file.exists()
|
|
|
|
async def test_cleanup_orphan_records_handles_large_memory_stores(self, tmp_path):
|
|
"""Test cleanup_orphan_records with 1000+ memories without SQLite variable limit crash.
|
|
|
|
Verifies the fix for SQLite SQLITE_LIMIT_VARIABLE_NUMBER (default 999).
|
|
With >999 valid memory IDs, a direct NOT IN clause would fail with
|
|
OperationalError: too many SQL variables. The fix uses a temporary
|
|
table approach to avoid this limit.
|
|
"""
|
|
store = FileMemoryStore(base_path=tmp_path, vector_enabled=True, embedding_provider="hash")
|
|
|
|
# Create 1050 entries to comfortably exceed SQLite's default variable limit (999)
|
|
entry_ids = []
|
|
for i in range(1050):
|
|
entry = MemoryEntry(
|
|
id="",
|
|
type=MemoryType.LONG_TERM,
|
|
content=f"Memory {i}: Important fact about topic {i}",
|
|
metadata={"header": f"Entry {i}"},
|
|
)
|
|
entry_id = await store.save(entry)
|
|
entry_ids.append(entry_id)
|
|
|
|
# Verify all 1050 are indexed
|
|
assert len(store._index) == 1050
|
|
|
|
# Create orphan vector records (not in _index)
|
|
with sqlite3.connect(store._vector_db_path) as conn:
|
|
conn.execute(
|
|
"""
|
|
INSERT INTO memory_vectors
|
|
(doc_id, content, memory_type, user_scope, created_at, metadata_json,
|
|
embedding_json)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?)
|
|
""",
|
|
(
|
|
"orphan_1",
|
|
"orphaned text",
|
|
"long_term",
|
|
"default",
|
|
datetime.now(tz=UTC).isoformat(),
|
|
"{}",
|
|
"[]",
|
|
),
|
|
)
|
|
conn.execute(
|
|
"""
|
|
INSERT INTO memory_vectors
|
|
(doc_id, content, memory_type, user_scope, created_at, metadata_json,
|
|
embedding_json)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?)
|
|
""",
|
|
(
|
|
"orphan_2",
|
|
"another orphaned text",
|
|
"long_term",
|
|
"default",
|
|
datetime.now(tz=UTC).isoformat(),
|
|
"{}",
|
|
"[]",
|
|
),
|
|
)
|
|
conn.commit()
|
|
|
|
# Verify orphans are present before cleanup
|
|
with sqlite3.connect(store._vector_db_path) as conn:
|
|
orphan_count = conn.execute(
|
|
"SELECT COUNT(*) FROM memory_vectors WHERE doc_id LIKE 'orphan_%'"
|
|
).fetchone()[0]
|
|
total_before = conn.execute("SELECT COUNT(*) FROM memory_vectors").fetchone()[0]
|
|
|
|
assert orphan_count == 2
|
|
assert total_before >= 1050 + 2
|
|
|
|
# Run cleanup — must not crash with "too many SQL variables" error
|
|
# This is the critical test: with 1050 valid IDs in a direct NOT IN clause,
|
|
# SQLite would hit the variable limit. The temp table approach should work.
|
|
await store._cleanup_orphan_records()
|
|
|
|
# Verify orphans are deleted and valid entries remain
|
|
with sqlite3.connect(store._vector_db_path) as conn:
|
|
remaining_orphan_count = conn.execute(
|
|
"SELECT COUNT(*) FROM memory_vectors WHERE doc_id LIKE 'orphan_%'"
|
|
).fetchone()[0]
|
|
total_after = conn.execute("SELECT COUNT(*) FROM memory_vectors").fetchone()[0]
|
|
|
|
assert remaining_orphan_count == 0, "Orphaned records should be deleted"
|
|
# All 1050 valid entries should still have vector records
|
|
assert total_after >= 1050, "Valid entries should be preserved"
|
|
|
|
async def test_cleanup_orphan_records_with_graph_entities(self, tmp_path):
|
|
"""Test that cleanup_orphan_records also cleans up graph relationships correctly."""
|
|
store = FileMemoryStore(base_path=tmp_path, vector_enabled=True, embedding_provider="hash")
|
|
|
|
# Create 100 entries with graph relationships
|
|
entry_ids = []
|
|
for i in range(100):
|
|
entry = MemoryEntry(
|
|
id="",
|
|
type=MemoryType.LONG_TERM,
|
|
content="Project Alpha uses PostgreSQL and FastAPI framework",
|
|
metadata={"header": f"Entry {i}"},
|
|
)
|
|
entry_id = await store.save(entry)
|
|
entry_ids.append(entry_id)
|
|
|
|
await store._cleanup_orphan_records()
|
|
|
|
# Verify that the cleanup completes without error and maintains valid entities
|
|
stats = await store.get_memory_stats()
|
|
assert stats["total_memories"] >= 100
|
|
|
|
async def test_semantic_search_sqlite_caps_candidate_limit(self, tmp_path):
|
|
"""Candidate fetch size should be bounded to avoid excessive memory usage."""
|
|
store = FileMemoryStore(
|
|
base_path=tmp_path,
|
|
vector_enabled=True,
|
|
embedding_provider="hash",
|
|
)
|
|
|
|
await store.save(
|
|
MemoryEntry(
|
|
id="",
|
|
type=MemoryType.LONG_TERM,
|
|
content="Semantic candidate limit regression test",
|
|
metadata={"header": "Limits"},
|
|
)
|
|
)
|
|
|
|
observed_limits: list[int] = []
|
|
real_connect = sqlite3.connect
|
|
|
|
class _ConnectionProxy:
|
|
def __init__(self, conn):
|
|
self._conn = conn
|
|
|
|
def __enter__(self):
|
|
self._conn.__enter__()
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
return self._conn.__exit__(exc_type, exc_val, exc_tb)
|
|
|
|
def execute(self, sql, params=()):
|
|
if "FROM memory_vectors" in sql and "LIMIT ?" in sql and len(params) >= 2:
|
|
observed_limits.append(int(params[1]))
|
|
return self._conn.execute(sql, params)
|
|
|
|
def __getattr__(self, name):
|
|
return getattr(self._conn, name)
|
|
|
|
def _connect_proxy(*args, **kwargs):
|
|
return _ConnectionProxy(real_connect(*args, **kwargs))
|
|
|
|
with patch("pocketpaw.memory.file_store.sqlite3.connect", side_effect=_connect_proxy):
|
|
await store._semantic_search_sqlite("semantic query", user_id="default", limit=1)
|
|
await store._semantic_search_sqlite("semantic query", user_id="default", limit=10_000)
|
|
|
|
assert observed_limits, "Expected semantic search SQL query to run"
|
|
assert 200 in observed_limits
|
|
assert 2000 in observed_limits
|
|
|
|
|
|
# ===========================================================================
|
|
# TestGraphSVGHtmlEscaping
|
|
# ===========================================================================
|
|
|
|
|
|
class TestGraphSVGHtmlEscaping:
|
|
"""HTML escaping in get_graph_svg to prevent malformed SVG."""
|
|
|
|
@staticmethod
|
|
def _skip_if_graph_unavailable(svg: str) -> None:
|
|
if "Graph visualization unavailable" in svg:
|
|
pytest.skip("graph extras unavailable; SVG renderer fallback active")
|
|
|
|
async def test_direct_entity_escaping_greater_than(self, tmp_path):
|
|
"""Test HTML escaping by directly inserting entities with >."""
|
|
store = FileMemoryStore(
|
|
base_path=tmp_path,
|
|
vector_enabled=True,
|
|
embedding_provider="hash",
|
|
)
|
|
|
|
now = datetime.now(UTC).isoformat()
|
|
with sqlite3.connect(store._graph_db_path) as conn:
|
|
conn.execute(
|
|
"""INSERT INTO entities
|
|
(entity_id, entity_key, display_name, mention_count, user_scope,
|
|
first_seen, last_seen)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?)""",
|
|
(
|
|
"entity_1",
|
|
"value>100",
|
|
"Value>100",
|
|
5,
|
|
"default",
|
|
now,
|
|
now,
|
|
),
|
|
)
|
|
conn.commit()
|
|
|
|
svg = await store.get_graph_svg(user_id="default")
|
|
self._skip_if_graph_unavailable(svg)
|
|
assert ">" in svg
|
|
assert "<svg" in svg
|
|
assert "</svg>" in svg
|
|
|
|
async def test_direct_entity_escaping_ampersand(self, tmp_path):
|
|
"""Test HTML escaping by directly inserting entities with &."""
|
|
store = FileMemoryStore(
|
|
base_path=tmp_path,
|
|
vector_enabled=True,
|
|
embedding_provider="hash",
|
|
)
|
|
|
|
now = datetime.now(UTC).isoformat()
|
|
with sqlite3.connect(store._graph_db_path) as conn:
|
|
conn.execute(
|
|
"""INSERT INTO entities
|
|
(entity_id, entity_key, display_name, mention_count, user_scope,
|
|
first_seen, last_seen)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?)""",
|
|
(
|
|
"entity_1",
|
|
"dogs&cats",
|
|
"Dogs&Cats",
|
|
5,
|
|
"default",
|
|
now,
|
|
now,
|
|
),
|
|
)
|
|
conn.commit()
|
|
|
|
svg = await store.get_graph_svg(user_id="default")
|
|
self._skip_if_graph_unavailable(svg)
|
|
assert "&" in svg
|
|
assert "<svg" in svg
|
|
assert "</svg>" in svg
|
|
|
|
async def test_direct_relation_type_escaping(self, tmp_path):
|
|
"""Test HTML escaping for relation types in edge labels."""
|
|
store = FileMemoryStore(
|
|
base_path=tmp_path,
|
|
vector_enabled=True,
|
|
embedding_provider="hash",
|
|
)
|
|
|
|
# Insert entities and a relationship with special chars in type
|
|
now = datetime.now(UTC).isoformat()
|
|
with sqlite3.connect(store._graph_db_path) as conn:
|
|
conn.execute(
|
|
"""INSERT INTO entities
|
|
(entity_id, entity_key, display_name, mention_count, user_scope,
|
|
first_seen, last_seen)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?)""",
|
|
("entity_1", "python", "Python", 5, "default", now, now),
|
|
)
|
|
conn.execute(
|
|
"""INSERT INTO entities
|
|
(entity_id, entity_key, display_name, mention_count, user_scope,
|
|
first_seen, last_seen)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?)""",
|
|
("entity_2", "framework", "Framework", 3, "default", now, now),
|
|
)
|
|
# Relationship with normalized types; test escaping
|
|
conn.execute(
|
|
"""INSERT INTO relationships
|
|
(relationship_id, source_entity_id, target_entity_id,
|
|
relation_type, weight, user_scope, first_seen, last_seen)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
|
|
("rel_1", "entity_1", "entity_2", "uses", 0, "default", now, now),
|
|
)
|
|
conn.commit()
|
|
|
|
svg = await store.get_graph_svg(user_id="default")
|
|
self._skip_if_graph_unavailable(svg)
|
|
# Verify SVG is well-formed and contains the entities
|
|
assert "<svg" in svg
|
|
assert "</svg>" in svg
|
|
assert "Python" in svg or "Framework" in svg or "uses" in svg
|
|
|
|
async def test_svg_without_special_chars_unchanged(self, tmp_path):
|
|
"""Normal entity names without special chars work as expected."""
|
|
store = FileMemoryStore(
|
|
base_path=tmp_path,
|
|
vector_enabled=True,
|
|
embedding_provider="hash",
|
|
)
|
|
|
|
await store.save(
|
|
MemoryEntry(
|
|
id="",
|
|
type=MemoryType.LONG_TERM,
|
|
content="Python is great for backend development",
|
|
metadata={"header": "Language"},
|
|
)
|
|
)
|
|
|
|
svg = await store.get_graph_svg(user_id="default")
|
|
self._skip_if_graph_unavailable(svg)
|
|
# Should have normal SVG structure
|
|
assert "<svg" in svg
|
|
assert "</svg>" in svg
|
|
assert "<text" in svg
|
|
# Should have node label text
|
|
assert "Python" in svg
|
|
|
|
async def test_svg_malformed_without_escaping(self, tmp_path):
|
|
"""Demonstrate that without escaping, SVG would be malformed."""
|
|
# This test verifies the fix prevents SVG malformation
|
|
store = FileMemoryStore(
|
|
base_path=tmp_path,
|
|
vector_enabled=True,
|
|
embedding_provider="hash",
|
|
)
|
|
|
|
now = datetime.now(UTC).isoformat()
|
|
with sqlite3.connect(store._graph_db_path) as conn:
|
|
conn.execute(
|
|
"""INSERT INTO entities
|
|
(entity_id, entity_key, display_name, mention_count, user_scope,
|
|
first_seen, last_seen)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?)""",
|
|
(
|
|
"entity_1",
|
|
"test<bad>evil",
|
|
"Test<Bad>Evil",
|
|
5,
|
|
"default",
|
|
now,
|
|
now,
|
|
),
|
|
)
|
|
conn.commit()
|
|
|
|
svg = await store.get_graph_svg(user_id="default")
|
|
self._skip_if_graph_unavailable(svg)
|
|
# SVG should still be well-formed (properly closed tags)
|
|
assert svg.count("<svg") == 1
|
|
assert svg.count("</svg>") == 1
|
|
# Should have escaped the angle brackets
|
|
assert "<" in svg or ">" in svg
|
|
|
|
async def test_empty_graph_returns_valid_svg(self, tmp_path):
|
|
"""Empty graph with no entities returns valid SVG."""
|
|
store = FileMemoryStore(base_path=tmp_path)
|
|
|
|
svg = await store.get_graph_svg(user_id="default")
|
|
# Should return valid SVG even with no data
|
|
assert "<svg" in svg
|
|
assert "</svg>" in svg
|
|
assert not store._graph_db_path.exists()
|
|
|
|
async def test_graph_with_networkx_unavailable_returns_fallback(self, tmp_path):
|
|
"""If networkx is unavailable, returns fallback SVG."""
|
|
store = FileMemoryStore(
|
|
base_path=tmp_path,
|
|
vector_enabled=True,
|
|
embedding_provider="hash",
|
|
)
|
|
|
|
await store.save(
|
|
MemoryEntry(
|
|
id="",
|
|
type=MemoryType.LONG_TERM,
|
|
content="Test entity",
|
|
metadata={"header": "Test"},
|
|
)
|
|
)
|
|
|
|
# Mock networkx import failure
|
|
with patch("importlib.import_module", side_effect=ImportError("No networkx")):
|
|
svg = await store.get_graph_svg(user_id="default")
|
|
|
|
# Should return fallback SVG
|
|
assert "<svg" in svg
|
|
assert "</svg>" in svg
|
|
assert "Graph visualization unavailable" in svg or "pocketpaw[graph]" in svg
|