diff --git a/tests/test_openai_compatible_inference_bot.py b/tests/test_openai_compatible_inference_bot.py new file mode 100644 index 0000000..dc667c0 --- /dev/null +++ b/tests/test_openai_compatible_inference_bot.py @@ -0,0 +1,332 @@ +import unittest +from unittest.mock import MagicMock, patch, AsyncMock, ANY +import os +import json + +# Assuming openai_compatible_inference_bot.py is in the parent directory or PYTHONPATH is set +from openai_compatible_inference_bot import OpenAICompatibleInferenceBot + +# Mock response from OpenAI client's chat.completions.create +def create_mock_openai_response(content=None, tool_calls=None): + mock_message = MagicMock() + mock_message.role = "assistant" + mock_message.content = content + if tool_calls: + # tool_calls should be a list of objects with id and function (name, arguments) + mock_tool_calls = [] + for tc in tool_calls: + mock_tc = MagicMock() + mock_tc.id = tc["id"] + mock_tc.function.name = tc["function"]["name"] + mock_tc.function.arguments = tc["function"]["arguments"] + mock_tool_calls.append(mock_tc) + mock_message.tool_calls = mock_tool_calls + else: + mock_message.tool_calls = None + + mock_choice = MagicMock() + mock_choice.message = mock_message + + mock_response = MagicMock() + mock_response.choices = [mock_choice] + return mock_response + +# Concrete class for testing +class ConcreteOpenAICompatibleBot(OpenAICompatibleInferenceBot): + # Implement abstract methods for instantiation + async def switch_model(self): + # Simple switch for testing if needed, or just pass + if self.model == self.small_model_name: + self._configure_model_and_tokens(self.large_model_name, self.large_model_max_tokens_str) + else: + self._configure_model_and_tokens(self.small_model_name, self.small_model_max_tokens_str) + return f"Switched to {self.model}" + + # Override load_functions if it's called by parent and needs mocking for these tests + # (OpenAICompatibleInferenceBot's __init__ calls BaseTelegramInferenceBot's __init__, which calls load_functions) + def load_functions(self): + # For these tests, assume no tools unless specifically added + self.tools = [] + self.functions = [] + return self.tools, self.functions + + +class TestOpenAICompatibleInferenceBot(unittest.IsolatedAsyncioTestCase): + + def setUp(self): + self.original_openai_api_key = os.environ.get("OPENAI_API_KEY") + self.original_azure_openai_key = os.environ.get("AZURE_OPENAI_KEY") + self.original_azure_endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT") + self.original_api_version = os.environ.get("AZURE_OPENAI_API_VERSION") + self.original_azure_deployment = os.environ.get("AZURE_DEPLOYMENT_NAME") + + # Clear relevant env vars before each test + for key in ["OPENAI_API_KEY", "AZURE_OPENAI_KEY", "AZURE_OPENAI_ENDPOINT", + "AZURE_OPENAI_API_VERSION", "AZURE_DEPLOYMENT_NAME", "SYSTEM_PROMPT_PATH"]: + if os.environ.get(key): + del os.environ[key] + + self.mock_openai_client_instance = MagicMock() + self.mock_openai_client_instance.chat.completions.create = MagicMock() + + def tearDown(self): + # Restore environment variables + if self.original_openai_api_key: os.environ["OPENAI_API_KEY"] = self.original_openai_api_key + if self.original_azure_openai_key: os.environ["AZURE_OPENAI_KEY"] = self.original_azure_openai_key + if self.original_azure_endpoint: os.environ["AZURE_OPENAI_ENDPOINT"] = self.original_azure_endpoint + if self.original_api_version: os.environ["AZURE_OPENAI_API_VERSION"] = self.original_api_version + if self.original_azure_deployment: os.environ["AZURE_DEPLOYMENT_NAME"] = self.original_azure_deployment + + + @patch('openai.OpenAI') + def test_init_with_openai_defaults(self, MockOpenAIConstructor): + MockOpenAIConstructor.return_value = self.mock_openai_client_instance + os.environ["OPENAI_API_KEY"] = "test_openai_key" + + bot = ConcreteOpenAICompatibleBot(model_name="gpt-4") + + MockOpenAIConstructor.assert_called_once_with(api_key="test_openai_key", base_url=None) + self.assertEqual(bot.client, self.mock_openai_client_instance) + self.assertEqual(bot.model, "gpt-4") + self.assertEqual(bot.max_tokens, 1000) # Default from _configure_model_and_tokens + self.assertEqual(bot.azure_openai, False) + + @patch('openai.OpenAI') + def test_init_with_provided_client(self, MockOpenAIConstructor): + preconfigured_client = MagicMock() + bot = ConcreteOpenAICompatibleBot(client=preconfigured_client, model_name="gpt-3.5") + + MockOpenAIConstructor.assert_not_called() + self.assertEqual(bot.client, preconfigured_client) + self.assertEqual(bot.model, "gpt-3.5") + + @patch('openai.AzureOpenAI') + def test_init_with_azure_config_args(self, MockAzureOpenAIConstructor): + MockAzureOpenAIConstructor.return_value = self.mock_openai_client_instance + + bot = ConcreteOpenAICompatibleBot( + api_key="azure_key", + azure_endpoint="https://myenv.openai.azure.com", + api_version="2023-05-15", + azure_deployment="my-gpt-4", # This should be used as model_name for API call + model_name="should_be_overridden_by_azure_deployment_for_api" + # model_name is passed to _configure_model_and_tokens, which sets self.model for display/logging + # but for Azure, the client needs the deployment name. + ) + + MockAzureOpenAIConstructor.assert_called_once_with( + api_key="azure_key", + azure_endpoint="https://myenv.openai.azure.com", + api_version="2023-05-15" + ) + self.assertEqual(bot.client, self.mock_openai_client_instance) + self.assertEqual(bot.model, "my-gpt-4") # Azure deployment name becomes the model for API calls + self.assertEqual(bot.azure_openai, True) + + + @patch('openai.AzureOpenAI') + def test_init_with_azure_env_vars(self, MockAzureOpenAIConstructor): + MockAzureOpenAIConstructor.return_value = self.mock_openai_client_instance + os.environ["AZURE_OPENAI_KEY"] = "env_azure_key" + os.environ["AZURE_OPENAI_ENDPOINT"] = "https://env.openai.azure.com" + os.environ["AZURE_OPENAI_API_VERSION"] = "2023-06-01" + os.environ["AZURE_DEPLOYMENT_NAME"] = "env-gpt-35" # Used as model_name + + bot = ConcreteOpenAICompatibleBot(model_name="ignored_if_azure_deployment_env_is_set") + + MockAzureOpenAIConstructor.assert_called_once_with( + api_key="env_azure_key", + azure_endpoint="https://env.openai.azure.com", + api_version="2023-06-01" + ) + self.assertEqual(bot.model, "env-gpt-35") + self.assertTrue(bot.azure_openai) + + @patch('openai.OpenAI') + def test_init_with_gemini_config_args(self, MockOpenAIConstructor): + MockOpenAIConstructor.return_value = self.mock_openai_client_instance + + bot = ConcreteOpenAICompatibleBot( + api_key="gemini_key", + base_url="https://gemini.example.com", + model_name="gemini-pro", + is_gemini=True + ) + MockOpenAIConstructor.assert_called_once_with(api_key="gemini_key", base_url="https://gemini.example.com") + self.assertEqual(bot.model, "gemini-pro") + self.assertFalse(bot.azure_openai) # is_gemini doesn't mean azure_openai + + def test_configure_model_and_tokens(self): + bot = ConcreteOpenAICompatibleBot(model_name="initial_model") # init calls _configure + bot._configure_model_and_tokens("test-model", "500") + self.assertEqual(bot.model, "test-model") + self.assertEqual(bot.max_tokens, 500) + + bot._configure_model_and_tokens("test-model-2", None, default_max_tokens=150) + self.assertEqual(bot.max_tokens, 150) + + bot._configure_model_and_tokens("test-model-3", "invalid_token_val") + self.assertEqual(bot.max_tokens, 1000) # Default fallback + + def test_get_llm_description(self): + bot = ConcreteOpenAICompatibleBot(model_name="desc-model", max_tokens_str="256") + self.assertEqual(bot.get_llm_description(), "LLM: desc-model, Max Tokens: 256, Azure: False") + + bot_azure = ConcreteOpenAICompatibleBot(azure_deployment="azure-model", azure_endpoint="x", api_key="y", api_version="z") + self.assertEqual(bot_azure.get_llm_description(), "LLM: azure-model, Max Tokens: 1000, Azure: True") + + + def test_get_chat_response_success(self): + bot = ConcreteOpenAICompatibleBot(client=self.mock_openai_client_instance, model_name="test-gpt") + bot.max_tokens = 50 # Ensure this is set + mock_api_response = create_mock_openai_response(content="Hello from API") + self.mock_openai_client_instance.chat.completions.create.return_value = mock_api_response + + messages = [{"role": "user", "content": "Hi"}] + response = bot.get_chat_response(messages) + + self.mock_openai_client_instance.chat.completions.create.assert_called_once_with( + model="test-gpt", + messages=messages, + tools=ANY, # Assuming functions can be None or empty list + tool_choice=ANY, + max_tokens=50 + ) + self.assertEqual(response, mock_api_response) + + def test_get_chat_response_api_error(self): + bot = ConcreteOpenAICompatibleBot(client=self.mock_openai_client_instance, model_name="error-gpt") + self.mock_openai_client_instance.chat.completions.create.side_effect = Exception("API Down") + + with self.assertRaisesRegex(Exception, "API Down"): + bot.get_chat_response([{"role": "user", "content": "trigger"}]) + + async def test_handle_message_simple_response(self): + bot = ConcreteOpenAICompatibleBot(client=self.mock_openai_client_instance, model_name="chatty") + bot.system_prompt = "You are a test bot." # Set directly for simplicity + mock_api_response = create_mock_openai_response(content="Test reply") + self.mock_openai_client_instance.chat.completions.create.return_value = mock_api_response + + response_content = await bot.handle_message(user_id=1, user_message="Hello") + + self.assertEqual(response_content, "Test reply") + self.assertIn(1, bot.conversation_history) + self.assertEqual(len(bot.conversation_history[1]), 3) # System, User, Assistant + self.assertEqual(bot.conversation_history[1][0]["content"], "You are a test bot.") + self.assertEqual(bot.conversation_history[1][2]["content"], "Test reply") + + async def test_handle_message_with_tool_call_and_response(self): + bot = ConcreteOpenAICompatibleBot(client=self.mock_openai_client_instance, model_name="tool-user") + + # Mock functions/tools setup on the bot + mock_tool_def = {"function": {"name": "get_weather", "description": "Gets weather", "parameters": {}}} + bot.functions = [mock_tool_def] # Simulate tools are loaded + + # API response 1: Request to call tool + tool_call_request = [{"id": "call123", "function": {"name": "get_weather", "arguments": '''{"location": "moon"}'''}}] + api_response_1 = create_mock_openai_response(tool_calls=tool_call_request) + + # API response 2: Final answer after tool execution + api_response_2 = create_mock_openai_response(content="The weather on the moon is chilly.") + + self.mock_openai_client_instance.chat.completions.create.side_effect = [api_response_1, api_response_2] + + # Mock self.call_tool + bot.call_tool = MagicMock(return_value='''{"temperature": "-100 C"}''') + + final_response = await bot.handle_message(user_id=2, user_message="Weather on moon?") + + self.assertEqual(final_response, "The weather on the moon is chilly.") + bot.call_tool.assert_called_once_with("get_weather", '''{"location": "moon"}''') + + # Check conversation history includes tool messages + history = bot.conversation_history[2] + self.assertTrue(any(msg["role"] == "assistant" and msg.tool_calls is not None for msg in history)) + self.assertTrue(any(msg["role"] == "tool" and msg["name"] == "get_weather" for msg in history)) + self.assertEqual(self.mock_openai_client_instance.chat.completions.create.call_count, 2) + + async def test_handle_message_max_history_length(self): + bot = ConcreteOpenAICompatibleBot(client=self.mock_openai_client_instance, model_name="hist-test", max_history_length=3) + self.mock_openai_client_instance.chat.completions.create.return_value = create_mock_openai_response(content="Ok") + + await bot.handle_message(1, "Msg1") # Sys, User, Assist (3) + self.assertEqual(len(bot.conversation_history[1]), 3) + + await bot.handle_message(1, "Msg2") # User, Assist. Should be 3 (prev User, prev Assist, new User) -> then adds new Assist. + # Before new call: [Sys, U1, A1]. New U2. Call with [Sys,U1,A1,U2]. Resp A2. + # History: [Sys,U1,A1,U2,A2]. Limit 3. -> [A1,U2,A2] (if system is not preserved specially) + # The current code appends to history then truncates if over limit. + # So after Msg1: [S, U1, A1]. len=3. + # For Msg2: History is [S, U1, A1]. Append U2. Call with [S,U1,A1,U2]. Append A2. + # History now [S,U1,A1,U2,A2]. len=5. Truncate to 3. + # Expected: [A1, U2, A2] or [U1,A1,U2] or [U2,A2,S] depending on how system prompt is handled in truncation. + # The code is: self.conversation_history[user_id][-self.max_history_length:] + # And system prompt is only added IF user_id not in self.conversation_history. + # So, for Msg2, system prompt is not re-added. + # History before Msg2 call: [S, U1, A1] + # Messages for Msg2 call: [S, U1, A1, U2] + # History after Msg2 response A2: [S, U1, A1, U2, A2]. Len 5. + # Truncated to self.max_history_length=3: [A1, U2, A2] + + # Call 1 + self.mock_openai_client_instance.chat.completions.create.reset_mock() + self.mock_openai_client_instance.chat.completions.create.return_value = create_mock_openai_response(content="Reply1") + await bot.handle_message(user_id=7, user_message="First message") + self.assertEqual(len(bot.conversation_history[7]), 3) # System, User1, Assistant1 + + # Call 2 + self.mock_openai_client_instance.chat.completions.create.reset_mock() + self.mock_openai_client_instance.chat.completions.create.return_value = create_mock_openai_response(content="Reply2") + await bot.handle_message(user_id=7, user_message="Second message") + # History before call: [S, U1, A1]. Messages for call: [S, U1, A1, U2]. History after: [S, U1, A1, U2, A2]. + # Truncated to 3: [A1, U2, A2] + self.assertEqual(len(bot.conversation_history[7]), 3) + self.assertEqual(bot.conversation_history[7][0]["content"], "Reply1") # A1 + self.assertEqual(bot.conversation_history[7][1]["content"], "Second message") # U2 + self.assertEqual(bot.conversation_history[7][2]["content"], "Reply2") # A2 + + # Call 3 + self.mock_openai_client_instance.chat.completions.create.reset_mock() + self.mock_openai_client_instance.chat.completions.create.return_value = create_mock_openai_response(content="Reply3") + await bot.handle_message(user_id=7, user_message="Third message") + # History before call: [A1, U2, A2]. Messages for call: [A1, U2, A2, U3]. History after: [A1, U2, A2, U3, A3]. + # Truncated to 3: [A2, U3, A3] + self.assertEqual(len(bot.conversation_history[7]), 3) + self.assertEqual(bot.conversation_history[7][0]["content"], "Reply2") # A2 + self.assertEqual(bot.conversation_history[7][1]["content"], "Third message") # U3 + self.assertEqual(bot.conversation_history[7][2]["content"], "Reply3") # A3 + + + async def test_abort_processing(self): + bot = ConcreteOpenAICompatibleBot(model_name="test") + user_id = 123 + bot.processing_status[user_id] = {"processing": True, "message_id": 456} + bot.conversation_history[user_id] = [{"role": "user", "content": "stuff"}] + + with patch.object(bot, 'clear_conversation_history') as mock_clear_hist: # Patching the method from Base class + result = await bot.abort_processing(user_id) + + self.assertEqual(result, "Processing aborted and conversation cleared.") + self.assertFalse(bot.processing_status[user_id]["processing"]) + mock_clear_hist.assert_called_once_with(user_id) + + async def test_abort_processing_no_active_processing(self): + bot = ConcreteOpenAICompatibleBot(model_name="test") + user_id = 404 # Not in processing_status + with patch.object(bot, 'clear_conversation_history') as mock_clear_hist: + result = await bot.abort_processing(user_id) + self.assertEqual(result, "No active processing found to abort. Conversation cleared.") + mock_clear_hist.assert_called_once_with(user_id) + + # Test for the abstract switch_model (basic call, actual logic in concrete class for this test) + async def test_switch_model_concrete_implementation(self): + bot = ConcreteOpenAICompatibleBot(model_name="model1", small_model_name="model1", large_model_name="model2", max_tokens_str="100") + self.assertEqual(bot.model, "model1") + await bot.switch_model() # Calls the concrete implementation + self.assertEqual(bot.model, "model2") + await bot.switch_model() + self.assertEqual(bot.model, "model1") + + +if __name__ == '__main__': + unittest.main()