content-moderation/services/inference-api/models.py
2026-03-13 04:13:49 -07:00

97 lines
3.6 KiB
Python

"""Pydantic request/response models for the inference API."""
from __future__ import annotations
from datetime import datetime
from typing import Literal
from pydantic import BaseModel, Field, field_validator
from content_moderation_feedback.categories import CATEGORY_SET
ContentType = Literal["message", "bio", "listing", "review", "coop_description"]
Severity = Literal["critical", "high", "medium", "low", "none"]
Action = Literal["allow", "warn", "soft_block", "hard_block", "age_gate", "payment_route"]
FeedbackType = Literal["false_positive", "false_negative"]
# ── Classify ─────────────────────────────────────────────────────────────────
class ClassifyRequest(BaseModel):
text: str = Field(..., min_length=1, max_length=8192)
content_type: ContentType
context_prefix: str | None = Field(
default=None,
description="Override the auto-generated context prefix.",
)
class ClassifyResponse(BaseModel):
scores: dict[str, float] = Field(
description="Probability score per category (only categories >= 0.01 are included)."
)
flagged_categories: list[str] = Field(
description="Categories whose probability exceeds their threshold."
)
severity: Severity
action: Action
model_version: str
# ── Batch classify ────────────────────────────────────────────────────────────
class BatchClassifyRequest(BaseModel):
items: list[ClassifyRequest] = Field(..., min_length=1, max_length=64)
class BatchClassifyResponse(BaseModel):
results: list[ClassifyResponse]
# ── Feedback ──────────────────────────────────────────────────────────────────
class FeedbackRequest(BaseModel):
text: str = Field(..., min_length=1, max_length=8192)
expected_categories: list[str]
actual_categories: list[str]
feedback_type: FeedbackType
notes: str | None = None
@field_validator("expected_categories", "actual_categories")
@classmethod
def _validate_categories(cls, cats: list[str]) -> list[str]:
unknown = [c for c in cats if c not in CATEGORY_SET]
if unknown:
raise ValueError(f"Unknown categories: {unknown}")
return cats
class FeedbackResponse(BaseModel):
status: Literal["ok"] = "ok"
feedback_id: str
feedback_type: str
# ── Model info ────────────────────────────────────────────────────────────────
class ModelInfoResponse(BaseModel):
version: str
categories: list[str]
thresholds: dict[str, float]
loaded_at: str
# ── Model reload ──────────────────────────────────────────────────────────────
class ModelReloadResponse(BaseModel):
previous_version: str
new_version: str
# ── Health ────────────────────────────────────────────────────────────────────
class HealthResponse(BaseModel):
status: Literal["ok", "degraded"]
model_loaded: bool
model_version: str | None
uptime_seconds: float