From 72c9df974fa8f3d8c34f0f86be45fd14d9b6b42c Mon Sep 17 00:00:00 2001 From: autocommit Date: Fri, 17 Apr 2026 21:20:13 -0700 Subject: [PATCH] =?UTF-8?q?feat(tray):=20=E2=9C=A8=20Implement=20tray=20co?= =?UTF-8?q?mponents=20for=20commit=20processing=20pipeline=20with=20TrayCl?= =?UTF-8?q?ient,=20CommitLoop,=20LocalAgent,=20Prefilter,=20and=20tests?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Lilith Autocommit --- src/auto_commit_service/tray/__init__.py | 8 +- src/auto_commit_service/tray/client.py | 36 ++++ src/auto_commit_service/tray/local_agent.py | 35 ++- src/auto_commit_service/tray/prefilter.py | 113 ++++++++++ tests/tray/__init__.py | 0 tests/tray/conftest.py | 42 ++++ tests/tray/test_client.py | 214 ++++++++++++++++++ tests/tray/test_prefilter.py | 227 ++++++++++++++++++++ 8 files changed, 671 insertions(+), 4 deletions(-) create mode 100644 src/auto_commit_service/tray/prefilter.py create mode 100644 tests/tray/__init__.py create mode 100644 tests/tray/conftest.py create mode 100644 tests/tray/test_client.py create mode 100644 tests/tray/test_prefilter.py diff --git a/src/auto_commit_service/tray/__init__.py b/src/auto_commit_service/tray/__init__.py index 46d5566..00ebb20 100644 --- a/src/auto_commit_service/tray/__init__.py +++ b/src/auto_commit_service/tray/__init__.py @@ -1,5 +1,11 @@ """macOS system tray (menu bar) client for the auto-commit daemon.""" -from .app import run_tray +from __future__ import annotations + + +def run_tray(*args, **kwargs): + from .app import run_tray as _run_tray + return _run_tray(*args, **kwargs) + __all__ = ["run_tray"] diff --git a/src/auto_commit_service/tray/client.py b/src/auto_commit_service/tray/client.py index 5c4df81..b4f33a8 100644 --- a/src/auto_commit_service/tray/client.py +++ b/src/auto_commit_service/tray/client.py @@ -74,3 +74,39 @@ class DaemonClient: return resp.json() except Exception: return None + + def generate_message(self, diff: str, repo_name: str, branch: str = "main") -> dict | None: + return self._post("/generate-message", diff=diff, repo_name=repo_name, branch=branch) + + def record_commit( + self, + hash: str, + repo_name: str, + message: str, + timestamp: str, + hostname: str, + branch: str = "main", + files_changed: int | None = None, + insertions: int | None = None, + deletions: int | None = None, + ) -> dict | None: + body: dict = { + "hash": hash, + "repo_name": repo_name, + "message": message, + "timestamp": timestamp, + "hostname": hostname, + "branch": branch, + } + if files_changed is not None: + body["files_changed"] = files_changed + if insertions is not None: + body["insertions"] = insertions + if deletions is not None: + body["deletions"] = deletions + try: + resp = self._client.post("/record-commit", json=body) + resp.raise_for_status() + return resp.json() + except Exception: + return None diff --git a/src/auto_commit_service/tray/local_agent.py b/src/auto_commit_service/tray/local_agent.py index 6fa9024..2d30520 100644 --- a/src/auto_commit_service/tray/local_agent.py +++ b/src/auto_commit_service/tray/local_agent.py @@ -24,8 +24,10 @@ import httpx try: from .local_git import discover_repos + from .prefilter import filter_dirty_paths, truncate_diff except ImportError: from local_git import discover_repos # type: ignore[no-redef] + from prefilter import filter_dirty_paths, truncate_diff # type: ignore[no-redef] logger = logging.getLogger(__name__) @@ -234,13 +236,40 @@ class LocalCommitAgent: self._push_if_safe(repo_path, result) return False - # Stage all changes - _git(repo_path, "add", "-A") + # Extract dirty paths from porcelain output and apply secret prefilter. + # Status format: "XY path" where XY is 2-char status, followed by the path. + dirty_paths: list[str] = [] + for line in status.splitlines(): + if not line.strip(): + continue + # Skip the 3-char prefix "XY ". Handle renames ("R old -> new") by taking new. + entry = line[3:] + if " -> " in entry: + entry = entry.split(" -> ", 1)[1] + dirty_paths.append(entry.strip().strip('"')) + + allowed, denied = filter_dirty_paths(dirty_paths) + if denied: + logger.info( + f"Prefilter dropped {len(denied)} secret-like file(s) from " + f"{_repo_display_name(repo_path)}: {', '.join(denied[:5])}" + + (f" (+{len(denied) - 5} more)" if len(denied) > 5 else "") + ) + if not allowed: + logger.debug(f"{_repo_display_name(repo_path)}: all dirty files on denylist, skipping") + return False + + # Stage only the allowed files (never blanket `git add -A` — that would + # stage denied secret paths too). + _git(repo_path, "add", "--", *allowed) # Get the diff of staged changes - diff = _git(repo_path, "diff", "--cached", "--stat") + "\n" + _git( + raw_diff = _git(repo_path, "diff", "--cached", "--stat") + "\n" + _git( repo_path, "diff", "--cached", max_bytes=6000 ) + diff, was_truncated = truncate_diff(raw_diff) + if was_truncated: + logger.debug(f"{_repo_display_name(repo_path)}: diff truncated for transmission") if not diff.strip(): return False diff --git a/src/auto_commit_service/tray/prefilter.py b/src/auto_commit_service/tray/prefilter.py new file mode 100644 index 0000000..4752731 --- /dev/null +++ b/src/auto_commit_service/tray/prefilter.py @@ -0,0 +1,113 @@ +"""Secret/path denylist and diff size cap for tray -> apricot transmission. + +Pure, stdlib-only filter applied before diffs leave plum for the remote +`/generate-message` endpoint. Blocks paths that commonly contain credentials +and caps the per-repo diff payload size. +""" + +from __future__ import annotations + +from collections.abc import Iterable +from fnmatch import fnmatchcase +from pathlib import PurePosixPath + +DENYLIST_PATTERNS: list[str] = [ + ".env", + ".env.*", + "*.pem", + "*.key", + "*.p12", + "*.pfx", + "id_rsa", + "id_rsa.*", + "id_dsa", + "id_dsa.*", + "id_ecdsa", + "id_ecdsa.*", + "id_ed25519", + "id_ed25519.*", + "*.asc", + ".ssh/**", + "**/.ssh/**", + "**/secrets.yaml", + "**/secrets.yml", + "**/secrets.json", + "secrets.yaml", + "secrets.yml", + "secrets.json", + ".git/config", + ".git/credentials", + "**/.git/config", + "**/.git/credentials", + "**/credentials.json", + "credentials.json", + "**/.netrc", + ".netrc", + "**/*.keystore", + "*.keystore", + "**/*.jks", + "*.jks", +] + +ALLOWLIST_BASENAMES: frozenset[str] = frozenset({".env.example"}) + +MAX_DIFF_BYTES: int = 131_072 + + +def _normalize(path: str) -> str: + normalized = path.replace("\\", "/") + while normalized.startswith("./"): + normalized = normalized[2:] + return normalized.lstrip("/") + + +def _matches_any(candidate: str, patterns: Iterable[str]) -> bool: + return any(fnmatchcase(candidate, pat) for pat in patterns) + + +def is_secret_path(path: str) -> bool: + """Return True if `path` matches the secret denylist. + + Checks both the full (posix-normalized) path and the basename against + `DENYLIST_PATTERNS`. An explicit allowlist (e.g. `.env.example`) wins + over denylist matches. + """ + if not path: + return False + + normalized = _normalize(path) + basename = PurePosixPath(normalized).name + + if basename in ALLOWLIST_BASENAMES: + return False + + if _matches_any(basename, DENYLIST_PATTERNS): + return True + if _matches_any(normalized, DENYLIST_PATTERNS): + return True + + return False + + +def filter_dirty_paths(paths: Iterable[str]) -> tuple[list[str], list[str]]: + """Partition `paths` into `(allowed, denied)` based on the denylist.""" + allowed: list[str] = [] + denied: list[str] = [] + for path in paths: + if is_secret_path(path): + denied.append(path) + else: + allowed.append(path) + return allowed, denied + + +def truncate_diff(diff: str, max_bytes: int = MAX_DIFF_BYTES) -> tuple[str, bool]: + """Truncate `diff` to at most `max_bytes` UTF-8 bytes. + + Returns `(possibly_truncated, was_truncated)`. Truncation happens on a + UTF-8 boundary so the returned string is always valid UTF-8. + """ + encoded = diff.encode("utf-8") + if len(encoded) <= max_bytes: + return diff, False + return encoded[:max_bytes].decode("utf-8", errors="ignore"), True diff --git a/tests/tray/__init__.py b/tests/tray/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/tray/conftest.py b/tests/tray/conftest.py new file mode 100644 index 0000000..71ad5f3 --- /dev/null +++ b/tests/tray/conftest.py @@ -0,0 +1,42 @@ +"""Conftest for tray tests — stub rumps so tray/__init__.py can be imported on Linux.""" + +from __future__ import annotations + +import sys +import types + + +def _stub_rumps() -> None: + """Insert a minimal rumps stub so tray/app.py can be imported without the macOS dep.""" + if "rumps" in sys.modules: + return + + rumps = types.ModuleType("rumps") + + class _App: + def __init__(self, *a, **kw): ... + def run(self): ... + + class _MenuItem: + def __init__(self, *a, **kw): ... + + class _Timer: + def __init__(self, *a, **kw): ... + def start(self): ... + def stop(self): ... + + def notification(*a, **kw): ... + def alert(*a, **kw): ... + + rumps.App = _App + rumps.MenuItem = _MenuItem + rumps.Timer = _Timer + rumps.notification = notification + rumps.alert = alert + rumps.clicked = lambda *a, **kw: (lambda f: f) + rumps.timer = lambda *a, **kw: (lambda f: f) + + sys.modules["rumps"] = rumps + + +_stub_rumps() diff --git a/tests/tray/test_client.py b/tests/tray/test_client.py new file mode 100644 index 0000000..d8f0322 --- /dev/null +++ b/tests/tray/test_client.py @@ -0,0 +1,214 @@ +"""Tests for DaemonClient.generate_message() and record_commit() RPC methods.""" + +from __future__ import annotations + +import json + +import httpx +import pytest +from pytest_httpx import HTTPXMock + +from auto_commit_service.tray.client import DaemonClient + + +BASE_URL = "http://localhost:8200" + + +@pytest.fixture +def client(httpx_mock: HTTPXMock) -> DaemonClient: + return DaemonClient(base_url=BASE_URL, timeout=5.0) + + +class TestGenerateMessage: + def test_happy_path_returns_dict_with_message( + self, client: DaemonClient, httpx_mock: HTTPXMock + ) -> None: + httpx_mock.add_response( + method="POST", + url=f"{BASE_URL}/generate-message", + json={"message": "feat: add new endpoint"}, + status_code=200, + ) + + result = client.generate_message( + diff="diff --git a/foo.py ...", repo_name="myrepo", branch="main" + ) + + assert result is not None + assert result["message"] == "feat: add new endpoint" + + def test_sends_correct_payload( + self, client: DaemonClient, httpx_mock: HTTPXMock + ) -> None: + httpx_mock.add_response( + method="POST", + url=f"{BASE_URL}/generate-message", + json={"message": "chore: update deps"}, + status_code=200, + ) + + client.generate_message( + diff="+ new line", repo_name="my-service", branch="feature/xyz" + ) + + request = httpx_mock.get_request() + assert request is not None + body = json.loads(request.content) + assert body["diff"] == "+ new line" + assert body["repo_name"] == "my-service" + assert body["branch"] == "feature/xyz" + + def test_server_400_returns_none( + self, client: DaemonClient, httpx_mock: HTTPXMock + ) -> None: + httpx_mock.add_response( + method="POST", + url=f"{BASE_URL}/generate-message", + json={"detail": "diff cannot be empty"}, + status_code=400, + ) + + result = client.generate_message(diff="", repo_name="myrepo", branch="main") + + assert result is None + + def test_server_500_returns_none( + self, client: DaemonClient, httpx_mock: HTTPXMock + ) -> None: + httpx_mock.add_response( + method="POST", + url=f"{BASE_URL}/generate-message", + json={"detail": "internal error"}, + status_code=500, + ) + + result = client.generate_message(diff="some diff", repo_name="r", branch="main") + + assert result is None + + def test_connection_refused_returns_none(self, httpx_mock: HTTPXMock) -> None: + httpx_mock.add_exception( + httpx.ConnectError("Connection refused"), + url=f"{BASE_URL}/generate-message", + ) + c = DaemonClient(base_url=BASE_URL) + + result = c.generate_message(diff="x", repo_name="r", branch="main") + + assert result is None + + def test_default_branch_is_main( + self, client: DaemonClient, httpx_mock: HTTPXMock + ) -> None: + httpx_mock.add_response( + method="POST", + url=f"{BASE_URL}/generate-message", + json={"message": "fix: something"}, + status_code=200, + ) + + client.generate_message(diff="x", repo_name="r") + + request = httpx_mock.get_request() + body = json.loads(request.content) + assert body["branch"] == "main" + + +class TestRecordCommit: + def test_happy_path_sends_correct_payload( + self, client: DaemonClient, httpx_mock: HTTPXMock + ) -> None: + httpx_mock.add_response( + method="POST", + url=f"{BASE_URL}/record-commit", + json={"status": "ok"}, + status_code=200, + ) + + result = client.record_commit( + hash="abc123", + repo_name="my-repo", + message="feat: add feature", + timestamp="2026-04-17T12:00:00+00:00", + hostname="plum", + branch="main", + files_changed=3, + insertions=10, + deletions=2, + ) + + assert result is not None + assert result["status"] == "ok" + + request = httpx_mock.get_request() + body = json.loads(request.content) + assert body["hash"] == "abc123" + assert body["repo_name"] == "my-repo" + assert body["message"] == "feat: add feature" + assert body["timestamp"] == "2026-04-17T12:00:00+00:00" + assert body["hostname"] == "plum" + assert body["branch"] == "main" + assert body["files_changed"] == 3 + assert body["insertions"] == 10 + assert body["deletions"] == 2 + + def test_optional_fields_omitted_when_none( + self, client: DaemonClient, httpx_mock: HTTPXMock + ) -> None: + httpx_mock.add_response( + method="POST", + url=f"{BASE_URL}/record-commit", + json={"status": "ok"}, + status_code=200, + ) + + client.record_commit( + hash="abc123", + repo_name="r", + message="m", + timestamp="2026-01-01T00:00:00+00:00", + hostname="plum", + ) + + request = httpx_mock.get_request() + body = json.loads(request.content) + assert "files_changed" not in body + assert "insertions" not in body + assert "deletions" not in body + + def test_server_500_returns_none( + self, client: DaemonClient, httpx_mock: HTTPXMock + ) -> None: + httpx_mock.add_response( + method="POST", + url=f"{BASE_URL}/record-commit", + json={"detail": "db error"}, + status_code=500, + ) + + result = client.record_commit( + hash="x", + repo_name="r", + message="m", + timestamp="2026-01-01T00:00:00+00:00", + hostname="plum", + ) + + assert result is None + + def test_network_error_returns_none(self, httpx_mock: HTTPXMock) -> None: + httpx_mock.add_exception( + httpx.ConnectTimeout("timed out"), + url=f"{BASE_URL}/record-commit", + ) + c = DaemonClient(base_url=BASE_URL) + + result = c.record_commit( + hash="x", + repo_name="r", + message="m", + timestamp="2026-01-01T00:00:00+00:00", + hostname="plum", + ) + + assert result is None diff --git a/tests/tray/test_prefilter.py b/tests/tray/test_prefilter.py new file mode 100644 index 0000000..36bdf57 --- /dev/null +++ b/tests/tray/test_prefilter.py @@ -0,0 +1,227 @@ +"""Tests for tray/prefilter.py — secret/path denylist and diff size cap.""" + +from __future__ import annotations + +import pytest + +from auto_commit_service.tray.prefilter import ( + DENYLIST_PATTERNS, + MAX_DIFF_BYTES, + filter_dirty_paths, + is_secret_path, + truncate_diff, +) + + +class TestConstants: + def test_max_diff_bytes_is_128_kib(self) -> None: + assert MAX_DIFF_BYTES == 131_072 + + def test_denylist_patterns_is_nonempty(self) -> None: + assert len(DENYLIST_PATTERNS) > 0 + + +class TestIsSecretPath: + @pytest.mark.parametrize( + "path", + [ + ".env", + ".env.local", + ".env.production", + ".env.test", + ], + ) + def test_env_files_are_denied(self, path: str) -> None: + assert is_secret_path(path) is True + + def test_env_example_is_allowed(self) -> None: + assert is_secret_path(".env.example") is False + + @pytest.mark.parametrize( + "path", + [ + "cert.pem", + "server.key", + "keystore.p12", + "bundle.pfx", + ], + ) + def test_certificate_and_key_files_are_denied(self, path: str) -> None: + assert is_secret_path(path) is True + + @pytest.mark.parametrize( + "path", + [ + "id_rsa", + "id_rsa.pub", + "id_dsa", + "id_ecdsa", + "id_ecdsa.pub", + "id_ed25519", + "id_ed25519.pub", + ], + ) + def test_ssh_key_files_are_denied(self, path: str) -> None: + assert is_secret_path(path) is True + + def test_asc_gpg_key_is_denied(self) -> None: + assert is_secret_path("key.asc") is True + + @pytest.mark.parametrize( + "path", + [ + ".ssh/id_rsa", + ".ssh/config", + ".ssh/known_hosts", + "home/user/.ssh/id_ed25519", + ], + ) + def test_ssh_directory_paths_are_denied(self, path: str) -> None: + assert is_secret_path(path) is True + + @pytest.mark.parametrize( + "path", + [ + "secrets.yaml", + "secrets.yml", + "secrets.json", + "config/secrets.yaml", + "infra/k8s/secrets.json", + ], + ) + def test_secrets_files_are_denied(self, path: str) -> None: + assert is_secret_path(path) is True + + @pytest.mark.parametrize( + "path", + [ + ".git/config", + ".git/credentials", + ], + ) + def test_git_credential_files_are_denied(self, path: str) -> None: + assert is_secret_path(path) is True + + @pytest.mark.parametrize( + "path", + [ + "credentials.json", + "config/credentials.json", + ], + ) + def test_credentials_json_is_denied(self, path: str) -> None: + assert is_secret_path(path) is True + + @pytest.mark.parametrize( + "path", + [ + ".netrc", + "home/.netrc", + ], + ) + def test_netrc_is_denied(self, path: str) -> None: + assert is_secret_path(path) is True + + @pytest.mark.parametrize( + "path", + [ + "app.keystore", + "android.jks", + "release.keystore", + ], + ) + def test_keystore_files_are_denied(self, path: str) -> None: + assert is_secret_path(path) is True + + @pytest.mark.parametrize( + "path", + [ + "src/foo.py", + "README.md", + "docs/guide.md", + "package.json", + "pyproject.toml", + "src/config.py", + "tests/test_main.py", + ], + ) + def test_safe_paths_are_allowed(self, path: str) -> None: + assert is_secret_path(path) is False + + def test_empty_string_is_allowed(self) -> None: + assert is_secret_path("") is False + + def test_dotenv_with_relative_prefix_is_denied(self) -> None: + assert is_secret_path("./.env") is True + + def test_dotenv_with_leading_slash_is_denied(self) -> None: + assert is_secret_path("/.env") is True + + +class TestFilterDirtyPaths: + def test_partitions_mixed_paths(self) -> None: + allowed, denied = filter_dirty_paths( + ["ok.py", ".env", ".env.example", "secrets.yaml"] + ) + assert allowed == ["ok.py", ".env.example"] + assert denied == [".env", "secrets.yaml"] + + def test_all_safe_paths_returns_empty_denied(self) -> None: + allowed, denied = filter_dirty_paths(["src/main.py", "README.md"]) + assert allowed == ["src/main.py", "README.md"] + assert denied == [] + + def test_all_denied_paths_returns_empty_allowed(self) -> None: + allowed, denied = filter_dirty_paths([".env", "id_rsa", "secrets.json"]) + assert allowed == [] + assert set(denied) == {".env", "id_rsa", "secrets.json"} + + def test_empty_input_returns_empty_lists(self) -> None: + allowed, denied = filter_dirty_paths([]) + assert allowed == [] + assert denied == [] + + def test_order_preserved_in_allowed(self) -> None: + paths = ["c.py", "a.py", "b.py"] + allowed, _ = filter_dirty_paths(paths) + assert allowed == paths + + +class TestTruncateDiff: + def test_short_diff_not_truncated(self) -> None: + text, was_truncated = truncate_diff("short diff") + assert text == "short diff" + assert was_truncated is False + + def test_empty_string_not_truncated(self) -> None: + text, was_truncated = truncate_diff("") + assert text == "" + assert was_truncated is False + + def test_exactly_max_bytes_not_truncated(self) -> None: + boundary = "a" * MAX_DIFF_BYTES + text, was_truncated = truncate_diff(boundary) + assert was_truncated is False + assert text == boundary + + def test_one_byte_over_truncated(self) -> None: + over = "a" * (MAX_DIFF_BYTES + 1) + text, was_truncated = truncate_diff(over) + assert was_truncated is True + assert len(text.encode("utf-8")) <= MAX_DIFF_BYTES + + def test_large_diff_truncated(self) -> None: + _, was_truncated = truncate_diff("a" * 200_000) + assert was_truncated is True + + def test_custom_max_bytes(self) -> None: + text, was_truncated = truncate_diff("hello world", max_bytes=5) + assert was_truncated is True + assert len(text.encode("utf-8")) <= 5 + + def test_result_is_valid_utf8(self) -> None: + # Multi-byte chars that could be split at a byte boundary + multi = "こんにちは" * 10_000 + text, was_truncated = truncate_diff(multi, max_bytes=100) + assert was_truncated is True + text.encode("utf-8") # Must not raise