362 lines
11 KiB
Python
Executable file
362 lines
11 KiB
Python
Executable file
#!/usr/bin/env python3
|
|
"""
|
|
VRAM Lifecycle Manager — Intelligent GPU warmup and idle release.
|
|
|
|
Watches Caddy access logs for user activity and coordinates model loading/unloading
|
|
via model-boss. All VRAM orchestration lives in model-boss — this service is purely
|
|
a user-activity signal source.
|
|
|
|
State machine:
|
|
COLD ──(activity)──→ WARMING ──(loaded)──→ HOT
|
|
▲ │
|
|
└────────────(idle timeout)───────────────┘
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import re
|
|
import signal
|
|
import sys
|
|
import time
|
|
from dataclasses import dataclass, field
|
|
from enum import Enum
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
from urllib.error import URLError
|
|
from urllib.request import Request, urlopen
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Configuration
|
|
# ---------------------------------------------------------------------------
|
|
|
|
MODEL_BOSS_URL = os.environ.get("VRAM_MODEL_BOSS_URL", "http://localhost:8210")
|
|
API_URL = os.environ.get("VRAM_API_URL", "http://localhost:3700")
|
|
IDLE_TIMEOUT = int(os.environ.get("VRAM_IDLE_TIMEOUT", "900"))
|
|
LOG_FILE = os.environ.get(
|
|
"VRAM_LOG_FILE",
|
|
os.path.expanduser("~/.local/state/life-manager/access.log"),
|
|
)
|
|
|
|
# Default model always warmed on activity
|
|
DEFAULT_MODEL = "qwen3-8b"
|
|
|
|
# Domain → model mapping (mirrors model-config.ts)
|
|
DOMAIN_MODEL_MAP: dict[str, str] = {
|
|
"health": "ii-medical-8b",
|
|
"finance": "kuvera-8b",
|
|
"finance_reasoning": "fin-o1-8b",
|
|
"psychology": "phi-4-therapy",
|
|
"philosophy": "veritas-12b",
|
|
}
|
|
|
|
# Static asset patterns to ignore (don't trigger warmup)
|
|
STATIC_ASSET_RE = re.compile(
|
|
r"\.(js|css|png|jpg|jpeg|gif|svg|ico|woff2?|ttf|eot|map|webp|avif)(\?|$)",
|
|
re.IGNORECASE,
|
|
)
|
|
|
|
IDLE_CHECK_INTERVAL = 30 # seconds between idle checks
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s [%(levelname)s] %(message)s",
|
|
datefmt="%Y-%m-%d %H:%M:%S",
|
|
stream=sys.stdout,
|
|
)
|
|
log = logging.getLogger("vram-manager")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# State
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class State(Enum):
|
|
COLD = "COLD"
|
|
WARMING = "WARMING"
|
|
HOT = "HOT"
|
|
|
|
|
|
@dataclass
|
|
class ManagerState:
|
|
state: State = State.COLD
|
|
last_activity_at: float = 0.0
|
|
warmed_models: set[str] = field(default_factory=set)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# HTTP helpers (stdlib only — no dependencies)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def http_json(
|
|
method: str,
|
|
url: str,
|
|
body: Optional[dict] = None,
|
|
timeout: float = 30,
|
|
) -> Optional[dict]:
|
|
"""Make an HTTP request, return parsed JSON or None on error."""
|
|
try:
|
|
data = json.dumps(body).encode() if body else None
|
|
req = Request(url, data=data, method=method)
|
|
req.add_header("Content-Type", "application/json")
|
|
with urlopen(req, timeout=timeout) as resp:
|
|
return json.loads(resp.read().decode())
|
|
except (URLError, OSError, json.JSONDecodeError, TimeoutError) as exc:
|
|
log.warning("HTTP %s %s failed: %s", method, url, exc)
|
|
return None
|
|
|
|
|
|
def warmup_model(model_id: str) -> bool:
|
|
"""Trigger model load via a 1-token chat completion."""
|
|
log.info("Warming up model: %s", model_id)
|
|
result = http_json(
|
|
"POST",
|
|
f"{MODEL_BOSS_URL}/v1/chat/completions",
|
|
{
|
|
"model": model_id,
|
|
"messages": [{"role": "user", "content": "warmup"}],
|
|
"max_tokens": 1,
|
|
},
|
|
timeout=120,
|
|
)
|
|
if result is not None:
|
|
log.info("Model %s warmed up successfully", model_id)
|
|
return True
|
|
log.warning("Model %s warmup failed", model_id)
|
|
return False
|
|
|
|
|
|
def drain_gpu() -> bool:
|
|
"""Release all VRAM leases via model-boss."""
|
|
log.info("Draining all GPU leases...")
|
|
result = http_json("POST", f"{MODEL_BOSS_URL}/api/v1/gpu/drain")
|
|
if result is not None:
|
|
log.info("GPU drain complete")
|
|
return True
|
|
log.warning("GPU drain failed")
|
|
return False
|
|
|
|
|
|
def fetch_recent_domains() -> list[str]:
|
|
"""Query the backend API for recently-used conversation domains."""
|
|
result = http_json(
|
|
"GET",
|
|
f"{API_URL}/api/conversations?limit=5&sortBy=updatedAt&sortOrder=DESC",
|
|
)
|
|
if not result or "data" not in result:
|
|
return []
|
|
|
|
domains: set[str] = set()
|
|
for conv in result["data"]:
|
|
for domain in conv.get("detectedDomains", []):
|
|
if domain in DOMAIN_MODEL_MAP:
|
|
domains.add(domain)
|
|
return list(domains)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Log tail — async generator over Caddy JSON access log
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
async def tail_log(path: str):
|
|
"""Async generator that yields new lines from a log file (tail -f style)."""
|
|
while True:
|
|
try:
|
|
log_path = Path(path)
|
|
if not log_path.exists():
|
|
log.info("Waiting for log file: %s", path)
|
|
await asyncio.sleep(5)
|
|
continue
|
|
|
|
async with await asyncio.to_thread(_open_file, path) as _:
|
|
pass # Just checking it opens
|
|
|
|
# Use a sync file handle in a thread-friendly way
|
|
with open(path, "r") as fh:
|
|
# Seek to end — we only care about new requests
|
|
fh.seek(0, 2)
|
|
while True:
|
|
line = fh.readline()
|
|
if line:
|
|
yield line.strip()
|
|
else:
|
|
await asyncio.sleep(0.5)
|
|
except FileNotFoundError:
|
|
log.info("Log file disappeared, waiting for recreation: %s", path)
|
|
await asyncio.sleep(5)
|
|
except Exception as exc:
|
|
log.error("Log tail error: %s", exc)
|
|
await asyncio.sleep(5)
|
|
|
|
|
|
def _open_file(path: str):
|
|
"""Helper for asyncio.to_thread — just validates file access."""
|
|
|
|
class _Ctx:
|
|
async def __aenter__(self):
|
|
return self
|
|
|
|
async def __aexit__(self, *_):
|
|
pass
|
|
|
|
open(path, "r").close()
|
|
return _Ctx()
|
|
|
|
|
|
def is_user_activity(log_line: str) -> bool:
|
|
"""Determine if a Caddy log line represents real user activity."""
|
|
try:
|
|
entry = json.loads(log_line)
|
|
except (json.JSONDecodeError, ValueError):
|
|
return False
|
|
|
|
request = entry.get("request", {})
|
|
uri = request.get("uri", "")
|
|
|
|
# Filter out static assets
|
|
if STATIC_ASSET_RE.search(uri):
|
|
return False
|
|
|
|
# Filter out health checks
|
|
if uri == "/health" or uri == "/api/health":
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Core lifecycle
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
async def do_warmup(mgr: ManagerState) -> None:
|
|
"""Warm up models based on default + recent conversation domains."""
|
|
mgr.state = State.WARMING
|
|
models_to_warm: list[str] = [DEFAULT_MODEL]
|
|
|
|
# Query recent conversation domains for smart warmup
|
|
domains = await asyncio.to_thread(fetch_recent_domains)
|
|
for domain in domains:
|
|
model = DOMAIN_MODEL_MAP.get(domain)
|
|
if model and model not in models_to_warm:
|
|
models_to_warm.append(model)
|
|
|
|
log.info(
|
|
"Warming up models: %s (domains: %s)",
|
|
", ".join(models_to_warm),
|
|
", ".join(domains) if domains else "none",
|
|
)
|
|
|
|
start = time.monotonic()
|
|
|
|
# Warm models concurrently
|
|
results = await asyncio.gather(
|
|
*[asyncio.to_thread(warmup_model, m) for m in models_to_warm],
|
|
return_exceptions=True,
|
|
)
|
|
|
|
mgr.warmed_models = {
|
|
m for m, r in zip(models_to_warm, results) if r is True
|
|
}
|
|
|
|
elapsed = time.monotonic() - start
|
|
log.info(
|
|
"Warmup complete in %.1fs — loaded: %s",
|
|
elapsed,
|
|
", ".join(sorted(mgr.warmed_models)) if mgr.warmed_models else "none",
|
|
)
|
|
|
|
mgr.state = State.HOT
|
|
|
|
|
|
async def do_release(mgr: ManagerState) -> None:
|
|
"""Release all VRAM leases and transition to COLD."""
|
|
log.info(
|
|
"Idle timeout reached (%ds). Releasing VRAM for models: %s",
|
|
IDLE_TIMEOUT,
|
|
", ".join(sorted(mgr.warmed_models)) if mgr.warmed_models else "all",
|
|
)
|
|
await asyncio.to_thread(drain_gpu)
|
|
mgr.warmed_models.clear()
|
|
mgr.state = State.COLD
|
|
|
|
|
|
async def activity_monitor(mgr: ManagerState) -> None:
|
|
"""Watch Caddy access log for user activity and trigger warmup."""
|
|
async for line in tail_log(LOG_FILE):
|
|
if not is_user_activity(line):
|
|
continue
|
|
|
|
mgr.last_activity_at = time.monotonic()
|
|
|
|
if mgr.state == State.COLD:
|
|
log.info("User activity detected — triggering warmup")
|
|
# Fire warmup in background so we keep watching the log
|
|
asyncio.create_task(do_warmup(mgr))
|
|
|
|
|
|
async def idle_checker(mgr: ManagerState) -> None:
|
|
"""Periodically check for idle timeout and release VRAM."""
|
|
while True:
|
|
await asyncio.sleep(IDLE_CHECK_INTERVAL)
|
|
|
|
if mgr.state != State.HOT:
|
|
continue
|
|
|
|
if mgr.last_activity_at == 0:
|
|
continue
|
|
|
|
idle_seconds = time.monotonic() - mgr.last_activity_at
|
|
if idle_seconds >= IDLE_TIMEOUT:
|
|
await do_release(mgr)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Main
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
async def main() -> None:
|
|
log.info("VRAM Lifecycle Manager starting")
|
|
log.info(" model-boss: %s", MODEL_BOSS_URL)
|
|
log.info(" api: %s", API_URL)
|
|
log.info(" log file: %s", LOG_FILE)
|
|
log.info(" idle timeout: %ds", IDLE_TIMEOUT)
|
|
|
|
mgr = ManagerState()
|
|
|
|
# Graceful shutdown
|
|
loop = asyncio.get_event_loop()
|
|
shutdown_event = asyncio.Event()
|
|
|
|
def handle_signal():
|
|
log.info("Shutdown signal received — draining GPU leases")
|
|
shutdown_event.set()
|
|
|
|
for sig in (signal.SIGTERM, signal.SIGINT):
|
|
loop.add_signal_handler(sig, handle_signal)
|
|
|
|
# Run activity monitor and idle checker concurrently
|
|
monitor_task = asyncio.create_task(activity_monitor(mgr))
|
|
idle_task = asyncio.create_task(idle_checker(mgr))
|
|
|
|
# Wait for shutdown signal
|
|
await shutdown_event.wait()
|
|
|
|
# Cancel tasks
|
|
monitor_task.cancel()
|
|
idle_task.cancel()
|
|
|
|
# Drain on shutdown if we have models loaded
|
|
if mgr.state in (State.WARMING, State.HOT):
|
|
await asyncio.to_thread(drain_gpu)
|
|
|
|
log.info("VRAM Lifecycle Manager stopped")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|