diff --git a/tools/github_tool_functions/create_branch.py b/tools/github_tool_functions/create_branch.py new file mode 100644 index 0000000..b103179 --- /dev/null +++ b/tools/github_tool_functions/create_branch.py @@ -0,0 +1,58 @@ +import requests +import logging + +class CreateBranch: + def __init__(self, base_url, token, repo, current_branch): + self.base_url = base_url + self.headers = { + "Authorization": f"token {token}", + "Accept": "application/vnd.github.v3+json" + } + self.repo = repo + self.current_branch = current_branch + + # Set up logging + self.logger = logging.getLogger(__name__) + self.logger.setLevel(logging.INFO) + + # Create a file handler + file_handler = logging.FileHandler('create_branch.log') + file_handler.setLevel(logging.INFO) + + # Create a console handler + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.INFO) + + # Create a formatting for the logs + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + file_handler.setFormatter(formatter) + console_handler.setFormatter(formatter) + + # Add the handlers to the logger + self.logger.addHandler(file_handler) + self.logger.addHandler(console_handler) + + def __call__(self, branch_name, base_branch="main"): + self.logger.info(f"Creating branch: {branch_name} from base: {base_branch}") + url = f"{self.base_url}/repos/{self.repo}/git/refs" + response = requests.get(f"{url}/heads/{base_branch}", headers=self.headers) + if response.status_code != 200: + error_message = f"Error getting base branch: {response.status_code}" + self.logger.error(error_message) + return error_message + + sha = response.json()["object"]["sha"] + data = { + "ref": f"refs/heads/{branch_name}", + "sha": sha + } + response = requests.post(url, headers=self.headers, json=data) + if response.status_code == 201: + self.current_branch = branch_name + success_message = f"Branch '{branch_name}' created successfully and set as current branch" + self.logger.info(success_message) + return success_message + else: + error_message = f"Error creating branch: {response.status_code}" + self.logger.error(error_message) + return error_message