Add unit tests for GeminiTelegramInferenceBot
This commit is contained in:
@@ -0,0 +1,154 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch, ANY
|
||||
import os
|
||||
|
||||
# Assuming gemini_telegram_inference_bot.py and its parent are accessible
|
||||
from gemini_telegram_inference_bot import GeminiTelegramInferenceBot
|
||||
from openai_compatible_inference_bot import OpenAICompatibleInferenceBot # For patching super
|
||||
|
||||
class TestGeminiTelegramInferenceBot(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Store and clear relevant environment variables
|
||||
self.original_gemini_key = os.environ.get("GEMINI_API_KEY")
|
||||
self.original_gemini_base_url = os.environ.get("GEMINI_API_BASE_URL")
|
||||
self.original_small_model = os.environ.get("GEMINI_SMALL_MODEL")
|
||||
self.original_large_model = os.environ.get("GEMINI_LARGE_MODEL")
|
||||
self.original_small_tokens = os.environ.get("GEMINI_SMALL_MODEL_MAX_TOKENS")
|
||||
self.original_large_tokens = os.environ.get("GEMINI_LARGE_MODEL_MAX_TOKENS")
|
||||
self.original_system_prompt_path = os.environ.get("SYSTEM_PROMPT_PATH")
|
||||
|
||||
for key in ["GEMINI_API_KEY", "GEMINI_API_BASE_URL", "GEMINI_SMALL_MODEL", "GEMINI_LARGE_MODEL",
|
||||
"GEMINI_SMALL_MODEL_MAX_TOKENS", "GEMINI_LARGE_MODEL_MAX_TOKENS", "SYSTEM_PROMPT_PATH"]:
|
||||
if os.environ.get(key):
|
||||
del os.environ[key]
|
||||
|
||||
self.mock_openai_client = MagicMock() # Used if superclass creates an OpenAI client
|
||||
|
||||
def tearDown(self):
|
||||
# Restore environment variables
|
||||
if self.original_gemini_key: os.environ["GEMINI_API_KEY"] = self.original_gemini_key
|
||||
if self.original_gemini_base_url: os.environ["GEMINI_API_BASE_URL"] = self.original_gemini_base_url
|
||||
if self.original_small_model: os.environ["GEMINI_SMALL_MODEL"] = self.original_small_model
|
||||
if self.original_large_model: os.environ["GEMINI_LARGE_MODEL"] = self.original_large_model
|
||||
if self.original_small_tokens: os.environ["GEMINI_SMALL_MODEL_MAX_TOKENS"] = self.original_small_tokens
|
||||
if self.original_large_tokens: os.environ["GEMINI_LARGE_MODEL_MAX_TOKENS"] = self.original_large_tokens
|
||||
if self.original_system_prompt_path: os.environ["SYSTEM_PROMPT_PATH"] = self.original_system_prompt_path
|
||||
|
||||
@patch.object(OpenAICompatibleInferenceBot, '__init__') # Mock the superclass's __init__
|
||||
def test_init_defaults_and_super_call(self, mock_super_init):
|
||||
os.environ["GEMINI_API_KEY"] = "test_key_gemini"
|
||||
os.environ["GEMINI_API_BASE_URL"] = "https://gemini.env.com"
|
||||
os.environ["GEMINI_SMALL_MODEL"] = "gemini-pro-env"
|
||||
os.environ["GEMINI_SMALL_MODEL_MAX_TOKENS"] = "360"
|
||||
|
||||
bot = GeminiTelegramInferenceBot()
|
||||
|
||||
mock_super_init.assert_called_once_with(
|
||||
client=None,
|
||||
api_key="test_key_gemini",
|
||||
base_url="https://gemini.env.com", # Passed to super
|
||||
api_version=None,
|
||||
azure_deployment=None,
|
||||
model_name="gemini-pro-env",
|
||||
max_tokens_str="360",
|
||||
small_model_name="gemini-pro-env",
|
||||
small_model_max_tokens_str="360",
|
||||
large_model_name=os.environ.get("GEMINI_LARGE_MODEL", "gemini-1.5-pro-latest"), # Default large
|
||||
large_model_max_tokens_str=os.environ.get("GEMINI_LARGE_MODEL_MAX_TOKENS"),
|
||||
system_prompt_content=None,
|
||||
system_prompt_path=None,
|
||||
is_gemini=True, # Important for Gemini bot
|
||||
max_history_length=20
|
||||
)
|
||||
|
||||
@patch.object(OpenAICompatibleInferenceBot, '__init__')
|
||||
def test_init_with_arguments(self, mock_super_init):
|
||||
mock_client_arg = MagicMock()
|
||||
bot = GeminiTelegramInferenceBot(
|
||||
openai_client=mock_client_arg, # Name in Gemini bot is openai_client for consistency
|
||||
api_key="arg_gem_key",
|
||||
base_url="https://arg.gemini.com",
|
||||
small_model_name="arg_gem_small",
|
||||
small_model_max_tokens="124",
|
||||
large_model_name="arg_gem_large",
|
||||
large_model_max_tokens="457",
|
||||
system_prompt_content="Gemini prompt"
|
||||
)
|
||||
mock_super_init.assert_called_once_with(
|
||||
client=mock_client_arg,
|
||||
api_key="arg_gem_key",
|
||||
base_url="https://arg.gemini.com",
|
||||
api_version=None,
|
||||
azure_deployment=None,
|
||||
model_name="arg_gem_small",
|
||||
max_tokens_str="124",
|
||||
small_model_name="arg_gem_small",
|
||||
small_model_max_tokens_str="124",
|
||||
large_model_name="arg_gem_large",
|
||||
large_model_max_tokens_str="457",
|
||||
system_prompt_content="Gemini prompt",
|
||||
system_prompt_path=None,
|
||||
is_gemini=True,
|
||||
max_history_length=20
|
||||
)
|
||||
|
||||
@patch('openai.OpenAI') # Gemini bot uses OpenAI client configured for Gemini endpoint
|
||||
async def test_switch_model_logic(self, mock_openai_constructor):
|
||||
mock_openai_constructor.return_value = self.mock_openai_client
|
||||
|
||||
os.environ["GEMINI_SMALL_MODEL"] = "env-gemini-small"
|
||||
os.environ["GEMINI_SMALL_MODEL_MAX_TOKENS"] = "110"
|
||||
os.environ["GEMINI_LARGE_MODEL"] = "env-gemini-large"
|
||||
os.environ["GEMINI_LARGE_MODEL_MAX_TOKENS"] = "220"
|
||||
|
||||
bot = GeminiTelegramInferenceBot() # Uses env vars by default
|
||||
self.assertEqual(bot.model, "env-gemini-small")
|
||||
self.assertEqual(bot.max_tokens, 110)
|
||||
|
||||
status = await bot.switch_model()
|
||||
self.assertEqual(bot.model, "env-gemini-large")
|
||||
self.assertEqual(bot.max_tokens, 220)
|
||||
self.assertEqual(status, "Switched to model: env-gemini-large")
|
||||
|
||||
status = await bot.switch_model()
|
||||
self.assertEqual(bot.model, "env-gemini-small")
|
||||
self.assertEqual(bot.max_tokens, 110)
|
||||
self.assertEqual(status, "Switched to model: env-gemini-small")
|
||||
|
||||
@patch('openai.OpenAI')
|
||||
async def test_switch_model_uses_instance_configs_if_provided(self, mock_openai_constructor):
|
||||
mock_openai_constructor.return_value = self.mock_openai_client
|
||||
|
||||
bot = GeminiTelegramInferenceBot(
|
||||
small_model_name="init-gem-small", small_model_max_tokens="55",
|
||||
large_model_name="init-gem-large", large_model_max_tokens="155"
|
||||
)
|
||||
self.assertEqual(bot.model, "init-gem-small")
|
||||
self.assertEqual(bot.max_tokens, 55)
|
||||
|
||||
status = await bot.switch_model()
|
||||
self.assertEqual(bot.model, "init-gem-large")
|
||||
self.assertEqual(bot.max_tokens, 155)
|
||||
self.assertEqual(status, "Switched to model: init-gem-large")
|
||||
|
||||
status = await bot.switch_model()
|
||||
self.assertEqual(bot.model, "init-gem-small")
|
||||
self.assertEqual(bot.max_tokens, 55)
|
||||
self.assertEqual(status, "Switched to model: init-gem-small")
|
||||
|
||||
@patch('openai.OpenAI')
|
||||
def test_get_llm_description_for_gemini_bot(self, mock_openai_constructor):
|
||||
mock_openai_constructor.return_value = self.mock_openai_client
|
||||
bot = GeminiTelegramInferenceBot(
|
||||
small_model_name="gemini-pro-desc",
|
||||
small_model_max_tokens="888",
|
||||
# is_gemini is True by default in constructor call to super
|
||||
)
|
||||
# LLM description should indicate not Azure, even though it uses OpenAICompatible... base
|
||||
# The is_gemini flag primarily affects client instantiation logic in the superclass.
|
||||
# The azure_openai flag in superclass is based on azure_endpoint presence.
|
||||
self.assertEqual(bot.get_llm_description(), "LLM: gemini-pro-desc, Max Tokens: 888, Azure: False")
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user