content-moderation/services/inference-api/inference.py
2026-03-13 04:13:49 -07:00

203 lines
7 KiB
Python

"""ONNX model loading, tokenization, and prediction."""
from __future__ import annotations
import json
import logging
import threading
from datetime import UTC, datetime
from pathlib import Path
from typing import NamedTuple
import numpy as np
import onnxruntime as ort
from transformers import AutoTokenizer
from content_moderation_feedback.categories import CATEGORIES
from config import CM_MODEL_DIR_ENV, MAX_SEQ_LENGTH, MODELS_ROOT
logger = logging.getLogger(__name__)
# ── Model state ───────────────────────────────────────────────────────────────
class ModelState(NamedTuple):
session: ort.InferenceSession
tokenizer: AutoTokenizer
thresholds: dict[str, float]
categories: tuple[str, ...]
version: str
model_dir: Path
loaded_at: str
_state: ModelState | None = None
_state_lock = threading.RLock()
# ── Discovery ─────────────────────────────────────────────────────────────────
def _find_latest_model_dir() -> Path:
"""Return the onnx/ sub-directory of the most recently modified model."""
if not MODELS_ROOT.exists():
raise RuntimeError(f"Models directory not found: {MODELS_ROOT}")
candidates: list[tuple[float, Path]] = []
for entry in MODELS_ROOT.iterdir():
if not entry.is_dir():
continue
onnx_dir = entry / "onnx"
if any((onnx_dir / name).exists() for name in ("model_fp16.onnx", "model.onnx")):
candidates.append((onnx_dir.stat().st_mtime, onnx_dir))
if not candidates:
raise RuntimeError(f"No ONNX models found under {MODELS_ROOT}/*/onnx/")
candidates.sort(key=lambda x: x[0])
return candidates[-1][1]
def _select_model_file(model_dir: Path) -> Path:
"""Select fp16 over fp32; reject legacy q8."""
for name in ("model_fp16.onnx", "model.onnx"):
path = model_dir / name
if path.exists():
return path
raise RuntimeError(
f"No compatible ONNX model in {model_dir}. "
"Expected model_fp16.onnx or model.onnx."
)
# ── Loading ───────────────────────────────────────────────────────────────────
def load_model(model_dir_override: Path | None = None) -> ModelState:
"""Load model from disk and return a fresh ModelState.
Args:
model_dir_override: explicit onnx/ directory; auto-discover if None.
"""
if model_dir_override is not None:
model_dir = model_dir_override
elif CM_MODEL_DIR_ENV:
model_dir = Path(CM_MODEL_DIR_ENV)
else:
model_dir = _find_latest_model_dir()
model_path = _select_model_file(model_dir)
providers: list[str] = []
available = ort.get_available_providers()
if "CUDAExecutionProvider" in available:
providers.append("CUDAExecutionProvider")
providers.append("CPUExecutionProvider")
logger.info("Loading ONNX session from %s (providers: %s)", model_path, providers)
session = ort.InferenceSession(str(model_path), providers=providers)
logger.info("Loading tokenizer from %s", model_dir)
tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(str(model_dir))
num_labels: int = session.get_outputs()[0].shape[1]
categories: tuple[str, ...] = CATEGORIES[:num_labels]
thresholds: dict[str, float] = {}
thresholds_path = model_dir / "thresholds.json"
if thresholds_path.exists():
with open(thresholds_path) as fh:
thresholds = json.load(fh)
for cat in categories:
thresholds.setdefault(cat, 0.5)
version = model_dir.parent.name if model_dir.name == "onnx" else model_dir.name
loaded_at = datetime.now(UTC).isoformat()
logger.info(
"Model %s loaded: %d categories, providers=%s",
version,
len(categories),
session.get_providers(),
)
return ModelState(
session=session,
tokenizer=tokenizer,
thresholds=thresholds,
categories=categories,
version=version,
model_dir=model_dir,
loaded_at=loaded_at,
)
# ── Module-level init ─────────────────────────────────────────────────────────
def initialize() -> None:
"""Load model into module-level state. Called once at startup."""
global _state
with _state_lock:
_state = load_model()
def reload() -> tuple[str, str]:
"""Hot-reload model from disk. Thread-safe. Returns (prev_version, new_version)."""
global _state
with _state_lock:
prev_version = _state.version if _state else ""
new_state = load_model()
_state = new_state
logger.info("Model reloaded: %s%s", prev_version, new_state.version)
return prev_version, new_state.version
def get_state() -> ModelState:
"""Return current ModelState; raises RuntimeError if not loaded."""
with _state_lock:
if _state is None:
raise RuntimeError("Model not loaded. Call initialize() first.")
return _state
# ── Inference ─────────────────────────────────────────────────────────────────
def _run_session(
session: ort.InferenceSession,
tokenizer: AutoTokenizer,
texts: list[str],
) -> np.ndarray:
"""Tokenize and run ONNX session. Returns float32 logits (N, num_labels)."""
encoded = tokenizer(
texts,
padding="max_length",
truncation=True,
max_length=MAX_SEQ_LENGTH,
return_tensors="np",
)
input_names = {inp.name for inp in session.get_inputs()}
feed: dict[str, np.ndarray] = {
name: encoded[name].astype(np.int64)
for name in input_names
if name in encoded
}
outputs = session.run(None, feed)
return outputs[0] # shape: (N, num_labels)
def predict(text: str) -> dict[str, float]:
"""Run inference on a single text. Returns {category: probability}."""
state = get_state()
logits = _run_session(state.session, state.tokenizer, [text])[0]
probs = 1.0 / (1.0 + np.exp(-logits))
return {cat: round(float(probs[i]), 4) for i, cat in enumerate(state.categories)}
def predict_batch(texts: list[str]) -> list[dict[str, float]]:
"""Run inference on multiple texts. Returns a list of {category: probability} dicts."""
state = get_state()
logits_batch = _run_session(state.session, state.tokenizer, texts)
results: list[dict[str, float]] = []
for logits in logits_batch:
probs = 1.0 / (1.0 + np.exp(-logits))
results.append({cat: round(float(probs[i]), 4) for i, cat in enumerate(state.categories)})
return results