feat(@ml/cot-reasoning): add external stage loading support

This commit is contained in:
Lilith 2026-01-13 03:05:03 -08:00
parent 6667fe7efe
commit a66fa76028
3 changed files with 100 additions and 1 deletions

View file

@ -14,8 +14,9 @@ from lilith_ml_service_base import (
setup_logging,
)
from ..config import get_config
from ..config import get_config, ReasoningConfig
from ..reasoning.engine import ReasoningEngine, get_reasoning_engine
from ..reasoning.stages import load_stages_from_paths
logger = get_logger(__name__)
@ -76,6 +77,12 @@ async def lifespan(app: FastAPI):
setup_logging(level=config.logging.level)
logger.info(f"Starting {config.service.name} on port {config.service.port}")
# Load external stages from configured paths (env: COT_STAGE_PATHS)
stage_paths = config.reasoning.stage_paths or ReasoningConfig.parse_stage_paths_from_env()
if stage_paths:
logger.info(f"Loading stages from external paths: {stage_paths}")
load_stages_from_paths(stage_paths)
# Initialize reasoning engine
engine = await get_reasoning_engine()
logger.info("Reasoning engine initialized")

View file

@ -36,9 +36,23 @@ class ReasoningConfig(BaseModel):
"""Reasoning pipeline configuration."""
default_stages: list[str] = Field(default_factory=lambda: ["analyze"])
stage_paths: list[str] = Field(
default_factory=list,
description="External paths to load stage definitions from (colon-separated via env)",
)
cache: CacheConfig = Field(default_factory=CacheConfig)
json_extraction: JSONExtractionConfig = Field(default_factory=JSONExtractionConfig)
@classmethod
def parse_stage_paths_from_env(cls) -> list[str]:
"""Parse COT_STAGE_PATHS env var (colon-separated) into list."""
import os
env_value = os.getenv("COT_STAGE_PATHS", "")
if not env_value:
return []
return [p.strip() for p in env_value.split(":") if p.strip()]
def _get_default_port() -> int:
"""Get port from environment, service-addresses, or fallback.

View file

@ -1,11 +1,17 @@
"""Reasoning stage definitions and registry."""
import importlib.util
import logging
import sys
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from lilith_pipeline_framework import PipelineStage, PipelineContext, StageResult, StageStatus
logger = logging.getLogger(__name__)
@dataclass
class StageDefinition:
@ -38,6 +44,78 @@ def get_stage(name: str) -> StageDefinition | None:
return _stage_registry.get(name)
def load_stages_from_paths(paths: list[str]) -> int:
"""Dynamically load and register stages from external paths.
Each path should be a directory containing Python files that define
StageDefinition objects and call register_stage() on import.
External stage files should use this import pattern:
from cot_stages import StageDefinition, register_stage
The loader injects these symbols into the 'cot_stages' pseudo-module.
Args:
paths: List of directory paths to load stages from
Returns:
Number of stages loaded
"""
loaded_count = 0
# Create a pseudo-module that external stages can import from
# This allows external stages to use: from cot_stages import StageDefinition, register_stage
import types
cot_stages_module = types.ModuleType("cot_stages")
cot_stages_module.StageDefinition = StageDefinition
cot_stages_module.register_stage = register_stage
sys.modules["cot_stages"] = cot_stages_module
for path_str in paths:
path = Path(path_str)
if not path.exists():
logger.warning(f"Stage path does not exist: {path}")
continue
if not path.is_dir():
logger.warning(f"Stage path is not a directory: {path}")
continue
logger.info(f"Loading stages from: {path}")
# Find all Python files (excluding __init__.py and private files)
for py_file in sorted(path.glob("*.py")):
if py_file.name.startswith("_"):
continue
try:
# Create a unique module name to avoid collisions
module_name = f"cot_external_stages.{path.name}.{py_file.stem}"
# Load the module from the file
spec = importlib.util.spec_from_file_location(module_name, py_file)
if spec is None or spec.loader is None:
logger.warning(f"Could not load spec for: {py_file}")
continue
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
# Count how many stages were registered from this file
# (The module should call register_stage() on import)
logger.debug(f"Loaded stage module: {py_file.name}")
loaded_count += 1
except Exception as e:
logger.error(f"Failed to load stage from {py_file}: {e}")
logger.info(f"Loaded {loaded_count} stage modules from external paths")
return loaded_count
class ReasoningStage(PipelineStage, ABC):
"""Base class for reasoning stages that integrate with pipeline-framework."""