From e4e5949fded6f77cab5c80b79313a35e32d97e03 Mon Sep 17 00:00:00 2001 From: cyclop-bot <178948048+cyclop-bot@users.noreply.github.com> Date: Mon, 2 Jun 2025 17:06:46 -0500 Subject: [PATCH] Add unit tests for refactored MetricsTool --- tests/tools/test_metrics_tool.py | 161 +++++++++++++++++++++++++++++++ 1 file changed, 161 insertions(+) create mode 100644 tests/tools/test_metrics_tool.py diff --git a/tests/tools/test_metrics_tool.py b/tests/tools/test_metrics_tool.py new file mode 100644 index 0000000..17c1b9d --- /dev/null +++ b/tests/tools/test_metrics_tool.py @@ -0,0 +1,161 @@ +import unittest +from unittest.mock import MagicMock, patch +import logging + +# Ensure tools.metrics_tool and tools.metrics are accessible +from tools.metrics_tool import MetricsTool +from tools.metrics import Metrics # Used for typehinting and creating a mockable instance + +class TestMetricsTool(unittest.TestCase): + + def setUp(self): + self.mock_metrics_provider = MagicMock(spec=Metrics) + self.logger = logging.getLogger('tools.metrics_tool.test') + if not self.logger.handlers: + self.logger.addHandler(logging.NullHandler()) + self.logger.propagate = False + + + def test_init_with_provider(self): + tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger) + self.assertEqual(tool.metrics_provider, self.mock_metrics_provider) + + @patch('tools.metrics_tool.global_metrics_instance') # Patch the global instance path + def test_init_default_provider(self, mock_global_metrics): + tool = MetricsTool(logger=self.logger) + self.assertEqual(tool.metrics_provider, mock_global_metrics) + + def test_get_functions(self): + tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger) + functions = tool.get_functions() + self.assertIsInstance(functions, list) + self.assertTrue(len(functions) == 3) # Based on current definition + self.assertIn("get_function_metrics", [f["function"]["name"] for f in functions]) + self.assertIn("get_specific_function_metrics", [f["function"]["name"] for f in functions]) + self.assertIn("get_top_n_functions", [f["function"]["name"] for f in functions]) + + def test_execute_get_function_metrics(self): + tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger) + expected_metrics = {"func1": {"call_count": 1, "total_time": 0.1}} + self.mock_metrics_provider.get_metrics.return_value = expected_metrics + + result = tool.execute(function_name="get_function_metrics") + + self.mock_metrics_provider.get_metrics.assert_called_once() + self.assertEqual(result, expected_metrics) + + def test_execute_get_specific_function_metrics_found(self): + tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger) + func_metrics = {"call_count": 5, "total_time": 0.5, "average_time": 0.1} + all_metrics = {"specific_func": func_metrics, "other_func": {}} + self.mock_metrics_provider.get_metrics.return_value = all_metrics + + # The execute method expects kwargs that match the function parameters in get_functions. + # So, the argument name for the function to get is 'function_name' in the tool's spec. + result = tool.execute(function_name="get_specific_function_metrics", **{"function_name": "specific_func"}) + self.assertEqual(result, func_metrics) + + def test_execute_get_specific_function_metrics_not_found(self): + tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger) + self.mock_metrics_provider.get_metrics.return_value = {"other_func": {}} + + result = tool.execute(function_name="get_specific_function_metrics", **{"function_name": "non_existent_func"}) + self.assertEqual(result, "No metrics found for function: non_existent_func") + + def test_execute_get_specific_function_metrics_missing_arg(self): + tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger) + result = tool.execute(function_name="get_specific_function_metrics") # Missing function_name kwarg + self.assertIn("Error: Missing required argument 'function_name'", result) + + + def test_execute_get_top_n_functions(self): + tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger) + metrics_data = { + "func_a": {"call_count": 1, "total_time": 0.3}, + "func_b": {"call_count": 1, "total_time": 0.1}, + "func_c": {"call_count": 1, "total_time": 0.5}, + "func_d": {"call_count": 1, "total_time": 0.2}, + } + self.mock_metrics_provider.get_metrics.return_value = metrics_data + + # Test getting top 2 + result = tool.execute(function_name="get_top_n_functions", n=2) + expected_top_2 = {"func_c": metrics_data["func_c"], "func_a": metrics_data["func_a"]} + self.assertEqual(result, expected_top_2) + + # Test getting top 1 + result_top_1 = tool.execute(function_name="get_top_n_functions", n=1) + expected_top_1 = {"func_c": metrics_data["func_c"]} + self.assertEqual(result_top_1, expected_top_1) + + # Test N larger than available functions + result_top_all = tool.execute(function_name="get_top_n_functions", n=10) + # Order should be func_c, func_a, func_d, func_b + expected_top_all_keys = ["func_c", "func_a", "func_d", "func_b"] + self.assertEqual(list(result_top_all.keys()), expected_top_all_keys) + + def test_execute_get_top_n_functions_malformed_metrics(self): + tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger) + metrics_data = { + "func_a": {"call_count": 1, "total_time": 0.3}, + "func_b": "not a dict", # Malformed + "func_c": {"call_count": 1}, # Missing total_time + "func_d": {"call_count": 1, "total_time": 0.2}, + } + self.mock_metrics_provider.get_metrics.return_value = metrics_data + + metrics_tool_logger = logging.getLogger('tools.metrics_tool') + original_level = metrics_tool_logger.level + metrics_tool_logger.setLevel(logging.WARNING) + with self.assertLogs(metrics_tool_logger, level='WARNING') as log_capture: + result = tool.execute(function_name="get_top_n_functions", n=2) + metrics_tool_logger.setLevel(original_level) + + # Check that warnings were logged for malformed items + self.assertTrue(any("Metric item for 'func_b' is not in expected format" in msg for msg in log_capture.output)) + self.assertTrue(any("Metric item for 'func_c' is not in expected format" in msg for msg in log_capture.output)) + + # Expected: func_a, func_d (as they are valid and sortable) + expected_result = { + "func_a": metrics_data["func_a"], + "func_d": metrics_data["func_d"] + } + self.assertEqual(result, expected_result) + + + def test_execute_get_top_n_functions_invalid_n(self): + tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger) + self.mock_metrics_provider.get_metrics.return_value = {} # No metrics needed for this test + + result_zero = tool.execute(function_name="get_top_n_functions", n=0) + self.assertIn("Error: Argument 'n' must be a positive integer.", result_zero) + + result_negative = tool.execute(function_name="get_top_n_functions", n=-1) + self.assertIn("Error: Argument 'n' must be a positive integer.", result_negative) + + result_string = tool.execute(function_name="get_top_n_functions", n="abc") + self.assertIn("Error: Argument 'n' must be an integer.", result_string) + + def test_execute_get_top_n_functions_missing_arg_n(self): + tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger) + result = tool.execute(function_name="get_top_n_functions") # Missing n + self.assertIn("Error: Missing required argument 'n'.", result) + + + def test_execute_unknown_function(self): + tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger) + result = tool.execute(function_name="non_existent_metrics_function") + self.assertIn("Unknown function: non_existent_metrics_function", result) + + def test_clear_method(self): + tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger) + metrics_tool_logger = logging.getLogger('tools.metrics_tool') + original_level = metrics_tool_logger.level + metrics_tool_logger.setLevel(logging.DEBUG) + with self.assertLogs(metrics_tool_logger, level='DEBUG') as cm: + tool.clear() + metrics_tool_logger.setLevel(original_level) + self.assertTrue(any("MetricsTool clear method called" in message for message in cm.output)) + +if __name__ == '__main__': + unittest.main()