content-moderation/services/inference-api/batch_api.py

91 lines
2.4 KiB
Python
Raw Permalink Normal View History

"""Minimal batch-classify HTTP API — no model-boss / GPU lease required.
Exposes the same POST /api/v1/classify/batch endpoint as the production
app.py, but loads the ONNX model directly on CPU. Intended for dev
environments where model-boss is not available.
Usage:
CM_MODEL_DIR=../../models/v15_mpnet_full_overlap/onnx \\
python batch_api.py
"""
from __future__ import annotations
import logging
import sys
from contextlib import asynccontextmanager
from typing import AsyncIterator
import inference
from config import API_HOST, API_PORT
from fastapi import FastAPI
from pydantic import BaseModel, Field
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s %(message)s",
stream=sys.stdout,
)
logger = logging.getLogger(__name__)
class BatchItem(BaseModel):
text: str = Field(..., min_length=1)
content_type: str = "message"
class BatchRequest(BaseModel):
items: list[BatchItem] = Field(..., min_length=1, max_length=64)
class ItemResult(BaseModel):
scores: dict[str, float]
flagged_categories: list[str]
severity: str
action: str
class BatchResponse(BaseModel):
results: list[ItemResult]
@asynccontextmanager
async def lifespan(_: FastAPI) -> AsyncIterator[None]:
logger.info("Loading ONNX model...")
inference.initialize()
state = inference.get_state()
logger.info("Model loaded: %s (%d categories)", state.version, len(state.categories))
yield
app = FastAPI(title="content-moderation-batch", lifespan=lifespan)
@app.post("/api/v1/classify/batch", response_model=BatchResponse)
def classify_batch(req: BatchRequest) -> BatchResponse:
texts = [item.text for item in req.items]
score_maps = inference.predict_batch(texts)
state = inference.get_state()
results: list[ItemResult] = []
for scores in score_maps:
flagged = [
cat for cat, prob in scores.items() if prob >= state.thresholds.get(cat, 0.5)
]
severity = "high" if flagged else "none"
action = "block" if flagged else "allow"
results.append(
ItemResult(scores=scores, flagged_categories=flagged, severity=severity, action=action)
)
return BatchResponse(results=results)
@app.get("/health")
def health() -> dict[str, str]:
return {"status": "ok"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host=API_HOST, port=API_PORT)