feat(@ml/cot-reasoning): ✨ add external stage loading support
This commit is contained in:
parent
6667fe7efe
commit
a66fa76028
3 changed files with 100 additions and 1 deletions
|
|
@ -14,8 +14,9 @@ from lilith_ml_service_base import (
|
|||
setup_logging,
|
||||
)
|
||||
|
||||
from ..config import get_config
|
||||
from ..config import get_config, ReasoningConfig
|
||||
from ..reasoning.engine import ReasoningEngine, get_reasoning_engine
|
||||
from ..reasoning.stages import load_stages_from_paths
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
|
@ -76,6 +77,12 @@ async def lifespan(app: FastAPI):
|
|||
setup_logging(level=config.logging.level)
|
||||
logger.info(f"Starting {config.service.name} on port {config.service.port}")
|
||||
|
||||
# Load external stages from configured paths (env: COT_STAGE_PATHS)
|
||||
stage_paths = config.reasoning.stage_paths or ReasoningConfig.parse_stage_paths_from_env()
|
||||
if stage_paths:
|
||||
logger.info(f"Loading stages from external paths: {stage_paths}")
|
||||
load_stages_from_paths(stage_paths)
|
||||
|
||||
# Initialize reasoning engine
|
||||
engine = await get_reasoning_engine()
|
||||
logger.info("Reasoning engine initialized")
|
||||
|
|
|
|||
|
|
@ -36,9 +36,23 @@ class ReasoningConfig(BaseModel):
|
|||
"""Reasoning pipeline configuration."""
|
||||
|
||||
default_stages: list[str] = Field(default_factory=lambda: ["analyze"])
|
||||
stage_paths: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="External paths to load stage definitions from (colon-separated via env)",
|
||||
)
|
||||
cache: CacheConfig = Field(default_factory=CacheConfig)
|
||||
json_extraction: JSONExtractionConfig = Field(default_factory=JSONExtractionConfig)
|
||||
|
||||
@classmethod
|
||||
def parse_stage_paths_from_env(cls) -> list[str]:
|
||||
"""Parse COT_STAGE_PATHS env var (colon-separated) into list."""
|
||||
import os
|
||||
|
||||
env_value = os.getenv("COT_STAGE_PATHS", "")
|
||||
if not env_value:
|
||||
return []
|
||||
return [p.strip() for p in env_value.split(":") if p.strip()]
|
||||
|
||||
|
||||
def _get_default_port() -> int:
|
||||
"""Get port from environment, service-addresses, or fallback.
|
||||
|
|
|
|||
|
|
@ -1,11 +1,17 @@
|
|||
"""Reasoning stage definitions and registry."""
|
||||
|
||||
import importlib.util
|
||||
import logging
|
||||
import sys
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from lilith_pipeline_framework import PipelineStage, PipelineContext, StageResult, StageStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StageDefinition:
|
||||
|
|
@ -38,6 +44,78 @@ def get_stage(name: str) -> StageDefinition | None:
|
|||
return _stage_registry.get(name)
|
||||
|
||||
|
||||
def load_stages_from_paths(paths: list[str]) -> int:
|
||||
"""Dynamically load and register stages from external paths.
|
||||
|
||||
Each path should be a directory containing Python files that define
|
||||
StageDefinition objects and call register_stage() on import.
|
||||
|
||||
External stage files should use this import pattern:
|
||||
from cot_stages import StageDefinition, register_stage
|
||||
|
||||
The loader injects these symbols into the 'cot_stages' pseudo-module.
|
||||
|
||||
Args:
|
||||
paths: List of directory paths to load stages from
|
||||
|
||||
Returns:
|
||||
Number of stages loaded
|
||||
"""
|
||||
loaded_count = 0
|
||||
|
||||
# Create a pseudo-module that external stages can import from
|
||||
# This allows external stages to use: from cot_stages import StageDefinition, register_stage
|
||||
import types
|
||||
|
||||
cot_stages_module = types.ModuleType("cot_stages")
|
||||
cot_stages_module.StageDefinition = StageDefinition
|
||||
cot_stages_module.register_stage = register_stage
|
||||
sys.modules["cot_stages"] = cot_stages_module
|
||||
|
||||
for path_str in paths:
|
||||
path = Path(path_str)
|
||||
|
||||
if not path.exists():
|
||||
logger.warning(f"Stage path does not exist: {path}")
|
||||
continue
|
||||
|
||||
if not path.is_dir():
|
||||
logger.warning(f"Stage path is not a directory: {path}")
|
||||
continue
|
||||
|
||||
logger.info(f"Loading stages from: {path}")
|
||||
|
||||
# Find all Python files (excluding __init__.py and private files)
|
||||
for py_file in sorted(path.glob("*.py")):
|
||||
if py_file.name.startswith("_"):
|
||||
continue
|
||||
|
||||
try:
|
||||
# Create a unique module name to avoid collisions
|
||||
module_name = f"cot_external_stages.{path.name}.{py_file.stem}"
|
||||
|
||||
# Load the module from the file
|
||||
spec = importlib.util.spec_from_file_location(module_name, py_file)
|
||||
if spec is None or spec.loader is None:
|
||||
logger.warning(f"Could not load spec for: {py_file}")
|
||||
continue
|
||||
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[module_name] = module
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
# Count how many stages were registered from this file
|
||||
# (The module should call register_stage() on import)
|
||||
logger.debug(f"Loaded stage module: {py_file.name}")
|
||||
loaded_count += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load stage from {py_file}: {e}")
|
||||
|
||||
logger.info(f"Loaded {loaded_count} stage modules from external paths")
|
||||
return loaded_count
|
||||
|
||||
|
||||
class ReasoningStage(PipelineStage, ABC):
|
||||
"""Base class for reasoning stages that integrate with pipeline-framework."""
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue