auto-commit-service/tests/test_llm_client.py
2026-01-05 18:41:40 -08:00

457 lines
20 KiB
Python

"""Tests for LLM client module.
Proves that "sloppy" commit messages from the LLM get properly cleaned
into the expected format: type(scope): emoji description
"""
import pytest
from unittest.mock import AsyncMock, patch, MagicMock
import httpx
from auto_commit_service.git.diff_parser import DiffSummary, summarize_diff
from auto_commit_service.llm.client import (
LlamaCommitClient,
LlamaServiceError,
LlamaServiceUnavailable,
)
class TestLlamaCommitClient:
"""Tests for LlamaCommitClient class."""
@pytest.fixture
def client(self) -> LlamaCommitClient:
"""Create a test client."""
return LlamaCommitClient(
base_url="http://localhost:8000",
timeout=5.0,
)
async def test_health_check_unavailable(self, client: LlamaCommitClient) -> None:
"""Test health check when service is unavailable."""
health = await client.health_check()
assert health["status"] == "error"
assert not health["model_loaded"]
async def test_is_available_when_down(self, client: LlamaCommitClient) -> None:
"""Test is_available when service is down."""
available = await client.is_available()
assert not available
class TestCleanResponseNormalization:
"""Tests proving the _clean_response method normalizes sloppy LLM output.
The LLM may produce commit messages in various formats. These tests
verify that all variations get normalized to: type(scope): emoji description
"""
@pytest.fixture
def client(self) -> LlamaCommitClient:
return LlamaCommitClient()
# ==========================================================================
# Already correct format - should pass through
# ==========================================================================
def test_correct_format_passes_through(self, client: LlamaCommitClient) -> None:
"""Correctly formatted messages pass through unchanged."""
result = client._clean_response("feat(auth): ✨ add login endpoint")
assert result == "feat(auth): ✨ add login endpoint"
def test_correct_format_with_different_types(self, client: LlamaCommitClient) -> None:
"""All valid types with correct format pass through (emojis may be normalized)."""
test_cases = [
("fix(api): 🐛 resolve null pointer in handler", "fix(api):"),
("refactor(core): ♻️ extract validation logic", "refactor(core):"),
("chore(deps): 🔧 update eslint to v9", "chore(deps):"),
("docs(readme): 📝 add installation section", "docs(readme):"),
("build(ci): ⬆️ upgrade node to v20", "build(ci):"),
("test(auth): ✅ add unit tests for login", "test(auth):"),
("perf(query): ⚡ optimize database lookup", "perf(query):"),
]
for msg, expected_prefix in test_cases:
result = client._clean_response(msg)
assert result.startswith(expected_prefix), f"Expected {result} to start with {expected_prefix}"
# Verify description is preserved
assert "resolve null pointer" in result or "extract validation" in result or \
"update eslint" in result or "add installation" in result or \
"upgrade node" in result or "add unit tests" in result or \
"optimize database" in result
# ==========================================================================
# Markdown formatting removal
# ==========================================================================
def test_removes_markdown_code_blocks(self, client: LlamaCommitClient) -> None:
"""Strips markdown code block wrapper."""
result = client._clean_response("```\nfeat(ui): ✨ add button component\n```")
assert result == "feat(ui): ✨ add button component"
def test_removes_language_tagged_code_blocks(self, client: LlamaCommitClient) -> None:
"""Strips code blocks with language tags."""
result = client._clean_response("```text\nfix(api): 🐛 fix timeout\n```")
assert result == "fix(api): 🐛 fix timeout"
def test_removes_surrounding_quotes(self, client: LlamaCommitClient) -> None:
"""Strips surrounding double quotes."""
result = client._clean_response('"feat(auth): ✨ add oauth"')
assert result == "feat(auth): ✨ add oauth"
def test_removes_single_quotes(self, client: LlamaCommitClient) -> None:
"""Strips surrounding single quotes."""
result = client._clean_response("'chore(config): 🔧 update settings'")
assert result == "chore(config): 🔧 update settings"
def test_takes_only_first_line(self, client: LlamaCommitClient) -> None:
"""When LLM produces multiple lines, only first is used."""
sloppy = """feat(api): ✨ add user endpoint
This adds a new REST endpoint for user management.
It supports CRUD operations."""
result = client._clean_response(sloppy)
assert result == "feat(api): ✨ add user endpoint"
# ==========================================================================
# Emoji position correction
# ==========================================================================
def test_moves_emoji_from_start_to_after_colon(self, client: LlamaCommitClient) -> None:
"""When emoji comes before type, move it after the colon."""
result = client._clean_response("✨ feat(ui): add new component")
assert result == "feat(ui): ✨ add new component"
def test_fixes_emoji_before_type_with_fix(self, client: LlamaCommitClient) -> None:
"""Bug fix emoji at start gets repositioned."""
result = client._clean_response("🐛 fix(auth): resolve login bug")
assert result == "fix(auth): 🐛 resolve login bug"
# ==========================================================================
# Missing emoji addition
# ==========================================================================
def test_adds_emoji_for_feat_type(self, client: LlamaCommitClient) -> None:
"""Adds ✨ emoji for feat type when missing."""
result = client._clean_response("feat(api): add health endpoint")
assert result == "feat(api): ✨ add health endpoint"
def test_adds_emoji_for_fix_type(self, client: LlamaCommitClient) -> None:
"""Adds 🐛 emoji for fix type when missing."""
result = client._clean_response("fix(auth): resolve timeout issue")
assert result == "fix(auth): 🐛 resolve timeout issue"
def test_adds_emoji_for_refactor_type(self, client: LlamaCommitClient) -> None:
"""Adds ♻️ emoji for refactor type when missing."""
result = client._clean_response("refactor(core): extract shared logic")
assert result == "refactor(core): ♻️ extract shared logic"
def test_adds_emoji_for_chore_type(self, client: LlamaCommitClient) -> None:
"""Adds 🔧 emoji for chore type when missing."""
result = client._clean_response("chore(deps): update dependencies")
assert result == "chore(deps): 🔧 update dependencies"
def test_adds_emoji_for_docs_type(self, client: LlamaCommitClient) -> None:
"""Adds 📝 emoji for docs type when missing."""
result = client._clean_response("docs(readme): update installation guide")
assert result == "docs(readme): 📝 update installation guide"
# ==========================================================================
# Fallback to chore(shared) for emoji-only messages
# ==========================================================================
def test_wraps_emoji_only_message(self, client: LlamaCommitClient) -> None:
"""Bare emoji + description gets wrapped in chore(shared)."""
result = client._clean_response("✨ add new feature")
assert result == "chore(shared): ✨ add new feature"
def test_wraps_wrench_emoji_message(self, client: LlamaCommitClient) -> None:
"""Wrench emoji messages become chore(shared)."""
result = client._clean_response("🔧 update config")
assert result == "chore(shared): 🔧 update config"
# ==========================================================================
# Keyword inference for unstructured messages
# ==========================================================================
def test_infers_feat_from_add_keyword(self, client: LlamaCommitClient) -> None:
"""'add' keyword triggers feat type inference."""
result = client._clean_response("add new authentication module")
assert result.startswith("feat(shared): ✨")
assert "add new authentication module" in result
def test_infers_fix_from_fix_keyword(self, client: LlamaCommitClient) -> None:
"""'fix' keyword triggers fix type inference."""
result = client._clean_response("fix null pointer in user handler")
assert result.startswith("fix(shared): 🐛")
def test_infers_refactor_from_refactor_keyword(self, client: LlamaCommitClient) -> None:
"""'refactor' keyword triggers refactor type inference."""
result = client._clean_response("refactor the database layer")
assert result.startswith("refactor(shared): ♻️")
def test_infers_chore_from_update_keyword(self, client: LlamaCommitClient) -> None:
"""'update' keyword triggers chore type inference."""
result = client._clean_response("update eslint configuration")
assert result.startswith("chore(shared): 🔧")
def test_fallback_to_chore_for_unknown(self, client: LlamaCommitClient) -> None:
"""Unknown messages default to chore(shared): 🔧."""
result = client._clean_response("miscellaneous changes to codebase")
assert result.startswith("chore(shared): 🔧")
class TestCommitMessageGeneration:
"""End-to-end tests for commit message generation with mocked LLM."""
@pytest.fixture
def client(self) -> LlamaCommitClient:
return LlamaCommitClient(base_url="http://test:8000")
@pytest.fixture
def mock_httpx_response(self):
"""Factory to create mock httpx responses."""
def _create(content: str, status_code: int = 200):
response = MagicMock(spec=httpx.Response)
response.status_code = status_code
response.json.return_value = {"content": content}
response.text = content
return response
return _create
async def test_generate_from_diff_cleans_output(
self, client: LlamaCommitClient, mock_httpx_response
) -> None:
"""Full flow: diff -> prompt -> mock LLM -> cleaned message."""
# Simulate LLM returning a sloppy message
mock_response = mock_httpx_response('"✨ feat(api): add new endpoint"')
with patch.object(client, "_get_client") as mock_get_client:
mock_http_client = AsyncMock()
mock_http_client.post.return_value = mock_response
mock_get_client.return_value = mock_http_client
result = await client.generate_from_diff(
diff="diff --git a/api.py b/api.py\n+def new_endpoint(): pass",
repo_name="test-repo",
)
# Quotes should be stripped, message cleaned
assert result == "feat(api): ✨ add new endpoint"
async def test_generate_commit_message_with_summary(
self, client: LlamaCommitClient, mock_httpx_response
) -> None:
"""Generate from DiffSummary produces cleaned message."""
summary = DiffSummary(
files_modified=2,
files_added=1,
additions=50,
deletions=10,
file_types={".py": 2, ".md": 1},
key_files=["src/api.py", "src/utils.py", "README.md"],
diff_excerpt="@@ -1,5 +1,10 @@ ...",
)
# Simulate LLM returning message without emoji
mock_response = mock_httpx_response("feat(api): add authentication module")
with patch.object(client, "_get_client") as mock_get_client:
mock_http_client = AsyncMock()
mock_http_client.post.return_value = mock_response
mock_get_client.return_value = mock_http_client
result = await client.generate_commit_message(
diff_summary=summary,
repo_name="auth-service",
branch="main",
)
# Emoji should be added automatically
assert result == "feat(api): ✨ add authentication module"
async def test_handles_service_unavailable(
self, client: LlamaCommitClient
) -> None:
"""Raises LlamaServiceUnavailable when service is down."""
with patch.object(client, "_get_client") as mock_get_client:
mock_http_client = AsyncMock()
mock_http_client.post.side_effect = httpx.ConnectError("Connection refused")
mock_get_client.return_value = mock_http_client
with pytest.raises(LlamaServiceUnavailable):
await client.generate_from_diff("diff content", "repo")
async def test_handles_503_service_unavailable(
self, client: LlamaCommitClient, mock_httpx_response
) -> None:
"""Raises LlamaServiceUnavailable on 503 response."""
mock_response = mock_httpx_response("Service Unavailable", status_code=503)
with patch.object(client, "_get_client") as mock_get_client:
mock_http_client = AsyncMock()
mock_http_client.post.return_value = mock_response
mock_get_client.return_value = mock_http_client
with pytest.raises(LlamaServiceUnavailable):
await client.generate_from_diff("diff content", "repo")
class TestDiffToMessageIntegration:
"""Integration tests: diff parsing -> prompt building -> message cleaning.
These tests demonstrate the full pipeline from raw git diff to final
cleaned commit message, using mocked LLM responses.
"""
@pytest.fixture
def sample_feature_diff(self) -> str:
"""Diff showing a new feature being added."""
return """diff --git a/src/auth/login.py b/src/auth/login.py
new file mode 100644
index 0000000..abc1234
--- /dev/null
+++ b/src/auth/login.py
@@ -0,0 +1,25 @@
+from flask import request, jsonify
+
+def login_endpoint():
+ username = request.json.get('username')
+ password = request.json.get('password')
+ if authenticate(username, password):
+ return jsonify({'token': generate_token(username)})
+ return jsonify({'error': 'Invalid credentials'}), 401
"""
@pytest.fixture
def sample_bugfix_diff(self) -> str:
"""Diff showing a bug being fixed."""
return """diff --git a/src/api/handler.py b/src/api/handler.py
index 1234567..abcdefg 100644
--- a/src/api/handler.py
+++ b/src/api/handler.py
@@ -15,7 +15,9 @@ def process_request(data):
- result = data.get('value')
+ result = data.get('value')
+ if result is None:
+ raise ValueError("Missing required field: value")
return transform(result)
"""
@pytest.fixture
def sample_chore_diff(self) -> str:
"""Diff showing a config/maintenance change."""
return """diff --git a/pyproject.toml b/pyproject.toml
index 1234567..abcdefg 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -10,7 +10,7 @@ dependencies = [
- "httpx>=0.24.0",
+ "httpx>=0.25.0",
"pydantic>=2.0.0",
]
"""
def test_parse_feature_diff(self, sample_feature_diff: str) -> None:
"""Feature diff is correctly parsed."""
summary = summarize_diff(sample_feature_diff)
assert summary.files_added == 1
assert summary.files_modified == 0
assert ".py" in summary.file_types
assert "src/auth/login.py" in summary.key_files
def test_parse_bugfix_diff(self, sample_bugfix_diff: str) -> None:
"""Bugfix diff is correctly parsed."""
summary = summarize_diff(sample_bugfix_diff)
assert summary.files_modified == 1
assert summary.additions >= 2 # Added lines
assert "src/api/handler.py" in summary.key_files
def test_parse_chore_diff(self, sample_chore_diff: str) -> None:
"""Chore/config diff is correctly parsed."""
summary = summarize_diff(sample_chore_diff)
assert summary.files_modified == 1
assert ".toml" in summary.file_types
assert "pyproject.toml" in summary.key_files
async def test_full_pipeline_feature(self, sample_feature_diff: str) -> None:
"""Full pipeline for feature: diff -> summary -> prompt -> clean message."""
summary = summarize_diff(sample_feature_diff)
client = LlamaCommitClient()
# Simulate sloppy LLM response
sloppy_response = "✨ feat(auth): add login endpoint"
mock_response = MagicMock(spec=httpx.Response)
mock_response.status_code = 200
mock_response.json.return_value = {"content": sloppy_response}
with patch.object(client, "_get_client") as mock_get_client:
mock_http_client = AsyncMock()
mock_http_client.post.return_value = mock_response
mock_get_client.return_value = mock_http_client
result = await client.generate_commit_message(summary, "auth-service")
# Emoji moved to correct position
assert result == "feat(auth): ✨ add login endpoint"
async def test_full_pipeline_bugfix(self, sample_bugfix_diff: str) -> None:
"""Full pipeline for bugfix: diff -> summary -> prompt -> clean message."""
summary = summarize_diff(sample_bugfix_diff)
client = LlamaCommitClient()
# LLM returns message without emoji
sloppy_response = "fix(api): handle missing value field"
mock_response = MagicMock(spec=httpx.Response)
mock_response.status_code = 200
mock_response.json.return_value = {"content": sloppy_response}
with patch.object(client, "_get_client") as mock_get_client:
mock_http_client = AsyncMock()
mock_http_client.post.return_value = mock_response
mock_get_client.return_value = mock_http_client
result = await client.generate_commit_message(summary, "api-service")
# Emoji should be added
assert result == "fix(api): 🐛 handle missing value field"
async def test_full_pipeline_chore(self, sample_chore_diff: str) -> None:
"""Full pipeline for chore: diff -> summary -> prompt -> clean message."""
summary = summarize_diff(sample_chore_diff)
client = LlamaCommitClient()
# LLM returns bare emoji message
sloppy_response = "⬆️ update httpx dependency"
mock_response = MagicMock(spec=httpx.Response)
mock_response.status_code = 200
mock_response.json.return_value = {"content": sloppy_response}
with patch.object(client, "_get_client") as mock_get_client:
mock_http_client = AsyncMock()
mock_http_client.post.return_value = mock_response
mock_get_client.return_value = mock_http_client
result = await client.generate_commit_message(summary, "project")
# Should be wrapped in chore(shared)
assert result == "chore(shared): ⬆️ update httpx dependency"
class TestClientContextManager:
"""Tests for async context manager."""
async def test_context_manager(self) -> None:
"""Test using client as async context manager."""
async with LlamaCommitClient() as client:
# Trigger client creation by calling a method
await client.health_check()
assert client._client is not None
# After exit, client should be closed
assert client._client is None or client._client.is_closed