auto-commit-service/tests/test_multi_model_client.py

202 lines
7.7 KiB
Python
Raw Permalink Normal View History

"""Tests for MultiModelLlamaClient — retry logic and pipeline methods."""
import asyncio
from unittest.mock import AsyncMock, MagicMock, call, patch
import pytest
from auto_commit_service.llm.multi_model_client import (
MultiModelLlamaClient,
_TRANSIENT_EXCEPTIONS,
_transient_exc_types,
)
class TestTransientExcTypes:
"""_transient_exc_types builds the correct tuple of retryable exceptions."""
def test_always_includes_builtins(self) -> None:
types = _transient_exc_types()
assert ConnectionError in types
assert TimeoutError in types
def test_no_asyncio_timeout_error_duplicate(self) -> None:
# asyncio.TimeoutError is TimeoutError on Python 3.11+ — must not be
# included separately (it would be an identical duplicate).
types = _transient_exc_types()
assert asyncio.TimeoutError not in types or asyncio.TimeoutError is TimeoutError
def test_returns_tuple(self) -> None:
assert isinstance(_transient_exc_types(), tuple)
def test_module_constant_matches_function(self) -> None:
assert set(_TRANSIENT_EXCEPTIONS) == set(_transient_exc_types())
def test_httpx_included_when_available(self) -> None:
try:
import httpx
assert httpx.RemoteProtocolError in _TRANSIENT_EXCEPTIONS
assert httpx.ReadError in _TRANSIENT_EXCEPTIONS
assert httpx.ConnectError in _TRANSIENT_EXCEPTIONS
except ImportError:
pytest.skip("httpx not installed")
class TestRetryHelper:
"""_retry retries transient transport errors with exponential backoff."""
@pytest.fixture
def client(self) -> MultiModelLlamaClient:
with patch("auto_commit_service.llm.multi_model_client.InferenceClient"):
return MultiModelLlamaClient(retry_attempts=2)
@pytest.mark.asyncio
async def test_success_on_first_attempt(self, client: MultiModelLlamaClient) -> None:
coro_fn = AsyncMock(return_value="result")
result = await client._retry("test", coro_fn)
assert result == "result"
coro_fn.assert_awaited_once()
@pytest.mark.asyncio
async def test_retries_on_transient_error(self, client: MultiModelLlamaClient) -> None:
exc = ConnectionError("disconnected")
coro_fn = AsyncMock(side_effect=[exc, "result"])
with patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep:
result = await client._retry("test", coro_fn)
assert result == "result"
assert coro_fn.await_count == 2
mock_sleep.assert_awaited_once_with(1) # 2**0 = 1s backoff
@pytest.mark.asyncio
async def test_exponential_backoff_sequence(self, client: MultiModelLlamaClient) -> None:
exc = ConnectionError("disconnected")
# 3 total attempts (retry_attempts=2): fail, fail, succeed
coro_fn = AsyncMock(side_effect=[exc, exc, "result"])
with patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep:
result = await client._retry("test", coro_fn)
assert result == "result"
assert mock_sleep.await_args_list == [call(1), call(2)] # 2**0, 2**1
@pytest.mark.asyncio
async def test_raises_after_all_attempts_exhausted(self, client: MultiModelLlamaClient) -> None:
exc = ConnectionError("disconnected")
coro_fn = AsyncMock(side_effect=exc)
with patch("asyncio.sleep", new_callable=AsyncMock):
with pytest.raises(ConnectionError):
await client._retry("test", coro_fn)
assert coro_fn.await_count == 3 # 1 initial + 2 retries
@pytest.mark.asyncio
async def test_does_not_retry_non_transient_errors(self, client: MultiModelLlamaClient) -> None:
exc = ValueError("bad input")
coro_fn = AsyncMock(side_effect=exc)
with pytest.raises(ValueError):
await client._retry("test", coro_fn)
coro_fn.assert_awaited_once()
@pytest.mark.asyncio
async def test_zero_retries_no_sleep(self, client: MultiModelLlamaClient) -> None:
with patch("auto_commit_service.llm.multi_model_client.InferenceClient"):
no_retry_client = MultiModelLlamaClient(retry_attempts=0)
exc = ConnectionError("disconnected")
coro_fn = AsyncMock(side_effect=exc)
with patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep:
with pytest.raises(ConnectionError):
await no_retry_client._retry("test", coro_fn)
mock_sleep.assert_not_awaited()
coro_fn.assert_awaited_once()
class TestAnalyzeCommitViaPipeline:
"""analyze_commit_via_pipeline: fill_step is called once; only wait_step is retried."""
@pytest.fixture
def client(self) -> MultiModelLlamaClient:
with patch("auto_commit_service.llm.multi_model_client.InferenceClient"):
c = MultiModelLlamaClient(retry_attempts=2)
c._client = AsyncMock()
return c
@pytest.fixture
def handle(self) -> MagicMock:
h = MagicMock()
h.pipeline_id = "pipe-123"
h.wait_step = AsyncMock(return_value={
"result": {"choices": [{"message": {"content": "feat(auth): ✨ Add login"}}]}
})
return h
@pytest.mark.asyncio
async def test_fill_step_called_exactly_once_on_success(
self, client: MultiModelLlamaClient, handle: MagicMock
) -> None:
await client.analyze_commit_via_pipeline(handle, "prompt", "system")
client._client.fill_step.assert_awaited_once()
@pytest.mark.asyncio
async def test_wait_step_retried_on_transient_error(
self, client: MultiModelLlamaClient, handle: MagicMock
) -> None:
"""fill_step must NOT be called again when wait_step fails transiently."""
exc = ConnectionError("read error")
good = {"result": {"choices": [{"message": {"content": "feat: ✨ x"}}]}}
handle.wait_step = AsyncMock(side_effect=[exc, good])
with patch("asyncio.sleep", new_callable=AsyncMock):
result = await client.analyze_commit_via_pipeline(handle, "prompt", "system")
# fill_step called exactly once — not retried
client._client.fill_step.assert_awaited_once()
# wait_step called twice (initial + 1 retry)
assert handle.wait_step.await_count == 2
assert result == "feat: ✨ x"
@pytest.mark.asyncio
async def test_fill_step_failure_propagates_without_retry(
self, client: MultiModelLlamaClient, handle: MagicMock
) -> None:
"""fill_step is outside the retry boundary — errors propagate immediately."""
client._client.fill_step = AsyncMock(side_effect=ConnectionError("fill failed"))
with pytest.raises(ConnectionError):
await client.analyze_commit_via_pipeline(handle, "prompt", "system")
# fill_step called once, wait_step never reached
client._client.fill_step.assert_awaited_once()
handle.wait_step.assert_not_awaited()
class TestBuildLlmClientFactory:
"""_build_llm_client in app.py produces a correctly configured client."""
def test_factory_passes_all_settings(self) -> None:
from auto_commit_service.app import _build_llm_client
from auto_commit_service.config import AutoCommitSettings
settings = AutoCommitSettings(
service_name="test",
reasoning_model_id="ministral-14b-reasoning",
instruct_model_id="ministral-3b-instruct",
llm_timeout=45.0,
llm_retry_attempts=3,
)
with patch("auto_commit_service.llm.multi_model_client.InferenceClient"):
client = _build_llm_client(settings)
assert client._reasoning_model_id == "ministral-14b-reasoning"
assert client._instruct_model_id == "ministral-3b-instruct"
assert client._timeout == 45.0
assert client._retry_attempts == 3