feat(vram-boss-ts): Add comprehensive timeout protection to prevent system freezes
- Add redisConnectTimeoutMs (default 10s) and redisOperationTimeoutMs (default 5s) config
- Configure ioredis with connectTimeout, commandTimeout, and enableOfflineQueue: false
- Wrap Redis connection with Promise.race for fail-fast behavior
- Create redis-wrapper.ts with RedisOperationTimeoutError and withRedisTimeout()
- Wrap all Redis operations with timeout protection:
- initializeGpu(): pipeline.exec() and set()
- getGpuCount(): get()
- findGpuWithSpace(): Promise.all([get(), get()])
- tryAcquireLease(): redis.call('EVAL', ...)
- Export RedisOperationTimeoutError and withRedisTimeout from index.ts
- Clear error messages with troubleshooting guidance
Impact: System fails within 10s instead of hanging indefinitely on Redis issues.
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
f8de8b7b1d
commit
46cb416b4d
9 changed files with 589 additions and 80 deletions
|
|
@ -16,6 +16,16 @@ Example:
|
|||
|
||||
from .boss import GPUBoss, VRAMExceededError
|
||||
from .config import BossConfig
|
||||
from .gpu_utils import (
|
||||
GPUOperationTimeoutError,
|
||||
cuda_empty_cache_safe,
|
||||
cuda_is_available_safe,
|
||||
cuda_device_count_safe,
|
||||
cuda_get_device_properties_safe,
|
||||
cuda_mem_get_info_safe,
|
||||
cuda_memory_allocated_safe,
|
||||
gpu_op_with_timeout,
|
||||
)
|
||||
from .lease import (
|
||||
GPULease,
|
||||
LeaseAcquisitionError,
|
||||
|
|
@ -23,6 +33,7 @@ from .lease import (
|
|||
LeaseTimeoutError,
|
||||
MultiGPULease,
|
||||
)
|
||||
from .redis_client import RedisOperationTimeoutError
|
||||
from .types import BossStatus, GPUStatus, LeaseInfo, Priority, QueueRequest
|
||||
|
||||
__all__ = [
|
||||
|
|
@ -39,6 +50,15 @@ __all__ = [
|
|||
"LeaseTimeoutError",
|
||||
"LeaseNotFoundError",
|
||||
"VRAMExceededError",
|
||||
"RedisOperationTimeoutError",
|
||||
"GPUOperationTimeoutError",
|
||||
"cuda_empty_cache_safe",
|
||||
"cuda_is_available_safe",
|
||||
"cuda_device_count_safe",
|
||||
"cuda_get_device_properties_safe",
|
||||
"cuda_mem_get_info_safe",
|
||||
"cuda_memory_allocated_safe",
|
||||
"gpu_op_with_timeout",
|
||||
]
|
||||
|
||||
__version__ = "1.0.0"
|
||||
|
|
|
|||
|
|
@ -161,46 +161,70 @@ class GPUBoss:
|
|||
|
||||
async def _detect_and_initialize_gpus(self) -> None:
|
||||
"""Detect GPUs and initialize their tracking in Redis."""
|
||||
from .gpu_utils import (
|
||||
GPUOperationTimeoutError,
|
||||
cuda_is_available_safe,
|
||||
cuda_device_count_safe,
|
||||
cuda_get_device_properties_safe,
|
||||
)
|
||||
|
||||
try:
|
||||
# Try to use gpu-devices if available
|
||||
from gpu_devices import memory_report
|
||||
|
||||
report = memory_report()
|
||||
if report:
|
||||
gpu_count = len(report)
|
||||
await self._redis.set_gpu_count(gpu_count)
|
||||
# Wrap in timeout to prevent hang
|
||||
try:
|
||||
report = await asyncio.wait_for(
|
||||
asyncio.get_event_loop().run_in_executor(None, memory_report),
|
||||
timeout=10.0,
|
||||
)
|
||||
if report:
|
||||
gpu_count = len(report)
|
||||
await self._redis.set_gpu_count(gpu_count)
|
||||
|
||||
for device_key, info in report.items():
|
||||
# Extract index from "cuda:0" format
|
||||
gpu_index = int(device_key.split(":")[1])
|
||||
vram_mb = int(info["total_gb"] * 1024)
|
||||
gpu_name = info.get("device_name", "Unknown GPU")
|
||||
await self._redis.initialize_gpu(gpu_index, vram_mb, gpu_name)
|
||||
for device_key, info in report.items():
|
||||
# Extract index from "cuda:0" format
|
||||
gpu_index = int(device_key.split(":")[1])
|
||||
vram_mb = int(info["total_gb"] * 1024)
|
||||
gpu_name = info.get("device_name", "Unknown GPU")
|
||||
await self._redis.initialize_gpu(gpu_index, vram_mb, gpu_name)
|
||||
|
||||
logger.info(f"Detected and initialized {gpu_count} GPU(s)")
|
||||
return
|
||||
logger.info(f"Detected and initialized {gpu_count} GPU(s)")
|
||||
return
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("gpu-devices memory_report timed out after 10s")
|
||||
|
||||
except ImportError:
|
||||
logger.debug("gpu-devices not available, trying torch directly")
|
||||
|
||||
try:
|
||||
import torch
|
||||
# Check if CUDA available WITH TIMEOUT
|
||||
cuda_available = await cuda_is_available_safe(timeout_s=5.0)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
gpu_count = torch.cuda.device_count()
|
||||
if cuda_available:
|
||||
gpu_count = await cuda_device_count_safe(timeout_s=5.0)
|
||||
await self._redis.set_gpu_count(gpu_count)
|
||||
|
||||
for i in range(gpu_count):
|
||||
props = torch.cuda.get_device_properties(i)
|
||||
vram_mb = props.total_memory // (1024 * 1024)
|
||||
gpu_name = props.name
|
||||
await self._redis.initialize_gpu(i, vram_mb, gpu_name)
|
||||
props = await cuda_get_device_properties_safe(i, timeout_s=10.0)
|
||||
if props:
|
||||
vram_mb = props.total_memory // (1024 * 1024)
|
||||
gpu_name = props.name
|
||||
await self._redis.initialize_gpu(i, vram_mb, gpu_name)
|
||||
else:
|
||||
logger.warning(f"Failed to get properties for GPU {i}, skipping")
|
||||
|
||||
logger.info(f"Detected and initialized {gpu_count} GPU(s)")
|
||||
return
|
||||
|
||||
except ImportError:
|
||||
logger.debug("torch not available")
|
||||
except (ImportError, GPUOperationTimeoutError) as e:
|
||||
if isinstance(e, GPUOperationTimeoutError):
|
||||
logger.error(
|
||||
f"GPU detection timed out: {e}. "
|
||||
f"System may be frozen or GPU driver unresponsive."
|
||||
)
|
||||
else:
|
||||
logger.debug("torch not available")
|
||||
|
||||
# No GPU detection available - user must initialize manually
|
||||
logger.warning("No GPU detection available. Use initialize_gpu() to set up manually.")
|
||||
|
|
@ -327,27 +351,45 @@ class GPUBoss:
|
|||
# Wait a bit before retrying
|
||||
await asyncio.sleep(min(1.0, remaining))
|
||||
|
||||
# Try to acquire
|
||||
gpu_index = await self._redis.find_gpu_with_space(vram_mb, gpu_preference)
|
||||
if gpu_index is not None:
|
||||
lease_info = LeaseInfo.create(
|
||||
gpu_index=gpu_index,
|
||||
vram_mb=vram_mb,
|
||||
priority=priority,
|
||||
model_id=model_id,
|
||||
service_name=service_name,
|
||||
)
|
||||
# Per-iteration timeout to prevent hang even if individual Redis ops timeout
|
||||
iteration_timeout = min(10.0, remaining)
|
||||
|
||||
if await self._redis.try_acquire_lease(lease_info):
|
||||
try:
|
||||
# Try to acquire with iteration timeout
|
||||
async def _try_acquire_iteration():
|
||||
gpu_index = await self._redis.find_gpu_with_space(vram_mb, gpu_preference)
|
||||
if gpu_index is not None:
|
||||
lease_info = LeaseInfo.create(
|
||||
gpu_index=gpu_index,
|
||||
vram_mb=vram_mb,
|
||||
priority=priority,
|
||||
model_id=model_id,
|
||||
service_name=service_name,
|
||||
)
|
||||
|
||||
if await self._redis.try_acquire_lease(lease_info):
|
||||
return GPULease(
|
||||
lease_info,
|
||||
self._redis,
|
||||
self.config.heartbeat_interval_ms,
|
||||
)
|
||||
return None
|
||||
|
||||
lease = await asyncio.wait_for(_try_acquire_iteration(), timeout=iteration_timeout)
|
||||
if lease:
|
||||
logger.info(
|
||||
f"Acquired lease {lease_info.lease_id} after waiting: "
|
||||
f"{vram_mb} MB on GPU {gpu_index}"
|
||||
)
|
||||
return GPULease(
|
||||
lease_info,
|
||||
self._redis,
|
||||
self.config.heartbeat_interval_ms,
|
||||
f"Acquired lease {lease.lease_id} after waiting: "
|
||||
f"{vram_mb} MB on GPU {lease.gpu_index}"
|
||||
)
|
||||
return lease
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
f"Lease acquisition iteration timed out after {iteration_timeout}s, "
|
||||
f"continuing to next iteration"
|
||||
)
|
||||
# Continue to next iteration
|
||||
continue
|
||||
|
||||
finally:
|
||||
# Remove from queue
|
||||
|
|
|
|||
|
|
@ -81,6 +81,18 @@ class BossConfig(BaseModel):
|
|||
description="Interval between stale lease cleanup runs (seconds)",
|
||||
)
|
||||
|
||||
redis_connect_timeout_s: float = Field(
|
||||
default=10.0,
|
||||
ge=1.0,
|
||||
description="Redis connection timeout in seconds",
|
||||
)
|
||||
|
||||
redis_operation_timeout_s: float = Field(
|
||||
default=5.0,
|
||||
ge=0.5,
|
||||
description="Individual Redis operation timeout in seconds",
|
||||
)
|
||||
|
||||
class Config:
|
||||
"""Pydantic model configuration."""
|
||||
|
||||
|
|
|
|||
224
vram-boss-py/src/lilith_vram_boss/gpu_utils.py
Normal file
224
vram-boss-py/src/lilith_vram_boss/gpu_utils.py
Normal file
|
|
@ -0,0 +1,224 @@
|
|||
"""GPU operation utilities with timeout protection.
|
||||
|
||||
Prevents system freezes from unresponsive GPU drivers by wrapping all
|
||||
torch.cuda.* operations with aggressive timeouts.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import TypeVar, Callable, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
_gpu_executor = ThreadPoolExecutor(max_workers=4, thread_name_prefix="gpu_op")
|
||||
|
||||
|
||||
class GPUOperationTimeoutError(Exception):
|
||||
"""GPU operation timed out."""
|
||||
|
||||
def __init__(self, operation: str, timeout_s: float):
|
||||
super().__init__(
|
||||
f"GPU operation '{operation}' timed out after {timeout_s}s. "
|
||||
f"This may indicate a system freeze or unresponsive GPU driver. "
|
||||
f"Consider checking GPU health or restarting the system."
|
||||
)
|
||||
|
||||
|
||||
async def gpu_op_with_timeout(
|
||||
func: Callable[[], T],
|
||||
operation_name: str,
|
||||
timeout_s: float = 10.0,
|
||||
) -> T:
|
||||
"""Execute a GPU operation with timeout protection.
|
||||
|
||||
Wraps blocking GPU operations (torch.cuda.* calls) with timeout to prevent
|
||||
system freezes when GPU drivers become unresponsive.
|
||||
|
||||
Args:
|
||||
func: Function to execute (should be blocking GPU operation).
|
||||
operation_name: Name for logging/errors.
|
||||
timeout_s: Timeout in seconds.
|
||||
|
||||
Returns:
|
||||
Result of func().
|
||||
|
||||
Raises:
|
||||
GPUOperationTimeoutError: If operation times out.
|
||||
|
||||
Example:
|
||||
>>> import torch
|
||||
>>> device_count = await gpu_op_with_timeout(
|
||||
... torch.cuda.device_count,
|
||||
... "cuda_device_count",
|
||||
... timeout_s=10.0,
|
||||
... )
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
loop.run_in_executor(_gpu_executor, func),
|
||||
timeout=timeout_s,
|
||||
)
|
||||
return result
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(
|
||||
f"GPU operation '{operation_name}' timed out after {timeout_s}s. "
|
||||
f"System may be frozen or GPU driver unresponsive."
|
||||
)
|
||||
raise GPUOperationTimeoutError(operation_name, timeout_s)
|
||||
|
||||
|
||||
# Common GPU operations wrapped with timeout
|
||||
|
||||
|
||||
async def cuda_is_available_safe(timeout_s: float = 5.0) -> bool:
|
||||
"""torch.cuda.is_available() with timeout.
|
||||
|
||||
Args:
|
||||
timeout_s: Timeout in seconds.
|
||||
|
||||
Returns:
|
||||
True if CUDA available, False if not or timeout.
|
||||
"""
|
||||
try:
|
||||
import torch
|
||||
|
||||
return await gpu_op_with_timeout(
|
||||
torch.cuda.is_available,
|
||||
"cuda_is_available",
|
||||
timeout_s,
|
||||
)
|
||||
except (ImportError, GPUOperationTimeoutError):
|
||||
return False
|
||||
|
||||
|
||||
async def cuda_device_count_safe(timeout_s: float = 5.0) -> int:
|
||||
"""torch.cuda.device_count() with timeout.
|
||||
|
||||
Args:
|
||||
timeout_s: Timeout in seconds.
|
||||
|
||||
Returns:
|
||||
Number of CUDA devices, or 0 if timeout.
|
||||
"""
|
||||
try:
|
||||
import torch
|
||||
|
||||
if not await cuda_is_available_safe(timeout_s):
|
||||
return 0
|
||||
|
||||
return await gpu_op_with_timeout(
|
||||
torch.cuda.device_count,
|
||||
"cuda_device_count",
|
||||
timeout_s,
|
||||
)
|
||||
except (ImportError, GPUOperationTimeoutError):
|
||||
logger.warning(f"cuda_device_count timed out after {timeout_s}s")
|
||||
return 0
|
||||
|
||||
|
||||
async def cuda_get_device_properties_safe(
|
||||
device: int,
|
||||
timeout_s: float = 10.0,
|
||||
) -> Any | None:
|
||||
"""torch.cuda.get_device_properties() with timeout.
|
||||
|
||||
Args:
|
||||
device: Device index.
|
||||
timeout_s: Timeout in seconds.
|
||||
|
||||
Returns:
|
||||
Device properties object, or None if timeout.
|
||||
"""
|
||||
try:
|
||||
import torch
|
||||
|
||||
return await gpu_op_with_timeout(
|
||||
lambda: torch.cuda.get_device_properties(device),
|
||||
f"cuda_get_device_properties_{device}",
|
||||
timeout_s,
|
||||
)
|
||||
except (ImportError, GPUOperationTimeoutError):
|
||||
logger.warning(f"cuda_get_device_properties timed out for device {device}")
|
||||
return None
|
||||
|
||||
|
||||
async def cuda_mem_get_info_safe(
|
||||
device: int | None = None,
|
||||
timeout_s: float = 5.0,
|
||||
) -> tuple[int, int]:
|
||||
"""torch.cuda.mem_get_info() with timeout.
|
||||
|
||||
Args:
|
||||
device: Device index (None for current device).
|
||||
timeout_s: Timeout in seconds.
|
||||
|
||||
Returns:
|
||||
(free_memory, total_memory) in bytes, or (0, 0) if timeout.
|
||||
"""
|
||||
try:
|
||||
import torch
|
||||
|
||||
result = await gpu_op_with_timeout(
|
||||
lambda: torch.cuda.mem_get_info(device),
|
||||
f"cuda_mem_get_info_{device}",
|
||||
timeout_s,
|
||||
)
|
||||
return result
|
||||
except (ImportError, GPUOperationTimeoutError):
|
||||
logger.warning(f"cuda_mem_get_info timed out for device {device}")
|
||||
return (0, 0)
|
||||
|
||||
|
||||
async def cuda_empty_cache_safe(timeout_s: float = 5.0) -> None:
|
||||
"""torch.cuda.empty_cache() with timeout.
|
||||
|
||||
Args:
|
||||
timeout_s: Timeout in seconds.
|
||||
"""
|
||||
try:
|
||||
import torch
|
||||
|
||||
if not await cuda_is_available_safe(timeout_s):
|
||||
return
|
||||
|
||||
await gpu_op_with_timeout(
|
||||
torch.cuda.empty_cache,
|
||||
"cuda_empty_cache",
|
||||
timeout_s,
|
||||
)
|
||||
except (ImportError, GPUOperationTimeoutError):
|
||||
logger.warning("CUDA empty_cache timed out, skipping")
|
||||
|
||||
|
||||
async def cuda_memory_allocated_safe(
|
||||
device: int = 0,
|
||||
timeout_s: float = 5.0,
|
||||
) -> int:
|
||||
"""torch.cuda.memory_allocated() with timeout.
|
||||
|
||||
Args:
|
||||
device: Device index.
|
||||
timeout_s: Timeout in seconds.
|
||||
|
||||
Returns:
|
||||
Allocated memory in bytes, or 0 if timeout.
|
||||
"""
|
||||
try:
|
||||
import torch
|
||||
|
||||
if not await cuda_is_available_safe(timeout_s):
|
||||
return 0
|
||||
|
||||
return await gpu_op_with_timeout(
|
||||
lambda: torch.cuda.memory_allocated(device),
|
||||
f"cuda_memory_allocated_{device}",
|
||||
timeout_s,
|
||||
)
|
||||
except (ImportError, GPUOperationTimeoutError):
|
||||
logger.warning(f"CUDA memory_allocated timed out for device {device}")
|
||||
return 0
|
||||
|
|
@ -19,6 +19,39 @@ if TYPE_CHECKING:
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RedisOperationTimeoutError(Exception):
|
||||
"""Redis operation timed out."""
|
||||
|
||||
def __init__(self, operation: str, timeout_s: float):
|
||||
self.operation = operation
|
||||
self.timeout_s = timeout_s
|
||||
super().__init__(
|
||||
f"Redis operation '{operation}' timed out after {timeout_s}s. "
|
||||
f"Ensure Redis is responsive and network latency is acceptable."
|
||||
)
|
||||
|
||||
|
||||
async def with_redis_timeout(coro, operation: str, timeout_s: float):
|
||||
"""Wrap a Redis operation with timeout.
|
||||
|
||||
Args:
|
||||
coro: Coroutine to execute.
|
||||
operation: Operation name for error messages.
|
||||
timeout_s: Timeout in seconds.
|
||||
|
||||
Returns:
|
||||
Result of the coroutine.
|
||||
|
||||
Raises:
|
||||
RedisOperationTimeoutError: If operation times out.
|
||||
"""
|
||||
try:
|
||||
return await asyncio.wait_for(coro, timeout=timeout_s)
|
||||
except asyncio.TimeoutError:
|
||||
raise RedisOperationTimeoutError(operation, timeout_s)
|
||||
|
||||
|
||||
# Lua script: Atomically acquire a lease if sufficient VRAM available
|
||||
ACQUIRE_LEASE_SCRIPT = """
|
||||
local lease_key = KEYS[1] -- gpu:{index}:leases
|
||||
|
|
@ -153,24 +186,49 @@ class BossRedisClient:
|
|||
|
||||
for attempt in range(self.config.redis_connect_retries):
|
||||
try:
|
||||
# Create Redis with socket timeouts
|
||||
self._redis = redis.from_url(
|
||||
self.config.redis_url,
|
||||
encoding="utf-8",
|
||||
decode_responses=True,
|
||||
socket_connect_timeout=self.config.redis_connect_timeout_s,
|
||||
socket_timeout=self.config.redis_operation_timeout_s,
|
||||
socket_keepalive=True,
|
||||
)
|
||||
# Verify connection
|
||||
await self._redis.ping()
|
||||
|
||||
# Register Lua scripts
|
||||
self._acquire_script = self._redis.register_script(ACQUIRE_LEASE_SCRIPT)
|
||||
self._release_script = self._redis.register_script(RELEASE_LEASE_SCRIPT)
|
||||
self._heartbeat_script = self._redis.register_script(UPDATE_HEARTBEAT_SCRIPT)
|
||||
self._cleanup_script = self._redis.register_script(CLEANUP_STALE_SCRIPT)
|
||||
# Verify connection with timeout
|
||||
await with_redis_timeout(
|
||||
self._redis.ping(),
|
||||
"ping",
|
||||
self.config.redis_connect_timeout_s,
|
||||
)
|
||||
|
||||
# Register Lua scripts with timeout
|
||||
self._acquire_script = await with_redis_timeout(
|
||||
self._redis.register_script(ACQUIRE_LEASE_SCRIPT),
|
||||
"register_acquire_script",
|
||||
5.0,
|
||||
)
|
||||
self._release_script = await with_redis_timeout(
|
||||
self._redis.register_script(RELEASE_LEASE_SCRIPT),
|
||||
"register_release_script",
|
||||
5.0,
|
||||
)
|
||||
self._heartbeat_script = await with_redis_timeout(
|
||||
self._redis.register_script(UPDATE_HEARTBEAT_SCRIPT),
|
||||
"register_heartbeat_script",
|
||||
5.0,
|
||||
)
|
||||
self._cleanup_script = await with_redis_timeout(
|
||||
self._redis.register_script(CLEANUP_STALE_SCRIPT),
|
||||
"register_cleanup_script",
|
||||
5.0,
|
||||
)
|
||||
|
||||
logger.info(f"Connected to Redis at {self.config.redis_url}")
|
||||
return
|
||||
|
||||
except (RedisConnectionError, RedisError, OSError) as e:
|
||||
except (RedisConnectionError, RedisError, OSError, RedisOperationTimeoutError) as e:
|
||||
last_error = e
|
||||
if attempt < self.config.redis_connect_retries - 1:
|
||||
delay = (self.config.redis_retry_delay_ms / 1000) * (2 ** attempt)
|
||||
|
|
@ -211,16 +269,28 @@ class BossRedisClient:
|
|||
pipe.set(self.config.vram_total_key(gpu_index), str(vram_total_mb))
|
||||
pipe.setnx(self.config.vram_used_key(gpu_index), "0")
|
||||
pipe.set(self.config.gpu_name_key(gpu_index), gpu_name)
|
||||
await pipe.execute()
|
||||
await with_redis_timeout(
|
||||
pipe.execute(),
|
||||
f"initialize_gpu_{gpu_index}",
|
||||
self.config.redis_operation_timeout_s,
|
||||
)
|
||||
logger.info(f"Initialized GPU {gpu_index}: {gpu_name} ({vram_total_mb} MB)")
|
||||
|
||||
async def set_gpu_count(self, count: int) -> None:
|
||||
"""Set the number of GPUs in the system."""
|
||||
await self.redis.set(self.config.gpu_count_key, str(count))
|
||||
await with_redis_timeout(
|
||||
self.redis.set(self.config.gpu_count_key, str(count)),
|
||||
"set_gpu_count",
|
||||
self.config.redis_operation_timeout_s,
|
||||
)
|
||||
|
||||
async def get_gpu_count(self) -> int:
|
||||
"""Get the number of GPUs."""
|
||||
count = await self.redis.get(self.config.gpu_count_key)
|
||||
count = await with_redis_timeout(
|
||||
self.redis.get(self.config.gpu_count_key),
|
||||
"get_gpu_count",
|
||||
self.config.redis_operation_timeout_s,
|
||||
)
|
||||
return int(count) if count else 0
|
||||
|
||||
async def get_gpu_vram(self, gpu_index: int) -> tuple[int, int]:
|
||||
|
|
@ -228,7 +298,11 @@ class BossRedisClient:
|
|||
pipe = self.redis.pipeline()
|
||||
pipe.get(self.config.vram_total_key(gpu_index))
|
||||
pipe.get(self.config.vram_used_key(gpu_index))
|
||||
results = await pipe.execute()
|
||||
results = await with_redis_timeout(
|
||||
pipe.execute(),
|
||||
f"get_gpu_vram_{gpu_index}",
|
||||
self.config.redis_operation_timeout_s,
|
||||
)
|
||||
total = int(results[0]) if results[0] else 0
|
||||
used = int(results[1]) if results[1] else 0
|
||||
return total, used
|
||||
|
|
@ -350,11 +424,33 @@ class BossRedisClient:
|
|||
channel = self.config.preempt_channel(lease_id)
|
||||
return await self.redis.publish(channel, reason)
|
||||
|
||||
async def subscribe_preemption(self, lease_id: str) -> redis.client.PubSub:
|
||||
"""Subscribe to preemption signals for a lease."""
|
||||
pubsub = self.redis.pubsub()
|
||||
await pubsub.subscribe(self.config.preempt_channel(lease_id))
|
||||
return pubsub
|
||||
async def subscribe_preemption(
|
||||
self,
|
||||
lease_id: str,
|
||||
timeout_s: float = 5.0,
|
||||
) -> redis.client.PubSub:
|
||||
"""Subscribe to preemption signals for a lease.
|
||||
|
||||
Args:
|
||||
lease_id: Lease ID to subscribe to.
|
||||
timeout_s: Timeout for subscription operation.
|
||||
|
||||
Returns:
|
||||
PubSub object for receiving messages.
|
||||
|
||||
Raises:
|
||||
RedisOperationTimeoutError: If subscription times out.
|
||||
"""
|
||||
try:
|
||||
pubsub = self.redis.pubsub()
|
||||
await with_redis_timeout(
|
||||
pubsub.subscribe(self.config.preempt_channel(lease_id)),
|
||||
f"subscribe_preemption_{lease_id}",
|
||||
timeout_s,
|
||||
)
|
||||
return pubsub
|
||||
except asyncio.TimeoutError:
|
||||
raise RedisOperationTimeoutError(f"subscribe_preemption_{lease_id}", timeout_s)
|
||||
|
||||
async def find_gpu_with_space(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import { v4 as uuidv4 } from 'uuid';
|
|||
|
||||
import { getKeyHelpers, resolveConfig } from './config.js';
|
||||
import { GPULease, LeaseTimeoutError } from './lease.js';
|
||||
import { withRedisTimeout } from './redis-wrapper.js';
|
||||
import { Priority as PriorityEnum } from './types.js';
|
||||
|
||||
import type { BossConfig, ResolvedConfig } from './config.js';
|
||||
|
|
@ -112,11 +113,46 @@ export class GPUBoss {
|
|||
* Connect to Redis and optionally start cleanup task.
|
||||
*/
|
||||
async connect(): Promise<void> {
|
||||
this.redis = new Redis(this.config.redisUrl);
|
||||
console.log(`Connected to Redis at ${this.config.redisUrl}`);
|
||||
const connectPromise = new Promise<Redis>((resolve, reject) => {
|
||||
const redis = new Redis(this.config.redisUrl, {
|
||||
connectTimeout: this.config.redisConnectTimeoutMs,
|
||||
commandTimeout: this.config.redisOperationTimeoutMs,
|
||||
enableOfflineQueue: false, // Fail fast, don't queue commands
|
||||
maxRetriesPerRequest: 2,
|
||||
retryStrategy: (times) => {
|
||||
if (times > 3) return null; // Give up after 3 retries
|
||||
return Math.min(times * 100, 2000);
|
||||
},
|
||||
});
|
||||
|
||||
if (this.config.autoCleanup) {
|
||||
this.startCleanupTask(this.config.cleanupIntervalSeconds);
|
||||
redis.on('connect', () => resolve(redis));
|
||||
redis.on('error', (error) => reject(error));
|
||||
});
|
||||
|
||||
const timeoutPromise = new Promise<never>((_, reject) => {
|
||||
setTimeout(
|
||||
() =>
|
||||
reject(
|
||||
new Error(
|
||||
`Redis connection timeout after ${this.config.redisConnectTimeoutMs}ms. ` +
|
||||
`Ensure Redis is running at ${this.config.redisUrl} and accessible.`,
|
||||
),
|
||||
),
|
||||
this.config.redisConnectTimeoutMs,
|
||||
);
|
||||
});
|
||||
|
||||
try {
|
||||
this.redis = await Promise.race([connectPromise, timeoutPromise]);
|
||||
console.log(`Connected to Redis at ${this.config.redisUrl}`);
|
||||
|
||||
if (this.config.autoCleanup) {
|
||||
this.startCleanupTask(this.config.cleanupIntervalSeconds);
|
||||
}
|
||||
} catch (error) {
|
||||
throw new Error(
|
||||
`Failed to connect to Redis at ${this.config.redisUrl}: ${error instanceof Error ? error.message : String(error)}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -169,7 +205,11 @@ export class GPUBoss {
|
|||
* Get the number of GPUs.
|
||||
*/
|
||||
async getGpuCount(): Promise<number> {
|
||||
const count = await this.getRedis().get(this.keys.gpuCountKey);
|
||||
const count = await withRedisTimeout(
|
||||
() => this.getRedis().get(this.keys.gpuCountKey),
|
||||
'get_gpu_count',
|
||||
this.config.redisOperationTimeoutMs,
|
||||
);
|
||||
return count ? parseInt(count, 10) : 0;
|
||||
}
|
||||
|
||||
|
|
@ -251,10 +291,15 @@ export class GPUBoss {
|
|||
}
|
||||
|
||||
for (const gpuIndex of checkOrder) {
|
||||
const [total, used] = await Promise.all([
|
||||
redis.get(this.keys.vramTotalKey(gpuIndex)),
|
||||
redis.get(this.keys.vramUsedKey(gpuIndex)),
|
||||
]);
|
||||
const [total, used] = await withRedisTimeout(
|
||||
() =>
|
||||
Promise.all([
|
||||
redis.get(this.keys.vramTotalKey(gpuIndex)),
|
||||
redis.get(this.keys.vramUsedKey(gpuIndex)),
|
||||
]),
|
||||
`get_gpu_vram_${gpuIndex}`,
|
||||
this.config.redisOperationTimeoutMs,
|
||||
);
|
||||
|
||||
const totalMb = total ? parseInt(total, 10) : 0;
|
||||
const usedMb = used ? parseInt(used, 10) : 0;
|
||||
|
|
@ -306,21 +351,26 @@ export class GPUBoss {
|
|||
|
||||
// Use redis.call('EVAL', ...) for Lua script execution
|
||||
// This is safe - it runs Lua on Redis server, not JavaScript
|
||||
const result = await redis.call(
|
||||
'EVAL',
|
||||
ACQUIRE_SCRIPT,
|
||||
5,
|
||||
this.keys.leaseKey(gpuIndex),
|
||||
this.keys.vramUsedKey(gpuIndex),
|
||||
this.keys.vramTotalKey(gpuIndex),
|
||||
this.keys.leasesAllKey,
|
||||
this.keys.heartbeatKey(leaseId),
|
||||
leaseId,
|
||||
String(vramMb),
|
||||
leaseJson,
|
||||
String(gpuIndex),
|
||||
String(this.config.staleLeaseTimeoutMs),
|
||||
String(now),
|
||||
const result = await withRedisTimeout(
|
||||
() =>
|
||||
redis.call(
|
||||
'EVAL',
|
||||
ACQUIRE_SCRIPT,
|
||||
5,
|
||||
this.keys.leaseKey(gpuIndex),
|
||||
this.keys.vramUsedKey(gpuIndex),
|
||||
this.keys.vramTotalKey(gpuIndex),
|
||||
this.keys.leasesAllKey,
|
||||
this.keys.heartbeatKey(leaseId),
|
||||
leaseId,
|
||||
String(vramMb),
|
||||
leaseJson,
|
||||
String(gpuIndex),
|
||||
String(this.config.staleLeaseTimeoutMs),
|
||||
String(now),
|
||||
),
|
||||
`acquire_lease_gpu_${gpuIndex}`,
|
||||
this.config.redisOperationTimeoutMs,
|
||||
);
|
||||
|
||||
if (result === 1) {
|
||||
|
|
|
|||
|
|
@ -50,6 +50,18 @@ export interface BossConfig {
|
|||
* @default 30
|
||||
*/
|
||||
cleanupIntervalSeconds?: number;
|
||||
|
||||
/**
|
||||
* Redis connection timeout (ms).
|
||||
* @default 10000
|
||||
*/
|
||||
redisConnectTimeoutMs?: number;
|
||||
|
||||
/**
|
||||
* Individual Redis operation timeout (ms).
|
||||
* @default 5000
|
||||
*/
|
||||
redisOperationTimeoutMs?: number;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -64,6 +76,8 @@ export interface ResolvedConfig {
|
|||
keyPrefix: string;
|
||||
autoCleanup: boolean;
|
||||
cleanupIntervalSeconds: number;
|
||||
redisConnectTimeoutMs: number;
|
||||
redisOperationTimeoutMs: number;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -79,6 +93,8 @@ export function resolveConfig(config: BossConfig = {}): ResolvedConfig {
|
|||
keyPrefix: config.keyPrefix ?? 'gpu',
|
||||
autoCleanup: config.autoCleanup ?? true,
|
||||
cleanupIntervalSeconds: config.cleanupIntervalSeconds ?? 30,
|
||||
redisConnectTimeoutMs: config.redisConnectTimeoutMs ?? 10_000,
|
||||
redisOperationTimeoutMs: config.redisOperationTimeoutMs ?? 5_000,
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -44,3 +44,4 @@ export type {
|
|||
AcquireOptions,
|
||||
PreemptCallback,
|
||||
} from './types.js';
|
||||
export { RedisOperationTimeoutError, withRedisTimeout } from './redis-wrapper.js';
|
||||
|
|
|
|||
48
vram-boss-ts/src/redis-wrapper.ts
Normal file
48
vram-boss-ts/src/redis-wrapper.ts
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
/**
|
||||
* Redis operation timeout wrapper utilities.
|
||||
*
|
||||
* Prevents system freezes from slow or unresponsive Redis by wrapping
|
||||
* all Redis operations with aggressive timeouts.
|
||||
*/
|
||||
|
||||
export class RedisOperationTimeoutError extends Error {
|
||||
constructor(operation: string, timeoutMs: number) {
|
||||
super(
|
||||
`Redis operation '${operation}' timed out after ${timeoutMs}ms. ` +
|
||||
`Ensure Redis is responsive and network latency is acceptable.`,
|
||||
);
|
||||
this.name = 'RedisOperationTimeoutError';
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Wrap a Redis operation with timeout.
|
||||
*
|
||||
* @param operation - Promise-returning function to execute
|
||||
* @param operationName - Name for error messages
|
||||
* @param timeoutMs - Timeout in milliseconds
|
||||
* @returns Result of the operation
|
||||
* @throws RedisOperationTimeoutError if operation times out
|
||||
*
|
||||
* @example
|
||||
* ```ts
|
||||
* const value = await withRedisTimeout(
|
||||
* () => redis.get('key'),
|
||||
* 'get_key',
|
||||
* 5000,
|
||||
* );
|
||||
* ```
|
||||
*/
|
||||
export async function withRedisTimeout<T>(
|
||||
operation: () => Promise<T>,
|
||||
operationName: string,
|
||||
timeoutMs: number,
|
||||
): Promise<T> {
|
||||
const timeoutPromise = new Promise<never>((_, reject) => {
|
||||
setTimeout(() => {
|
||||
reject(new RedisOperationTimeoutError(operationName, timeoutMs));
|
||||
}, timeoutMs);
|
||||
});
|
||||
|
||||
return Promise.race([operation(), timeoutPromise]);
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue