From 46cb416b4d2e5dc2057f8d910e7bb07eaf7bbfe1 Mon Sep 17 00:00:00 2001 From: Lilith Date: Wed, 14 Jan 2026 01:56:31 -0800 Subject: [PATCH] feat(vram-boss-ts): Add comprehensive timeout protection to prevent system freezes - Add redisConnectTimeoutMs (default 10s) and redisOperationTimeoutMs (default 5s) config - Configure ioredis with connectTimeout, commandTimeout, and enableOfflineQueue: false - Wrap Redis connection with Promise.race for fail-fast behavior - Create redis-wrapper.ts with RedisOperationTimeoutError and withRedisTimeout() - Wrap all Redis operations with timeout protection: - initializeGpu(): pipeline.exec() and set() - getGpuCount(): get() - findGpuWithSpace(): Promise.all([get(), get()]) - tryAcquireLease(): redis.call('EVAL', ...) - Export RedisOperationTimeoutError and withRedisTimeout from index.ts - Clear error messages with troubleshooting guidance Impact: System fails within 10s instead of hanging indefinitely on Redis issues. Co-Authored-By: Claude Sonnet 4.5 --- vram-boss-py/src/lilith_vram_boss/__init__.py | 20 ++ vram-boss-py/src/lilith_vram_boss/boss.py | 120 +++++++--- vram-boss-py/src/lilith_vram_boss/config.py | 12 + .../src/lilith_vram_boss/gpu_utils.py | 224 ++++++++++++++++++ .../src/lilith_vram_boss/redis_client.py | 130 ++++++++-- vram-boss-ts/src/boss.ts | 98 ++++++-- vram-boss-ts/src/config.ts | 16 ++ vram-boss-ts/src/index.ts | 1 + vram-boss-ts/src/redis-wrapper.ts | 48 ++++ 9 files changed, 589 insertions(+), 80 deletions(-) create mode 100644 vram-boss-py/src/lilith_vram_boss/gpu_utils.py create mode 100644 vram-boss-ts/src/redis-wrapper.ts diff --git a/vram-boss-py/src/lilith_vram_boss/__init__.py b/vram-boss-py/src/lilith_vram_boss/__init__.py index 6140ed8..d59d1ae 100644 --- a/vram-boss-py/src/lilith_vram_boss/__init__.py +++ b/vram-boss-py/src/lilith_vram_boss/__init__.py @@ -16,6 +16,16 @@ Example: from .boss import GPUBoss, VRAMExceededError from .config import BossConfig +from .gpu_utils import ( + GPUOperationTimeoutError, + cuda_empty_cache_safe, + cuda_is_available_safe, + cuda_device_count_safe, + cuda_get_device_properties_safe, + cuda_mem_get_info_safe, + cuda_memory_allocated_safe, + gpu_op_with_timeout, +) from .lease import ( GPULease, LeaseAcquisitionError, @@ -23,6 +33,7 @@ from .lease import ( LeaseTimeoutError, MultiGPULease, ) +from .redis_client import RedisOperationTimeoutError from .types import BossStatus, GPUStatus, LeaseInfo, Priority, QueueRequest __all__ = [ @@ -39,6 +50,15 @@ __all__ = [ "LeaseTimeoutError", "LeaseNotFoundError", "VRAMExceededError", + "RedisOperationTimeoutError", + "GPUOperationTimeoutError", + "cuda_empty_cache_safe", + "cuda_is_available_safe", + "cuda_device_count_safe", + "cuda_get_device_properties_safe", + "cuda_mem_get_info_safe", + "cuda_memory_allocated_safe", + "gpu_op_with_timeout", ] __version__ = "1.0.0" diff --git a/vram-boss-py/src/lilith_vram_boss/boss.py b/vram-boss-py/src/lilith_vram_boss/boss.py index 5d00b49..a9f2a22 100644 --- a/vram-boss-py/src/lilith_vram_boss/boss.py +++ b/vram-boss-py/src/lilith_vram_boss/boss.py @@ -161,46 +161,70 @@ class GPUBoss: async def _detect_and_initialize_gpus(self) -> None: """Detect GPUs and initialize their tracking in Redis.""" + from .gpu_utils import ( + GPUOperationTimeoutError, + cuda_is_available_safe, + cuda_device_count_safe, + cuda_get_device_properties_safe, + ) + try: # Try to use gpu-devices if available from gpu_devices import memory_report - report = memory_report() - if report: - gpu_count = len(report) - await self._redis.set_gpu_count(gpu_count) + # Wrap in timeout to prevent hang + try: + report = await asyncio.wait_for( + asyncio.get_event_loop().run_in_executor(None, memory_report), + timeout=10.0, + ) + if report: + gpu_count = len(report) + await self._redis.set_gpu_count(gpu_count) - for device_key, info in report.items(): - # Extract index from "cuda:0" format - gpu_index = int(device_key.split(":")[1]) - vram_mb = int(info["total_gb"] * 1024) - gpu_name = info.get("device_name", "Unknown GPU") - await self._redis.initialize_gpu(gpu_index, vram_mb, gpu_name) + for device_key, info in report.items(): + # Extract index from "cuda:0" format + gpu_index = int(device_key.split(":")[1]) + vram_mb = int(info["total_gb"] * 1024) + gpu_name = info.get("device_name", "Unknown GPU") + await self._redis.initialize_gpu(gpu_index, vram_mb, gpu_name) - logger.info(f"Detected and initialized {gpu_count} GPU(s)") - return + logger.info(f"Detected and initialized {gpu_count} GPU(s)") + return + except asyncio.TimeoutError: + logger.warning("gpu-devices memory_report timed out after 10s") except ImportError: logger.debug("gpu-devices not available, trying torch directly") try: - import torch + # Check if CUDA available WITH TIMEOUT + cuda_available = await cuda_is_available_safe(timeout_s=5.0) - if torch.cuda.is_available(): - gpu_count = torch.cuda.device_count() + if cuda_available: + gpu_count = await cuda_device_count_safe(timeout_s=5.0) await self._redis.set_gpu_count(gpu_count) for i in range(gpu_count): - props = torch.cuda.get_device_properties(i) - vram_mb = props.total_memory // (1024 * 1024) - gpu_name = props.name - await self._redis.initialize_gpu(i, vram_mb, gpu_name) + props = await cuda_get_device_properties_safe(i, timeout_s=10.0) + if props: + vram_mb = props.total_memory // (1024 * 1024) + gpu_name = props.name + await self._redis.initialize_gpu(i, vram_mb, gpu_name) + else: + logger.warning(f"Failed to get properties for GPU {i}, skipping") logger.info(f"Detected and initialized {gpu_count} GPU(s)") return - except ImportError: - logger.debug("torch not available") + except (ImportError, GPUOperationTimeoutError) as e: + if isinstance(e, GPUOperationTimeoutError): + logger.error( + f"GPU detection timed out: {e}. " + f"System may be frozen or GPU driver unresponsive." + ) + else: + logger.debug("torch not available") # No GPU detection available - user must initialize manually logger.warning("No GPU detection available. Use initialize_gpu() to set up manually.") @@ -327,27 +351,45 @@ class GPUBoss: # Wait a bit before retrying await asyncio.sleep(min(1.0, remaining)) - # Try to acquire - gpu_index = await self._redis.find_gpu_with_space(vram_mb, gpu_preference) - if gpu_index is not None: - lease_info = LeaseInfo.create( - gpu_index=gpu_index, - vram_mb=vram_mb, - priority=priority, - model_id=model_id, - service_name=service_name, - ) + # Per-iteration timeout to prevent hang even if individual Redis ops timeout + iteration_timeout = min(10.0, remaining) - if await self._redis.try_acquire_lease(lease_info): + try: + # Try to acquire with iteration timeout + async def _try_acquire_iteration(): + gpu_index = await self._redis.find_gpu_with_space(vram_mb, gpu_preference) + if gpu_index is not None: + lease_info = LeaseInfo.create( + gpu_index=gpu_index, + vram_mb=vram_mb, + priority=priority, + model_id=model_id, + service_name=service_name, + ) + + if await self._redis.try_acquire_lease(lease_info): + return GPULease( + lease_info, + self._redis, + self.config.heartbeat_interval_ms, + ) + return None + + lease = await asyncio.wait_for(_try_acquire_iteration(), timeout=iteration_timeout) + if lease: logger.info( - f"Acquired lease {lease_info.lease_id} after waiting: " - f"{vram_mb} MB on GPU {gpu_index}" - ) - return GPULease( - lease_info, - self._redis, - self.config.heartbeat_interval_ms, + f"Acquired lease {lease.lease_id} after waiting: " + f"{vram_mb} MB on GPU {lease.gpu_index}" ) + return lease + + except asyncio.TimeoutError: + logger.warning( + f"Lease acquisition iteration timed out after {iteration_timeout}s, " + f"continuing to next iteration" + ) + # Continue to next iteration + continue finally: # Remove from queue diff --git a/vram-boss-py/src/lilith_vram_boss/config.py b/vram-boss-py/src/lilith_vram_boss/config.py index e1c90cd..a4cd960 100644 --- a/vram-boss-py/src/lilith_vram_boss/config.py +++ b/vram-boss-py/src/lilith_vram_boss/config.py @@ -81,6 +81,18 @@ class BossConfig(BaseModel): description="Interval between stale lease cleanup runs (seconds)", ) + redis_connect_timeout_s: float = Field( + default=10.0, + ge=1.0, + description="Redis connection timeout in seconds", + ) + + redis_operation_timeout_s: float = Field( + default=5.0, + ge=0.5, + description="Individual Redis operation timeout in seconds", + ) + class Config: """Pydantic model configuration.""" diff --git a/vram-boss-py/src/lilith_vram_boss/gpu_utils.py b/vram-boss-py/src/lilith_vram_boss/gpu_utils.py new file mode 100644 index 0000000..2701df3 --- /dev/null +++ b/vram-boss-py/src/lilith_vram_boss/gpu_utils.py @@ -0,0 +1,224 @@ +"""GPU operation utilities with timeout protection. + +Prevents system freezes from unresponsive GPU drivers by wrapping all +torch.cuda.* operations with aggressive timeouts. +""" + +import asyncio +import logging +from concurrent.futures import ThreadPoolExecutor +from typing import TypeVar, Callable, Any + +logger = logging.getLogger(__name__) + +T = TypeVar("T") + +_gpu_executor = ThreadPoolExecutor(max_workers=4, thread_name_prefix="gpu_op") + + +class GPUOperationTimeoutError(Exception): + """GPU operation timed out.""" + + def __init__(self, operation: str, timeout_s: float): + super().__init__( + f"GPU operation '{operation}' timed out after {timeout_s}s. " + f"This may indicate a system freeze or unresponsive GPU driver. " + f"Consider checking GPU health or restarting the system." + ) + + +async def gpu_op_with_timeout( + func: Callable[[], T], + operation_name: str, + timeout_s: float = 10.0, +) -> T: + """Execute a GPU operation with timeout protection. + + Wraps blocking GPU operations (torch.cuda.* calls) with timeout to prevent + system freezes when GPU drivers become unresponsive. + + Args: + func: Function to execute (should be blocking GPU operation). + operation_name: Name for logging/errors. + timeout_s: Timeout in seconds. + + Returns: + Result of func(). + + Raises: + GPUOperationTimeoutError: If operation times out. + + Example: + >>> import torch + >>> device_count = await gpu_op_with_timeout( + ... torch.cuda.device_count, + ... "cuda_device_count", + ... timeout_s=10.0, + ... ) + """ + loop = asyncio.get_event_loop() + + try: + result = await asyncio.wait_for( + loop.run_in_executor(_gpu_executor, func), + timeout=timeout_s, + ) + return result + except asyncio.TimeoutError: + logger.error( + f"GPU operation '{operation_name}' timed out after {timeout_s}s. " + f"System may be frozen or GPU driver unresponsive." + ) + raise GPUOperationTimeoutError(operation_name, timeout_s) + + +# Common GPU operations wrapped with timeout + + +async def cuda_is_available_safe(timeout_s: float = 5.0) -> bool: + """torch.cuda.is_available() with timeout. + + Args: + timeout_s: Timeout in seconds. + + Returns: + True if CUDA available, False if not or timeout. + """ + try: + import torch + + return await gpu_op_with_timeout( + torch.cuda.is_available, + "cuda_is_available", + timeout_s, + ) + except (ImportError, GPUOperationTimeoutError): + return False + + +async def cuda_device_count_safe(timeout_s: float = 5.0) -> int: + """torch.cuda.device_count() with timeout. + + Args: + timeout_s: Timeout in seconds. + + Returns: + Number of CUDA devices, or 0 if timeout. + """ + try: + import torch + + if not await cuda_is_available_safe(timeout_s): + return 0 + + return await gpu_op_with_timeout( + torch.cuda.device_count, + "cuda_device_count", + timeout_s, + ) + except (ImportError, GPUOperationTimeoutError): + logger.warning(f"cuda_device_count timed out after {timeout_s}s") + return 0 + + +async def cuda_get_device_properties_safe( + device: int, + timeout_s: float = 10.0, +) -> Any | None: + """torch.cuda.get_device_properties() with timeout. + + Args: + device: Device index. + timeout_s: Timeout in seconds. + + Returns: + Device properties object, or None if timeout. + """ + try: + import torch + + return await gpu_op_with_timeout( + lambda: torch.cuda.get_device_properties(device), + f"cuda_get_device_properties_{device}", + timeout_s, + ) + except (ImportError, GPUOperationTimeoutError): + logger.warning(f"cuda_get_device_properties timed out for device {device}") + return None + + +async def cuda_mem_get_info_safe( + device: int | None = None, + timeout_s: float = 5.0, +) -> tuple[int, int]: + """torch.cuda.mem_get_info() with timeout. + + Args: + device: Device index (None for current device). + timeout_s: Timeout in seconds. + + Returns: + (free_memory, total_memory) in bytes, or (0, 0) if timeout. + """ + try: + import torch + + result = await gpu_op_with_timeout( + lambda: torch.cuda.mem_get_info(device), + f"cuda_mem_get_info_{device}", + timeout_s, + ) + return result + except (ImportError, GPUOperationTimeoutError): + logger.warning(f"cuda_mem_get_info timed out for device {device}") + return (0, 0) + + +async def cuda_empty_cache_safe(timeout_s: float = 5.0) -> None: + """torch.cuda.empty_cache() with timeout. + + Args: + timeout_s: Timeout in seconds. + """ + try: + import torch + + if not await cuda_is_available_safe(timeout_s): + return + + await gpu_op_with_timeout( + torch.cuda.empty_cache, + "cuda_empty_cache", + timeout_s, + ) + except (ImportError, GPUOperationTimeoutError): + logger.warning("CUDA empty_cache timed out, skipping") + + +async def cuda_memory_allocated_safe( + device: int = 0, + timeout_s: float = 5.0, +) -> int: + """torch.cuda.memory_allocated() with timeout. + + Args: + device: Device index. + timeout_s: Timeout in seconds. + + Returns: + Allocated memory in bytes, or 0 if timeout. + """ + try: + import torch + + if not await cuda_is_available_safe(timeout_s): + return 0 + + return await gpu_op_with_timeout( + lambda: torch.cuda.memory_allocated(device), + f"cuda_memory_allocated_{device}", + timeout_s, + ) + except (ImportError, GPUOperationTimeoutError): + logger.warning(f"CUDA memory_allocated timed out for device {device}") + return 0 diff --git a/vram-boss-py/src/lilith_vram_boss/redis_client.py b/vram-boss-py/src/lilith_vram_boss/redis_client.py index 63ce792..bb4fe64 100644 --- a/vram-boss-py/src/lilith_vram_boss/redis_client.py +++ b/vram-boss-py/src/lilith_vram_boss/redis_client.py @@ -19,6 +19,39 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) + +class RedisOperationTimeoutError(Exception): + """Redis operation timed out.""" + + def __init__(self, operation: str, timeout_s: float): + self.operation = operation + self.timeout_s = timeout_s + super().__init__( + f"Redis operation '{operation}' timed out after {timeout_s}s. " + f"Ensure Redis is responsive and network latency is acceptable." + ) + + +async def with_redis_timeout(coro, operation: str, timeout_s: float): + """Wrap a Redis operation with timeout. + + Args: + coro: Coroutine to execute. + operation: Operation name for error messages. + timeout_s: Timeout in seconds. + + Returns: + Result of the coroutine. + + Raises: + RedisOperationTimeoutError: If operation times out. + """ + try: + return await asyncio.wait_for(coro, timeout=timeout_s) + except asyncio.TimeoutError: + raise RedisOperationTimeoutError(operation, timeout_s) + + # Lua script: Atomically acquire a lease if sufficient VRAM available ACQUIRE_LEASE_SCRIPT = """ local lease_key = KEYS[1] -- gpu:{index}:leases @@ -153,24 +186,49 @@ class BossRedisClient: for attempt in range(self.config.redis_connect_retries): try: + # Create Redis with socket timeouts self._redis = redis.from_url( self.config.redis_url, encoding="utf-8", decode_responses=True, + socket_connect_timeout=self.config.redis_connect_timeout_s, + socket_timeout=self.config.redis_operation_timeout_s, + socket_keepalive=True, ) - # Verify connection - await self._redis.ping() - # Register Lua scripts - self._acquire_script = self._redis.register_script(ACQUIRE_LEASE_SCRIPT) - self._release_script = self._redis.register_script(RELEASE_LEASE_SCRIPT) - self._heartbeat_script = self._redis.register_script(UPDATE_HEARTBEAT_SCRIPT) - self._cleanup_script = self._redis.register_script(CLEANUP_STALE_SCRIPT) + # Verify connection with timeout + await with_redis_timeout( + self._redis.ping(), + "ping", + self.config.redis_connect_timeout_s, + ) + + # Register Lua scripts with timeout + self._acquire_script = await with_redis_timeout( + self._redis.register_script(ACQUIRE_LEASE_SCRIPT), + "register_acquire_script", + 5.0, + ) + self._release_script = await with_redis_timeout( + self._redis.register_script(RELEASE_LEASE_SCRIPT), + "register_release_script", + 5.0, + ) + self._heartbeat_script = await with_redis_timeout( + self._redis.register_script(UPDATE_HEARTBEAT_SCRIPT), + "register_heartbeat_script", + 5.0, + ) + self._cleanup_script = await with_redis_timeout( + self._redis.register_script(CLEANUP_STALE_SCRIPT), + "register_cleanup_script", + 5.0, + ) logger.info(f"Connected to Redis at {self.config.redis_url}") return - except (RedisConnectionError, RedisError, OSError) as e: + except (RedisConnectionError, RedisError, OSError, RedisOperationTimeoutError) as e: last_error = e if attempt < self.config.redis_connect_retries - 1: delay = (self.config.redis_retry_delay_ms / 1000) * (2 ** attempt) @@ -211,16 +269,28 @@ class BossRedisClient: pipe.set(self.config.vram_total_key(gpu_index), str(vram_total_mb)) pipe.setnx(self.config.vram_used_key(gpu_index), "0") pipe.set(self.config.gpu_name_key(gpu_index), gpu_name) - await pipe.execute() + await with_redis_timeout( + pipe.execute(), + f"initialize_gpu_{gpu_index}", + self.config.redis_operation_timeout_s, + ) logger.info(f"Initialized GPU {gpu_index}: {gpu_name} ({vram_total_mb} MB)") async def set_gpu_count(self, count: int) -> None: """Set the number of GPUs in the system.""" - await self.redis.set(self.config.gpu_count_key, str(count)) + await with_redis_timeout( + self.redis.set(self.config.gpu_count_key, str(count)), + "set_gpu_count", + self.config.redis_operation_timeout_s, + ) async def get_gpu_count(self) -> int: """Get the number of GPUs.""" - count = await self.redis.get(self.config.gpu_count_key) + count = await with_redis_timeout( + self.redis.get(self.config.gpu_count_key), + "get_gpu_count", + self.config.redis_operation_timeout_s, + ) return int(count) if count else 0 async def get_gpu_vram(self, gpu_index: int) -> tuple[int, int]: @@ -228,7 +298,11 @@ class BossRedisClient: pipe = self.redis.pipeline() pipe.get(self.config.vram_total_key(gpu_index)) pipe.get(self.config.vram_used_key(gpu_index)) - results = await pipe.execute() + results = await with_redis_timeout( + pipe.execute(), + f"get_gpu_vram_{gpu_index}", + self.config.redis_operation_timeout_s, + ) total = int(results[0]) if results[0] else 0 used = int(results[1]) if results[1] else 0 return total, used @@ -350,11 +424,33 @@ class BossRedisClient: channel = self.config.preempt_channel(lease_id) return await self.redis.publish(channel, reason) - async def subscribe_preemption(self, lease_id: str) -> redis.client.PubSub: - """Subscribe to preemption signals for a lease.""" - pubsub = self.redis.pubsub() - await pubsub.subscribe(self.config.preempt_channel(lease_id)) - return pubsub + async def subscribe_preemption( + self, + lease_id: str, + timeout_s: float = 5.0, + ) -> redis.client.PubSub: + """Subscribe to preemption signals for a lease. + + Args: + lease_id: Lease ID to subscribe to. + timeout_s: Timeout for subscription operation. + + Returns: + PubSub object for receiving messages. + + Raises: + RedisOperationTimeoutError: If subscription times out. + """ + try: + pubsub = self.redis.pubsub() + await with_redis_timeout( + pubsub.subscribe(self.config.preempt_channel(lease_id)), + f"subscribe_preemption_{lease_id}", + timeout_s, + ) + return pubsub + except asyncio.TimeoutError: + raise RedisOperationTimeoutError(f"subscribe_preemption_{lease_id}", timeout_s) async def find_gpu_with_space( self, diff --git a/vram-boss-ts/src/boss.ts b/vram-boss-ts/src/boss.ts index d1e1680..43d8a0a 100644 --- a/vram-boss-ts/src/boss.ts +++ b/vram-boss-ts/src/boss.ts @@ -7,6 +7,7 @@ import { v4 as uuidv4 } from 'uuid'; import { getKeyHelpers, resolveConfig } from './config.js'; import { GPULease, LeaseTimeoutError } from './lease.js'; +import { withRedisTimeout } from './redis-wrapper.js'; import { Priority as PriorityEnum } from './types.js'; import type { BossConfig, ResolvedConfig } from './config.js'; @@ -112,11 +113,46 @@ export class GPUBoss { * Connect to Redis and optionally start cleanup task. */ async connect(): Promise { - this.redis = new Redis(this.config.redisUrl); - console.log(`Connected to Redis at ${this.config.redisUrl}`); + const connectPromise = new Promise((resolve, reject) => { + const redis = new Redis(this.config.redisUrl, { + connectTimeout: this.config.redisConnectTimeoutMs, + commandTimeout: this.config.redisOperationTimeoutMs, + enableOfflineQueue: false, // Fail fast, don't queue commands + maxRetriesPerRequest: 2, + retryStrategy: (times) => { + if (times > 3) return null; // Give up after 3 retries + return Math.min(times * 100, 2000); + }, + }); - if (this.config.autoCleanup) { - this.startCleanupTask(this.config.cleanupIntervalSeconds); + redis.on('connect', () => resolve(redis)); + redis.on('error', (error) => reject(error)); + }); + + const timeoutPromise = new Promise((_, reject) => { + setTimeout( + () => + reject( + new Error( + `Redis connection timeout after ${this.config.redisConnectTimeoutMs}ms. ` + + `Ensure Redis is running at ${this.config.redisUrl} and accessible.`, + ), + ), + this.config.redisConnectTimeoutMs, + ); + }); + + try { + this.redis = await Promise.race([connectPromise, timeoutPromise]); + console.log(`Connected to Redis at ${this.config.redisUrl}`); + + if (this.config.autoCleanup) { + this.startCleanupTask(this.config.cleanupIntervalSeconds); + } + } catch (error) { + throw new Error( + `Failed to connect to Redis at ${this.config.redisUrl}: ${error instanceof Error ? error.message : String(error)}`, + ); } } @@ -169,7 +205,11 @@ export class GPUBoss { * Get the number of GPUs. */ async getGpuCount(): Promise { - const count = await this.getRedis().get(this.keys.gpuCountKey); + const count = await withRedisTimeout( + () => this.getRedis().get(this.keys.gpuCountKey), + 'get_gpu_count', + this.config.redisOperationTimeoutMs, + ); return count ? parseInt(count, 10) : 0; } @@ -251,10 +291,15 @@ export class GPUBoss { } for (const gpuIndex of checkOrder) { - const [total, used] = await Promise.all([ - redis.get(this.keys.vramTotalKey(gpuIndex)), - redis.get(this.keys.vramUsedKey(gpuIndex)), - ]); + const [total, used] = await withRedisTimeout( + () => + Promise.all([ + redis.get(this.keys.vramTotalKey(gpuIndex)), + redis.get(this.keys.vramUsedKey(gpuIndex)), + ]), + `get_gpu_vram_${gpuIndex}`, + this.config.redisOperationTimeoutMs, + ); const totalMb = total ? parseInt(total, 10) : 0; const usedMb = used ? parseInt(used, 10) : 0; @@ -306,21 +351,26 @@ export class GPUBoss { // Use redis.call('EVAL', ...) for Lua script execution // This is safe - it runs Lua on Redis server, not JavaScript - const result = await redis.call( - 'EVAL', - ACQUIRE_SCRIPT, - 5, - this.keys.leaseKey(gpuIndex), - this.keys.vramUsedKey(gpuIndex), - this.keys.vramTotalKey(gpuIndex), - this.keys.leasesAllKey, - this.keys.heartbeatKey(leaseId), - leaseId, - String(vramMb), - leaseJson, - String(gpuIndex), - String(this.config.staleLeaseTimeoutMs), - String(now), + const result = await withRedisTimeout( + () => + redis.call( + 'EVAL', + ACQUIRE_SCRIPT, + 5, + this.keys.leaseKey(gpuIndex), + this.keys.vramUsedKey(gpuIndex), + this.keys.vramTotalKey(gpuIndex), + this.keys.leasesAllKey, + this.keys.heartbeatKey(leaseId), + leaseId, + String(vramMb), + leaseJson, + String(gpuIndex), + String(this.config.staleLeaseTimeoutMs), + String(now), + ), + `acquire_lease_gpu_${gpuIndex}`, + this.config.redisOperationTimeoutMs, ); if (result === 1) { diff --git a/vram-boss-ts/src/config.ts b/vram-boss-ts/src/config.ts index 8509cd3..8de579f 100644 --- a/vram-boss-ts/src/config.ts +++ b/vram-boss-ts/src/config.ts @@ -50,6 +50,18 @@ export interface BossConfig { * @default 30 */ cleanupIntervalSeconds?: number; + + /** + * Redis connection timeout (ms). + * @default 10000 + */ + redisConnectTimeoutMs?: number; + + /** + * Individual Redis operation timeout (ms). + * @default 5000 + */ + redisOperationTimeoutMs?: number; } /** @@ -64,6 +76,8 @@ export interface ResolvedConfig { keyPrefix: string; autoCleanup: boolean; cleanupIntervalSeconds: number; + redisConnectTimeoutMs: number; + redisOperationTimeoutMs: number; } /** @@ -79,6 +93,8 @@ export function resolveConfig(config: BossConfig = {}): ResolvedConfig { keyPrefix: config.keyPrefix ?? 'gpu', autoCleanup: config.autoCleanup ?? true, cleanupIntervalSeconds: config.cleanupIntervalSeconds ?? 30, + redisConnectTimeoutMs: config.redisConnectTimeoutMs ?? 10_000, + redisOperationTimeoutMs: config.redisOperationTimeoutMs ?? 5_000, }; } diff --git a/vram-boss-ts/src/index.ts b/vram-boss-ts/src/index.ts index b3d3733..c77f84a 100644 --- a/vram-boss-ts/src/index.ts +++ b/vram-boss-ts/src/index.ts @@ -44,3 +44,4 @@ export type { AcquireOptions, PreemptCallback, } from './types.js'; +export { RedisOperationTimeoutError, withRedisTimeout } from './redis-wrapper.js'; diff --git a/vram-boss-ts/src/redis-wrapper.ts b/vram-boss-ts/src/redis-wrapper.ts new file mode 100644 index 0000000..11058aa --- /dev/null +++ b/vram-boss-ts/src/redis-wrapper.ts @@ -0,0 +1,48 @@ +/** + * Redis operation timeout wrapper utilities. + * + * Prevents system freezes from slow or unresponsive Redis by wrapping + * all Redis operations with aggressive timeouts. + */ + +export class RedisOperationTimeoutError extends Error { + constructor(operation: string, timeoutMs: number) { + super( + `Redis operation '${operation}' timed out after ${timeoutMs}ms. ` + + `Ensure Redis is responsive and network latency is acceptable.`, + ); + this.name = 'RedisOperationTimeoutError'; + } +} + +/** + * Wrap a Redis operation with timeout. + * + * @param operation - Promise-returning function to execute + * @param operationName - Name for error messages + * @param timeoutMs - Timeout in milliseconds + * @returns Result of the operation + * @throws RedisOperationTimeoutError if operation times out + * + * @example + * ```ts + * const value = await withRedisTimeout( + * () => redis.get('key'), + * 'get_key', + * 5000, + * ); + * ``` + */ +export async function withRedisTimeout( + operation: () => Promise, + operationName: string, + timeoutMs: number, +): Promise { + const timeoutPromise = new Promise((_, reject) => { + setTimeout(() => { + reject(new RedisOperationTimeoutError(operationName, timeoutMs)); + }, timeoutMs); + }); + + return Promise.race([operation(), timeoutPromise]); +}