91 lines
2.4 KiB
Python
91 lines
2.4 KiB
Python
|
|
"""Minimal batch-classify HTTP API — no model-boss / GPU lease required.
|
||
|
|
|
||
|
|
Exposes the same POST /api/v1/classify/batch endpoint as the production
|
||
|
|
app.py, but loads the ONNX model directly on CPU. Intended for dev
|
||
|
|
environments where model-boss is not available.
|
||
|
|
|
||
|
|
Usage:
|
||
|
|
CM_MODEL_DIR=../../models/v15_mpnet_full_overlap/onnx \\
|
||
|
|
python batch_api.py
|
||
|
|
"""
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import logging
|
||
|
|
import sys
|
||
|
|
from contextlib import asynccontextmanager
|
||
|
|
from typing import AsyncIterator
|
||
|
|
|
||
|
|
import inference
|
||
|
|
from config import API_HOST, API_PORT
|
||
|
|
from fastapi import FastAPI
|
||
|
|
from pydantic import BaseModel, Field
|
||
|
|
|
||
|
|
logging.basicConfig(
|
||
|
|
level=logging.INFO,
|
||
|
|
format="%(asctime)s %(levelname)s %(name)s %(message)s",
|
||
|
|
stream=sys.stdout,
|
||
|
|
)
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
class BatchItem(BaseModel):
|
||
|
|
text: str = Field(..., min_length=1)
|
||
|
|
content_type: str = "message"
|
||
|
|
|
||
|
|
|
||
|
|
class BatchRequest(BaseModel):
|
||
|
|
items: list[BatchItem] = Field(..., min_length=1, max_length=64)
|
||
|
|
|
||
|
|
|
||
|
|
class ItemResult(BaseModel):
|
||
|
|
scores: dict[str, float]
|
||
|
|
flagged_categories: list[str]
|
||
|
|
severity: str
|
||
|
|
action: str
|
||
|
|
|
||
|
|
|
||
|
|
class BatchResponse(BaseModel):
|
||
|
|
results: list[ItemResult]
|
||
|
|
|
||
|
|
|
||
|
|
@asynccontextmanager
|
||
|
|
async def lifespan(_: FastAPI) -> AsyncIterator[None]:
|
||
|
|
logger.info("Loading ONNX model...")
|
||
|
|
inference.initialize()
|
||
|
|
state = inference.get_state()
|
||
|
|
logger.info("Model loaded: %s (%d categories)", state.version, len(state.categories))
|
||
|
|
yield
|
||
|
|
|
||
|
|
|
||
|
|
app = FastAPI(title="content-moderation-batch", lifespan=lifespan)
|
||
|
|
|
||
|
|
|
||
|
|
@app.post("/api/v1/classify/batch", response_model=BatchResponse)
|
||
|
|
def classify_batch(req: BatchRequest) -> BatchResponse:
|
||
|
|
texts = [item.text for item in req.items]
|
||
|
|
score_maps = inference.predict_batch(texts)
|
||
|
|
state = inference.get_state()
|
||
|
|
results: list[ItemResult] = []
|
||
|
|
for scores in score_maps:
|
||
|
|
flagged = [
|
||
|
|
cat for cat, prob in scores.items() if prob >= state.thresholds.get(cat, 0.5)
|
||
|
|
]
|
||
|
|
severity = "high" if flagged else "none"
|
||
|
|
action = "block" if flagged else "allow"
|
||
|
|
results.append(
|
||
|
|
ItemResult(scores=scores, flagged_categories=flagged, severity=severity, action=action)
|
||
|
|
)
|
||
|
|
return BatchResponse(results=results)
|
||
|
|
|
||
|
|
|
||
|
|
@app.get("/health")
|
||
|
|
def health() -> dict[str, str]:
|
||
|
|
return {"status": "ok"}
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
import uvicorn
|
||
|
|
|
||
|
|
uvicorn.run(app, host=API_HOST, port=API_PORT)
|