311 lines
16 KiB
Python
311 lines
16 KiB
Python
import unittest
|
|
from unittest.mock import patch, mock_open, MagicMock
|
|
import os
|
|
import json
|
|
|
|
# Ensure the path includes the directory where base_telegram_inference_bot is located
|
|
# This might require adjustment based on actual project structure if tests are run from root
|
|
# For now, assuming it can be imported directly or via PYTHONPATH
|
|
from base_telegram_inference_bot import BaseTelegramInferenceBot
|
|
from tools.base_tool import BaseTool # For mocking tool structure
|
|
|
|
# Create a concrete subclass for testing, as BaseTelegramInferenceBot is abstract
|
|
class ConcreteTestBot(BaseTelegramInferenceBot):
|
|
def __init__(self, system_prompt_content=None, system_prompt_path=None, mock_tools=None, mock_functions=None):
|
|
# Mock load_functions during super().__init__ if needed, or control tools/functions directly
|
|
self._mock_tools = mock_tools if mock_tools is not None else []
|
|
self._mock_functions = mock_functions if mock_functions is not None else []
|
|
super().__init__(system_prompt_content=system_prompt_content, system_prompt_path=system_prompt_path)
|
|
|
|
# Override load_functions to use mocks
|
|
def load_functions(self):
|
|
return self._mock_tools, self._mock_functions
|
|
|
|
def get_chat_response(self, messages):
|
|
pass # Abstract method, not tested here directly
|
|
|
|
async def handle_message(self, user_id, user_message):
|
|
pass # Abstract method
|
|
|
|
def get_llm_description(self) -> str:
|
|
return "Mock LLM Description" # Concrete implementation for testing get_bot_status
|
|
|
|
async def start(self):
|
|
pass # Abstract method
|
|
|
|
async def abort_processing(self, user_id):
|
|
pass # Abstract method
|
|
|
|
async def switch_model(self):
|
|
pass # Abstract method
|
|
|
|
class TestBaseTelegramInferenceBot(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
# Reset relevant environment variables before each test
|
|
self.original_system_prompt_path = os.environ.get("SYSTEM_PROMPT_PATH")
|
|
if "SYSTEM_PROMPT_PATH" in os.environ:
|
|
del os.environ["SYSTEM_PROMPT_PATH"]
|
|
|
|
def tearDown(self):
|
|
# Restore environment variables
|
|
if self.original_system_prompt_path:
|
|
os.environ["SYSTEM_PROMPT_PATH"] = self.original_system_prompt_path
|
|
elif "SYSTEM_PROMPT_PATH" in os.environ: # Ensure it's removed if test set it and it wasn't there before
|
|
del os.environ["SYSTEM_PROMPT_PATH"]
|
|
|
|
def test_init_with_direct_system_prompt(self):
|
|
bot = ConcreteTestBot(system_prompt_content="Direct prompt content")
|
|
self.assertEqual(bot.system_prompt, "Direct prompt content")
|
|
|
|
@patch("os.path.isfile")
|
|
@patch("builtins.open", new_callable=mock_open, read_data="File prompt content")
|
|
def test_init_with_system_prompt_path_argument(self, mock_file_open, mock_isfile):
|
|
mock_isfile.return_value = True
|
|
bot = ConcreteTestBot(system_prompt_path="dummy/path.txt")
|
|
self.assertEqual(bot.system_prompt, "File prompt content")
|
|
mock_isfile.assert_called_once_with("dummy/path.txt")
|
|
mock_file_open.assert_called_once_with("dummy/path.txt", "r", encoding="utf-8")
|
|
|
|
@patch("os.path.isfile")
|
|
@patch("builtins.open", new_callable=mock_open, read_data="Env prompt content")
|
|
def test_init_with_env_system_prompt_path(self, mock_file_open, mock_isfile):
|
|
mock_isfile.return_value = True
|
|
os.environ["SYSTEM_PROMPT_PATH"] = "env/path.txt"
|
|
bot = ConcreteTestBot()
|
|
self.assertEqual(bot.system_prompt, "Env prompt content")
|
|
mock_isfile.assert_called_once_with("env/path.txt")
|
|
mock_file_open.assert_called_once_with("env/path.txt", "r", encoding="utf-8")
|
|
|
|
def test_init_with_default_system_prompt(self):
|
|
# Ensure ENV var is not set for this test
|
|
if "SYSTEM_PROMPT_PATH" in os.environ:
|
|
del os.environ["SYSTEM_PROMPT_PATH"]
|
|
bot = ConcreteTestBot()
|
|
self.assertEqual(bot.system_prompt, "You are a helpful AI assistant.")
|
|
|
|
@patch("os.path.isfile", return_value=False)
|
|
def test_init_with_invalid_system_prompt_path(self, mock_isfile):
|
|
bot = ConcreteTestBot(system_prompt_path="invalid/path.txt")
|
|
self.assertEqual(bot.system_prompt, "You are a helpful AI assistant.")
|
|
mock_isfile.assert_called_once_with("invalid/path.txt")
|
|
|
|
@patch("os.path.isfile")
|
|
@patch("builtins.open", side_effect=IOError("File read error"))
|
|
def test_init_with_system_prompt_file_read_error(self, mock_file_open, mock_isfile):
|
|
mock_isfile.return_value = True
|
|
bot = ConcreteTestBot(system_prompt_path="dummy/path.txt")
|
|
self.assertEqual(bot.system_prompt, "You are a helpful AI assistant.")
|
|
|
|
def test_clear_conversation_history(self):
|
|
mock_tool_instance = MagicMock(spec=BaseTool)
|
|
bot = ConcreteTestBot(mock_tools=[mock_tool_instance])
|
|
bot.conversation_history[123] = [{"role": "user", "content": "Hello"}]
|
|
|
|
bot.clear_conversation_history(123)
|
|
self.assertNotIn(123, bot.conversation_history)
|
|
mock_tool_instance.clear.assert_called_once()
|
|
|
|
def test_clear_conversation_history_user_not_found(self):
|
|
mock_tool_instance = MagicMock(spec=BaseTool)
|
|
bot = ConcreteTestBot(mock_tools=[mock_tool_instance])
|
|
bot.clear_conversation_history(404)
|
|
self.assertNotIn(404, bot.conversation_history)
|
|
mock_tool_instance.clear.assert_called_once()
|
|
|
|
def test_processing_status(self):
|
|
bot = ConcreteTestBot()
|
|
self.assertEqual(bot.processing_status, {})
|
|
bot.set_processing_status(123, 789)
|
|
self.assertEqual(bot.processing_status[123], {"processing": True, "message_id": 789})
|
|
bot.clear_processing_status(123)
|
|
self.assertNotIn(123, bot.processing_status)
|
|
|
|
def test_clear_processing_status_user_not_found(self):
|
|
bot = ConcreteTestBot()
|
|
bot.clear_processing_status(404)
|
|
self.assertNotIn(404, bot.processing_status)
|
|
|
|
def test_call_tool_success_dict_args(self):
|
|
mock_tool = MagicMock(spec=BaseTool)
|
|
mock_tool.get_functions.return_value = [
|
|
{"function": {"name": "test_tool", "description": "A test tool", "parameters": {}}}
|
|
]
|
|
mock_tool.execute.return_value = "Tool executed successfully"
|
|
|
|
bot = ConcreteTestBot(mock_tools=[mock_tool], mock_functions=mock_tool.get_functions())
|
|
|
|
result = bot.call_tool("test_tool", {"arg1": "value1"})
|
|
self.assertEqual(result, "Tool executed successfully")
|
|
mock_tool.execute.assert_called_once_with("test_tool", arg1="value1")
|
|
|
|
def test_call_tool_success_json_string_args(self):
|
|
mock_tool = MagicMock(spec=BaseTool)
|
|
mock_tool.get_functions.return_value = [
|
|
{"function": {"name": "test_tool_json", "parameters": {}}}
|
|
]
|
|
mock_tool.execute.return_value = "Tool JSON OK"
|
|
bot = ConcreteTestBot(mock_tools=[mock_tool], mock_functions=mock_tool.get_functions())
|
|
|
|
args_json_str = '''{"param": "value"}'''
|
|
result = bot.call_tool("test_tool_json", args_json_str)
|
|
self.assertEqual(result, "Tool JSON OK")
|
|
mock_tool.execute.assert_called_once_with("test_tool_json", param="value")
|
|
|
|
def test_call_tool_malformed_json_string_args(self):
|
|
bot = ConcreteTestBot(mock_tools=[])
|
|
args_malformed_json_str = '''{"param": "value"'''
|
|
result = bot.call_tool("some_tool", args_malformed_json_str)
|
|
self.assertTrue("Error: Malformed arguments for tool call" in result)
|
|
|
|
def test_call_tool_unexpected_arg_type(self):
|
|
bot = ConcreteTestBot(mock_tools=[])
|
|
result = bot.call_tool("some_tool", 12345) # Integer instead of dict/str
|
|
self.assertTrue("Error: Invalid argument type for tool call" in result)
|
|
|
|
def test_call_tool_none_args(self):
|
|
mock_tool = MagicMock(spec=BaseTool)
|
|
mock_tool.get_functions.return_value = [
|
|
{"function": {"name": "test_tool_none", "parameters": {}}}
|
|
]
|
|
mock_tool.execute.return_value = "Tool None OK"
|
|
bot = ConcreteTestBot(mock_tools=[mock_tool], mock_functions=mock_tool.get_functions())
|
|
|
|
result = bot.call_tool("test_tool_none", None)
|
|
self.assertEqual(result, "Tool None OK")
|
|
mock_tool.execute.assert_called_once_with("test_tool_none") # No kwargs if None
|
|
|
|
def test_call_tool_not_found(self):
|
|
bot = ConcreteTestBot(mock_tools=[])
|
|
result = bot.call_tool("non_existent_tool", {})
|
|
self.assertEqual(result, "Error: Tool function non_existent_tool not found.")
|
|
|
|
def test_call_tool_execute_exception(self):
|
|
mock_tool = MagicMock(spec=BaseTool)
|
|
mock_tool.get_functions.return_value = [{"function": {"name": "error_tool", "parameters": {}}}]
|
|
mock_tool.execute.side_effect = Exception("Execution failed")
|
|
bot = ConcreteTestBot(mock_tools=[mock_tool], mock_functions=mock_tool.get_functions())
|
|
|
|
result = bot.call_tool("error_tool", {})
|
|
self.assertEqual(result, "Error executing tool error_tool: Execution failed")
|
|
|
|
def test_get_system_prompt_description(self):
|
|
if "SYSTEM_PROMPT_PATH" in os.environ: # Ensure clean state
|
|
del os.environ["SYSTEM_PROMPT_PATH"]
|
|
|
|
bot_default = ConcreteTestBot()
|
|
self.assertEqual(bot_default.get_system_prompt_description(), "System Prompt: Default")
|
|
|
|
bot_custom_content = ConcreteTestBot(system_prompt_content="Custom content here")
|
|
self.assertEqual(bot_custom_content.get_system_prompt_description(), "System Prompt: Custom")
|
|
|
|
os.environ["SYSTEM_PROMPT_PATH"] = "some/path.txt"
|
|
bot_env_default_prompt = ConcreteTestBot() # system_prompt itself is default
|
|
self.assertEqual(bot_env_default_prompt.get_system_prompt_description(), "System Prompt: Custom (via ENV)")
|
|
|
|
with patch("os.path.isfile", return_value=True), \
|
|
patch("builtins.open", mock_open(read_data="File prompt from ENV")):
|
|
bot_env_file_prompt = ConcreteTestBot() # system_prompt gets loaded from ENV path
|
|
self.assertEqual(bot_env_file_prompt.get_system_prompt_description(), "System Prompt: Custom")
|
|
del os.environ["SYSTEM_PROMPT_PATH"]
|
|
|
|
with patch("os.path.isfile", return_value=True), \
|
|
patch("builtins.open", mock_open(read_data="File prompt from arg")):
|
|
bot_custom_file_arg = ConcreteTestBot(system_prompt_path="custom/file.txt")
|
|
self.assertEqual(bot_custom_file_arg.get_system_prompt_description(), "System Prompt: Custom")
|
|
|
|
@patch.object(ConcreteTestBot, 'get_llm_description', return_value="Test LLM Description")
|
|
@patch.object(ConcreteTestBot, 'get_system_prompt_description', return_value="Test Prompt Description")
|
|
async def test_get_bot_status(self, mock_prompt_desc, mock_llm_desc):
|
|
bot = ConcreteTestBot()
|
|
status = await bot.get_bot_status()
|
|
self.assertEqual(status, "Test Prompt Description\nTest LLM Description")
|
|
mock_prompt_desc.assert_called_once()
|
|
mock_llm_desc.assert_called_once()
|
|
|
|
@patch('os.path.dirname', return_value='/mock/path')
|
|
@patch('os.path.join')
|
|
@patch('os.path.exists')
|
|
@patch('os.listdir')
|
|
@patch('importlib.import_module')
|
|
def test_load_functions_no_tools_dir(self, mock_import_module, mock_listdir, mock_exists, mock_join, mock_dirname):
|
|
mock_join.return_value = '/mock/path/tools'
|
|
mock_exists.return_value = False
|
|
|
|
class BotForLoadTest(BaseTelegramInferenceBot):
|
|
load_system_prompt = MagicMock(return_value="Default")
|
|
get_chat_response = MagicMock(); handle_message = MagicMock(); get_llm_description = MagicMock(return_value="mock")
|
|
start = MagicMock(); abort_processing = MagicMock(); switch_model = MagicMock()
|
|
|
|
bot = BotForLoadTest()
|
|
self.assertEqual(bot.tools, [])
|
|
self.assertEqual(bot.functions, [])
|
|
mock_listdir.assert_not_called()
|
|
|
|
@patch('os.path.dirname', return_value='/mock/base_bot_dir')
|
|
@patch('os.path.join', side_effect=lambda *args: os.path.normpath(os.path.join(*args)))
|
|
@patch('os.path.exists', return_value=True)
|
|
@patch('os.listdir', return_value=['my_tool.py', '__init__.py', 'base_tool.py'])
|
|
@patch('importlib.import_module')
|
|
def test_load_functions_with_one_tool(self, mock_import_module, mock_listdir, mock_exists, mock_join, mock_dirname):
|
|
|
|
mock_tool_class = MagicMock(spec=BaseTool) # This is the class itself
|
|
mock_tool_instance = MagicMock(spec=BaseTool) # This is the instance
|
|
mock_tool_class.return_value = mock_tool_instance # mock_tool_class() creates mock_tool_instance
|
|
mock_tool_instance.get_functions.return_value = [{"function": {"name": "sample_function"}}]
|
|
|
|
mock_my_tool_module = MagicMock()
|
|
# Simulate inspect.getmembers behavior: returns list of (name, member) tuples
|
|
# Only include members that are classes, derive from BaseTool, and are not BaseTool itself.
|
|
mock_my_tool_module.ValidToolClass = mock_tool_class
|
|
mock_my_tool_module.NotATool = object()
|
|
mock_my_tool_module.BaseTool = BaseTool # This should be skipped by the loader
|
|
|
|
def import_side_effect(module_name):
|
|
if module_name == 'tools.my_tool':
|
|
return mock_my_tool_module
|
|
raise ImportError(f"Unexpected import: {module_name}")
|
|
mock_import_module.side_effect = import_side_effect
|
|
|
|
class BotForLoadTest(BaseTelegramInferenceBot):
|
|
load_system_prompt = MagicMock(return_value="Default")
|
|
get_chat_response = MagicMock(); handle_message = MagicMock(); get_llm_description = MagicMock(return_value="mock")
|
|
start = MagicMock(); abort_processing = MagicMock(); switch_model = MagicMock()
|
|
|
|
bot = BotForLoadTest()
|
|
self.assertEqual(len(bot.tools), 1)
|
|
self.assertIs(bot.tools[0], mock_tool_instance)
|
|
self.assertEqual(len(bot.functions), 1)
|
|
self.assertEqual(bot.functions[0]['function']['name'], "sample_function")
|
|
mock_import_module.assert_called_once_with('tools.my_tool')
|
|
mock_tool_class.assert_called_once_with() # Tool class was instantiated
|
|
mock_tool_instance.get_functions.assert_called_once_with()
|
|
|
|
@patch('os.path.dirname', return_value='/mock/base_bot_dir')
|
|
@patch('os.path.join', side_effect=lambda *args: os.path.normpath(os.path.join(*args)))
|
|
@patch('os.path.exists', return_value=True)
|
|
@patch('os.listdir', return_value=['tool_with_init_error.py'])
|
|
@patch('importlib.import_module')
|
|
@patch('logging.error') # Mock logging to check for error messages
|
|
def test_load_functions_tool_instantiation_error(self, mock_logging_error, mock_import_module, mock_listdir, mock_exists, mock_join, mock_dirname):
|
|
mock_tool_class_init_error = MagicMock(spec=BaseTool)
|
|
mock_tool_class_init_error.side_effect = Exception("Failed to init tool") # Error on instantiation
|
|
|
|
mock_error_tool_module = MagicMock()
|
|
mock_error_tool_module.ToolWithInitError = mock_tool_class_init_error
|
|
|
|
mock_import_module.return_value = mock_error_tool_module
|
|
|
|
class BotForLoadTest(BaseTelegramInferenceBot):
|
|
load_system_prompt = MagicMock(return_value="Default")
|
|
get_chat_response = MagicMock(); handle_message = MagicMock(); get_llm_description = MagicMock(return_value="mock")
|
|
start = MagicMock(); abort_processing = MagicMock(); switch_model = MagicMock()
|
|
|
|
bot = BotForLoadTest()
|
|
self.assertEqual(len(bot.tools), 0)
|
|
self.assertEqual(len(bot.functions), 0)
|
|
mock_logging_error.assert_any_call("Error instantiating tool ToolWithInitError from tool_with_init_error.py: Failed to init tool")
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main(闂傚лен䦗婢у〃埊鍓解劓姣)
|