377 lines
14 KiB
Python
377 lines
14 KiB
Python
"""Multi-model LLM client routed through the model-boss coordinator.
|
|
|
|
The coordinator handles model loading, GPU placement, VRAM management,
|
|
health monitoring, and LRU eviction. This client only
|
|
needs to specify which model to use per request.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
from typing import Any
|
|
|
|
from model_boss import InferenceClient
|
|
from model_boss.client import PipelineHandle
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Exception classes treated as transient transport failures and retried.
|
|
# Resolved lazily so the import works even if httpx/anyio are absent.
|
|
def _transient_exc_types() -> tuple[type[BaseException], ...]:
|
|
types: list[type[BaseException]] = [ConnectionError, TimeoutError]
|
|
try:
|
|
import httpx
|
|
types.extend([
|
|
httpx.RemoteProtocolError,
|
|
httpx.ReadError,
|
|
httpx.ConnectError,
|
|
httpx.ReadTimeout,
|
|
httpx.WriteError,
|
|
])
|
|
except ImportError:
|
|
pass
|
|
try:
|
|
import anyio
|
|
types.append(anyio.EndOfStream)
|
|
except (ImportError, AttributeError):
|
|
pass
|
|
return tuple(types)
|
|
|
|
|
|
_TRANSIENT_EXCEPTIONS = _transient_exc_types()
|
|
|
|
|
|
class MultiModelLlamaClient:
|
|
"""Routes LLM requests to reasoning (14B) or instruct (3B) models.
|
|
|
|
All requests go through the model-boss coordinator, which dynamically
|
|
loads models on the best available GPU. No port management or systemd
|
|
interaction needed.
|
|
"""
|
|
|
|
# Map task buckets to their model-boss TaskRegistry names.
|
|
# Why: per p1-74, opt these consumers into central recommendations
|
|
# instead of pinning to a model id. Current resolutions (per tasks.yaml)
|
|
# match the historical model ids — flipping is a separate operator action.
|
|
_REASONING_TASK = "code.agent"
|
|
_INSTRUCT_TASK = "summarization.short"
|
|
|
|
def __init__(
|
|
self,
|
|
reasoning_model_id: str = "ministral-14b-reasoning",
|
|
instruct_model_id: str = "ministral-3b-instruct",
|
|
timeout: float = 120.0,
|
|
temperature: float = 0.2,
|
|
client_id: str = "auto-commit-service",
|
|
retry_attempts: int = 2,
|
|
):
|
|
self._reasoning_model_id = reasoning_model_id
|
|
self._instruct_model_id = instruct_model_id
|
|
self._timeout = timeout
|
|
self._temperature = temperature
|
|
self._client_id = client_id
|
|
self._retry_attempts = max(0, retry_attempts)
|
|
self._configured_priority = "normal"
|
|
self._client = InferenceClient(
|
|
timeout=timeout,
|
|
auto_start_services=True,
|
|
client_id=client_id,
|
|
default_priority=self._configured_priority,
|
|
)
|
|
|
|
@property
|
|
def reasoning_model_id(self) -> str:
|
|
return self._reasoning_model_id
|
|
|
|
@property
|
|
def instruct_model_id(self) -> str:
|
|
return self._instruct_model_id
|
|
|
|
@property
|
|
def inference_client(self) -> InferenceClient:
|
|
"""Expose the underlying InferenceClient for reuse by other modules."""
|
|
return self._client
|
|
|
|
@property
|
|
def current_priority(self) -> str:
|
|
"""Current inference priority level."""
|
|
return self._client._default_priority
|
|
|
|
def set_priority(self, level: str | None) -> None:
|
|
"""Override the inference queue priority.
|
|
|
|
Args:
|
|
level: One of "urgent", "high", "normal", "low", or None to restore
|
|
the configured default ("low").
|
|
|
|
Valid values match the model-boss Priority enum:
|
|
urgent=1, high=5, normal=10, low=20
|
|
"""
|
|
_VALID = {"urgent", "high", "normal", "low"}
|
|
effective = level if level is not None else self._configured_priority
|
|
if effective not in _VALID:
|
|
raise ValueError(f"Invalid priority {effective!r}. Must be one of {sorted(_VALID)}")
|
|
self._client._default_priority = effective
|
|
logger.info(f"Inference priority set to {effective!r}")
|
|
|
|
async def drain_pending_requests(self) -> int:
|
|
"""Cancel all pending (queued, not yet running) requests for this client.
|
|
|
|
Called after a priority change so stale lower-priority queue entries
|
|
are flushed. ACS will resubmit them at the new priority on the next
|
|
cycle. Active requests (currently being processed by a model slot)
|
|
are unaffected — they are not in the pending queue.
|
|
|
|
Returns:
|
|
Number of requests cancelled.
|
|
"""
|
|
import httpx as _httpx
|
|
|
|
coordinator_url = self._client._coordinator_url
|
|
client_id = self._client._client_id
|
|
|
|
async with _httpx.AsyncClient(timeout=10.0) as http:
|
|
try:
|
|
resp = await http.get(f"{coordinator_url}/v1/queue")
|
|
resp.raise_for_status()
|
|
data = resp.json()
|
|
except Exception as exc:
|
|
logger.warning(f"Failed to fetch queue for drain: {exc}")
|
|
return 0
|
|
|
|
pending = [
|
|
item for item in data.get("pending", [])
|
|
if item.get("client_id") == client_id
|
|
]
|
|
|
|
cancelled = 0
|
|
for item in pending:
|
|
request_id = item["request_id"]
|
|
try:
|
|
r = await http.delete(f"{coordinator_url}/v1/queue/{request_id}")
|
|
if r.status_code in (200, 204, 404):
|
|
cancelled += 1
|
|
except Exception as exc:
|
|
logger.warning(f"Failed to cancel request {request_id}: {exc}")
|
|
|
|
if cancelled:
|
|
logger.info(f"Drained {cancelled} pending requests from queue after priority change")
|
|
return cancelled
|
|
|
|
async def is_available(self) -> bool:
|
|
"""Check if inference is available via the coordinator."""
|
|
try:
|
|
await self._client.connect()
|
|
return True
|
|
except Exception:
|
|
return False
|
|
|
|
async def health_check(self) -> dict[str, str | bool]:
|
|
available = await self.is_available()
|
|
return {
|
|
"status": "ok" if available else "unavailable",
|
|
"instruct_available": available,
|
|
"instruct_model": self._instruct_model_id,
|
|
"reasoning_available": available,
|
|
"reasoning_model": self._reasoning_model_id,
|
|
}
|
|
|
|
async def ensure_services(self) -> None:
|
|
"""Connect to coordinator (which auto-starts Redis + coordinator)."""
|
|
await self._client.connect()
|
|
|
|
async def release_services(self) -> None:
|
|
"""No-op — coordinator manages VRAM lifecycle."""
|
|
pass
|
|
|
|
async def close(self) -> None:
|
|
await self._client.dispose()
|
|
|
|
async def _retry(self, label: str, coro_fn: Any) -> Any:
|
|
"""Run coro_fn() with exponential-backoff retry on transient transport errors.
|
|
|
|
coro_fn must be a zero-argument async callable that is safe to call
|
|
multiple times (i.e. creates a fresh coroutine each invocation).
|
|
"""
|
|
attempts = self._retry_attempts + 1
|
|
last_exc: BaseException | None = None
|
|
for attempt in range(attempts):
|
|
try:
|
|
return await coro_fn()
|
|
except _TRANSIENT_EXCEPTIONS as exc:
|
|
last_exc = exc
|
|
if attempt + 1 >= attempts:
|
|
break
|
|
backoff = 2 ** attempt
|
|
logger.warning(
|
|
f"Transport error on {label} (attempt {attempt + 1}/{attempts}): "
|
|
f"{type(exc).__name__}: {exc!r}. Retrying in {backoff}s."
|
|
)
|
|
await asyncio.sleep(backoff)
|
|
assert last_exc is not None
|
|
raise last_exc
|
|
|
|
async def _chat(
|
|
self,
|
|
task: str,
|
|
messages: list[dict[str, str]],
|
|
max_tokens: int,
|
|
context: dict | None = None,
|
|
) -> str:
|
|
# Routes through /v1/chat/completions (not the InferenceClient queue
|
|
# at /v1/requests, which strips `task`). The OpenAI-compatible proxy
|
|
# resolves `task` via the coordinator's TaskRegistry. `context=` and
|
|
# `keep_alive=` are queue-only and intentionally not preserved here —
|
|
# see model_boss.client.InferenceClient.chat docstring.
|
|
import httpx
|
|
|
|
coordinator_url = self._client._coordinator_url
|
|
body: dict[str, Any] = {
|
|
"task": task,
|
|
"messages": messages,
|
|
"max_tokens": max_tokens,
|
|
"temperature": self._temperature,
|
|
"x_client_id": self._client_id,
|
|
"x_keep_alive": 300,
|
|
}
|
|
priority = self._client._default_priority
|
|
if priority != "normal":
|
|
body["x_priority"] = priority
|
|
if context:
|
|
body["x_context"] = context
|
|
|
|
async def _post() -> str:
|
|
async with httpx.AsyncClient(timeout=self._timeout) as http:
|
|
resp = await http.post(
|
|
f"{coordinator_url}/v1/chat/completions",
|
|
json=body,
|
|
)
|
|
resp.raise_for_status()
|
|
data = resp.json()
|
|
choices = data.get("choices", [])
|
|
if not choices:
|
|
return ""
|
|
return choices[0].get("message", {}).get("content", "") or ""
|
|
|
|
return await self._retry(task, _post)
|
|
|
|
async def analyze_commit(
|
|
self, prompt: str, system_prompt: str, max_tokens: int = 512, repo_name: str = "",
|
|
) -> str:
|
|
messages = [
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": prompt},
|
|
]
|
|
return await self._chat(self._REASONING_TASK, messages, max_tokens, context={
|
|
"label": f"Analyze: {repo_name}" if repo_name else "Analyze commit",
|
|
"tags": ["stage:analyze"],
|
|
})
|
|
|
|
async def group_files(
|
|
self, files: list[str], diff_summary: str, repo_name: str = "", branch: str = "main",
|
|
) -> str:
|
|
"""Group files — standalone call without pipeline reservation."""
|
|
from .prompts import build_grouping_prompt, GROUPING_SYSTEM_PROMPT
|
|
|
|
prompt = build_grouping_prompt(files, diff_summary, repo_name, branch)
|
|
messages = [
|
|
{"role": "system", "content": GROUPING_SYSTEM_PROMPT},
|
|
{"role": "user", "content": prompt},
|
|
]
|
|
return await self._chat(self._REASONING_TASK, messages, 2000)
|
|
|
|
async def group_files_pipeline(
|
|
self, files: list[str], diff_summary: str, repo_name: str = "", branch: str = "main",
|
|
) -> tuple[str, PipelineHandle]:
|
|
"""Group files as step 0 of a pipeline, reserving step 1 for analyze.
|
|
|
|
Returns (group_result_text, pipeline_handle). The caller uses the
|
|
handle to fill step 1 with the analysis prompt once grouping is done.
|
|
"""
|
|
from .prompts import build_grouping_prompt, GROUPING_SYSTEM_PROMPT
|
|
|
|
prompt = build_grouping_prompt(files, diff_summary, repo_name, branch)
|
|
messages = [
|
|
{"role": "system", "content": GROUPING_SYSTEM_PROMPT},
|
|
{"role": "user", "content": prompt},
|
|
]
|
|
|
|
async def _submit_and_wait() -> tuple[Any, Any]:
|
|
h = await self._client.submit_pipeline(
|
|
model=self._reasoning_model_id,
|
|
messages=messages,
|
|
steps=2,
|
|
max_tokens=2000,
|
|
temperature=self._temperature,
|
|
keep_alive=300,
|
|
context={
|
|
"label": f"Group files: {repo_name or 'unknown'}",
|
|
"description": f"{len(files)} changed files → semantic grouping",
|
|
"tags": [f"repo:{repo_name}", "stage:group"],
|
|
},
|
|
)
|
|
r = await h.wait_step(0, timeout=self._timeout)
|
|
return h, r
|
|
|
|
handle, result = await self._retry(
|
|
f"group_files_pipeline:{repo_name or 'unknown'}",
|
|
_submit_and_wait,
|
|
)
|
|
text = self._extract_text(result)
|
|
return text, handle
|
|
|
|
async def analyze_commit_via_pipeline(
|
|
self, handle: PipelineHandle, prompt: str, system_prompt: str, max_tokens: int = 512,
|
|
) -> str:
|
|
"""Fill pipeline step 1 with an analysis prompt and wait for result."""
|
|
messages = [
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": prompt},
|
|
]
|
|
|
|
# fill_step is not idempotent — submit once, outside the retry boundary.
|
|
# Only wait_step is safe to retry (re-polling an already-queued step).
|
|
await self._client.fill_step(
|
|
handle.pipeline_id,
|
|
step=1,
|
|
model=self._reasoning_model_id,
|
|
messages=messages,
|
|
max_tokens=max_tokens,
|
|
temperature=self._temperature,
|
|
)
|
|
result = await self._retry(
|
|
f"analyze_via_pipeline:{handle.pipeline_id}",
|
|
lambda: handle.wait_step(1, timeout=self._timeout),
|
|
)
|
|
return self._extract_text(result)
|
|
|
|
@staticmethod
|
|
def _extract_text(result: dict) -> str:
|
|
"""Extract response text from a pubsub result envelope."""
|
|
inner = result.get("result", result)
|
|
if isinstance(inner, dict):
|
|
choices = inner.get("choices", [])
|
|
if choices:
|
|
msg = choices[0].get("message", {})
|
|
return msg.get("content", "")
|
|
return str(inner) if inner else ""
|
|
|
|
async def format_commit_message(
|
|
self, prompt: str, system_prompt: str, max_tokens: int = 150, repo_name: str = "",
|
|
) -> str:
|
|
messages = [
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": prompt},
|
|
]
|
|
return await self._chat(self._INSTRUCT_TASK, messages, max_tokens, context={
|
|
"label": f"Format: {repo_name}" if repo_name else "Format commit msg",
|
|
"tags": ["stage:format"],
|
|
})
|
|
|
|
async def __aenter__(self) -> MultiModelLlamaClient:
|
|
await self._client.connect()
|
|
return self
|
|
|
|
async def __aexit__(self, *args: Any) -> None:
|
|
await self.close()
|