feat(tray): ✨ Implement tray components for commit processing pipeline with TrayClient, CommitLoop, LocalAgent, Prefilter, and tests
Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
This commit is contained in:
parent
6358be5902
commit
72c9df974f
8 changed files with 671 additions and 4 deletions
|
|
@ -1,5 +1,11 @@
|
|||
"""macOS system tray (menu bar) client for the auto-commit daemon."""
|
||||
|
||||
from .app import run_tray
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def run_tray(*args, **kwargs):
|
||||
from .app import run_tray as _run_tray
|
||||
return _run_tray(*args, **kwargs)
|
||||
|
||||
|
||||
__all__ = ["run_tray"]
|
||||
|
|
|
|||
|
|
@ -74,3 +74,39 @@ class DaemonClient:
|
|||
return resp.json()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def generate_message(self, diff: str, repo_name: str, branch: str = "main") -> dict | None:
|
||||
return self._post("/generate-message", diff=diff, repo_name=repo_name, branch=branch)
|
||||
|
||||
def record_commit(
|
||||
self,
|
||||
hash: str,
|
||||
repo_name: str,
|
||||
message: str,
|
||||
timestamp: str,
|
||||
hostname: str,
|
||||
branch: str = "main",
|
||||
files_changed: int | None = None,
|
||||
insertions: int | None = None,
|
||||
deletions: int | None = None,
|
||||
) -> dict | None:
|
||||
body: dict = {
|
||||
"hash": hash,
|
||||
"repo_name": repo_name,
|
||||
"message": message,
|
||||
"timestamp": timestamp,
|
||||
"hostname": hostname,
|
||||
"branch": branch,
|
||||
}
|
||||
if files_changed is not None:
|
||||
body["files_changed"] = files_changed
|
||||
if insertions is not None:
|
||||
body["insertions"] = insertions
|
||||
if deletions is not None:
|
||||
body["deletions"] = deletions
|
||||
try:
|
||||
resp = self._client.post("/record-commit", json=body)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
except Exception:
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -24,8 +24,10 @@ import httpx
|
|||
|
||||
try:
|
||||
from .local_git import discover_repos
|
||||
from .prefilter import filter_dirty_paths, truncate_diff
|
||||
except ImportError:
|
||||
from local_git import discover_repos # type: ignore[no-redef]
|
||||
from prefilter import filter_dirty_paths, truncate_diff # type: ignore[no-redef]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -234,13 +236,40 @@ class LocalCommitAgent:
|
|||
self._push_if_safe(repo_path, result)
|
||||
return False
|
||||
|
||||
# Stage all changes
|
||||
_git(repo_path, "add", "-A")
|
||||
# Extract dirty paths from porcelain output and apply secret prefilter.
|
||||
# Status format: "XY path" where XY is 2-char status, followed by the path.
|
||||
dirty_paths: list[str] = []
|
||||
for line in status.splitlines():
|
||||
if not line.strip():
|
||||
continue
|
||||
# Skip the 3-char prefix "XY ". Handle renames ("R old -> new") by taking new.
|
||||
entry = line[3:]
|
||||
if " -> " in entry:
|
||||
entry = entry.split(" -> ", 1)[1]
|
||||
dirty_paths.append(entry.strip().strip('"'))
|
||||
|
||||
allowed, denied = filter_dirty_paths(dirty_paths)
|
||||
if denied:
|
||||
logger.info(
|
||||
f"Prefilter dropped {len(denied)} secret-like file(s) from "
|
||||
f"{_repo_display_name(repo_path)}: {', '.join(denied[:5])}"
|
||||
+ (f" (+{len(denied) - 5} more)" if len(denied) > 5 else "")
|
||||
)
|
||||
if not allowed:
|
||||
logger.debug(f"{_repo_display_name(repo_path)}: all dirty files on denylist, skipping")
|
||||
return False
|
||||
|
||||
# Stage only the allowed files (never blanket `git add -A` — that would
|
||||
# stage denied secret paths too).
|
||||
_git(repo_path, "add", "--", *allowed)
|
||||
|
||||
# Get the diff of staged changes
|
||||
diff = _git(repo_path, "diff", "--cached", "--stat") + "\n" + _git(
|
||||
raw_diff = _git(repo_path, "diff", "--cached", "--stat") + "\n" + _git(
|
||||
repo_path, "diff", "--cached", max_bytes=6000
|
||||
)
|
||||
diff, was_truncated = truncate_diff(raw_diff)
|
||||
if was_truncated:
|
||||
logger.debug(f"{_repo_display_name(repo_path)}: diff truncated for transmission")
|
||||
if not diff.strip():
|
||||
return False
|
||||
|
||||
|
|
|
|||
113
src/auto_commit_service/tray/prefilter.py
Normal file
113
src/auto_commit_service/tray/prefilter.py
Normal file
|
|
@ -0,0 +1,113 @@
|
|||
"""Secret/path denylist and diff size cap for tray -> apricot transmission.
|
||||
|
||||
Pure, stdlib-only filter applied before diffs leave plum for the remote
|
||||
`/generate-message` endpoint. Blocks paths that commonly contain credentials
|
||||
and caps the per-repo diff payload size.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable
|
||||
from fnmatch import fnmatchcase
|
||||
from pathlib import PurePosixPath
|
||||
|
||||
DENYLIST_PATTERNS: list[str] = [
|
||||
".env",
|
||||
".env.*",
|
||||
"*.pem",
|
||||
"*.key",
|
||||
"*.p12",
|
||||
"*.pfx",
|
||||
"id_rsa",
|
||||
"id_rsa.*",
|
||||
"id_dsa",
|
||||
"id_dsa.*",
|
||||
"id_ecdsa",
|
||||
"id_ecdsa.*",
|
||||
"id_ed25519",
|
||||
"id_ed25519.*",
|
||||
"*.asc",
|
||||
".ssh/**",
|
||||
"**/.ssh/**",
|
||||
"**/secrets.yaml",
|
||||
"**/secrets.yml",
|
||||
"**/secrets.json",
|
||||
"secrets.yaml",
|
||||
"secrets.yml",
|
||||
"secrets.json",
|
||||
".git/config",
|
||||
".git/credentials",
|
||||
"**/.git/config",
|
||||
"**/.git/credentials",
|
||||
"**/credentials.json",
|
||||
"credentials.json",
|
||||
"**/.netrc",
|
||||
".netrc",
|
||||
"**/*.keystore",
|
||||
"*.keystore",
|
||||
"**/*.jks",
|
||||
"*.jks",
|
||||
]
|
||||
|
||||
ALLOWLIST_BASENAMES: frozenset[str] = frozenset({".env.example"})
|
||||
|
||||
MAX_DIFF_BYTES: int = 131_072
|
||||
|
||||
|
||||
def _normalize(path: str) -> str:
|
||||
normalized = path.replace("\\", "/")
|
||||
while normalized.startswith("./"):
|
||||
normalized = normalized[2:]
|
||||
return normalized.lstrip("/")
|
||||
|
||||
|
||||
def _matches_any(candidate: str, patterns: Iterable[str]) -> bool:
|
||||
return any(fnmatchcase(candidate, pat) for pat in patterns)
|
||||
|
||||
|
||||
def is_secret_path(path: str) -> bool:
|
||||
"""Return True if `path` matches the secret denylist.
|
||||
|
||||
Checks both the full (posix-normalized) path and the basename against
|
||||
`DENYLIST_PATTERNS`. An explicit allowlist (e.g. `.env.example`) wins
|
||||
over denylist matches.
|
||||
"""
|
||||
if not path:
|
||||
return False
|
||||
|
||||
normalized = _normalize(path)
|
||||
basename = PurePosixPath(normalized).name
|
||||
|
||||
if basename in ALLOWLIST_BASENAMES:
|
||||
return False
|
||||
|
||||
if _matches_any(basename, DENYLIST_PATTERNS):
|
||||
return True
|
||||
if _matches_any(normalized, DENYLIST_PATTERNS):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def filter_dirty_paths(paths: Iterable[str]) -> tuple[list[str], list[str]]:
|
||||
"""Partition `paths` into `(allowed, denied)` based on the denylist."""
|
||||
allowed: list[str] = []
|
||||
denied: list[str] = []
|
||||
for path in paths:
|
||||
if is_secret_path(path):
|
||||
denied.append(path)
|
||||
else:
|
||||
allowed.append(path)
|
||||
return allowed, denied
|
||||
|
||||
|
||||
def truncate_diff(diff: str, max_bytes: int = MAX_DIFF_BYTES) -> tuple[str, bool]:
|
||||
"""Truncate `diff` to at most `max_bytes` UTF-8 bytes.
|
||||
|
||||
Returns `(possibly_truncated, was_truncated)`. Truncation happens on a
|
||||
UTF-8 boundary so the returned string is always valid UTF-8.
|
||||
"""
|
||||
encoded = diff.encode("utf-8")
|
||||
if len(encoded) <= max_bytes:
|
||||
return diff, False
|
||||
return encoded[:max_bytes].decode("utf-8", errors="ignore"), True
|
||||
0
tests/tray/__init__.py
Normal file
0
tests/tray/__init__.py
Normal file
42
tests/tray/conftest.py
Normal file
42
tests/tray/conftest.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
"""Conftest for tray tests — stub rumps so tray/__init__.py can be imported on Linux."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import types
|
||||
|
||||
|
||||
def _stub_rumps() -> None:
|
||||
"""Insert a minimal rumps stub so tray/app.py can be imported without the macOS dep."""
|
||||
if "rumps" in sys.modules:
|
||||
return
|
||||
|
||||
rumps = types.ModuleType("rumps")
|
||||
|
||||
class _App:
|
||||
def __init__(self, *a, **kw): ...
|
||||
def run(self): ...
|
||||
|
||||
class _MenuItem:
|
||||
def __init__(self, *a, **kw): ...
|
||||
|
||||
class _Timer:
|
||||
def __init__(self, *a, **kw): ...
|
||||
def start(self): ...
|
||||
def stop(self): ...
|
||||
|
||||
def notification(*a, **kw): ...
|
||||
def alert(*a, **kw): ...
|
||||
|
||||
rumps.App = _App
|
||||
rumps.MenuItem = _MenuItem
|
||||
rumps.Timer = _Timer
|
||||
rumps.notification = notification
|
||||
rumps.alert = alert
|
||||
rumps.clicked = lambda *a, **kw: (lambda f: f)
|
||||
rumps.timer = lambda *a, **kw: (lambda f: f)
|
||||
|
||||
sys.modules["rumps"] = rumps
|
||||
|
||||
|
||||
_stub_rumps()
|
||||
214
tests/tray/test_client.py
Normal file
214
tests/tray/test_client.py
Normal file
|
|
@ -0,0 +1,214 @@
|
|||
"""Tests for DaemonClient.generate_message() and record_commit() RPC methods."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from pytest_httpx import HTTPXMock
|
||||
|
||||
from auto_commit_service.tray.client import DaemonClient
|
||||
|
||||
|
||||
BASE_URL = "http://localhost:8200"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(httpx_mock: HTTPXMock) -> DaemonClient:
|
||||
return DaemonClient(base_url=BASE_URL, timeout=5.0)
|
||||
|
||||
|
||||
class TestGenerateMessage:
|
||||
def test_happy_path_returns_dict_with_message(
|
||||
self, client: DaemonClient, httpx_mock: HTTPXMock
|
||||
) -> None:
|
||||
httpx_mock.add_response(
|
||||
method="POST",
|
||||
url=f"{BASE_URL}/generate-message",
|
||||
json={"message": "feat: add new endpoint"},
|
||||
status_code=200,
|
||||
)
|
||||
|
||||
result = client.generate_message(
|
||||
diff="diff --git a/foo.py ...", repo_name="myrepo", branch="main"
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result["message"] == "feat: add new endpoint"
|
||||
|
||||
def test_sends_correct_payload(
|
||||
self, client: DaemonClient, httpx_mock: HTTPXMock
|
||||
) -> None:
|
||||
httpx_mock.add_response(
|
||||
method="POST",
|
||||
url=f"{BASE_URL}/generate-message",
|
||||
json={"message": "chore: update deps"},
|
||||
status_code=200,
|
||||
)
|
||||
|
||||
client.generate_message(
|
||||
diff="+ new line", repo_name="my-service", branch="feature/xyz"
|
||||
)
|
||||
|
||||
request = httpx_mock.get_request()
|
||||
assert request is not None
|
||||
body = json.loads(request.content)
|
||||
assert body["diff"] == "+ new line"
|
||||
assert body["repo_name"] == "my-service"
|
||||
assert body["branch"] == "feature/xyz"
|
||||
|
||||
def test_server_400_returns_none(
|
||||
self, client: DaemonClient, httpx_mock: HTTPXMock
|
||||
) -> None:
|
||||
httpx_mock.add_response(
|
||||
method="POST",
|
||||
url=f"{BASE_URL}/generate-message",
|
||||
json={"detail": "diff cannot be empty"},
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
result = client.generate_message(diff="", repo_name="myrepo", branch="main")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_server_500_returns_none(
|
||||
self, client: DaemonClient, httpx_mock: HTTPXMock
|
||||
) -> None:
|
||||
httpx_mock.add_response(
|
||||
method="POST",
|
||||
url=f"{BASE_URL}/generate-message",
|
||||
json={"detail": "internal error"},
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
result = client.generate_message(diff="some diff", repo_name="r", branch="main")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_connection_refused_returns_none(self, httpx_mock: HTTPXMock) -> None:
|
||||
httpx_mock.add_exception(
|
||||
httpx.ConnectError("Connection refused"),
|
||||
url=f"{BASE_URL}/generate-message",
|
||||
)
|
||||
c = DaemonClient(base_url=BASE_URL)
|
||||
|
||||
result = c.generate_message(diff="x", repo_name="r", branch="main")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_default_branch_is_main(
|
||||
self, client: DaemonClient, httpx_mock: HTTPXMock
|
||||
) -> None:
|
||||
httpx_mock.add_response(
|
||||
method="POST",
|
||||
url=f"{BASE_URL}/generate-message",
|
||||
json={"message": "fix: something"},
|
||||
status_code=200,
|
||||
)
|
||||
|
||||
client.generate_message(diff="x", repo_name="r")
|
||||
|
||||
request = httpx_mock.get_request()
|
||||
body = json.loads(request.content)
|
||||
assert body["branch"] == "main"
|
||||
|
||||
|
||||
class TestRecordCommit:
|
||||
def test_happy_path_sends_correct_payload(
|
||||
self, client: DaemonClient, httpx_mock: HTTPXMock
|
||||
) -> None:
|
||||
httpx_mock.add_response(
|
||||
method="POST",
|
||||
url=f"{BASE_URL}/record-commit",
|
||||
json={"status": "ok"},
|
||||
status_code=200,
|
||||
)
|
||||
|
||||
result = client.record_commit(
|
||||
hash="abc123",
|
||||
repo_name="my-repo",
|
||||
message="feat: add feature",
|
||||
timestamp="2026-04-17T12:00:00+00:00",
|
||||
hostname="plum",
|
||||
branch="main",
|
||||
files_changed=3,
|
||||
insertions=10,
|
||||
deletions=2,
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result["status"] == "ok"
|
||||
|
||||
request = httpx_mock.get_request()
|
||||
body = json.loads(request.content)
|
||||
assert body["hash"] == "abc123"
|
||||
assert body["repo_name"] == "my-repo"
|
||||
assert body["message"] == "feat: add feature"
|
||||
assert body["timestamp"] == "2026-04-17T12:00:00+00:00"
|
||||
assert body["hostname"] == "plum"
|
||||
assert body["branch"] == "main"
|
||||
assert body["files_changed"] == 3
|
||||
assert body["insertions"] == 10
|
||||
assert body["deletions"] == 2
|
||||
|
||||
def test_optional_fields_omitted_when_none(
|
||||
self, client: DaemonClient, httpx_mock: HTTPXMock
|
||||
) -> None:
|
||||
httpx_mock.add_response(
|
||||
method="POST",
|
||||
url=f"{BASE_URL}/record-commit",
|
||||
json={"status": "ok"},
|
||||
status_code=200,
|
||||
)
|
||||
|
||||
client.record_commit(
|
||||
hash="abc123",
|
||||
repo_name="r",
|
||||
message="m",
|
||||
timestamp="2026-01-01T00:00:00+00:00",
|
||||
hostname="plum",
|
||||
)
|
||||
|
||||
request = httpx_mock.get_request()
|
||||
body = json.loads(request.content)
|
||||
assert "files_changed" not in body
|
||||
assert "insertions" not in body
|
||||
assert "deletions" not in body
|
||||
|
||||
def test_server_500_returns_none(
|
||||
self, client: DaemonClient, httpx_mock: HTTPXMock
|
||||
) -> None:
|
||||
httpx_mock.add_response(
|
||||
method="POST",
|
||||
url=f"{BASE_URL}/record-commit",
|
||||
json={"detail": "db error"},
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
result = client.record_commit(
|
||||
hash="x",
|
||||
repo_name="r",
|
||||
message="m",
|
||||
timestamp="2026-01-01T00:00:00+00:00",
|
||||
hostname="plum",
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_network_error_returns_none(self, httpx_mock: HTTPXMock) -> None:
|
||||
httpx_mock.add_exception(
|
||||
httpx.ConnectTimeout("timed out"),
|
||||
url=f"{BASE_URL}/record-commit",
|
||||
)
|
||||
c = DaemonClient(base_url=BASE_URL)
|
||||
|
||||
result = c.record_commit(
|
||||
hash="x",
|
||||
repo_name="r",
|
||||
message="m",
|
||||
timestamp="2026-01-01T00:00:00+00:00",
|
||||
hostname="plum",
|
||||
)
|
||||
|
||||
assert result is None
|
||||
227
tests/tray/test_prefilter.py
Normal file
227
tests/tray/test_prefilter.py
Normal file
|
|
@ -0,0 +1,227 @@
|
|||
"""Tests for tray/prefilter.py — secret/path denylist and diff size cap."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from auto_commit_service.tray.prefilter import (
|
||||
DENYLIST_PATTERNS,
|
||||
MAX_DIFF_BYTES,
|
||||
filter_dirty_paths,
|
||||
is_secret_path,
|
||||
truncate_diff,
|
||||
)
|
||||
|
||||
|
||||
class TestConstants:
|
||||
def test_max_diff_bytes_is_128_kib(self) -> None:
|
||||
assert MAX_DIFF_BYTES == 131_072
|
||||
|
||||
def test_denylist_patterns_is_nonempty(self) -> None:
|
||||
assert len(DENYLIST_PATTERNS) > 0
|
||||
|
||||
|
||||
class TestIsSecretPath:
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
".env",
|
||||
".env.local",
|
||||
".env.production",
|
||||
".env.test",
|
||||
],
|
||||
)
|
||||
def test_env_files_are_denied(self, path: str) -> None:
|
||||
assert is_secret_path(path) is True
|
||||
|
||||
def test_env_example_is_allowed(self) -> None:
|
||||
assert is_secret_path(".env.example") is False
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"cert.pem",
|
||||
"server.key",
|
||||
"keystore.p12",
|
||||
"bundle.pfx",
|
||||
],
|
||||
)
|
||||
def test_certificate_and_key_files_are_denied(self, path: str) -> None:
|
||||
assert is_secret_path(path) is True
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"id_rsa",
|
||||
"id_rsa.pub",
|
||||
"id_dsa",
|
||||
"id_ecdsa",
|
||||
"id_ecdsa.pub",
|
||||
"id_ed25519",
|
||||
"id_ed25519.pub",
|
||||
],
|
||||
)
|
||||
def test_ssh_key_files_are_denied(self, path: str) -> None:
|
||||
assert is_secret_path(path) is True
|
||||
|
||||
def test_asc_gpg_key_is_denied(self) -> None:
|
||||
assert is_secret_path("key.asc") is True
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
".ssh/id_rsa",
|
||||
".ssh/config",
|
||||
".ssh/known_hosts",
|
||||
"home/user/.ssh/id_ed25519",
|
||||
],
|
||||
)
|
||||
def test_ssh_directory_paths_are_denied(self, path: str) -> None:
|
||||
assert is_secret_path(path) is True
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"secrets.yaml",
|
||||
"secrets.yml",
|
||||
"secrets.json",
|
||||
"config/secrets.yaml",
|
||||
"infra/k8s/secrets.json",
|
||||
],
|
||||
)
|
||||
def test_secrets_files_are_denied(self, path: str) -> None:
|
||||
assert is_secret_path(path) is True
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
".git/config",
|
||||
".git/credentials",
|
||||
],
|
||||
)
|
||||
def test_git_credential_files_are_denied(self, path: str) -> None:
|
||||
assert is_secret_path(path) is True
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"credentials.json",
|
||||
"config/credentials.json",
|
||||
],
|
||||
)
|
||||
def test_credentials_json_is_denied(self, path: str) -> None:
|
||||
assert is_secret_path(path) is True
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
".netrc",
|
||||
"home/.netrc",
|
||||
],
|
||||
)
|
||||
def test_netrc_is_denied(self, path: str) -> None:
|
||||
assert is_secret_path(path) is True
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"app.keystore",
|
||||
"android.jks",
|
||||
"release.keystore",
|
||||
],
|
||||
)
|
||||
def test_keystore_files_are_denied(self, path: str) -> None:
|
||||
assert is_secret_path(path) is True
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"src/foo.py",
|
||||
"README.md",
|
||||
"docs/guide.md",
|
||||
"package.json",
|
||||
"pyproject.toml",
|
||||
"src/config.py",
|
||||
"tests/test_main.py",
|
||||
],
|
||||
)
|
||||
def test_safe_paths_are_allowed(self, path: str) -> None:
|
||||
assert is_secret_path(path) is False
|
||||
|
||||
def test_empty_string_is_allowed(self) -> None:
|
||||
assert is_secret_path("") is False
|
||||
|
||||
def test_dotenv_with_relative_prefix_is_denied(self) -> None:
|
||||
assert is_secret_path("./.env") is True
|
||||
|
||||
def test_dotenv_with_leading_slash_is_denied(self) -> None:
|
||||
assert is_secret_path("/.env") is True
|
||||
|
||||
|
||||
class TestFilterDirtyPaths:
|
||||
def test_partitions_mixed_paths(self) -> None:
|
||||
allowed, denied = filter_dirty_paths(
|
||||
["ok.py", ".env", ".env.example", "secrets.yaml"]
|
||||
)
|
||||
assert allowed == ["ok.py", ".env.example"]
|
||||
assert denied == [".env", "secrets.yaml"]
|
||||
|
||||
def test_all_safe_paths_returns_empty_denied(self) -> None:
|
||||
allowed, denied = filter_dirty_paths(["src/main.py", "README.md"])
|
||||
assert allowed == ["src/main.py", "README.md"]
|
||||
assert denied == []
|
||||
|
||||
def test_all_denied_paths_returns_empty_allowed(self) -> None:
|
||||
allowed, denied = filter_dirty_paths([".env", "id_rsa", "secrets.json"])
|
||||
assert allowed == []
|
||||
assert set(denied) == {".env", "id_rsa", "secrets.json"}
|
||||
|
||||
def test_empty_input_returns_empty_lists(self) -> None:
|
||||
allowed, denied = filter_dirty_paths([])
|
||||
assert allowed == []
|
||||
assert denied == []
|
||||
|
||||
def test_order_preserved_in_allowed(self) -> None:
|
||||
paths = ["c.py", "a.py", "b.py"]
|
||||
allowed, _ = filter_dirty_paths(paths)
|
||||
assert allowed == paths
|
||||
|
||||
|
||||
class TestTruncateDiff:
|
||||
def test_short_diff_not_truncated(self) -> None:
|
||||
text, was_truncated = truncate_diff("short diff")
|
||||
assert text == "short diff"
|
||||
assert was_truncated is False
|
||||
|
||||
def test_empty_string_not_truncated(self) -> None:
|
||||
text, was_truncated = truncate_diff("")
|
||||
assert text == ""
|
||||
assert was_truncated is False
|
||||
|
||||
def test_exactly_max_bytes_not_truncated(self) -> None:
|
||||
boundary = "a" * MAX_DIFF_BYTES
|
||||
text, was_truncated = truncate_diff(boundary)
|
||||
assert was_truncated is False
|
||||
assert text == boundary
|
||||
|
||||
def test_one_byte_over_truncated(self) -> None:
|
||||
over = "a" * (MAX_DIFF_BYTES + 1)
|
||||
text, was_truncated = truncate_diff(over)
|
||||
assert was_truncated is True
|
||||
assert len(text.encode("utf-8")) <= MAX_DIFF_BYTES
|
||||
|
||||
def test_large_diff_truncated(self) -> None:
|
||||
_, was_truncated = truncate_diff("a" * 200_000)
|
||||
assert was_truncated is True
|
||||
|
||||
def test_custom_max_bytes(self) -> None:
|
||||
text, was_truncated = truncate_diff("hello world", max_bytes=5)
|
||||
assert was_truncated is True
|
||||
assert len(text.encode("utf-8")) <= 5
|
||||
|
||||
def test_result_is_valid_utf8(self) -> None:
|
||||
# Multi-byte chars that could be split at a byte boundary
|
||||
multi = "こんにちは" * 10_000
|
||||
text, was_truncated = truncate_diff(multi, max_bytes=100)
|
||||
assert was_truncated is True
|
||||
text.encode("utf-8") # Must not raise
|
||||
Loading…
Add table
Reference in a new issue