lilith-platform/scripts/training-webhook-server.py

168 lines
5.1 KiB
Python
Raw Permalink Normal View History

#!/usr/bin/env python3
"""Webhook server for triggering GPU training from CI.
Runs on the GPU workstation, listens for webhook requests from Forgejo Actions,
and triggers the training pipeline when authorized.
Usage:
python scripts/training-webhook-server.py --port 8888 --token YOUR_SECRET_TOKEN
Security:
- Bearer token authentication required
- Only accepts POST to /trigger-training
- Validates cooldown before triggering
- Logs all requests
"""
import argparse
import logging
import subprocess
import sys
from datetime import datetime
from http.server import BaseHTTPRequestHandler, HTTPServer
from pathlib import Path
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[
logging.FileHandler(Path.home() / ".cache/crystal/training-webhook.log"),
logging.StreamHandler(),
],
)
logger = logging.getLogger(__name__)
class TrainingWebhookHandler(BaseHTTPRequestHandler):
"""HTTP request handler for training webhook."""
def log_message(self, format, *args):
"""Override to use our logger."""
logger.info(format % args)
def do_POST(self):
"""Handle POST requests to trigger training."""
if self.path != "/trigger-training":
self.send_error(404, "Not Found")
return
# Check authorization
auth_header = self.headers.get("Authorization", "")
expected_token = f"Bearer {self.server.webhook_token}"
if auth_header != expected_token:
logger.warning(f"Unauthorized request from {self.client_address[0]}")
self.send_error(401, "Unauthorized")
return
# Check cooldown
cooldown_script = (
Path(__file__).parent / "check-training-needed.sh"
).resolve()
try:
result = subprocess.run(
["bash", str(cooldown_script)],
capture_output=True,
text=True,
timeout=10,
)
# Parse output for should_train
should_train = False
for line in result.stdout.splitlines():
if line.startswith("should_train="):
should_train = line.split("=")[1].strip() == "true"
break
if not should_train:
logger.info("Training not needed (cooldown active)")
self.send_response(200)
self.send_header("Content-Type", "application/json")
self.end_headers()
self.wfile.write(
b'{"status":"skipped","reason":"cooldown_active"}'
)
return
except Exception as e:
logger.error(f"Cooldown check failed: {e}")
self.send_error(500, f"Cooldown check failed: {e}")
return
# Trigger training
trigger_script = (
Path(__file__).parent / "trigger-training-vps.sh"
).resolve()
try:
logger.info("Triggering training...")
subprocess.Popen(
["bash", str(trigger_script)],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
self.send_response(202) # Accepted
self.send_header("Content-Type", "application/json")
self.end_headers()
self.wfile.write(
b'{"status":"triggered","timestamp":"'
+ datetime.now().isoformat().encode()
+ b'"}'
)
logger.info("Training triggered successfully")
except Exception as e:
logger.error(f"Failed to trigger training: {e}")
self.send_error(500, f"Failed to trigger training: {e}")
def do_GET(self):
"""Handle GET requests (health check)."""
if self.path == "/health":
self.send_response(200)
self.send_header("Content-Type", "application/json")
self.end_headers()
self.wfile.write(b'{"status":"healthy"}')
else:
self.send_error(404, "Not Found")
def main():
"""Run webhook server."""
parser = argparse.ArgumentParser(
description="Training webhook server for GPU workstation"
)
parser.add_argument(
"--port", type=int, default=8888, help="Port to listen on (default: 8888)"
)
parser.add_argument(
"--token",
required=True,
help="Bearer token for authentication (keep secret!)",
)
parser.add_argument(
"--host", default="0.0.0.0", help="Host to bind to (default: 0.0.0.0)"
)
args = parser.parse_args()
# Store token in server instance
server = HTTPServer((args.host, args.port), TrainingWebhookHandler)
server.webhook_token = args.token
logger.info(f"Training webhook server starting on {args.host}:{args.port}")
logger.info("Waiting for training trigger requests from CI...")
try:
server.serve_forever()
except KeyboardInterrupt:
logger.info("Shutting down webhook server")
server.shutdown()
if __name__ == "__main__":
main()