import unittest from unittest.mock import MagicMock, patch import os import base64 import logging import requests # Required for spec in MagicMock # Ensure tools/github_tool.py is accessible from tools.github_tool import GitHubTool # Helper to create a mock response for requests.Session def create_mock_response(status_code, json_data=None, text_data=None, headers=None, links=None): mock_resp = MagicMock() mock_resp.status_code = status_code if json_data is not None: mock_resp.json = MagicMock(return_value=json_data) mock_resp.text = text_data if text_data is not None else str(json_data) mock_resp.headers = headers if headers else {} mock_resp.links = links if links else {} # For pagination in _list_branches return mock_resp class TestGitHubTool(unittest.TestCase): def setUp(self): self.mock_session = MagicMock(spec=requests.Session) self.mock_session.headers = {} # Simulate a new session's headers self.test_token = "test_github_token" self.test_repo = "owner/repo" self.test_base_url = "https://api.example.com" # Use a non-default base_url for some tests # Suppress logging output during tests unless explicitly testing for it self.logger = logging.getLogger('tools.github_tool') # Ensure only one NullHandler to prevent duplicate messages if tests run multiple times in a session if not any(isinstance(h, logging.NullHandler) for h in self.logger.handlers): self.logger.addHandler(logging.NullHandler()) self.logger.propagate = False # Prevent propagation to root logger if it has handlers def test_init_with_args_and_session(self): tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, base_url=self.test_base_url, logger=self.logger) self.assertEqual(tool.session, self.mock_session) self.assertEqual(tool._token, self.test_token) self.assertEqual(tool._repo, self.test_repo) self.assertEqual(tool.base_url, self.test_base_url) self.assertEqual(tool.current_branch, "main") # Default initial branch @patch('requests.Session') def test_init_creates_session_if_not_provided(self, MockSessionConstructor): mock_created_session = MagicMock(spec=requests.Session) mock_created_session.headers = {} MockSessionConstructor.return_value = mock_created_session # Temporarily set env vars for this test original_token = os.environ.get("GITHUB_TOKEN") original_repo = os.environ.get("GITHUB_REPOSITORY") os.environ["GITHUB_TOKEN"] = "env_token" os.environ["GITHUB_REPOSITORY"] = "env/repo" tool = GitHubTool(logger=self.logger) # Use env vars MockSessionConstructor.assert_called_once() self.assertEqual(tool.session, mock_created_session) self.assertEqual(tool._token, "env_token") self.assertEqual(tool._repo, "env/repo") self.assertIn("Authorization", mock_created_session.headers) self.assertEqual(mock_created_session.headers["Authorization"], "token env_token") # Restore original env vars if original_token is None: del os.environ["GITHUB_TOKEN"] else: os.environ["GITHUB_TOKEN"] = original_token if original_repo is None: del os.environ["GITHUB_REPOSITORY"] else: os.environ["GITHUB_REPOSITORY"] = original_repo def test_init_raises_value_error_if_no_token(self): original_token = os.environ.pop("GITHUB_TOKEN", None) with self.assertRaisesRegex(ValueError, "GitHub token must be provided"): GitHubTool(repo=self.test_repo, logger=self.logger) if original_token: os.environ["GITHUB_TOKEN"] = original_token def test_init_raises_value_error_if_no_repo(self): original_repo = os.environ.pop("GITHUB_REPOSITORY", None) with self.assertRaisesRegex(ValueError, "GitHub repository.*must be provided"): GitHubTool(token=self.test_token, logger=self.logger) if original_repo: os.environ["GITHUB_REPOSITORY"] = original_repo def test_clear_resets_branch(self): tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, initial_branch="feature-branch", logger=self.logger) # Mock _get_branch_sha for _set_current_branch called by clear with patch.object(tool, '_get_branch_sha', return_value="sha_for_main"): tool.clear() self.assertEqual(tool.current_branch, "main") def test_get_functions_returns_list(self): tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) functions = tool.get_functions() self.assertIsInstance(functions, list) self.assertTrue(len(functions) > 0) self.assertIn("name", functions[0]["function"]) # --- Test individual private methods --- def test_read_file_success(self): tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) file_content = "Hello World!" encoded_content = base64.b64encode(file_content.encode('utf-8')).decode('utf-8') self.mock_session.get.return_value = create_mock_response(200, json_data={"content": encoded_content}) result = tool._read_file(path="test.txt") self.assertEqual(result, file_content) self.mock_session.get.assert_called_once_with( f"{tool.base_url}/repos/{self.test_repo}/contents/test.txt", params={"ref": "main"} ) def test_read_file_error(self): tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) self.mock_session.get.return_value = create_mock_response(404, text_data="Not Found") result = tool._read_file(path="nonexistent.txt") self.assertIn("Error reading file", result) def test_create_branch_success(self): tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) # Mock getting base branch SHA self.mock_session.get.return_value = create_mock_response(200, json_data={"object": {"sha": "base_sha123"}}) # Mock creating new branch self.mock_session.post.return_value = create_mock_response(201, json_data={"ref": "refs/heads/new-feature"}) result = tool._create_branch(branch_name="new-feature", base_branch="main") self.assertIn("Branch 'new-feature' created successfully", result) self.assertEqual(tool.current_branch, "new-feature") self.mock_session.get.assert_called_once() # For base branch SHA self.mock_session.post.assert_called_once() # For creating branch def test_create_branch_base_sha_error(self): tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) self.mock_session.get.return_value = create_mock_response(404, text_data="Base branch not found") result = tool._create_branch(branch_name="new-feature", base_branch="nonexistent-base") self.assertIn("Error getting base branch SHA", result) def test_create_branch_creation_error(self): tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) self.mock_session.get.return_value = create_mock_response(200, json_data={"object": {"sha": "base_sha456"}}) self.mock_session.post.return_value = create_mock_response(422, text_data="Validation failed") result = tool._create_branch(branch_name="bad-branch", base_branch="main") self.assertIn("Error creating branch", result) def test_commit_file_success_new_file(self): tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) tool.current_branch = "dev-branch" # Cannot commit to main by default # Mock GET for checking file existence (404 means new file) self.mock_session.get.return_value = create_mock_response(404) # Mock PUT for committing file self.mock_session.put.return_value = create_mock_response(201, json_data={"commit": {"sha": "commit_sha_abc"}}) result = tool._commit_file(file_path="new_file.py", content="print('Hello')", commit_message="Add new_file.py") self.assertIn("committed successfully", result) self.assertIn("commit_sha_abc", result) self.mock_session.get.assert_called_once() # Check file existence self.mock_session.put.assert_called_once() # Commit file def test_commit_file_success_update_file(self): tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) tool.current_branch = "dev-branch" # Mock GET for checking file existence (200 means existing file) self.mock_session.get.return_value = create_mock_response(200, json_data={"sha": "existing_file_sha"}) # Mock PUT for committing file self.mock_session.put.return_value = create_mock_response(200, json_data={"commit": {"sha": "commit_sha_def"}}) result = tool._commit_file(file_path="existing_file.txt", content="Updated content", commit_message="Update existing_file.txt") self.assertIn("committed successfully", result) self.assertIn("commit_sha_def", result) args, kwargs = self.mock_session.put.call_args self.assertEqual(kwargs['json']['sha'], "existing_file_sha") def test_commit_file_to_main_branch_error(self): tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) tool.current_branch = "main" result = tool._commit_file(file_path="some.txt", content="content", commit_message="msg") self.assertIn("Action directly to main branch is not allowed", result) def test_create_pull_request_success(self): tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) tool.current_branch = "feature-pr" pr_url = "https://example.com/pull/1" self.mock_session.post.return_value = create_mock_response(201, json_data={"html_url": pr_url, "number": 1}) result = tool._create_pull_request(title="New Feature PR", body="Please review.", base="main") self.assertIn(f"Pull request created successfully: {pr_url}", result) self.mock_session.post.assert_called_once() call_data = self.mock_session.post.call_args[1]['json'] self.assertEqual(call_data['head'], "feature-pr") self.assertEqual(call_data['base'], "main") def test_create_pull_request_same_branch_error(self): tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) tool.current_branch = "main" result = tool._create_pull_request(title="PR to self", body="This should fail", base="main") self.assertIn("Cannot create a pull request from branch 'main' to itself", result) def test_list_files_success(self): tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) mock_items = [ {"name": "file1.txt", "type": "file", "path": "dir/file1.txt"}, {"name": "subdir", "type": "dir", "path": "dir/subdir"} ] self.mock_session.get.return_value = create_mock_response(200, json_data=mock_items) result = tool._list_files(path="dir") self.assertEqual(len(result), 2) self.assertEqual(result[0]["name"], "file1.txt") self.assertEqual(result[1]["type"], "dir") def test_search_code_success(self): tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) mock_search_results = { "items": [{"path": "src/code.py", "html_url": "url1"}] } self.mock_session.get.return_value = create_mock_response(200, json_data=mock_search_results) results = tool._search_code(query="my_function") self.assertEqual(len(results), 1) self.assertEqual(results[0]["path"], "src/code.py") def test_get_commit_history_success(self): tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) mock_commits = [{ "sha": "sha1", "commit": {"message": "Msg1", "author": {"name": "Authy", "date": "Date1"}} }] self.mock_session.get.return_value = create_mock_response(200, json_data=mock_commits) commits = tool._get_commit_history(file_path="file.txt", num_commits=1) self.assertEqual(len(commits), 1) self.assertEqual(commits[0]["sha"], "sha1") def test_set_current_branch_success(self): tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) # Mock _get_branch_sha to simulate branch exists with patch.object(tool, '_get_branch_sha', return_value="some_sha_for_dev"): result = tool._set_current_branch(branch_name="dev") self.assertEqual(tool.current_branch, "dev") self.assertIn("Current branch set to: dev", result) def test_set_current_branch_not_exists(self): tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) with patch.object(tool, '_get_branch_sha', return_value="Error getting SHA for branch"): result = tool._set_current_branch(branch_name="nonexistent-branch") self.assertNotEqual(tool.current_branch, "nonexistent-branch") # Should not change self.assertIn("Cannot set current branch", result) def test_list_branches_single_page(self): tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) mock_branches = [{"name": "main"}, {"name": "dev"}] self.mock_session.get.return_value = create_mock_response(200, json_data=mock_branches, links={}) # No "next" link branches = tool._list_branches(all_pages=True) self.assertEqual(branches, ["main", "dev"]) self.mock_session.get.assert_called_once() def test_list_branches_multiple_pages(self): tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) # Page 1 response page1_branches = [{"name": "branch1"}, {"name": "branch2"}] next_url = f"{tool.base_url}/repos/{self.test_repo}/branches?page=2" response1 = create_mock_response(200, json_data=page1_branches, links={"next": {"url": next_url}}) # Page 2 response page2_branches = [{"name": "branch3"}] response2 = create_mock_response(200, json_data=page2_branches, links={}) # No "next" link self.mock_session.get.side_effect = [response1, response2] branches = tool._list_branches(all_pages=True) self.assertEqual(branches, ["branch1", "branch2", "branch3"]) self.assertEqual(self.mock_session.get.call_count, 2) # Check that the second call used the next_url calls = self.mock_session.get.call_args_list self.assertEqual(calls[1][0][0], next_url) # args[0] is the URL # --- Test execute dispatcher --- def test_execute_read_file(self): tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) with patch.object(tool, '_read_file', return_value="file content") as mock_method: result = tool.execute(function_name="read_file", path="test.md") mock_method.assert_called_once_with(path="test.md") self.assertEqual(result, "file content") def test_execute_unknown_function(self): tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) result = tool.execute(function_name="non_existent_function_name", arg1="val1") self.assertIn("Unknown function: non_existent_function_name", result) def test_execute_method_exception(self): tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) with patch.object(tool, '_read_file', side_effect=Exception("Kaboom")) as mock_method: result = tool.execute(function_name="read_file", path="crash.txt") self.assertIn("Error during read_file execution: Kaboom", result) if __name__ == '__main__': unittest.main()