refactor(llm): ♻️ Introduce task-based routing logic in MultiModelLlamaClient for dynamic model selection
Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
This commit is contained in:
parent
a1b4748f45
commit
c59be053f0
1 changed files with 47 additions and 16 deletions
|
|
@ -50,6 +50,13 @@ class MultiModelLlamaClient:
|
|||
interaction needed.
|
||||
"""
|
||||
|
||||
# Map task buckets to their model-boss TaskRegistry names.
|
||||
# Why: per p1-74, opt these consumers into central recommendations
|
||||
# instead of pinning to a model id. Current resolutions (per tasks.yaml)
|
||||
# match the historical model ids — flipping is a separate operator action.
|
||||
_REASONING_TASK = "code.agent"
|
||||
_INSTRUCT_TASK = "summarization.short"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
reasoning_model_id: str = "ministral-14b-reasoning",
|
||||
|
|
@ -207,23 +214,47 @@ class MultiModelLlamaClient:
|
|||
|
||||
async def _chat(
|
||||
self,
|
||||
model_id: str,
|
||||
task: str,
|
||||
messages: list[dict[str, str]],
|
||||
max_tokens: int,
|
||||
context: dict | None = None,
|
||||
) -> str:
|
||||
result = await self._retry(
|
||||
model_id,
|
||||
lambda: self._client.chat(
|
||||
model=model_id,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
temperature=self._temperature,
|
||||
keep_alive=300,
|
||||
context=context,
|
||||
),
|
||||
# Routes through /v1/chat/completions (not the InferenceClient queue
|
||||
# at /v1/requests, which strips `task`). The OpenAI-compatible proxy
|
||||
# resolves `task` via the coordinator's TaskRegistry. `context=` and
|
||||
# `keep_alive=` are queue-only and intentionally not preserved here —
|
||||
# see model_boss.client.InferenceClient.chat docstring.
|
||||
import httpx
|
||||
|
||||
coordinator_url = self._client._coordinator_url
|
||||
body: dict[str, Any] = {
|
||||
"task": task,
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": self._temperature,
|
||||
"x_client_id": self._client_id,
|
||||
"x_keep_alive": 300,
|
||||
}
|
||||
priority = self._client._default_priority
|
||||
if priority != "normal":
|
||||
body["x_priority"] = priority
|
||||
if context:
|
||||
body["x_context"] = context
|
||||
|
||||
async def _post() -> str:
|
||||
async with httpx.AsyncClient(timeout=self._timeout) as http:
|
||||
resp = await http.post(
|
||||
f"{coordinator_url}/v1/chat/completions",
|
||||
json=body,
|
||||
)
|
||||
return result if isinstance(result, str) else ""
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
choices = data.get("choices", [])
|
||||
if not choices:
|
||||
return ""
|
||||
return choices[0].get("message", {}).get("content", "") or ""
|
||||
|
||||
return await self._retry(task, _post)
|
||||
|
||||
async def analyze_commit(
|
||||
self, prompt: str, system_prompt: str, max_tokens: int = 512, repo_name: str = "",
|
||||
|
|
@ -232,7 +263,7 @@ class MultiModelLlamaClient:
|
|||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
return await self._chat(self._reasoning_model_id, messages, max_tokens, context={
|
||||
return await self._chat(self._REASONING_TASK, messages, max_tokens, context={
|
||||
"label": f"Analyze: {repo_name}" if repo_name else "Analyze commit",
|
||||
"tags": ["stage:analyze"],
|
||||
})
|
||||
|
|
@ -248,7 +279,7 @@ class MultiModelLlamaClient:
|
|||
{"role": "system", "content": GROUPING_SYSTEM_PROMPT},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
return await self._chat(self._reasoning_model_id, messages, 2000)
|
||||
return await self._chat(self._REASONING_TASK, messages, 2000)
|
||||
|
||||
async def group_files_pipeline(
|
||||
self, files: list[str], diff_summary: str, repo_name: str = "", branch: str = "main",
|
||||
|
|
@ -333,7 +364,7 @@ class MultiModelLlamaClient:
|
|||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
return await self._chat(self._instruct_model_id, messages, max_tokens, context={
|
||||
return await self._chat(self._INSTRUCT_TASK, messages, max_tokens, context={
|
||||
"label": f"Format: {repo_name}" if repo_name else "Format commit msg",
|
||||
"tags": ["stage:format"],
|
||||
})
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue