375 lines
13 KiB
Python
375 lines
13 KiB
Python
"""GPU integration tests with real models.
|
|
|
|
These tests require:
|
|
- GPU with loaded models (ministral-14b-reasoning, ministral-3b-instruct)
|
|
- The gpu_services session fixture auto-starts llama-http systemd services
|
|
|
|
Run with: pytest tests/test_integration_gpu.py -v
|
|
Skip with: pytest -m "not gpu"
|
|
|
|
Tests are designed to fail fast:
|
|
- Short inference timeouts (30s for reasoning, 15s for instruct)
|
|
- Services guaranteed running via gpu_services fixture
|
|
"""
|
|
|
|
import asyncio
|
|
from typing import AsyncGenerator
|
|
|
|
import pytest
|
|
|
|
from auto_commit_service.llm import MultiModelLlamaClient
|
|
from auto_commit_service.pipeline.format_utils import (
|
|
build_format_system_prompt, sanitize_message, correct_emoji, extract_commit_message
|
|
)
|
|
|
|
# Mark entire module as GPU tests
|
|
pytestmark = pytest.mark.gpu
|
|
|
|
|
|
# =============================================================================
|
|
# Fixtures
|
|
# =============================================================================
|
|
|
|
|
|
@pytest.fixture
|
|
async def multi_model_client(
|
|
gpu_services: dict[str, str],
|
|
) -> AsyncGenerator[MultiModelLlamaClient, None]:
|
|
"""Create multi-model client backed by gpu_services fixture.
|
|
|
|
Services are guaranteed healthy by the session-scoped gpu_services fixture.
|
|
"""
|
|
client = MultiModelLlamaClient(
|
|
reasoning_model_id="ministral-14b-reasoning",
|
|
instruct_model_id="ministral-3b-instruct",
|
|
timeout=30.0,
|
|
temperature=0.2,
|
|
)
|
|
|
|
yield client
|
|
|
|
await client.close()
|
|
|
|
|
|
# =============================================================================
|
|
# Service Health Tests
|
|
# =============================================================================
|
|
|
|
|
|
class TestServiceHealth:
|
|
"""Quick health checks for inference services."""
|
|
|
|
async def test_reasoning_model_available(self, multi_model_client: MultiModelLlamaClient) -> None:
|
|
"""Verify reasoning model (14B) is loaded and responding."""
|
|
is_available = await multi_model_client.is_available()
|
|
assert is_available, "Reasoning model should be available"
|
|
|
|
async def test_client_has_correct_model_ids(self, multi_model_client: MultiModelLlamaClient) -> None:
|
|
"""Verify client is configured with expected model IDs."""
|
|
assert multi_model_client.reasoning_model_id == "ministral-14b-reasoning"
|
|
assert multi_model_client.instruct_model_id == "ministral-3b-instruct"
|
|
|
|
|
|
# =============================================================================
|
|
# Reasoning Model Tests (14B)
|
|
# =============================================================================
|
|
|
|
|
|
class TestReasoningModel:
|
|
"""Integration tests for the 14B reasoning model."""
|
|
|
|
@pytest.fixture
|
|
def sample_diff(self) -> str:
|
|
"""Sample diff for analysis."""
|
|
return """diff --git a/src/auth/login.py b/src/auth/login.py
|
|
new file mode 100644
|
|
--- /dev/null
|
|
+++ b/src/auth/login.py
|
|
@@ -0,0 +1,15 @@
|
|
+from flask import request, jsonify
|
|
+
|
|
+def login_endpoint():
|
|
+ username = request.json.get('username')
|
|
+ password = request.json.get('password')
|
|
+
|
|
+ if not username or not password:
|
|
+ return jsonify({'error': 'Missing credentials'}), 400
|
|
+
|
|
+ if authenticate(username, password):
|
|
+ token = generate_token(username)
|
|
+ return jsonify({'token': token})
|
|
+
|
|
+ return jsonify({'error': 'Invalid credentials'}), 401
|
|
"""
|
|
|
|
async def test_analyze_commit_returns_response(
|
|
self, multi_model_client: MultiModelLlamaClient, sample_diff: str
|
|
) -> None:
|
|
"""Reasoning model returns non-empty analysis for a diff."""
|
|
prompt = f"""Analyze these file changes and provide:
|
|
1. TYPE: feat/fix/refactor/chore/docs/test/perf/build/ci/style
|
|
2. SCOPE: component or area affected
|
|
3. IMPACT: brief description of what changed
|
|
4. REASONING: why these changes were made
|
|
|
|
Changes:
|
|
{sample_diff}
|
|
"""
|
|
system_prompt = "You are a code analysis assistant. Analyze git diffs concisely."
|
|
|
|
response = await asyncio.wait_for(
|
|
multi_model_client.analyze_commit(
|
|
prompt=prompt,
|
|
system_prompt=system_prompt,
|
|
max_tokens=256,
|
|
),
|
|
timeout=30.0, # 30s timeout for reasoning
|
|
)
|
|
|
|
assert response, "Reasoning model should return non-empty response"
|
|
assert len(response) > 20, "Response should be substantial"
|
|
# Should contain some analysis keywords
|
|
response_lower = response.lower()
|
|
assert any(kw in response_lower for kw in ["type", "feat", "add", "auth", "login"]), \
|
|
f"Response should contain relevant analysis: {response[:200]}"
|
|
|
|
async def test_group_files_returns_response(
|
|
self, multi_model_client: MultiModelLlamaClient
|
|
) -> None:
|
|
"""Reasoning model can group files logically."""
|
|
files = [
|
|
"src/auth/login.py",
|
|
"src/auth/logout.py",
|
|
"src/auth/session.py",
|
|
"tests/test_auth.py",
|
|
"docs/auth.md",
|
|
"pyproject.toml",
|
|
]
|
|
diff_summary = "Auth module changes with test and doc updates"
|
|
|
|
response = await asyncio.wait_for(
|
|
multi_model_client.group_files(
|
|
files=files,
|
|
diff_summary=diff_summary,
|
|
repo_name="test-repo",
|
|
branch="main",
|
|
),
|
|
timeout=30.0,
|
|
)
|
|
|
|
assert response, "Reasoning model should return grouping response"
|
|
assert any(f in response for f in ["auth", "login", "test"]), \
|
|
f"Response should reference files: {response[:200]}"
|
|
|
|
|
|
# =============================================================================
|
|
# Instruct Model Tests (3B)
|
|
# =============================================================================
|
|
|
|
|
|
class TestInstructModel:
|
|
"""Integration tests for the 3B instruct model."""
|
|
|
|
async def test_format_commit_message_returns_clean_output(
|
|
self, multi_model_client: MultiModelLlamaClient
|
|
) -> None:
|
|
"""Instruct model formats commit message from analysis data."""
|
|
prompt = """Format this analysis into a conventional commit message:
|
|
|
|
TYPE: feat
|
|
SCOPE: auth
|
|
IMPACT: Added user login endpoint with JWT token generation
|
|
REASONING: Implementing authentication for the API
|
|
|
|
Format: type(scope): description
|
|
Keep it under 72 characters.
|
|
"""
|
|
system_prompt = "You are a commit message formatter. Output only the commit message, no explanation."
|
|
|
|
response = await asyncio.wait_for(
|
|
multi_model_client.format_commit_message(
|
|
prompt=prompt,
|
|
system_prompt=system_prompt,
|
|
max_tokens=100,
|
|
),
|
|
timeout=15.0, # 15s timeout for instruct (faster model)
|
|
)
|
|
|
|
assert response, "Instruct model should return formatted message"
|
|
# Should be a short, formatted commit message
|
|
assert len(response) < 200, "Commit message should be concise"
|
|
# Should contain type and some description
|
|
response_lower = response.lower()
|
|
assert any(t in response_lower for t in ["feat", "add", "auth", "login"]), \
|
|
f"Response should be a proper commit message: {response}"
|
|
|
|
async def test_instruct_handles_short_prompts(
|
|
self, multi_model_client: MultiModelLlamaClient
|
|
) -> None:
|
|
"""Instruct model handles minimal prompts gracefully."""
|
|
prompt = "Format: fix(api): resolve null pointer bug"
|
|
system_prompt = "Output only the commit message."
|
|
|
|
response = await asyncio.wait_for(
|
|
multi_model_client.format_commit_message(
|
|
prompt=prompt,
|
|
system_prompt=system_prompt,
|
|
max_tokens=50,
|
|
),
|
|
timeout=15.0,
|
|
)
|
|
|
|
assert response, "Should return something for short prompt"
|
|
|
|
|
|
# =============================================================================
|
|
# Full Pipeline Tests
|
|
# =============================================================================
|
|
|
|
|
|
class TestFullPipeline:
|
|
"""End-to-end tests using both models in sequence."""
|
|
|
|
async def test_analyze_then_format_pipeline(
|
|
self, multi_model_client: MultiModelLlamaClient
|
|
) -> None:
|
|
"""Full pipeline: reasoning model analyzes, instruct model formats."""
|
|
# Step 1: Analyze with reasoning model
|
|
analysis_prompt = """Analyze this change:
|
|
- File: src/api/health.py (new)
|
|
- Content: Added health check endpoint returning {"status": "ok"}
|
|
|
|
Provide TYPE, SCOPE, IMPACT.
|
|
"""
|
|
analysis_system = "Analyze code changes concisely."
|
|
|
|
analysis = await asyncio.wait_for(
|
|
multi_model_client.analyze_commit(
|
|
prompt=analysis_prompt,
|
|
system_prompt=analysis_system,
|
|
max_tokens=150,
|
|
),
|
|
timeout=30.0,
|
|
)
|
|
|
|
assert analysis, "Analysis step should return response"
|
|
|
|
# Step 2: Format with instruct model (use real system prompt like production)
|
|
format_prompt = f"""Based on this analysis, write a commit message:
|
|
|
|
{analysis}
|
|
|
|
Format: type(scope): emoji description
|
|
"""
|
|
format_system = build_format_system_prompt()
|
|
|
|
message = await asyncio.wait_for(
|
|
multi_model_client.format_commit_message(
|
|
prompt=format_prompt,
|
|
system_prompt=format_system,
|
|
max_tokens=100,
|
|
),
|
|
timeout=15.0,
|
|
)
|
|
|
|
assert message, "Format step should return commit message"
|
|
# Post-process like the real pipeline
|
|
raw = extract_commit_message(message)
|
|
first_line = sanitize_message(correct_emoji(raw)).split("\n")[0]
|
|
assert len(first_line) <= 120, f"Commit message too long: {first_line}"
|
|
|
|
|
|
# =============================================================================
|
|
# Error Handling Tests
|
|
# =============================================================================
|
|
|
|
|
|
class TestErrorHandling:
|
|
"""Tests for graceful error handling."""
|
|
|
|
async def test_empty_prompt_handling(
|
|
self, multi_model_client: MultiModelLlamaClient
|
|
) -> None:
|
|
"""Models handle empty/minimal prompts without crashing."""
|
|
response = await asyncio.wait_for(
|
|
multi_model_client.format_commit_message(
|
|
prompt="",
|
|
system_prompt="Return a default commit message.",
|
|
max_tokens=50,
|
|
),
|
|
timeout=15.0,
|
|
)
|
|
# Should return something, even if generic
|
|
assert isinstance(response, str)
|
|
|
|
async def test_very_long_prompt_handling(
|
|
self, multi_model_client: MultiModelLlamaClient
|
|
) -> None:
|
|
"""Models handle long prompts (truncation test)."""
|
|
# Create a long prompt (simulate large diff)
|
|
long_content = "x" * 5000
|
|
prompt = f"Analyze this change:\n{long_content}\n\nProvide TYPE and SCOPE."
|
|
|
|
response = await asyncio.wait_for(
|
|
multi_model_client.analyze_commit(
|
|
prompt=prompt,
|
|
system_prompt="Analyze briefly.",
|
|
max_tokens=100,
|
|
),
|
|
timeout=60.0,
|
|
)
|
|
# Should return something (model may truncate or summarize)
|
|
assert isinstance(response, str)
|
|
|
|
|
|
# =============================================================================
|
|
# Performance / Smoke Tests
|
|
# =============================================================================
|
|
|
|
|
|
class TestPerformance:
|
|
"""Quick smoke tests to verify response times are acceptable."""
|
|
|
|
async def test_instruct_response_under_10_seconds(
|
|
self, multi_model_client: MultiModelLlamaClient
|
|
) -> None:
|
|
"""Instruct model responds within 10 seconds for simple prompts."""
|
|
prompt = "Format: feat(api): add endpoint"
|
|
system_prompt = "Output commit message."
|
|
|
|
import time
|
|
start = time.monotonic()
|
|
|
|
response = await multi_model_client.format_commit_message(
|
|
prompt=prompt,
|
|
system_prompt=system_prompt,
|
|
max_tokens=50,
|
|
)
|
|
|
|
elapsed = time.monotonic() - start
|
|
|
|
assert response, "Should get response"
|
|
assert elapsed < 10.0, f"Instruct model too slow: {elapsed:.2f}s"
|
|
|
|
async def test_reasoning_response_under_30_seconds(
|
|
self, multi_model_client: MultiModelLlamaClient
|
|
) -> None:
|
|
"""Reasoning model responds within 30 seconds for typical prompts."""
|
|
prompt = """Analyze: Added login endpoint with JWT tokens.
|
|
- File: auth/login.py
|
|
Provide TYPE, SCOPE, IMPACT."""
|
|
system_prompt = "Analyze concisely."
|
|
|
|
import time
|
|
start = time.monotonic()
|
|
|
|
response = await multi_model_client.analyze_commit(
|
|
prompt=prompt,
|
|
system_prompt=system_prompt,
|
|
max_tokens=150,
|
|
)
|
|
|
|
elapsed = time.monotonic() - start
|
|
|
|
assert response, "Should get response"
|
|
assert elapsed < 30.0, f"Reasoning model too slow: {elapsed:.2f}s"
|