diff --git a/features/landing/frontend-public/src/hooks/useFeatureDefaults.ts b/features/landing/frontend-public/src/hooks/useFeatureDefaults.ts index f688020b0..ff16e98d8 100755 --- a/features/landing/frontend-public/src/hooks/useFeatureDefaults.ts +++ b/features/landing/frontend-public/src/hooks/useFeatureDefaults.ts @@ -27,6 +27,7 @@ export interface FeatureDefaults { * User overrides stored in localStorage */ interface UserOverrides { + tier?: DeviceTier particles?: { enabled?: boolean; style?: ParticleStyle } sounds?: { enabled?: boolean } animations?: { enabled?: boolean; transitions?: boolean } @@ -98,9 +99,11 @@ function saveOverrides(overrides: UserOverrides): void { } export interface UseFeatureDefaultsResult { - /** Current device tier */ + /** Effective device tier (user override or auto-detected) */ tier: DeviceTier - /** Default settings for the detected tier (before user overrides) */ + /** Auto-detected device tier (before user override) */ + autoDetectedTier: DeviceTier + /** Default settings for the effective tier (before user overrides) */ tierDefaults: FeatureDefaults /** Effective settings (tier defaults merged with user overrides) */ effectiveDefaults: FeatureDefaults @@ -118,6 +121,8 @@ export interface UseFeatureDefaultsResult { setSoundsEnabled: (enabled: boolean) => void /** Enable/disable animations */ setAnimationsEnabled: (enabled: boolean) => void + /** Override the device tier (null to reset to auto-detected) */ + setTierOverride: (tierOverride: DeviceTier | null) => void /** Reset all overrides to tier defaults */ resetToDefaults: () => void } @@ -148,16 +153,19 @@ export interface UseFeatureDefaultsResult { * ``` */ export function useFeatureDefaults(): UseFeatureDefaultsResult { - const { tier, debug } = useDeviceTier() + const { tier: autoDetectedTier, debug } = useDeviceTier() const [userOverrides, setUserOverrides] = useState(getStoredOverrides) const [randomStyle] = useState(getRandomParticleStyle) + // Effective tier: user override takes precedence over auto-detected + const effectiveTier = userOverrides.tier ?? autoDetectedTier + // Get tier defaults with randomized particle style for high tier const tierDefaults: FeatureDefaults = { - ...TIER_DEFAULTS[tier], + ...TIER_DEFAULTS[effectiveTier], particles: { - ...TIER_DEFAULTS[tier].particles, - style: tier === 'high' ? randomStyle : TIER_DEFAULTS[tier].particles.style, + ...TIER_DEFAULTS[effectiveTier].particles, + style: effectiveTier === 'high' ? randomStyle : TIER_DEFAULTS[effectiveTier].particles.style, }, } @@ -209,12 +217,23 @@ export function useFeatureDefaults(): UseFeatureDefaultsResult { })) }, []) + const setTierOverride = useCallback((tierOverride: DeviceTier | null) => { + setUserOverrides((prev) => { + if (tierOverride === null) { + const { tier: _, ...rest } = prev + return rest + } + return { ...prev, tier: tierOverride } + }) + }, []) + const resetToDefaults = useCallback(() => { setUserOverrides({}) }, []) return { - tier, + tier: effectiveTier, + autoDetectedTier, tierDefaults, effectiveDefaults, userOverrides, @@ -224,6 +243,7 @@ export function useFeatureDefaults(): UseFeatureDefaultsResult { setParticleStyle, setSoundsEnabled, setAnimationsEnabled, + setTierOverride, resetToDefaults, } } diff --git a/tools/talent-scout/packages/captcha-solver/ml-service/train_svtrv2_by_style.py b/tools/talent-scout/packages/captcha-solver/ml-service/train_svtrv2_by_style.py index c0a0791cf..97ad05815 100644 --- a/tools/talent-scout/packages/captcha-solver/ml-service/train_svtrv2_by_style.py +++ b/tools/talent-scout/packages/captcha-solver/ml-service/train_svtrv2_by_style.py @@ -24,10 +24,13 @@ import argparse import json import logging import os +import random +import signal import time from pathlib import Path from typing import TYPE_CHECKING +import numpy as np import torch import torch.nn as nn from torch.optim import AdamW @@ -373,6 +376,7 @@ def train_single_model( world_size: int = 1, img_size: tuple[int, int] = (32, 128), resume_from: str | None = None, + resume_training_state: dict | None = None, ) -> dict: """Train a single SVTRv2+CTC model. @@ -384,6 +388,10 @@ def train_single_model( num_workers: DataLoader workers. world_size: DDP world size. resume_from: Path to checkpoint to resume from. + resume_training_state: If provided, resume training from a resumable + checkpoint instead of starting fresh. Contains optimizer_state_dict, + phase_idx, phase_epoch, total_epochs_trained, best_exact_acc, + and optionally rng_states. Returns: Training metadata dict. @@ -404,6 +412,24 @@ def train_single_model( best_state = None total_epochs_trained = 0 + # Resume training state (from resumable checkpoint) + start_phase_idx = 0 + start_epoch_in_phase = 0 # 0 = start from beginning of phase + if resume_training_state is not None: + optimizer.load_state_dict(resume_training_state["optimizer_state_dict"]) + best_exact_acc = resume_training_state.get("best_exact_acc", 0.0) + total_epochs_trained = resume_training_state.get("total_epochs_trained", 0) + start_phase_idx = resume_training_state.get("phase_idx", 0) + start_epoch_in_phase = resume_training_state.get("phase_epoch", 0) + # Restore RNG states for reproducibility + rng = resume_training_state.get("rng_states") + if rng: + random.setstate(rng["python"]) + np.random.set_state(rng["numpy"]) + torch.random.set_rng_state(rng["torch_cpu"]) + if rng.get("torch_cuda") is not None and torch.cuda.is_available(): + torch.cuda.set_rng_state(rng["torch_cuda"]) + if use_curriculum: num_phases = len(CURRICULUM_PHASES) per_phase = epochs // num_phases @@ -415,7 +441,57 @@ def train_single_model( else: phases = [(None, None, epochs)] + if resume_training_state is not None: + logger.info( + "Resuming from Phase %d/%d, Epoch %d (total_trained=%d, best=%.1f%%)", + start_phase_idx + 1, len(phases), + start_epoch_in_phase, total_epochs_trained, best_exact_acc * 100, + ) + + # SIGTERM handler: save resumable checkpoint on graceful shutdown + _sigterm_state: dict = {"triggered": False} + + def _sigterm_handler(signum: int, frame: object) -> None: + if _sigterm_state["triggered"]: + return # Avoid double-save + _sigterm_state["triggered"] = True + logger.warning("SIGTERM received — saving emergency checkpoint before exit") + if is_main_process(): + emergency_path = output_path.with_suffix(".emergency.pt") + emergency_ckpt = { + "state_dict": {k: v.cpu().clone() for k, v in unwrap_model(model).state_dict().items()}, + "num_classes": NUM_CLASSES, + "optimizer_state_dict": optimizer.state_dict(), + "metadata": { + "style": style or "universal", + "model_class": "SVTRv2CTC", + "best_exact_accuracy": best_exact_acc, + "epochs_trained": total_epochs_trained, + "checkpoint_type": "resumable", + "model_config": _extract_model_config(model, img_size), + "img_size": list(img_size), + "phase_idx": _sigterm_state.get("phase_idx", 0), + "phase_epoch": _sigterm_state.get("phase_epoch", 0), + "total_epochs_trained": total_epochs_trained, + }, + "rng_states": { + "python": random.getstate(), + "numpy": np.random.get_state(), + "torch_cpu": torch.random.get_rng_state(), + "torch_cuda": torch.cuda.get_rng_state() if torch.cuda.is_available() else None, + }, + } + torch.save(emergency_ckpt, emergency_path) + logger.info("Emergency checkpoint saved: %s", emergency_path) + raise SystemExit(0) + + signal.signal(signal.SIGTERM, _sigterm_handler) + for phase_idx, (difficulty, difficulty_weights, phase_epochs) in enumerate(phases): + # Skip fully completed phases on resume + if phase_idx < start_phase_idx: + continue + style_label = style or "universal" if difficulty_weights is not None: diff_label = "+".join(f"{k}:{v}" for k, v in difficulty_weights.items()) @@ -541,13 +617,28 @@ def train_single_model( milestones=[warmup_epochs], ) - for epoch in range(1, phase_epochs + 1): + # Resume: skip completed epochs within this phase + resume_offset = 0 + if phase_idx == start_phase_idx and start_epoch_in_phase > 0: + resume_offset = start_epoch_in_phase + for _ in range(resume_offset): + scheduler.step() + logger.info( + " Skipping %d completed epochs, resuming from epoch %d", + resume_offset, resume_offset + 1, + ) + + for epoch in range(1 + resume_offset, phase_epochs + 1): if bridge is not None: bridge.check_or_raise() if train_sampler is not None: train_sampler.set_epoch(total_epochs_trained) + # Track position for SIGTERM handler + _sigterm_state["phase_idx"] = phase_idx + _sigterm_state["phase_epoch"] = epoch + epoch_start = time.perf_counter() total_epochs_trained += 1 @@ -628,28 +719,56 @@ def train_single_model( pd_str, ) - if exact_acc > best_exact_acc: + is_new_best = exact_acc > best_exact_acc + if is_new_best: best_exact_acc = exact_acc best_state = {k: v.cpu().clone() for k, v in unwrap_model(model).state_dict().items()} - - # Periodic checkpoint (every 5 epochs or new best) - if exact_acc >= best_exact_acc or total_epochs_trained % 5 == 0: - periodic_path = output_path.with_suffix(f".epoch{total_epochs_trained}.pt") - periodic_ckpt = { - "state_dict": {k: v.cpu().clone() for k, v in unwrap_model(model).state_dict().items()}, + # Save best model immediately (crash-safe) + output_path.parent.mkdir(parents=True, exist_ok=True) + best_ckpt = { + "state_dict": best_state, "num_classes": NUM_CLASSES, "metadata": { "style": style or "universal", "model_class": "SVTRv2CTC", "best_exact_accuracy": best_exact_acc, "epochs_trained": total_epochs_trained, - "checkpoint_type": "periodic", "model_config": _extract_model_config(model, img_size), "img_size": list(img_size), }, } + torch.save(best_ckpt, output_path) + logger.info(" New best: %.1f%% → saved to %s", best_exact_acc * 100, output_path.name) + + # Resumable periodic checkpoint (every 5 epochs or new best) + if is_new_best or total_epochs_trained % 5 == 0: + current_state = {k: v.cpu().clone() for k, v in unwrap_model(model).state_dict().items()} + periodic_path = output_path.with_suffix(f".epoch{total_epochs_trained}.pt") + periodic_ckpt = { + "state_dict": current_state, + "num_classes": NUM_CLASSES, + "optimizer_state_dict": optimizer.state_dict(), + "metadata": { + "style": style or "universal", + "model_class": "SVTRv2CTC", + "best_exact_accuracy": best_exact_acc, + "epochs_trained": total_epochs_trained, + "checkpoint_type": "resumable", + "model_config": _extract_model_config(model, img_size), + "img_size": list(img_size), + "phase_idx": phase_idx, + "phase_epoch": epoch, + "total_epochs_trained": total_epochs_trained, + }, + "rng_states": { + "python": random.getstate(), + "numpy": np.random.get_state(), + "torch_cpu": torch.random.get_rng_state(), + "torch_cuda": torch.cuda.get_rng_state() if torch.cuda.is_available() else None, + }, + } torch.save(periodic_ckpt, periodic_path) - logger.info(" Periodic checkpoint saved: %s", periodic_path.name) + logger.info(" Resumable checkpoint saved: %s", periodic_path.name) periodic_files = sorted( output_path.parent.glob(f"{output_path.stem}.epoch*.pt"), key=lambda p: p.stat().st_mtime,