content-moderation/services/inference-api/app.py
Claude Code cc76075ef8 feat(inference-api): Add new inference endpoint or refactor inference logic in app.py
Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
2026-03-25 06:42:41 -07:00

318 lines
10 KiB
Python

"""Content moderation inference API — FastAPI application entry point."""
from __future__ import annotations
import logging
import time
from contextlib import asynccontextmanager
from typing import AsyncIterator
from fastapi import Depends, FastAPI, HTTPException, Security, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
import inference
import policy
from config import ADMIN_TOKEN, API_HOST, API_PORT, FEEDBACK_STORE_PATH
from models import (
BatchClassifyRequest,
BatchClassifyResponse,
ClassifyRequest,
ClassifyResponse,
FeedbackRequest,
FeedbackResponse,
HealthResponse,
ModelInfoResponse,
ModelReloadResponse,
)
logger = logging.getLogger(__name__)
_bearer = HTTPBearer(auto_error=False)
_startup_time = time.monotonic()
# ── Lifespan ──────────────────────────────────────────────────────────────────
@asynccontextmanager
async def lifespan(_: FastAPI) -> AsyncIterator[None]:
"""Initialize model with coordinator VRAM lease."""
import asyncio
import os
from model_boss.client import InferenceClient
# Acquire GPU lease through coordinator
logger.info("Acquiring GPU lease via model-boss coordinator...")
client = InferenceClient(client_id="content-moderation", auto_start_services=False)
lease = await client.acquire_lease(
model_id="moderation:content-classifier",
vram_mb=500, # ONNX text classifier is small
priority="normal",
)
gpu_index = lease["gpu_index"]
lease_id = lease["lease_id"]
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_index)
logger.info("GPU lease acquired: 500MB on GPU %d (lease=%s)", gpu_index, lease_id)
# Start heartbeat
async def _heartbeat() -> None:
while True:
try:
await asyncio.sleep(10.0)
await client.heartbeat(lease_id)
except asyncio.CancelledError:
break
except Exception as exc:
logger.warning("Heartbeat failed: %s", exc)
heartbeat_task = asyncio.create_task(_heartbeat())
# Load model (now CUDA_VISIBLE_DEVICES is set)
logger.info("Initializing model...")
inference.initialize()
from content_moderation_feedback import FeedbackClient
app.state.feedback_client = FeedbackClient(
store_path=FEEDBACK_STORE_PATH,
model_version=inference.get_state().version,
)
logger.info("Startup complete. Listening on %s:%d", API_HOST, API_PORT)
yield
# Cleanup
heartbeat_task.cancel()
try:
await heartbeat_task
except asyncio.CancelledError:
pass
try:
await client.release_lease(lease_id)
logger.info("GPU lease released: %s", lease_id)
except Exception as exc:
logger.warning("Failed to release lease: %s", exc)
await client.dispose()
logger.info("Shutting down.")
# ── App ───────────────────────────────────────────────────────────────────────
app = FastAPI(
title="Content Moderation Inference API",
version="1.0.0",
lifespan=lifespan,
)
# ── Auth ──────────────────────────────────────────────────────────────────────
def _require_admin(
credentials: HTTPAuthorizationCredentials | None = Security(_bearer),
) -> None:
"""Dependency that enforces the admin token for privileged endpoints."""
if not ADMIN_TOKEN:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Admin token not configured — endpoint disabled.",
)
if credentials is None or credentials.credentials != ADMIN_TOKEN:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or missing admin token.",
)
# ── Helpers ───────────────────────────────────────────────────────────────────
def _classify_one(req: ClassifyRequest) -> ClassifyResponse:
"""Run classification for a single item using current model state."""
state = inference.get_state()
context_prefix = policy.build_context_prefix(req.content_type, req.context_prefix)
full_text = f"{context_prefix} {req.text}"
raw_scores = inference.predict(full_text)
scores = {cat: prob for cat, prob in raw_scores.items() if prob >= 0.01}
flagged = [
cat for cat, prob in raw_scores.items()
if prob >= state.thresholds.get(cat, 0.5)
]
severity, action = policy.evaluate(
scores=raw_scores,
thresholds=state.thresholds,
flagged_categories=flagged,
)
return ClassifyResponse(
scores=scores,
flagged_categories=flagged,
severity=severity,
action=action,
model_version=state.version,
)
# ── Endpoints ─────────────────────────────────────────────────────────────────
@app.post(
"/api/v1/classify",
response_model=ClassifyResponse,
summary="Classify a single text",
)
async def classify(req: ClassifyRequest) -> ClassifyResponse:
return _classify_one(req)
@app.post(
"/api/v1/classify/batch",
response_model=BatchClassifyResponse,
summary="Classify up to 64 texts in one request",
)
async def classify_batch(req: BatchClassifyRequest) -> BatchClassifyResponse:
state = inference.get_state()
# Build full texts for batch inference
full_texts: list[str] = []
for item in req.items:
prefix = policy.build_context_prefix(item.content_type, item.context_prefix)
full_texts.append(f"{prefix} {item.text}")
raw_scores_batch = inference.predict_batch(full_texts)
results: list[ClassifyResponse] = []
for raw_scores in raw_scores_batch:
scores = {cat: prob for cat, prob in raw_scores.items() if prob >= 0.01}
flagged = [
cat for cat, prob in raw_scores.items()
if prob >= state.thresholds.get(cat, 0.5)
]
severity, action = policy.evaluate(
scores=raw_scores,
thresholds=state.thresholds,
flagged_categories=flagged,
)
results.append(ClassifyResponse(
scores=scores,
flagged_categories=flagged,
severity=severity,
action=action,
model_version=state.version,
))
return BatchClassifyResponse(results=results)
@app.post(
"/api/v1/feedback",
response_model=FeedbackResponse,
summary="Submit a misclassification correction",
)
async def submit_feedback(req: FeedbackRequest) -> FeedbackResponse:
client = app.state.feedback_client
state = inference.get_state()
# Re-run inference to get current probabilities for this text so the
# feedback record carries accurate model output at submission time.
raw_scores = inference.predict(req.text)
probabilities = {cat: prob for cat, prob in raw_scores.items() if prob >= 0.01}
thresholds = {cat: state.thresholds.get(cat, 0.5) for cat in raw_scores}
if req.feedback_type == "false_positive":
record = client.report_false_positive(
text=req.text,
probabilities=probabilities,
thresholds=thresholds,
categories=req.actual_categories,
reason=req.notes or "",
source="inference-api",
)
else:
record = client.report_false_negative(
text=req.text,
probabilities=probabilities,
thresholds=thresholds,
categories=req.expected_categories,
reason=req.notes or "",
source="inference-api",
)
return FeedbackResponse(
feedback_id=record.id,
feedback_type=record.feedback_type.value,
)
@app.get(
"/api/v1/model/info",
response_model=ModelInfoResponse,
summary="Current model metadata",
)
async def model_info() -> ModelInfoResponse:
state = inference.get_state()
return ModelInfoResponse(
version=state.version,
categories=list(state.categories),
thresholds=state.thresholds,
loaded_at=state.loaded_at,
)
@app.post(
"/api/v1/model/reload",
response_model=ModelReloadResponse,
summary="Hot-reload model from disk (admin only)",
dependencies=[Depends(_require_admin)],
)
async def model_reload() -> ModelReloadResponse:
prev_version, new_version = inference.reload()
# Refresh the feedback client's model_version reference
from content_moderation_feedback import FeedbackClient
app.state.feedback_client = FeedbackClient(
store_path=FEEDBACK_STORE_PATH,
model_version=new_version,
)
return ModelReloadResponse(previous_version=prev_version, new_version=new_version)
@app.get(
"/health",
response_model=HealthResponse,
summary="Health check with model load status",
)
async def health() -> HealthResponse:
uptime = time.monotonic() - _startup_time
try:
state = inference.get_state()
return HealthResponse(
status="ok",
model_loaded=True,
model_version=state.version,
uptime_seconds=round(uptime, 2),
)
except RuntimeError:
return HealthResponse(
status="degraded",
model_loaded=False,
model_version=None,
uptime_seconds=round(uptime, 2),
)
# ── Dev entrypoint ────────────────────────────────────────────────────────────
if __name__ == "__main__":
import uvicorn
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s %(message)s",
)
uvicorn.run(
"app:app",
host=API_HOST,
port=API_PORT,
reload=False,
)