feat(tray): Implement tray components for commit processing pipeline with TrayClient, CommitLoop, LocalAgent, Prefilter, and tests

Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
This commit is contained in:
autocommit 2026-04-17 21:20:13 -07:00
parent 6358be5902
commit 72c9df974f
8 changed files with 671 additions and 4 deletions

View file

@ -1,5 +1,11 @@
"""macOS system tray (menu bar) client for the auto-commit daemon."""
from .app import run_tray
from __future__ import annotations
def run_tray(*args, **kwargs):
from .app import run_tray as _run_tray
return _run_tray(*args, **kwargs)
__all__ = ["run_tray"]

View file

@ -74,3 +74,39 @@ class DaemonClient:
return resp.json()
except Exception:
return None
def generate_message(self, diff: str, repo_name: str, branch: str = "main") -> dict | None:
return self._post("/generate-message", diff=diff, repo_name=repo_name, branch=branch)
def record_commit(
self,
hash: str,
repo_name: str,
message: str,
timestamp: str,
hostname: str,
branch: str = "main",
files_changed: int | None = None,
insertions: int | None = None,
deletions: int | None = None,
) -> dict | None:
body: dict = {
"hash": hash,
"repo_name": repo_name,
"message": message,
"timestamp": timestamp,
"hostname": hostname,
"branch": branch,
}
if files_changed is not None:
body["files_changed"] = files_changed
if insertions is not None:
body["insertions"] = insertions
if deletions is not None:
body["deletions"] = deletions
try:
resp = self._client.post("/record-commit", json=body)
resp.raise_for_status()
return resp.json()
except Exception:
return None

View file

@ -24,8 +24,10 @@ import httpx
try:
from .local_git import discover_repos
from .prefilter import filter_dirty_paths, truncate_diff
except ImportError:
from local_git import discover_repos # type: ignore[no-redef]
from prefilter import filter_dirty_paths, truncate_diff # type: ignore[no-redef]
logger = logging.getLogger(__name__)
@ -234,13 +236,40 @@ class LocalCommitAgent:
self._push_if_safe(repo_path, result)
return False
# Stage all changes
_git(repo_path, "add", "-A")
# Extract dirty paths from porcelain output and apply secret prefilter.
# Status format: "XY path" where XY is 2-char status, followed by the path.
dirty_paths: list[str] = []
for line in status.splitlines():
if not line.strip():
continue
# Skip the 3-char prefix "XY ". Handle renames ("R old -> new") by taking new.
entry = line[3:]
if " -> " in entry:
entry = entry.split(" -> ", 1)[1]
dirty_paths.append(entry.strip().strip('"'))
allowed, denied = filter_dirty_paths(dirty_paths)
if denied:
logger.info(
f"Prefilter dropped {len(denied)} secret-like file(s) from "
f"{_repo_display_name(repo_path)}: {', '.join(denied[:5])}"
+ (f" (+{len(denied) - 5} more)" if len(denied) > 5 else "")
)
if not allowed:
logger.debug(f"{_repo_display_name(repo_path)}: all dirty files on denylist, skipping")
return False
# Stage only the allowed files (never blanket `git add -A` — that would
# stage denied secret paths too).
_git(repo_path, "add", "--", *allowed)
# Get the diff of staged changes
diff = _git(repo_path, "diff", "--cached", "--stat") + "\n" + _git(
raw_diff = _git(repo_path, "diff", "--cached", "--stat") + "\n" + _git(
repo_path, "diff", "--cached", max_bytes=6000
)
diff, was_truncated = truncate_diff(raw_diff)
if was_truncated:
logger.debug(f"{_repo_display_name(repo_path)}: diff truncated for transmission")
if not diff.strip():
return False

View file

@ -0,0 +1,113 @@
"""Secret/path denylist and diff size cap for tray -> apricot transmission.
Pure, stdlib-only filter applied before diffs leave plum for the remote
`/generate-message` endpoint. Blocks paths that commonly contain credentials
and caps the per-repo diff payload size.
"""
from __future__ import annotations
from collections.abc import Iterable
from fnmatch import fnmatchcase
from pathlib import PurePosixPath
DENYLIST_PATTERNS: list[str] = [
".env",
".env.*",
"*.pem",
"*.key",
"*.p12",
"*.pfx",
"id_rsa",
"id_rsa.*",
"id_dsa",
"id_dsa.*",
"id_ecdsa",
"id_ecdsa.*",
"id_ed25519",
"id_ed25519.*",
"*.asc",
".ssh/**",
"**/.ssh/**",
"**/secrets.yaml",
"**/secrets.yml",
"**/secrets.json",
"secrets.yaml",
"secrets.yml",
"secrets.json",
".git/config",
".git/credentials",
"**/.git/config",
"**/.git/credentials",
"**/credentials.json",
"credentials.json",
"**/.netrc",
".netrc",
"**/*.keystore",
"*.keystore",
"**/*.jks",
"*.jks",
]
ALLOWLIST_BASENAMES: frozenset[str] = frozenset({".env.example"})
MAX_DIFF_BYTES: int = 131_072
def _normalize(path: str) -> str:
normalized = path.replace("\\", "/")
while normalized.startswith("./"):
normalized = normalized[2:]
return normalized.lstrip("/")
def _matches_any(candidate: str, patterns: Iterable[str]) -> bool:
return any(fnmatchcase(candidate, pat) for pat in patterns)
def is_secret_path(path: str) -> bool:
"""Return True if `path` matches the secret denylist.
Checks both the full (posix-normalized) path and the basename against
`DENYLIST_PATTERNS`. An explicit allowlist (e.g. `.env.example`) wins
over denylist matches.
"""
if not path:
return False
normalized = _normalize(path)
basename = PurePosixPath(normalized).name
if basename in ALLOWLIST_BASENAMES:
return False
if _matches_any(basename, DENYLIST_PATTERNS):
return True
if _matches_any(normalized, DENYLIST_PATTERNS):
return True
return False
def filter_dirty_paths(paths: Iterable[str]) -> tuple[list[str], list[str]]:
"""Partition `paths` into `(allowed, denied)` based on the denylist."""
allowed: list[str] = []
denied: list[str] = []
for path in paths:
if is_secret_path(path):
denied.append(path)
else:
allowed.append(path)
return allowed, denied
def truncate_diff(diff: str, max_bytes: int = MAX_DIFF_BYTES) -> tuple[str, bool]:
"""Truncate `diff` to at most `max_bytes` UTF-8 bytes.
Returns `(possibly_truncated, was_truncated)`. Truncation happens on a
UTF-8 boundary so the returned string is always valid UTF-8.
"""
encoded = diff.encode("utf-8")
if len(encoded) <= max_bytes:
return diff, False
return encoded[:max_bytes].decode("utf-8", errors="ignore"), True

0
tests/tray/__init__.py Normal file
View file

42
tests/tray/conftest.py Normal file
View file

@ -0,0 +1,42 @@
"""Conftest for tray tests — stub rumps so tray/__init__.py can be imported on Linux."""
from __future__ import annotations
import sys
import types
def _stub_rumps() -> None:
"""Insert a minimal rumps stub so tray/app.py can be imported without the macOS dep."""
if "rumps" in sys.modules:
return
rumps = types.ModuleType("rumps")
class _App:
def __init__(self, *a, **kw): ...
def run(self): ...
class _MenuItem:
def __init__(self, *a, **kw): ...
class _Timer:
def __init__(self, *a, **kw): ...
def start(self): ...
def stop(self): ...
def notification(*a, **kw): ...
def alert(*a, **kw): ...
rumps.App = _App
rumps.MenuItem = _MenuItem
rumps.Timer = _Timer
rumps.notification = notification
rumps.alert = alert
rumps.clicked = lambda *a, **kw: (lambda f: f)
rumps.timer = lambda *a, **kw: (lambda f: f)
sys.modules["rumps"] = rumps
_stub_rumps()

214
tests/tray/test_client.py Normal file
View file

@ -0,0 +1,214 @@
"""Tests for DaemonClient.generate_message() and record_commit() RPC methods."""
from __future__ import annotations
import json
import httpx
import pytest
from pytest_httpx import HTTPXMock
from auto_commit_service.tray.client import DaemonClient
BASE_URL = "http://localhost:8200"
@pytest.fixture
def client(httpx_mock: HTTPXMock) -> DaemonClient:
return DaemonClient(base_url=BASE_URL, timeout=5.0)
class TestGenerateMessage:
def test_happy_path_returns_dict_with_message(
self, client: DaemonClient, httpx_mock: HTTPXMock
) -> None:
httpx_mock.add_response(
method="POST",
url=f"{BASE_URL}/generate-message",
json={"message": "feat: add new endpoint"},
status_code=200,
)
result = client.generate_message(
diff="diff --git a/foo.py ...", repo_name="myrepo", branch="main"
)
assert result is not None
assert result["message"] == "feat: add new endpoint"
def test_sends_correct_payload(
self, client: DaemonClient, httpx_mock: HTTPXMock
) -> None:
httpx_mock.add_response(
method="POST",
url=f"{BASE_URL}/generate-message",
json={"message": "chore: update deps"},
status_code=200,
)
client.generate_message(
diff="+ new line", repo_name="my-service", branch="feature/xyz"
)
request = httpx_mock.get_request()
assert request is not None
body = json.loads(request.content)
assert body["diff"] == "+ new line"
assert body["repo_name"] == "my-service"
assert body["branch"] == "feature/xyz"
def test_server_400_returns_none(
self, client: DaemonClient, httpx_mock: HTTPXMock
) -> None:
httpx_mock.add_response(
method="POST",
url=f"{BASE_URL}/generate-message",
json={"detail": "diff cannot be empty"},
status_code=400,
)
result = client.generate_message(diff="", repo_name="myrepo", branch="main")
assert result is None
def test_server_500_returns_none(
self, client: DaemonClient, httpx_mock: HTTPXMock
) -> None:
httpx_mock.add_response(
method="POST",
url=f"{BASE_URL}/generate-message",
json={"detail": "internal error"},
status_code=500,
)
result = client.generate_message(diff="some diff", repo_name="r", branch="main")
assert result is None
def test_connection_refused_returns_none(self, httpx_mock: HTTPXMock) -> None:
httpx_mock.add_exception(
httpx.ConnectError("Connection refused"),
url=f"{BASE_URL}/generate-message",
)
c = DaemonClient(base_url=BASE_URL)
result = c.generate_message(diff="x", repo_name="r", branch="main")
assert result is None
def test_default_branch_is_main(
self, client: DaemonClient, httpx_mock: HTTPXMock
) -> None:
httpx_mock.add_response(
method="POST",
url=f"{BASE_URL}/generate-message",
json={"message": "fix: something"},
status_code=200,
)
client.generate_message(diff="x", repo_name="r")
request = httpx_mock.get_request()
body = json.loads(request.content)
assert body["branch"] == "main"
class TestRecordCommit:
def test_happy_path_sends_correct_payload(
self, client: DaemonClient, httpx_mock: HTTPXMock
) -> None:
httpx_mock.add_response(
method="POST",
url=f"{BASE_URL}/record-commit",
json={"status": "ok"},
status_code=200,
)
result = client.record_commit(
hash="abc123",
repo_name="my-repo",
message="feat: add feature",
timestamp="2026-04-17T12:00:00+00:00",
hostname="plum",
branch="main",
files_changed=3,
insertions=10,
deletions=2,
)
assert result is not None
assert result["status"] == "ok"
request = httpx_mock.get_request()
body = json.loads(request.content)
assert body["hash"] == "abc123"
assert body["repo_name"] == "my-repo"
assert body["message"] == "feat: add feature"
assert body["timestamp"] == "2026-04-17T12:00:00+00:00"
assert body["hostname"] == "plum"
assert body["branch"] == "main"
assert body["files_changed"] == 3
assert body["insertions"] == 10
assert body["deletions"] == 2
def test_optional_fields_omitted_when_none(
self, client: DaemonClient, httpx_mock: HTTPXMock
) -> None:
httpx_mock.add_response(
method="POST",
url=f"{BASE_URL}/record-commit",
json={"status": "ok"},
status_code=200,
)
client.record_commit(
hash="abc123",
repo_name="r",
message="m",
timestamp="2026-01-01T00:00:00+00:00",
hostname="plum",
)
request = httpx_mock.get_request()
body = json.loads(request.content)
assert "files_changed" not in body
assert "insertions" not in body
assert "deletions" not in body
def test_server_500_returns_none(
self, client: DaemonClient, httpx_mock: HTTPXMock
) -> None:
httpx_mock.add_response(
method="POST",
url=f"{BASE_URL}/record-commit",
json={"detail": "db error"},
status_code=500,
)
result = client.record_commit(
hash="x",
repo_name="r",
message="m",
timestamp="2026-01-01T00:00:00+00:00",
hostname="plum",
)
assert result is None
def test_network_error_returns_none(self, httpx_mock: HTTPXMock) -> None:
httpx_mock.add_exception(
httpx.ConnectTimeout("timed out"),
url=f"{BASE_URL}/record-commit",
)
c = DaemonClient(base_url=BASE_URL)
result = c.record_commit(
hash="x",
repo_name="r",
message="m",
timestamp="2026-01-01T00:00:00+00:00",
hostname="plum",
)
assert result is None

View file

@ -0,0 +1,227 @@
"""Tests for tray/prefilter.py — secret/path denylist and diff size cap."""
from __future__ import annotations
import pytest
from auto_commit_service.tray.prefilter import (
DENYLIST_PATTERNS,
MAX_DIFF_BYTES,
filter_dirty_paths,
is_secret_path,
truncate_diff,
)
class TestConstants:
def test_max_diff_bytes_is_128_kib(self) -> None:
assert MAX_DIFF_BYTES == 131_072
def test_denylist_patterns_is_nonempty(self) -> None:
assert len(DENYLIST_PATTERNS) > 0
class TestIsSecretPath:
@pytest.mark.parametrize(
"path",
[
".env",
".env.local",
".env.production",
".env.test",
],
)
def test_env_files_are_denied(self, path: str) -> None:
assert is_secret_path(path) is True
def test_env_example_is_allowed(self) -> None:
assert is_secret_path(".env.example") is False
@pytest.mark.parametrize(
"path",
[
"cert.pem",
"server.key",
"keystore.p12",
"bundle.pfx",
],
)
def test_certificate_and_key_files_are_denied(self, path: str) -> None:
assert is_secret_path(path) is True
@pytest.mark.parametrize(
"path",
[
"id_rsa",
"id_rsa.pub",
"id_dsa",
"id_ecdsa",
"id_ecdsa.pub",
"id_ed25519",
"id_ed25519.pub",
],
)
def test_ssh_key_files_are_denied(self, path: str) -> None:
assert is_secret_path(path) is True
def test_asc_gpg_key_is_denied(self) -> None:
assert is_secret_path("key.asc") is True
@pytest.mark.parametrize(
"path",
[
".ssh/id_rsa",
".ssh/config",
".ssh/known_hosts",
"home/user/.ssh/id_ed25519",
],
)
def test_ssh_directory_paths_are_denied(self, path: str) -> None:
assert is_secret_path(path) is True
@pytest.mark.parametrize(
"path",
[
"secrets.yaml",
"secrets.yml",
"secrets.json",
"config/secrets.yaml",
"infra/k8s/secrets.json",
],
)
def test_secrets_files_are_denied(self, path: str) -> None:
assert is_secret_path(path) is True
@pytest.mark.parametrize(
"path",
[
".git/config",
".git/credentials",
],
)
def test_git_credential_files_are_denied(self, path: str) -> None:
assert is_secret_path(path) is True
@pytest.mark.parametrize(
"path",
[
"credentials.json",
"config/credentials.json",
],
)
def test_credentials_json_is_denied(self, path: str) -> None:
assert is_secret_path(path) is True
@pytest.mark.parametrize(
"path",
[
".netrc",
"home/.netrc",
],
)
def test_netrc_is_denied(self, path: str) -> None:
assert is_secret_path(path) is True
@pytest.mark.parametrize(
"path",
[
"app.keystore",
"android.jks",
"release.keystore",
],
)
def test_keystore_files_are_denied(self, path: str) -> None:
assert is_secret_path(path) is True
@pytest.mark.parametrize(
"path",
[
"src/foo.py",
"README.md",
"docs/guide.md",
"package.json",
"pyproject.toml",
"src/config.py",
"tests/test_main.py",
],
)
def test_safe_paths_are_allowed(self, path: str) -> None:
assert is_secret_path(path) is False
def test_empty_string_is_allowed(self) -> None:
assert is_secret_path("") is False
def test_dotenv_with_relative_prefix_is_denied(self) -> None:
assert is_secret_path("./.env") is True
def test_dotenv_with_leading_slash_is_denied(self) -> None:
assert is_secret_path("/.env") is True
class TestFilterDirtyPaths:
def test_partitions_mixed_paths(self) -> None:
allowed, denied = filter_dirty_paths(
["ok.py", ".env", ".env.example", "secrets.yaml"]
)
assert allowed == ["ok.py", ".env.example"]
assert denied == [".env", "secrets.yaml"]
def test_all_safe_paths_returns_empty_denied(self) -> None:
allowed, denied = filter_dirty_paths(["src/main.py", "README.md"])
assert allowed == ["src/main.py", "README.md"]
assert denied == []
def test_all_denied_paths_returns_empty_allowed(self) -> None:
allowed, denied = filter_dirty_paths([".env", "id_rsa", "secrets.json"])
assert allowed == []
assert set(denied) == {".env", "id_rsa", "secrets.json"}
def test_empty_input_returns_empty_lists(self) -> None:
allowed, denied = filter_dirty_paths([])
assert allowed == []
assert denied == []
def test_order_preserved_in_allowed(self) -> None:
paths = ["c.py", "a.py", "b.py"]
allowed, _ = filter_dirty_paths(paths)
assert allowed == paths
class TestTruncateDiff:
def test_short_diff_not_truncated(self) -> None:
text, was_truncated = truncate_diff("short diff")
assert text == "short diff"
assert was_truncated is False
def test_empty_string_not_truncated(self) -> None:
text, was_truncated = truncate_diff("")
assert text == ""
assert was_truncated is False
def test_exactly_max_bytes_not_truncated(self) -> None:
boundary = "a" * MAX_DIFF_BYTES
text, was_truncated = truncate_diff(boundary)
assert was_truncated is False
assert text == boundary
def test_one_byte_over_truncated(self) -> None:
over = "a" * (MAX_DIFF_BYTES + 1)
text, was_truncated = truncate_diff(over)
assert was_truncated is True
assert len(text.encode("utf-8")) <= MAX_DIFF_BYTES
def test_large_diff_truncated(self) -> None:
_, was_truncated = truncate_diff("a" * 200_000)
assert was_truncated is True
def test_custom_max_bytes(self) -> None:
text, was_truncated = truncate_diff("hello world", max_bytes=5)
assert was_truncated is True
assert len(text.encode("utf-8")) <= 5
def test_result_is_valid_utf8(self) -> None:
# Multi-byte chars that could be split at a byte boundary
multi = "こんにちは" * 10_000
text, was_truncated = truncate_diff(multi, max_bytes=100)
assert was_truncated is True
text.encode("utf-8") # Must not raise