Add unit tests for OpenAICompatibleInferenceBot

This commit is contained in:
cyclop-bot
2025-06-02 16:50:48 -05:00
parent 59be9dbb5d
commit 3ac1c350ab
@@ -0,0 +1,332 @@
import unittest
from unittest.mock import MagicMock, patch, AsyncMock, ANY
import os
import json
# Assuming openai_compatible_inference_bot.py is in the parent directory or PYTHONPATH is set
from openai_compatible_inference_bot import OpenAICompatibleInferenceBot
# Mock response from OpenAI client's chat.completions.create
def create_mock_openai_response(content=None, tool_calls=None):
mock_message = MagicMock()
mock_message.role = "assistant"
mock_message.content = content
if tool_calls:
# tool_calls should be a list of objects with id and function (name, arguments)
mock_tool_calls = []
for tc in tool_calls:
mock_tc = MagicMock()
mock_tc.id = tc["id"]
mock_tc.function.name = tc["function"]["name"]
mock_tc.function.arguments = tc["function"]["arguments"]
mock_tool_calls.append(mock_tc)
mock_message.tool_calls = mock_tool_calls
else:
mock_message.tool_calls = None
mock_choice = MagicMock()
mock_choice.message = mock_message
mock_response = MagicMock()
mock_response.choices = [mock_choice]
return mock_response
# Concrete class for testing
class ConcreteOpenAICompatibleBot(OpenAICompatibleInferenceBot):
# Implement abstract methods for instantiation
async def switch_model(self):
# Simple switch for testing if needed, or just pass
if self.model == self.small_model_name:
self._configure_model_and_tokens(self.large_model_name, self.large_model_max_tokens_str)
else:
self._configure_model_and_tokens(self.small_model_name, self.small_model_max_tokens_str)
return f"Switched to {self.model}"
# Override load_functions if it's called by parent and needs mocking for these tests
# (OpenAICompatibleInferenceBot's __init__ calls BaseTelegramInferenceBot's __init__, which calls load_functions)
def load_functions(self):
# For these tests, assume no tools unless specifically added
self.tools = []
self.functions = []
return self.tools, self.functions
class TestOpenAICompatibleInferenceBot(unittest.IsolatedAsyncioTestCase):
def setUp(self):
self.original_openai_api_key = os.environ.get("OPENAI_API_KEY")
self.original_azure_openai_key = os.environ.get("AZURE_OPENAI_KEY")
self.original_azure_endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT")
self.original_api_version = os.environ.get("AZURE_OPENAI_API_VERSION")
self.original_azure_deployment = os.environ.get("AZURE_DEPLOYMENT_NAME")
# Clear relevant env vars before each test
for key in ["OPENAI_API_KEY", "AZURE_OPENAI_KEY", "AZURE_OPENAI_ENDPOINT",
"AZURE_OPENAI_API_VERSION", "AZURE_DEPLOYMENT_NAME", "SYSTEM_PROMPT_PATH"]:
if os.environ.get(key):
del os.environ[key]
self.mock_openai_client_instance = MagicMock()
self.mock_openai_client_instance.chat.completions.create = MagicMock()
def tearDown(self):
# Restore environment variables
if self.original_openai_api_key: os.environ["OPENAI_API_KEY"] = self.original_openai_api_key
if self.original_azure_openai_key: os.environ["AZURE_OPENAI_KEY"] = self.original_azure_openai_key
if self.original_azure_endpoint: os.environ["AZURE_OPENAI_ENDPOINT"] = self.original_azure_endpoint
if self.original_api_version: os.environ["AZURE_OPENAI_API_VERSION"] = self.original_api_version
if self.original_azure_deployment: os.environ["AZURE_DEPLOYMENT_NAME"] = self.original_azure_deployment
@patch('openai.OpenAI')
def test_init_with_openai_defaults(self, MockOpenAIConstructor):
MockOpenAIConstructor.return_value = self.mock_openai_client_instance
os.environ["OPENAI_API_KEY"] = "test_openai_key"
bot = ConcreteOpenAICompatibleBot(model_name="gpt-4")
MockOpenAIConstructor.assert_called_once_with(api_key="test_openai_key", base_url=None)
self.assertEqual(bot.client, self.mock_openai_client_instance)
self.assertEqual(bot.model, "gpt-4")
self.assertEqual(bot.max_tokens, 1000) # Default from _configure_model_and_tokens
self.assertEqual(bot.azure_openai, False)
@patch('openai.OpenAI')
def test_init_with_provided_client(self, MockOpenAIConstructor):
preconfigured_client = MagicMock()
bot = ConcreteOpenAICompatibleBot(client=preconfigured_client, model_name="gpt-3.5")
MockOpenAIConstructor.assert_not_called()
self.assertEqual(bot.client, preconfigured_client)
self.assertEqual(bot.model, "gpt-3.5")
@patch('openai.AzureOpenAI')
def test_init_with_azure_config_args(self, MockAzureOpenAIConstructor):
MockAzureOpenAIConstructor.return_value = self.mock_openai_client_instance
bot = ConcreteOpenAICompatibleBot(
api_key="azure_key",
azure_endpoint="https://myenv.openai.azure.com",
api_version="2023-05-15",
azure_deployment="my-gpt-4", # This should be used as model_name for API call
model_name="should_be_overridden_by_azure_deployment_for_api"
# model_name is passed to _configure_model_and_tokens, which sets self.model for display/logging
# but for Azure, the client needs the deployment name.
)
MockAzureOpenAIConstructor.assert_called_once_with(
api_key="azure_key",
azure_endpoint="https://myenv.openai.azure.com",
api_version="2023-05-15"
)
self.assertEqual(bot.client, self.mock_openai_client_instance)
self.assertEqual(bot.model, "my-gpt-4") # Azure deployment name becomes the model for API calls
self.assertEqual(bot.azure_openai, True)
@patch('openai.AzureOpenAI')
def test_init_with_azure_env_vars(self, MockAzureOpenAIConstructor):
MockAzureOpenAIConstructor.return_value = self.mock_openai_client_instance
os.environ["AZURE_OPENAI_KEY"] = "env_azure_key"
os.environ["AZURE_OPENAI_ENDPOINT"] = "https://env.openai.azure.com"
os.environ["AZURE_OPENAI_API_VERSION"] = "2023-06-01"
os.environ["AZURE_DEPLOYMENT_NAME"] = "env-gpt-35" # Used as model_name
bot = ConcreteOpenAICompatibleBot(model_name="ignored_if_azure_deployment_env_is_set")
MockAzureOpenAIConstructor.assert_called_once_with(
api_key="env_azure_key",
azure_endpoint="https://env.openai.azure.com",
api_version="2023-06-01"
)
self.assertEqual(bot.model, "env-gpt-35")
self.assertTrue(bot.azure_openai)
@patch('openai.OpenAI')
def test_init_with_gemini_config_args(self, MockOpenAIConstructor):
MockOpenAIConstructor.return_value = self.mock_openai_client_instance
bot = ConcreteOpenAICompatibleBot(
api_key="gemini_key",
base_url="https://gemini.example.com",
model_name="gemini-pro",
is_gemini=True
)
MockOpenAIConstructor.assert_called_once_with(api_key="gemini_key", base_url="https://gemini.example.com")
self.assertEqual(bot.model, "gemini-pro")
self.assertFalse(bot.azure_openai) # is_gemini doesn't mean azure_openai
def test_configure_model_and_tokens(self):
bot = ConcreteOpenAICompatibleBot(model_name="initial_model") # init calls _configure
bot._configure_model_and_tokens("test-model", "500")
self.assertEqual(bot.model, "test-model")
self.assertEqual(bot.max_tokens, 500)
bot._configure_model_and_tokens("test-model-2", None, default_max_tokens=150)
self.assertEqual(bot.max_tokens, 150)
bot._configure_model_and_tokens("test-model-3", "invalid_token_val")
self.assertEqual(bot.max_tokens, 1000) # Default fallback
def test_get_llm_description(self):
bot = ConcreteOpenAICompatibleBot(model_name="desc-model", max_tokens_str="256")
self.assertEqual(bot.get_llm_description(), "LLM: desc-model, Max Tokens: 256, Azure: False")
bot_azure = ConcreteOpenAICompatibleBot(azure_deployment="azure-model", azure_endpoint="x", api_key="y", api_version="z")
self.assertEqual(bot_azure.get_llm_description(), "LLM: azure-model, Max Tokens: 1000, Azure: True")
def test_get_chat_response_success(self):
bot = ConcreteOpenAICompatibleBot(client=self.mock_openai_client_instance, model_name="test-gpt")
bot.max_tokens = 50 # Ensure this is set
mock_api_response = create_mock_openai_response(content="Hello from API")
self.mock_openai_client_instance.chat.completions.create.return_value = mock_api_response
messages = [{"role": "user", "content": "Hi"}]
response = bot.get_chat_response(messages)
self.mock_openai_client_instance.chat.completions.create.assert_called_once_with(
model="test-gpt",
messages=messages,
tools=ANY, # Assuming functions can be None or empty list
tool_choice=ANY,
max_tokens=50
)
self.assertEqual(response, mock_api_response)
def test_get_chat_response_api_error(self):
bot = ConcreteOpenAICompatibleBot(client=self.mock_openai_client_instance, model_name="error-gpt")
self.mock_openai_client_instance.chat.completions.create.side_effect = Exception("API Down")
with self.assertRaisesRegex(Exception, "API Down"):
bot.get_chat_response([{"role": "user", "content": "trigger"}])
async def test_handle_message_simple_response(self):
bot = ConcreteOpenAICompatibleBot(client=self.mock_openai_client_instance, model_name="chatty")
bot.system_prompt = "You are a test bot." # Set directly for simplicity
mock_api_response = create_mock_openai_response(content="Test reply")
self.mock_openai_client_instance.chat.completions.create.return_value = mock_api_response
response_content = await bot.handle_message(user_id=1, user_message="Hello")
self.assertEqual(response_content, "Test reply")
self.assertIn(1, bot.conversation_history)
self.assertEqual(len(bot.conversation_history[1]), 3) # System, User, Assistant
self.assertEqual(bot.conversation_history[1][0]["content"], "You are a test bot.")
self.assertEqual(bot.conversation_history[1][2]["content"], "Test reply")
async def test_handle_message_with_tool_call_and_response(self):
bot = ConcreteOpenAICompatibleBot(client=self.mock_openai_client_instance, model_name="tool-user")
# Mock functions/tools setup on the bot
mock_tool_def = {"function": {"name": "get_weather", "description": "Gets weather", "parameters": {}}}
bot.functions = [mock_tool_def] # Simulate tools are loaded
# API response 1: Request to call tool
tool_call_request = [{"id": "call123", "function": {"name": "get_weather", "arguments": '''{"location": "moon"}'''}}]
api_response_1 = create_mock_openai_response(tool_calls=tool_call_request)
# API response 2: Final answer after tool execution
api_response_2 = create_mock_openai_response(content="The weather on the moon is chilly.")
self.mock_openai_client_instance.chat.completions.create.side_effect = [api_response_1, api_response_2]
# Mock self.call_tool
bot.call_tool = MagicMock(return_value='''{"temperature": "-100 C"}''')
final_response = await bot.handle_message(user_id=2, user_message="Weather on moon?")
self.assertEqual(final_response, "The weather on the moon is chilly.")
bot.call_tool.assert_called_once_with("get_weather", '''{"location": "moon"}''')
# Check conversation history includes tool messages
history = bot.conversation_history[2]
self.assertTrue(any(msg["role"] == "assistant" and msg.tool_calls is not None for msg in history))
self.assertTrue(any(msg["role"] == "tool" and msg["name"] == "get_weather" for msg in history))
self.assertEqual(self.mock_openai_client_instance.chat.completions.create.call_count, 2)
async def test_handle_message_max_history_length(self):
bot = ConcreteOpenAICompatibleBot(client=self.mock_openai_client_instance, model_name="hist-test", max_history_length=3)
self.mock_openai_client_instance.chat.completions.create.return_value = create_mock_openai_response(content="Ok")
await bot.handle_message(1, "Msg1") # Sys, User, Assist (3)
self.assertEqual(len(bot.conversation_history[1]), 3)
await bot.handle_message(1, "Msg2") # User, Assist. Should be 3 (prev User, prev Assist, new User) -> then adds new Assist.
# Before new call: [Sys, U1, A1]. New U2. Call with [Sys,U1,A1,U2]. Resp A2.
# History: [Sys,U1,A1,U2,A2]. Limit 3. -> [A1,U2,A2] (if system is not preserved specially)
# The current code appends to history then truncates if over limit.
# So after Msg1: [S, U1, A1]. len=3.
# For Msg2: History is [S, U1, A1]. Append U2. Call with [S,U1,A1,U2]. Append A2.
# History now [S,U1,A1,U2,A2]. len=5. Truncate to 3.
# Expected: [A1, U2, A2] or [U1,A1,U2] or [U2,A2,S] depending on how system prompt is handled in truncation.
# The code is: self.conversation_history[user_id][-self.max_history_length:]
# And system prompt is only added IF user_id not in self.conversation_history.
# So, for Msg2, system prompt is not re-added.
# History before Msg2 call: [S, U1, A1]
# Messages for Msg2 call: [S, U1, A1, U2]
# History after Msg2 response A2: [S, U1, A1, U2, A2]. Len 5.
# Truncated to self.max_history_length=3: [A1, U2, A2]
# Call 1
self.mock_openai_client_instance.chat.completions.create.reset_mock()
self.mock_openai_client_instance.chat.completions.create.return_value = create_mock_openai_response(content="Reply1")
await bot.handle_message(user_id=7, user_message="First message")
self.assertEqual(len(bot.conversation_history[7]), 3) # System, User1, Assistant1
# Call 2
self.mock_openai_client_instance.chat.completions.create.reset_mock()
self.mock_openai_client_instance.chat.completions.create.return_value = create_mock_openai_response(content="Reply2")
await bot.handle_message(user_id=7, user_message="Second message")
# History before call: [S, U1, A1]. Messages for call: [S, U1, A1, U2]. History after: [S, U1, A1, U2, A2].
# Truncated to 3: [A1, U2, A2]
self.assertEqual(len(bot.conversation_history[7]), 3)
self.assertEqual(bot.conversation_history[7][0]["content"], "Reply1") # A1
self.assertEqual(bot.conversation_history[7][1]["content"], "Second message") # U2
self.assertEqual(bot.conversation_history[7][2]["content"], "Reply2") # A2
# Call 3
self.mock_openai_client_instance.chat.completions.create.reset_mock()
self.mock_openai_client_instance.chat.completions.create.return_value = create_mock_openai_response(content="Reply3")
await bot.handle_message(user_id=7, user_message="Third message")
# History before call: [A1, U2, A2]. Messages for call: [A1, U2, A2, U3]. History after: [A1, U2, A2, U3, A3].
# Truncated to 3: [A2, U3, A3]
self.assertEqual(len(bot.conversation_history[7]), 3)
self.assertEqual(bot.conversation_history[7][0]["content"], "Reply2") # A2
self.assertEqual(bot.conversation_history[7][1]["content"], "Third message") # U3
self.assertEqual(bot.conversation_history[7][2]["content"], "Reply3") # A3
async def test_abort_processing(self):
bot = ConcreteOpenAICompatibleBot(model_name="test")
user_id = 123
bot.processing_status[user_id] = {"processing": True, "message_id": 456}
bot.conversation_history[user_id] = [{"role": "user", "content": "stuff"}]
with patch.object(bot, 'clear_conversation_history') as mock_clear_hist: # Patching the method from Base class
result = await bot.abort_processing(user_id)
self.assertEqual(result, "Processing aborted and conversation cleared.")
self.assertFalse(bot.processing_status[user_id]["processing"])
mock_clear_hist.assert_called_once_with(user_id)
async def test_abort_processing_no_active_processing(self):
bot = ConcreteOpenAICompatibleBot(model_name="test")
user_id = 404 # Not in processing_status
with patch.object(bot, 'clear_conversation_history') as mock_clear_hist:
result = await bot.abort_processing(user_id)
self.assertEqual(result, "No active processing found to abort. Conversation cleared.")
mock_clear_hist.assert_called_once_with(user_id)
# Test for the abstract switch_model (basic call, actual logic in concrete class for this test)
async def test_switch_model_concrete_implementation(self):
bot = ConcreteOpenAICompatibleBot(model_name="model1", small_model_name="model1", large_model_name="model2", max_tokens_str="100")
self.assertEqual(bot.model, "model1")
await bot.switch_model() # Calls the concrete implementation
self.assertEqual(bot.model, "model2")
await bot.switch_model()
self.assertEqual(bot.model, "model1")
if __name__ == '__main__':
unittest.main()