318 lines
10 KiB
Python
318 lines
10 KiB
Python
"""Content moderation inference API — FastAPI application entry point."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import time
|
|
from contextlib import asynccontextmanager
|
|
from typing import AsyncIterator
|
|
|
|
from fastapi import Depends, FastAPI, HTTPException, Security, status
|
|
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
|
|
|
import inference
|
|
import policy
|
|
from config import ADMIN_TOKEN, API_HOST, API_PORT, FEEDBACK_STORE_PATH
|
|
from models import (
|
|
BatchClassifyRequest,
|
|
BatchClassifyResponse,
|
|
ClassifyRequest,
|
|
ClassifyResponse,
|
|
FeedbackRequest,
|
|
FeedbackResponse,
|
|
HealthResponse,
|
|
ModelInfoResponse,
|
|
ModelReloadResponse,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
_bearer = HTTPBearer(auto_error=False)
|
|
_startup_time = time.monotonic()
|
|
|
|
|
|
# ── Lifespan ──────────────────────────────────────────────────────────────────
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(_: FastAPI) -> AsyncIterator[None]:
|
|
"""Initialize model with coordinator VRAM lease."""
|
|
import asyncio
|
|
import os
|
|
|
|
from model_boss.client import InferenceClient
|
|
|
|
# Acquire GPU lease through coordinator
|
|
logger.info("Acquiring GPU lease via model-boss coordinator...")
|
|
client = InferenceClient(client_id="content-moderation", auto_start_services=False)
|
|
lease = await client.acquire_lease(
|
|
model_id="moderation:content-classifier",
|
|
vram_mb=500, # ONNX text classifier is small
|
|
priority="normal",
|
|
)
|
|
gpu_index = lease["gpu_index"]
|
|
lease_id = lease["lease_id"]
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_index)
|
|
logger.info("GPU lease acquired: 500MB on GPU %d (lease=%s)", gpu_index, lease_id)
|
|
|
|
# Start heartbeat
|
|
async def _heartbeat() -> None:
|
|
while True:
|
|
try:
|
|
await asyncio.sleep(10.0)
|
|
await client.heartbeat(lease_id)
|
|
except asyncio.CancelledError:
|
|
break
|
|
except Exception as exc:
|
|
logger.warning("Heartbeat failed: %s", exc)
|
|
|
|
heartbeat_task = asyncio.create_task(_heartbeat())
|
|
|
|
# Load model (now CUDA_VISIBLE_DEVICES is set)
|
|
logger.info("Initializing model...")
|
|
inference.initialize()
|
|
|
|
from content_moderation_feedback import FeedbackClient
|
|
app.state.feedback_client = FeedbackClient(
|
|
store_path=FEEDBACK_STORE_PATH,
|
|
model_version=inference.get_state().version,
|
|
)
|
|
logger.info("Startup complete. Listening on %s:%d", API_HOST, API_PORT)
|
|
|
|
yield
|
|
|
|
# Cleanup
|
|
heartbeat_task.cancel()
|
|
try:
|
|
await heartbeat_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
try:
|
|
await client.release_lease(lease_id)
|
|
logger.info("GPU lease released: %s", lease_id)
|
|
except Exception as exc:
|
|
logger.warning("Failed to release lease: %s", exc)
|
|
await client.dispose()
|
|
logger.info("Shutting down.")
|
|
|
|
|
|
# ── App ───────────────────────────────────────────────────────────────────────
|
|
|
|
app = FastAPI(
|
|
title="Content Moderation Inference API",
|
|
version="1.0.0",
|
|
lifespan=lifespan,
|
|
)
|
|
|
|
|
|
# ── Auth ──────────────────────────────────────────────────────────────────────
|
|
|
|
def _require_admin(
|
|
credentials: HTTPAuthorizationCredentials | None = Security(_bearer),
|
|
) -> None:
|
|
"""Dependency that enforces the admin token for privileged endpoints."""
|
|
if not ADMIN_TOKEN:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
detail="Admin token not configured — endpoint disabled.",
|
|
)
|
|
if credentials is None or credentials.credentials != ADMIN_TOKEN:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid or missing admin token.",
|
|
)
|
|
|
|
|
|
# ── Helpers ───────────────────────────────────────────────────────────────────
|
|
|
|
def _classify_one(req: ClassifyRequest) -> ClassifyResponse:
|
|
"""Run classification for a single item using current model state."""
|
|
state = inference.get_state()
|
|
context_prefix = policy.build_context_prefix(req.content_type, req.context_prefix)
|
|
full_text = f"{context_prefix} {req.text}"
|
|
|
|
raw_scores = inference.predict(full_text)
|
|
scores = {cat: prob for cat, prob in raw_scores.items() if prob >= 0.01}
|
|
|
|
flagged = [
|
|
cat for cat, prob in raw_scores.items()
|
|
if prob >= state.thresholds.get(cat, 0.5)
|
|
]
|
|
|
|
severity, action = policy.evaluate(
|
|
scores=raw_scores,
|
|
thresholds=state.thresholds,
|
|
flagged_categories=flagged,
|
|
)
|
|
|
|
return ClassifyResponse(
|
|
scores=scores,
|
|
flagged_categories=flagged,
|
|
severity=severity,
|
|
action=action,
|
|
model_version=state.version,
|
|
)
|
|
|
|
|
|
# ── Endpoints ─────────────────────────────────────────────────────────────────
|
|
|
|
@app.post(
|
|
"/api/v1/classify",
|
|
response_model=ClassifyResponse,
|
|
summary="Classify a single text",
|
|
)
|
|
async def classify(req: ClassifyRequest) -> ClassifyResponse:
|
|
return _classify_one(req)
|
|
|
|
|
|
@app.post(
|
|
"/api/v1/classify/batch",
|
|
response_model=BatchClassifyResponse,
|
|
summary="Classify up to 64 texts in one request",
|
|
)
|
|
async def classify_batch(req: BatchClassifyRequest) -> BatchClassifyResponse:
|
|
state = inference.get_state()
|
|
|
|
# Build full texts for batch inference
|
|
full_texts: list[str] = []
|
|
for item in req.items:
|
|
prefix = policy.build_context_prefix(item.content_type, item.context_prefix)
|
|
full_texts.append(f"{prefix} {item.text}")
|
|
|
|
raw_scores_batch = inference.predict_batch(full_texts)
|
|
|
|
results: list[ClassifyResponse] = []
|
|
for raw_scores in raw_scores_batch:
|
|
scores = {cat: prob for cat, prob in raw_scores.items() if prob >= 0.01}
|
|
flagged = [
|
|
cat for cat, prob in raw_scores.items()
|
|
if prob >= state.thresholds.get(cat, 0.5)
|
|
]
|
|
severity, action = policy.evaluate(
|
|
scores=raw_scores,
|
|
thresholds=state.thresholds,
|
|
flagged_categories=flagged,
|
|
)
|
|
results.append(ClassifyResponse(
|
|
scores=scores,
|
|
flagged_categories=flagged,
|
|
severity=severity,
|
|
action=action,
|
|
model_version=state.version,
|
|
))
|
|
|
|
return BatchClassifyResponse(results=results)
|
|
|
|
|
|
@app.post(
|
|
"/api/v1/feedback",
|
|
response_model=FeedbackResponse,
|
|
summary="Submit a misclassification correction",
|
|
)
|
|
async def submit_feedback(req: FeedbackRequest) -> FeedbackResponse:
|
|
client = app.state.feedback_client
|
|
state = inference.get_state()
|
|
|
|
# Re-run inference to get current probabilities for this text so the
|
|
# feedback record carries accurate model output at submission time.
|
|
raw_scores = inference.predict(req.text)
|
|
probabilities = {cat: prob for cat, prob in raw_scores.items() if prob >= 0.01}
|
|
thresholds = {cat: state.thresholds.get(cat, 0.5) for cat in raw_scores}
|
|
|
|
if req.feedback_type == "false_positive":
|
|
record = client.report_false_positive(
|
|
text=req.text,
|
|
probabilities=probabilities,
|
|
thresholds=thresholds,
|
|
categories=req.actual_categories,
|
|
reason=req.notes or "",
|
|
source="inference-api",
|
|
)
|
|
else:
|
|
record = client.report_false_negative(
|
|
text=req.text,
|
|
probabilities=probabilities,
|
|
thresholds=thresholds,
|
|
categories=req.expected_categories,
|
|
reason=req.notes or "",
|
|
source="inference-api",
|
|
)
|
|
|
|
return FeedbackResponse(
|
|
feedback_id=record.id,
|
|
feedback_type=record.feedback_type.value,
|
|
)
|
|
|
|
|
|
@app.get(
|
|
"/api/v1/model/info",
|
|
response_model=ModelInfoResponse,
|
|
summary="Current model metadata",
|
|
)
|
|
async def model_info() -> ModelInfoResponse:
|
|
state = inference.get_state()
|
|
return ModelInfoResponse(
|
|
version=state.version,
|
|
categories=list(state.categories),
|
|
thresholds=state.thresholds,
|
|
loaded_at=state.loaded_at,
|
|
)
|
|
|
|
|
|
@app.post(
|
|
"/api/v1/model/reload",
|
|
response_model=ModelReloadResponse,
|
|
summary="Hot-reload model from disk (admin only)",
|
|
dependencies=[Depends(_require_admin)],
|
|
)
|
|
async def model_reload() -> ModelReloadResponse:
|
|
prev_version, new_version = inference.reload()
|
|
|
|
# Refresh the feedback client's model_version reference
|
|
from content_moderation_feedback import FeedbackClient
|
|
app.state.feedback_client = FeedbackClient(
|
|
store_path=FEEDBACK_STORE_PATH,
|
|
model_version=new_version,
|
|
)
|
|
|
|
return ModelReloadResponse(previous_version=prev_version, new_version=new_version)
|
|
|
|
|
|
@app.get(
|
|
"/health",
|
|
response_model=HealthResponse,
|
|
summary="Health check with model load status",
|
|
)
|
|
async def health() -> HealthResponse:
|
|
uptime = time.monotonic() - _startup_time
|
|
try:
|
|
state = inference.get_state()
|
|
return HealthResponse(
|
|
status="ok",
|
|
model_loaded=True,
|
|
model_version=state.version,
|
|
uptime_seconds=round(uptime, 2),
|
|
)
|
|
except RuntimeError:
|
|
return HealthResponse(
|
|
status="degraded",
|
|
model_loaded=False,
|
|
model_version=None,
|
|
uptime_seconds=round(uptime, 2),
|
|
)
|
|
|
|
|
|
# ── Dev entrypoint ────────────────────────────────────────────────────────────
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s %(levelname)s %(name)s %(message)s",
|
|
)
|
|
uvicorn.run(
|
|
"app:app",
|
|
host=API_HOST,
|
|
port=API_PORT,
|
|
reload=False,
|
|
)
|