import unittest from unittest.mock import MagicMock, patch, AsyncMock, ANY import os import json # Assuming openai_compatible_inference_bot.py is in the parent directory or PYTHONPATH is set from openai_compatible_inference_bot import OpenAICompatibleInferenceBot # Mock response from OpenAI client's chat.completions.create def create_mock_openai_response(content=None, tool_calls=None): mock_message = MagicMock() mock_message.role = "assistant" mock_message.content = content if tool_calls: # tool_calls should be a list of objects with id and function (name, arguments) mock_tool_calls = [] for tc in tool_calls: mock_tc = MagicMock() mock_tc.id = tc["id"] mock_tc.function.name = tc["function"]["name"] mock_tc.function.arguments = tc["function"]["arguments"] mock_tool_calls.append(mock_tc) mock_message.tool_calls = mock_tool_calls else: mock_message.tool_calls = None mock_choice = MagicMock() mock_choice.message = mock_message mock_response = MagicMock() mock_response.choices = [mock_choice] return mock_response # Concrete class for testing class ConcreteOpenAICompatibleBot(OpenAICompatibleInferenceBot): # Implement abstract methods for instantiation async def switch_model(self): # Simple switch for testing if needed, or just pass if self.model == self.small_model_name: self._configure_model_and_tokens(self.large_model_name, self.large_model_max_tokens_str) else: self._configure_model_and_tokens(self.small_model_name, self.small_model_max_tokens_str) return f"Switched to {self.model}" # Override load_functions if it's called by parent and needs mocking for these tests # (OpenAICompatibleInferenceBot's __init__ calls BaseTelegramInferenceBot's __init__, which calls load_functions) def load_functions(self): # For these tests, assume no tools unless specifically added self.tools = [] self.functions = [] return self.tools, self.functions class TestOpenAICompatibleInferenceBot(unittest.IsolatedAsyncioTestCase): def setUp(self): self.original_openai_api_key = os.environ.get("OPENAI_API_KEY") self.original_azure_openai_key = os.environ.get("AZURE_OPENAI_KEY") self.original_azure_endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT") self.original_api_version = os.environ.get("AZURE_OPENAI_API_VERSION") self.original_azure_deployment = os.environ.get("AZURE_DEPLOYMENT_NAME") # Clear relevant env vars before each test for key in ["OPENAI_API_KEY", "AZURE_OPENAI_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_API_VERSION", "AZURE_DEPLOYMENT_NAME", "SYSTEM_PROMPT_PATH"]: if os.environ.get(key): del os.environ[key] self.mock_openai_client_instance = MagicMock() self.mock_openai_client_instance.chat.completions.create = MagicMock() def tearDown(self): # Restore environment variables if self.original_openai_api_key: os.environ["OPENAI_API_KEY"] = self.original_openai_api_key if self.original_azure_openai_key: os.environ["AZURE_OPENAI_KEY"] = self.original_azure_openai_key if self.original_azure_endpoint: os.environ["AZURE_OPENAI_ENDPOINT"] = self.original_azure_endpoint if self.original_api_version: os.environ["AZURE_OPENAI_API_VERSION"] = self.original_api_version if self.original_azure_deployment: os.environ["AZURE_DEPLOYMENT_NAME"] = self.original_azure_deployment @patch('openai.OpenAI') def test_init_with_openai_defaults(self, MockOpenAIConstructor): MockOpenAIConstructor.return_value = self.mock_openai_client_instance os.environ["OPENAI_API_KEY"] = "test_openai_key" bot = ConcreteOpenAICompatibleBot(model_name="gpt-4") MockOpenAIConstructor.assert_called_once_with(api_key="test_openai_key", base_url=None) self.assertEqual(bot.client, self.mock_openai_client_instance) self.assertEqual(bot.model, "gpt-4") self.assertEqual(bot.max_tokens, 1000) # Default from _configure_model_and_tokens self.assertEqual(bot.azure_openai, False) @patch('openai.OpenAI') def test_init_with_provided_client(self, MockOpenAIConstructor): preconfigured_client = MagicMock() bot = ConcreteOpenAICompatibleBot(client=preconfigured_client, model_name="gpt-3.5") MockOpenAIConstructor.assert_not_called() self.assertEqual(bot.client, preconfigured_client) self.assertEqual(bot.model, "gpt-3.5") @patch('openai.AzureOpenAI') def test_init_with_azure_config_args(self, MockAzureOpenAIConstructor): MockAzureOpenAIConstructor.return_value = self.mock_openai_client_instance bot = ConcreteOpenAICompatibleBot( api_key="azure_key", azure_endpoint="https://myenv.openai.azure.com", api_version="2023-05-15", azure_deployment="my-gpt-4", # This should be used as model_name for API call model_name="should_be_overridden_by_azure_deployment_for_api" # model_name is passed to _configure_model_and_tokens, which sets self.model for display/logging # but for Azure, the client needs the deployment name. ) MockAzureOpenAIConstructor.assert_called_once_with( api_key="azure_key", azure_endpoint="https://myenv.openai.azure.com", api_version="2023-05-15" ) self.assertEqual(bot.client, self.mock_openai_client_instance) self.assertEqual(bot.model, "my-gpt-4") # Azure deployment name becomes the model for API calls self.assertEqual(bot.azure_openai, True) @patch('openai.AzureOpenAI') def test_init_with_azure_env_vars(self, MockAzureOpenAIConstructor): MockAzureOpenAIConstructor.return_value = self.mock_openai_client_instance os.environ["AZURE_OPENAI_KEY"] = "env_azure_key" os.environ["AZURE_OPENAI_ENDPOINT"] = "https://env.openai.azure.com" os.environ["AZURE_OPENAI_API_VERSION"] = "2023-06-01" os.environ["AZURE_DEPLOYMENT_NAME"] = "env-gpt-35" # Used as model_name bot = ConcreteOpenAICompatibleBot(model_name="ignored_if_azure_deployment_env_is_set") MockAzureOpenAIConstructor.assert_called_once_with( api_key="env_azure_key", azure_endpoint="https://env.openai.azure.com", api_version="2023-06-01" ) self.assertEqual(bot.model, "env-gpt-35") self.assertTrue(bot.azure_openai) @patch('openai.OpenAI') def test_init_with_gemini_config_args(self, MockOpenAIConstructor): MockOpenAIConstructor.return_value = self.mock_openai_client_instance bot = ConcreteOpenAICompatibleBot( api_key="gemini_key", base_url="https://gemini.example.com", model_name="gemini-pro", is_gemini=True ) MockOpenAIConstructor.assert_called_once_with(api_key="gemini_key", base_url="https://gemini.example.com") self.assertEqual(bot.model, "gemini-pro") self.assertFalse(bot.azure_openai) # is_gemini doesn't mean azure_openai def test_configure_model_and_tokens(self): bot = ConcreteOpenAICompatibleBot(model_name="initial_model") # init calls _configure bot._configure_model_and_tokens("test-model", "500") self.assertEqual(bot.model, "test-model") self.assertEqual(bot.max_tokens, 500) bot._configure_model_and_tokens("test-model-2", None, default_max_tokens=150) self.assertEqual(bot.max_tokens, 150) bot._configure_model_and_tokens("test-model-3", "invalid_token_val") self.assertEqual(bot.max_tokens, 1000) # Default fallback def test_get_llm_description(self): bot = ConcreteOpenAICompatibleBot(model_name="desc-model", max_tokens_str="256") self.assertEqual(bot.get_llm_description(), "LLM: desc-model, Max Tokens: 256, Azure: False") bot_azure = ConcreteOpenAICompatibleBot(azure_deployment="azure-model", azure_endpoint="x", api_key="y", api_version="z") self.assertEqual(bot_azure.get_llm_description(), "LLM: azure-model, Max Tokens: 1000, Azure: True") def test_get_chat_response_success(self): bot = ConcreteOpenAICompatibleBot(client=self.mock_openai_client_instance, model_name="test-gpt") bot.max_tokens = 50 # Ensure this is set mock_api_response = create_mock_openai_response(content="Hello from API") self.mock_openai_client_instance.chat.completions.create.return_value = mock_api_response messages = [{"role": "user", "content": "Hi"}] response = bot.get_chat_response(messages) self.mock_openai_client_instance.chat.completions.create.assert_called_once_with( model="test-gpt", messages=messages, tools=ANY, # Assuming functions can be None or empty list tool_choice=ANY, max_tokens=50 ) self.assertEqual(response, mock_api_response) def test_get_chat_response_api_error(self): bot = ConcreteOpenAICompatibleBot(client=self.mock_openai_client_instance, model_name="error-gpt") self.mock_openai_client_instance.chat.completions.create.side_effect = Exception("API Down") with self.assertRaisesRegex(Exception, "API Down"): bot.get_chat_response([{"role": "user", "content": "trigger"}]) async def test_handle_message_simple_response(self): bot = ConcreteOpenAICompatibleBot(client=self.mock_openai_client_instance, model_name="chatty") bot.system_prompt = "You are a test bot." # Set directly for simplicity mock_api_response = create_mock_openai_response(content="Test reply") self.mock_openai_client_instance.chat.completions.create.return_value = mock_api_response response_content = await bot.handle_message(user_id=1, user_message="Hello") self.assertEqual(response_content, "Test reply") self.assertIn(1, bot.conversation_history) self.assertEqual(len(bot.conversation_history[1]), 3) # System, User, Assistant self.assertEqual(bot.conversation_history[1][0]["content"], "You are a test bot.") self.assertEqual(bot.conversation_history[1][2]["content"], "Test reply") async def test_handle_message_with_tool_call_and_response(self): bot = ConcreteOpenAICompatibleBot(client=self.mock_openai_client_instance, model_name="tool-user") # Mock functions/tools setup on the bot mock_tool_def = {"function": {"name": "get_weather", "description": "Gets weather", "parameters": {}}} bot.functions = [mock_tool_def] # Simulate tools are loaded # API response 1: Request to call tool tool_call_request = [{"id": "call123", "function": {"name": "get_weather", "arguments": '''{"location": "moon"}'''}}] api_response_1 = create_mock_openai_response(tool_calls=tool_call_request) # API response 2: Final answer after tool execution api_response_2 = create_mock_openai_response(content="The weather on the moon is chilly.") self.mock_openai_client_instance.chat.completions.create.side_effect = [api_response_1, api_response_2] # Mock self.call_tool bot.call_tool = MagicMock(return_value='''{"temperature": "-100 C"}''') final_response = await bot.handle_message(user_id=2, user_message="Weather on moon?") self.assertEqual(final_response, "The weather on the moon is chilly.") bot.call_tool.assert_called_once_with("get_weather", '''{"location": "moon"}''') # Check conversation history includes tool messages history = bot.conversation_history[2] self.assertTrue(any(msg["role"] == "assistant" and msg.tool_calls is not None for msg in history)) self.assertTrue(any(msg["role"] == "tool" and msg["name"] == "get_weather" for msg in history)) self.assertEqual(self.mock_openai_client_instance.chat.completions.create.call_count, 2) async def test_handle_message_max_history_length(self): bot = ConcreteOpenAICompatibleBot(client=self.mock_openai_client_instance, model_name="hist-test", max_history_length=3) self.mock_openai_client_instance.chat.completions.create.return_value = create_mock_openai_response(content="Ok") await bot.handle_message(1, "Msg1") # Sys, User, Assist (3) self.assertEqual(len(bot.conversation_history[1]), 3) await bot.handle_message(1, "Msg2") # User, Assist. Should be 3 (prev User, prev Assist, new User) -> then adds new Assist. # Before new call: [Sys, U1, A1]. New U2. Call with [Sys,U1,A1,U2]. Resp A2. # History: [Sys,U1,A1,U2,A2]. Limit 3. -> [A1,U2,A2] (if system is not preserved specially) # The current code appends to history then truncates if over limit. # So after Msg1: [S, U1, A1]. len=3. # For Msg2: History is [S, U1, A1]. Append U2. Call with [S,U1,A1,U2]. Append A2. # History now [S,U1,A1,U2,A2]. len=5. Truncate to 3. # Expected: [A1, U2, A2] or [U1,A1,U2] or [U2,A2,S] depending on how system prompt is handled in truncation. # The code is: self.conversation_history[user_id][-self.max_history_length:] # And system prompt is only added IF user_id not in self.conversation_history. # So, for Msg2, system prompt is not re-added. # History before Msg2 call: [S, U1, A1] # Messages for Msg2 call: [S, U1, A1, U2] # History after Msg2 response A2: [S, U1, A1, U2, A2]. Len 5. # Truncated to self.max_history_length=3: [A1, U2, A2] # Call 1 self.mock_openai_client_instance.chat.completions.create.reset_mock() self.mock_openai_client_instance.chat.completions.create.return_value = create_mock_openai_response(content="Reply1") await bot.handle_message(user_id=7, user_message="First message") self.assertEqual(len(bot.conversation_history[7]), 3) # System, User1, Assistant1 # Call 2 self.mock_openai_client_instance.chat.completions.create.reset_mock() self.mock_openai_client_instance.chat.completions.create.return_value = create_mock_openai_response(content="Reply2") await bot.handle_message(user_id=7, user_message="Second message") # History before call: [S, U1, A1]. Messages for call: [S, U1, A1, U2]. History after: [S, U1, A1, U2, A2]. # Truncated to 3: [A1, U2, A2] self.assertEqual(len(bot.conversation_history[7]), 3) self.assertEqual(bot.conversation_history[7][0]["content"], "Reply1") # A1 self.assertEqual(bot.conversation_history[7][1]["content"], "Second message") # U2 self.assertEqual(bot.conversation_history[7][2]["content"], "Reply2") # A2 # Call 3 self.mock_openai_client_instance.chat.completions.create.reset_mock() self.mock_openai_client_instance.chat.completions.create.return_value = create_mock_openai_response(content="Reply3") await bot.handle_message(user_id=7, user_message="Third message") # History before call: [A1, U2, A2]. Messages for call: [A1, U2, A2, U3]. History after: [A1, U2, A2, U3, A3]. # Truncated to 3: [A2, U3, A3] self.assertEqual(len(bot.conversation_history[7]), 3) self.assertEqual(bot.conversation_history[7][0]["content"], "Reply2") # A2 self.assertEqual(bot.conversation_history[7][1]["content"], "Third message") # U3 self.assertEqual(bot.conversation_history[7][2]["content"], "Reply3") # A3 async def test_abort_processing(self): bot = ConcreteOpenAICompatibleBot(model_name="test") user_id = 123 bot.processing_status[user_id] = {"processing": True, "message_id": 456} bot.conversation_history[user_id] = [{"role": "user", "content": "stuff"}] with patch.object(bot, 'clear_conversation_history') as mock_clear_hist: # Patching the method from Base class result = await bot.abort_processing(user_id) self.assertEqual(result, "Processing aborted and conversation cleared.") self.assertFalse(bot.processing_status[user_id]["processing"]) mock_clear_hist.assert_called_once_with(user_id) async def test_abort_processing_no_active_processing(self): bot = ConcreteOpenAICompatibleBot(model_name="test") user_id = 404 # Not in processing_status with patch.object(bot, 'clear_conversation_history') as mock_clear_hist: result = await bot.abort_processing(user_id) self.assertEqual(result, "No active processing found to abort. Conversation cleared.") mock_clear_hist.assert_called_once_with(user_id) # Test for the abstract switch_model (basic call, actual logic in concrete class for this test) async def test_switch_model_concrete_implementation(self): bot = ConcreteOpenAICompatibleBot(model_name="model1", small_model_name="model1", large_model_name="model2", max_tokens_str="100") self.assertEqual(bot.model, "model1") await bot.switch_model() # Calls the concrete implementation self.assertEqual(bot.model, "model2") await bot.switch_model() self.assertEqual(bot.model, "model1") if __name__ == '__main__': unittest.main()