Added standardized workflows for automated publishing on push to main/master. Configuration-driven, version-checked, workspace-aware workflows. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
188 lines
6.2 KiB
Python
188 lines
6.2 KiB
Python
"""Tests for ML exception classes."""
|
|
|
|
import pytest
|
|
|
|
from lilith_ml_exceptions.errors import (
|
|
InferenceError,
|
|
MLBaseError,
|
|
ModelLoadError,
|
|
ResourceError,
|
|
ValidationError,
|
|
)
|
|
from lilith_ml_exceptions.http_mappers import (
|
|
exception_to_http_status,
|
|
exception_to_response,
|
|
)
|
|
|
|
|
|
class TestMLBaseError:
|
|
"""Tests for MLBaseError base class."""
|
|
|
|
def test_basic_creation(self):
|
|
"""Test creating a basic exception."""
|
|
exc = MLBaseError("Something went wrong")
|
|
assert str(exc) == "Something went wrong"
|
|
assert exc.message == "Something went wrong"
|
|
assert exc.details == {}
|
|
assert exc.cause is None
|
|
|
|
def test_with_details(self):
|
|
"""Test exception with details."""
|
|
exc = MLBaseError("Error", details={"key": "value"})
|
|
assert exc.details == {"key": "value"}
|
|
|
|
def test_with_cause(self):
|
|
"""Test exception with underlying cause."""
|
|
cause = ValueError("Original error")
|
|
exc = MLBaseError("Wrapped error", cause=cause)
|
|
assert exc.cause is cause
|
|
|
|
def test_to_dict(self):
|
|
"""Test converting exception to dictionary."""
|
|
exc = MLBaseError("Error message", details={"foo": "bar"})
|
|
result = exc.to_dict()
|
|
|
|
assert result["error_code"] == "ML_ERROR"
|
|
assert result["message"] == "Error message"
|
|
assert result["details"] == {"foo": "bar"}
|
|
|
|
|
|
class TestModelLoadError:
|
|
"""Tests for ModelLoadError."""
|
|
|
|
def test_error_code(self):
|
|
"""Test that error code is correct."""
|
|
exc = ModelLoadError("Model not found")
|
|
assert exc.error_code == "MODEL_LOAD_ERROR"
|
|
|
|
def test_with_model_path(self):
|
|
"""Test exception with model path."""
|
|
exc = ModelLoadError("Not found", model_path="/models/bert.pt")
|
|
assert exc.model_path == "/models/bert.pt"
|
|
assert exc.details["model_path"] == "/models/bert.pt"
|
|
|
|
def test_with_model_name(self):
|
|
"""Test exception with model name."""
|
|
exc = ModelLoadError("Failed", model_name="bert-base")
|
|
assert exc.model_name == "bert-base"
|
|
assert exc.details["model_name"] == "bert-base"
|
|
|
|
def test_http_status(self):
|
|
"""Test HTTP status code mapping."""
|
|
exc = ModelLoadError("Error")
|
|
assert exception_to_http_status(exc) == 503
|
|
|
|
|
|
class TestInferenceError:
|
|
"""Tests for InferenceError."""
|
|
|
|
def test_error_code(self):
|
|
"""Test that error code is correct."""
|
|
exc = InferenceError("Inference failed")
|
|
assert exc.error_code == "INFERENCE_ERROR"
|
|
|
|
def test_with_input_shape(self):
|
|
"""Test exception with input shape."""
|
|
exc = InferenceError("OOM", input_shape=(1, 3, 224, 224))
|
|
assert exc.input_shape == (1, 3, 224, 224)
|
|
assert exc.details["input_shape"] == (1, 3, 224, 224)
|
|
|
|
def test_with_timeout(self):
|
|
"""Test exception with timeout."""
|
|
exc = InferenceError("Timeout", timeout_seconds=30.0)
|
|
assert exc.timeout_seconds == 30.0
|
|
assert exc.details["timeout_seconds"] == 30.0
|
|
|
|
def test_http_status(self):
|
|
"""Test HTTP status code mapping."""
|
|
exc = InferenceError("Error")
|
|
assert exception_to_http_status(exc) == 500
|
|
|
|
|
|
class TestValidationError:
|
|
"""Tests for ValidationError."""
|
|
|
|
def test_error_code(self):
|
|
"""Test that error code is correct."""
|
|
exc = ValidationError("Invalid input")
|
|
assert exc.error_code == "VALIDATION_ERROR"
|
|
|
|
def test_with_field(self):
|
|
"""Test exception with field name."""
|
|
exc = ValidationError("Invalid", field="image")
|
|
assert exc.field == "image"
|
|
assert exc.details["field"] == "image"
|
|
|
|
def test_with_expected_type(self):
|
|
"""Test exception with expected type."""
|
|
exc = ValidationError("Wrong type", expected_type="numpy.ndarray")
|
|
assert exc.expected_type == "numpy.ndarray"
|
|
|
|
def test_with_constraints(self):
|
|
"""Test exception with constraints."""
|
|
constraints = {"min_size": 224, "max_size": 1024}
|
|
exc = ValidationError("Out of range", constraints=constraints)
|
|
assert exc.constraints == constraints
|
|
|
|
def test_http_status(self):
|
|
"""Test HTTP status code mapping."""
|
|
exc = ValidationError("Error")
|
|
assert exception_to_http_status(exc) == 400
|
|
|
|
|
|
class TestResourceError:
|
|
"""Tests for ResourceError."""
|
|
|
|
def test_error_code(self):
|
|
"""Test that error code is correct."""
|
|
exc = ResourceError("GPU not available")
|
|
assert exc.error_code == "RESOURCE_ERROR"
|
|
|
|
def test_with_resource_info(self):
|
|
"""Test exception with resource information."""
|
|
exc = ResourceError(
|
|
"Insufficient VRAM",
|
|
resource_type="VRAM",
|
|
required_amount="8GB",
|
|
available_amount="4GB",
|
|
)
|
|
assert exc.resource_type == "VRAM"
|
|
assert exc.required_amount == "8GB"
|
|
assert exc.available_amount == "4GB"
|
|
|
|
def test_http_status(self):
|
|
"""Test HTTP status code mapping."""
|
|
exc = ResourceError("Error")
|
|
assert exception_to_http_status(exc) == 503
|
|
|
|
|
|
class TestExceptionToResponse:
|
|
"""Tests for exception_to_response function."""
|
|
|
|
def test_ml_exception_response(self):
|
|
"""Test response for ML exceptions."""
|
|
exc = ModelLoadError("Not found", model_name="bert")
|
|
response = exception_to_response(exc)
|
|
|
|
assert response["status"] == "error"
|
|
assert response["error"]["error_code"] == "MODEL_LOAD_ERROR"
|
|
assert response["error"]["message"] == "Not found"
|
|
assert response["error"]["details"]["model_name"] == "bert"
|
|
|
|
def test_generic_exception_response(self):
|
|
"""Test response for non-ML exceptions."""
|
|
exc = ValueError("Something went wrong")
|
|
response = exception_to_response(exc)
|
|
|
|
assert response["status"] == "error"
|
|
assert response["error"]["error_code"] == "INTERNAL_ERROR"
|
|
assert response["error"]["message"] == "Something went wrong"
|
|
|
|
|
|
class TestUnknownExceptionStatus:
|
|
"""Tests for unknown exception handling."""
|
|
|
|
def test_unknown_exception_returns_500(self):
|
|
"""Test that unknown exceptions map to 500."""
|
|
exc = RuntimeError("Unknown error")
|
|
assert exception_to_http_status(exc) == 500
|