auto-commit-service/src/auto_commit_service/git/operations.py

286 lines
8 KiB
Python

"""Async git command wrappers using subprocess."""
import asyncio
import logging
from pathlib import Path
from .repository import GitStatus, CommitResult, PushResult
logger = logging.getLogger(__name__)
class GitError(Exception):
"""Base exception for git operations."""
def __init__(self, message: str, stderr: str = "", returncode: int = 1):
super().__init__(message)
self.stderr = stderr
self.returncode = returncode
class MergeConflictError(GitError):
"""Raised when there's a merge conflict."""
pass
class PushRejectedError(GitError):
"""Raised when push is rejected by remote."""
pass
async def _run_git_command(
*args: str,
cwd: Path,
check: bool = True,
) -> tuple[str, str, int]:
"""Run a git command asynchronously.
Uses asyncio subprocess with argument list (safe, no shell injection).
This is equivalent to Node.js execFile - arguments are passed directly
to the process without shell interpretation.
"""
# asyncio.create_subprocess_exec is safe - no shell, args passed directly
create_process = asyncio.create_subprocess_exec
proc = await create_process(
"git",
*args,
cwd=str(cwd),
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await proc.communicate()
stdout_str = stdout.decode().strip()
stderr_str = stderr.decode().strip()
returncode = proc.returncode or 0
if check and returncode != 0:
raise GitError(
f"git {' '.join(args)} failed: {stderr_str}",
stderr=stderr_str,
returncode=returncode,
)
return stdout_str, stderr_str, returncode
async def git_status(repo_path: Path) -> GitStatus:
"""Get the git status of a repository."""
# Get porcelain status for parsing
stdout, _, _ = await _run_git_command("status", "--porcelain", "-b", cwd=repo_path)
lines = stdout.split("\n") if stdout else []
staged = []
modified = []
untracked = []
deleted = []
branch = "main"
ahead = 0
behind = 0
for line in lines:
if not line:
continue
if line.startswith("##"):
# Parse branch info: ## main...origin/main [ahead 1, behind 2]
branch_info = line[3:]
if "..." in branch_info:
branch = branch_info.split("...")[0]
else:
branch = branch_info.split()[0] if branch_info else "main"
if "[ahead " in line:
try:
ahead_part = line.split("[ahead ")[1].split("]")[0].split(",")[0]
ahead = int(ahead_part)
except (IndexError, ValueError):
pass
if "behind " in line:
try:
behind_part = line.split("behind ")[1].split("]")[0]
behind = int(behind_part)
except (IndexError, ValueError):
pass
continue
# Parse file status
if len(line) >= 3:
index_status = line[0]
worktree_status = line[1]
filepath = line[3:]
if index_status == "?" and worktree_status == "?":
untracked.append(filepath)
elif index_status == "D" or worktree_status == "D":
deleted.append(filepath)
elif index_status in "MARC":
staged.append(filepath)
elif worktree_status == "M":
modified.append(filepath)
has_changes = bool(staged or modified or untracked or deleted)
return GitStatus(
has_changes=has_changes,
staged=staged,
modified=modified,
untracked=untracked,
deleted=deleted,
branch=branch,
ahead=ahead,
behind=behind,
)
async def git_diff(repo_path: Path, staged: bool = False) -> str:
"""Get the diff of changes.
Args:
repo_path: Path to the repository
staged: If True, show only staged changes (--cached)
"""
args = ["diff"]
if staged:
args.append("--cached")
else:
args.append("HEAD")
stdout, _, _ = await _run_git_command(*args, cwd=repo_path, check=False)
return stdout
async def git_add_all(repo_path: Path) -> None:
"""Stage all changes including untracked files."""
await _run_git_command("add", "-A", cwd=repo_path)
async def git_commit(repo_path: Path, message: str) -> CommitResult:
"""Create a commit with the given message."""
try:
stdout, stderr, returncode = await _run_git_command(
"commit", "-m", message, cwd=repo_path, check=False
)
if returncode != 0:
# Check for "nothing to commit"
if "nothing to commit" in stdout or "nothing to commit" in stderr:
return CommitResult(
success=False,
error="Nothing to commit",
)
return CommitResult(
success=False,
error=stderr or stdout,
)
# Get commit hash
hash_stdout, _, _ = await _run_git_command(
"rev-parse", "--short", "HEAD", cwd=repo_path
)
return CommitResult(
success=True,
commit_hash=hash_stdout,
message=message,
)
except GitError as e:
return CommitResult(
success=False,
error=str(e),
)
async def git_push(
repo_path: Path,
remote: str = "origin",
branch: str = "main",
) -> PushResult:
"""Push commits to remote."""
try:
_, stderr, returncode = await _run_git_command(
"push", remote, branch, cwd=repo_path, check=False
)
if returncode != 0:
rejected = "rejected" in stderr.lower() or "non-fast-forward" in stderr.lower()
if rejected:
raise PushRejectedError(
f"Push rejected: {stderr}",
stderr=stderr,
returncode=returncode,
)
return PushResult(
success=False,
remote=remote,
branch=branch,
error=stderr,
rejected=rejected,
)
return PushResult(
success=True,
remote=remote,
branch=branch,
)
except PushRejectedError:
raise
except GitError as e:
return PushResult(
success=False,
remote=remote,
branch=branch,
error=str(e),
)
async def git_pull_rebase(
repo_path: Path,
remote: str = "origin",
branch: str = "main",
) -> bool:
"""Pull with rebase to resolve diverged history.
Returns True if successful, raises MergeConflictError on conflicts.
"""
try:
_, stderr, returncode = await _run_git_command(
"pull", "--rebase", remote, branch, cwd=repo_path, check=False
)
if returncode != 0:
if "conflict" in stderr.lower() or "CONFLICT" in stderr:
# Abort the rebase
await _run_git_command("rebase", "--abort", cwd=repo_path, check=False)
raise MergeConflictError(
f"Merge conflict during rebase: {stderr}",
stderr=stderr,
returncode=returncode,
)
raise GitError(
f"Pull rebase failed: {stderr}",
stderr=stderr,
returncode=returncode,
)
return True
except (MergeConflictError, GitError):
raise
except Exception as e:
raise GitError(f"Unexpected error during pull rebase: {e}")
async def git_log_recent(
repo_path: Path,
count: int = 5,
) -> list[str]:
"""Get recent commit messages for style reference."""
stdout, _, _ = await _run_git_command(
"log", f"-{count}", "--format=%s", cwd=repo_path, check=False
)
return [line for line in stdout.split("\n") if line]