refactor(llm): ♻️ Introduce task-based routing logic in MultiModelLlamaClient for dynamic model selection

Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
This commit is contained in:
autocommit 2026-05-11 08:56:36 -07:00
parent a1b4748f45
commit c59be053f0

View file

@ -50,6 +50,13 @@ class MultiModelLlamaClient:
interaction needed. interaction needed.
""" """
# Map task buckets to their model-boss TaskRegistry names.
# Why: per p1-74, opt these consumers into central recommendations
# instead of pinning to a model id. Current resolutions (per tasks.yaml)
# match the historical model ids — flipping is a separate operator action.
_REASONING_TASK = "code.agent"
_INSTRUCT_TASK = "summarization.short"
def __init__( def __init__(
self, self,
reasoning_model_id: str = "ministral-14b-reasoning", reasoning_model_id: str = "ministral-14b-reasoning",
@ -207,23 +214,47 @@ class MultiModelLlamaClient:
async def _chat( async def _chat(
self, self,
model_id: str, task: str,
messages: list[dict[str, str]], messages: list[dict[str, str]],
max_tokens: int, max_tokens: int,
context: dict | None = None, context: dict | None = None,
) -> str: ) -> str:
result = await self._retry( # Routes through /v1/chat/completions (not the InferenceClient queue
model_id, # at /v1/requests, which strips `task`). The OpenAI-compatible proxy
lambda: self._client.chat( # resolves `task` via the coordinator's TaskRegistry. `context=` and
model=model_id, # `keep_alive=` are queue-only and intentionally not preserved here —
messages=messages, # see model_boss.client.InferenceClient.chat docstring.
max_tokens=max_tokens, import httpx
temperature=self._temperature,
keep_alive=300, coordinator_url = self._client._coordinator_url
context=context, body: dict[str, Any] = {
), "task": task,
) "messages": messages,
return result if isinstance(result, str) else "" "max_tokens": max_tokens,
"temperature": self._temperature,
"x_client_id": self._client_id,
"x_keep_alive": 300,
}
priority = self._client._default_priority
if priority != "normal":
body["x_priority"] = priority
if context:
body["x_context"] = context
async def _post() -> str:
async with httpx.AsyncClient(timeout=self._timeout) as http:
resp = await http.post(
f"{coordinator_url}/v1/chat/completions",
json=body,
)
resp.raise_for_status()
data = resp.json()
choices = data.get("choices", [])
if not choices:
return ""
return choices[0].get("message", {}).get("content", "") or ""
return await self._retry(task, _post)
async def analyze_commit( async def analyze_commit(
self, prompt: str, system_prompt: str, max_tokens: int = 512, repo_name: str = "", self, prompt: str, system_prompt: str, max_tokens: int = 512, repo_name: str = "",
@ -232,7 +263,7 @@ class MultiModelLlamaClient:
{"role": "system", "content": system_prompt}, {"role": "system", "content": system_prompt},
{"role": "user", "content": prompt}, {"role": "user", "content": prompt},
] ]
return await self._chat(self._reasoning_model_id, messages, max_tokens, context={ return await self._chat(self._REASONING_TASK, messages, max_tokens, context={
"label": f"Analyze: {repo_name}" if repo_name else "Analyze commit", "label": f"Analyze: {repo_name}" if repo_name else "Analyze commit",
"tags": ["stage:analyze"], "tags": ["stage:analyze"],
}) })
@ -248,7 +279,7 @@ class MultiModelLlamaClient:
{"role": "system", "content": GROUPING_SYSTEM_PROMPT}, {"role": "system", "content": GROUPING_SYSTEM_PROMPT},
{"role": "user", "content": prompt}, {"role": "user", "content": prompt},
] ]
return await self._chat(self._reasoning_model_id, messages, 2000) return await self._chat(self._REASONING_TASK, messages, 2000)
async def group_files_pipeline( async def group_files_pipeline(
self, files: list[str], diff_summary: str, repo_name: str = "", branch: str = "main", self, files: list[str], diff_summary: str, repo_name: str = "", branch: str = "main",
@ -333,7 +364,7 @@ class MultiModelLlamaClient:
{"role": "system", "content": system_prompt}, {"role": "system", "content": system_prompt},
{"role": "user", "content": prompt}, {"role": "user", "content": prompt},
] ]
return await self._chat(self._instruct_model_id, messages, max_tokens, context={ return await self._chat(self._INSTRUCT_TASK, messages, max_tokens, context={
"label": f"Format: {repo_name}" if repo_name else "Format commit msg", "label": f"Format: {repo_name}" if repo_name else "Format commit msg",
"tags": ["stage:format"], "tags": ["stage:format"],
}) })