diff --git a/src/auto_commit_service/llm/multi_model_client.py b/src/auto_commit_service/llm/multi_model_client.py index 5695515..6b2a2fb 100644 --- a/src/auto_commit_service/llm/multi_model_client.py +++ b/src/auto_commit_service/llm/multi_model_client.py @@ -50,6 +50,13 @@ class MultiModelLlamaClient: interaction needed. """ + # Map task buckets to their model-boss TaskRegistry names. + # Why: per p1-74, opt these consumers into central recommendations + # instead of pinning to a model id. Current resolutions (per tasks.yaml) + # match the historical model ids — flipping is a separate operator action. + _REASONING_TASK = "code.agent" + _INSTRUCT_TASK = "summarization.short" + def __init__( self, reasoning_model_id: str = "ministral-14b-reasoning", @@ -207,23 +214,47 @@ class MultiModelLlamaClient: async def _chat( self, - model_id: str, + task: str, messages: list[dict[str, str]], max_tokens: int, context: dict | None = None, ) -> str: - result = await self._retry( - model_id, - lambda: self._client.chat( - model=model_id, - messages=messages, - max_tokens=max_tokens, - temperature=self._temperature, - keep_alive=300, - context=context, - ), - ) - return result if isinstance(result, str) else "" + # Routes through /v1/chat/completions (not the InferenceClient queue + # at /v1/requests, which strips `task`). The OpenAI-compatible proxy + # resolves `task` via the coordinator's TaskRegistry. `context=` and + # `keep_alive=` are queue-only and intentionally not preserved here — + # see model_boss.client.InferenceClient.chat docstring. + import httpx + + coordinator_url = self._client._coordinator_url + body: dict[str, Any] = { + "task": task, + "messages": messages, + "max_tokens": max_tokens, + "temperature": self._temperature, + "x_client_id": self._client_id, + "x_keep_alive": 300, + } + priority = self._client._default_priority + if priority != "normal": + body["x_priority"] = priority + if context: + body["x_context"] = context + + async def _post() -> str: + async with httpx.AsyncClient(timeout=self._timeout) as http: + resp = await http.post( + f"{coordinator_url}/v1/chat/completions", + json=body, + ) + resp.raise_for_status() + data = resp.json() + choices = data.get("choices", []) + if not choices: + return "" + return choices[0].get("message", {}).get("content", "") or "" + + return await self._retry(task, _post) async def analyze_commit( self, prompt: str, system_prompt: str, max_tokens: int = 512, repo_name: str = "", @@ -232,7 +263,7 @@ class MultiModelLlamaClient: {"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}, ] - return await self._chat(self._reasoning_model_id, messages, max_tokens, context={ + return await self._chat(self._REASONING_TASK, messages, max_tokens, context={ "label": f"Analyze: {repo_name}" if repo_name else "Analyze commit", "tags": ["stage:analyze"], }) @@ -248,7 +279,7 @@ class MultiModelLlamaClient: {"role": "system", "content": GROUPING_SYSTEM_PROMPT}, {"role": "user", "content": prompt}, ] - return await self._chat(self._reasoning_model_id, messages, 2000) + return await self._chat(self._REASONING_TASK, messages, 2000) async def group_files_pipeline( self, files: list[str], diff_summary: str, repo_name: str = "", branch: str = "main", @@ -333,7 +364,7 @@ class MultiModelLlamaClient: {"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}, ] - return await self._chat(self._instruct_model_id, messages, max_tokens, context={ + return await self._chat(self._INSTRUCT_TASK, messages, max_tokens, context={ "label": f"Format: {repo_name}" if repo_name else "Format commit msg", "tags": ["stage:format"], })