auto-commit-service/src/auto_commit_service/llm/multi_model_client.py
autocommit c59be053f0 refactor(llm): ♻️ Introduce task-based routing logic in MultiModelLlamaClient for dynamic model selection
Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
2026-05-11 08:56:36 -07:00

377 lines
14 KiB
Python

"""Multi-model LLM client routed through the model-boss coordinator.
The coordinator handles model loading, GPU placement, VRAM management,
health monitoring, and LRU eviction. This client only
needs to specify which model to use per request.
"""
from __future__ import annotations
import asyncio
import logging
from typing import Any
from model_boss import InferenceClient
from model_boss.client import PipelineHandle
logger = logging.getLogger(__name__)
# Exception classes treated as transient transport failures and retried.
# Resolved lazily so the import works even if httpx/anyio are absent.
def _transient_exc_types() -> tuple[type[BaseException], ...]:
types: list[type[BaseException]] = [ConnectionError, TimeoutError]
try:
import httpx
types.extend([
httpx.RemoteProtocolError,
httpx.ReadError,
httpx.ConnectError,
httpx.ReadTimeout,
httpx.WriteError,
])
except ImportError:
pass
try:
import anyio
types.append(anyio.EndOfStream)
except (ImportError, AttributeError):
pass
return tuple(types)
_TRANSIENT_EXCEPTIONS = _transient_exc_types()
class MultiModelLlamaClient:
"""Routes LLM requests to reasoning (14B) or instruct (3B) models.
All requests go through the model-boss coordinator, which dynamically
loads models on the best available GPU. No port management or systemd
interaction needed.
"""
# Map task buckets to their model-boss TaskRegistry names.
# Why: per p1-74, opt these consumers into central recommendations
# instead of pinning to a model id. Current resolutions (per tasks.yaml)
# match the historical model ids — flipping is a separate operator action.
_REASONING_TASK = "code.agent"
_INSTRUCT_TASK = "summarization.short"
def __init__(
self,
reasoning_model_id: str = "ministral-14b-reasoning",
instruct_model_id: str = "ministral-3b-instruct",
timeout: float = 120.0,
temperature: float = 0.2,
client_id: str = "auto-commit-service",
retry_attempts: int = 2,
):
self._reasoning_model_id = reasoning_model_id
self._instruct_model_id = instruct_model_id
self._timeout = timeout
self._temperature = temperature
self._client_id = client_id
self._retry_attempts = max(0, retry_attempts)
self._configured_priority = "normal"
self._client = InferenceClient(
timeout=timeout,
auto_start_services=True,
client_id=client_id,
default_priority=self._configured_priority,
)
@property
def reasoning_model_id(self) -> str:
return self._reasoning_model_id
@property
def instruct_model_id(self) -> str:
return self._instruct_model_id
@property
def inference_client(self) -> InferenceClient:
"""Expose the underlying InferenceClient for reuse by other modules."""
return self._client
@property
def current_priority(self) -> str:
"""Current inference priority level."""
return self._client._default_priority
def set_priority(self, level: str | None) -> None:
"""Override the inference queue priority.
Args:
level: One of "urgent", "high", "normal", "low", or None to restore
the configured default ("low").
Valid values match the model-boss Priority enum:
urgent=1, high=5, normal=10, low=20
"""
_VALID = {"urgent", "high", "normal", "low"}
effective = level if level is not None else self._configured_priority
if effective not in _VALID:
raise ValueError(f"Invalid priority {effective!r}. Must be one of {sorted(_VALID)}")
self._client._default_priority = effective
logger.info(f"Inference priority set to {effective!r}")
async def drain_pending_requests(self) -> int:
"""Cancel all pending (queued, not yet running) requests for this client.
Called after a priority change so stale lower-priority queue entries
are flushed. ACS will resubmit them at the new priority on the next
cycle. Active requests (currently being processed by a model slot)
are unaffected — they are not in the pending queue.
Returns:
Number of requests cancelled.
"""
import httpx as _httpx
coordinator_url = self._client._coordinator_url
client_id = self._client._client_id
async with _httpx.AsyncClient(timeout=10.0) as http:
try:
resp = await http.get(f"{coordinator_url}/v1/queue")
resp.raise_for_status()
data = resp.json()
except Exception as exc:
logger.warning(f"Failed to fetch queue for drain: {exc}")
return 0
pending = [
item for item in data.get("pending", [])
if item.get("client_id") == client_id
]
cancelled = 0
for item in pending:
request_id = item["request_id"]
try:
r = await http.delete(f"{coordinator_url}/v1/queue/{request_id}")
if r.status_code in (200, 204, 404):
cancelled += 1
except Exception as exc:
logger.warning(f"Failed to cancel request {request_id}: {exc}")
if cancelled:
logger.info(f"Drained {cancelled} pending requests from queue after priority change")
return cancelled
async def is_available(self) -> bool:
"""Check if inference is available via the coordinator."""
try:
await self._client.connect()
return True
except Exception:
return False
async def health_check(self) -> dict[str, str | bool]:
available = await self.is_available()
return {
"status": "ok" if available else "unavailable",
"instruct_available": available,
"instruct_model": self._instruct_model_id,
"reasoning_available": available,
"reasoning_model": self._reasoning_model_id,
}
async def ensure_services(self) -> None:
"""Connect to coordinator (which auto-starts Redis + coordinator)."""
await self._client.connect()
async def release_services(self) -> None:
"""No-op — coordinator manages VRAM lifecycle."""
pass
async def close(self) -> None:
await self._client.dispose()
async def _retry(self, label: str, coro_fn: Any) -> Any:
"""Run coro_fn() with exponential-backoff retry on transient transport errors.
coro_fn must be a zero-argument async callable that is safe to call
multiple times (i.e. creates a fresh coroutine each invocation).
"""
attempts = self._retry_attempts + 1
last_exc: BaseException | None = None
for attempt in range(attempts):
try:
return await coro_fn()
except _TRANSIENT_EXCEPTIONS as exc:
last_exc = exc
if attempt + 1 >= attempts:
break
backoff = 2 ** attempt
logger.warning(
f"Transport error on {label} (attempt {attempt + 1}/{attempts}): "
f"{type(exc).__name__}: {exc!r}. Retrying in {backoff}s."
)
await asyncio.sleep(backoff)
assert last_exc is not None
raise last_exc
async def _chat(
self,
task: str,
messages: list[dict[str, str]],
max_tokens: int,
context: dict | None = None,
) -> str:
# Routes through /v1/chat/completions (not the InferenceClient queue
# at /v1/requests, which strips `task`). The OpenAI-compatible proxy
# resolves `task` via the coordinator's TaskRegistry. `context=` and
# `keep_alive=` are queue-only and intentionally not preserved here —
# see model_boss.client.InferenceClient.chat docstring.
import httpx
coordinator_url = self._client._coordinator_url
body: dict[str, Any] = {
"task": task,
"messages": messages,
"max_tokens": max_tokens,
"temperature": self._temperature,
"x_client_id": self._client_id,
"x_keep_alive": 300,
}
priority = self._client._default_priority
if priority != "normal":
body["x_priority"] = priority
if context:
body["x_context"] = context
async def _post() -> str:
async with httpx.AsyncClient(timeout=self._timeout) as http:
resp = await http.post(
f"{coordinator_url}/v1/chat/completions",
json=body,
)
resp.raise_for_status()
data = resp.json()
choices = data.get("choices", [])
if not choices:
return ""
return choices[0].get("message", {}).get("content", "") or ""
return await self._retry(task, _post)
async def analyze_commit(
self, prompt: str, system_prompt: str, max_tokens: int = 512, repo_name: str = "",
) -> str:
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt},
]
return await self._chat(self._REASONING_TASK, messages, max_tokens, context={
"label": f"Analyze: {repo_name}" if repo_name else "Analyze commit",
"tags": ["stage:analyze"],
})
async def group_files(
self, files: list[str], diff_summary: str, repo_name: str = "", branch: str = "main",
) -> str:
"""Group files — standalone call without pipeline reservation."""
from .prompts import build_grouping_prompt, GROUPING_SYSTEM_PROMPT
prompt = build_grouping_prompt(files, diff_summary, repo_name, branch)
messages = [
{"role": "system", "content": GROUPING_SYSTEM_PROMPT},
{"role": "user", "content": prompt},
]
return await self._chat(self._REASONING_TASK, messages, 2000)
async def group_files_pipeline(
self, files: list[str], diff_summary: str, repo_name: str = "", branch: str = "main",
) -> tuple[str, PipelineHandle]:
"""Group files as step 0 of a pipeline, reserving step 1 for analyze.
Returns (group_result_text, pipeline_handle). The caller uses the
handle to fill step 1 with the analysis prompt once grouping is done.
"""
from .prompts import build_grouping_prompt, GROUPING_SYSTEM_PROMPT
prompt = build_grouping_prompt(files, diff_summary, repo_name, branch)
messages = [
{"role": "system", "content": GROUPING_SYSTEM_PROMPT},
{"role": "user", "content": prompt},
]
async def _submit_and_wait() -> tuple[Any, Any]:
h = await self._client.submit_pipeline(
model=self._reasoning_model_id,
messages=messages,
steps=2,
max_tokens=2000,
temperature=self._temperature,
keep_alive=300,
context={
"label": f"Group files: {repo_name or 'unknown'}",
"description": f"{len(files)} changed files → semantic grouping",
"tags": [f"repo:{repo_name}", "stage:group"],
},
)
r = await h.wait_step(0, timeout=self._timeout)
return h, r
handle, result = await self._retry(
f"group_files_pipeline:{repo_name or 'unknown'}",
_submit_and_wait,
)
text = self._extract_text(result)
return text, handle
async def analyze_commit_via_pipeline(
self, handle: PipelineHandle, prompt: str, system_prompt: str, max_tokens: int = 512,
) -> str:
"""Fill pipeline step 1 with an analysis prompt and wait for result."""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt},
]
# fill_step is not idempotent — submit once, outside the retry boundary.
# Only wait_step is safe to retry (re-polling an already-queued step).
await self._client.fill_step(
handle.pipeline_id,
step=1,
model=self._reasoning_model_id,
messages=messages,
max_tokens=max_tokens,
temperature=self._temperature,
)
result = await self._retry(
f"analyze_via_pipeline:{handle.pipeline_id}",
lambda: handle.wait_step(1, timeout=self._timeout),
)
return self._extract_text(result)
@staticmethod
def _extract_text(result: dict) -> str:
"""Extract response text from a pubsub result envelope."""
inner = result.get("result", result)
if isinstance(inner, dict):
choices = inner.get("choices", [])
if choices:
msg = choices[0].get("message", {})
return msg.get("content", "")
return str(inner) if inner else ""
async def format_commit_message(
self, prompt: str, system_prompt: str, max_tokens: int = 150, repo_name: str = "",
) -> str:
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt},
]
return await self._chat(self._INSTRUCT_TASK, messages, max_tokens, context={
"label": f"Format: {repo_name}" if repo_name else "Format commit msg",
"tags": ["stage:format"],
})
async def __aenter__(self) -> MultiModelLlamaClient:
await self._client.connect()
return self
async def __aexit__(self, *args: Any) -> None:
await self.close()