Recommit github_tool.py with missing pieces added back

This commit is contained in:
2024-08-20 18:54:57 -05:00
parent 34550d128b
commit ca2b0bd667
+108 -52
View File
@@ -1,3 +1,4 @@
# tools/github_tool.py
from .base_tool import BaseTool from .base_tool import BaseTool
from .metrics import metrics from .metrics import metrics
import requests import requests
@@ -560,75 +561,127 @@ class GitHubTool(BaseTool):
return error_message return error_message
@metrics.measure @metrics.measure
def _create_project_board(self, name, body=None): def _create_branch(self, branch_name, base_branch):
url = f"{self.base_url}/repos/{self.repo}/projects" self.logger.info(f"Creating branch: {branch_name} from base: {base_branch}")
data = {"name": name, "body": body} url = f"{self.base_url}/repos/{self.repo}/git/refs"
response = requests.get(f"{url}/heads/{base_branch}", headers=self.headers)
if response.status_code != 200:
error_message = f"Error getting base branch: {response.status_code}"
self.logger.error(error_message)
return error_message
sha = response.json()["object"]["sha"]
data = {
"ref": f"refs/heads/{branch_name}",
"sha": sha
}
response = requests.post(url, headers=self.headers, json=data) response = requests.post(url, headers=self.headers, json=data)
if response.status_code == 201: if response.status_code == 201:
project = response.json() self.current_branch = branch_name
success_message = f"Project board '{name}' created successfully." success_message = f"Branch '{branch_name}' created successfully and set as current branch"
self.logger.info(success_message) self.logger.info(success_message)
return success_message return success_message
else: else:
error_message = f"Error creating project board: {response.status_code}" error_message = f"Error creating branch: {response.status_code}"
self.logger.error(error_message) self.logger.error(error_message)
return error_message return error_message
@metrics.measure @metrics.measure
def _create_project_column(self, project_id, column_name): def _commit_file(self, file_path, content, commit_message):
url = f"{self.base_url}/projects/{project_id}/columns" self.logger.info(f"Committing file: {file_path} to branch: {self.current_branch}")
data = {"name": column_name} if self.current_branch == "main":
response = requests.post(url, headers=self.headers, json=data) error_message = "Cannot commit directly to main branch"
if response.status_code == 201:
column = response.json()
success_message = f"Column '{column_name}' created successfully in project {project_id}."
self.logger.info(success_message)
return success_message
else:
error_message = f"Error creating project column: {response.status_code}"
self.logger.error(error_message) self.logger.error(error_message)
return error_message return error_message
@metrics.measure url = f"{self.base_url}/repos/{self.repo}/contents/{file_path}"
def _create_project_card(self, column_id, note):
url = f"{self.base_url}/projects/columns/{column_id}/cards"
data = {"note": note}
response = requests.post(url, headers=self.headers, json=data)
if response.status_code == 201:
card = response.json()
success_message = f"Card created successfully in column {column_id}."
self.logger.info(success_message)
return success_message
else:
error_message = f"Error creating project card: {response.status_code}"
self.logger.error(error_message)
return error_message
@metrics.measure self.logger.info("Checking if file already exists")
def _move_project_card(self, card_id, position, column_id): response = requests.get(url, headers=self.headers, params={"ref": self.current_branch})
url = f"{self.base_url}/projects/columns/cards/{card_id}/moves"
data = {"position": position, "column_id": column_id} data = {
response = requests.post(url, headers=self.headers, json=data) "message": commit_message,
if response.status_code == 201: "content": base64.b64encode(content.encode()).decode(),
success_message = f"Card {card_id} moved successfully." "branch": self.current_branch
self.logger.info(success_message) }
return success_message
else:
error_message = f"Error moving project card: {response.status_code}"
self.logger.error(error_message)
return error_message
@metrics.measure
def _link_issue_to_project_card(self, card_id, content_id, content_type):
url = f"{self.base_url}/projects/columns/cards/{card_id}"
data = {"content_id": content_id, "content_type": content_type}
response = requests.patch(url, headers=self.headers, json=data)
if response.status_code == 200: if response.status_code == 200:
success_message = f"Issue/PR linked to card {card_id} successfully." self.logger.info("File exists, updating")
file_sha = response.json()["sha"]
data["sha"] = file_sha
else:
self.logger.info("File does not exist, creating new file")
response = requests.put(url, headers=self.headers, json=data)
if response.status_code in [200, 201]:
success_message = f"File committed successfully to branch '{self.current_branch}'"
self.logger.info(success_message) self.logger.info(success_message)
return success_message return success_message
else: else:
error_message = f"Error linking issue/PR to project card: {response.status_code}" error_message = f"Error committing file: {response.status_code}\nResponse: {response.text}"
self.logger.error(error_message)
return error_message
@metrics.measure
def _create_pull_request(self, title, body, base):
self.logger.info(f"Creating pull request: {title} from {self.current_branch} to {base}")
url = f"{self.base_url}/repos/{self.repo}/pulls"
data = {
"title": title,
"body": body,
"head": self.current_branch,
"base": base
}
response = requests.post(url, headers=self.headers, json=data)
if response.status_code == 201:
success_message = f"Pull request created successfully: {response.json()['html_url']}"
self.logger.info(success_message)
return success_message
else:
error_message = f"Error creating pull request: {response.status_code}\nResponse: {response.text}"
self.logger.error(error_message)
return error_message
@metrics.measure
def _get_branch_sha(self, branch):
url = f"{self.base_url}/repos/{self.repo}/git/refs/heads/{branch}"
response = requests.get(url, headers=self.headers)
if response.status_code == 200:
return response.json()["object"]["sha"]
else:
return f"Error getting branch SHA: {response.status_code}"
@metrics.measure
def _list_files(self, path):
self.logger.info(f"Listing files in: {path} on branch: {self.current_branch}")
url = f"{self.base_url}/repos/{self.repo}/contents/{path}"
response = requests.get(url, headers=self.headers, params={"ref": self.current_branch})
if response.status_code == 200:
files = [item["name"] for item in response.json() if item["type"] == "file"]
directories = [item["name"] for item in response.json() if item["type"] == "dir"]
self.logger.info(f"Successfully listed files and directories in {path}")
return {"files": files, "directories": directories}
else:
error_message = f"Error listing files: {response.status_code}"
self.logger.error(error_message)
return error_message
@metrics.measure
def _search_code(self, query):
self.logger.info(f"Searching code with query: {query}")
url = f"{self.base_url}/search/code"
params = {
"q": f"{query} repo:{self.repo}",
"per_page": 10
}
response = requests.get(url, headers=self.headers, params=params)
if response.status_code == 200:
results = [{"file": item["path"], "url": item["html_url"]} for item in response.json()["items"]]
self.logger.info(f"Successfully searched code. Found {len(results)} results.")
return results
else:
error_message = f"Error searching code: {response.status_code}"
self.logger.error(error_message) self.logger.error(error_message)
return error_message return error_message
@@ -636,7 +689,10 @@ class GitHubTool(BaseTool):
def _get_commit_history(self, file_path, num_commits): def _get_commit_history(self, file_path, num_commits):
self.logger.info(f"Getting commit history for file: {file_path}, number of commits: {num_commits}") self.logger.info(f"Getting commit history for file: {file_path}, number of commits: {num_commits}")
url = f"{self.base_url}/repos/{self.repo}/commits" url = f"{self.base_url}/repos/{self.repo}/commits"
params = {"path": file_path, "per_page": num_commits} params = {
"path": file_path,
"per_page": num_commits
}
response = requests.get(url, headers=self.headers, params=params) response = requests.get(url, headers=self.headers, params=params)
if response.status_code == 200: if response.status_code == 200:
commits = [{"sha": commit["sha"], "message": commit["commit"]["message"], "date": commit["commit"]["author"]["date"]} for commit in response.json()] commits = [{"sha": commit["sha"], "message": commit["commit"]["message"], "date": commit["commit"]["author"]["date"]} for commit in response.json()]