From 7271c5da8475061f90a2b0afc8373eff3727b5ed Mon Sep 17 00:00:00 2001 From: bucolucas Date: Mon, 19 Aug 2024 16:01:24 -0500 Subject: [PATCH] Refactor github_tool.py to use consistent JSON definitions --- tools/github_tool.py | 86 ++++++++++++-------------------------------- 1 file changed, 23 insertions(+), 63 deletions(-) diff --git a/tools/github_tool.py b/tools/github_tool.py index ac38afa..4cbb8f8 100644 --- a/tools/github_tool.py +++ b/tools/github_tool.py @@ -1,23 +1,5 @@ -from .github_tool_functions.read_file import ReadFile, read_file_definition -from .github_tool_functions.create_branch import CreateBranch, create_branch_definition -from .github_tool_functions.commit_file import CommitFile, commit_file_definition -from .github_tool_functions.create_pull_request import CreatePullRequest, create_pull_request_definition -from .github_tool_functions.list_files import ListFiles, list_files_definition -from .github_tool_functions.search_code import SearchCode, search_code_definition -from .github_tool_functions.get_commit_history import GetCommitHistory, get_commit_history_definition -from .github_tool_functions.get_current_branch import GetCurrentBranch, get_current_branch_definition -from .github_tool_functions.set_current_branch import SetCurrentBranch, set_current_branch_definition -from .github_tool_functions.get_file_at_commit import GetFileAtCommit, get_file_at_commit_definition -from .github_tool_functions.list_branches import ListBranches, list_branches_definition -from .github_tool_functions.get_branch_sha import GetBranchSHA, get_branch_sha_definition -from .github_tool_functions.approve_pull_request import ApprovePullRequest, approve_pull_request_definition -from .github_tool_functions.close_pull_request import ClosePullRequest, close_pull_request_definition -from .github_tool_functions.merge_pull_request import MergePullRequest, merge_pull_request_definition -from .github_tool_functions.delete_branch import DeleteBranch, delete_branch_definition -from .github_tool_functions.get_issue_details import GetIssueDetails, get_issue_details_definition -from .github_tool_functions.create_issue import CreateIssue, create_issue_definition -from .github_tool_functions.list_issues import ListIssues, list_issues_definition - +import os +import json class GitHubTool: def __init__(self): @@ -26,27 +8,26 @@ class GitHubTool: self.repo = os.environ.get("GITHUB_REPOSITORY") self.current_branch = "main" # Default to main branch - self.instances = { - "read_file": ReadFile(self.base_url, self.token, self.repo, self.current_branch), - "create_branch": CreateBranch(self.base_url, self.token, self.repo, self.current_branch), - "commit_file": CommitFile(self.base_url, self.token, self.repo, self.current_branch), - "create_pull_request": CreatePullRequest(self.base_url, self.token, self.repo, self.current_branch), - "list_files": ListFiles(self.base_url, self.token, self.repo, self.current_branch), - "search_code": SearchCode(self.base_url, self.token, self.repo), - "get_commit_history": GetCommitHistory(self.base_url, self.token, self.repo), - "get_current_branch": GetCurrentBranch(self.current_branch), - "set_current_branch": SetCurrentBranch(self.current_branch), - "get_file_at_commit": GetFileAtCommit(self.base_url, self.token, self.repo), - "list_branches": ListBranches(self.base_url, self.token, self.repo), - "get_branch_sha": GetBranchSHA(self.base_url, self.token, self.repo), - "approve_pull_request": ApprovePullRequest(self.base_url, self.token, self.repo), - "close_pull_request": ClosePullRequest(self.base_url, self.token, self.repo), - "merge_pull_request": MergePullRequest(self.base_url, self.token, self.repo), - "delete_branch": DeleteBranch(self.base_url, self.token, self.repo), - "get_issue_details": GetIssueDetails(self.base_url, self.token, self.repo), - "create_issue": CreateIssue(self.base_url, self.token, self.repo), - "list_issues": ListIssues(self.base_url, self.token, self.repo) - } + self.instances = {} + self.functions = [] + self._load_functions() + + def _load_functions(self): + function_dir = os.path.join(os.path.dirname(__file__), "github_tool_functions") + for filename in os.listdir(function_dir): + if filename.endswith(".py") and filename != "__init__.py": + function_name = filename[:-3] + module = __import__(f"tools.github_tool_functions.{function_name}", fromlist=[function_name.capitalize()]) + class_name = getattr(module, function_name.capitalize()) + self.instances[function_name] = class_name(self.base_url, self.token, self.repo, self.current_branch) + + with open(os.path.join(function_dir, filename), 'r') as f: + content = f.read() + json_start = content.find('{') + json_end = content.rfind('}') + 1 + if json_start != -1 and json_end != -1: + json_def = json.loads(content[json_start:json_end]) + self.functions.append(json_def) def execute(self, function_name, **kwargs): if function_name in self.instances: @@ -55,26 +36,5 @@ class GitHubTool: error_message = f"Unknown function: {function_name}" return {"error": error_message} - def get_functions(self): - return [ - read_file_definition, - create_branch_definition, - commit_file_definition, - create_pull_request_definition, - list_files_definition, - search_code_definition, - get_commit_history_definition, - get_current_branch_definition, - set_current_branch_definition, - get_file_at_commit_definition, - list_branches_definition, - get_branch_sha_definition, - approve_pull_request_definition, - close_pull_request_definition, - merge_pull_request_definition, - delete_branch_definition, - get_issue_details_definition, - create_issue_definition, - list_issues_definition - ] \ No newline at end of file + return self.functions \ No newline at end of file