286 lines
8 KiB
Python
286 lines
8 KiB
Python
"""Async git command wrappers using subprocess."""
|
|
|
|
import asyncio
|
|
import logging
|
|
from pathlib import Path
|
|
|
|
from .repository import GitStatus, CommitResult, PushResult
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class GitError(Exception):
|
|
"""Base exception for git operations."""
|
|
|
|
def __init__(self, message: str, stderr: str = "", returncode: int = 1):
|
|
super().__init__(message)
|
|
self.stderr = stderr
|
|
self.returncode = returncode
|
|
|
|
|
|
class MergeConflictError(GitError):
|
|
"""Raised when there's a merge conflict."""
|
|
|
|
pass
|
|
|
|
|
|
class PushRejectedError(GitError):
|
|
"""Raised when push is rejected by remote."""
|
|
|
|
pass
|
|
|
|
|
|
async def _run_git_command(
|
|
*args: str,
|
|
cwd: Path,
|
|
check: bool = True,
|
|
) -> tuple[str, str, int]:
|
|
"""Run a git command asynchronously.
|
|
|
|
Uses asyncio subprocess with argument list (safe, no shell injection).
|
|
This is equivalent to Node.js execFile - arguments are passed directly
|
|
to the process without shell interpretation.
|
|
"""
|
|
# asyncio.create_subprocess_exec is safe - no shell, args passed directly
|
|
create_process = asyncio.create_subprocess_exec
|
|
proc = await create_process(
|
|
"git",
|
|
*args,
|
|
cwd=str(cwd),
|
|
stdout=asyncio.subprocess.PIPE,
|
|
stderr=asyncio.subprocess.PIPE,
|
|
)
|
|
stdout, stderr = await proc.communicate()
|
|
stdout_str = stdout.decode().strip()
|
|
stderr_str = stderr.decode().strip()
|
|
returncode = proc.returncode or 0
|
|
|
|
if check and returncode != 0:
|
|
raise GitError(
|
|
f"git {' '.join(args)} failed: {stderr_str}",
|
|
stderr=stderr_str,
|
|
returncode=returncode,
|
|
)
|
|
|
|
return stdout_str, stderr_str, returncode
|
|
|
|
|
|
async def git_status(repo_path: Path) -> GitStatus:
|
|
"""Get the git status of a repository."""
|
|
# Get porcelain status for parsing
|
|
stdout, _, _ = await _run_git_command("status", "--porcelain", "-b", cwd=repo_path)
|
|
|
|
lines = stdout.split("\n") if stdout else []
|
|
|
|
staged = []
|
|
modified = []
|
|
untracked = []
|
|
deleted = []
|
|
branch = "main"
|
|
ahead = 0
|
|
behind = 0
|
|
|
|
for line in lines:
|
|
if not line:
|
|
continue
|
|
|
|
if line.startswith("##"):
|
|
# Parse branch info: ## main...origin/main [ahead 1, behind 2]
|
|
branch_info = line[3:]
|
|
if "..." in branch_info:
|
|
branch = branch_info.split("...")[0]
|
|
else:
|
|
branch = branch_info.split()[0] if branch_info else "main"
|
|
|
|
if "[ahead " in line:
|
|
try:
|
|
ahead_part = line.split("[ahead ")[1].split("]")[0].split(",")[0]
|
|
ahead = int(ahead_part)
|
|
except (IndexError, ValueError):
|
|
pass
|
|
if "behind " in line:
|
|
try:
|
|
behind_part = line.split("behind ")[1].split("]")[0]
|
|
behind = int(behind_part)
|
|
except (IndexError, ValueError):
|
|
pass
|
|
continue
|
|
|
|
# Parse file status
|
|
if len(line) >= 3:
|
|
index_status = line[0]
|
|
worktree_status = line[1]
|
|
filepath = line[3:]
|
|
|
|
if index_status == "?" and worktree_status == "?":
|
|
untracked.append(filepath)
|
|
elif index_status == "D" or worktree_status == "D":
|
|
deleted.append(filepath)
|
|
elif index_status in "MARC":
|
|
staged.append(filepath)
|
|
elif worktree_status == "M":
|
|
modified.append(filepath)
|
|
|
|
has_changes = bool(staged or modified or untracked or deleted)
|
|
|
|
return GitStatus(
|
|
has_changes=has_changes,
|
|
staged=staged,
|
|
modified=modified,
|
|
untracked=untracked,
|
|
deleted=deleted,
|
|
branch=branch,
|
|
ahead=ahead,
|
|
behind=behind,
|
|
)
|
|
|
|
|
|
async def git_diff(repo_path: Path, staged: bool = False) -> str:
|
|
"""Get the diff of changes.
|
|
|
|
Args:
|
|
repo_path: Path to the repository
|
|
staged: If True, show only staged changes (--cached)
|
|
"""
|
|
args = ["diff"]
|
|
if staged:
|
|
args.append("--cached")
|
|
else:
|
|
args.append("HEAD")
|
|
|
|
stdout, _, _ = await _run_git_command(*args, cwd=repo_path, check=False)
|
|
return stdout
|
|
|
|
|
|
async def git_add_all(repo_path: Path) -> None:
|
|
"""Stage all changes including untracked files."""
|
|
await _run_git_command("add", "-A", cwd=repo_path)
|
|
|
|
|
|
async def git_commit(repo_path: Path, message: str) -> CommitResult:
|
|
"""Create a commit with the given message."""
|
|
try:
|
|
stdout, stderr, returncode = await _run_git_command(
|
|
"commit", "-m", message, cwd=repo_path, check=False
|
|
)
|
|
|
|
if returncode != 0:
|
|
# Check for "nothing to commit"
|
|
if "nothing to commit" in stdout or "nothing to commit" in stderr:
|
|
return CommitResult(
|
|
success=False,
|
|
error="Nothing to commit",
|
|
)
|
|
return CommitResult(
|
|
success=False,
|
|
error=stderr or stdout,
|
|
)
|
|
|
|
# Get commit hash
|
|
hash_stdout, _, _ = await _run_git_command(
|
|
"rev-parse", "--short", "HEAD", cwd=repo_path
|
|
)
|
|
|
|
return CommitResult(
|
|
success=True,
|
|
commit_hash=hash_stdout,
|
|
message=message,
|
|
)
|
|
|
|
except GitError as e:
|
|
return CommitResult(
|
|
success=False,
|
|
error=str(e),
|
|
)
|
|
|
|
|
|
async def git_push(
|
|
repo_path: Path,
|
|
remote: str = "origin",
|
|
branch: str = "main",
|
|
) -> PushResult:
|
|
"""Push commits to remote."""
|
|
try:
|
|
_, stderr, returncode = await _run_git_command(
|
|
"push", remote, branch, cwd=repo_path, check=False
|
|
)
|
|
|
|
if returncode != 0:
|
|
rejected = "rejected" in stderr.lower() or "non-fast-forward" in stderr.lower()
|
|
if rejected:
|
|
raise PushRejectedError(
|
|
f"Push rejected: {stderr}",
|
|
stderr=stderr,
|
|
returncode=returncode,
|
|
)
|
|
return PushResult(
|
|
success=False,
|
|
remote=remote,
|
|
branch=branch,
|
|
error=stderr,
|
|
rejected=rejected,
|
|
)
|
|
|
|
return PushResult(
|
|
success=True,
|
|
remote=remote,
|
|
branch=branch,
|
|
)
|
|
|
|
except PushRejectedError:
|
|
raise
|
|
except GitError as e:
|
|
return PushResult(
|
|
success=False,
|
|
remote=remote,
|
|
branch=branch,
|
|
error=str(e),
|
|
)
|
|
|
|
|
|
async def git_pull_rebase(
|
|
repo_path: Path,
|
|
remote: str = "origin",
|
|
branch: str = "main",
|
|
) -> bool:
|
|
"""Pull with rebase to resolve diverged history.
|
|
|
|
Returns True if successful, raises MergeConflictError on conflicts.
|
|
"""
|
|
try:
|
|
_, stderr, returncode = await _run_git_command(
|
|
"pull", "--rebase", remote, branch, cwd=repo_path, check=False
|
|
)
|
|
|
|
if returncode != 0:
|
|
if "conflict" in stderr.lower() or "CONFLICT" in stderr:
|
|
# Abort the rebase
|
|
await _run_git_command("rebase", "--abort", cwd=repo_path, check=False)
|
|
raise MergeConflictError(
|
|
f"Merge conflict during rebase: {stderr}",
|
|
stderr=stderr,
|
|
returncode=returncode,
|
|
)
|
|
raise GitError(
|
|
f"Pull rebase failed: {stderr}",
|
|
stderr=stderr,
|
|
returncode=returncode,
|
|
)
|
|
|
|
return True
|
|
|
|
except (MergeConflictError, GitError):
|
|
raise
|
|
except Exception as e:
|
|
raise GitError(f"Unexpected error during pull rebase: {e}")
|
|
|
|
|
|
async def git_log_recent(
|
|
repo_path: Path,
|
|
count: int = 5,
|
|
) -> list[str]:
|
|
"""Get recent commit messages for style reference."""
|
|
stdout, _, _ = await _run_git_command(
|
|
"log", f"-{count}", "--format=%s", cwd=repo_path, check=False
|
|
)
|
|
return [line for line in stdout.split("\n") if line]
|