Add unit tests for refactored GitHubTool
This commit is contained in:
@@ -0,0 +1,307 @@
|
|||||||
|
import unittest
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
import os
|
||||||
|
import base64
|
||||||
|
import logging
|
||||||
|
import requests # Required for spec in MagicMock
|
||||||
|
|
||||||
|
# Ensure tools/github_tool.py is accessible
|
||||||
|
from tools.github_tool import GitHubTool
|
||||||
|
|
||||||
|
# Helper to create a mock response for requests.Session
|
||||||
|
def create_mock_response(status_code, json_data=None, text_data=None, headers=None, links=None):
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.status_code = status_code
|
||||||
|
if json_data is not None:
|
||||||
|
mock_resp.json = MagicMock(return_value=json_data)
|
||||||
|
mock_resp.text = text_data if text_data is not None else str(json_data)
|
||||||
|
mock_resp.headers = headers if headers else {}
|
||||||
|
mock_resp.links = links if links else {} # For pagination in _list_branches
|
||||||
|
return mock_resp
|
||||||
|
|
||||||
|
class TestGitHubTool(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.mock_session = MagicMock(spec=requests.Session)
|
||||||
|
self.mock_session.headers = {} # Simulate a new session's headers
|
||||||
|
|
||||||
|
self.test_token = "test_github_token"
|
||||||
|
self.test_repo = "owner/repo"
|
||||||
|
self.test_base_url = "https://api.example.com" # Use a non-default base_url for some tests
|
||||||
|
|
||||||
|
# Suppress logging output during tests unless explicitly testing for it
|
||||||
|
self.logger = logging.getLogger('tools.github_tool')
|
||||||
|
# Ensure only one NullHandler to prevent duplicate messages if tests run multiple times in a session
|
||||||
|
if not any(isinstance(h, logging.NullHandler) for h in self.logger.handlers):
|
||||||
|
self.logger.addHandler(logging.NullHandler())
|
||||||
|
self.logger.propagate = False # Prevent propagation to root logger if it has handlers
|
||||||
|
|
||||||
|
def test_init_with_args_and_session(self):
|
||||||
|
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, base_url=self.test_base_url, logger=self.logger)
|
||||||
|
self.assertEqual(tool.session, self.mock_session)
|
||||||
|
self.assertEqual(tool._token, self.test_token)
|
||||||
|
self.assertEqual(tool._repo, self.test_repo)
|
||||||
|
self.assertEqual(tool.base_url, self.test_base_url)
|
||||||
|
self.assertEqual(tool.current_branch, "main") # Default initial branch
|
||||||
|
|
||||||
|
@patch('requests.Session')
|
||||||
|
def test_init_creates_session_if_not_provided(self, MockSessionConstructor):
|
||||||
|
mock_created_session = MagicMock(spec=requests.Session)
|
||||||
|
mock_created_session.headers = {}
|
||||||
|
MockSessionConstructor.return_value = mock_created_session
|
||||||
|
|
||||||
|
# Temporarily set env vars for this test
|
||||||
|
original_token = os.environ.get("GITHUB_TOKEN")
|
||||||
|
original_repo = os.environ.get("GITHUB_REPOSITORY")
|
||||||
|
os.environ["GITHUB_TOKEN"] = "env_token"
|
||||||
|
os.environ["GITHUB_REPOSITORY"] = "env/repo"
|
||||||
|
|
||||||
|
tool = GitHubTool(logger=self.logger) # Use env vars
|
||||||
|
|
||||||
|
MockSessionConstructor.assert_called_once()
|
||||||
|
self.assertEqual(tool.session, mock_created_session)
|
||||||
|
self.assertEqual(tool._token, "env_token")
|
||||||
|
self.assertEqual(tool._repo, "env/repo")
|
||||||
|
self.assertIn("Authorization", mock_created_session.headers)
|
||||||
|
self.assertEqual(mock_created_session.headers["Authorization"], "token env_token")
|
||||||
|
|
||||||
|
# Restore original env vars
|
||||||
|
if original_token is None: del os.environ["GITHUB_TOKEN"]
|
||||||
|
else: os.environ["GITHUB_TOKEN"] = original_token
|
||||||
|
if original_repo is None: del os.environ["GITHUB_REPOSITORY"]
|
||||||
|
else: os.environ["GITHUB_REPOSITORY"] = original_repo
|
||||||
|
|
||||||
|
def test_init_raises_value_error_if_no_token(self):
|
||||||
|
original_token = os.environ.pop("GITHUB_TOKEN", None)
|
||||||
|
with self.assertRaisesRegex(ValueError, "GitHub token must be provided"):
|
||||||
|
GitHubTool(repo=self.test_repo, logger=self.logger)
|
||||||
|
if original_token: os.environ["GITHUB_TOKEN"] = original_token
|
||||||
|
|
||||||
|
def test_init_raises_value_error_if_no_repo(self):
|
||||||
|
original_repo = os.environ.pop("GITHUB_REPOSITORY", None)
|
||||||
|
with self.assertRaisesRegex(ValueError, "GitHub repository.*must be provided"):
|
||||||
|
GitHubTool(token=self.test_token, logger=self.logger)
|
||||||
|
if original_repo: os.environ["GITHUB_REPOSITORY"] = original_repo
|
||||||
|
|
||||||
|
def test_clear_resets_branch(self):
|
||||||
|
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, initial_branch="feature-branch", logger=self.logger)
|
||||||
|
# Mock _get_branch_sha for _set_current_branch called by clear
|
||||||
|
with patch.object(tool, '_get_branch_sha', return_value="sha_for_main"):
|
||||||
|
tool.clear()
|
||||||
|
self.assertEqual(tool.current_branch, "main")
|
||||||
|
|
||||||
|
def test_get_functions_returns_list(self):
|
||||||
|
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||||
|
functions = tool.get_functions()
|
||||||
|
self.assertIsInstance(functions, list)
|
||||||
|
self.assertTrue(len(functions) > 0)
|
||||||
|
self.assertIn("name", functions[0]["function"])
|
||||||
|
|
||||||
|
|
||||||
|
# --- Test individual private methods ---
|
||||||
|
|
||||||
|
def test_read_file_success(self):
|
||||||
|
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||||
|
file_content = "Hello World!"
|
||||||
|
encoded_content = base64.b64encode(file_content.encode('utf-8')).decode('utf-8')
|
||||||
|
self.mock_session.get.return_value = create_mock_response(200, json_data={"content": encoded_content})
|
||||||
|
|
||||||
|
result = tool._read_file(path="test.txt")
|
||||||
|
self.assertEqual(result, file_content)
|
||||||
|
self.mock_session.get.assert_called_once_with(
|
||||||
|
f"{tool.base_url}/repos/{self.test_repo}/contents/test.txt",
|
||||||
|
params={"ref": "main"}
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_read_file_error(self):
|
||||||
|
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||||
|
self.mock_session.get.return_value = create_mock_response(404, text_data="Not Found")
|
||||||
|
result = tool._read_file(path="nonexistent.txt")
|
||||||
|
self.assertIn("Error reading file", result)
|
||||||
|
|
||||||
|
def test_create_branch_success(self):
|
||||||
|
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||||
|
# Mock getting base branch SHA
|
||||||
|
self.mock_session.get.return_value = create_mock_response(200, json_data={"object": {"sha": "base_sha123"}})
|
||||||
|
# Mock creating new branch
|
||||||
|
self.mock_session.post.return_value = create_mock_response(201, json_data={"ref": "refs/heads/new-feature"})
|
||||||
|
|
||||||
|
result = tool._create_branch(branch_name="new-feature", base_branch="main")
|
||||||
|
self.assertIn("Branch 'new-feature' created successfully", result)
|
||||||
|
self.assertEqual(tool.current_branch, "new-feature")
|
||||||
|
self.mock_session.get.assert_called_once() # For base branch SHA
|
||||||
|
self.mock_session.post.assert_called_once() # For creating branch
|
||||||
|
|
||||||
|
def test_create_branch_base_sha_error(self):
|
||||||
|
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||||
|
self.mock_session.get.return_value = create_mock_response(404, text_data="Base branch not found")
|
||||||
|
result = tool._create_branch(branch_name="new-feature", base_branch="nonexistent-base")
|
||||||
|
self.assertIn("Error getting base branch SHA", result)
|
||||||
|
|
||||||
|
def test_create_branch_creation_error(self):
|
||||||
|
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||||
|
self.mock_session.get.return_value = create_mock_response(200, json_data={"object": {"sha": "base_sha456"}})
|
||||||
|
self.mock_session.post.return_value = create_mock_response(422, text_data="Validation failed")
|
||||||
|
result = tool._create_branch(branch_name="bad-branch", base_branch="main")
|
||||||
|
self.assertIn("Error creating branch", result)
|
||||||
|
|
||||||
|
def test_commit_file_success_new_file(self):
|
||||||
|
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||||
|
tool.current_branch = "dev-branch" # Cannot commit to main by default
|
||||||
|
|
||||||
|
# Mock GET for checking file existence (404 means new file)
|
||||||
|
self.mock_session.get.return_value = create_mock_response(404)
|
||||||
|
# Mock PUT for committing file
|
||||||
|
self.mock_session.put.return_value = create_mock_response(201, json_data={"commit": {"sha": "commit_sha_abc"}})
|
||||||
|
|
||||||
|
result = tool._commit_file(file_path="new_file.py", content="print('Hello')", commit_message="Add new_file.py")
|
||||||
|
self.assertIn("committed successfully", result)
|
||||||
|
self.assertIn("commit_sha_abc", result)
|
||||||
|
self.mock_session.get.assert_called_once() # Check file existence
|
||||||
|
self.mock_session.put.assert_called_once() # Commit file
|
||||||
|
|
||||||
|
def test_commit_file_success_update_file(self):
|
||||||
|
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||||
|
tool.current_branch = "dev-branch"
|
||||||
|
|
||||||
|
# Mock GET for checking file existence (200 means existing file)
|
||||||
|
self.mock_session.get.return_value = create_mock_response(200, json_data={"sha": "existing_file_sha"})
|
||||||
|
# Mock PUT for committing file
|
||||||
|
self.mock_session.put.return_value = create_mock_response(200, json_data={"commit": {"sha": "commit_sha_def"}})
|
||||||
|
|
||||||
|
result = tool._commit_file(file_path="existing_file.txt", content="Updated content", commit_message="Update existing_file.txt")
|
||||||
|
self.assertIn("committed successfully", result)
|
||||||
|
self.assertIn("commit_sha_def", result)
|
||||||
|
args, kwargs = self.mock_session.put.call_args
|
||||||
|
self.assertEqual(kwargs['json']['sha'], "existing_file_sha")
|
||||||
|
|
||||||
|
|
||||||
|
def test_commit_file_to_main_branch_error(self):
|
||||||
|
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||||
|
tool.current_branch = "main"
|
||||||
|
result = tool._commit_file(file_path="some.txt", content="content", commit_message="msg")
|
||||||
|
self.assertIn("Action directly to main branch is not allowed", result)
|
||||||
|
|
||||||
|
def test_create_pull_request_success(self):
|
||||||
|
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||||
|
tool.current_branch = "feature-pr"
|
||||||
|
pr_url = "https://example.com/pull/1"
|
||||||
|
self.mock_session.post.return_value = create_mock_response(201, json_data={"html_url": pr_url, "number": 1})
|
||||||
|
|
||||||
|
result = tool._create_pull_request(title="New Feature PR", body="Please review.", base="main")
|
||||||
|
self.assertIn(f"Pull request created successfully: {pr_url}", result)
|
||||||
|
self.mock_session.post.assert_called_once()
|
||||||
|
call_data = self.mock_session.post.call_args[1]['json']
|
||||||
|
self.assertEqual(call_data['head'], "feature-pr")
|
||||||
|
self.assertEqual(call_data['base'], "main")
|
||||||
|
|
||||||
|
def test_create_pull_request_same_branch_error(self):
|
||||||
|
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||||
|
tool.current_branch = "main"
|
||||||
|
result = tool._create_pull_request(title="PR to self", body="This should fail", base="main")
|
||||||
|
self.assertIn("Cannot create a pull request from branch 'main' to itself", result)
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_files_success(self):
|
||||||
|
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||||
|
mock_items = [
|
||||||
|
{"name": "file1.txt", "type": "file", "path": "dir/file1.txt"},
|
||||||
|
{"name": "subdir", "type": "dir", "path": "dir/subdir"}
|
||||||
|
]
|
||||||
|
self.mock_session.get.return_value = create_mock_response(200, json_data=mock_items)
|
||||||
|
|
||||||
|
result = tool._list_files(path="dir")
|
||||||
|
self.assertEqual(len(result), 2)
|
||||||
|
self.assertEqual(result[0]["name"], "file1.txt")
|
||||||
|
self.assertEqual(result[1]["type"], "dir")
|
||||||
|
|
||||||
|
def test_search_code_success(self):
|
||||||
|
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||||
|
mock_search_results = {
|
||||||
|
"items": [{"path": "src/code.py", "html_url": "url1"}]
|
||||||
|
}
|
||||||
|
self.mock_session.get.return_value = create_mock_response(200, json_data=mock_search_results)
|
||||||
|
|
||||||
|
results = tool._search_code(query="my_function")
|
||||||
|
self.assertEqual(len(results), 1)
|
||||||
|
self.assertEqual(results[0]["path"], "src/code.py")
|
||||||
|
|
||||||
|
def test_get_commit_history_success(self):
|
||||||
|
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||||
|
mock_commits = [{
|
||||||
|
"sha": "sha1", "commit": {"message": "Msg1", "author": {"name": "Authy", "date": "Date1"}}
|
||||||
|
}]
|
||||||
|
self.mock_session.get.return_value = create_mock_response(200, json_data=mock_commits)
|
||||||
|
|
||||||
|
commits = tool._get_commit_history(file_path="file.txt", num_commits=1)
|
||||||
|
self.assertEqual(len(commits), 1)
|
||||||
|
self.assertEqual(commits[0]["sha"], "sha1")
|
||||||
|
|
||||||
|
def test_set_current_branch_success(self):
|
||||||
|
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||||
|
# Mock _get_branch_sha to simulate branch exists
|
||||||
|
with patch.object(tool, '_get_branch_sha', return_value="some_sha_for_dev"):
|
||||||
|
result = tool._set_current_branch(branch_name="dev")
|
||||||
|
self.assertEqual(tool.current_branch, "dev")
|
||||||
|
self.assertIn("Current branch set to: dev", result)
|
||||||
|
|
||||||
|
def test_set_current_branch_not_exists(self):
|
||||||
|
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||||
|
with patch.object(tool, '_get_branch_sha', return_value="Error getting SHA for branch"):
|
||||||
|
result = tool._set_current_branch(branch_name="nonexistent-branch")
|
||||||
|
self.assertNotEqual(tool.current_branch, "nonexistent-branch") # Should not change
|
||||||
|
self.assertIn("Cannot set current branch", result)
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_branches_single_page(self):
|
||||||
|
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||||
|
mock_branches = [{"name": "main"}, {"name": "dev"}]
|
||||||
|
self.mock_session.get.return_value = create_mock_response(200, json_data=mock_branches, links={}) # No "next" link
|
||||||
|
|
||||||
|
branches = tool._list_branches(all_pages=True)
|
||||||
|
self.assertEqual(branches, ["main", "dev"])
|
||||||
|
self.mock_session.get.assert_called_once()
|
||||||
|
|
||||||
|
def test_list_branches_multiple_pages(self):
|
||||||
|
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||||
|
|
||||||
|
# Page 1 response
|
||||||
|
page1_branches = [{"name": "branch1"}, {"name": "branch2"}]
|
||||||
|
next_url = f"{tool.base_url}/repos/{self.test_repo}/branches?page=2"
|
||||||
|
response1 = create_mock_response(200, json_data=page1_branches, links={"next": {"url": next_url}})
|
||||||
|
|
||||||
|
# Page 2 response
|
||||||
|
page2_branches = [{"name": "branch3"}]
|
||||||
|
response2 = create_mock_response(200, json_data=page2_branches, links={}) # No "next" link
|
||||||
|
|
||||||
|
self.mock_session.get.side_effect = [response1, response2]
|
||||||
|
|
||||||
|
branches = tool._list_branches(all_pages=True)
|
||||||
|
self.assertEqual(branches, ["branch1", "branch2", "branch3"])
|
||||||
|
self.assertEqual(self.mock_session.get.call_count, 2)
|
||||||
|
|
||||||
|
# Check that the second call used the next_url
|
||||||
|
calls = self.mock_session.get.call_args_list
|
||||||
|
self.assertEqual(calls[1][0][0], next_url) # args[0] is the URL
|
||||||
|
|
||||||
|
# --- Test execute dispatcher ---
|
||||||
|
def test_execute_read_file(self):
|
||||||
|
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||||
|
with patch.object(tool, '_read_file', return_value="file content") as mock_method:
|
||||||
|
result = tool.execute(function_name="read_file", path="test.md")
|
||||||
|
mock_method.assert_called_once_with(path="test.md")
|
||||||
|
self.assertEqual(result, "file content")
|
||||||
|
|
||||||
|
def test_execute_unknown_function(self):
|
||||||
|
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||||
|
result = tool.execute(function_name="non_existent_function_name", arg1="val1")
|
||||||
|
self.assertIn("Unknown function: non_existent_function_name", result)
|
||||||
|
|
||||||
|
def test_execute_method_exception(self):
|
||||||
|
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||||
|
with patch.object(tool, '_read_file', side_effect=Exception("Kaboom")) as mock_method:
|
||||||
|
result = tool.execute(function_name="read_file", path="crash.txt")
|
||||||
|
self.assertIn("Error during read_file execution: Kaboom", result)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user