✨ Add new files
This commit is contained in:
parent
a02b0e7b9b
commit
6c9036b573
23 changed files with 470 additions and 209 deletions
|
|
@ -37,7 +37,7 @@ The service requires a language model to generate commit messages. You have thre
|
|||
|
||||
2. **Manual model path**:
|
||||
```bash
|
||||
export LLAMA_SERVICE_FAST_MODEL_PATH=/path/to/model.gguf
|
||||
export LLAMA_SERVICE_MODEL_PATH=/path/to/model.gguf
|
||||
```
|
||||
|
||||
3. **Disable auto-start** (use external llama-service):
|
||||
|
|
@ -54,9 +54,8 @@ Environment variables (prefix: `AUTO_COMMIT_`):
|
|||
| Variable | Default | Description |
|
||||
|----------|---------|-------------|
|
||||
| `AUTO_COMMIT_LLAMA_SERVICE_URL` | `http://localhost:8000` | llama-service URL |
|
||||
| `AUTO_COMMIT_LLAMA_MODEL` | `fast` | Model to use (fast/reasoning) |
|
||||
| `AUTO_COMMIT_LLAMA_SERVICE_AUTOSTART` | `true` | Auto-start llama-service if down |
|
||||
| `AUTO_COMMIT_LLAMA_FAST_MODEL_ID` | `ministral-3b-instruct` | Model ID for model-boss |
|
||||
| `AUTO_COMMIT_LLAMA_MODEL_ID` | `deepseek-r1-70b` | Model ID for model-boss |
|
||||
| `AUTO_COMMIT_USE_MODEL_BOSS` | `true` | Use model-boss for model loading |
|
||||
| `AUTO_COMMIT_CYCLE_INTERVAL_SECONDS` | `900` | Seconds between cycles (15 min) |
|
||||
| `AUTO_COMMIT_ENABLED` | `true` | Enable daemon on startup |
|
||||
|
|
|
|||
Binary file not shown.
Binary file not shown.
|
|
@ -38,7 +38,6 @@ def create_auto_commit_service(
|
|||
# Initialize components
|
||||
llm_client = LlamaCommitClient(
|
||||
base_url=settings.llama_service_url,
|
||||
model=settings.llama_model,
|
||||
timeout=settings.llama_timeout,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -23,10 +23,6 @@ class AutoCommitSettings(BaseServiceSettings):
|
|||
default="http://localhost:8000",
|
||||
description="URL for llama-service inference",
|
||||
)
|
||||
llama_model: str = Field(
|
||||
default="fast",
|
||||
description="Model to use for commit message generation (fast/reasoning)",
|
||||
)
|
||||
llama_timeout: float = Field(
|
||||
default=30.0,
|
||||
description="Timeout for LLM requests in seconds",
|
||||
|
|
@ -123,13 +119,9 @@ class AutoCommitSettings(BaseServiceSettings):
|
|||
)
|
||||
|
||||
# Model-boss integration for auto-loading LLM
|
||||
llama_fast_model_id: str = Field(
|
||||
llama_model_id: str = Field(
|
||||
default="deepseek-r1-70b",
|
||||
description="Model ID for fast commit message generation (resolved via model-boss)",
|
||||
)
|
||||
llama_reasoning_model_id: str | None = Field(
|
||||
default=None,
|
||||
description="Optional model ID for reasoning tasks (resolved via model-boss)",
|
||||
description="Model ID for commit message generation (resolved via model-boss)",
|
||||
)
|
||||
use_model_boss: bool = Field(
|
||||
default=True,
|
||||
|
|
|
|||
Binary file not shown.
|
|
@ -46,25 +46,28 @@ def discover_git_repos(
|
|||
# Path is not relative to base_path, skip
|
||||
continue
|
||||
|
||||
# Check depth limit
|
||||
if max_depth is not None and current_depth >= max_depth:
|
||||
dirs.clear() # Don't descend further
|
||||
continue
|
||||
|
||||
# Filter excluded directories IN-PLACE
|
||||
# This prevents os.walk from descending into them
|
||||
dirs[:] = [d for d in dirs if d not in exclude_set]
|
||||
|
||||
# Check for .git directory BEFORE clearing for depth limit
|
||||
has_git = ".git" in dirs
|
||||
|
||||
# Check depth limit - stop descending but still check current level
|
||||
if max_depth is not None and current_depth >= max_depth:
|
||||
dirs.clear() # Don't descend further
|
||||
|
||||
# Check for .git directory
|
||||
if ".git" in dirs:
|
||||
if has_git:
|
||||
# Validate it's a directory, not a file (gitlinks/submodules use .git file)
|
||||
git_path = root_path / ".git"
|
||||
if git_path.is_dir() and not should_exclude(root_path):
|
||||
repos.append(root_path)
|
||||
logger.debug(f"Found repo: {root_path}")
|
||||
|
||||
# Don't descend into .git
|
||||
dirs.remove(".git")
|
||||
# Don't descend into .git (may already be removed by depth limit)
|
||||
if ".git" in dirs:
|
||||
dirs.remove(".git")
|
||||
|
||||
except PermissionError as e:
|
||||
logger.warning(f"Permission denied accessing {base_path}: {e}")
|
||||
|
|
|
|||
Binary file not shown.
|
|
@ -31,22 +31,11 @@ class LlamaCommitClient:
|
|||
def __init__(
|
||||
self,
|
||||
base_url: str = "http://localhost:8000",
|
||||
model: str = "fast",
|
||||
timeout: float = 30.0,
|
||||
max_tokens: int = 100,
|
||||
temperature: float = 0.2,
|
||||
):
|
||||
"""Initialize the client.
|
||||
|
||||
Args:
|
||||
base_url: llama-service URL
|
||||
model: Model to use (fast or reasoning)
|
||||
timeout: Request timeout in seconds
|
||||
max_tokens: Maximum tokens to generate
|
||||
temperature: Generation temperature (lower = more deterministic)
|
||||
"""
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.model = model
|
||||
self.timeout = timeout
|
||||
self.max_tokens = max_tokens
|
||||
self.temperature = temperature
|
||||
|
|
@ -65,51 +54,23 @@ class LlamaCommitClient:
|
|||
self._client = None
|
||||
|
||||
async def health_check(self) -> dict[str, Any]:
|
||||
"""Check if llama-service is available.
|
||||
|
||||
Returns:
|
||||
Health status dict with keys: status, fast_model_loaded, reasoning_model_loaded
|
||||
"""
|
||||
"""Check if llama-service is available."""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
response = await client.get(f"{self.base_url}/health")
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except httpx.ConnectError:
|
||||
return {
|
||||
"status": "error",
|
||||
"error": "Service unreachable",
|
||||
"fast_model_loaded": False,
|
||||
"reasoning_model_loaded": False,
|
||||
}
|
||||
return {"status": "error", "error": "Service unreachable", "model_loaded": False}
|
||||
except httpx.TimeoutException:
|
||||
return {
|
||||
"status": "error",
|
||||
"error": "Health check timeout",
|
||||
"fast_model_loaded": False,
|
||||
"reasoning_model_loaded": False,
|
||||
}
|
||||
return {"status": "error", "error": "Health check timeout", "model_loaded": False}
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "error",
|
||||
"error": str(e),
|
||||
"fast_model_loaded": False,
|
||||
"reasoning_model_loaded": False,
|
||||
}
|
||||
return {"status": "error", "error": str(e), "model_loaded": False}
|
||||
|
||||
async def is_available(self) -> bool:
|
||||
"""Check if the service is available and has models loaded."""
|
||||
"""Check if the service is available and has model loaded."""
|
||||
health = await self.health_check()
|
||||
if health.get("status") == "error":
|
||||
return False
|
||||
|
||||
# Check if required model is loaded
|
||||
if self.model == "fast":
|
||||
return health.get("fast_model_loaded", False)
|
||||
elif self.model == "reasoning":
|
||||
return health.get("reasoning_model_loaded", False)
|
||||
|
||||
return health.get("fast_model_loaded", False)
|
||||
return health.get("status") == "ok" and health.get("model_loaded", False)
|
||||
|
||||
async def generate_commit_message(
|
||||
self,
|
||||
|
|
@ -169,7 +130,6 @@ class LlamaCommitClient:
|
|||
json={
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"system_prompt": COMMIT_SYSTEM_PROMPT,
|
||||
"model": self.model,
|
||||
"max_tokens": self.max_tokens,
|
||||
"temperature": self.temperature,
|
||||
"stream": False,
|
||||
|
|
|
|||
Binary file not shown.
|
|
@ -48,8 +48,7 @@ class CommitDaemon:
|
|||
lock_file=settings.llama_service_lock_file,
|
||||
startup_timeout=settings.llama_service_startup_timeout,
|
||||
health_check_timeout=5.0,
|
||||
fast_model_id=settings.llama_fast_model_id,
|
||||
reasoning_model_id=settings.llama_reasoning_model_id,
|
||||
model_id=settings.llama_model_id,
|
||||
use_model_boss=settings.use_model_boss,
|
||||
)
|
||||
else:
|
||||
|
|
@ -400,11 +399,6 @@ class CommitDaemon:
|
|||
logger.error("Failed to start llama service")
|
||||
return False
|
||||
|
||||
if health == ServiceHealth.DEGRADED:
|
||||
logger.warning("Llama service is degraded")
|
||||
# Fail-fast: degraded service is not acceptable
|
||||
return False
|
||||
|
||||
# Service is healthy
|
||||
self._service_crashed = False
|
||||
return True
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ from .manager import (
|
|||
ServiceHealth,
|
||||
ServiceManagerError,
|
||||
ServiceStartError,
|
||||
ServiceCrashError,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
|
|
@ -13,5 +12,4 @@ __all__ = [
|
|||
"ServiceHealth",
|
||||
"ServiceManagerError",
|
||||
"ServiceStartError",
|
||||
"ServiceCrashError",
|
||||
]
|
||||
|
|
|
|||
Binary file not shown.
Binary file not shown.
|
|
@ -23,15 +23,10 @@ class ServiceStartError(ServiceManagerError):
|
|||
"""Failed to start service."""
|
||||
|
||||
|
||||
class ServiceCrashError(ServiceManagerError):
|
||||
"""Service crashed unexpectedly."""
|
||||
|
||||
|
||||
class ServiceHealth(str, Enum):
|
||||
"""Service health status."""
|
||||
|
||||
HEALTHY = "healthy"
|
||||
DEGRADED = "degraded"
|
||||
CRASHED = "crashed"
|
||||
UNREACHABLE = "unreachable"
|
||||
|
||||
|
|
@ -46,23 +41,19 @@ class LlamaServiceManager:
|
|||
lock_file: Path | None = None,
|
||||
startup_timeout: float = 30.0,
|
||||
health_check_timeout: float = 5.0,
|
||||
fast_model_id: str | None = None,
|
||||
reasoning_model_id: str | None = None,
|
||||
use_model_boss: bool = True,
|
||||
model_id: str | None = None,
|
||||
use_model_boss: bool = False,
|
||||
):
|
||||
"""Initialize service manager."""
|
||||
self.service_url = service_url
|
||||
self._pid_file = pid_file or Path.home() / ".config/commits/llama-service.pid"
|
||||
self._lock_file = lock_file or Path.home() / ".config/commits/llama-service.lock"
|
||||
self._startup_timeout = startup_timeout
|
||||
self._health_check_timeout = health_check_timeout
|
||||
self._fast_model_id = fast_model_id
|
||||
self._reasoning_model_id = reasoning_model_id
|
||||
self._model_id = model_id
|
||||
self._use_model_boss = use_model_boss
|
||||
self._spawned_pid: int | None = None
|
||||
self._lock_fd: int | None = None
|
||||
self._resolved_fast_model_path: str | None = None
|
||||
self._resolved_reasoning_model_path: str | None = None
|
||||
self._resolved_model_path: str | None = None
|
||||
|
||||
async def ensure_service_available(self) -> bool:
|
||||
"""Ensure service is available, starting if necessary."""
|
||||
|
|
@ -70,52 +61,33 @@ class LlamaServiceManager:
|
|||
|
||||
if health == ServiceHealth.HEALTHY:
|
||||
return True
|
||||
if health == ServiceHealth.DEGRADED:
|
||||
# Degraded but running - acceptable for commits
|
||||
return True
|
||||
if health == ServiceHealth.CRASHED:
|
||||
# Stale PID file from previous session - clean up and restart
|
||||
logger.info("Detected crashed service (stale PID file), cleaning up...")
|
||||
logger.info("Detected crashed service (stale PID), cleaning up...")
|
||||
self._cleanup_pid_file()
|
||||
# Fall through to restart logic
|
||||
|
||||
logger.info("Llama service unreachable, attempting to start...")
|
||||
|
||||
try:
|
||||
# Resolve model paths via model-boss before starting
|
||||
if self._use_model_boss and self._fast_model_id:
|
||||
await self._resolve_model_paths()
|
||||
|
||||
if self._use_model_boss and self._model_id:
|
||||
await self._resolve_model_path()
|
||||
return await self.start_service()
|
||||
except ServiceStartError as e:
|
||||
logger.error(f"Failed to start llama-service: {e}")
|
||||
return False
|
||||
|
||||
async def _resolve_model_paths(self) -> None:
|
||||
"""Resolve model IDs to paths via model-boss.
|
||||
|
||||
Raises:
|
||||
ServiceStartError: If model resolution fails
|
||||
"""
|
||||
async def _resolve_model_path(self) -> None:
|
||||
"""Resolve model ID to path via model-boss."""
|
||||
try:
|
||||
from lilith_model_boss import ensure_model
|
||||
|
||||
if self._fast_model_id and not self._resolved_fast_model_path:
|
||||
logger.info(f"Resolving fast model via model-boss: {self._fast_model_id}")
|
||||
self._resolved_fast_model_path = ensure_model(self._fast_model_id)
|
||||
logger.info(f"Resolved fast model path: {self._resolved_fast_model_path}")
|
||||
|
||||
if self._reasoning_model_id and not self._resolved_reasoning_model_path:
|
||||
logger.info(f"Resolving reasoning model via model-boss: {self._reasoning_model_id}")
|
||||
self._resolved_reasoning_model_path = ensure_model(self._reasoning_model_id)
|
||||
logger.info(f"Resolved reasoning model path: {self._resolved_reasoning_model_path}")
|
||||
|
||||
if self._model_id and not self._resolved_model_path:
|
||||
logger.info(f"Resolving model via model-boss: {self._model_id}")
|
||||
self._resolved_model_path = ensure_model(self._model_id)
|
||||
logger.info(f"Resolved model path: {self._resolved_model_path}")
|
||||
except ImportError:
|
||||
raise ServiceStartError(
|
||||
"model-boss not installed. Install with: pip install auto-commit-service[model-boss]"
|
||||
)
|
||||
raise ServiceStartError("model-boss not installed")
|
||||
except Exception as e:
|
||||
raise ServiceStartError(f"Failed to resolve model paths: {e}")
|
||||
raise ServiceStartError(f"Failed to resolve model path: {e}")
|
||||
|
||||
async def start_service(self) -> bool:
|
||||
"""Start llama service subprocess."""
|
||||
|
|
@ -148,7 +120,6 @@ class LlamaServiceManager:
|
|||
else:
|
||||
logger.error(f"✗ Service failed to start within {self._startup_timeout}s")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to start llama service: {e}")
|
||||
return False
|
||||
|
|
@ -156,7 +127,7 @@ class LlamaServiceManager:
|
|||
self._release_lock()
|
||||
|
||||
async def check_health(self) -> ServiceHealth:
|
||||
"""Check service health and detect crashes."""
|
||||
"""Check service health."""
|
||||
pid = self._read_pid_file()
|
||||
if pid and not self._is_process_alive(pid):
|
||||
return ServiceHealth.CRASHED
|
||||
|
|
@ -166,7 +137,7 @@ class LlamaServiceManager:
|
|||
response = await client.get(f"{self.service_url}/health")
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
return ServiceHealth.HEALTHY if data.get("status") == "ok" else ServiceHealth.DEGRADED
|
||||
return ServiceHealth.HEALTHY if data.get("status") == "ok" else ServiceHealth.UNREACHABLE
|
||||
return ServiceHealth.UNREACHABLE
|
||||
except (httpx.ConnectError, httpx.TimeoutException):
|
||||
return ServiceHealth.UNREACHABLE
|
||||
|
|
@ -194,7 +165,6 @@ class LlamaServiceManager:
|
|||
logger.exception(f"Error stopping service: {e}")
|
||||
|
||||
def _acquire_lock(self) -> bool:
|
||||
"""Acquire exclusive lock."""
|
||||
self._lock_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
self._lock_fd = os.open(self._lock_file, os.O_CREAT | os.O_WRONLY)
|
||||
|
|
@ -204,7 +174,6 @@ class LlamaServiceManager:
|
|||
return False
|
||||
|
||||
def _release_lock(self) -> None:
|
||||
"""Release lock."""
|
||||
if self._lock_fd is not None:
|
||||
try:
|
||||
fcntl.flock(self._lock_fd, fcntl.LOCK_UN)
|
||||
|
|
@ -214,7 +183,6 @@ class LlamaServiceManager:
|
|||
pass
|
||||
|
||||
def _read_pid_file(self) -> int | None:
|
||||
"""Read PID from file."""
|
||||
if not self._pid_file.exists():
|
||||
return None
|
||||
try:
|
||||
|
|
@ -224,12 +192,10 @@ class LlamaServiceManager:
|
|||
return None
|
||||
|
||||
def _write_pid_file(self, pid: int) -> None:
|
||||
"""Write PID to file."""
|
||||
self._pid_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._pid_file.write_text(str(pid))
|
||||
|
||||
def _cleanup_pid_file(self) -> None:
|
||||
"""Remove PID file."""
|
||||
try:
|
||||
if self._pid_file.exists():
|
||||
self._pid_file.unlink()
|
||||
|
|
@ -237,7 +203,6 @@ class LlamaServiceManager:
|
|||
pass
|
||||
|
||||
def _is_process_alive(self, pid: int) -> bool:
|
||||
"""Check if process is alive."""
|
||||
try:
|
||||
os.kill(pid, 0)
|
||||
return True
|
||||
|
|
@ -250,79 +215,48 @@ class LlamaServiceManager:
|
|||
if not cache_dir.exists():
|
||||
return None
|
||||
|
||||
# Preferred models for commit message generation (fast, small)
|
||||
preferred_models = [
|
||||
preferred = [
|
||||
"qwen2.5-1.5b-instruct-q4_k_m.gguf",
|
||||
"Ministral-3-3B-Instruct-2512-Q8_0.gguf",
|
||||
"ministral-3b-instruct",
|
||||
]
|
||||
|
||||
# Check for preferred models first
|
||||
for model_name in preferred_models:
|
||||
model_path = cache_dir / model_name
|
||||
if model_path.exists():
|
||||
return str(model_path)
|
||||
for name in preferred:
|
||||
path = cache_dir / name
|
||||
if path.exists():
|
||||
return str(path)
|
||||
|
||||
# Fall back to any small GGUF file (< 5GB)
|
||||
for gguf_file in cache_dir.glob("*.gguf"):
|
||||
if gguf_file.stat().st_size < 5 * 1024 * 1024 * 1024: # < 5GB
|
||||
return str(gguf_file)
|
||||
for gguf in cache_dir.glob("*.gguf"):
|
||||
if gguf.stat().st_size < 5 * 1024 * 1024 * 1024:
|
||||
return str(gguf)
|
||||
|
||||
return None
|
||||
|
||||
async def _spawn_service(self) -> asyncio.subprocess.Process:
|
||||
"""Spawn service as background subprocess.
|
||||
|
||||
Raises:
|
||||
ServiceStartError: If no model paths are configured
|
||||
"""
|
||||
"""Spawn service subprocess."""
|
||||
cmd = [sys.executable, "-m", "lilith_llama_service"]
|
||||
env = os.environ.copy()
|
||||
|
||||
# Use resolved model paths from model-boss if available
|
||||
has_model_paths = False
|
||||
model_path = self._resolved_model_path or env.get("LLAMA_SERVICE_MODEL_PATH")
|
||||
|
||||
if self._resolved_fast_model_path:
|
||||
env["LLAMA_SERVICE_FAST_MODEL_PATH"] = self._resolved_fast_model_path
|
||||
has_model_paths = True
|
||||
logger.info(f"Using fast model: {self._resolved_fast_model_path}")
|
||||
if not model_path:
|
||||
model_path = self._find_default_model()
|
||||
|
||||
if self._resolved_reasoning_model_path:
|
||||
env["LLAMA_SERVICE_REASONING_MODEL_PATH"] = self._resolved_reasoning_model_path
|
||||
has_model_paths = True
|
||||
logger.info(f"Using reasoning model: {self._resolved_reasoning_model_path}")
|
||||
|
||||
# Fall back to environment variables if set
|
||||
if not has_model_paths:
|
||||
if "LLAMA_SERVICE_FAST_MODEL_PATH" in env or "LLAMA_SERVICE_REASONING_MODEL_PATH" in env:
|
||||
has_model_paths = True
|
||||
logger.info("Using model paths from environment variables")
|
||||
|
||||
# Fall back to auto-discovered model in cache
|
||||
if not has_model_paths:
|
||||
default_model = self._find_default_model()
|
||||
if default_model:
|
||||
env["LLAMA_SERVICE_FAST_MODEL_PATH"] = default_model
|
||||
has_model_paths = True
|
||||
logger.info(f"Using auto-discovered model: {default_model}")
|
||||
|
||||
# Fail if no models are configured - do not fall back to mock mode
|
||||
if not has_model_paths:
|
||||
if not model_path:
|
||||
raise ServiceStartError(
|
||||
"No model paths configured and no models found in ~/.cache/models/. Either:\n"
|
||||
" 1. Install model-boss: pip install auto-commit-service[model-boss]\n"
|
||||
" 2. Set LLAMA_SERVICE_FAST_MODEL_PATH environment variable\n"
|
||||
" 3. Place a GGUF model in ~/.cache/models/\n"
|
||||
" 4. Disable llama_service_autostart in config"
|
||||
"No model found. Either:\n"
|
||||
" 1. Set LLAMA_SERVICE_MODEL_PATH\n"
|
||||
" 2. Place a GGUF model in ~/.cache/models/"
|
||||
)
|
||||
|
||||
env["LLAMA_SERVICE_MODEL_PATH"] = model_path
|
||||
logger.info(f"Using model: {model_path}")
|
||||
|
||||
log_file = self._pid_file.parent / "llama-service.log"
|
||||
log_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(log_file, "a") as log:
|
||||
log.write(f"\n=== Service started at {time.ctime()} ===\n")
|
||||
log.write(f"Fast model: {env.get('LLAMA_SERVICE_FAST_MODEL_PATH', 'not set')}\n")
|
||||
log.write(f"Reasoning model: {env.get('LLAMA_SERVICE_REASONING_MODEL_PATH', 'not set')}\n")
|
||||
log.write(f"Model: {model_path}\n")
|
||||
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
|
|
@ -334,7 +268,6 @@ class LlamaServiceManager:
|
|||
return process
|
||||
|
||||
async def _wait_for_healthy(self, timeout: float) -> bool:
|
||||
"""Wait for service to become healthy."""
|
||||
start = time.time()
|
||||
while time.time() - start < timeout:
|
||||
try:
|
||||
|
|
|
|||
BIN
tests/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
tests/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
tests/__pycache__/conftest.cpython-312-pytest-9.0.2.pyc
Normal file
BIN
tests/__pycache__/conftest.cpython-312-pytest-9.0.2.pyc
Normal file
Binary file not shown.
BIN
tests/__pycache__/test_diff_parser.cpython-312-pytest-9.0.2.pyc
Normal file
BIN
tests/__pycache__/test_diff_parser.cpython-312-pytest-9.0.2.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
tests/__pycache__/test_llm_client.cpython-312-pytest-9.0.2.pyc
Normal file
BIN
tests/__pycache__/test_llm_client.cpython-312-pytest-9.0.2.pyc
Normal file
Binary file not shown.
BIN
tests/__pycache__/test_prompts.cpython-312-pytest-9.0.2.pyc
Normal file
BIN
tests/__pycache__/test_prompts.cpython-312-pytest-9.0.2.pyc
Normal file
Binary file not shown.
|
|
@ -1,9 +1,14 @@
|
|||
"""Tests for LLM client module."""
|
||||
"""Tests for LLM client module.
|
||||
|
||||
Proves that "sloppy" commit messages from the LLM get properly cleaned
|
||||
into the expected format: type(scope): emoji description
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
import httpx
|
||||
|
||||
from auto_commit_service.git.diff_parser import DiffSummary
|
||||
from auto_commit_service.git.diff_parser import DiffSummary, summarize_diff
|
||||
from auto_commit_service.llm.client import (
|
||||
LlamaCommitClient,
|
||||
LlamaServiceError,
|
||||
|
|
@ -19,47 +24,423 @@ class TestLlamaCommitClient:
|
|||
"""Create a test client."""
|
||||
return LlamaCommitClient(
|
||||
base_url="http://localhost:8000",
|
||||
model="fast",
|
||||
timeout=5.0,
|
||||
)
|
||||
|
||||
async def test_health_check_unavailable(self, client: LlamaCommitClient) -> None:
|
||||
"""Test health check when service is unavailable."""
|
||||
# No mock server, should fail
|
||||
health = await client.health_check()
|
||||
|
||||
assert health["status"] == "error"
|
||||
assert not health["fast_model_loaded"]
|
||||
assert not health["model_loaded"]
|
||||
|
||||
async def test_is_available_when_down(self, client: LlamaCommitClient) -> None:
|
||||
"""Test is_available when service is down."""
|
||||
available = await client.is_available()
|
||||
assert not available
|
||||
|
||||
def test_clean_response_removes_quotes(self, client: LlamaCommitClient) -> None:
|
||||
"""Test that clean_response removes surrounding quotes."""
|
||||
result = client._clean_response('"✨ Add new feature"')
|
||||
assert result == "✨ Add new feature"
|
||||
|
||||
def test_clean_response_removes_code_blocks(self, client: LlamaCommitClient) -> None:
|
||||
"""Test that clean_response removes markdown code blocks."""
|
||||
result = client._clean_response("```\n✨ Add new feature\n```")
|
||||
assert result == "✨ Add new feature"
|
||||
class TestCleanResponseNormalization:
|
||||
"""Tests proving the _clean_response method normalizes sloppy LLM output.
|
||||
|
||||
def test_clean_response_takes_first_line(self, client: LlamaCommitClient) -> None:
|
||||
"""Test that clean_response takes only first line."""
|
||||
result = client._clean_response("✨ Add feature\nThis is extra text")
|
||||
assert result == "✨ Add feature"
|
||||
The LLM may produce commit messages in various formats. These tests
|
||||
verify that all variations get normalized to: type(scope): emoji description
|
||||
"""
|
||||
|
||||
def test_clean_response_adds_emoji_if_missing(self, client: LlamaCommitClient) -> None:
|
||||
"""Test that clean_response adds emoji for known types."""
|
||||
result = client._clean_response("Add new authentication")
|
||||
assert result.startswith("✨")
|
||||
@pytest.fixture
|
||||
def client(self) -> LlamaCommitClient:
|
||||
return LlamaCommitClient()
|
||||
|
||||
def test_clean_response_preserves_valid_emoji(self, client: LlamaCommitClient) -> None:
|
||||
"""Test that valid emoji prefixes are preserved."""
|
||||
result = client._clean_response("🔧 Update dependencies")
|
||||
assert result == "🔧 Update dependencies"
|
||||
# ==========================================================================
|
||||
# Already correct format - should pass through
|
||||
# ==========================================================================
|
||||
|
||||
def test_correct_format_passes_through(self, client: LlamaCommitClient) -> None:
|
||||
"""Correctly formatted messages pass through unchanged."""
|
||||
result = client._clean_response("feat(auth): ✨ add login endpoint")
|
||||
assert result == "feat(auth): ✨ add login endpoint"
|
||||
|
||||
def test_correct_format_with_different_types(self, client: LlamaCommitClient) -> None:
|
||||
"""All valid types with correct format pass through (emojis may be normalized)."""
|
||||
test_cases = [
|
||||
("fix(api): 🐛 resolve null pointer in handler", "fix(api):"),
|
||||
("refactor(core): ♻️ extract validation logic", "refactor(core):"),
|
||||
("chore(deps): 🔧 update eslint to v9", "chore(deps):"),
|
||||
("docs(readme): 📝 add installation section", "docs(readme):"),
|
||||
("build(ci): ⬆️ upgrade node to v20", "build(ci):"),
|
||||
("test(auth): ✅ add unit tests for login", "test(auth):"),
|
||||
("perf(query): ⚡ optimize database lookup", "perf(query):"),
|
||||
]
|
||||
for msg, expected_prefix in test_cases:
|
||||
result = client._clean_response(msg)
|
||||
assert result.startswith(expected_prefix), f"Expected {result} to start with {expected_prefix}"
|
||||
# Verify description is preserved
|
||||
assert "resolve null pointer" in result or "extract validation" in result or \
|
||||
"update eslint" in result or "add installation" in result or \
|
||||
"upgrade node" in result or "add unit tests" in result or \
|
||||
"optimize database" in result
|
||||
|
||||
# ==========================================================================
|
||||
# Markdown formatting removal
|
||||
# ==========================================================================
|
||||
|
||||
def test_removes_markdown_code_blocks(self, client: LlamaCommitClient) -> None:
|
||||
"""Strips markdown code block wrapper."""
|
||||
result = client._clean_response("```\nfeat(ui): ✨ add button component\n```")
|
||||
assert result == "feat(ui): ✨ add button component"
|
||||
|
||||
def test_removes_language_tagged_code_blocks(self, client: LlamaCommitClient) -> None:
|
||||
"""Strips code blocks with language tags."""
|
||||
result = client._clean_response("```text\nfix(api): 🐛 fix timeout\n```")
|
||||
assert result == "fix(api): 🐛 fix timeout"
|
||||
|
||||
def test_removes_surrounding_quotes(self, client: LlamaCommitClient) -> None:
|
||||
"""Strips surrounding double quotes."""
|
||||
result = client._clean_response('"feat(auth): ✨ add oauth"')
|
||||
assert result == "feat(auth): ✨ add oauth"
|
||||
|
||||
def test_removes_single_quotes(self, client: LlamaCommitClient) -> None:
|
||||
"""Strips surrounding single quotes."""
|
||||
result = client._clean_response("'chore(config): 🔧 update settings'")
|
||||
assert result == "chore(config): 🔧 update settings"
|
||||
|
||||
def test_takes_only_first_line(self, client: LlamaCommitClient) -> None:
|
||||
"""When LLM produces multiple lines, only first is used."""
|
||||
sloppy = """feat(api): ✨ add user endpoint
|
||||
This adds a new REST endpoint for user management.
|
||||
It supports CRUD operations."""
|
||||
result = client._clean_response(sloppy)
|
||||
assert result == "feat(api): ✨ add user endpoint"
|
||||
|
||||
# ==========================================================================
|
||||
# Emoji position correction
|
||||
# ==========================================================================
|
||||
|
||||
def test_moves_emoji_from_start_to_after_colon(self, client: LlamaCommitClient) -> None:
|
||||
"""When emoji comes before type, move it after the colon."""
|
||||
result = client._clean_response("✨ feat(ui): add new component")
|
||||
assert result == "feat(ui): ✨ add new component"
|
||||
|
||||
def test_fixes_emoji_before_type_with_fix(self, client: LlamaCommitClient) -> None:
|
||||
"""Bug fix emoji at start gets repositioned."""
|
||||
result = client._clean_response("🐛 fix(auth): resolve login bug")
|
||||
assert result == "fix(auth): 🐛 resolve login bug"
|
||||
|
||||
# ==========================================================================
|
||||
# Missing emoji addition
|
||||
# ==========================================================================
|
||||
|
||||
def test_adds_emoji_for_feat_type(self, client: LlamaCommitClient) -> None:
|
||||
"""Adds ✨ emoji for feat type when missing."""
|
||||
result = client._clean_response("feat(api): add health endpoint")
|
||||
assert result == "feat(api): ✨ add health endpoint"
|
||||
|
||||
def test_adds_emoji_for_fix_type(self, client: LlamaCommitClient) -> None:
|
||||
"""Adds 🐛 emoji for fix type when missing."""
|
||||
result = client._clean_response("fix(auth): resolve timeout issue")
|
||||
assert result == "fix(auth): 🐛 resolve timeout issue"
|
||||
|
||||
def test_adds_emoji_for_refactor_type(self, client: LlamaCommitClient) -> None:
|
||||
"""Adds ♻️ emoji for refactor type when missing."""
|
||||
result = client._clean_response("refactor(core): extract shared logic")
|
||||
assert result == "refactor(core): ♻️ extract shared logic"
|
||||
|
||||
def test_adds_emoji_for_chore_type(self, client: LlamaCommitClient) -> None:
|
||||
"""Adds 🔧 emoji for chore type when missing."""
|
||||
result = client._clean_response("chore(deps): update dependencies")
|
||||
assert result == "chore(deps): 🔧 update dependencies"
|
||||
|
||||
def test_adds_emoji_for_docs_type(self, client: LlamaCommitClient) -> None:
|
||||
"""Adds 📝 emoji for docs type when missing."""
|
||||
result = client._clean_response("docs(readme): update installation guide")
|
||||
assert result == "docs(readme): 📝 update installation guide"
|
||||
|
||||
# ==========================================================================
|
||||
# Fallback to chore(shared) for emoji-only messages
|
||||
# ==========================================================================
|
||||
|
||||
def test_wraps_emoji_only_message(self, client: LlamaCommitClient) -> None:
|
||||
"""Bare emoji + description gets wrapped in chore(shared)."""
|
||||
result = client._clean_response("✨ add new feature")
|
||||
assert result == "chore(shared): ✨ add new feature"
|
||||
|
||||
def test_wraps_wrench_emoji_message(self, client: LlamaCommitClient) -> None:
|
||||
"""Wrench emoji messages become chore(shared)."""
|
||||
result = client._clean_response("🔧 update config")
|
||||
assert result == "chore(shared): 🔧 update config"
|
||||
|
||||
# ==========================================================================
|
||||
# Keyword inference for unstructured messages
|
||||
# ==========================================================================
|
||||
|
||||
def test_infers_feat_from_add_keyword(self, client: LlamaCommitClient) -> None:
|
||||
"""'add' keyword triggers feat type inference."""
|
||||
result = client._clean_response("add new authentication module")
|
||||
assert result.startswith("feat(shared): ✨")
|
||||
assert "add new authentication module" in result
|
||||
|
||||
def test_infers_fix_from_fix_keyword(self, client: LlamaCommitClient) -> None:
|
||||
"""'fix' keyword triggers fix type inference."""
|
||||
result = client._clean_response("fix null pointer in user handler")
|
||||
assert result.startswith("fix(shared): 🐛")
|
||||
|
||||
def test_infers_refactor_from_refactor_keyword(self, client: LlamaCommitClient) -> None:
|
||||
"""'refactor' keyword triggers refactor type inference."""
|
||||
result = client._clean_response("refactor the database layer")
|
||||
assert result.startswith("refactor(shared): ♻️")
|
||||
|
||||
def test_infers_chore_from_update_keyword(self, client: LlamaCommitClient) -> None:
|
||||
"""'update' keyword triggers chore type inference."""
|
||||
result = client._clean_response("update eslint configuration")
|
||||
assert result.startswith("chore(shared): 🔧")
|
||||
|
||||
def test_fallback_to_chore_for_unknown(self, client: LlamaCommitClient) -> None:
|
||||
"""Unknown messages default to chore(shared): 🔧."""
|
||||
result = client._clean_response("miscellaneous changes to codebase")
|
||||
assert result.startswith("chore(shared): 🔧")
|
||||
|
||||
|
||||
class TestCommitMessageGeneration:
|
||||
"""End-to-end tests for commit message generation with mocked LLM."""
|
||||
|
||||
@pytest.fixture
|
||||
def client(self) -> LlamaCommitClient:
|
||||
return LlamaCommitClient(base_url="http://test:8000")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_httpx_response(self):
|
||||
"""Factory to create mock httpx responses."""
|
||||
def _create(content: str, status_code: int = 200):
|
||||
response = MagicMock(spec=httpx.Response)
|
||||
response.status_code = status_code
|
||||
response.json.return_value = {"content": content}
|
||||
response.text = content
|
||||
return response
|
||||
return _create
|
||||
|
||||
async def test_generate_from_diff_cleans_output(
|
||||
self, client: LlamaCommitClient, mock_httpx_response
|
||||
) -> None:
|
||||
"""Full flow: diff -> prompt -> mock LLM -> cleaned message."""
|
||||
# Simulate LLM returning a sloppy message
|
||||
mock_response = mock_httpx_response('"✨ feat(api): add new endpoint"')
|
||||
|
||||
with patch.object(client, "_get_client") as mock_get_client:
|
||||
mock_http_client = AsyncMock()
|
||||
mock_http_client.post.return_value = mock_response
|
||||
mock_get_client.return_value = mock_http_client
|
||||
|
||||
result = await client.generate_from_diff(
|
||||
diff="diff --git a/api.py b/api.py\n+def new_endpoint(): pass",
|
||||
repo_name="test-repo",
|
||||
)
|
||||
|
||||
# Quotes should be stripped, message cleaned
|
||||
assert result == "feat(api): ✨ add new endpoint"
|
||||
|
||||
async def test_generate_commit_message_with_summary(
|
||||
self, client: LlamaCommitClient, mock_httpx_response
|
||||
) -> None:
|
||||
"""Generate from DiffSummary produces cleaned message."""
|
||||
summary = DiffSummary(
|
||||
files_modified=2,
|
||||
files_added=1,
|
||||
additions=50,
|
||||
deletions=10,
|
||||
file_types={".py": 2, ".md": 1},
|
||||
key_files=["src/api.py", "src/utils.py", "README.md"],
|
||||
diff_excerpt="@@ -1,5 +1,10 @@ ...",
|
||||
)
|
||||
|
||||
# Simulate LLM returning message without emoji
|
||||
mock_response = mock_httpx_response("feat(api): add authentication module")
|
||||
|
||||
with patch.object(client, "_get_client") as mock_get_client:
|
||||
mock_http_client = AsyncMock()
|
||||
mock_http_client.post.return_value = mock_response
|
||||
mock_get_client.return_value = mock_http_client
|
||||
|
||||
result = await client.generate_commit_message(
|
||||
diff_summary=summary,
|
||||
repo_name="auth-service",
|
||||
branch="main",
|
||||
)
|
||||
|
||||
# Emoji should be added automatically
|
||||
assert result == "feat(api): ✨ add authentication module"
|
||||
|
||||
async def test_handles_service_unavailable(
|
||||
self, client: LlamaCommitClient
|
||||
) -> None:
|
||||
"""Raises LlamaServiceUnavailable when service is down."""
|
||||
with patch.object(client, "_get_client") as mock_get_client:
|
||||
mock_http_client = AsyncMock()
|
||||
mock_http_client.post.side_effect = httpx.ConnectError("Connection refused")
|
||||
mock_get_client.return_value = mock_http_client
|
||||
|
||||
with pytest.raises(LlamaServiceUnavailable):
|
||||
await client.generate_from_diff("diff content", "repo")
|
||||
|
||||
async def test_handles_503_service_unavailable(
|
||||
self, client: LlamaCommitClient, mock_httpx_response
|
||||
) -> None:
|
||||
"""Raises LlamaServiceUnavailable on 503 response."""
|
||||
mock_response = mock_httpx_response("Service Unavailable", status_code=503)
|
||||
|
||||
with patch.object(client, "_get_client") as mock_get_client:
|
||||
mock_http_client = AsyncMock()
|
||||
mock_http_client.post.return_value = mock_response
|
||||
mock_get_client.return_value = mock_http_client
|
||||
|
||||
with pytest.raises(LlamaServiceUnavailable):
|
||||
await client.generate_from_diff("diff content", "repo")
|
||||
|
||||
|
||||
class TestDiffToMessageIntegration:
|
||||
"""Integration tests: diff parsing -> prompt building -> message cleaning.
|
||||
|
||||
These tests demonstrate the full pipeline from raw git diff to final
|
||||
cleaned commit message, using mocked LLM responses.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_feature_diff(self) -> str:
|
||||
"""Diff showing a new feature being added."""
|
||||
return """diff --git a/src/auth/login.py b/src/auth/login.py
|
||||
new file mode 100644
|
||||
index 0000000..abc1234
|
||||
--- /dev/null
|
||||
+++ b/src/auth/login.py
|
||||
@@ -0,0 +1,25 @@
|
||||
+from flask import request, jsonify
|
||||
+
|
||||
+def login_endpoint():
|
||||
+ username = request.json.get('username')
|
||||
+ password = request.json.get('password')
|
||||
+ if authenticate(username, password):
|
||||
+ return jsonify({'token': generate_token(username)})
|
||||
+ return jsonify({'error': 'Invalid credentials'}), 401
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_bugfix_diff(self) -> str:
|
||||
"""Diff showing a bug being fixed."""
|
||||
return """diff --git a/src/api/handler.py b/src/api/handler.py
|
||||
index 1234567..abcdefg 100644
|
||||
--- a/src/api/handler.py
|
||||
+++ b/src/api/handler.py
|
||||
@@ -15,7 +15,9 @@ def process_request(data):
|
||||
- result = data.get('value')
|
||||
+ result = data.get('value')
|
||||
+ if result is None:
|
||||
+ raise ValueError("Missing required field: value")
|
||||
return transform(result)
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_chore_diff(self) -> str:
|
||||
"""Diff showing a config/maintenance change."""
|
||||
return """diff --git a/pyproject.toml b/pyproject.toml
|
||||
index 1234567..abcdefg 100644
|
||||
--- a/pyproject.toml
|
||||
+++ b/pyproject.toml
|
||||
@@ -10,7 +10,7 @@ dependencies = [
|
||||
- "httpx>=0.24.0",
|
||||
+ "httpx>=0.25.0",
|
||||
"pydantic>=2.0.0",
|
||||
]
|
||||
"""
|
||||
|
||||
def test_parse_feature_diff(self, sample_feature_diff: str) -> None:
|
||||
"""Feature diff is correctly parsed."""
|
||||
summary = summarize_diff(sample_feature_diff)
|
||||
|
||||
assert summary.files_added == 1
|
||||
assert summary.files_modified == 0
|
||||
assert ".py" in summary.file_types
|
||||
assert "src/auth/login.py" in summary.key_files
|
||||
|
||||
def test_parse_bugfix_diff(self, sample_bugfix_diff: str) -> None:
|
||||
"""Bugfix diff is correctly parsed."""
|
||||
summary = summarize_diff(sample_bugfix_diff)
|
||||
|
||||
assert summary.files_modified == 1
|
||||
assert summary.additions >= 2 # Added lines
|
||||
assert "src/api/handler.py" in summary.key_files
|
||||
|
||||
def test_parse_chore_diff(self, sample_chore_diff: str) -> None:
|
||||
"""Chore/config diff is correctly parsed."""
|
||||
summary = summarize_diff(sample_chore_diff)
|
||||
|
||||
assert summary.files_modified == 1
|
||||
assert ".toml" in summary.file_types
|
||||
assert "pyproject.toml" in summary.key_files
|
||||
|
||||
async def test_full_pipeline_feature(self, sample_feature_diff: str) -> None:
|
||||
"""Full pipeline for feature: diff -> summary -> prompt -> clean message."""
|
||||
summary = summarize_diff(sample_feature_diff)
|
||||
client = LlamaCommitClient()
|
||||
|
||||
# Simulate sloppy LLM response
|
||||
sloppy_response = "✨ feat(auth): add login endpoint"
|
||||
|
||||
mock_response = MagicMock(spec=httpx.Response)
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"content": sloppy_response}
|
||||
|
||||
with patch.object(client, "_get_client") as mock_get_client:
|
||||
mock_http_client = AsyncMock()
|
||||
mock_http_client.post.return_value = mock_response
|
||||
mock_get_client.return_value = mock_http_client
|
||||
|
||||
result = await client.generate_commit_message(summary, "auth-service")
|
||||
|
||||
# Emoji moved to correct position
|
||||
assert result == "feat(auth): ✨ add login endpoint"
|
||||
|
||||
async def test_full_pipeline_bugfix(self, sample_bugfix_diff: str) -> None:
|
||||
"""Full pipeline for bugfix: diff -> summary -> prompt -> clean message."""
|
||||
summary = summarize_diff(sample_bugfix_diff)
|
||||
client = LlamaCommitClient()
|
||||
|
||||
# LLM returns message without emoji
|
||||
sloppy_response = "fix(api): handle missing value field"
|
||||
|
||||
mock_response = MagicMock(spec=httpx.Response)
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"content": sloppy_response}
|
||||
|
||||
with patch.object(client, "_get_client") as mock_get_client:
|
||||
mock_http_client = AsyncMock()
|
||||
mock_http_client.post.return_value = mock_response
|
||||
mock_get_client.return_value = mock_http_client
|
||||
|
||||
result = await client.generate_commit_message(summary, "api-service")
|
||||
|
||||
# Emoji should be added
|
||||
assert result == "fix(api): 🐛 handle missing value field"
|
||||
|
||||
async def test_full_pipeline_chore(self, sample_chore_diff: str) -> None:
|
||||
"""Full pipeline for chore: diff -> summary -> prompt -> clean message."""
|
||||
summary = summarize_diff(sample_chore_diff)
|
||||
client = LlamaCommitClient()
|
||||
|
||||
# LLM returns bare emoji message
|
||||
sloppy_response = "⬆️ update httpx dependency"
|
||||
|
||||
mock_response = MagicMock(spec=httpx.Response)
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"content": sloppy_response}
|
||||
|
||||
with patch.object(client, "_get_client") as mock_get_client:
|
||||
mock_http_client = AsyncMock()
|
||||
mock_http_client.post.return_value = mock_response
|
||||
mock_get_client.return_value = mock_http_client
|
||||
|
||||
result = await client.generate_commit_message(summary, "project")
|
||||
|
||||
# Should be wrapped in chore(shared)
|
||||
assert result == "chore(shared): ⬆️ update httpx dependency"
|
||||
|
||||
|
||||
class TestClientContextManager:
|
||||
|
|
@ -68,6 +449,8 @@ class TestClientContextManager:
|
|||
async def test_context_manager(self) -> None:
|
||||
"""Test using client as async context manager."""
|
||||
async with LlamaCommitClient() as client:
|
||||
# Trigger client creation by calling a method
|
||||
await client.health_check()
|
||||
assert client._client is not None
|
||||
|
||||
# After exit, client should be closed
|
||||
|
|
|
|||
|
|
@ -22,8 +22,8 @@ class TestCommitSystemPrompt:
|
|||
|
||||
def test_includes_format_rules(self) -> None:
|
||||
"""Test that system prompt includes formatting rules."""
|
||||
assert "50 characters" in COMMIT_SYSTEM_PROMPT or "under 50" in COMMIT_SYSTEM_PROMPT
|
||||
assert "imperative" in COMMIT_SYSTEM_PROMPT.lower()
|
||||
assert "50" in COMMIT_SYSTEM_PROMPT # Length limit
|
||||
assert "lowercase" in COMMIT_SYSTEM_PROMPT.lower() # Case requirement
|
||||
|
||||
|
||||
class TestBuildCommitPrompt:
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue