162 lines
8.2 KiB
Python
162 lines
8.2 KiB
Python
import unittest
|
|
from unittest.mock import MagicMock, patch
|
|
import logging
|
|
|
|
# Ensure tools.metrics_tool and tools.metrics are accessible
|
|
from tools.metrics_tool import MetricsTool
|
|
from tools.metrics import Metrics # Used for typehinting and creating a mockable instance
|
|
|
|
class TestMetricsTool(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
self.mock_metrics_provider = MagicMock(spec=Metrics)
|
|
self.logger = logging.getLogger('tools.metrics_tool.test')
|
|
if not self.logger.handlers:
|
|
self.logger.addHandler(logging.NullHandler())
|
|
self.logger.propagate = False
|
|
|
|
|
|
def test_init_with_provider(self):
|
|
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
|
|
self.assertEqual(tool.metrics_provider, self.mock_metrics_provider)
|
|
|
|
@patch('tools.metrics_tool.global_metrics_instance') # Patch the global instance path
|
|
def test_init_default_provider(self, mock_global_metrics):
|
|
tool = MetricsTool(logger=self.logger)
|
|
self.assertEqual(tool.metrics_provider, mock_global_metrics)
|
|
|
|
def test_get_functions(self):
|
|
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
|
|
functions = tool.get_functions()
|
|
self.assertIsInstance(functions, list)
|
|
self.assertTrue(len(functions) == 3) # Based on current definition
|
|
self.assertIn("get_function_metrics", [f["function"]["name"] for f in functions])
|
|
self.assertIn("get_specific_function_metrics", [f["function"]["name"] for f in functions])
|
|
self.assertIn("get_top_n_functions", [f["function"]["name"] for f in functions])
|
|
|
|
def test_execute_get_function_metrics(self):
|
|
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
|
|
expected_metrics = {"func1": {"call_count": 1, "total_time": 0.1}}
|
|
self.mock_metrics_provider.get_metrics.return_value = expected_metrics
|
|
|
|
result = tool.execute(function_name="get_function_metrics")
|
|
|
|
self.mock_metrics_provider.get_metrics.assert_called_once()
|
|
self.assertEqual(result, expected_metrics)
|
|
|
|
def test_execute_get_specific_function_metrics_found(self):
|
|
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
|
|
func_metrics = {"call_count": 5, "total_time": 0.5, "average_time": 0.1}
|
|
all_metrics = {"specific_func": func_metrics, "other_func": {}}
|
|
self.mock_metrics_provider.get_metrics.return_value = all_metrics
|
|
|
|
# The execute method expects kwargs that match the function parameters in get_functions.
|
|
# So, the argument name for the function to get is 'function_name' in the tool's spec.
|
|
result = tool.execute(function_name="get_specific_function_metrics", **{"function_name": "specific_func"})
|
|
self.assertEqual(result, func_metrics)
|
|
|
|
def test_execute_get_specific_function_metrics_not_found(self):
|
|
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
|
|
self.mock_metrics_provider.get_metrics.return_value = {"other_func": {}}
|
|
|
|
result = tool.execute(function_name="get_specific_function_metrics", **{"function_name": "non_existent_func"})
|
|
self.assertEqual(result, "No metrics found for function: non_existent_func")
|
|
|
|
def test_execute_get_specific_function_metrics_missing_arg(self):
|
|
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
|
|
result = tool.execute(function_name="get_specific_function_metrics") # Missing function_name kwarg
|
|
self.assertIn("Error: Missing required argument 'function_name'", result)
|
|
|
|
|
|
def test_execute_get_top_n_functions(self):
|
|
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
|
|
metrics_data = {
|
|
"func_a": {"call_count": 1, "total_time": 0.3},
|
|
"func_b": {"call_count": 1, "total_time": 0.1},
|
|
"func_c": {"call_count": 1, "total_time": 0.5},
|
|
"func_d": {"call_count": 1, "total_time": 0.2},
|
|
}
|
|
self.mock_metrics_provider.get_metrics.return_value = metrics_data
|
|
|
|
# Test getting top 2
|
|
result = tool.execute(function_name="get_top_n_functions", n=2)
|
|
expected_top_2 = {"func_c": metrics_data["func_c"], "func_a": metrics_data["func_a"]}
|
|
self.assertEqual(result, expected_top_2)
|
|
|
|
# Test getting top 1
|
|
result_top_1 = tool.execute(function_name="get_top_n_functions", n=1)
|
|
expected_top_1 = {"func_c": metrics_data["func_c"]}
|
|
self.assertEqual(result_top_1, expected_top_1)
|
|
|
|
# Test N larger than available functions
|
|
result_top_all = tool.execute(function_name="get_top_n_functions", n=10)
|
|
# Order should be func_c, func_a, func_d, func_b
|
|
expected_top_all_keys = ["func_c", "func_a", "func_d", "func_b"]
|
|
self.assertEqual(list(result_top_all.keys()), expected_top_all_keys)
|
|
|
|
def test_execute_get_top_n_functions_malformed_metrics(self):
|
|
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
|
|
metrics_data = {
|
|
"func_a": {"call_count": 1, "total_time": 0.3},
|
|
"func_b": "not a dict", # Malformed
|
|
"func_c": {"call_count": 1}, # Missing total_time
|
|
"func_d": {"call_count": 1, "total_time": 0.2},
|
|
}
|
|
self.mock_metrics_provider.get_metrics.return_value = metrics_data
|
|
|
|
metrics_tool_logger = logging.getLogger('tools.metrics_tool')
|
|
original_level = metrics_tool_logger.level
|
|
metrics_tool_logger.setLevel(logging.WARNING)
|
|
with self.assertLogs(metrics_tool_logger, level='WARNING') as log_capture:
|
|
result = tool.execute(function_name="get_top_n_functions", n=2)
|
|
metrics_tool_logger.setLevel(original_level)
|
|
|
|
# Check that warnings were logged for malformed items
|
|
self.assertTrue(any("Metric item for 'func_b' is not in expected format" in msg for msg in log_capture.output))
|
|
self.assertTrue(any("Metric item for 'func_c' is not in expected format" in msg for msg in log_capture.output))
|
|
|
|
# Expected: func_a, func_d (as they are valid and sortable)
|
|
expected_result = {
|
|
"func_a": metrics_data["func_a"],
|
|
"func_d": metrics_data["func_d"]
|
|
}
|
|
self.assertEqual(result, expected_result)
|
|
|
|
|
|
def test_execute_get_top_n_functions_invalid_n(self):
|
|
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
|
|
self.mock_metrics_provider.get_metrics.return_value = {} # No metrics needed for this test
|
|
|
|
result_zero = tool.execute(function_name="get_top_n_functions", n=0)
|
|
self.assertIn("Error: Argument 'n' must be a positive integer.", result_zero)
|
|
|
|
result_negative = tool.execute(function_name="get_top_n_functions", n=-1)
|
|
self.assertIn("Error: Argument 'n' must be a positive integer.", result_negative)
|
|
|
|
result_string = tool.execute(function_name="get_top_n_functions", n="abc")
|
|
self.assertIn("Error: Argument 'n' must be an integer.", result_string)
|
|
|
|
def test_execute_get_top_n_functions_missing_arg_n(self):
|
|
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
|
|
result = tool.execute(function_name="get_top_n_functions") # Missing n
|
|
self.assertIn("Error: Missing required argument 'n'.", result)
|
|
|
|
|
|
def test_execute_unknown_function(self):
|
|
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
|
|
result = tool.execute(function_name="non_existent_metrics_function")
|
|
self.assertIn("Unknown function: non_existent_metrics_function", result)
|
|
|
|
def test_clear_method(self):
|
|
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
|
|
metrics_tool_logger = logging.getLogger('tools.metrics_tool')
|
|
original_level = metrics_tool_logger.level
|
|
metrics_tool_logger.setLevel(logging.DEBUG)
|
|
with self.assertLogs(metrics_tool_logger, level='DEBUG') as cm:
|
|
tool.clear()
|
|
metrics_tool_logger.setLevel(original_level)
|
|
self.assertTrue(any("MetricsTool clear method called" in message for message in cm.output))
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|