168 lines
5.1 KiB
Python
168 lines
5.1 KiB
Python
|
|
#!/usr/bin/env python3
|
||
|
|
"""Webhook server for triggering GPU training from CI.
|
||
|
|
|
||
|
|
Runs on the GPU workstation, listens for webhook requests from Forgejo Actions,
|
||
|
|
and triggers the training pipeline when authorized.
|
||
|
|
|
||
|
|
Usage:
|
||
|
|
python scripts/training-webhook-server.py --port 8888 --token YOUR_SECRET_TOKEN
|
||
|
|
|
||
|
|
Security:
|
||
|
|
- Bearer token authentication required
|
||
|
|
- Only accepts POST to /trigger-training
|
||
|
|
- Validates cooldown before triggering
|
||
|
|
- Logs all requests
|
||
|
|
"""
|
||
|
|
|
||
|
|
import argparse
|
||
|
|
import logging
|
||
|
|
import subprocess
|
||
|
|
import sys
|
||
|
|
from datetime import datetime
|
||
|
|
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||
|
|
from pathlib import Path
|
||
|
|
|
||
|
|
# Configure logging
|
||
|
|
logging.basicConfig(
|
||
|
|
level=logging.INFO,
|
||
|
|
format="%(asctime)s [%(levelname)s] %(message)s",
|
||
|
|
handlers=[
|
||
|
|
logging.FileHandler(Path.home() / ".cache/crystal/training-webhook.log"),
|
||
|
|
logging.StreamHandler(),
|
||
|
|
],
|
||
|
|
)
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
class TrainingWebhookHandler(BaseHTTPRequestHandler):
|
||
|
|
"""HTTP request handler for training webhook."""
|
||
|
|
|
||
|
|
def log_message(self, format, *args):
|
||
|
|
"""Override to use our logger."""
|
||
|
|
logger.info(format % args)
|
||
|
|
|
||
|
|
def do_POST(self):
|
||
|
|
"""Handle POST requests to trigger training."""
|
||
|
|
if self.path != "/trigger-training":
|
||
|
|
self.send_error(404, "Not Found")
|
||
|
|
return
|
||
|
|
|
||
|
|
# Check authorization
|
||
|
|
auth_header = self.headers.get("Authorization", "")
|
||
|
|
expected_token = f"Bearer {self.server.webhook_token}"
|
||
|
|
|
||
|
|
if auth_header != expected_token:
|
||
|
|
logger.warning(f"Unauthorized request from {self.client_address[0]}")
|
||
|
|
self.send_error(401, "Unauthorized")
|
||
|
|
return
|
||
|
|
|
||
|
|
# Check cooldown
|
||
|
|
cooldown_script = (
|
||
|
|
Path(__file__).parent / "check-training-needed.sh"
|
||
|
|
).resolve()
|
||
|
|
|
||
|
|
try:
|
||
|
|
result = subprocess.run(
|
||
|
|
["bash", str(cooldown_script)],
|
||
|
|
capture_output=True,
|
||
|
|
text=True,
|
||
|
|
timeout=10,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Parse output for should_train
|
||
|
|
should_train = False
|
||
|
|
for line in result.stdout.splitlines():
|
||
|
|
if line.startswith("should_train="):
|
||
|
|
should_train = line.split("=")[1].strip() == "true"
|
||
|
|
break
|
||
|
|
|
||
|
|
if not should_train:
|
||
|
|
logger.info("Training not needed (cooldown active)")
|
||
|
|
self.send_response(200)
|
||
|
|
self.send_header("Content-Type", "application/json")
|
||
|
|
self.end_headers()
|
||
|
|
self.wfile.write(
|
||
|
|
b'{"status":"skipped","reason":"cooldown_active"}'
|
||
|
|
)
|
||
|
|
return
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Cooldown check failed: {e}")
|
||
|
|
self.send_error(500, f"Cooldown check failed: {e}")
|
||
|
|
return
|
||
|
|
|
||
|
|
# Trigger training
|
||
|
|
trigger_script = (
|
||
|
|
Path(__file__).parent / "trigger-training-vps.sh"
|
||
|
|
).resolve()
|
||
|
|
|
||
|
|
try:
|
||
|
|
logger.info("Triggering training...")
|
||
|
|
subprocess.Popen(
|
||
|
|
["bash", str(trigger_script)],
|
||
|
|
stdout=subprocess.PIPE,
|
||
|
|
stderr=subprocess.PIPE,
|
||
|
|
)
|
||
|
|
|
||
|
|
self.send_response(202) # Accepted
|
||
|
|
self.send_header("Content-Type", "application/json")
|
||
|
|
self.end_headers()
|
||
|
|
self.wfile.write(
|
||
|
|
b'{"status":"triggered","timestamp":"'
|
||
|
|
+ datetime.now().isoformat().encode()
|
||
|
|
+ b'"}'
|
||
|
|
)
|
||
|
|
|
||
|
|
logger.info("Training triggered successfully")
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Failed to trigger training: {e}")
|
||
|
|
self.send_error(500, f"Failed to trigger training: {e}")
|
||
|
|
|
||
|
|
def do_GET(self):
|
||
|
|
"""Handle GET requests (health check)."""
|
||
|
|
if self.path == "/health":
|
||
|
|
self.send_response(200)
|
||
|
|
self.send_header("Content-Type", "application/json")
|
||
|
|
self.end_headers()
|
||
|
|
self.wfile.write(b'{"status":"healthy"}')
|
||
|
|
else:
|
||
|
|
self.send_error(404, "Not Found")
|
||
|
|
|
||
|
|
|
||
|
|
def main():
|
||
|
|
"""Run webhook server."""
|
||
|
|
parser = argparse.ArgumentParser(
|
||
|
|
description="Training webhook server for GPU workstation"
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
"--port", type=int, default=8888, help="Port to listen on (default: 8888)"
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
"--token",
|
||
|
|
required=True,
|
||
|
|
help="Bearer token for authentication (keep secret!)",
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
"--host", default="0.0.0.0", help="Host to bind to (default: 0.0.0.0)"
|
||
|
|
)
|
||
|
|
|
||
|
|
args = parser.parse_args()
|
||
|
|
|
||
|
|
# Store token in server instance
|
||
|
|
server = HTTPServer((args.host, args.port), TrainingWebhookHandler)
|
||
|
|
server.webhook_token = args.token
|
||
|
|
|
||
|
|
logger.info(f"Training webhook server starting on {args.host}:{args.port}")
|
||
|
|
logger.info("Waiting for training trigger requests from CI...")
|
||
|
|
|
||
|
|
try:
|
||
|
|
server.serve_forever()
|
||
|
|
except KeyboardInterrupt:
|
||
|
|
logger.info("Shutting down webhook server")
|
||
|
|
server.shutdown()
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|