457 lines
20 KiB
Python
457 lines
20 KiB
Python
"""Tests for LLM client module.
|
|
|
|
Proves that "sloppy" commit messages from the LLM get properly cleaned
|
|
into the expected format: type(scope): emoji description
|
|
"""
|
|
|
|
import pytest
|
|
from unittest.mock import AsyncMock, patch, MagicMock
|
|
import httpx
|
|
|
|
from auto_commit_service.git.diff_parser import DiffSummary, summarize_diff
|
|
from auto_commit_service.llm.client import (
|
|
LlamaCommitClient,
|
|
LlamaServiceError,
|
|
LlamaServiceUnavailable,
|
|
)
|
|
|
|
|
|
class TestLlamaCommitClient:
|
|
"""Tests for LlamaCommitClient class."""
|
|
|
|
@pytest.fixture
|
|
def client(self) -> LlamaCommitClient:
|
|
"""Create a test client."""
|
|
return LlamaCommitClient(
|
|
base_url="http://localhost:8000",
|
|
timeout=5.0,
|
|
)
|
|
|
|
async def test_health_check_unavailable(self, client: LlamaCommitClient) -> None:
|
|
"""Test health check when service is unavailable."""
|
|
health = await client.health_check()
|
|
|
|
assert health["status"] == "error"
|
|
assert not health["model_loaded"]
|
|
|
|
async def test_is_available_when_down(self, client: LlamaCommitClient) -> None:
|
|
"""Test is_available when service is down."""
|
|
available = await client.is_available()
|
|
assert not available
|
|
|
|
|
|
class TestCleanResponseNormalization:
|
|
"""Tests proving the _clean_response method normalizes sloppy LLM output.
|
|
|
|
The LLM may produce commit messages in various formats. These tests
|
|
verify that all variations get normalized to: type(scope): emoji description
|
|
"""
|
|
|
|
@pytest.fixture
|
|
def client(self) -> LlamaCommitClient:
|
|
return LlamaCommitClient()
|
|
|
|
# ==========================================================================
|
|
# Already correct format - should pass through
|
|
# ==========================================================================
|
|
|
|
def test_correct_format_passes_through(self, client: LlamaCommitClient) -> None:
|
|
"""Correctly formatted messages pass through unchanged."""
|
|
result = client._clean_response("feat(auth): ✨ add login endpoint")
|
|
assert result == "feat(auth): ✨ add login endpoint"
|
|
|
|
def test_correct_format_with_different_types(self, client: LlamaCommitClient) -> None:
|
|
"""All valid types with correct format pass through (emojis may be normalized)."""
|
|
test_cases = [
|
|
("fix(api): 🐛 resolve null pointer in handler", "fix(api):"),
|
|
("refactor(core): ♻️ extract validation logic", "refactor(core):"),
|
|
("chore(deps): 🔧 update eslint to v9", "chore(deps):"),
|
|
("docs(readme): 📝 add installation section", "docs(readme):"),
|
|
("build(ci): ⬆️ upgrade node to v20", "build(ci):"),
|
|
("test(auth): ✅ add unit tests for login", "test(auth):"),
|
|
("perf(query): ⚡ optimize database lookup", "perf(query):"),
|
|
]
|
|
for msg, expected_prefix in test_cases:
|
|
result = client._clean_response(msg)
|
|
assert result.startswith(expected_prefix), f"Expected {result} to start with {expected_prefix}"
|
|
# Verify description is preserved
|
|
assert "resolve null pointer" in result or "extract validation" in result or \
|
|
"update eslint" in result or "add installation" in result or \
|
|
"upgrade node" in result or "add unit tests" in result or \
|
|
"optimize database" in result
|
|
|
|
# ==========================================================================
|
|
# Markdown formatting removal
|
|
# ==========================================================================
|
|
|
|
def test_removes_markdown_code_blocks(self, client: LlamaCommitClient) -> None:
|
|
"""Strips markdown code block wrapper."""
|
|
result = client._clean_response("```\nfeat(ui): ✨ add button component\n```")
|
|
assert result == "feat(ui): ✨ add button component"
|
|
|
|
def test_removes_language_tagged_code_blocks(self, client: LlamaCommitClient) -> None:
|
|
"""Strips code blocks with language tags."""
|
|
result = client._clean_response("```text\nfix(api): 🐛 fix timeout\n```")
|
|
assert result == "fix(api): 🐛 fix timeout"
|
|
|
|
def test_removes_surrounding_quotes(self, client: LlamaCommitClient) -> None:
|
|
"""Strips surrounding double quotes."""
|
|
result = client._clean_response('"feat(auth): ✨ add oauth"')
|
|
assert result == "feat(auth): ✨ add oauth"
|
|
|
|
def test_removes_single_quotes(self, client: LlamaCommitClient) -> None:
|
|
"""Strips surrounding single quotes."""
|
|
result = client._clean_response("'chore(config): 🔧 update settings'")
|
|
assert result == "chore(config): 🔧 update settings"
|
|
|
|
def test_takes_only_first_line(self, client: LlamaCommitClient) -> None:
|
|
"""When LLM produces multiple lines, only first is used."""
|
|
sloppy = """feat(api): ✨ add user endpoint
|
|
This adds a new REST endpoint for user management.
|
|
It supports CRUD operations."""
|
|
result = client._clean_response(sloppy)
|
|
assert result == "feat(api): ✨ add user endpoint"
|
|
|
|
# ==========================================================================
|
|
# Emoji position correction
|
|
# ==========================================================================
|
|
|
|
def test_moves_emoji_from_start_to_after_colon(self, client: LlamaCommitClient) -> None:
|
|
"""When emoji comes before type, move it after the colon."""
|
|
result = client._clean_response("✨ feat(ui): add new component")
|
|
assert result == "feat(ui): ✨ add new component"
|
|
|
|
def test_fixes_emoji_before_type_with_fix(self, client: LlamaCommitClient) -> None:
|
|
"""Bug fix emoji at start gets repositioned."""
|
|
result = client._clean_response("🐛 fix(auth): resolve login bug")
|
|
assert result == "fix(auth): 🐛 resolve login bug"
|
|
|
|
# ==========================================================================
|
|
# Missing emoji addition
|
|
# ==========================================================================
|
|
|
|
def test_adds_emoji_for_feat_type(self, client: LlamaCommitClient) -> None:
|
|
"""Adds ✨ emoji for feat type when missing."""
|
|
result = client._clean_response("feat(api): add health endpoint")
|
|
assert result == "feat(api): ✨ add health endpoint"
|
|
|
|
def test_adds_emoji_for_fix_type(self, client: LlamaCommitClient) -> None:
|
|
"""Adds 🐛 emoji for fix type when missing."""
|
|
result = client._clean_response("fix(auth): resolve timeout issue")
|
|
assert result == "fix(auth): 🐛 resolve timeout issue"
|
|
|
|
def test_adds_emoji_for_refactor_type(self, client: LlamaCommitClient) -> None:
|
|
"""Adds ♻️ emoji for refactor type when missing."""
|
|
result = client._clean_response("refactor(core): extract shared logic")
|
|
assert result == "refactor(core): ♻️ extract shared logic"
|
|
|
|
def test_adds_emoji_for_chore_type(self, client: LlamaCommitClient) -> None:
|
|
"""Adds 🔧 emoji for chore type when missing."""
|
|
result = client._clean_response("chore(deps): update dependencies")
|
|
assert result == "chore(deps): 🔧 update dependencies"
|
|
|
|
def test_adds_emoji_for_docs_type(self, client: LlamaCommitClient) -> None:
|
|
"""Adds 📝 emoji for docs type when missing."""
|
|
result = client._clean_response("docs(readme): update installation guide")
|
|
assert result == "docs(readme): 📝 update installation guide"
|
|
|
|
# ==========================================================================
|
|
# Fallback to chore(shared) for emoji-only messages
|
|
# ==========================================================================
|
|
|
|
def test_wraps_emoji_only_message(self, client: LlamaCommitClient) -> None:
|
|
"""Bare emoji + description gets wrapped in chore(shared)."""
|
|
result = client._clean_response("✨ add new feature")
|
|
assert result == "chore(shared): ✨ add new feature"
|
|
|
|
def test_wraps_wrench_emoji_message(self, client: LlamaCommitClient) -> None:
|
|
"""Wrench emoji messages become chore(shared)."""
|
|
result = client._clean_response("🔧 update config")
|
|
assert result == "chore(shared): 🔧 update config"
|
|
|
|
# ==========================================================================
|
|
# Keyword inference for unstructured messages
|
|
# ==========================================================================
|
|
|
|
def test_infers_feat_from_add_keyword(self, client: LlamaCommitClient) -> None:
|
|
"""'add' keyword triggers feat type inference."""
|
|
result = client._clean_response("add new authentication module")
|
|
assert result.startswith("feat(shared): ✨")
|
|
assert "add new authentication module" in result
|
|
|
|
def test_infers_fix_from_fix_keyword(self, client: LlamaCommitClient) -> None:
|
|
"""'fix' keyword triggers fix type inference."""
|
|
result = client._clean_response("fix null pointer in user handler")
|
|
assert result.startswith("fix(shared): 🐛")
|
|
|
|
def test_infers_refactor_from_refactor_keyword(self, client: LlamaCommitClient) -> None:
|
|
"""'refactor' keyword triggers refactor type inference."""
|
|
result = client._clean_response("refactor the database layer")
|
|
assert result.startswith("refactor(shared): ♻️")
|
|
|
|
def test_infers_chore_from_update_keyword(self, client: LlamaCommitClient) -> None:
|
|
"""'update' keyword triggers chore type inference."""
|
|
result = client._clean_response("update eslint configuration")
|
|
assert result.startswith("chore(shared): 🔧")
|
|
|
|
def test_fallback_to_chore_for_unknown(self, client: LlamaCommitClient) -> None:
|
|
"""Unknown messages default to chore(shared): 🔧."""
|
|
result = client._clean_response("miscellaneous changes to codebase")
|
|
assert result.startswith("chore(shared): 🔧")
|
|
|
|
|
|
class TestCommitMessageGeneration:
|
|
"""End-to-end tests for commit message generation with mocked LLM."""
|
|
|
|
@pytest.fixture
|
|
def client(self) -> LlamaCommitClient:
|
|
return LlamaCommitClient(base_url="http://test:8000")
|
|
|
|
@pytest.fixture
|
|
def mock_httpx_response(self):
|
|
"""Factory to create mock httpx responses."""
|
|
def _create(content: str, status_code: int = 200):
|
|
response = MagicMock(spec=httpx.Response)
|
|
response.status_code = status_code
|
|
response.json.return_value = {"content": content}
|
|
response.text = content
|
|
return response
|
|
return _create
|
|
|
|
async def test_generate_from_diff_cleans_output(
|
|
self, client: LlamaCommitClient, mock_httpx_response
|
|
) -> None:
|
|
"""Full flow: diff -> prompt -> mock LLM -> cleaned message."""
|
|
# Simulate LLM returning a sloppy message
|
|
mock_response = mock_httpx_response('"✨ feat(api): add new endpoint"')
|
|
|
|
with patch.object(client, "_get_client") as mock_get_client:
|
|
mock_http_client = AsyncMock()
|
|
mock_http_client.post.return_value = mock_response
|
|
mock_get_client.return_value = mock_http_client
|
|
|
|
result = await client.generate_from_diff(
|
|
diff="diff --git a/api.py b/api.py\n+def new_endpoint(): pass",
|
|
repo_name="test-repo",
|
|
)
|
|
|
|
# Quotes should be stripped, message cleaned
|
|
assert result == "feat(api): ✨ add new endpoint"
|
|
|
|
async def test_generate_commit_message_with_summary(
|
|
self, client: LlamaCommitClient, mock_httpx_response
|
|
) -> None:
|
|
"""Generate from DiffSummary produces cleaned message."""
|
|
summary = DiffSummary(
|
|
files_modified=2,
|
|
files_added=1,
|
|
additions=50,
|
|
deletions=10,
|
|
file_types={".py": 2, ".md": 1},
|
|
key_files=["src/api.py", "src/utils.py", "README.md"],
|
|
diff_excerpt="@@ -1,5 +1,10 @@ ...",
|
|
)
|
|
|
|
# Simulate LLM returning message without emoji
|
|
mock_response = mock_httpx_response("feat(api): add authentication module")
|
|
|
|
with patch.object(client, "_get_client") as mock_get_client:
|
|
mock_http_client = AsyncMock()
|
|
mock_http_client.post.return_value = mock_response
|
|
mock_get_client.return_value = mock_http_client
|
|
|
|
result = await client.generate_commit_message(
|
|
diff_summary=summary,
|
|
repo_name="auth-service",
|
|
branch="main",
|
|
)
|
|
|
|
# Emoji should be added automatically
|
|
assert result == "feat(api): ✨ add authentication module"
|
|
|
|
async def test_handles_service_unavailable(
|
|
self, client: LlamaCommitClient
|
|
) -> None:
|
|
"""Raises LlamaServiceUnavailable when service is down."""
|
|
with patch.object(client, "_get_client") as mock_get_client:
|
|
mock_http_client = AsyncMock()
|
|
mock_http_client.post.side_effect = httpx.ConnectError("Connection refused")
|
|
mock_get_client.return_value = mock_http_client
|
|
|
|
with pytest.raises(LlamaServiceUnavailable):
|
|
await client.generate_from_diff("diff content", "repo")
|
|
|
|
async def test_handles_503_service_unavailable(
|
|
self, client: LlamaCommitClient, mock_httpx_response
|
|
) -> None:
|
|
"""Raises LlamaServiceUnavailable on 503 response."""
|
|
mock_response = mock_httpx_response("Service Unavailable", status_code=503)
|
|
|
|
with patch.object(client, "_get_client") as mock_get_client:
|
|
mock_http_client = AsyncMock()
|
|
mock_http_client.post.return_value = mock_response
|
|
mock_get_client.return_value = mock_http_client
|
|
|
|
with pytest.raises(LlamaServiceUnavailable):
|
|
await client.generate_from_diff("diff content", "repo")
|
|
|
|
|
|
class TestDiffToMessageIntegration:
|
|
"""Integration tests: diff parsing -> prompt building -> message cleaning.
|
|
|
|
These tests demonstrate the full pipeline from raw git diff to final
|
|
cleaned commit message, using mocked LLM responses.
|
|
"""
|
|
|
|
@pytest.fixture
|
|
def sample_feature_diff(self) -> str:
|
|
"""Diff showing a new feature being added."""
|
|
return """diff --git a/src/auth/login.py b/src/auth/login.py
|
|
new file mode 100644
|
|
index 0000000..abc1234
|
|
--- /dev/null
|
|
+++ b/src/auth/login.py
|
|
@@ -0,0 +1,25 @@
|
|
+from flask import request, jsonify
|
|
+
|
|
+def login_endpoint():
|
|
+ username = request.json.get('username')
|
|
+ password = request.json.get('password')
|
|
+ if authenticate(username, password):
|
|
+ return jsonify({'token': generate_token(username)})
|
|
+ return jsonify({'error': 'Invalid credentials'}), 401
|
|
"""
|
|
|
|
@pytest.fixture
|
|
def sample_bugfix_diff(self) -> str:
|
|
"""Diff showing a bug being fixed."""
|
|
return """diff --git a/src/api/handler.py b/src/api/handler.py
|
|
index 1234567..abcdefg 100644
|
|
--- a/src/api/handler.py
|
|
+++ b/src/api/handler.py
|
|
@@ -15,7 +15,9 @@ def process_request(data):
|
|
- result = data.get('value')
|
|
+ result = data.get('value')
|
|
+ if result is None:
|
|
+ raise ValueError("Missing required field: value")
|
|
return transform(result)
|
|
"""
|
|
|
|
@pytest.fixture
|
|
def sample_chore_diff(self) -> str:
|
|
"""Diff showing a config/maintenance change."""
|
|
return """diff --git a/pyproject.toml b/pyproject.toml
|
|
index 1234567..abcdefg 100644
|
|
--- a/pyproject.toml
|
|
+++ b/pyproject.toml
|
|
@@ -10,7 +10,7 @@ dependencies = [
|
|
- "httpx>=0.24.0",
|
|
+ "httpx>=0.25.0",
|
|
"pydantic>=2.0.0",
|
|
]
|
|
"""
|
|
|
|
def test_parse_feature_diff(self, sample_feature_diff: str) -> None:
|
|
"""Feature diff is correctly parsed."""
|
|
summary = summarize_diff(sample_feature_diff)
|
|
|
|
assert summary.files_added == 1
|
|
assert summary.files_modified == 0
|
|
assert ".py" in summary.file_types
|
|
assert "src/auth/login.py" in summary.key_files
|
|
|
|
def test_parse_bugfix_diff(self, sample_bugfix_diff: str) -> None:
|
|
"""Bugfix diff is correctly parsed."""
|
|
summary = summarize_diff(sample_bugfix_diff)
|
|
|
|
assert summary.files_modified == 1
|
|
assert summary.additions >= 2 # Added lines
|
|
assert "src/api/handler.py" in summary.key_files
|
|
|
|
def test_parse_chore_diff(self, sample_chore_diff: str) -> None:
|
|
"""Chore/config diff is correctly parsed."""
|
|
summary = summarize_diff(sample_chore_diff)
|
|
|
|
assert summary.files_modified == 1
|
|
assert ".toml" in summary.file_types
|
|
assert "pyproject.toml" in summary.key_files
|
|
|
|
async def test_full_pipeline_feature(self, sample_feature_diff: str) -> None:
|
|
"""Full pipeline for feature: diff -> summary -> prompt -> clean message."""
|
|
summary = summarize_diff(sample_feature_diff)
|
|
client = LlamaCommitClient()
|
|
|
|
# Simulate sloppy LLM response
|
|
sloppy_response = "✨ feat(auth): add login endpoint"
|
|
|
|
mock_response = MagicMock(spec=httpx.Response)
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = {"content": sloppy_response}
|
|
|
|
with patch.object(client, "_get_client") as mock_get_client:
|
|
mock_http_client = AsyncMock()
|
|
mock_http_client.post.return_value = mock_response
|
|
mock_get_client.return_value = mock_http_client
|
|
|
|
result = await client.generate_commit_message(summary, "auth-service")
|
|
|
|
# Emoji moved to correct position
|
|
assert result == "feat(auth): ✨ add login endpoint"
|
|
|
|
async def test_full_pipeline_bugfix(self, sample_bugfix_diff: str) -> None:
|
|
"""Full pipeline for bugfix: diff -> summary -> prompt -> clean message."""
|
|
summary = summarize_diff(sample_bugfix_diff)
|
|
client = LlamaCommitClient()
|
|
|
|
# LLM returns message without emoji
|
|
sloppy_response = "fix(api): handle missing value field"
|
|
|
|
mock_response = MagicMock(spec=httpx.Response)
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = {"content": sloppy_response}
|
|
|
|
with patch.object(client, "_get_client") as mock_get_client:
|
|
mock_http_client = AsyncMock()
|
|
mock_http_client.post.return_value = mock_response
|
|
mock_get_client.return_value = mock_http_client
|
|
|
|
result = await client.generate_commit_message(summary, "api-service")
|
|
|
|
# Emoji should be added
|
|
assert result == "fix(api): 🐛 handle missing value field"
|
|
|
|
async def test_full_pipeline_chore(self, sample_chore_diff: str) -> None:
|
|
"""Full pipeline for chore: diff -> summary -> prompt -> clean message."""
|
|
summary = summarize_diff(sample_chore_diff)
|
|
client = LlamaCommitClient()
|
|
|
|
# LLM returns bare emoji message
|
|
sloppy_response = "⬆️ update httpx dependency"
|
|
|
|
mock_response = MagicMock(spec=httpx.Response)
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = {"content": sloppy_response}
|
|
|
|
with patch.object(client, "_get_client") as mock_get_client:
|
|
mock_http_client = AsyncMock()
|
|
mock_http_client.post.return_value = mock_response
|
|
mock_get_client.return_value = mock_http_client
|
|
|
|
result = await client.generate_commit_message(summary, "project")
|
|
|
|
# Should be wrapped in chore(shared)
|
|
assert result == "chore(shared): ⬆️ update httpx dependency"
|
|
|
|
|
|
class TestClientContextManager:
|
|
"""Tests for async context manager."""
|
|
|
|
async def test_context_manager(self) -> None:
|
|
"""Test using client as async context manager."""
|
|
async with LlamaCommitClient() as client:
|
|
# Trigger client creation by calling a method
|
|
await client.health_check()
|
|
assert client._client is not None
|
|
|
|
# After exit, client should be closed
|
|
assert client._client is None or client._client.is_closed
|