97 lines
3.6 KiB
Python
97 lines
3.6 KiB
Python
"""Pydantic request/response models for the inference API."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from datetime import datetime
|
|
from typing import Literal
|
|
|
|
from pydantic import BaseModel, Field, field_validator
|
|
|
|
from content_moderation_feedback.categories import CATEGORY_SET
|
|
|
|
ContentType = Literal["message", "bio", "listing", "review", "coop_description"]
|
|
Severity = Literal["critical", "high", "medium", "low", "none"]
|
|
Action = Literal["allow", "warn", "soft_block", "hard_block", "age_gate", "payment_route"]
|
|
FeedbackType = Literal["false_positive", "false_negative"]
|
|
|
|
|
|
# ── Classify ─────────────────────────────────────────────────────────────────
|
|
|
|
class ClassifyRequest(BaseModel):
|
|
text: str = Field(..., min_length=1, max_length=8192)
|
|
content_type: ContentType
|
|
context_prefix: str | None = Field(
|
|
default=None,
|
|
description="Override the auto-generated context prefix.",
|
|
)
|
|
|
|
|
|
class ClassifyResponse(BaseModel):
|
|
scores: dict[str, float] = Field(
|
|
description="Probability score per category (only categories >= 0.01 are included)."
|
|
)
|
|
flagged_categories: list[str] = Field(
|
|
description="Categories whose probability exceeds their threshold."
|
|
)
|
|
severity: Severity
|
|
action: Action
|
|
model_version: str
|
|
|
|
|
|
# ── Batch classify ────────────────────────────────────────────────────────────
|
|
|
|
class BatchClassifyRequest(BaseModel):
|
|
items: list[ClassifyRequest] = Field(..., min_length=1, max_length=64)
|
|
|
|
|
|
class BatchClassifyResponse(BaseModel):
|
|
results: list[ClassifyResponse]
|
|
|
|
|
|
# ── Feedback ──────────────────────────────────────────────────────────────────
|
|
|
|
class FeedbackRequest(BaseModel):
|
|
text: str = Field(..., min_length=1, max_length=8192)
|
|
expected_categories: list[str]
|
|
actual_categories: list[str]
|
|
feedback_type: FeedbackType
|
|
notes: str | None = None
|
|
|
|
@field_validator("expected_categories", "actual_categories")
|
|
@classmethod
|
|
def _validate_categories(cls, cats: list[str]) -> list[str]:
|
|
unknown = [c for c in cats if c not in CATEGORY_SET]
|
|
if unknown:
|
|
raise ValueError(f"Unknown categories: {unknown}")
|
|
return cats
|
|
|
|
|
|
class FeedbackResponse(BaseModel):
|
|
status: Literal["ok"] = "ok"
|
|
feedback_id: str
|
|
feedback_type: str
|
|
|
|
|
|
# ── Model info ────────────────────────────────────────────────────────────────
|
|
|
|
class ModelInfoResponse(BaseModel):
|
|
version: str
|
|
categories: list[str]
|
|
thresholds: dict[str, float]
|
|
loaded_at: str
|
|
|
|
|
|
# ── Model reload ──────────────────────────────────────────────────────────────
|
|
|
|
class ModelReloadResponse(BaseModel):
|
|
previous_version: str
|
|
new_version: str
|
|
|
|
|
|
# ── Health ────────────────────────────────────────────────────────────────────
|
|
|
|
class HealthResponse(BaseModel):
|
|
status: Literal["ok", "degraded"]
|
|
model_loaded: bool
|
|
model_version: str | None
|
|
uptime_seconds: float
|