life-tooling/scripts/vram-manager.py
2026-03-20 09:32:40 -07:00

362 lines
11 KiB
Python
Executable file

#!/usr/bin/env python3
"""
VRAM Lifecycle Manager — Intelligent GPU warmup and idle release.
Watches Caddy access logs for user activity and coordinates model loading/unloading
via model-boss. All VRAM orchestration lives in model-boss — this service is purely
a user-activity signal source.
State machine:
COLD ──(activity)──→ WARMING ──(loaded)──→ HOT
▲ │
└────────────(idle timeout)───────────────┘
"""
import asyncio
import json
import logging
import os
import re
import signal
import sys
import time
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from typing import Optional
from urllib.error import URLError
from urllib.request import Request, urlopen
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
MODEL_BOSS_URL = os.environ.get("VRAM_MODEL_BOSS_URL", "http://localhost:8210")
API_URL = os.environ.get("VRAM_API_URL", "http://localhost:3700")
IDLE_TIMEOUT = int(os.environ.get("VRAM_IDLE_TIMEOUT", "900"))
LOG_FILE = os.environ.get(
"VRAM_LOG_FILE",
os.path.expanduser("~/.local/state/life-manager/access.log"),
)
# Default model always warmed on activity
DEFAULT_MODEL = "qwen3-8b"
# Domain → model mapping (mirrors model-config.ts)
DOMAIN_MODEL_MAP: dict[str, str] = {
"health": "ii-medical-8b",
"finance": "kuvera-8b",
"finance_reasoning": "fin-o1-8b",
"psychology": "phi-4-therapy",
"philosophy": "veritas-12b",
}
# Static asset patterns to ignore (don't trigger warmup)
STATIC_ASSET_RE = re.compile(
r"\.(js|css|png|jpg|jpeg|gif|svg|ico|woff2?|ttf|eot|map|webp|avif)(\?|$)",
re.IGNORECASE,
)
IDLE_CHECK_INTERVAL = 30 # seconds between idle checks
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
stream=sys.stdout,
)
log = logging.getLogger("vram-manager")
# ---------------------------------------------------------------------------
# State
# ---------------------------------------------------------------------------
class State(Enum):
COLD = "COLD"
WARMING = "WARMING"
HOT = "HOT"
@dataclass
class ManagerState:
state: State = State.COLD
last_activity_at: float = 0.0
warmed_models: set[str] = field(default_factory=set)
# ---------------------------------------------------------------------------
# HTTP helpers (stdlib only — no dependencies)
# ---------------------------------------------------------------------------
def http_json(
method: str,
url: str,
body: Optional[dict] = None,
timeout: float = 30,
) -> Optional[dict]:
"""Make an HTTP request, return parsed JSON or None on error."""
try:
data = json.dumps(body).encode() if body else None
req = Request(url, data=data, method=method)
req.add_header("Content-Type", "application/json")
with urlopen(req, timeout=timeout) as resp:
return json.loads(resp.read().decode())
except (URLError, OSError, json.JSONDecodeError, TimeoutError) as exc:
log.warning("HTTP %s %s failed: %s", method, url, exc)
return None
def warmup_model(model_id: str) -> bool:
"""Trigger model load via a 1-token chat completion."""
log.info("Warming up model: %s", model_id)
result = http_json(
"POST",
f"{MODEL_BOSS_URL}/v1/chat/completions",
{
"model": model_id,
"messages": [{"role": "user", "content": "warmup"}],
"max_tokens": 1,
},
timeout=120,
)
if result is not None:
log.info("Model %s warmed up successfully", model_id)
return True
log.warning("Model %s warmup failed", model_id)
return False
def drain_gpu() -> bool:
"""Release all VRAM leases via model-boss."""
log.info("Draining all GPU leases...")
result = http_json("POST", f"{MODEL_BOSS_URL}/api/v1/gpu/drain")
if result is not None:
log.info("GPU drain complete")
return True
log.warning("GPU drain failed")
return False
def fetch_recent_domains() -> list[str]:
"""Query the backend API for recently-used conversation domains."""
result = http_json(
"GET",
f"{API_URL}/api/conversations?limit=5&sortBy=updatedAt&sortOrder=DESC",
)
if not result or "data" not in result:
return []
domains: set[str] = set()
for conv in result["data"]:
for domain in conv.get("detectedDomains", []):
if domain in DOMAIN_MODEL_MAP:
domains.add(domain)
return list(domains)
# ---------------------------------------------------------------------------
# Log tail — async generator over Caddy JSON access log
# ---------------------------------------------------------------------------
async def tail_log(path: str):
"""Async generator that yields new lines from a log file (tail -f style)."""
while True:
try:
log_path = Path(path)
if not log_path.exists():
log.info("Waiting for log file: %s", path)
await asyncio.sleep(5)
continue
async with await asyncio.to_thread(_open_file, path) as _:
pass # Just checking it opens
# Use a sync file handle in a thread-friendly way
with open(path, "r") as fh:
# Seek to end — we only care about new requests
fh.seek(0, 2)
while True:
line = fh.readline()
if line:
yield line.strip()
else:
await asyncio.sleep(0.5)
except FileNotFoundError:
log.info("Log file disappeared, waiting for recreation: %s", path)
await asyncio.sleep(5)
except Exception as exc:
log.error("Log tail error: %s", exc)
await asyncio.sleep(5)
def _open_file(path: str):
"""Helper for asyncio.to_thread — just validates file access."""
class _Ctx:
async def __aenter__(self):
return self
async def __aexit__(self, *_):
pass
open(path, "r").close()
return _Ctx()
def is_user_activity(log_line: str) -> bool:
"""Determine if a Caddy log line represents real user activity."""
try:
entry = json.loads(log_line)
except (json.JSONDecodeError, ValueError):
return False
request = entry.get("request", {})
uri = request.get("uri", "")
# Filter out static assets
if STATIC_ASSET_RE.search(uri):
return False
# Filter out health checks
if uri == "/health" or uri == "/api/health":
return False
return True
# ---------------------------------------------------------------------------
# Core lifecycle
# ---------------------------------------------------------------------------
async def do_warmup(mgr: ManagerState) -> None:
"""Warm up models based on default + recent conversation domains."""
mgr.state = State.WARMING
models_to_warm: list[str] = [DEFAULT_MODEL]
# Query recent conversation domains for smart warmup
domains = await asyncio.to_thread(fetch_recent_domains)
for domain in domains:
model = DOMAIN_MODEL_MAP.get(domain)
if model and model not in models_to_warm:
models_to_warm.append(model)
log.info(
"Warming up models: %s (domains: %s)",
", ".join(models_to_warm),
", ".join(domains) if domains else "none",
)
start = time.monotonic()
# Warm models concurrently
results = await asyncio.gather(
*[asyncio.to_thread(warmup_model, m) for m in models_to_warm],
return_exceptions=True,
)
mgr.warmed_models = {
m for m, r in zip(models_to_warm, results) if r is True
}
elapsed = time.monotonic() - start
log.info(
"Warmup complete in %.1fs — loaded: %s",
elapsed,
", ".join(sorted(mgr.warmed_models)) if mgr.warmed_models else "none",
)
mgr.state = State.HOT
async def do_release(mgr: ManagerState) -> None:
"""Release all VRAM leases and transition to COLD."""
log.info(
"Idle timeout reached (%ds). Releasing VRAM for models: %s",
IDLE_TIMEOUT,
", ".join(sorted(mgr.warmed_models)) if mgr.warmed_models else "all",
)
await asyncio.to_thread(drain_gpu)
mgr.warmed_models.clear()
mgr.state = State.COLD
async def activity_monitor(mgr: ManagerState) -> None:
"""Watch Caddy access log for user activity and trigger warmup."""
async for line in tail_log(LOG_FILE):
if not is_user_activity(line):
continue
mgr.last_activity_at = time.monotonic()
if mgr.state == State.COLD:
log.info("User activity detected — triggering warmup")
# Fire warmup in background so we keep watching the log
asyncio.create_task(do_warmup(mgr))
async def idle_checker(mgr: ManagerState) -> None:
"""Periodically check for idle timeout and release VRAM."""
while True:
await asyncio.sleep(IDLE_CHECK_INTERVAL)
if mgr.state != State.HOT:
continue
if mgr.last_activity_at == 0:
continue
idle_seconds = time.monotonic() - mgr.last_activity_at
if idle_seconds >= IDLE_TIMEOUT:
await do_release(mgr)
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
async def main() -> None:
log.info("VRAM Lifecycle Manager starting")
log.info(" model-boss: %s", MODEL_BOSS_URL)
log.info(" api: %s", API_URL)
log.info(" log file: %s", LOG_FILE)
log.info(" idle timeout: %ds", IDLE_TIMEOUT)
mgr = ManagerState()
# Graceful shutdown
loop = asyncio.get_event_loop()
shutdown_event = asyncio.Event()
def handle_signal():
log.info("Shutdown signal received — draining GPU leases")
shutdown_event.set()
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, handle_signal)
# Run activity monitor and idle checker concurrently
monitor_task = asyncio.create_task(activity_monitor(mgr))
idle_task = asyncio.create_task(idle_checker(mgr))
# Wait for shutdown signal
await shutdown_event.wait()
# Cancel tasks
monitor_task.cancel()
idle_task.cancel()
# Drain on shutdown if we have models loaded
if mgr.state in (State.WARMING, State.HOT):
await asyncio.to_thread(drain_gpu)
log.info("VRAM Lifecycle Manager stopped")
if __name__ == "__main__":
asyncio.run(main())