Add unit tests for OpenAICompatibleInferenceBot
This commit is contained in:
@@ -0,0 +1,332 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch, AsyncMock, ANY
|
||||
import os
|
||||
import json
|
||||
|
||||
# Assuming openai_compatible_inference_bot.py is in the parent directory or PYTHONPATH is set
|
||||
from openai_compatible_inference_bot import OpenAICompatibleInferenceBot
|
||||
|
||||
# Mock response from OpenAI client's chat.completions.create
|
||||
def create_mock_openai_response(content=None, tool_calls=None):
|
||||
mock_message = MagicMock()
|
||||
mock_message.role = "assistant"
|
||||
mock_message.content = content
|
||||
if tool_calls:
|
||||
# tool_calls should be a list of objects with id and function (name, arguments)
|
||||
mock_tool_calls = []
|
||||
for tc in tool_calls:
|
||||
mock_tc = MagicMock()
|
||||
mock_tc.id = tc["id"]
|
||||
mock_tc.function.name = tc["function"]["name"]
|
||||
mock_tc.function.arguments = tc["function"]["arguments"]
|
||||
mock_tool_calls.append(mock_tc)
|
||||
mock_message.tool_calls = mock_tool_calls
|
||||
else:
|
||||
mock_message.tool_calls = None
|
||||
|
||||
mock_choice = MagicMock()
|
||||
mock_choice.message = mock_message
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [mock_choice]
|
||||
return mock_response
|
||||
|
||||
# Concrete class for testing
|
||||
class ConcreteOpenAICompatibleBot(OpenAICompatibleInferenceBot):
|
||||
# Implement abstract methods for instantiation
|
||||
async def switch_model(self):
|
||||
# Simple switch for testing if needed, or just pass
|
||||
if self.model == self.small_model_name:
|
||||
self._configure_model_and_tokens(self.large_model_name, self.large_model_max_tokens_str)
|
||||
else:
|
||||
self._configure_model_and_tokens(self.small_model_name, self.small_model_max_tokens_str)
|
||||
return f"Switched to {self.model}"
|
||||
|
||||
# Override load_functions if it's called by parent and needs mocking for these tests
|
||||
# (OpenAICompatibleInferenceBot's __init__ calls BaseTelegramInferenceBot's __init__, which calls load_functions)
|
||||
def load_functions(self):
|
||||
# For these tests, assume no tools unless specifically added
|
||||
self.tools = []
|
||||
self.functions = []
|
||||
return self.tools, self.functions
|
||||
|
||||
|
||||
class TestOpenAICompatibleInferenceBot(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.original_openai_api_key = os.environ.get("OPENAI_API_KEY")
|
||||
self.original_azure_openai_key = os.environ.get("AZURE_OPENAI_KEY")
|
||||
self.original_azure_endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT")
|
||||
self.original_api_version = os.environ.get("AZURE_OPENAI_API_VERSION")
|
||||
self.original_azure_deployment = os.environ.get("AZURE_DEPLOYMENT_NAME")
|
||||
|
||||
# Clear relevant env vars before each test
|
||||
for key in ["OPENAI_API_KEY", "AZURE_OPENAI_KEY", "AZURE_OPENAI_ENDPOINT",
|
||||
"AZURE_OPENAI_API_VERSION", "AZURE_DEPLOYMENT_NAME", "SYSTEM_PROMPT_PATH"]:
|
||||
if os.environ.get(key):
|
||||
del os.environ[key]
|
||||
|
||||
self.mock_openai_client_instance = MagicMock()
|
||||
self.mock_openai_client_instance.chat.completions.create = MagicMock()
|
||||
|
||||
def tearDown(self):
|
||||
# Restore environment variables
|
||||
if self.original_openai_api_key: os.environ["OPENAI_API_KEY"] = self.original_openai_api_key
|
||||
if self.original_azure_openai_key: os.environ["AZURE_OPENAI_KEY"] = self.original_azure_openai_key
|
||||
if self.original_azure_endpoint: os.environ["AZURE_OPENAI_ENDPOINT"] = self.original_azure_endpoint
|
||||
if self.original_api_version: os.environ["AZURE_OPENAI_API_VERSION"] = self.original_api_version
|
||||
if self.original_azure_deployment: os.environ["AZURE_DEPLOYMENT_NAME"] = self.original_azure_deployment
|
||||
|
||||
|
||||
@patch('openai.OpenAI')
|
||||
def test_init_with_openai_defaults(self, MockOpenAIConstructor):
|
||||
MockOpenAIConstructor.return_value = self.mock_openai_client_instance
|
||||
os.environ["OPENAI_API_KEY"] = "test_openai_key"
|
||||
|
||||
bot = ConcreteOpenAICompatibleBot(model_name="gpt-4")
|
||||
|
||||
MockOpenAIConstructor.assert_called_once_with(api_key="test_openai_key", base_url=None)
|
||||
self.assertEqual(bot.client, self.mock_openai_client_instance)
|
||||
self.assertEqual(bot.model, "gpt-4")
|
||||
self.assertEqual(bot.max_tokens, 1000) # Default from _configure_model_and_tokens
|
||||
self.assertEqual(bot.azure_openai, False)
|
||||
|
||||
@patch('openai.OpenAI')
|
||||
def test_init_with_provided_client(self, MockOpenAIConstructor):
|
||||
preconfigured_client = MagicMock()
|
||||
bot = ConcreteOpenAICompatibleBot(client=preconfigured_client, model_name="gpt-3.5")
|
||||
|
||||
MockOpenAIConstructor.assert_not_called()
|
||||
self.assertEqual(bot.client, preconfigured_client)
|
||||
self.assertEqual(bot.model, "gpt-3.5")
|
||||
|
||||
@patch('openai.AzureOpenAI')
|
||||
def test_init_with_azure_config_args(self, MockAzureOpenAIConstructor):
|
||||
MockAzureOpenAIConstructor.return_value = self.mock_openai_client_instance
|
||||
|
||||
bot = ConcreteOpenAICompatibleBot(
|
||||
api_key="azure_key",
|
||||
azure_endpoint="https://myenv.openai.azure.com",
|
||||
api_version="2023-05-15",
|
||||
azure_deployment="my-gpt-4", # This should be used as model_name for API call
|
||||
model_name="should_be_overridden_by_azure_deployment_for_api"
|
||||
# model_name is passed to _configure_model_and_tokens, which sets self.model for display/logging
|
||||
# but for Azure, the client needs the deployment name.
|
||||
)
|
||||
|
||||
MockAzureOpenAIConstructor.assert_called_once_with(
|
||||
api_key="azure_key",
|
||||
azure_endpoint="https://myenv.openai.azure.com",
|
||||
api_version="2023-05-15"
|
||||
)
|
||||
self.assertEqual(bot.client, self.mock_openai_client_instance)
|
||||
self.assertEqual(bot.model, "my-gpt-4") # Azure deployment name becomes the model for API calls
|
||||
self.assertEqual(bot.azure_openai, True)
|
||||
|
||||
|
||||
@patch('openai.AzureOpenAI')
|
||||
def test_init_with_azure_env_vars(self, MockAzureOpenAIConstructor):
|
||||
MockAzureOpenAIConstructor.return_value = self.mock_openai_client_instance
|
||||
os.environ["AZURE_OPENAI_KEY"] = "env_azure_key"
|
||||
os.environ["AZURE_OPENAI_ENDPOINT"] = "https://env.openai.azure.com"
|
||||
os.environ["AZURE_OPENAI_API_VERSION"] = "2023-06-01"
|
||||
os.environ["AZURE_DEPLOYMENT_NAME"] = "env-gpt-35" # Used as model_name
|
||||
|
||||
bot = ConcreteOpenAICompatibleBot(model_name="ignored_if_azure_deployment_env_is_set")
|
||||
|
||||
MockAzureOpenAIConstructor.assert_called_once_with(
|
||||
api_key="env_azure_key",
|
||||
azure_endpoint="https://env.openai.azure.com",
|
||||
api_version="2023-06-01"
|
||||
)
|
||||
self.assertEqual(bot.model, "env-gpt-35")
|
||||
self.assertTrue(bot.azure_openai)
|
||||
|
||||
@patch('openai.OpenAI')
|
||||
def test_init_with_gemini_config_args(self, MockOpenAIConstructor):
|
||||
MockOpenAIConstructor.return_value = self.mock_openai_client_instance
|
||||
|
||||
bot = ConcreteOpenAICompatibleBot(
|
||||
api_key="gemini_key",
|
||||
base_url="https://gemini.example.com",
|
||||
model_name="gemini-pro",
|
||||
is_gemini=True
|
||||
)
|
||||
MockOpenAIConstructor.assert_called_once_with(api_key="gemini_key", base_url="https://gemini.example.com")
|
||||
self.assertEqual(bot.model, "gemini-pro")
|
||||
self.assertFalse(bot.azure_openai) # is_gemini doesn't mean azure_openai
|
||||
|
||||
def test_configure_model_and_tokens(self):
|
||||
bot = ConcreteOpenAICompatibleBot(model_name="initial_model") # init calls _configure
|
||||
bot._configure_model_and_tokens("test-model", "500")
|
||||
self.assertEqual(bot.model, "test-model")
|
||||
self.assertEqual(bot.max_tokens, 500)
|
||||
|
||||
bot._configure_model_and_tokens("test-model-2", None, default_max_tokens=150)
|
||||
self.assertEqual(bot.max_tokens, 150)
|
||||
|
||||
bot._configure_model_and_tokens("test-model-3", "invalid_token_val")
|
||||
self.assertEqual(bot.max_tokens, 1000) # Default fallback
|
||||
|
||||
def test_get_llm_description(self):
|
||||
bot = ConcreteOpenAICompatibleBot(model_name="desc-model", max_tokens_str="256")
|
||||
self.assertEqual(bot.get_llm_description(), "LLM: desc-model, Max Tokens: 256, Azure: False")
|
||||
|
||||
bot_azure = ConcreteOpenAICompatibleBot(azure_deployment="azure-model", azure_endpoint="x", api_key="y", api_version="z")
|
||||
self.assertEqual(bot_azure.get_llm_description(), "LLM: azure-model, Max Tokens: 1000, Azure: True")
|
||||
|
||||
|
||||
def test_get_chat_response_success(self):
|
||||
bot = ConcreteOpenAICompatibleBot(client=self.mock_openai_client_instance, model_name="test-gpt")
|
||||
bot.max_tokens = 50 # Ensure this is set
|
||||
mock_api_response = create_mock_openai_response(content="Hello from API")
|
||||
self.mock_openai_client_instance.chat.completions.create.return_value = mock_api_response
|
||||
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
response = bot.get_chat_response(messages)
|
||||
|
||||
self.mock_openai_client_instance.chat.completions.create.assert_called_once_with(
|
||||
model="test-gpt",
|
||||
messages=messages,
|
||||
tools=ANY, # Assuming functions can be None or empty list
|
||||
tool_choice=ANY,
|
||||
max_tokens=50
|
||||
)
|
||||
self.assertEqual(response, mock_api_response)
|
||||
|
||||
def test_get_chat_response_api_error(self):
|
||||
bot = ConcreteOpenAICompatibleBot(client=self.mock_openai_client_instance, model_name="error-gpt")
|
||||
self.mock_openai_client_instance.chat.completions.create.side_effect = Exception("API Down")
|
||||
|
||||
with self.assertRaisesRegex(Exception, "API Down"):
|
||||
bot.get_chat_response([{"role": "user", "content": "trigger"}])
|
||||
|
||||
async def test_handle_message_simple_response(self):
|
||||
bot = ConcreteOpenAICompatibleBot(client=self.mock_openai_client_instance, model_name="chatty")
|
||||
bot.system_prompt = "You are a test bot." # Set directly for simplicity
|
||||
mock_api_response = create_mock_openai_response(content="Test reply")
|
||||
self.mock_openai_client_instance.chat.completions.create.return_value = mock_api_response
|
||||
|
||||
response_content = await bot.handle_message(user_id=1, user_message="Hello")
|
||||
|
||||
self.assertEqual(response_content, "Test reply")
|
||||
self.assertIn(1, bot.conversation_history)
|
||||
self.assertEqual(len(bot.conversation_history[1]), 3) # System, User, Assistant
|
||||
self.assertEqual(bot.conversation_history[1][0]["content"], "You are a test bot.")
|
||||
self.assertEqual(bot.conversation_history[1][2]["content"], "Test reply")
|
||||
|
||||
async def test_handle_message_with_tool_call_and_response(self):
|
||||
bot = ConcreteOpenAICompatibleBot(client=self.mock_openai_client_instance, model_name="tool-user")
|
||||
|
||||
# Mock functions/tools setup on the bot
|
||||
mock_tool_def = {"function": {"name": "get_weather", "description": "Gets weather", "parameters": {}}}
|
||||
bot.functions = [mock_tool_def] # Simulate tools are loaded
|
||||
|
||||
# API response 1: Request to call tool
|
||||
tool_call_request = [{"id": "call123", "function": {"name": "get_weather", "arguments": '''{"location": "moon"}'''}}]
|
||||
api_response_1 = create_mock_openai_response(tool_calls=tool_call_request)
|
||||
|
||||
# API response 2: Final answer after tool execution
|
||||
api_response_2 = create_mock_openai_response(content="The weather on the moon is chilly.")
|
||||
|
||||
self.mock_openai_client_instance.chat.completions.create.side_effect = [api_response_1, api_response_2]
|
||||
|
||||
# Mock self.call_tool
|
||||
bot.call_tool = MagicMock(return_value='''{"temperature": "-100 C"}''')
|
||||
|
||||
final_response = await bot.handle_message(user_id=2, user_message="Weather on moon?")
|
||||
|
||||
self.assertEqual(final_response, "The weather on the moon is chilly.")
|
||||
bot.call_tool.assert_called_once_with("get_weather", '''{"location": "moon"}''')
|
||||
|
||||
# Check conversation history includes tool messages
|
||||
history = bot.conversation_history[2]
|
||||
self.assertTrue(any(msg["role"] == "assistant" and msg.tool_calls is not None for msg in history))
|
||||
self.assertTrue(any(msg["role"] == "tool" and msg["name"] == "get_weather" for msg in history))
|
||||
self.assertEqual(self.mock_openai_client_instance.chat.completions.create.call_count, 2)
|
||||
|
||||
async def test_handle_message_max_history_length(self):
|
||||
bot = ConcreteOpenAICompatibleBot(client=self.mock_openai_client_instance, model_name="hist-test", max_history_length=3)
|
||||
self.mock_openai_client_instance.chat.completions.create.return_value = create_mock_openai_response(content="Ok")
|
||||
|
||||
await bot.handle_message(1, "Msg1") # Sys, User, Assist (3)
|
||||
self.assertEqual(len(bot.conversation_history[1]), 3)
|
||||
|
||||
await bot.handle_message(1, "Msg2") # User, Assist. Should be 3 (prev User, prev Assist, new User) -> then adds new Assist.
|
||||
# Before new call: [Sys, U1, A1]. New U2. Call with [Sys,U1,A1,U2]. Resp A2.
|
||||
# History: [Sys,U1,A1,U2,A2]. Limit 3. -> [A1,U2,A2] (if system is not preserved specially)
|
||||
# The current code appends to history then truncates if over limit.
|
||||
# So after Msg1: [S, U1, A1]. len=3.
|
||||
# For Msg2: History is [S, U1, A1]. Append U2. Call with [S,U1,A1,U2]. Append A2.
|
||||
# History now [S,U1,A1,U2,A2]. len=5. Truncate to 3.
|
||||
# Expected: [A1, U2, A2] or [U1,A1,U2] or [U2,A2,S] depending on how system prompt is handled in truncation.
|
||||
# The code is: self.conversation_history[user_id][-self.max_history_length:]
|
||||
# And system prompt is only added IF user_id not in self.conversation_history.
|
||||
# So, for Msg2, system prompt is not re-added.
|
||||
# History before Msg2 call: [S, U1, A1]
|
||||
# Messages for Msg2 call: [S, U1, A1, U2]
|
||||
# History after Msg2 response A2: [S, U1, A1, U2, A2]. Len 5.
|
||||
# Truncated to self.max_history_length=3: [A1, U2, A2]
|
||||
|
||||
# Call 1
|
||||
self.mock_openai_client_instance.chat.completions.create.reset_mock()
|
||||
self.mock_openai_client_instance.chat.completions.create.return_value = create_mock_openai_response(content="Reply1")
|
||||
await bot.handle_message(user_id=7, user_message="First message")
|
||||
self.assertEqual(len(bot.conversation_history[7]), 3) # System, User1, Assistant1
|
||||
|
||||
# Call 2
|
||||
self.mock_openai_client_instance.chat.completions.create.reset_mock()
|
||||
self.mock_openai_client_instance.chat.completions.create.return_value = create_mock_openai_response(content="Reply2")
|
||||
await bot.handle_message(user_id=7, user_message="Second message")
|
||||
# History before call: [S, U1, A1]. Messages for call: [S, U1, A1, U2]. History after: [S, U1, A1, U2, A2].
|
||||
# Truncated to 3: [A1, U2, A2]
|
||||
self.assertEqual(len(bot.conversation_history[7]), 3)
|
||||
self.assertEqual(bot.conversation_history[7][0]["content"], "Reply1") # A1
|
||||
self.assertEqual(bot.conversation_history[7][1]["content"], "Second message") # U2
|
||||
self.assertEqual(bot.conversation_history[7][2]["content"], "Reply2") # A2
|
||||
|
||||
# Call 3
|
||||
self.mock_openai_client_instance.chat.completions.create.reset_mock()
|
||||
self.mock_openai_client_instance.chat.completions.create.return_value = create_mock_openai_response(content="Reply3")
|
||||
await bot.handle_message(user_id=7, user_message="Third message")
|
||||
# History before call: [A1, U2, A2]. Messages for call: [A1, U2, A2, U3]. History after: [A1, U2, A2, U3, A3].
|
||||
# Truncated to 3: [A2, U3, A3]
|
||||
self.assertEqual(len(bot.conversation_history[7]), 3)
|
||||
self.assertEqual(bot.conversation_history[7][0]["content"], "Reply2") # A2
|
||||
self.assertEqual(bot.conversation_history[7][1]["content"], "Third message") # U3
|
||||
self.assertEqual(bot.conversation_history[7][2]["content"], "Reply3") # A3
|
||||
|
||||
|
||||
async def test_abort_processing(self):
|
||||
bot = ConcreteOpenAICompatibleBot(model_name="test")
|
||||
user_id = 123
|
||||
bot.processing_status[user_id] = {"processing": True, "message_id": 456}
|
||||
bot.conversation_history[user_id] = [{"role": "user", "content": "stuff"}]
|
||||
|
||||
with patch.object(bot, 'clear_conversation_history') as mock_clear_hist: # Patching the method from Base class
|
||||
result = await bot.abort_processing(user_id)
|
||||
|
||||
self.assertEqual(result, "Processing aborted and conversation cleared.")
|
||||
self.assertFalse(bot.processing_status[user_id]["processing"])
|
||||
mock_clear_hist.assert_called_once_with(user_id)
|
||||
|
||||
async def test_abort_processing_no_active_processing(self):
|
||||
bot = ConcreteOpenAICompatibleBot(model_name="test")
|
||||
user_id = 404 # Not in processing_status
|
||||
with patch.object(bot, 'clear_conversation_history') as mock_clear_hist:
|
||||
result = await bot.abort_processing(user_id)
|
||||
self.assertEqual(result, "No active processing found to abort. Conversation cleared.")
|
||||
mock_clear_hist.assert_called_once_with(user_id)
|
||||
|
||||
# Test for the abstract switch_model (basic call, actual logic in concrete class for this test)
|
||||
async def test_switch_model_concrete_implementation(self):
|
||||
bot = ConcreteOpenAICompatibleBot(model_name="model1", small_model_name="model1", large_model_name="model2", max_tokens_str="100")
|
||||
self.assertEqual(bot.model, "model1")
|
||||
await bot.switch_model() # Calls the concrete implementation
|
||||
self.assertEqual(bot.model, "model2")
|
||||
await bot.switch_model()
|
||||
self.assertEqual(bot.model, "model1")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user