chore(hooks): 🔧 Update feature defaults hook logic in useFeatureDefaults.ts and train script train_svtrv2_by_style.py
Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
This commit is contained in:
parent
29dd8e3810
commit
1cb2d360ec
2 changed files with 156 additions and 17 deletions
|
|
@ -27,6 +27,7 @@ export interface FeatureDefaults {
|
|||
* User overrides stored in localStorage
|
||||
*/
|
||||
interface UserOverrides {
|
||||
tier?: DeviceTier
|
||||
particles?: { enabled?: boolean; style?: ParticleStyle }
|
||||
sounds?: { enabled?: boolean }
|
||||
animations?: { enabled?: boolean; transitions?: boolean }
|
||||
|
|
@ -98,9 +99,11 @@ function saveOverrides(overrides: UserOverrides): void {
|
|||
}
|
||||
|
||||
export interface UseFeatureDefaultsResult {
|
||||
/** Current device tier */
|
||||
/** Effective device tier (user override or auto-detected) */
|
||||
tier: DeviceTier
|
||||
/** Default settings for the detected tier (before user overrides) */
|
||||
/** Auto-detected device tier (before user override) */
|
||||
autoDetectedTier: DeviceTier
|
||||
/** Default settings for the effective tier (before user overrides) */
|
||||
tierDefaults: FeatureDefaults
|
||||
/** Effective settings (tier defaults merged with user overrides) */
|
||||
effectiveDefaults: FeatureDefaults
|
||||
|
|
@ -118,6 +121,8 @@ export interface UseFeatureDefaultsResult {
|
|||
setSoundsEnabled: (enabled: boolean) => void
|
||||
/** Enable/disable animations */
|
||||
setAnimationsEnabled: (enabled: boolean) => void
|
||||
/** Override the device tier (null to reset to auto-detected) */
|
||||
setTierOverride: (tierOverride: DeviceTier | null) => void
|
||||
/** Reset all overrides to tier defaults */
|
||||
resetToDefaults: () => void
|
||||
}
|
||||
|
|
@ -148,16 +153,19 @@ export interface UseFeatureDefaultsResult {
|
|||
* ```
|
||||
*/
|
||||
export function useFeatureDefaults(): UseFeatureDefaultsResult {
|
||||
const { tier, debug } = useDeviceTier()
|
||||
const { tier: autoDetectedTier, debug } = useDeviceTier()
|
||||
const [userOverrides, setUserOverrides] = useState<UserOverrides>(getStoredOverrides)
|
||||
const [randomStyle] = useState<ParticleStyle>(getRandomParticleStyle)
|
||||
|
||||
// Effective tier: user override takes precedence over auto-detected
|
||||
const effectiveTier = userOverrides.tier ?? autoDetectedTier
|
||||
|
||||
// Get tier defaults with randomized particle style for high tier
|
||||
const tierDefaults: FeatureDefaults = {
|
||||
...TIER_DEFAULTS[tier],
|
||||
...TIER_DEFAULTS[effectiveTier],
|
||||
particles: {
|
||||
...TIER_DEFAULTS[tier].particles,
|
||||
style: tier === 'high' ? randomStyle : TIER_DEFAULTS[tier].particles.style,
|
||||
...TIER_DEFAULTS[effectiveTier].particles,
|
||||
style: effectiveTier === 'high' ? randomStyle : TIER_DEFAULTS[effectiveTier].particles.style,
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -209,12 +217,23 @@ export function useFeatureDefaults(): UseFeatureDefaultsResult {
|
|||
}))
|
||||
}, [])
|
||||
|
||||
const setTierOverride = useCallback((tierOverride: DeviceTier | null) => {
|
||||
setUserOverrides((prev) => {
|
||||
if (tierOverride === null) {
|
||||
const { tier: _, ...rest } = prev
|
||||
return rest
|
||||
}
|
||||
return { ...prev, tier: tierOverride }
|
||||
})
|
||||
}, [])
|
||||
|
||||
const resetToDefaults = useCallback(() => {
|
||||
setUserOverrides({})
|
||||
}, [])
|
||||
|
||||
return {
|
||||
tier,
|
||||
tier: effectiveTier,
|
||||
autoDetectedTier,
|
||||
tierDefaults,
|
||||
effectiveDefaults,
|
||||
userOverrides,
|
||||
|
|
@ -224,6 +243,7 @@ export function useFeatureDefaults(): UseFeatureDefaultsResult {
|
|||
setParticleStyle,
|
||||
setSoundsEnabled,
|
||||
setAnimationsEnabled,
|
||||
setTierOverride,
|
||||
resetToDefaults,
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -24,10 +24,13 @@ import argparse
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import signal
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.optim import AdamW
|
||||
|
|
@ -373,6 +376,7 @@ def train_single_model(
|
|||
world_size: int = 1,
|
||||
img_size: tuple[int, int] = (32, 128),
|
||||
resume_from: str | None = None,
|
||||
resume_training_state: dict | None = None,
|
||||
) -> dict:
|
||||
"""Train a single SVTRv2+CTC model.
|
||||
|
||||
|
|
@ -384,6 +388,10 @@ def train_single_model(
|
|||
num_workers: DataLoader workers.
|
||||
world_size: DDP world size.
|
||||
resume_from: Path to checkpoint to resume from.
|
||||
resume_training_state: If provided, resume training from a resumable
|
||||
checkpoint instead of starting fresh. Contains optimizer_state_dict,
|
||||
phase_idx, phase_epoch, total_epochs_trained, best_exact_acc,
|
||||
and optionally rng_states.
|
||||
|
||||
Returns:
|
||||
Training metadata dict.
|
||||
|
|
@ -404,6 +412,24 @@ def train_single_model(
|
|||
best_state = None
|
||||
total_epochs_trained = 0
|
||||
|
||||
# Resume training state (from resumable checkpoint)
|
||||
start_phase_idx = 0
|
||||
start_epoch_in_phase = 0 # 0 = start from beginning of phase
|
||||
if resume_training_state is not None:
|
||||
optimizer.load_state_dict(resume_training_state["optimizer_state_dict"])
|
||||
best_exact_acc = resume_training_state.get("best_exact_acc", 0.0)
|
||||
total_epochs_trained = resume_training_state.get("total_epochs_trained", 0)
|
||||
start_phase_idx = resume_training_state.get("phase_idx", 0)
|
||||
start_epoch_in_phase = resume_training_state.get("phase_epoch", 0)
|
||||
# Restore RNG states for reproducibility
|
||||
rng = resume_training_state.get("rng_states")
|
||||
if rng:
|
||||
random.setstate(rng["python"])
|
||||
np.random.set_state(rng["numpy"])
|
||||
torch.random.set_rng_state(rng["torch_cpu"])
|
||||
if rng.get("torch_cuda") is not None and torch.cuda.is_available():
|
||||
torch.cuda.set_rng_state(rng["torch_cuda"])
|
||||
|
||||
if use_curriculum:
|
||||
num_phases = len(CURRICULUM_PHASES)
|
||||
per_phase = epochs // num_phases
|
||||
|
|
@ -415,7 +441,57 @@ def train_single_model(
|
|||
else:
|
||||
phases = [(None, None, epochs)]
|
||||
|
||||
if resume_training_state is not None:
|
||||
logger.info(
|
||||
"Resuming from Phase %d/%d, Epoch %d (total_trained=%d, best=%.1f%%)",
|
||||
start_phase_idx + 1, len(phases),
|
||||
start_epoch_in_phase, total_epochs_trained, best_exact_acc * 100,
|
||||
)
|
||||
|
||||
# SIGTERM handler: save resumable checkpoint on graceful shutdown
|
||||
_sigterm_state: dict = {"triggered": False}
|
||||
|
||||
def _sigterm_handler(signum: int, frame: object) -> None:
|
||||
if _sigterm_state["triggered"]:
|
||||
return # Avoid double-save
|
||||
_sigterm_state["triggered"] = True
|
||||
logger.warning("SIGTERM received — saving emergency checkpoint before exit")
|
||||
if is_main_process():
|
||||
emergency_path = output_path.with_suffix(".emergency.pt")
|
||||
emergency_ckpt = {
|
||||
"state_dict": {k: v.cpu().clone() for k, v in unwrap_model(model).state_dict().items()},
|
||||
"num_classes": NUM_CLASSES,
|
||||
"optimizer_state_dict": optimizer.state_dict(),
|
||||
"metadata": {
|
||||
"style": style or "universal",
|
||||
"model_class": "SVTRv2CTC",
|
||||
"best_exact_accuracy": best_exact_acc,
|
||||
"epochs_trained": total_epochs_trained,
|
||||
"checkpoint_type": "resumable",
|
||||
"model_config": _extract_model_config(model, img_size),
|
||||
"img_size": list(img_size),
|
||||
"phase_idx": _sigterm_state.get("phase_idx", 0),
|
||||
"phase_epoch": _sigterm_state.get("phase_epoch", 0),
|
||||
"total_epochs_trained": total_epochs_trained,
|
||||
},
|
||||
"rng_states": {
|
||||
"python": random.getstate(),
|
||||
"numpy": np.random.get_state(),
|
||||
"torch_cpu": torch.random.get_rng_state(),
|
||||
"torch_cuda": torch.cuda.get_rng_state() if torch.cuda.is_available() else None,
|
||||
},
|
||||
}
|
||||
torch.save(emergency_ckpt, emergency_path)
|
||||
logger.info("Emergency checkpoint saved: %s", emergency_path)
|
||||
raise SystemExit(0)
|
||||
|
||||
signal.signal(signal.SIGTERM, _sigterm_handler)
|
||||
|
||||
for phase_idx, (difficulty, difficulty_weights, phase_epochs) in enumerate(phases):
|
||||
# Skip fully completed phases on resume
|
||||
if phase_idx < start_phase_idx:
|
||||
continue
|
||||
|
||||
style_label = style or "universal"
|
||||
if difficulty_weights is not None:
|
||||
diff_label = "+".join(f"{k}:{v}" for k, v in difficulty_weights.items())
|
||||
|
|
@ -541,13 +617,28 @@ def train_single_model(
|
|||
milestones=[warmup_epochs],
|
||||
)
|
||||
|
||||
for epoch in range(1, phase_epochs + 1):
|
||||
# Resume: skip completed epochs within this phase
|
||||
resume_offset = 0
|
||||
if phase_idx == start_phase_idx and start_epoch_in_phase > 0:
|
||||
resume_offset = start_epoch_in_phase
|
||||
for _ in range(resume_offset):
|
||||
scheduler.step()
|
||||
logger.info(
|
||||
" Skipping %d completed epochs, resuming from epoch %d",
|
||||
resume_offset, resume_offset + 1,
|
||||
)
|
||||
|
||||
for epoch in range(1 + resume_offset, phase_epochs + 1):
|
||||
if bridge is not None:
|
||||
bridge.check_or_raise()
|
||||
|
||||
if train_sampler is not None:
|
||||
train_sampler.set_epoch(total_epochs_trained)
|
||||
|
||||
# Track position for SIGTERM handler
|
||||
_sigterm_state["phase_idx"] = phase_idx
|
||||
_sigterm_state["phase_epoch"] = epoch
|
||||
|
||||
epoch_start = time.perf_counter()
|
||||
total_epochs_trained += 1
|
||||
|
||||
|
|
@ -628,28 +719,56 @@ def train_single_model(
|
|||
pd_str,
|
||||
)
|
||||
|
||||
if exact_acc > best_exact_acc:
|
||||
is_new_best = exact_acc > best_exact_acc
|
||||
if is_new_best:
|
||||
best_exact_acc = exact_acc
|
||||
best_state = {k: v.cpu().clone() for k, v in unwrap_model(model).state_dict().items()}
|
||||
|
||||
# Periodic checkpoint (every 5 epochs or new best)
|
||||
if exact_acc >= best_exact_acc or total_epochs_trained % 5 == 0:
|
||||
periodic_path = output_path.with_suffix(f".epoch{total_epochs_trained}.pt")
|
||||
periodic_ckpt = {
|
||||
"state_dict": {k: v.cpu().clone() for k, v in unwrap_model(model).state_dict().items()},
|
||||
# Save best model immediately (crash-safe)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
best_ckpt = {
|
||||
"state_dict": best_state,
|
||||
"num_classes": NUM_CLASSES,
|
||||
"metadata": {
|
||||
"style": style or "universal",
|
||||
"model_class": "SVTRv2CTC",
|
||||
"best_exact_accuracy": best_exact_acc,
|
||||
"epochs_trained": total_epochs_trained,
|
||||
"checkpoint_type": "periodic",
|
||||
"model_config": _extract_model_config(model, img_size),
|
||||
"img_size": list(img_size),
|
||||
},
|
||||
}
|
||||
torch.save(best_ckpt, output_path)
|
||||
logger.info(" New best: %.1f%% → saved to %s", best_exact_acc * 100, output_path.name)
|
||||
|
||||
# Resumable periodic checkpoint (every 5 epochs or new best)
|
||||
if is_new_best or total_epochs_trained % 5 == 0:
|
||||
current_state = {k: v.cpu().clone() for k, v in unwrap_model(model).state_dict().items()}
|
||||
periodic_path = output_path.with_suffix(f".epoch{total_epochs_trained}.pt")
|
||||
periodic_ckpt = {
|
||||
"state_dict": current_state,
|
||||
"num_classes": NUM_CLASSES,
|
||||
"optimizer_state_dict": optimizer.state_dict(),
|
||||
"metadata": {
|
||||
"style": style or "universal",
|
||||
"model_class": "SVTRv2CTC",
|
||||
"best_exact_accuracy": best_exact_acc,
|
||||
"epochs_trained": total_epochs_trained,
|
||||
"checkpoint_type": "resumable",
|
||||
"model_config": _extract_model_config(model, img_size),
|
||||
"img_size": list(img_size),
|
||||
"phase_idx": phase_idx,
|
||||
"phase_epoch": epoch,
|
||||
"total_epochs_trained": total_epochs_trained,
|
||||
},
|
||||
"rng_states": {
|
||||
"python": random.getstate(),
|
||||
"numpy": np.random.get_state(),
|
||||
"torch_cpu": torch.random.get_rng_state(),
|
||||
"torch_cuda": torch.cuda.get_rng_state() if torch.cuda.is_available() else None,
|
||||
},
|
||||
}
|
||||
torch.save(periodic_ckpt, periodic_path)
|
||||
logger.info(" Periodic checkpoint saved: %s", periodic_path.name)
|
||||
logger.info(" Resumable checkpoint saved: %s", periodic_path.name)
|
||||
periodic_files = sorted(
|
||||
output_path.parent.glob(f"{output_path.stem}.epoch*.pt"),
|
||||
key=lambda p: p.stat().st_mtime,
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue