diff --git a/tools/github_tool.py b/tools/github_tool.py index 1808084..435a000 100644 --- a/tools/github_tool.py +++ b/tools/github_tool.py @@ -134,6 +134,36 @@ class GitHubTool(BaseTool): }, "_tags": ["write"] }, + { + "type": "function", + "function": { + "name": "commit_file_patch", + "description": "Apply partial line-based edits to a file and commit the result (without requiring the caller to upload the entire file).", + "parameters": { + "type": "object", + "properties": { + "file_path": {"type": "string", "description": "Path to the file in the repository"}, + "commit_message": {"type": "string", "description": "Commit message"}, + "edits": { + "type": "array", + "description": "List of line-based edits. Each edit replaces lines [start_line..end_line] (inclusive) with 'replacement'. If end_line < start_line, the 'replacement' is inserted before start_line.", + "items": { + "type": "object", + "properties": { + "start_line": {"type": "integer", "description": "1-based start line"}, + "end_line": {"type": "integer", "description": "1-based end line (inclusive), or set less than start_line for insertion"}, + "replacement": {"type": "string", "description": "Replacement text for the specified range (can be multi-line)"} + }, + "required": ["start_line", "end_line", "replacement"] + } + }, + "base_sha": {"type": "string", "description": "Optional expected current blob SHA for the file; if provided and does not match, the operation aborts to prevent overwriting newer changes."} + }, + "required": ["file_path", "commit_message", "edits"] + } + }, + "_tags": ["write"] + }, { "type": "function", "function": { @@ -688,7 +718,7 @@ class GitHubTool(BaseTool): return error_message def _commit_file(self, file_path, content, commit_message): - content.strip("'''") + content = content.strip("'''") self.logger.info(f"Committing file: {file_path} to branch: {self.current_branch} with message: '{commit_message}'") if self.current_branch == "main": error_message = "Action directly to main branch is not allowed. Please create and switch to a new branch first." @@ -727,8 +757,131 @@ class GitHubTool(BaseTool): self.logger.error(error_message) return error_message + def _commit_file_patch(self, file_path, edits, commit_message, base_sha=None): + self.logger.info(f"Committing partial edits to file: {file_path} on branch: {self.current_branch}") + if self.current_branch == "main": + error_message = "Action directly to main branch is not allowed. Please create and switch to a new branch first." + self.logger.warning(error_message) + return error_message + + # Fetch current file content and sha + url = f"{self.base_url}/repos/{self._repo}/contents/{file_path}" + get_response = self.session.get(url, params={"ref": self.current_branch}) + if get_response.status_code == 404: + error_message = f"File '{file_path}' not found on branch '{self.current_branch}'. Cannot apply partial edits to a non-existent file." + self.logger.error(error_message) + return error_message + if get_response.status_code != 200: + error_message = f"Error reading file '{file_path}' for patching: {get_response.status_code} - {get_response.text}" + self.logger.error(error_message) + return error_message + + current_sha = get_response.json()["sha"] + if base_sha and base_sha != current_sha: + error_message = ( + f"Abort: base_sha mismatch for '{file_path}'. Expected {base_sha}, current {current_sha}. " + "Please refresh the file content and re-apply your edits." + ) + self.logger.warning(error_message) + return error_message + + content_b64 = get_response.json()["content"] + decoded_content = base64.b64decode(content_b64).decode('utf-8') + + # Apply edits + try: + new_content = self._apply_line_edits(decoded_content, edits) + except Exception as e: + self.logger.error(f"Failed to apply edits to '{file_path}': {e}", exc_info=True) + return f"Error applying edits: {str(e)}" + + if new_content == decoded_content: + msg = f"No changes detected after applying edits to '{file_path}'. Skipping commit." + self.logger.info(msg) + return msg + + # Commit updated content using the Contents API + encoded_content = base64.b64encode(new_content.encode('utf-8')).decode('utf-8') + data = { + "message": commit_message, + "content": encoded_content, + "branch": self.current_branch, + "sha": current_sha + } + put_response = self.session.put(url, json=data) + if put_response.status_code in [200, 201]: + commit_sha = put_response.json().get("commit", {}).get("sha", "N/A") + success_message = f"Partial edits committed to '{file_path}' on branch '{self.current_branch}'. Commit SHA: {commit_sha}" + self.logger.info(success_message) + return success_message + else: + error_message = f"Error committing partial edits to '{file_path}': {put_response.status_code} - {put_response.text}" + self.logger.error(error_message) + return error_message + + def _apply_line_edits(self, original_text, edits): + """ + Apply line-based edits to the provided text. + - Lines are 1-based. + - For each edit: replace lines [start_line..end_line] inclusive with 'replacement'. + If end_line < start_line, insert 'replacement' before start_line (no deletion). + - Preserves the file's original newline style (\n or \r\n) for joins. + """ + # Determine newline style + newline = "\r\n" if "\r\n" in original_text else "\n" + ends_with_newline = original_text.endswith("\n") or original_text.endswith("\r\n") + + # Normalize the working representation to a list of lines without newlines + lines = original_text.replace("\r\n", "\n").replace("\r", "\n").split("\n") + if ends_with_newline and (len(lines) == 0 or lines[-1] != ""): + # split drops the trailing empty element when text ends with newline; add it back as an empty logical line + lines.append("") + + # Validate and apply edits in reverse order to avoid index shifts + sorted_edits = sorted(edits, key=lambda e: (e.get('start_line', 1), e.get('end_line', 0)), reverse=True) + + for idx, edit in enumerate(sorted_edits): + start_line = edit.get('start_line') + end_line = edit.get('end_line') + replacement = edit.get('replacement', "") + + if start_line is None or end_line is None: + raise ValueError(f"Edit #{idx+1} missing start_line or end_line") + if not isinstance(start_line, int) or not isinstance(end_line, int): + raise ValueError(f"Edit #{idx+1} start_line and end_line must be integers") + if start_line < 1: + raise ValueError(f"Edit #{idx+1} start_line must be >= 1") + + # Normalize replacement to list of lines (without trailing newline) + rep_lines = replacement.replace("\r\n", "\n").replace("\r", "\n").split("\n") + + if end_line >= start_line: + # Replace lines in the inclusive range [start_line..end_line] + if end_line > len(lines): + raise ValueError(f"Edit #{idx+1} end_line {end_line} exceeds file length {len(lines)}") + # Python slice end is exclusive; convert to 0-based indices + s = start_line - 1 + e = end_line + lines[s:e] = rep_lines + else: + # Insertion before start_line (no deletion) + if start_line > len(lines) + 1: + raise ValueError(f"Edit #{idx+1} start_line {start_line} is beyond end of file {len(lines)}") + insert_at = start_line - 1 + lines[insert_at:insert_at] = rep_lines + + # Reconstruct the text using the original newline style + text = newline.join(lines) + if not ends_with_newline and text.endswith(newline): + # Original did not end with newline; remove any trailing newline we may have introduced + text = text[:-len(newline)] + elif ends_with_newline and not text.endswith(newline): + # Original ended with newline; ensure we keep it + text += newline + return text + def _create_pull_request(self, title, body, base="main"): - self.logger.info(f"Creating pull request: '{title}' from branch '{self.current_branch}' to '{base}'") + self.logger.info(f"Creating pull request: '{title}' from branch '{self.current_branch}' to '{base}") if self.current_branch == base: error_message = f"Cannot create a pull request from branch '{self.current_branch}' to itself ('{base}')." self.logger.warning(error_message)