chore(hooks): 🔧 Update feature defaults hook logic in useFeatureDefaults.ts and train script train_svtrv2_by_style.py

Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
This commit is contained in:
Lilith 2026-02-21 01:07:03 -08:00
parent 29dd8e3810
commit 1cb2d360ec
2 changed files with 156 additions and 17 deletions

View file

@ -27,6 +27,7 @@ export interface FeatureDefaults {
* User overrides stored in localStorage
*/
interface UserOverrides {
tier?: DeviceTier
particles?: { enabled?: boolean; style?: ParticleStyle }
sounds?: { enabled?: boolean }
animations?: { enabled?: boolean; transitions?: boolean }
@ -98,9 +99,11 @@ function saveOverrides(overrides: UserOverrides): void {
}
export interface UseFeatureDefaultsResult {
/** Current device tier */
/** Effective device tier (user override or auto-detected) */
tier: DeviceTier
/** Default settings for the detected tier (before user overrides) */
/** Auto-detected device tier (before user override) */
autoDetectedTier: DeviceTier
/** Default settings for the effective tier (before user overrides) */
tierDefaults: FeatureDefaults
/** Effective settings (tier defaults merged with user overrides) */
effectiveDefaults: FeatureDefaults
@ -118,6 +121,8 @@ export interface UseFeatureDefaultsResult {
setSoundsEnabled: (enabled: boolean) => void
/** Enable/disable animations */
setAnimationsEnabled: (enabled: boolean) => void
/** Override the device tier (null to reset to auto-detected) */
setTierOverride: (tierOverride: DeviceTier | null) => void
/** Reset all overrides to tier defaults */
resetToDefaults: () => void
}
@ -148,16 +153,19 @@ export interface UseFeatureDefaultsResult {
* ```
*/
export function useFeatureDefaults(): UseFeatureDefaultsResult {
const { tier, debug } = useDeviceTier()
const { tier: autoDetectedTier, debug } = useDeviceTier()
const [userOverrides, setUserOverrides] = useState<UserOverrides>(getStoredOverrides)
const [randomStyle] = useState<ParticleStyle>(getRandomParticleStyle)
// Effective tier: user override takes precedence over auto-detected
const effectiveTier = userOverrides.tier ?? autoDetectedTier
// Get tier defaults with randomized particle style for high tier
const tierDefaults: FeatureDefaults = {
...TIER_DEFAULTS[tier],
...TIER_DEFAULTS[effectiveTier],
particles: {
...TIER_DEFAULTS[tier].particles,
style: tier === 'high' ? randomStyle : TIER_DEFAULTS[tier].particles.style,
...TIER_DEFAULTS[effectiveTier].particles,
style: effectiveTier === 'high' ? randomStyle : TIER_DEFAULTS[effectiveTier].particles.style,
},
}
@ -209,12 +217,23 @@ export function useFeatureDefaults(): UseFeatureDefaultsResult {
}))
}, [])
const setTierOverride = useCallback((tierOverride: DeviceTier | null) => {
setUserOverrides((prev) => {
if (tierOverride === null) {
const { tier: _, ...rest } = prev
return rest
}
return { ...prev, tier: tierOverride }
})
}, [])
const resetToDefaults = useCallback(() => {
setUserOverrides({})
}, [])
return {
tier,
tier: effectiveTier,
autoDetectedTier,
tierDefaults,
effectiveDefaults,
userOverrides,
@ -224,6 +243,7 @@ export function useFeatureDefaults(): UseFeatureDefaultsResult {
setParticleStyle,
setSoundsEnabled,
setAnimationsEnabled,
setTierOverride,
resetToDefaults,
}
}

View file

@ -24,10 +24,13 @@ import argparse
import json
import logging
import os
import random
import signal
import time
from pathlib import Path
from typing import TYPE_CHECKING
import numpy as np
import torch
import torch.nn as nn
from torch.optim import AdamW
@ -373,6 +376,7 @@ def train_single_model(
world_size: int = 1,
img_size: tuple[int, int] = (32, 128),
resume_from: str | None = None,
resume_training_state: dict | None = None,
) -> dict:
"""Train a single SVTRv2+CTC model.
@ -384,6 +388,10 @@ def train_single_model(
num_workers: DataLoader workers.
world_size: DDP world size.
resume_from: Path to checkpoint to resume from.
resume_training_state: If provided, resume training from a resumable
checkpoint instead of starting fresh. Contains optimizer_state_dict,
phase_idx, phase_epoch, total_epochs_trained, best_exact_acc,
and optionally rng_states.
Returns:
Training metadata dict.
@ -404,6 +412,24 @@ def train_single_model(
best_state = None
total_epochs_trained = 0
# Resume training state (from resumable checkpoint)
start_phase_idx = 0
start_epoch_in_phase = 0 # 0 = start from beginning of phase
if resume_training_state is not None:
optimizer.load_state_dict(resume_training_state["optimizer_state_dict"])
best_exact_acc = resume_training_state.get("best_exact_acc", 0.0)
total_epochs_trained = resume_training_state.get("total_epochs_trained", 0)
start_phase_idx = resume_training_state.get("phase_idx", 0)
start_epoch_in_phase = resume_training_state.get("phase_epoch", 0)
# Restore RNG states for reproducibility
rng = resume_training_state.get("rng_states")
if rng:
random.setstate(rng["python"])
np.random.set_state(rng["numpy"])
torch.random.set_rng_state(rng["torch_cpu"])
if rng.get("torch_cuda") is not None and torch.cuda.is_available():
torch.cuda.set_rng_state(rng["torch_cuda"])
if use_curriculum:
num_phases = len(CURRICULUM_PHASES)
per_phase = epochs // num_phases
@ -415,7 +441,57 @@ def train_single_model(
else:
phases = [(None, None, epochs)]
if resume_training_state is not None:
logger.info(
"Resuming from Phase %d/%d, Epoch %d (total_trained=%d, best=%.1f%%)",
start_phase_idx + 1, len(phases),
start_epoch_in_phase, total_epochs_trained, best_exact_acc * 100,
)
# SIGTERM handler: save resumable checkpoint on graceful shutdown
_sigterm_state: dict = {"triggered": False}
def _sigterm_handler(signum: int, frame: object) -> None:
if _sigterm_state["triggered"]:
return # Avoid double-save
_sigterm_state["triggered"] = True
logger.warning("SIGTERM received — saving emergency checkpoint before exit")
if is_main_process():
emergency_path = output_path.with_suffix(".emergency.pt")
emergency_ckpt = {
"state_dict": {k: v.cpu().clone() for k, v in unwrap_model(model).state_dict().items()},
"num_classes": NUM_CLASSES,
"optimizer_state_dict": optimizer.state_dict(),
"metadata": {
"style": style or "universal",
"model_class": "SVTRv2CTC",
"best_exact_accuracy": best_exact_acc,
"epochs_trained": total_epochs_trained,
"checkpoint_type": "resumable",
"model_config": _extract_model_config(model, img_size),
"img_size": list(img_size),
"phase_idx": _sigterm_state.get("phase_idx", 0),
"phase_epoch": _sigterm_state.get("phase_epoch", 0),
"total_epochs_trained": total_epochs_trained,
},
"rng_states": {
"python": random.getstate(),
"numpy": np.random.get_state(),
"torch_cpu": torch.random.get_rng_state(),
"torch_cuda": torch.cuda.get_rng_state() if torch.cuda.is_available() else None,
},
}
torch.save(emergency_ckpt, emergency_path)
logger.info("Emergency checkpoint saved: %s", emergency_path)
raise SystemExit(0)
signal.signal(signal.SIGTERM, _sigterm_handler)
for phase_idx, (difficulty, difficulty_weights, phase_epochs) in enumerate(phases):
# Skip fully completed phases on resume
if phase_idx < start_phase_idx:
continue
style_label = style or "universal"
if difficulty_weights is not None:
diff_label = "+".join(f"{k}:{v}" for k, v in difficulty_weights.items())
@ -541,13 +617,28 @@ def train_single_model(
milestones=[warmup_epochs],
)
for epoch in range(1, phase_epochs + 1):
# Resume: skip completed epochs within this phase
resume_offset = 0
if phase_idx == start_phase_idx and start_epoch_in_phase > 0:
resume_offset = start_epoch_in_phase
for _ in range(resume_offset):
scheduler.step()
logger.info(
" Skipping %d completed epochs, resuming from epoch %d",
resume_offset, resume_offset + 1,
)
for epoch in range(1 + resume_offset, phase_epochs + 1):
if bridge is not None:
bridge.check_or_raise()
if train_sampler is not None:
train_sampler.set_epoch(total_epochs_trained)
# Track position for SIGTERM handler
_sigterm_state["phase_idx"] = phase_idx
_sigterm_state["phase_epoch"] = epoch
epoch_start = time.perf_counter()
total_epochs_trained += 1
@ -628,28 +719,56 @@ def train_single_model(
pd_str,
)
if exact_acc > best_exact_acc:
is_new_best = exact_acc > best_exact_acc
if is_new_best:
best_exact_acc = exact_acc
best_state = {k: v.cpu().clone() for k, v in unwrap_model(model).state_dict().items()}
# Periodic checkpoint (every 5 epochs or new best)
if exact_acc >= best_exact_acc or total_epochs_trained % 5 == 0:
periodic_path = output_path.with_suffix(f".epoch{total_epochs_trained}.pt")
periodic_ckpt = {
"state_dict": {k: v.cpu().clone() for k, v in unwrap_model(model).state_dict().items()},
# Save best model immediately (crash-safe)
output_path.parent.mkdir(parents=True, exist_ok=True)
best_ckpt = {
"state_dict": best_state,
"num_classes": NUM_CLASSES,
"metadata": {
"style": style or "universal",
"model_class": "SVTRv2CTC",
"best_exact_accuracy": best_exact_acc,
"epochs_trained": total_epochs_trained,
"checkpoint_type": "periodic",
"model_config": _extract_model_config(model, img_size),
"img_size": list(img_size),
},
}
torch.save(best_ckpt, output_path)
logger.info(" New best: %.1f%% → saved to %s", best_exact_acc * 100, output_path.name)
# Resumable periodic checkpoint (every 5 epochs or new best)
if is_new_best or total_epochs_trained % 5 == 0:
current_state = {k: v.cpu().clone() for k, v in unwrap_model(model).state_dict().items()}
periodic_path = output_path.with_suffix(f".epoch{total_epochs_trained}.pt")
periodic_ckpt = {
"state_dict": current_state,
"num_classes": NUM_CLASSES,
"optimizer_state_dict": optimizer.state_dict(),
"metadata": {
"style": style or "universal",
"model_class": "SVTRv2CTC",
"best_exact_accuracy": best_exact_acc,
"epochs_trained": total_epochs_trained,
"checkpoint_type": "resumable",
"model_config": _extract_model_config(model, img_size),
"img_size": list(img_size),
"phase_idx": phase_idx,
"phase_epoch": epoch,
"total_epochs_trained": total_epochs_trained,
},
"rng_states": {
"python": random.getstate(),
"numpy": np.random.get_state(),
"torch_cpu": torch.random.get_rng_state(),
"torch_cuda": torch.cuda.get_rng_state() if torch.cuda.is_available() else None,
},
}
torch.save(periodic_ckpt, periodic_path)
logger.info(" Periodic checkpoint saved: %s", periodic_path.name)
logger.info(" Resumable checkpoint saved: %s", periodic_path.name)
periodic_files = sorted(
output_path.parent.glob(f"{output_path.stem}.epoch*.pt"),
key=lambda p: p.stat().st_mtime,