103 lines
3 KiB
Python
103 lines
3 KiB
Python
"""Tests for prompt generation."""
|
|
|
|
import pytest
|
|
|
|
from auto_commit_service.git.diff_parser import DiffSummary
|
|
from auto_commit_service.llm.prompts import (
|
|
COMMIT_SYSTEM_PROMPT,
|
|
build_commit_prompt,
|
|
build_simple_prompt,
|
|
)
|
|
|
|
|
|
class TestCommitSystemPrompt:
|
|
"""Tests for the system prompt."""
|
|
|
|
def test_includes_emoji_rules(self) -> None:
|
|
"""Test that system prompt includes emoji prefix rules."""
|
|
assert "✨" in COMMIT_SYSTEM_PROMPT
|
|
assert "🔧" in COMMIT_SYSTEM_PROMPT
|
|
assert "♻️" in COMMIT_SYSTEM_PROMPT
|
|
assert "🐛" in COMMIT_SYSTEM_PROMPT
|
|
|
|
def test_includes_format_rules(self) -> None:
|
|
"""Test that system prompt includes formatting rules."""
|
|
assert "50 characters" in COMMIT_SYSTEM_PROMPT or "under 50" in COMMIT_SYSTEM_PROMPT
|
|
assert "imperative" in COMMIT_SYSTEM_PROMPT.lower()
|
|
|
|
|
|
class TestBuildCommitPrompt:
|
|
"""Tests for build_commit_prompt function."""
|
|
|
|
def test_includes_summary_info(self) -> None:
|
|
"""Test that prompt includes summary information."""
|
|
summary = DiffSummary(
|
|
files_modified=3,
|
|
files_added=1,
|
|
files_deleted=0,
|
|
additions=50,
|
|
deletions=10,
|
|
)
|
|
|
|
prompt = build_commit_prompt(summary, repo_name="codebase", branch="main")
|
|
|
|
assert "codebase" in prompt
|
|
assert "main" in prompt
|
|
assert "3" in prompt # files_modified
|
|
assert "50" in prompt # additions
|
|
|
|
def test_includes_file_types(self) -> None:
|
|
"""Test that prompt includes file type information."""
|
|
summary = DiffSummary(
|
|
files_modified=2,
|
|
file_types={".py": 2, ".ts": 1},
|
|
)
|
|
|
|
prompt = build_commit_prompt(summary)
|
|
|
|
assert ".py" in prompt
|
|
assert ".ts" in prompt
|
|
|
|
def test_includes_key_files(self) -> None:
|
|
"""Test that prompt includes key file names."""
|
|
summary = DiffSummary(
|
|
files_modified=1,
|
|
key_files=["src/app.py", "src/utils.py"],
|
|
)
|
|
|
|
prompt = build_commit_prompt(summary)
|
|
|
|
assert "src/app.py" in prompt
|
|
assert "src/utils.py" in prompt
|
|
|
|
def test_includes_diff_excerpt(self) -> None:
|
|
"""Test that prompt includes diff excerpt."""
|
|
summary = DiffSummary(
|
|
files_modified=1,
|
|
diff_excerpt="+ new_function()",
|
|
)
|
|
|
|
prompt = build_commit_prompt(summary)
|
|
|
|
assert "new_function" in prompt
|
|
|
|
|
|
class TestBuildSimplePrompt:
|
|
"""Tests for build_simple_prompt function."""
|
|
|
|
def test_includes_diff(self) -> None:
|
|
"""Test that simple prompt includes the diff."""
|
|
diff = "+print('hello world')"
|
|
|
|
prompt = build_simple_prompt(diff, repo_name="test")
|
|
|
|
assert "hello world" in prompt
|
|
assert "test" in prompt
|
|
|
|
def test_truncates_long_diff(self) -> None:
|
|
"""Test that long diffs are truncated."""
|
|
long_diff = "x" * 5000
|
|
|
|
prompt = build_simple_prompt(long_diff)
|
|
|
|
assert len(prompt) < 5000 # Should be truncated
|