From a66fa760285b6d9f6f4d7df265c4880122340480 Mon Sep 17 00:00:00 2001 From: Lilith Date: Tue, 13 Jan 2026 03:05:03 -0800 Subject: [PATCH] =?UTF-8?q?feat(@ml/cot-reasoning):=20=E2=9C=A8=20add=20ex?= =?UTF-8?q?ternal=20stage=20loading=20support?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- service/src/api/main.py | 9 +++- service/src/config.py | 14 ++++++ service/src/reasoning/stages.py | 78 +++++++++++++++++++++++++++++++++ 3 files changed, 100 insertions(+), 1 deletion(-) diff --git a/service/src/api/main.py b/service/src/api/main.py index f5c8414..a022a57 100644 --- a/service/src/api/main.py +++ b/service/src/api/main.py @@ -14,8 +14,9 @@ from lilith_ml_service_base import ( setup_logging, ) -from ..config import get_config +from ..config import get_config, ReasoningConfig from ..reasoning.engine import ReasoningEngine, get_reasoning_engine +from ..reasoning.stages import load_stages_from_paths logger = get_logger(__name__) @@ -76,6 +77,12 @@ async def lifespan(app: FastAPI): setup_logging(level=config.logging.level) logger.info(f"Starting {config.service.name} on port {config.service.port}") + # Load external stages from configured paths (env: COT_STAGE_PATHS) + stage_paths = config.reasoning.stage_paths or ReasoningConfig.parse_stage_paths_from_env() + if stage_paths: + logger.info(f"Loading stages from external paths: {stage_paths}") + load_stages_from_paths(stage_paths) + # Initialize reasoning engine engine = await get_reasoning_engine() logger.info("Reasoning engine initialized") diff --git a/service/src/config.py b/service/src/config.py index b83f1b7..61dbf4a 100644 --- a/service/src/config.py +++ b/service/src/config.py @@ -36,9 +36,23 @@ class ReasoningConfig(BaseModel): """Reasoning pipeline configuration.""" default_stages: list[str] = Field(default_factory=lambda: ["analyze"]) + stage_paths: list[str] = Field( + default_factory=list, + description="External paths to load stage definitions from (colon-separated via env)", + ) cache: CacheConfig = Field(default_factory=CacheConfig) json_extraction: JSONExtractionConfig = Field(default_factory=JSONExtractionConfig) + @classmethod + def parse_stage_paths_from_env(cls) -> list[str]: + """Parse COT_STAGE_PATHS env var (colon-separated) into list.""" + import os + + env_value = os.getenv("COT_STAGE_PATHS", "") + if not env_value: + return [] + return [p.strip() for p in env_value.split(":") if p.strip()] + def _get_default_port() -> int: """Get port from environment, service-addresses, or fallback. diff --git a/service/src/reasoning/stages.py b/service/src/reasoning/stages.py index 1478d25..9b7f11c 100644 --- a/service/src/reasoning/stages.py +++ b/service/src/reasoning/stages.py @@ -1,11 +1,17 @@ """Reasoning stage definitions and registry.""" +import importlib.util +import logging +import sys from abc import ABC, abstractmethod from dataclasses import dataclass, field +from pathlib import Path from typing import Any from lilith_pipeline_framework import PipelineStage, PipelineContext, StageResult, StageStatus +logger = logging.getLogger(__name__) + @dataclass class StageDefinition: @@ -38,6 +44,78 @@ def get_stage(name: str) -> StageDefinition | None: return _stage_registry.get(name) +def load_stages_from_paths(paths: list[str]) -> int: + """Dynamically load and register stages from external paths. + + Each path should be a directory containing Python files that define + StageDefinition objects and call register_stage() on import. + + External stage files should use this import pattern: + from cot_stages import StageDefinition, register_stage + + The loader injects these symbols into the 'cot_stages' pseudo-module. + + Args: + paths: List of directory paths to load stages from + + Returns: + Number of stages loaded + """ + loaded_count = 0 + + # Create a pseudo-module that external stages can import from + # This allows external stages to use: from cot_stages import StageDefinition, register_stage + import types + + cot_stages_module = types.ModuleType("cot_stages") + cot_stages_module.StageDefinition = StageDefinition + cot_stages_module.register_stage = register_stage + sys.modules["cot_stages"] = cot_stages_module + + for path_str in paths: + path = Path(path_str) + + if not path.exists(): + logger.warning(f"Stage path does not exist: {path}") + continue + + if not path.is_dir(): + logger.warning(f"Stage path is not a directory: {path}") + continue + + logger.info(f"Loading stages from: {path}") + + # Find all Python files (excluding __init__.py and private files) + for py_file in sorted(path.glob("*.py")): + if py_file.name.startswith("_"): + continue + + try: + # Create a unique module name to avoid collisions + module_name = f"cot_external_stages.{path.name}.{py_file.stem}" + + # Load the module from the file + spec = importlib.util.spec_from_file_location(module_name, py_file) + if spec is None or spec.loader is None: + logger.warning(f"Could not load spec for: {py_file}") + continue + + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + + # Count how many stages were registered from this file + # (The module should call register_stage() on import) + logger.debug(f"Loaded stage module: {py_file.name}") + loaded_count += 1 + + except Exception as e: + logger.error(f"Failed to load stage from {py_file}: {e}") + + logger.info(f"Loaded {loaded_count} stage modules from external paths") + return loaded_count + + class ReasoningStage(PipelineStage, ABC): """Base class for reasoning stages that integrate with pipeline-framework."""