Added features to github tool

This commit is contained in:
2024-08-17 11:30:34 -05:00
parent df0e6ad9ad
commit fd549c203e
2 changed files with 226 additions and 49 deletions
+5
View File
@@ -3,3 +3,8 @@
1. Run `python setup_venv.py` to create the virtual environment and install dependencies. 1. Run `python setup_venv.py` to create the virtual environment and install dependencies.
2. To activate the virtual environment, run `activate_venv` in the Windows Terminal. 2. To activate the virtual environment, run `activate_venv` in the Windows Terminal.
3. To deactivate the virtual environment, simply type `deactivate`. 3. To deactivate the virtual environment, simply type `deactivate`.
## Running the code (Any)
1. Run telegram_inference_bot.py after entering the environment
2. now we're cooking with gas!
+211 -39
View File
@@ -3,6 +3,7 @@ from .base_tool import BaseTool
import requests import requests
import os import os
import base64 import base64
import logging
class GitHubTool(BaseTool): class GitHubTool(BaseTool):
def __init__(self): def __init__(self):
@@ -13,6 +14,28 @@ class GitHubTool(BaseTool):
"Accept": "application/vnd.github.v3+json" "Accept": "application/vnd.github.v3+json"
} }
self.repo = os.environ.get("GITHUB_REPOSITORY") self.repo = os.environ.get("GITHUB_REPOSITORY")
self.current_branch = "main" # Default to main branch
# Set up logging
self.logger = logging.getLogger(__name__)
self.logger.setLevel(logging.INFO)
# Create a file handler
file_handler = logging.FileHandler('github_tool.log')
file_handler.setLevel(logging.INFO)
# Create a console handler
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
# Create a formatting for the logs
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
file_handler.setFormatter(formatter)
console_handler.setFormatter(formatter)
# Add the handlers to the logger
self.logger.addHandler(file_handler)
self.logger.addHandler(console_handler)
def get_functions(self): def get_functions(self):
return [ return [
@@ -55,10 +78,6 @@ class GitHubTool(BaseTool):
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": { "properties": {
"branch_name": {
"type": "string",
"description": "Name of the branch to commit to"
},
"file_path": { "file_path": {
"type": "string", "type": "string",
"description": "Path to the file in the repository" "description": "Path to the file in the repository"
@@ -72,7 +91,7 @@ class GitHubTool(BaseTool):
"description": "Commit message" "description": "Commit message"
} }
}, },
"required": ["branch_name", "file_path", "content", "commit_message"] "required": ["file_path", "content", "commit_message"]
} }
}, },
{ {
@@ -89,17 +108,13 @@ class GitHubTool(BaseTool):
"type": "string", "type": "string",
"description": "Body of the pull request" "description": "Body of the pull request"
}, },
"head": {
"type": "string",
"description": "The name of the branch where your changes are implemented"
},
"base": { "base": {
"type": "string", "type": "string",
"description": "The name of the branch you want the changes pulled into", "description": "The name of the branch you want the changes pulled into",
"default": "main" "default": "main"
} }
}, },
"required": ["title", "body", "head"] "required": ["title", "body"]
} }
}, },
{ {
@@ -162,44 +177,119 @@ class GitHubTool(BaseTool):
}, },
"required": ["branch"] "required": ["branch"]
} }
},
{
"name": "get_current_branch",
"description": "Get the name of the current branch",
"parameters": {}
},
{
"name": "set_current_branch",
"description": "Set the current branch",
"parameters": {
"type": "object",
"properties": {
"branch_name": {
"type": "string",
"description": "Name of the branch to set as current"
}
},
"required": ["branch_name"]
}
},
{
"name": "get_file_at_commit",
"description": "Get the contents of a file at a specific commit",
"parameters": {
"type": "object",
"properties": {
"file_path": {
"type": "string",
"description": "Path to the file in the repository"
},
"commit_sha": {
"type": "string",
"description": "SHA of the commit to retrieve the file from"
}
},
"required": ["file_path", "commit_sha"]
}
},
{
"name": "list_branches",
"description": "List all branches in the repository",
"parameters": {
"type": "object",
"properties": {
"per_page": {
"type": "integer",
"description": "Number of branches to return per page (max 100)",
"default": 100
},
"all_pages": {
"type": "boolean",
"description": "Whether to fetch all pages of results",
"default": True
}
}
}
} }
] ]
def execute(self, function_name, **kwargs): def execute(self, function_name, **kwargs):
self.logger.info(f"Executing: {function_name}")
if function_name == "read_file": if function_name == "read_file":
return self._read_file(kwargs["path"]) return self._read_file(kwargs["path"])
elif function_name == "create_branch": elif function_name == "create_branch":
return self._create_branch(kwargs["branch_name"], kwargs.get("base_branch", "main")) return self._create_branch(kwargs["branch_name"], kwargs.get("base_branch", "main"))
elif function_name == "commit_file": elif function_name == "commit_file":
return self._commit_file(kwargs["branch_name"], kwargs["file_path"], kwargs["content"], kwargs["commit_message"]) return self._commit_file(kwargs["file_path"], kwargs["content"], kwargs["commit_message"])
elif function_name == "create_pull_request": elif function_name == "create_pull_request":
return self._create_pull_request(kwargs["title"], kwargs["body"], kwargs["head"], kwargs.get("base", "main")) return self._create_pull_request(kwargs["title"], kwargs["body"], kwargs.get("base", "main"))
elif function_name == "list_files": elif function_name == "list_files":
return self._list_files(kwargs["path"]) return self._list_files(kwargs["path"])
elif function_name == "search_code": elif function_name == "search_code":
return self._search_code(kwargs["query"]) return self._search_code(kwargs["query"])
elif function_name == "get_commit_history": elif function_name == "get_commit_history":
return self._get_commit_history(kwargs["file_path"], kwargs.get("num_commits", 10)) return self._get_commit_history(kwargs["file_path"], kwargs.get("num_commits", 10))
elif function_name == "get_current_branch":
return self._get_current_branch()
elif function_name == "set_current_branch":
return self._set_current_branch(kwargs["branch_name"])
elif function_name == "get_file_at_commit":
return self._get_file_at_commit(kwargs["file_path"], kwargs["commit_sha"])
elif function_name == "list_branches":
return self._list_branches(kwargs.get("per_page", 100), kwargs.get("all_pages", True))
elif function_name == "get_branch_sha": elif function_name == "get_branch_sha":
return self._get_branch_sha(kwargs["branch"]) return self._get_branch_sha(kwargs["branch"])
else: else:
return f"Unknown function: {function_name}" error_message = f"Unknown function: {function_name}"
self.logger.error(error_message)
return error_message
def _read_file(self, path): def _read_file(self, path):
self.logger.info(f"Reading file: {path} from branch: {self.current_branch}")
url = f"{self.base_url}/repos/{self.repo}/contents/{path}" url = f"{self.base_url}/repos/{self.repo}/contents/{path}"
response = requests.get(url, headers=self.headers) response = requests.get(url, headers=self.headers, params={"ref": self.current_branch})
if response.status_code == 200: if response.status_code == 200:
content = response.json()["content"] content = response.json()["content"]
return content decoded_content = base64.b64decode(content).decode('utf-8')
self.logger.info(f"Successfully read file: {path}")
return decoded_content
else: else:
return f"Error reading file: {response.status_code}" error_message = f"Error reading file: {response.status_code}"
self.logger.error(error_message)
return error_message
def _create_branch(self, branch_name, base_branch): def _create_branch(self, branch_name, base_branch):
self.logger.info(f"Creating branch: {branch_name} from base: {base_branch}")
url = f"{self.base_url}/repos/{self.repo}/git/refs" url = f"{self.base_url}/repos/{self.repo}/git/refs"
response = requests.get(f"{url}/heads/{base_branch}", headers=self.headers) response = requests.get(f"{url}/heads/{base_branch}", headers=self.headers)
if response.status_code != 200: if response.status_code != 200:
return f"Error getting base branch: {response.status_code}" error_message = f"Error getting base branch: {response.status_code}"
self.logger.error(error_message)
return error_message
sha = response.json()["object"]["sha"] sha = response.json()["object"]["sha"]
data = { data = {
@@ -208,50 +298,69 @@ class GitHubTool(BaseTool):
} }
response = requests.post(url, headers=self.headers, json=data) response = requests.post(url, headers=self.headers, json=data)
if response.status_code == 201: if response.status_code == 201:
return f"Branch '{branch_name}' created successfully" self.current_branch = branch_name
success_message = f"Branch '{branch_name}' created successfully and set as current branch"
self.logger.info(success_message)
return success_message
else: else:
return f"Error creating branch: {response.status_code}" error_message = f"Error creating branch: {response.status_code}"
self.logger.error(error_message)
return error_message
def _commit_file(self, branch_name, file_path, content, commit_message): def _commit_file(self, file_path, content, commit_message):
if branch_name == "main": self.logger.info(f"Committing file: {file_path} to branch: {self.current_branch}")
return "Cannot commit directly to main branch" if self.current_branch == "main":
error_message = "Cannot commit directly to main branch"
self.logger.error(error_message)
return error_message
url = f"{self.base_url}/repos/{self.repo}/contents/{file_path}" url = f"{self.base_url}/repos/{self.repo}/contents/{file_path}"
# First, check if the file already exists self.logger.info("Checking if file already exists")
response = requests.get(url, headers=self.headers, params={"ref": branch_name}) response = requests.get(url, headers=self.headers, params={"ref": self.current_branch})
data = { data = {
"message": commit_message, "message": commit_message,
"content": base64.b64encode(content.encode()).decode(), "content": base64.b64encode(content.encode()).decode(),
"branch": branch_name "branch": self.current_branch
} }
if response.status_code == 200: if response.status_code == 200:
# File exists, so we need to update it self.logger.info("File exists, updating")
file_sha = response.json()["sha"] file_sha = response.json()["sha"]
data["sha"] = file_sha data["sha"] = file_sha
else:
self.logger.info("File does not exist, creating new file")
response = requests.put(url, headers=self.headers, json=data) response = requests.put(url, headers=self.headers, json=data)
if response.status_code in [200, 201]: if response.status_code in [200, 201]:
return f"File committed successfully to branch '{branch_name}'" success_message = f"File committed successfully to branch '{self.current_branch}'"
self.logger.info(success_message)
return success_message
else: else:
return f"Error committing file: {response.status_code}\nResponse: {response.text}" error_message = f"Error committing file: {response.status_code}\nResponse: {response.text}"
self.logger.error(error_message)
return error_message
def _create_pull_request(self, title, body, head, base): def _create_pull_request(self, title, body, base):
self.logger.info(f"Creating pull request: {title} from {self.current_branch} to {base}")
url = f"{self.base_url}/repos/{self.repo}/pulls" url = f"{self.base_url}/repos/{self.repo}/pulls"
data = { data = {
"title": title, "title": title,
"body": body, "body": body,
"head": head, "head": self.current_branch,
"base": base "base": base
} }
response = requests.post(url, headers=self.headers, json=data) response = requests.post(url, headers=self.headers, json=data)
if response.status_code == 201: if response.status_code == 201:
return f"Pull request created successfully: {response.json()['html_url']}" success_message = f"Pull request created successfully: {response.json()['html_url']}"
self.logger.info(success_message)
return success_message
else: else:
return f"Error creating pull request: {response.status_code}\nResponse: {response.text}" error_message = f"Error creating pull request: {response.status_code}\nResponse: {response.text}"
self.logger.error(error_message)
return error_message
def _get_branch_sha(self, branch): def _get_branch_sha(self, branch):
url = f"{self.base_url}/repos/{self.repo}/git/refs/heads/{branch}" url = f"{self.base_url}/repos/{self.repo}/git/refs/heads/{branch}"
@@ -261,18 +370,22 @@ class GitHubTool(BaseTool):
else: else:
return f"Error getting branch SHA: {response.status_code}" return f"Error getting branch SHA: {response.status_code}"
def _list_files(self, path, branch): def _list_files(self, path):
self.logger.info(f"Listing files in: {path} on branch: {self.current_branch}")
url = f"{self.base_url}/repos/{self.repo}/contents/{path}" url = f"{self.base_url}/repos/{self.repo}/contents/{path}"
params = {"ref": branch} response = requests.get(url, headers=self.headers, params={"ref": self.current_branch})
response = requests.get(url, headers=self.headers, params=params)
if response.status_code == 200: if response.status_code == 200:
files = [item["name"] for item in response.json() if item["type"] == "file"] files = [item["name"] for item in response.json() if item["type"] == "file"]
directories = [item["name"] for item in response.json() if item["type"] == "dir"] directories = [item["name"] for item in response.json() if item["type"] == "dir"]
self.logger.info(f"Successfully listed files and directories in {path}")
return {"files": files, "directories": directories} return {"files": files, "directories": directories}
else: else:
return f"Error listing files: {response.status_code}" error_message = f"Error listing files: {response.status_code}"
self.logger.error(error_message)
return error_message
def _search_code(self, query): def _search_code(self, query):
self.logger.info(f"Searching code with query: {query}")
url = f"{self.base_url}/search/code" url = f"{self.base_url}/search/code"
params = { params = {
"q": f"{query} repo:{self.repo}", "q": f"{query} repo:{self.repo}",
@@ -281,11 +394,15 @@ class GitHubTool(BaseTool):
response = requests.get(url, headers=self.headers, params=params) response = requests.get(url, headers=self.headers, params=params)
if response.status_code == 200: if response.status_code == 200:
results = [{"file": item["path"], "url": item["html_url"]} for item in response.json()["items"]] results = [{"file": item["path"], "url": item["html_url"]} for item in response.json()["items"]]
self.logger.info(f"Successfully searched code. Found {len(results)} results.")
return results return results
else: else:
return f"Error searching code: {response.status_code}" error_message = f"Error searching code: {response.status_code}"
self.logger.error(error_message)
return error_message
def _get_commit_history(self, file_path, num_commits): def _get_commit_history(self, file_path, num_commits):
self.logger.info(f"Getting commit history for file: {file_path}, number of commits: {num_commits}")
url = f"{self.base_url}/repos/{self.repo}/commits" url = f"{self.base_url}/repos/{self.repo}/commits"
params = { params = {
"path": file_path, "path": file_path,
@@ -294,6 +411,61 @@ class GitHubTool(BaseTool):
response = requests.get(url, headers=self.headers, params=params) response = requests.get(url, headers=self.headers, params=params)
if response.status_code == 200: if response.status_code == 200:
commits = [{"sha": commit["sha"], "message": commit["commit"]["message"], "date": commit["commit"]["author"]["date"]} for commit in response.json()] commits = [{"sha": commit["sha"], "message": commit["commit"]["message"], "date": commit["commit"]["author"]["date"]} for commit in response.json()]
self.logger.info(f"Successfully retrieved commit history. Found {len(commits)} commits.")
return commits return commits
else: else:
return f"Error getting commit history: {response.status_code}" error_message = f"Error getting commit history: {response.status_code}"
self.logger.error(error_message)
return error_message
def _get_current_branch(self):
self.logger.info(f"Getting current branch: {self.current_branch}")
return self.current_branch
def _set_current_branch(self, branch_name):
self.logger.info(f"Setting current branch from {self.current_branch} to {branch_name}")
self.current_branch = branch_name
return f"Current branch set to: {self.current_branch}"
def _get_file_at_commit(self, file_path, commit_sha):
self.logger.info(f"Getting file: {file_path} at commit: {commit_sha}")
url = f"{self.base_url}/repos/{self.repo}/contents/{file_path}"
response = requests.get(url, headers=self.headers, params={"ref": commit_sha})
if response.status_code == 200:
content = response.json()["content"]
decoded_content = base64.b64decode(content).decode('utf-8')
self.logger.info(f"Successfully retrieved file at commit")
return decoded_content
else:
error_message = f"Error reading file at commit: {response.status_code}"
self.logger.error(error_message)
return error_message
def _list_branches(self, per_page=100, all_pages=True):
self.logger.info(f"Listing branches. Per page: {per_page}, All pages: {all_pages}")
url = f"{self.base_url}/repos/{self.repo}/branches"
params = {"per_page": min(per_page, 100)} # GitHub API max is 100 per page
all_branches = []
while url:
self.logger.info(f"Fetching branches from: {url}")
response = requests.get(url, headers=self.headers, params=params)
if response.status_code != 200:
error_message = f"Error listing branches: {response.status_code}"
self.logger.error(error_message)
return error_message
branches = [branch["name"] for branch in response.json()]
all_branches.extend(branches)
self.logger.info(f"Fetched {len(branches)} branches")
if not all_pages:
break
# Check if there's a next page
url = response.links.get('next', {}).get('url')
if url:
params = {} # Remove per_page for subsequent requests
self.logger.info(f"Successfully listed all branches. Total: {len(all_branches)}")
return all_branches