refactor(llm): ♻️ Introduce task-based routing logic in MultiModelLlamaClient for dynamic model selection
Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
This commit is contained in:
parent
a1b4748f45
commit
c59be053f0
1 changed files with 47 additions and 16 deletions
|
|
@ -50,6 +50,13 @@ class MultiModelLlamaClient:
|
||||||
interaction needed.
|
interaction needed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Map task buckets to their model-boss TaskRegistry names.
|
||||||
|
# Why: per p1-74, opt these consumers into central recommendations
|
||||||
|
# instead of pinning to a model id. Current resolutions (per tasks.yaml)
|
||||||
|
# match the historical model ids — flipping is a separate operator action.
|
||||||
|
_REASONING_TASK = "code.agent"
|
||||||
|
_INSTRUCT_TASK = "summarization.short"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
reasoning_model_id: str = "ministral-14b-reasoning",
|
reasoning_model_id: str = "ministral-14b-reasoning",
|
||||||
|
|
@ -207,23 +214,47 @@ class MultiModelLlamaClient:
|
||||||
|
|
||||||
async def _chat(
|
async def _chat(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
task: str,
|
||||||
messages: list[dict[str, str]],
|
messages: list[dict[str, str]],
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
context: dict | None = None,
|
context: dict | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
result = await self._retry(
|
# Routes through /v1/chat/completions (not the InferenceClient queue
|
||||||
model_id,
|
# at /v1/requests, which strips `task`). The OpenAI-compatible proxy
|
||||||
lambda: self._client.chat(
|
# resolves `task` via the coordinator's TaskRegistry. `context=` and
|
||||||
model=model_id,
|
# `keep_alive=` are queue-only and intentionally not preserved here —
|
||||||
messages=messages,
|
# see model_boss.client.InferenceClient.chat docstring.
|
||||||
max_tokens=max_tokens,
|
import httpx
|
||||||
temperature=self._temperature,
|
|
||||||
keep_alive=300,
|
coordinator_url = self._client._coordinator_url
|
||||||
context=context,
|
body: dict[str, Any] = {
|
||||||
),
|
"task": task,
|
||||||
)
|
"messages": messages,
|
||||||
return result if isinstance(result, str) else ""
|
"max_tokens": max_tokens,
|
||||||
|
"temperature": self._temperature,
|
||||||
|
"x_client_id": self._client_id,
|
||||||
|
"x_keep_alive": 300,
|
||||||
|
}
|
||||||
|
priority = self._client._default_priority
|
||||||
|
if priority != "normal":
|
||||||
|
body["x_priority"] = priority
|
||||||
|
if context:
|
||||||
|
body["x_context"] = context
|
||||||
|
|
||||||
|
async def _post() -> str:
|
||||||
|
async with httpx.AsyncClient(timeout=self._timeout) as http:
|
||||||
|
resp = await http.post(
|
||||||
|
f"{coordinator_url}/v1/chat/completions",
|
||||||
|
json=body,
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = resp.json()
|
||||||
|
choices = data.get("choices", [])
|
||||||
|
if not choices:
|
||||||
|
return ""
|
||||||
|
return choices[0].get("message", {}).get("content", "") or ""
|
||||||
|
|
||||||
|
return await self._retry(task, _post)
|
||||||
|
|
||||||
async def analyze_commit(
|
async def analyze_commit(
|
||||||
self, prompt: str, system_prompt: str, max_tokens: int = 512, repo_name: str = "",
|
self, prompt: str, system_prompt: str, max_tokens: int = 512, repo_name: str = "",
|
||||||
|
|
@ -232,7 +263,7 @@ class MultiModelLlamaClient:
|
||||||
{"role": "system", "content": system_prompt},
|
{"role": "system", "content": system_prompt},
|
||||||
{"role": "user", "content": prompt},
|
{"role": "user", "content": prompt},
|
||||||
]
|
]
|
||||||
return await self._chat(self._reasoning_model_id, messages, max_tokens, context={
|
return await self._chat(self._REASONING_TASK, messages, max_tokens, context={
|
||||||
"label": f"Analyze: {repo_name}" if repo_name else "Analyze commit",
|
"label": f"Analyze: {repo_name}" if repo_name else "Analyze commit",
|
||||||
"tags": ["stage:analyze"],
|
"tags": ["stage:analyze"],
|
||||||
})
|
})
|
||||||
|
|
@ -248,7 +279,7 @@ class MultiModelLlamaClient:
|
||||||
{"role": "system", "content": GROUPING_SYSTEM_PROMPT},
|
{"role": "system", "content": GROUPING_SYSTEM_PROMPT},
|
||||||
{"role": "user", "content": prompt},
|
{"role": "user", "content": prompt},
|
||||||
]
|
]
|
||||||
return await self._chat(self._reasoning_model_id, messages, 2000)
|
return await self._chat(self._REASONING_TASK, messages, 2000)
|
||||||
|
|
||||||
async def group_files_pipeline(
|
async def group_files_pipeline(
|
||||||
self, files: list[str], diff_summary: str, repo_name: str = "", branch: str = "main",
|
self, files: list[str], diff_summary: str, repo_name: str = "", branch: str = "main",
|
||||||
|
|
@ -333,7 +364,7 @@ class MultiModelLlamaClient:
|
||||||
{"role": "system", "content": system_prompt},
|
{"role": "system", "content": system_prompt},
|
||||||
{"role": "user", "content": prompt},
|
{"role": "user", "content": prompt},
|
||||||
]
|
]
|
||||||
return await self._chat(self._instruct_model_id, messages, max_tokens, context={
|
return await self._chat(self._INSTRUCT_TASK, messages, max_tokens, context={
|
||||||
"label": f"Format: {repo_name}" if repo_name else "Format commit msg",
|
"label": f"Format: {repo_name}" if repo_name else "Format commit msg",
|
||||||
"tags": ["stage:format"],
|
"tags": ["stage:format"],
|
||||||
})
|
})
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue