Files
pocketpaw/tests/test_memory.py
Rohit Kushwaha 8d694a1a81 fix: lint, formatting, and test fixes after PR merge batch
- Fix ruff lint errors (import sorting, unused import, line length)
- Fix ruff formatting (5 files)
- Fix a2a test client missing Content-Type header (FastAPI 0.135 compat)
- Fix memory API test mocking _store instead of public get_by_type
- Fix date-dependent test_parse_only_number assertion
- Add trailing newline to styles.css, test_memory.py, starlette compat test
- Update uv.lock for fastapi>=0.134.0

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-06 18:45:01 +05:30

560 lines
20 KiB
Python

# Tests for Memory System
# Created: 2026-02-02
import tempfile
from pathlib import Path
import pytest
from pocketpaw.memory.file_store import FileMemoryStore
from pocketpaw.memory.manager import MemoryManager
from pocketpaw.memory.protocol import MemoryEntry, MemoryType
@pytest.fixture
def temp_memory_path():
"""Create a temporary directory for memory tests."""
tmpdir = tempfile.mkdtemp()
yield Path(tmpdir)
# Windows: SQLite WAL files may still be locked; best-effort cleanup
import shutil
shutil.rmtree(tmpdir, ignore_errors=True)
@pytest.fixture
def memory_store(temp_memory_path):
"""Create a FileMemoryStore with temp path."""
return FileMemoryStore(base_path=temp_memory_path)
@pytest.fixture
def memory_manager(temp_memory_path):
"""Create a MemoryManager with temp path."""
return MemoryManager(base_path=temp_memory_path)
class TestMemoryEntry:
"""Tests for MemoryEntry dataclass."""
def test_create_entry(self):
entry = MemoryEntry(
id="test-id",
type=MemoryType.LONG_TERM,
content="Test content",
)
assert entry.id == "test-id"
assert entry.type == MemoryType.LONG_TERM
assert entry.content == "Test content"
assert entry.tags == []
assert entry.metadata == {}
def test_entry_with_tags(self):
entry = MemoryEntry(
id="test-id",
type=MemoryType.DAILY,
content="Daily note",
tags=["work", "important"],
)
assert entry.tags == ["work", "important"]
def test_session_entry(self):
entry = MemoryEntry(
id="test-id",
type=MemoryType.SESSION,
content="Hello!",
role="user",
session_key="websocket:user123",
)
assert entry.role == "user"
assert entry.session_key == "websocket:user123"
class TestFileMemoryStore:
"""Tests for FileMemoryStore."""
@pytest.mark.asyncio
async def test_save_and_get_long_term(self, memory_store):
entry = MemoryEntry(
id="",
type=MemoryType.LONG_TERM,
content="User prefers dark mode",
tags=["preferences"],
metadata={"header": "User Preferences"},
)
entry_id = await memory_store.save(entry)
assert entry_id
# Check file was created
assert memory_store.long_term_file.exists()
content = memory_store.long_term_file.read_text(encoding="utf-8")
assert "User prefers dark mode" in content
@pytest.mark.asyncio
async def test_save_session(self, memory_store):
entry = MemoryEntry(
id="",
type=MemoryType.SESSION,
content="Hello, how are you?",
role="user",
session_key="test_session",
)
await memory_store.save(entry)
# Verify session was saved
history = await memory_store.get_session("test_session")
assert len(history) == 1
assert history[0].content == "Hello, how are you?"
assert history[0].role == "user"
@pytest.mark.asyncio
async def test_clear_session(self, memory_store):
# Add some messages
for i in range(3):
entry = MemoryEntry(
id="",
type=MemoryType.SESSION,
content=f"Message {i}",
role="user",
session_key="test_session",
)
await memory_store.save(entry)
# Clear session
count = await memory_store.clear_session("test_session")
assert count == 3
# Verify empty
history = await memory_store.get_session("test_session")
assert len(history) == 0
@pytest.mark.asyncio
async def test_search(self, memory_store):
# Save some memories
entry1 = MemoryEntry(
id="",
type=MemoryType.LONG_TERM,
content="User likes Python programming",
metadata={"header": "Preferences"},
)
entry2 = MemoryEntry(
id="",
type=MemoryType.LONG_TERM,
content="User prefers dark mode",
metadata={"header": "UI"},
)
await memory_store.save(entry1)
await memory_store.save(entry2)
# Search
results = await memory_store.search(query="Python")
assert len(results) == 1
assert "Python" in results[0].content
class TestMemoryManager:
"""Tests for MemoryManager facade."""
@pytest.mark.asyncio
async def test_remember(self, memory_manager):
entry_id = await memory_manager.remember(
"User prefers dark mode",
tags=["preferences"],
header="UI Preferences",
)
assert entry_id
@pytest.mark.asyncio
async def test_note(self, memory_manager):
entry_id = await memory_manager.note(
"Had a meeting about project X",
tags=["work"],
)
assert entry_id
@pytest.mark.asyncio
async def test_session_flow(self, memory_manager):
session_key = "test:session123"
# Add messages
await memory_manager.add_to_session(session_key, "user", "Hello!")
await memory_manager.add_to_session(session_key, "assistant", "Hi there!")
await memory_manager.add_to_session(session_key, "user", "How are you?")
# Get history
history = await memory_manager.get_session_history(session_key)
assert len(history) == 3
assert history[0] == {"role": "user", "content": "Hello!"}
assert history[1] == {"role": "assistant", "content": "Hi there!"}
# Clear
count = await memory_manager.clear_session(session_key)
assert count == 3
@pytest.mark.asyncio
async def test_get_context_for_agent(self, memory_manager):
# Add some memories
await memory_manager.remember("User prefers dark mode")
await memory_manager.note("Working on PocketPaw today")
# Get context
context = await memory_manager.get_context_for_agent()
assert "Long-term Memory" in context or "Today's Notes" in context
class TestMemoryIntegration:
"""Integration tests for the memory system."""
@pytest.mark.asyncio
async def test_full_workflow(self, temp_memory_path):
"""Test a realistic workflow."""
manager = MemoryManager(base_path=temp_memory_path)
# 1. Store user preference
await manager.remember(
"User's name is Prakash",
tags=["user", "identity"],
header="User Identity",
)
# 2. Add daily note
await manager.note("Started working on memory system")
# 3. Simulate conversation
session = "websocket:prakash"
await manager.add_to_session(session, "user", "What's my name?")
await manager.add_to_session(session, "assistant", "Your name is Prakash!")
# 4. Get agent context
context = await manager.get_context_for_agent()
assert "Prakash" in context
# 5. Get session history
history = await manager.get_session_history(session)
assert len(history) == 2
# 6. Verify files exist
assert (temp_memory_path / "MEMORY.md").exists()
assert (temp_memory_path / "sessions").is_dir()
class TestGraphExtraction:
"""Tests for knowledge graph extraction with conservative regex patterns."""
@pytest.fixture
def vector_memory_store(self, temp_memory_path):
"""Create a FileMemoryStore with vector/graph features enabled."""
return FileMemoryStore(
base_path=temp_memory_path,
vector_enabled=True, # Enables graph extraction
)
@pytest.fixture
def plain_memory_store(self, temp_memory_path):
"""Create a FileMemoryStore with vector/graph features disabled."""
return FileMemoryStore(
base_path=temp_memory_path,
vector_enabled=False, # Disables graph extraction
)
def test_entity_blacklist_filtering(self, vector_memory_store):
"""Test that blacklisted generic words are rejected as entities."""
# These should be filtered out
blacklisted = [
"something",
"anything",
"question",
"thing",
"it",
"this",
"that",
"they",
"meeting",
"call",
"plan",
]
for word in blacklisted:
assert not vector_memory_store._is_valid_entity_candidate(word)
def test_valid_entity_candidates(self, vector_memory_store):
"""Test that valid entities pass validation."""
valid = [
"Project Phoenix",
"PostgreSQL",
"FastAPI",
"OpenAI",
"MyProject",
]
for word in valid:
assert vector_memory_store._is_valid_entity_candidate(word)
def test_entity_length_validation(self, vector_memory_store):
"""Test that overly long entities are rejected."""
# Too long (sentence fragment)
too_long = "This is a very long sentence that should not be an entity"
assert not vector_memory_store._is_valid_entity_candidate(too_long)
# Just right
ok_length = "Project Phoenix"
assert vector_memory_store._is_valid_entity_candidate(ok_length)
def test_self_loop_prevention(self, vector_memory_store):
"""Test that self-referential relationships are rejected."""
assert not vector_memory_store._is_valid_relationship_candidate(
"Project", "uses", "Project"
)
assert not vector_memory_store._is_valid_relationship_candidate(
"OpenAI",
"uses",
"openai", # Case insensitive
)
def test_valid_relationship_candidates(self, vector_memory_store):
"""Test that valid relationships pass validation."""
assert vector_memory_store._is_valid_relationship_candidate(
"Project Phoenix", "uses", "PostgreSQL"
)
assert vector_memory_store._is_valid_relationship_candidate("MyApp", "depends_on", "Redis")
def test_extract_graph_signals_tech_terms(self, vector_memory_store):
"""Test extraction of technology terms."""
content = "The project uses Python and PostgreSQL for the backend"
entities, relationships = vector_memory_store._extract_graph_signals(content)
# Should extract canonical tech names
assert "Python" in entities
assert "PostgreSQL" in entities
def test_extract_graph_signals_title_case(self, vector_memory_store):
"""Test extraction of title-case entities."""
content = "Project Phoenix uses FastAPI for the API layer"
entities, relationships = vector_memory_store._extract_graph_signals(content)
# Should extract title-case entities
assert "Project Phoenix" in entities
assert "FastAPI" in entities
def test_extract_graph_signals_uses_pattern(self, vector_memory_store):
"""Test 'uses' relationship pattern extraction."""
content = "Project Phoenix uses PostgreSQL for data storage"
entities, relationships = vector_memory_store._extract_graph_signals(content)
# Should find the relationship
assert any(
src == "Project Phoenix" and rel == "uses" and tgt == "PostgreSQL for data"
for src, rel, tgt in relationships
)
def test_extract_graph_signals_depends_on_pattern(self, vector_memory_store):
"""Test 'depends_on' relationship pattern extraction."""
content = "MyService depends on Redis for caching"
entities, relationships = vector_memory_store._extract_graph_signals(content)
assert any(
src == "MyService" and rel == "depends_on" and tgt == "Redis for caching"
for src, rel, tgt in relationships
)
def test_entity_canonicalization(self, vector_memory_store):
"""Canonicalization strips determiners and normalizes known tech names."""
assert vector_memory_store._canonicalize_entity_name("the openai") == "OpenAI"
assert vector_memory_store._canonicalize_entity_name("an postgresql") == "PostgreSQL"
assert (
vector_memory_store._canonicalize_entity_name(" Project Phoenix ")
== "Project Phoenix"
)
def test_relation_normalization_schema(self, vector_memory_store):
"""Relation normalization maps aliases into controlled schema."""
assert vector_memory_store._normalize_relation_type("depends on") == "depends_on"
assert vector_memory_store._normalize_relation_type("invokes") == "calls"
assert vector_memory_store._normalize_relation_type("inherits_from") == "extends"
assert vector_memory_store._normalize_relation_type("unknown_rel") == ""
def test_confidence_threshold_blocks_low_confidence_edges(
self,
vector_memory_store,
monkeypatch,
):
"""Low-confidence relationship candidates are dropped."""
monkeypatch.setattr(
FileMemoryStore,
"_score_relationship_candidate",
staticmethod(lambda src, rel, tgt: 0.50),
)
content = "Project Phoenix uses PostgreSQL for data storage"
entities, relationships = vector_memory_store._extract_graph_signals(content)
assert "Project Phoenix" in entities
assert "PostgreSQL" in entities
assert relationships == []
def test_stores_only_high_confidence_edges(self, vector_memory_store):
"""Only confidence-qualified edges are emitted from extraction."""
content = "Project Alpha is built on FastAPI"
entities, relationships = vector_memory_store._extract_graph_signals(content)
assert any(rel == "built_on" for _, rel, _ in relationships)
def test_extract_graph_signals_built_on_pattern(self, vector_memory_store):
"""Test 'built_on' relationship pattern extraction."""
content = "Project Alpha is built on FastAPI"
entities, relationships = vector_memory_store._extract_graph_signals(content)
assert any(
src == "Project Alpha" and rel == "built_on" and tgt == "FastAPI"
for src, rel, tgt in relationships
)
def test_extract_graph_signals_is_a_pattern(self, vector_memory_store):
"""Test 'is_a' relationship pattern extraction."""
content = "PostgreSQL is a type of database"
entities, relationships = vector_memory_store._extract_graph_signals(content)
assert any(
src == "PostgreSQL" and rel == "is_a" and tgt == "database"
for src, rel, tgt in relationships
)
def test_extract_graph_signals_part_of_pattern(self, vector_memory_store):
"""Test 'part_of' relationship pattern extraction."""
content = "AuthService is part of Project Phoenix"
entities, relationships = vector_memory_store._extract_graph_signals(content)
assert any(
src == "AuthService" and rel == "part_of" and tgt == "Project Phoenix"
for src, rel, tgt in relationships
)
def test_extract_graph_signals_implements_pattern(self, vector_memory_store):
"""Test 'implements' relationship pattern extraction."""
content = "MyClient implements the API interface"
entities, relationships = vector_memory_store._extract_graph_signals(content)
assert any(
src == "MyClient" and rel == "implements" and tgt == "API interface"
for src, rel, tgt in relationships
)
def test_extract_graph_signals_extends_pattern(self, vector_memory_store):
"""Test 'extends' relationship pattern extraction."""
content = "MyBackend extends BaseService"
entities, relationships = vector_memory_store._extract_graph_signals(content)
assert any(
src == "MyBackend" and rel == "extends" and tgt == "BaseService"
for src, rel, tgt in relationships
)
def test_extract_graph_signals_inherits_from_pattern(self, vector_memory_store):
"""Test 'inherits_from' relationship pattern extraction."""
content = "MyBackend inherits from BaseService"
entities, relationships = vector_memory_store._extract_graph_signals(content)
assert any(
src == "MyBackend" and rel == "extends" and tgt == "BaseService"
for src, rel, tgt in relationships
)
def test_extract_graph_signals_calls_pattern(self, vector_memory_store):
"""Test 'calls' relationship pattern extraction."""
content = "Frontend calls the Backend API"
entities, relationships = vector_memory_store._extract_graph_signals(content)
assert any(
src == "Frontend" and rel == "calls" and tgt == "Backend API"
for src, rel, tgt in relationships
)
def test_extract_graph_signals_invokes_pattern(self, vector_memory_store):
"""Test 'invokes' relationship pattern extraction."""
content = "Frontend invokes the Backend API"
entities, relationships = vector_memory_store._extract_graph_signals(content)
assert any(
src == "Frontend" and rel == "calls" and tgt == "Backend API"
for src, rel, tgt in relationships
)
def test_no_false_positive_has_pattern(self, vector_memory_store):
"""Test that 'has' pattern does NOT create false positives."""
content = "The user has a question about billing"
entities, relationships = vector_memory_store._extract_graph_signals(content)
# Should NOT create a "user has question" relationship
assert not any(
rel == "has" or "user" in src.lower() and "question" in tgt.lower()
for src, rel, tgt in relationships
)
def test_no_false_positive_generic_verbs(self, vector_memory_store):
"""Test that generic conversational text doesn't create relationships."""
content = "I have a meeting tomorrow and need to check something"
entities, relationships = vector_memory_store._extract_graph_signals(content)
# Should not extract relationships from generic text
# (no capitalized tech entities, blacklisted words filtered)
assert len(relationships) == 0
def test_graph_disabled_when_vector_disabled(self, plain_memory_store):
"""Test that graph extraction is disabled when vector_enabled=False."""
assert not plain_memory_store._graph_enabled
def test_graph_enabled_when_vector_enabled(self, vector_memory_store):
"""Test that graph extraction is enabled when vector_enabled=True."""
assert vector_memory_store._graph_enabled
def test_extraction_bounds(self, vector_memory_store):
"""Test that extraction is bounded to prevent runaway processing."""
# Create content with many potential entities
content = " ".join([f"Project{i} uses Tool{i}" for i in range(50)])
entities, relationships = vector_memory_store._extract_graph_signals(content)
# Should be bounded
assert len(entities) <= 12
assert len(relationships) <= 24
def test_no_arbitrary_fallback_edges(self, vector_memory_store):
"""Test that arbitrary 'related_to' fallback edges are NOT created."""
content = "Project Alpha and Project Beta are both important"
entities, relationships = vector_memory_store._extract_graph_signals(content)
# Should NOT create arbitrary fallback edges
assert not any(rel == "related_to" for src, rel, tgt in relationships)
async def test_get_by_type_delegates_to_store(tmp_path):
"""Test that get_by_type public method delegates to the underlying store."""
from unittest.mock import AsyncMock
from pocketpaw.memory.manager import MemoryManager
from pocketpaw.memory.protocol import MemoryType
manager = MemoryManager(base_path=tmp_path)
manager._store.get_by_type = AsyncMock(return_value=[])
await manager.get_by_type(MemoryType.LONG_TERM, limit=10)
manager._store.get_by_type.assert_awaited_once_with(
MemoryType.LONG_TERM, limit=10, user_id=None
)
async def test_get_by_type_forwards_user_id(tmp_path):
"""Test that get_by_type forwards user_id to the store for scoped retrieval."""
from unittest.mock import AsyncMock
from pocketpaw.memory.manager import MemoryManager
from pocketpaw.memory.protocol import MemoryType
manager = MemoryManager(base_path=tmp_path)
manager._store.get_by_type = AsyncMock(return_value=[])
await manager.get_by_type(MemoryType.LONG_TERM, limit=5, user_id="user123")
manager._store.get_by_type.assert_awaited_once_with(
MemoryType.LONG_TERM, limit=5, user_id="user123"
)