diff --git a/tests/test_base_telegram_inference_bot.py b/tests/test_base_telegram_inference_bot.py new file mode 100644 index 0000000..bcb2b52 --- /dev/null +++ b/tests/test_base_telegram_inference_bot.py @@ -0,0 +1,310 @@ +import unittest +from unittest.mock import patch, mock_open, MagicMock +import os +import json + +# Ensure the path includes the directory where base_telegram_inference_bot is located +# This might require adjustment based on actual project structure if tests are run from root +# For now, assuming it can be imported directly or via PYTHONPATH +from base_telegram_inference_bot import BaseTelegramInferenceBot +from tools.base_tool import BaseTool # For mocking tool structure + +# Create a concrete subclass for testing, as BaseTelegramInferenceBot is abstract +class ConcreteTestBot(BaseTelegramInferenceBot): + def __init__(self, system_prompt_content=None, system_prompt_path=None, mock_tools=None, mock_functions=None): + # Mock load_functions during super().__init__ if needed, or control tools/functions directly + self._mock_tools = mock_tools if mock_tools is not None else [] + self._mock_functions = mock_functions if mock_functions is not None else [] + super().__init__(system_prompt_content=system_prompt_content, system_prompt_path=system_prompt_path) + + # Override load_functions to use mocks + def load_functions(self): + return self._mock_tools, self._mock_functions + + def get_chat_response(self, messages): + pass # Abstract method, not tested here directly + + async def handle_message(self, user_id, user_message): + pass # Abstract method + + def get_llm_description(self) -> str: + return "Mock LLM Description" # Concrete implementation for testing get_bot_status + + async def start(self): + pass # Abstract method + + async def abort_processing(self, user_id): + pass # Abstract method + + async def switch_model(self): + pass # Abstract method + +class TestBaseTelegramInferenceBot(unittest.TestCase): + + def setUp(self): + # Reset relevant environment variables before each test + self.original_system_prompt_path = os.environ.get("SYSTEM_PROMPT_PATH") + if "SYSTEM_PROMPT_PATH" in os.environ: + del os.environ["SYSTEM_PROMPT_PATH"] + + def tearDown(self): + # Restore environment variables + if self.original_system_prompt_path: + os.environ["SYSTEM_PROMPT_PATH"] = self.original_system_prompt_path + elif "SYSTEM_PROMPT_PATH" in os.environ: # Ensure it's removed if test set it and it wasn't there before + del os.environ["SYSTEM_PROMPT_PATH"] + + def test_init_with_direct_system_prompt(self): + bot = ConcreteTestBot(system_prompt_content="Direct prompt content") + self.assertEqual(bot.system_prompt, "Direct prompt content") + + @patch("os.path.isfile") + @patch("builtins.open", new_callable=mock_open, read_data="File prompt content") + def test_init_with_system_prompt_path_argument(self, mock_file_open, mock_isfile): + mock_isfile.return_value = True + bot = ConcreteTestBot(system_prompt_path="dummy/path.txt") + self.assertEqual(bot.system_prompt, "File prompt content") + mock_isfile.assert_called_once_with("dummy/path.txt") + mock_file_open.assert_called_once_with("dummy/path.txt", "r", encoding="utf-8") + + @patch("os.path.isfile") + @patch("builtins.open", new_callable=mock_open, read_data="Env prompt content") + def test_init_with_env_system_prompt_path(self, mock_file_open, mock_isfile): + mock_isfile.return_value = True + os.environ["SYSTEM_PROMPT_PATH"] = "env/path.txt" + bot = ConcreteTestBot() + self.assertEqual(bot.system_prompt, "Env prompt content") + mock_isfile.assert_called_once_with("env/path.txt") + mock_file_open.assert_called_once_with("env/path.txt", "r", encoding="utf-8") + + def test_init_with_default_system_prompt(self): + # Ensure ENV var is not set for this test + if "SYSTEM_PROMPT_PATH" in os.environ: + del os.environ["SYSTEM_PROMPT_PATH"] + bot = ConcreteTestBot() + self.assertEqual(bot.system_prompt, "You are a helpful AI assistant.") + + @patch("os.path.isfile", return_value=False) + def test_init_with_invalid_system_prompt_path(self, mock_isfile): + bot = ConcreteTestBot(system_prompt_path="invalid/path.txt") + self.assertEqual(bot.system_prompt, "You are a helpful AI assistant.") + mock_isfile.assert_called_once_with("invalid/path.txt") + + @patch("os.path.isfile") + @patch("builtins.open", side_effect=IOError("File read error")) + def test_init_with_system_prompt_file_read_error(self, mock_file_open, mock_isfile): + mock_isfile.return_value = True + bot = ConcreteTestBot(system_prompt_path="dummy/path.txt") + self.assertEqual(bot.system_prompt, "You are a helpful AI assistant.") + + def test_clear_conversation_history(self): + mock_tool_instance = MagicMock(spec=BaseTool) + bot = ConcreteTestBot(mock_tools=[mock_tool_instance]) + bot.conversation_history[123] = [{"role": "user", "content": "Hello"}] + + bot.clear_conversation_history(123) + self.assertNotIn(123, bot.conversation_history) + mock_tool_instance.clear.assert_called_once() + + def test_clear_conversation_history_user_not_found(self): + mock_tool_instance = MagicMock(spec=BaseTool) + bot = ConcreteTestBot(mock_tools=[mock_tool_instance]) + bot.clear_conversation_history(404) + self.assertNotIn(404, bot.conversation_history) + mock_tool_instance.clear.assert_called_once() + + def test_processing_status(self): + bot = ConcreteTestBot() + self.assertEqual(bot.processing_status, {}) + bot.set_processing_status(123, 789) + self.assertEqual(bot.processing_status[123], {"processing": True, "message_id": 789}) + bot.clear_processing_status(123) + self.assertNotIn(123, bot.processing_status) + + def test_clear_processing_status_user_not_found(self): + bot = ConcreteTestBot() + bot.clear_processing_status(404) + self.assertNotIn(404, bot.processing_status) + + def test_call_tool_success_dict_args(self): + mock_tool = MagicMock(spec=BaseTool) + mock_tool.get_functions.return_value = [ + {"function": {"name": "test_tool", "description": "A test tool", "parameters": {}}} + ] + mock_tool.execute.return_value = "Tool executed successfully" + + bot = ConcreteTestBot(mock_tools=[mock_tool], mock_functions=mock_tool.get_functions()) + + result = bot.call_tool("test_tool", {"arg1": "value1"}) + self.assertEqual(result, "Tool executed successfully") + mock_tool.execute.assert_called_once_with("test_tool", arg1="value1") + + def test_call_tool_success_json_string_args(self): + mock_tool = MagicMock(spec=BaseTool) + mock_tool.get_functions.return_value = [ + {"function": {"name": "test_tool_json", "parameters": {}}} + ] + mock_tool.execute.return_value = "Tool JSON OK" + bot = ConcreteTestBot(mock_tools=[mock_tool], mock_functions=mock_tool.get_functions()) + + args_json_str = '''{"param": "value"}''' + result = bot.call_tool("test_tool_json", args_json_str) + self.assertEqual(result, "Tool JSON OK") + mock_tool.execute.assert_called_once_with("test_tool_json", param="value") + + def test_call_tool_malformed_json_string_args(self): + bot = ConcreteTestBot(mock_tools=[]) + args_malformed_json_str = '''{"param": "value"''' + result = bot.call_tool("some_tool", args_malformed_json_str) + self.assertTrue("Error: Malformed arguments for tool call" in result) + + def test_call_tool_unexpected_arg_type(self): + bot = ConcreteTestBot(mock_tools=[]) + result = bot.call_tool("some_tool", 12345) # Integer instead of dict/str + self.assertTrue("Error: Invalid argument type for tool call" in result) + + def test_call_tool_none_args(self): + mock_tool = MagicMock(spec=BaseTool) + mock_tool.get_functions.return_value = [ + {"function": {"name": "test_tool_none", "parameters": {}}} + ] + mock_tool.execute.return_value = "Tool None OK" + bot = ConcreteTestBot(mock_tools=[mock_tool], mock_functions=mock_tool.get_functions()) + + result = bot.call_tool("test_tool_none", None) + self.assertEqual(result, "Tool None OK") + mock_tool.execute.assert_called_once_with("test_tool_none") # No kwargs if None + + def test_call_tool_not_found(self): + bot = ConcreteTestBot(mock_tools=[]) + result = bot.call_tool("non_existent_tool", {}) + self.assertEqual(result, "Error: Tool function non_existent_tool not found.") + + def test_call_tool_execute_exception(self): + mock_tool = MagicMock(spec=BaseTool) + mock_tool.get_functions.return_value = [{"function": {"name": "error_tool", "parameters": {}}}] + mock_tool.execute.side_effect = Exception("Execution failed") + bot = ConcreteTestBot(mock_tools=[mock_tool], mock_functions=mock_tool.get_functions()) + + result = bot.call_tool("error_tool", {}) + self.assertEqual(result, "Error executing tool error_tool: Execution failed") + + def test_get_system_prompt_description(self): + if "SYSTEM_PROMPT_PATH" in os.environ: # Ensure clean state + del os.environ["SYSTEM_PROMPT_PATH"] + + bot_default = ConcreteTestBot() + self.assertEqual(bot_default.get_system_prompt_description(), "System Prompt: Default") + + bot_custom_content = ConcreteTestBot(system_prompt_content="Custom content here") + self.assertEqual(bot_custom_content.get_system_prompt_description(), "System Prompt: Custom") + + os.environ["SYSTEM_PROMPT_PATH"] = "some/path.txt" + bot_env_default_prompt = ConcreteTestBot() # system_prompt itself is default + self.assertEqual(bot_env_default_prompt.get_system_prompt_description(), "System Prompt: Custom (via ENV)") + + with patch("os.path.isfile", return_value=True), \ + patch("builtins.open", mock_open(read_data="File prompt from ENV")): + bot_env_file_prompt = ConcreteTestBot() # system_prompt gets loaded from ENV path + self.assertEqual(bot_env_file_prompt.get_system_prompt_description(), "System Prompt: Custom") + del os.environ["SYSTEM_PROMPT_PATH"] + + with patch("os.path.isfile", return_value=True), \ + patch("builtins.open", mock_open(read_data="File prompt from arg")): + bot_custom_file_arg = ConcreteTestBot(system_prompt_path="custom/file.txt") + self.assertEqual(bot_custom_file_arg.get_system_prompt_description(), "System Prompt: Custom") + + @patch.object(ConcreteTestBot, 'get_llm_description', return_value="Test LLM Description") + @patch.object(ConcreteTestBot, 'get_system_prompt_description', return_value="Test Prompt Description") + async def test_get_bot_status(self, mock_prompt_desc, mock_llm_desc): + bot = ConcreteTestBot() + status = await bot.get_bot_status() + self.assertEqual(status, "Test Prompt Description\nTest LLM Description") + mock_prompt_desc.assert_called_once() + mock_llm_desc.assert_called_once() + + @patch('os.path.dirname', return_value='/mock/path') + @patch('os.path.join') + @patch('os.path.exists') + @patch('os.listdir') + @patch('importlib.import_module') + def test_load_functions_no_tools_dir(self, mock_import_module, mock_listdir, mock_exists, mock_join, mock_dirname): + mock_join.return_value = '/mock/path/tools' + mock_exists.return_value = False + + class BotForLoadTest(BaseTelegramInferenceBot): + load_system_prompt = MagicMock(return_value="Default") + get_chat_response = MagicMock(); handle_message = MagicMock(); get_llm_description = MagicMock(return_value="mock") + start = MagicMock(); abort_processing = MagicMock(); switch_model = MagicMock() + + bot = BotForLoadTest() + self.assertEqual(bot.tools, []) + self.assertEqual(bot.functions, []) + mock_listdir.assert_not_called() + + @patch('os.path.dirname', return_value='/mock/base_bot_dir') + @patch('os.path.join', side_effect=lambda *args: os.path.normpath(os.path.join(*args))) + @patch('os.path.exists', return_value=True) + @patch('os.listdir', return_value=['my_tool.py', '__init__.py', 'base_tool.py']) + @patch('importlib.import_module') + def test_load_functions_with_one_tool(self, mock_import_module, mock_listdir, mock_exists, mock_join, mock_dirname): + + mock_tool_class = MagicMock(spec=BaseTool) # This is the class itself + mock_tool_instance = MagicMock(spec=BaseTool) # This is the instance + mock_tool_class.return_value = mock_tool_instance # mock_tool_class() creates mock_tool_instance + mock_tool_instance.get_functions.return_value = [{"function": {"name": "sample_function"}}] + + mock_my_tool_module = MagicMock() + # Simulate inspect.getmembers behavior: returns list of (name, member) tuples + # Only include members that are classes, derive from BaseTool, and are not BaseTool itself. + mock_my_tool_module.ValidToolClass = mock_tool_class + mock_my_tool_module.NotATool = object() + mock_my_tool_module.BaseTool = BaseTool # This should be skipped by the loader + + def import_side_effect(module_name): + if module_name == 'tools.my_tool': + return mock_my_tool_module + raise ImportError(f"Unexpected import: {module_name}") + mock_import_module.side_effect = import_side_effect + + class BotForLoadTest(BaseTelegramInferenceBot): + load_system_prompt = MagicMock(return_value="Default") + get_chat_response = MagicMock(); handle_message = MagicMock(); get_llm_description = MagicMock(return_value="mock") + start = MagicMock(); abort_processing = MagicMock(); switch_model = MagicMock() + + bot = BotForLoadTest() + self.assertEqual(len(bot.tools), 1) + self.assertIs(bot.tools[0], mock_tool_instance) + self.assertEqual(len(bot.functions), 1) + self.assertEqual(bot.functions[0]['function']['name'], "sample_function") + mock_import_module.assert_called_once_with('tools.my_tool') + mock_tool_class.assert_called_once_with() # Tool class was instantiated + mock_tool_instance.get_functions.assert_called_once_with() + + @patch('os.path.dirname', return_value='/mock/base_bot_dir') + @patch('os.path.join', side_effect=lambda *args: os.path.normpath(os.path.join(*args))) + @patch('os.path.exists', return_value=True) + @patch('os.listdir', return_value=['tool_with_init_error.py']) + @patch('importlib.import_module') + @patch('logging.error') # Mock logging to check for error messages + def test_load_functions_tool_instantiation_error(self, mock_logging_error, mock_import_module, mock_listdir, mock_exists, mock_join, mock_dirname): + mock_tool_class_init_error = MagicMock(spec=BaseTool) + mock_tool_class_init_error.side_effect = Exception("Failed to init tool") # Error on instantiation + + mock_error_tool_module = MagicMock() + mock_error_tool_module.ToolWithInitError = mock_tool_class_init_error + + mock_import_module.return_value = mock_error_tool_module + + class BotForLoadTest(BaseTelegramInferenceBot): + load_system_prompt = MagicMock(return_value="Default") + get_chat_response = MagicMock(); handle_message = MagicMock(); get_llm_description = MagicMock(return_value="mock") + start = MagicMock(); abort_processing = MagicMock(); switch_model = MagicMock() + + bot = BotForLoadTest() + self.assertEqual(len(bot.tools), 0) + self.assertEqual(len(bot.functions), 0) + mock_logging_error.assert_any_call("Error instantiating tool ToolWithInitError from tool_with_init_error.py: Failed to init tool") + +if __name__ == '__main__': + unittest.main(闂傚лен䦗婢у〃埊鍓解劓姣)