299 lines
12 KiB
Python
299 lines
12 KiB
Python
# tools/github_tool.py
|
|
from .base_tool import BaseTool
|
|
import requests
|
|
import os
|
|
import base64
|
|
|
|
class GitHubTool(BaseTool):
|
|
def __init__(self):
|
|
self.base_url = "https://api.github.com"
|
|
self.token = os.environ.get("GITHUB_TOKEN")
|
|
self.headers = {
|
|
"Authorization": f"token {self.token}",
|
|
"Accept": "application/vnd.github.v3+json"
|
|
}
|
|
self.repo = os.environ.get("GITHUB_REPOSITORY")
|
|
|
|
def get_functions(self):
|
|
return [
|
|
{
|
|
"name": "read_file",
|
|
"description": "Read a file from the repository",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"path": {
|
|
"type": "string",
|
|
"description": "Path to the file in the repository"
|
|
}
|
|
},
|
|
"required": ["path"]
|
|
}
|
|
},
|
|
{
|
|
"name": "create_branch",
|
|
"description": "Create a new branch in the repository",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"branch_name": {
|
|
"type": "string",
|
|
"description": "Name of the new branch"
|
|
},
|
|
"base_branch": {
|
|
"type": "string",
|
|
"description": "Name of the base branch",
|
|
"default": "main"
|
|
}
|
|
},
|
|
"required": ["branch_name"]
|
|
}
|
|
},
|
|
{
|
|
"name": "commit_file",
|
|
"description": "Commit a file to a branch (not main)",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"branch_name": {
|
|
"type": "string",
|
|
"description": "Name of the branch to commit to"
|
|
},
|
|
"file_path": {
|
|
"type": "string",
|
|
"description": "Path to the file in the repository"
|
|
},
|
|
"content": {
|
|
"type": "string",
|
|
"description": "Content of the file"
|
|
},
|
|
"commit_message": {
|
|
"type": "string",
|
|
"description": "Commit message"
|
|
}
|
|
},
|
|
"required": ["branch_name", "file_path", "content", "commit_message"]
|
|
}
|
|
},
|
|
{
|
|
"name": "create_pull_request",
|
|
"description": "Create a pull request",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"title": {
|
|
"type": "string",
|
|
"description": "Title of the pull request"
|
|
},
|
|
"body": {
|
|
"type": "string",
|
|
"description": "Body of the pull request"
|
|
},
|
|
"head": {
|
|
"type": "string",
|
|
"description": "The name of the branch where your changes are implemented"
|
|
},
|
|
"base": {
|
|
"type": "string",
|
|
"description": "The name of the branch you want the changes pulled into",
|
|
"default": "main"
|
|
}
|
|
},
|
|
"required": ["title", "body", "head"]
|
|
}
|
|
},
|
|
{
|
|
"name": "list_files",
|
|
"description": "List files in a directory of the repository",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"path": {
|
|
"type": "string",
|
|
"description": "Path to the directory in the repository"
|
|
}
|
|
},
|
|
"required": ["path"]
|
|
}
|
|
},
|
|
{
|
|
"name": "search_code",
|
|
"description": "Search for code in the repository",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"query": {
|
|
"type": "string",
|
|
"description": "Search query"
|
|
}
|
|
},
|
|
"required": ["query"]
|
|
}
|
|
},
|
|
{
|
|
"name": "get_commit_history",
|
|
"description": "Get commit history for a file",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"file_path": {
|
|
"type": "string",
|
|
"description": "Path to the file in the repository"
|
|
},
|
|
"num_commits": {
|
|
"type": "integer",
|
|
"description": "Number of commits to retrieve",
|
|
"default": 10
|
|
}
|
|
},
|
|
"required": ["file_path"]
|
|
}
|
|
},
|
|
{
|
|
"name": "get_branch_sha",
|
|
"description": "Get the SHA of the latest commit on a branch",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"branch": {
|
|
"type": "string",
|
|
"description": "Name of the branch"
|
|
}
|
|
},
|
|
"required": ["branch"]
|
|
}
|
|
}
|
|
]
|
|
|
|
def execute(self, function_name, **kwargs):
|
|
if function_name == "read_file":
|
|
return self._read_file(kwargs["path"])
|
|
elif function_name == "create_branch":
|
|
return self._create_branch(kwargs["branch_name"], kwargs.get("base_branch", "main"))
|
|
elif function_name == "commit_file":
|
|
return self._commit_file(kwargs["branch_name"], kwargs["file_path"], kwargs["content"], kwargs["commit_message"])
|
|
elif function_name == "create_pull_request":
|
|
return self._create_pull_request(kwargs["title"], kwargs["body"], kwargs["head"], kwargs.get("base", "main"))
|
|
elif function_name == "list_files":
|
|
return self._list_files(kwargs["path"])
|
|
elif function_name == "search_code":
|
|
return self._search_code(kwargs["query"])
|
|
elif function_name == "get_commit_history":
|
|
return self._get_commit_history(kwargs["file_path"], kwargs.get("num_commits", 10))
|
|
elif function_name == "get_branch_sha":
|
|
return self._get_branch_sha(kwargs["branch"])
|
|
else:
|
|
return f"Unknown function: {function_name}"
|
|
|
|
|
|
def _read_file(self, path):
|
|
url = f"{self.base_url}/repos/{self.repo}/contents/{path}"
|
|
response = requests.get(url, headers=self.headers)
|
|
if response.status_code == 200:
|
|
content = response.json()["content"]
|
|
return content
|
|
else:
|
|
return f"Error reading file: {response.status_code}"
|
|
|
|
def _create_branch(self, branch_name, base_branch):
|
|
url = f"{self.base_url}/repos/{self.repo}/git/refs"
|
|
response = requests.get(f"{url}/heads/{base_branch}", headers=self.headers)
|
|
if response.status_code != 200:
|
|
return f"Error getting base branch: {response.status_code}"
|
|
|
|
sha = response.json()["object"]["sha"]
|
|
data = {
|
|
"ref": f"refs/heads/{branch_name}",
|
|
"sha": sha
|
|
}
|
|
response = requests.post(url, headers=self.headers, json=data)
|
|
if response.status_code == 201:
|
|
return f"Branch '{branch_name}' created successfully"
|
|
else:
|
|
return f"Error creating branch: {response.status_code}"
|
|
|
|
def _commit_file(self, branch_name, file_path, content, commit_message):
|
|
if branch_name == "main":
|
|
return "Cannot commit directly to main branch"
|
|
|
|
url = f"{self.base_url}/repos/{self.repo}/contents/{file_path}"
|
|
|
|
# First, check if the file already exists
|
|
response = requests.get(url, headers=self.headers, params={"ref": branch_name})
|
|
|
|
data = {
|
|
"message": commit_message,
|
|
"content": base64.b64encode(content.encode()).decode(),
|
|
"branch": branch_name
|
|
}
|
|
|
|
if response.status_code == 200:
|
|
# File exists, so we need to update it
|
|
file_sha = response.json()["sha"]
|
|
data["sha"] = file_sha
|
|
|
|
response = requests.put(url, headers=self.headers, json=data)
|
|
|
|
if response.status_code in [200, 201]:
|
|
return f"File committed successfully to branch '{branch_name}'"
|
|
else:
|
|
return f"Error committing file: {response.status_code}\nResponse: {response.text}"
|
|
|
|
def _create_pull_request(self, title, body, head, base):
|
|
url = f"{self.base_url}/repos/{self.repo}/pulls"
|
|
data = {
|
|
"title": title,
|
|
"body": body,
|
|
"head": head,
|
|
"base": base
|
|
}
|
|
response = requests.post(url, headers=self.headers, json=data)
|
|
if response.status_code == 201:
|
|
return f"Pull request created successfully: {response.json()['html_url']}"
|
|
else:
|
|
return f"Error creating pull request: {response.status_code}\nResponse: {response.text}"
|
|
|
|
def _get_branch_sha(self, branch):
|
|
url = f"{self.base_url}/repos/{self.repo}/git/refs/heads/{branch}"
|
|
response = requests.get(url, headers=self.headers)
|
|
if response.status_code == 200:
|
|
return response.json()["object"]["sha"]
|
|
else:
|
|
return f"Error getting branch SHA: {response.status_code}"
|
|
|
|
def _list_files(self, path, branch):
|
|
url = f"{self.base_url}/repos/{self.repo}/contents/{path}"
|
|
params = {"ref": branch}
|
|
response = requests.get(url, headers=self.headers, params=params)
|
|
if response.status_code == 200:
|
|
files = [item["name"] for item in response.json() if item["type"] == "file"]
|
|
directories = [item["name"] for item in response.json() if item["type"] == "dir"]
|
|
return {"files": files, "directories": directories}
|
|
else:
|
|
return f"Error listing files: {response.status_code}"
|
|
|
|
def _search_code(self, query):
|
|
url = f"{self.base_url}/search/code"
|
|
params = {
|
|
"q": f"{query} repo:{self.repo}",
|
|
"per_page": 10
|
|
}
|
|
response = requests.get(url, headers=self.headers, params=params)
|
|
if response.status_code == 200:
|
|
results = [{"file": item["path"], "url": item["html_url"]} for item in response.json()["items"]]
|
|
return results
|
|
else:
|
|
return f"Error searching code: {response.status_code}"
|
|
|
|
def _get_commit_history(self, file_path, num_commits):
|
|
url = f"{self.base_url}/repos/{self.repo}/commits"
|
|
params = {
|
|
"path": file_path,
|
|
"per_page": num_commits
|
|
}
|
|
response = requests.get(url, headers=self.headers, params=params)
|
|
if response.status_code == 200:
|
|
commits = [{"sha": commit["sha"], "message": commit["commit"]["message"], "date": commit["commit"]["author"]["date"]} for commit in response.json()]
|
|
return commits
|
|
else:
|
|
return f"Error getting commit history: {response.status_code}" |