110 lines
3.4 KiB
Python
Executable file
110 lines
3.4 KiB
Python
Executable file
"""Job models - Shared job status and result types for all ML services."""
|
|
|
|
from datetime import datetime
|
|
from enum import Enum
|
|
from typing import Any, Dict, Generic, List, Optional, TypeVar
|
|
from pydantic import BaseModel, Field
|
|
import uuid
|
|
|
|
|
|
class JobStatus(str, Enum):
|
|
"""Standard job statuses across all ML services."""
|
|
PENDING = "pending"
|
|
RUNNING = "running"
|
|
COMPLETED = "completed"
|
|
FAILED = "failed"
|
|
CANCELLED = "cancelled"
|
|
|
|
|
|
class StageResult(BaseModel):
|
|
"""Result from a single processing stage."""
|
|
stage: str
|
|
status: str # "success", "failed", "skipped"
|
|
duration_ms: float
|
|
summary: str
|
|
data: Optional[Dict[str, Any]] = None
|
|
error: Optional[str] = None
|
|
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
class JobResult(BaseModel, Generic[T]):
|
|
"""Generic job result wrapper."""
|
|
success: bool
|
|
data: Optional[T] = None
|
|
error: Optional[str] = None
|
|
stages: List[StageResult] = Field(default_factory=list)
|
|
total_duration_ms: float = 0
|
|
|
|
|
|
class Job(BaseModel):
|
|
"""Base job model for all ML services."""
|
|
job_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
|
status: JobStatus = JobStatus.PENDING
|
|
service: str = Field(..., description="Service that owns this job")
|
|
job_type: str = Field(..., description="Type of job (e.g., 'pipeline', 'generate', 'batch')")
|
|
|
|
# Request data
|
|
request: Dict[str, Any] = Field(default_factory=dict)
|
|
|
|
# Timing
|
|
created_at: datetime = Field(default_factory=datetime.utcnow)
|
|
started_at: Optional[datetime] = None
|
|
completed_at: Optional[datetime] = None
|
|
|
|
# Progress
|
|
current_stage: Optional[str] = None
|
|
stages_completed: int = 0
|
|
total_stages: int = 0
|
|
progress_percent: float = 0
|
|
|
|
# Result
|
|
result: Optional[Dict[str, Any]] = None
|
|
error: Optional[str] = None
|
|
|
|
# Dependencies
|
|
depends_on: List[str] = Field(default_factory=list, description="Job IDs this job depends on")
|
|
child_jobs: List[str] = Field(default_factory=list, description="Child jobs spawned by this job")
|
|
|
|
def start(self) -> None:
|
|
"""Mark job as started."""
|
|
self.status = JobStatus.RUNNING
|
|
self.started_at = datetime.utcnow()
|
|
|
|
def complete(self, result: Dict[str, Any]) -> None:
|
|
"""Mark job as completed with result."""
|
|
self.status = JobStatus.COMPLETED
|
|
self.completed_at = datetime.utcnow()
|
|
self.result = result
|
|
self.progress_percent = 100
|
|
|
|
def fail(self, error: str) -> None:
|
|
"""Mark job as failed with error."""
|
|
self.status = JobStatus.FAILED
|
|
self.completed_at = datetime.utcnow()
|
|
self.error = error
|
|
|
|
def update_progress(self, stage: str, completed: int, total: int) -> None:
|
|
"""Update job progress."""
|
|
self.current_stage = stage
|
|
self.stages_completed = completed
|
|
self.total_stages = total
|
|
self.progress_percent = (completed / total * 100) if total > 0 else 0
|
|
|
|
@property
|
|
def duration_ms(self) -> Optional[float]:
|
|
"""Get job duration in milliseconds."""
|
|
if self.started_at is None:
|
|
return None
|
|
end = self.completed_at or datetime.utcnow()
|
|
return (end - self.started_at).total_seconds() * 1000
|
|
|
|
|
|
class RemoteJobReference(BaseModel):
|
|
"""Reference to a job on a remote service."""
|
|
service_url: str
|
|
job_id: str
|
|
status: JobStatus = JobStatus.PENDING
|
|
result: Optional[Dict[str, Any]] = None
|
|
error: Optional[str] = None
|