diff --git a/src/auto_commit_service/grouping/__init__.py b/src/auto_commit_service/grouping/__init__.py new file mode 100644 index 0000000..138d832 --- /dev/null +++ b/src/auto_commit_service/grouping/__init__.py @@ -0,0 +1,5 @@ +"""File grouping for intelligent multi-commit workflows.""" + +from .strategy import FileGroupingStrategy, FileGroup + +__all__ = ["FileGroupingStrategy", "FileGroup"] diff --git a/src/auto_commit_service/grouping/strategy.py b/src/auto_commit_service/grouping/strategy.py new file mode 100644 index 0000000..9578b5f --- /dev/null +++ b/src/auto_commit_service/grouping/strategy.py @@ -0,0 +1,220 @@ +"""Intelligent file grouping strategy using Mistral 3 14B reasoning.""" + +import json +import logging +from dataclasses import dataclass +from pathlib import Path + +from ..config import AutoCommitSettings +from ..git.repository import Repository +from ..llm.client import LlamaCommitClient, LlamaServiceError +from ..llm.prompts import GROUPING_SYSTEM_PROMPT + +logger = logging.getLogger(__name__) + + +@dataclass +class FileGroup: + """A logical group of files for a single commit.""" + + files: list[str] + reasoning: str + diff_excerpt: str = "" + + +class FileGroupingStrategy: + """Strategy for grouping files into logical commits using LLM reasoning.""" + + def __init__( + self, + llm_client: LlamaCommitClient, + settings: AutoCommitSettings, + ): + """Initialize the grouping strategy. + + Args: + llm_client: LLM client for grouping analysis + settings: Service settings + """ + self.llm_client = llm_client + self.settings = settings + + async def group_files( + self, + repo: Repository, + changed_files: list[str], + diff: str, + ) -> list[FileGroup]: + """Group files into logical commits using Mistral 3 14B. + + Args: + repo: Repository being processed + changed_files: List of all changed file paths + diff: Full diff for context + + Returns: + List of file groups for separate commits + """ + # Check if grouping is enabled + if not self.settings.use_mistral_for_grouping: + logger.debug("Mistral grouping disabled, using single group") + return self._fallback_single_group(changed_files, diff) + + # If too few files, don't bother grouping + if len(changed_files) <= 3: + logger.debug(f"Only {len(changed_files)} files, skipping grouping") + return self._fallback_single_group(changed_files, diff) + + try: + # Get grouping from Mistral 3 14B + logger.info(f"Grouping {len(changed_files)} files with Mistral 3 14B for {repo.name}") + + # Prepare diff summary + diff_summary = self._summarize_diff(diff) + + # Call Mistral for grouping + response = await self.llm_client.group_files( + files=changed_files, + diff_summary=diff_summary, + repo_name=repo.name, + branch="main", # TODO: Get from git status + model=self.settings.grouping_model, + ) + + # Parse JSON response + groups_data = self._parse_grouping_response(response) + + if not groups_data: + logger.warning("Failed to parse grouping response, using fallback") + return self._fallback_single_group(changed_files, diff) + + # Convert to FileGroup objects + groups = [] + for group_data in groups_data: + files = group_data.get("files", []) + reasoning = group_data.get("reasoning", "") + + if not files: + continue + + # Validate files exist in changed_files + valid_files = [f for f in files if f in changed_files] + if not valid_files: + logger.warning(f"Group has no valid files: {reasoning}") + continue + + # Enforce max files per commit + if len(valid_files) > self.settings.max_files_per_commit: + logger.warning( + f"Group exceeds max files ({len(valid_files)} > {self.settings.max_files_per_commit}), " + f"splitting: {reasoning}" + ) + # Split into chunks + for i in range(0, len(valid_files), self.settings.max_files_per_commit): + chunk = valid_files[i : i + self.settings.max_files_per_commit] + groups.append( + FileGroup( + files=chunk, + reasoning=f"{reasoning} (part {i // self.settings.max_files_per_commit + 1})", + ) + ) + else: + groups.append( + FileGroup( + files=valid_files, + reasoning=reasoning, + ) + ) + + # Ensure all files are covered + grouped_files = set() + for group in groups: + grouped_files.update(group.files) + + ungrouped = set(changed_files) - grouped_files + if ungrouped: + logger.warning(f"{len(ungrouped)} files not grouped by LLM, adding catch-all group") + groups.append( + FileGroup( + files=list(ungrouped), + reasoning="Miscellaneous changes not categorized", + ) + ) + + logger.info(f"Created {len(groups)} commit groups for {repo.name}") + return groups + + except LlamaServiceError as e: + logger.error(f"LLM service error during grouping: {e}") + return self._fallback_single_group(changed_files, diff) + except Exception as e: + logger.exception(f"Unexpected error during file grouping: {e}") + return self._fallback_single_group(changed_files, diff) + + def _parse_grouping_response(self, response: str) -> list[dict]: + """Parse JSON response from Mistral. + + Args: + response: Raw LLM response + + Returns: + List of group dictionaries or empty list on failure + """ + try: + # Clean response - remove markdown code blocks if present + cleaned = response.strip() + if cleaned.startswith("```json"): + cleaned = cleaned[7:] + elif cleaned.startswith("```"): + cleaned = cleaned[3:] + if cleaned.endswith("```"): + cleaned = cleaned[:-3] + cleaned = cleaned.strip() + + # Parse JSON + groups = json.loads(cleaned) + + if not isinstance(groups, list): + logger.error("Grouping response is not a list") + return [] + + return groups + + except json.JSONDecodeError as e: + logger.error(f"Failed to parse grouping JSON: {e}") + logger.debug(f"Raw response: {response[:500]}") + return [] + + def _fallback_single_group(self, files: list[str], diff: str) -> list[FileGroup]: + """Fallback to single group when LLM grouping fails or is disabled. + + Args: + files: All changed files + diff: Full diff + + Returns: + Single group containing all files + """ + return [ + FileGroup( + files=files, + reasoning="All changes (grouping disabled or failed)", + diff_excerpt=diff[:500], + ) + ] + + def _summarize_diff(self, diff: str, max_length: int = 1500) -> str: + """Summarize diff for grouping prompt. + + Args: + diff: Full diff + max_length: Maximum length of summary + + Returns: + Truncated diff summary + """ + if len(diff) <= max_length: + return diff + + # Take first part and indicate truncation + return diff[:max_length] + f"\n\n... (truncated, {len(diff) - max_length} chars omitted)" diff --git a/src/auto_commit_service/scheduler/processor.py b/src/auto_commit_service/scheduler/processor.py index e1153f1..bcf9d63 100644 --- a/src/auto_commit_service/scheduler/processor.py +++ b/src/auto_commit_service/scheduler/processor.py @@ -18,7 +18,9 @@ from ..git.operations import ( MergeConflictError, PushRejectedError, git_pull_rebase, + git_add_specific, ) +from ..grouping import FileGroupingStrategy from ..llm import LlamaCommitClient from ..llm.client import LlamaServiceError, LlamaServiceUnavailable from ..models import ProcessStatus, RepoProcessResult @@ -85,10 +87,13 @@ class CommitProcessor: self.llm_client = llm_client self.settings = settings self.error_handler = error_handler + self.grouping_strategy = FileGroupingStrategy(llm_client, settings) async def commit_repo(self, repo: Repository) -> RepoProcessResult: """Check for changes and commit if found (no push). + Uses intelligent file grouping to create multiple logical commits per repo. + Args: repo: Repository to process @@ -117,67 +122,102 @@ class CommitProcessor: status=ProcessStatus.NO_CHANGES, ) + all_files = status.staged + status.modified + status.untracked + status.deleted logger.info( - f"Changes found in {repo.name}: " - f"{len(status.staged)} staged, " - f"{len(status.modified)} modified, " - f"{len(status.untracked)} untracked" + f"Changes found in {repo.name}: {len(all_files)} files total " + f"({len(status.staged)} staged, {len(status.modified)} modified, " + f"{len(status.untracked)} untracked, {len(status.deleted)} deleted)" ) - # Step 3: Get diff for LLM + # Step 3: Get full diff for grouping analysis diff = await git_diff(repo.path) - if not diff: - # Might be only untracked files - still commit them - logger.debug(f"No diff but has changes (likely new files) in {repo.name}") + if not diff and not status.untracked: + logger.debug(f"No diff in {repo.name}") - # Step 4: Generate commit message - summary = summarize_diff(diff) if diff else summarize_diff("") - summary.files_added = len(status.untracked) + # Step 4: Group files using Mistral 3 14B + file_groups = await self.grouping_strategy.group_files( + repo=repo, + changed_files=all_files, + diff=diff, + ) - try: - message = await self.llm_client.generate_commit_message( - diff_summary=summary, - repo_name=repo.name, - branch=status.branch, + logger.info(f"Created {len(file_groups)} commit group(s) for {repo.name}") + + # Step 5: Create commits for each group + commit_results = [] + for idx, group in enumerate(file_groups, 1): + logger.info( + f"Processing group {idx}/{len(file_groups)} for {repo.name}: " + f"{len(group.files)} files - {group.reasoning}" ) - logger.info(f"Generated message for {repo.name}: {message}") - except LlamaServiceUnavailable as e: - logger.error(f"LLM service unavailable for {repo.name}: {e}") + try: + # Stage only files in this group + await git_add_specific(repo.path, group.files) + + # Generate commit message for this group + summary = summarize_diff(diff) if diff else summarize_diff("") + summary.files_added = len([f for f in group.files if f in status.untracked]) + summary.files_modified = len([f for f in group.files if f in status.modified]) + summary.files_deleted = len([f for f in group.files if f in status.deleted]) + + try: + message = await self.llm_client.generate_commit_message( + diff_summary=summary, + repo_name=repo.name, + branch=status.branch, + ) + # Enhance message with group reasoning if useful + logger.info(f"Generated message for group {idx}: {message}") + + except LlamaServiceUnavailable as e: + logger.error(f"LLM service unavailable for {repo.name}: {e}") + return RepoProcessResult( + repo_name=repo.name, + status=ProcessStatus.ERROR, + error=f"LLM service unavailable: {e}", + ) + except LlamaServiceError as e: + logger.error(f"LLM error for {repo.name}: {e}") + # Use fallback message + message = f"chore({repo.name}): 🔧 {group.reasoning[:50]}" + logger.info(f"Using fallback message for group {idx}: {message}") + + # Create commit + commit_result = await git_commit(repo.path, message) + + if not commit_result.success: + logger.error(f"Commit failed for group {idx} in {repo.name}: {commit_result.error}") + # Don't fail entire repo - continue with other groups + continue + + logger.info(f"Committed {commit_result.commit_hash} for group {idx} in {repo.name}") + commit_results.append(commit_result) + + except GitError as e: + logger.error(f"Git error for group {idx} in {repo.name}: {e}") + # Continue with other groups + continue + + # Step 6: Check if any commits succeeded + if not commit_results: + logger.error(f"No commits succeeded for {repo.name}") return RepoProcessResult( repo_name=repo.name, status=ProcessStatus.ERROR, - error=f"LLM service unavailable: {e}", - ) - except LlamaServiceError as e: - logger.error(f"LLM error for {repo.name}: {e}") - # Use fallback message - message = self._fallback_message(status) - logger.info(f"Using fallback message for {repo.name}: {message}") - - # Step 5: Stage all changes - await git_add_all(repo.path) - - # Step 6: Commit - commit_result = await git_commit(repo.path, message) - - if not commit_result.success: - logger.error(f"Commit failed for {repo.name}: {commit_result.error}") - return RepoProcessResult( - repo_name=repo.name, - status=ProcessStatus.ERROR, - error=f"Commit failed: {commit_result.error}", + error="All commit groups failed", ) - logger.info(f"Committed {commit_result.commit_hash} in {repo.name}") + # Return COMMITTED status with info about all commits + commit_hashes = ", ".join(r.commit_hash for r in commit_results) + logger.info(f"Completed {len(commit_results)} commit(s) in {repo.name}: {commit_hashes}") - # Return COMMITTED status (not SUCCESS yet - that comes after push) return RepoProcessResult( repo_name=repo.name, status=ProcessStatus.COMMITTED, - commit_hash=commit_result.commit_hash, - commit_message=message, + commit_hash=commit_results[0].commit_hash, # First commit hash + commit_message=f"{len(commit_results)} commits created", ) except MergeConflictError as e: