"""Minimal batch-classify HTTP API — no model-boss / GPU lease required. Exposes the same POST /api/v1/classify/batch endpoint as the production app.py, but loads the ONNX model directly on CPU. Intended for dev environments where model-boss is not available. Usage: CM_MODEL_DIR=../../models/v15_mpnet_full_overlap/onnx \\ python batch_api.py """ from __future__ import annotations import logging import sys from contextlib import asynccontextmanager from typing import AsyncIterator import inference from config import API_HOST, API_PORT from fastapi import FastAPI from pydantic import BaseModel, Field logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s %(message)s", stream=sys.stdout, ) logger = logging.getLogger(__name__) class BatchItem(BaseModel): text: str = Field(..., min_length=1) content_type: str = "message" class BatchRequest(BaseModel): items: list[BatchItem] = Field(..., min_length=1, max_length=64) class ItemResult(BaseModel): scores: dict[str, float] flagged_categories: list[str] severity: str action: str class BatchResponse(BaseModel): results: list[ItemResult] @asynccontextmanager async def lifespan(_: FastAPI) -> AsyncIterator[None]: logger.info("Loading ONNX model...") inference.initialize() state = inference.get_state() logger.info("Model loaded: %s (%d categories)", state.version, len(state.categories)) yield app = FastAPI(title="content-moderation-batch", lifespan=lifespan) @app.post("/api/v1/classify/batch", response_model=BatchResponse) def classify_batch(req: BatchRequest) -> BatchResponse: texts = [item.text for item in req.items] score_maps = inference.predict_batch(texts) state = inference.get_state() results: list[ItemResult] = [] for scores in score_maps: flagged = [ cat for cat, prob in scores.items() if prob >= state.thresholds.get(cat, 0.5) ] severity = "high" if flagged else "none" action = "block" if flagged else "allow" results.append( ItemResult(scores=scores, flagged_categories=flagged, severity=severity, action=action) ) return BatchResponse(results=results) @app.get("/health") def health() -> dict[str, str]: return {"status": "ok"} if __name__ == "__main__": import uvicorn uvicorn.run(app, host=API_HOST, port=API_PORT)