301 lines
9.7 KiB
Python
301 lines
9.7 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Model Deployment CLI
|
|
|
|
Validates, quantizes, uploads content moderation models to MinIO,
|
|
and signals the inference API to reload.
|
|
|
|
Usage:
|
|
python tools/deploy-model.py --model-dir models/v15/onnx/ --version v15
|
|
python tools/deploy-model.py --model-dir models/v15/onnx/ --version v15 --skip-quantize
|
|
python tools/deploy-model.py --model-dir models/v15/onnx/ --version v15 --dry-run
|
|
|
|
MinIO bucket structure:
|
|
ml-models/content-moderation/
|
|
manifest.json
|
|
v15/
|
|
model_fp16.onnx
|
|
model_q4.onnx
|
|
tokenizer.json
|
|
vocab.txt
|
|
config.json
|
|
thresholds.json
|
|
metadata.json
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import logging
|
|
import sys
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def validate_model_directory(model_dir: Path) -> dict:
|
|
"""Validate that the model directory contains all required files."""
|
|
required_files = ["tokenizer.json", "config.json"]
|
|
optional_files = ["vocab.txt", "thresholds.json", "special_tokens_map.json"]
|
|
|
|
# Find the ONNX model file
|
|
onnx_files = list(model_dir.glob("*.onnx"))
|
|
if not onnx_files:
|
|
logger.error("No ONNX model file found in %s", model_dir)
|
|
sys.exit(1)
|
|
|
|
# Prefer fp16, then fp32
|
|
model_file = None
|
|
for pattern in ["*fp16*", "*fp32*", "*.onnx"]:
|
|
matches = list(model_dir.glob(pattern))
|
|
if matches:
|
|
model_file = matches[0]
|
|
break
|
|
|
|
if model_file is None:
|
|
logger.error("No ONNX model file found")
|
|
sys.exit(1)
|
|
|
|
missing = [f for f in required_files if not (model_dir / f).exists()]
|
|
if missing:
|
|
logger.error("Missing required files: %s", ", ".join(missing))
|
|
sys.exit(1)
|
|
|
|
present_optional = [f for f in optional_files if (model_dir / f).exists()]
|
|
|
|
model_size_mb = model_file.stat().st_size / (1024 * 1024)
|
|
logger.info("Model file: %s (%.1f MB)", model_file.name, model_size_mb)
|
|
logger.info("Required files: OK")
|
|
logger.info("Optional files present: %s", ", ".join(present_optional) or "none")
|
|
|
|
return {
|
|
"model_file": model_file,
|
|
"model_size_mb": model_size_mb,
|
|
"files": [model_file.name] + required_files + present_optional,
|
|
}
|
|
|
|
|
|
def load_thresholds(model_dir: Path) -> dict:
|
|
"""Load per-category thresholds from thresholds.json."""
|
|
thresholds_path = model_dir / "thresholds.json"
|
|
if thresholds_path.exists():
|
|
with open(thresholds_path) as f:
|
|
return json.load(f)
|
|
logger.warning("No thresholds.json found, using default 0.5 for all categories")
|
|
return {}
|
|
|
|
|
|
def create_metadata(
|
|
version: str,
|
|
model_dir: Path,
|
|
model_info: dict,
|
|
thresholds: dict,
|
|
) -> dict:
|
|
"""Create metadata.json for the deployed model version."""
|
|
# Try to load categories from the training config
|
|
categories_path = (
|
|
Path(__file__).parent.parent
|
|
/ "packages"
|
|
/ "content-moderation-feedback"
|
|
/ "src"
|
|
/ "content_moderation_feedback"
|
|
/ "categories.py"
|
|
)
|
|
|
|
categories = []
|
|
if categories_path.exists():
|
|
content = categories_path.read_text()
|
|
import re
|
|
match = re.search(r'CATEGORIES\s*(?::\s*[^=]*)?\s*=\s*\(([\s\S]*?)\)', content)
|
|
if match:
|
|
categories = [
|
|
line.strip().strip('"').strip("'").rstrip(",")
|
|
for line in match.group(1).split("\n")
|
|
if line.strip() and not line.strip().startswith("#")
|
|
]
|
|
|
|
return {
|
|
"version": version,
|
|
"categories": categories,
|
|
"category_count": len(categories),
|
|
"thresholds": thresholds,
|
|
"model_file": model_info["model_file"].name,
|
|
"model_size_mb": round(model_info["model_size_mb"], 2),
|
|
"deployed_at": datetime.now(timezone.utc).isoformat(),
|
|
"deployed_from": str(model_dir.resolve()),
|
|
}
|
|
|
|
|
|
def upload_to_minio(
|
|
model_dir: Path,
|
|
version: str,
|
|
files: list[str],
|
|
metadata: dict,
|
|
dry_run: bool = False,
|
|
) -> None:
|
|
"""Upload model files to MinIO bucket."""
|
|
try:
|
|
from minio import Minio
|
|
except ImportError:
|
|
logger.error("minio package not installed. Run: pip install minio")
|
|
sys.exit(1)
|
|
|
|
import os
|
|
|
|
minio_endpoint = os.environ.get("MINIO_ENDPOINT", "localhost:9000")
|
|
minio_access_key = os.environ.get("MINIO_ACCESS_KEY", "minioadmin")
|
|
minio_secret_key = os.environ.get("MINIO_SECRET_KEY", "minioadmin")
|
|
bucket = "ml-models"
|
|
prefix = f"content-moderation/{version}"
|
|
|
|
if dry_run:
|
|
logger.info("[DRY RUN] Would upload to %s/%s/:", bucket, prefix)
|
|
for filename in files:
|
|
file_path = model_dir / filename
|
|
if file_path.exists():
|
|
size_mb = file_path.stat().st_size / (1024 * 1024)
|
|
logger.info(" %s (%.1f MB)", filename, size_mb)
|
|
logger.info(" metadata.json")
|
|
return
|
|
|
|
client = Minio(
|
|
minio_endpoint,
|
|
access_key=minio_access_key,
|
|
secret_key=minio_secret_key,
|
|
secure=False,
|
|
)
|
|
|
|
# Ensure bucket exists
|
|
if not client.bucket_exists(bucket):
|
|
client.make_bucket(bucket)
|
|
logger.info("Created bucket: %s", bucket)
|
|
|
|
# Upload model files
|
|
for filename in files:
|
|
file_path = model_dir / filename
|
|
if not file_path.exists():
|
|
continue
|
|
|
|
object_name = f"{prefix}/{filename}"
|
|
client.fput_object(bucket, object_name, str(file_path))
|
|
size_mb = file_path.stat().st_size / (1024 * 1024)
|
|
logger.info("Uploaded: %s (%.1f MB)", object_name, size_mb)
|
|
|
|
# Upload metadata
|
|
metadata_json = json.dumps(metadata, indent=2).encode()
|
|
from io import BytesIO
|
|
client.put_object(
|
|
bucket,
|
|
f"{prefix}/metadata.json",
|
|
BytesIO(metadata_json),
|
|
len(metadata_json),
|
|
content_type="application/json",
|
|
)
|
|
logger.info("Uploaded: %s/metadata.json", prefix)
|
|
|
|
# Update manifest
|
|
update_manifest(client, bucket, version, metadata)
|
|
|
|
|
|
def update_manifest(client, bucket: str, version: str, metadata: dict) -> None:
|
|
"""Update the manifest.json with the new latest version."""
|
|
manifest_key = "content-moderation/manifest.json"
|
|
|
|
try:
|
|
response = client.get_object(bucket, manifest_key)
|
|
manifest = json.loads(response.read())
|
|
response.close()
|
|
except Exception:
|
|
manifest = {"latest": None, "versions": []}
|
|
|
|
manifest["latest"] = version
|
|
if version not in manifest["versions"]:
|
|
manifest["versions"].append(version)
|
|
|
|
manifest_json = json.dumps(manifest, indent=2).encode()
|
|
from io import BytesIO
|
|
client.put_object(
|
|
bucket,
|
|
manifest_key,
|
|
BytesIO(manifest_json),
|
|
len(manifest_json),
|
|
content_type="application/json",
|
|
)
|
|
logger.info("Updated manifest: latest=%s, versions=%s", version, manifest["versions"])
|
|
|
|
|
|
def signal_reload(inference_url: str, dry_run: bool = False) -> None:
|
|
"""Signal the inference API to hot-reload the new model."""
|
|
if dry_run:
|
|
logger.info("[DRY RUN] Would POST %s/api/v1/model/reload", inference_url)
|
|
return
|
|
|
|
import urllib.request
|
|
import urllib.error
|
|
|
|
try:
|
|
req = urllib.request.Request(
|
|
f"{inference_url}/api/v1/model/reload",
|
|
method="POST",
|
|
headers={"Content-Type": "application/json"},
|
|
)
|
|
with urllib.request.urlopen(req, timeout=30) as response:
|
|
result = json.loads(response.read())
|
|
logger.info(
|
|
"Model reloaded: %s -> %s",
|
|
result.get("previous_version", "?"),
|
|
result.get("new_version", "?"),
|
|
)
|
|
except urllib.error.URLError as e:
|
|
logger.warning("Could not signal inference API: %s", e)
|
|
logger.info("The model will be loaded on next API restart")
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(description="Deploy content moderation model to MinIO")
|
|
parser.add_argument("--model-dir", type=Path, required=True, help="Path to model directory")
|
|
parser.add_argument("--version", required=True, help="Model version (e.g., v15)")
|
|
parser.add_argument("--inference-url", default="http://localhost:3501", help="Inference API URL")
|
|
parser.add_argument("--skip-quantize", action="store_true", help="Skip q4 quantization")
|
|
parser.add_argument("--dry-run", action="store_true", help="Print what would be done")
|
|
parser.add_argument("--verbose", action="store_true", help="Enable debug logging")
|
|
|
|
args = parser.parse_args()
|
|
|
|
logging.basicConfig(
|
|
level=logging.DEBUG if args.verbose else logging.INFO,
|
|
format="%(asctime)s [%(levelname)s] %(message)s",
|
|
)
|
|
|
|
if not args.model_dir.exists():
|
|
logger.error("Model directory does not exist: %s", args.model_dir)
|
|
sys.exit(1)
|
|
|
|
logger.info("=== Content Moderation Model Deployment ===")
|
|
logger.info("Version: %s", args.version)
|
|
logger.info("Source: %s", args.model_dir)
|
|
|
|
# Step 1: Validate
|
|
logger.info("--- Step 1: Validate ---")
|
|
model_info = validate_model_directory(args.model_dir)
|
|
|
|
# Step 2: Load thresholds
|
|
thresholds = load_thresholds(args.model_dir)
|
|
|
|
# Step 3: Create metadata
|
|
metadata = create_metadata(args.version, args.model_dir, model_info, thresholds)
|
|
logger.info("Categories: %d", metadata["category_count"])
|
|
|
|
# Step 4: Upload to MinIO
|
|
logger.info("--- Step 4: Upload to MinIO ---")
|
|
upload_to_minio(args.model_dir, args.version, model_info["files"], metadata, args.dry_run)
|
|
|
|
# Step 5: Signal inference API to reload
|
|
logger.info("--- Step 5: Signal Inference API ---")
|
|
signal_reload(args.inference_url, args.dry_run)
|
|
|
|
logger.info("=== Deployment Complete ===")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|