diff --git a/readme.md b/readme.md index e3e9071..1b258dc 100644 --- a/readme.md +++ b/readme.md @@ -2,4 +2,9 @@ 1. Run `python setup_venv.py` to create the virtual environment and install dependencies. 2. To activate the virtual environment, run `activate_venv` in the Windows Terminal. -3. To deactivate the virtual environment, simply type `deactivate`. \ No newline at end of file +3. To deactivate the virtual environment, simply type `deactivate`. + + +## Running the code (Any) +1. Run telegram_inference_bot.py after entering the environment +2. now we're cooking with gas! \ No newline at end of file diff --git a/tools/github_tool.py b/tools/github_tool.py index f2e2ce4..a679bbd 100644 --- a/tools/github_tool.py +++ b/tools/github_tool.py @@ -3,6 +3,7 @@ from .base_tool import BaseTool import requests import os import base64 +import logging class GitHubTool(BaseTool): def __init__(self): @@ -13,6 +14,28 @@ class GitHubTool(BaseTool): "Accept": "application/vnd.github.v3+json" } self.repo = os.environ.get("GITHUB_REPOSITORY") + self.current_branch = "main" # Default to main branch + + # Set up logging + self.logger = logging.getLogger(__name__) + self.logger.setLevel(logging.INFO) + + # Create a file handler + file_handler = logging.FileHandler('github_tool.log') + file_handler.setLevel(logging.INFO) + + # Create a console handler + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.INFO) + + # Create a formatting for the logs + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + file_handler.setFormatter(formatter) + console_handler.setFormatter(formatter) + + # Add the handlers to the logger + self.logger.addHandler(file_handler) + self.logger.addHandler(console_handler) def get_functions(self): return [ @@ -55,10 +78,6 @@ class GitHubTool(BaseTool): "parameters": { "type": "object", "properties": { - "branch_name": { - "type": "string", - "description": "Name of the branch to commit to" - }, "file_path": { "type": "string", "description": "Path to the file in the repository" @@ -72,7 +91,7 @@ class GitHubTool(BaseTool): "description": "Commit message" } }, - "required": ["branch_name", "file_path", "content", "commit_message"] + "required": ["file_path", "content", "commit_message"] } }, { @@ -89,17 +108,13 @@ class GitHubTool(BaseTool): "type": "string", "description": "Body of the pull request" }, - "head": { - "type": "string", - "description": "The name of the branch where your changes are implemented" - }, "base": { "type": "string", "description": "The name of the branch you want the changes pulled into", "default": "main" } }, - "required": ["title", "body", "head"] + "required": ["title", "body"] } }, { @@ -162,44 +177,119 @@ class GitHubTool(BaseTool): }, "required": ["branch"] } + }, + { + "name": "get_current_branch", + "description": "Get the name of the current branch", + "parameters": {} + }, + { + "name": "set_current_branch", + "description": "Set the current branch", + "parameters": { + "type": "object", + "properties": { + "branch_name": { + "type": "string", + "description": "Name of the branch to set as current" + } + }, + "required": ["branch_name"] + } + }, + { + "name": "get_file_at_commit", + "description": "Get the contents of a file at a specific commit", + "parameters": { + "type": "object", + "properties": { + "file_path": { + "type": "string", + "description": "Path to the file in the repository" + }, + "commit_sha": { + "type": "string", + "description": "SHA of the commit to retrieve the file from" + } + }, + "required": ["file_path", "commit_sha"] + } + }, + { + "name": "list_branches", + "description": "List all branches in the repository", + "parameters": { + "type": "object", + "properties": { + "per_page": { + "type": "integer", + "description": "Number of branches to return per page (max 100)", + "default": 100 + }, + "all_pages": { + "type": "boolean", + "description": "Whether to fetch all pages of results", + "default": True + } + } + } } ] def execute(self, function_name, **kwargs): + self.logger.info(f"Executing: {function_name}") + if function_name == "read_file": return self._read_file(kwargs["path"]) elif function_name == "create_branch": return self._create_branch(kwargs["branch_name"], kwargs.get("base_branch", "main")) elif function_name == "commit_file": - return self._commit_file(kwargs["branch_name"], kwargs["file_path"], kwargs["content"], kwargs["commit_message"]) + return self._commit_file(kwargs["file_path"], kwargs["content"], kwargs["commit_message"]) elif function_name == "create_pull_request": - return self._create_pull_request(kwargs["title"], kwargs["body"], kwargs["head"], kwargs.get("base", "main")) + return self._create_pull_request(kwargs["title"], kwargs["body"], kwargs.get("base", "main")) elif function_name == "list_files": return self._list_files(kwargs["path"]) elif function_name == "search_code": return self._search_code(kwargs["query"]) elif function_name == "get_commit_history": return self._get_commit_history(kwargs["file_path"], kwargs.get("num_commits", 10)) + elif function_name == "get_current_branch": + return self._get_current_branch() + elif function_name == "set_current_branch": + return self._set_current_branch(kwargs["branch_name"]) + elif function_name == "get_file_at_commit": + return self._get_file_at_commit(kwargs["file_path"], kwargs["commit_sha"]) + elif function_name == "list_branches": + return self._list_branches(kwargs.get("per_page", 100), kwargs.get("all_pages", True)) elif function_name == "get_branch_sha": - return self._get_branch_sha(kwargs["branch"]) + return self._get_branch_sha(kwargs["branch"]) else: - return f"Unknown function: {function_name}" - + error_message = f"Unknown function: {function_name}" + self.logger.error(error_message) + return error_message def _read_file(self, path): + self.logger.info(f"Reading file: {path} from branch: {self.current_branch}") url = f"{self.base_url}/repos/{self.repo}/contents/{path}" - response = requests.get(url, headers=self.headers) + response = requests.get(url, headers=self.headers, params={"ref": self.current_branch}) if response.status_code == 200: content = response.json()["content"] - return content + decoded_content = base64.b64decode(content).decode('utf-8') + self.logger.info(f"Successfully read file: {path}") + return decoded_content else: - return f"Error reading file: {response.status_code}" + error_message = f"Error reading file: {response.status_code}" + self.logger.error(error_message) + return error_message def _create_branch(self, branch_name, base_branch): + self.logger.info(f"Creating branch: {branch_name} from base: {base_branch}") url = f"{self.base_url}/repos/{self.repo}/git/refs" response = requests.get(f"{url}/heads/{base_branch}", headers=self.headers) if response.status_code != 200: - return f"Error getting base branch: {response.status_code}" + error_message = f"Error getting base branch: {response.status_code}" + self.logger.error(error_message) + return error_message sha = response.json()["object"]["sha"] data = { @@ -208,51 +298,70 @@ class GitHubTool(BaseTool): } response = requests.post(url, headers=self.headers, json=data) if response.status_code == 201: - return f"Branch '{branch_name}' created successfully" + self.current_branch = branch_name + success_message = f"Branch '{branch_name}' created successfully and set as current branch" + self.logger.info(success_message) + return success_message else: - return f"Error creating branch: {response.status_code}" + error_message = f"Error creating branch: {response.status_code}" + self.logger.error(error_message) + return error_message - def _commit_file(self, branch_name, file_path, content, commit_message): - if branch_name == "main": - return "Cannot commit directly to main branch" + def _commit_file(self, file_path, content, commit_message): + self.logger.info(f"Committing file: {file_path} to branch: {self.current_branch}") + if self.current_branch == "main": + error_message = "Cannot commit directly to main branch" + self.logger.error(error_message) + return error_message url = f"{self.base_url}/repos/{self.repo}/contents/{file_path}" - # First, check if the file already exists - response = requests.get(url, headers=self.headers, params={"ref": branch_name}) + self.logger.info("Checking if file already exists") + response = requests.get(url, headers=self.headers, params={"ref": self.current_branch}) data = { "message": commit_message, "content": base64.b64encode(content.encode()).decode(), - "branch": branch_name + "branch": self.current_branch } if response.status_code == 200: - # File exists, so we need to update it + self.logger.info("File exists, updating") file_sha = response.json()["sha"] data["sha"] = file_sha + else: + self.logger.info("File does not exist, creating new file") response = requests.put(url, headers=self.headers, json=data) if response.status_code in [200, 201]: - return f"File committed successfully to branch '{branch_name}'" + success_message = f"File committed successfully to branch '{self.current_branch}'" + self.logger.info(success_message) + return success_message else: - return f"Error committing file: {response.status_code}\nResponse: {response.text}" + error_message = f"Error committing file: {response.status_code}\nResponse: {response.text}" + self.logger.error(error_message) + return error_message - def _create_pull_request(self, title, body, head, base): + def _create_pull_request(self, title, body, base): + self.logger.info(f"Creating pull request: {title} from {self.current_branch} to {base}") url = f"{self.base_url}/repos/{self.repo}/pulls" data = { "title": title, "body": body, - "head": head, + "head": self.current_branch, "base": base } response = requests.post(url, headers=self.headers, json=data) if response.status_code == 201: - return f"Pull request created successfully: {response.json()['html_url']}" + success_message = f"Pull request created successfully: {response.json()['html_url']}" + self.logger.info(success_message) + return success_message else: - return f"Error creating pull request: {response.status_code}\nResponse: {response.text}" - + error_message = f"Error creating pull request: {response.status_code}\nResponse: {response.text}" + self.logger.error(error_message) + return error_message + def _get_branch_sha(self, branch): url = f"{self.base_url}/repos/{self.repo}/git/refs/heads/{branch}" response = requests.get(url, headers=self.headers) @@ -261,18 +370,22 @@ class GitHubTool(BaseTool): else: return f"Error getting branch SHA: {response.status_code}" - def _list_files(self, path, branch): - url = f"{self.base_url}/repos/{self.repo}/contents/{path}" - params = {"ref": branch} - response = requests.get(url, headers=self.headers, params=params) - if response.status_code == 200: - files = [item["name"] for item in response.json() if item["type"] == "file"] - directories = [item["name"] for item in response.json() if item["type"] == "dir"] - return {"files": files, "directories": directories} - else: - return f"Error listing files: {response.status_code}" - + def _list_files(self, path): + self.logger.info(f"Listing files in: {path} on branch: {self.current_branch}") + url = f"{self.base_url}/repos/{self.repo}/contents/{path}" + response = requests.get(url, headers=self.headers, params={"ref": self.current_branch}) + if response.status_code == 200: + files = [item["name"] for item in response.json() if item["type"] == "file"] + directories = [item["name"] for item in response.json() if item["type"] == "dir"] + self.logger.info(f"Successfully listed files and directories in {path}") + return {"files": files, "directories": directories} + else: + error_message = f"Error listing files: {response.status_code}" + self.logger.error(error_message) + return error_message + def _search_code(self, query): + self.logger.info(f"Searching code with query: {query}") url = f"{self.base_url}/search/code" params = { "q": f"{query} repo:{self.repo}", @@ -281,11 +394,15 @@ class GitHubTool(BaseTool): response = requests.get(url, headers=self.headers, params=params) if response.status_code == 200: results = [{"file": item["path"], "url": item["html_url"]} for item in response.json()["items"]] + self.logger.info(f"Successfully searched code. Found {len(results)} results.") return results else: - return f"Error searching code: {response.status_code}" + error_message = f"Error searching code: {response.status_code}" + self.logger.error(error_message) + return error_message def _get_commit_history(self, file_path, num_commits): + self.logger.info(f"Getting commit history for file: {file_path}, number of commits: {num_commits}") url = f"{self.base_url}/repos/{self.repo}/commits" params = { "path": file_path, @@ -294,6 +411,61 @@ class GitHubTool(BaseTool): response = requests.get(url, headers=self.headers, params=params) if response.status_code == 200: commits = [{"sha": commit["sha"], "message": commit["commit"]["message"], "date": commit["commit"]["author"]["date"]} for commit in response.json()] + self.logger.info(f"Successfully retrieved commit history. Found {len(commits)} commits.") return commits else: - return f"Error getting commit history: {response.status_code}" \ No newline at end of file + error_message = f"Error getting commit history: {response.status_code}" + self.logger.error(error_message) + return error_message + + def _get_current_branch(self): + self.logger.info(f"Getting current branch: {self.current_branch}") + return self.current_branch + + def _set_current_branch(self, branch_name): + self.logger.info(f"Setting current branch from {self.current_branch} to {branch_name}") + self.current_branch = branch_name + return f"Current branch set to: {self.current_branch}" + + def _get_file_at_commit(self, file_path, commit_sha): + self.logger.info(f"Getting file: {file_path} at commit: {commit_sha}") + url = f"{self.base_url}/repos/{self.repo}/contents/{file_path}" + response = requests.get(url, headers=self.headers, params={"ref": commit_sha}) + if response.status_code == 200: + content = response.json()["content"] + decoded_content = base64.b64decode(content).decode('utf-8') + self.logger.info(f"Successfully retrieved file at commit") + return decoded_content + else: + error_message = f"Error reading file at commit: {response.status_code}" + self.logger.error(error_message) + return error_message + + def _list_branches(self, per_page=100, all_pages=True): + self.logger.info(f"Listing branches. Per page: {per_page}, All pages: {all_pages}") + url = f"{self.base_url}/repos/{self.repo}/branches" + params = {"per_page": min(per_page, 100)} # GitHub API max is 100 per page + all_branches = [] + + while url: + self.logger.info(f"Fetching branches from: {url}") + response = requests.get(url, headers=self.headers, params=params) + if response.status_code != 200: + error_message = f"Error listing branches: {response.status_code}" + self.logger.error(error_message) + return error_message + + branches = [branch["name"] for branch in response.json()] + all_branches.extend(branches) + self.logger.info(f"Fetched {len(branches)} branches") + + if not all_pages: + break + + # Check if there's a next page + url = response.links.get('next', {}).get('url') + if url: + params = {} # Remove per_page for subsequent requests + + self.logger.info(f"Successfully listed all branches. Total: {len(all_branches)}") + return all_branches