"""Pydantic request/response models for the inference API.""" from __future__ import annotations from datetime import datetime from typing import Literal from pydantic import BaseModel, Field, field_validator from content_moderation_feedback.categories import CATEGORY_SET ContentType = Literal["message", "bio", "listing", "review", "coop_description"] Severity = Literal["critical", "high", "medium", "low", "none"] Action = Literal["allow", "warn", "soft_block", "hard_block", "age_gate", "payment_route"] FeedbackType = Literal["false_positive", "false_negative"] # ── Classify ───────────────────────────────────────────────────────────────── class ClassifyRequest(BaseModel): text: str = Field(..., min_length=1, max_length=8192) content_type: ContentType context_prefix: str | None = Field( default=None, description="Override the auto-generated context prefix.", ) class ClassifyResponse(BaseModel): scores: dict[str, float] = Field( description="Probability score per category (only categories >= 0.01 are included)." ) flagged_categories: list[str] = Field( description="Categories whose probability exceeds their threshold." ) severity: Severity action: Action model_version: str # ── Batch classify ──────────────────────────────────────────────────────────── class BatchClassifyRequest(BaseModel): items: list[ClassifyRequest] = Field(..., min_length=1, max_length=64) class BatchClassifyResponse(BaseModel): results: list[ClassifyResponse] # ── Feedback ────────────────────────────────────────────────────────────────── class FeedbackRequest(BaseModel): text: str = Field(..., min_length=1, max_length=8192) expected_categories: list[str] actual_categories: list[str] feedback_type: FeedbackType notes: str | None = None @field_validator("expected_categories", "actual_categories") @classmethod def _validate_categories(cls, cats: list[str]) -> list[str]: unknown = [c for c in cats if c not in CATEGORY_SET] if unknown: raise ValueError(f"Unknown categories: {unknown}") return cats class FeedbackResponse(BaseModel): status: Literal["ok"] = "ok" feedback_id: str feedback_type: str # ── Model info ──────────────────────────────────────────────────────────────── class ModelInfoResponse(BaseModel): version: str categories: list[str] thresholds: dict[str, float] loaded_at: str # ── Model reload ────────────────────────────────────────────────────────────── class ModelReloadResponse(BaseModel): previous_version: str new_version: str # ── Health ──────────────────────────────────────────────────────────────────── class HealthResponse(BaseModel): status: Literal["ok", "degraded"] model_loaded: bool model_version: str | None uptime_seconds: float