Files
cyclop/tests/test_gemini_telegram_inference_bot.py
T

155 lines
7.4 KiB
Python
Raw Normal View History

import unittest
from unittest.mock import MagicMock, patch, ANY
import os
# Assuming gemini_telegram_inference_bot.py and its parent are accessible
from gemini_telegram_inference_bot import GeminiTelegramInferenceBot
from openai_compatible_inference_bot import OpenAICompatibleInferenceBot # For patching super
class TestGeminiTelegramInferenceBot(unittest.IsolatedAsyncioTestCase):
def setUp(self):
# Store and clear relevant environment variables
self.original_gemini_key = os.environ.get("GEMINI_API_KEY")
self.original_gemini_base_url = os.environ.get("GEMINI_API_BASE_URL")
self.original_small_model = os.environ.get("GEMINI_SMALL_MODEL")
self.original_large_model = os.environ.get("GEMINI_LARGE_MODEL")
self.original_small_tokens = os.environ.get("GEMINI_SMALL_MODEL_MAX_TOKENS")
self.original_large_tokens = os.environ.get("GEMINI_LARGE_MODEL_MAX_TOKENS")
self.original_system_prompt_path = os.environ.get("SYSTEM_PROMPT_PATH")
for key in ["GEMINI_API_KEY", "GEMINI_API_BASE_URL", "GEMINI_SMALL_MODEL", "GEMINI_LARGE_MODEL",
"GEMINI_SMALL_MODEL_MAX_TOKENS", "GEMINI_LARGE_MODEL_MAX_TOKENS", "SYSTEM_PROMPT_PATH"]:
if os.environ.get(key):
del os.environ[key]
self.mock_openai_client = MagicMock() # Used if superclass creates an OpenAI client
def tearDown(self):
# Restore environment variables
if self.original_gemini_key: os.environ["GEMINI_API_KEY"] = self.original_gemini_key
if self.original_gemini_base_url: os.environ["GEMINI_API_BASE_URL"] = self.original_gemini_base_url
if self.original_small_model: os.environ["GEMINI_SMALL_MODEL"] = self.original_small_model
if self.original_large_model: os.environ["GEMINI_LARGE_MODEL"] = self.original_large_model
if self.original_small_tokens: os.environ["GEMINI_SMALL_MODEL_MAX_TOKENS"] = self.original_small_tokens
if self.original_large_tokens: os.environ["GEMINI_LARGE_MODEL_MAX_TOKENS"] = self.original_large_tokens
if self.original_system_prompt_path: os.environ["SYSTEM_PROMPT_PATH"] = self.original_system_prompt_path
@patch.object(OpenAICompatibleInferenceBot, '__init__') # Mock the superclass's __init__
def test_init_defaults_and_super_call(self, mock_super_init):
os.environ["GEMINI_API_KEY"] = "test_key_gemini"
os.environ["GEMINI_API_BASE_URL"] = "https://gemini.env.com"
os.environ["GEMINI_SMALL_MODEL"] = "gemini-pro-env"
os.environ["GEMINI_SMALL_MODEL_MAX_TOKENS"] = "360"
bot = GeminiTelegramInferenceBot()
mock_super_init.assert_called_once_with(
client=None,
api_key="test_key_gemini",
base_url="https://gemini.env.com", # Passed to super
api_version=None,
azure_deployment=None,
model_name="gemini-pro-env",
max_tokens_str="360",
small_model_name="gemini-pro-env",
small_model_max_tokens_str="360",
large_model_name=os.environ.get("GEMINI_LARGE_MODEL", "gemini-1.5-pro-latest"), # Default large
large_model_max_tokens_str=os.environ.get("GEMINI_LARGE_MODEL_MAX_TOKENS"),
system_prompt_content=None,
system_prompt_path=None,
is_gemini=True, # Important for Gemini bot
max_history_length=20
)
@patch.object(OpenAICompatibleInferenceBot, '__init__')
def test_init_with_arguments(self, mock_super_init):
mock_client_arg = MagicMock()
bot = GeminiTelegramInferenceBot(
openai_client=mock_client_arg, # Name in Gemini bot is openai_client for consistency
api_key="arg_gem_key",
base_url="https://arg.gemini.com",
small_model_name="arg_gem_small",
small_model_max_tokens="124",
large_model_name="arg_gem_large",
large_model_max_tokens="457",
system_prompt_content="Gemini prompt"
)
mock_super_init.assert_called_once_with(
client=mock_client_arg,
api_key="arg_gem_key",
base_url="https://arg.gemini.com",
api_version=None,
azure_deployment=None,
model_name="arg_gem_small",
max_tokens_str="124",
small_model_name="arg_gem_small",
small_model_max_tokens_str="124",
large_model_name="arg_gem_large",
large_model_max_tokens_str="457",
system_prompt_content="Gemini prompt",
system_prompt_path=None,
is_gemini=True,
max_history_length=20
)
@patch('openai.OpenAI') # Gemini bot uses OpenAI client configured for Gemini endpoint
async def test_switch_model_logic(self, mock_openai_constructor):
mock_openai_constructor.return_value = self.mock_openai_client
os.environ["GEMINI_SMALL_MODEL"] = "env-gemini-small"
os.environ["GEMINI_SMALL_MODEL_MAX_TOKENS"] = "110"
os.environ["GEMINI_LARGE_MODEL"] = "env-gemini-large"
os.environ["GEMINI_LARGE_MODEL_MAX_TOKENS"] = "220"
bot = GeminiTelegramInferenceBot() # Uses env vars by default
self.assertEqual(bot.model, "env-gemini-small")
self.assertEqual(bot.max_tokens, 110)
status = await bot.switch_model()
self.assertEqual(bot.model, "env-gemini-large")
self.assertEqual(bot.max_tokens, 220)
self.assertEqual(status, "Switched to model: env-gemini-large")
status = await bot.switch_model()
self.assertEqual(bot.model, "env-gemini-small")
self.assertEqual(bot.max_tokens, 110)
self.assertEqual(status, "Switched to model: env-gemini-small")
@patch('openai.OpenAI')
async def test_switch_model_uses_instance_configs_if_provided(self, mock_openai_constructor):
mock_openai_constructor.return_value = self.mock_openai_client
bot = GeminiTelegramInferenceBot(
small_model_name="init-gem-small", small_model_max_tokens="55",
large_model_name="init-gem-large", large_model_max_tokens="155"
)
self.assertEqual(bot.model, "init-gem-small")
self.assertEqual(bot.max_tokens, 55)
status = await bot.switch_model()
self.assertEqual(bot.model, "init-gem-large")
self.assertEqual(bot.max_tokens, 155)
self.assertEqual(status, "Switched to model: init-gem-large")
status = await bot.switch_model()
self.assertEqual(bot.model, "init-gem-small")
self.assertEqual(bot.max_tokens, 55)
self.assertEqual(status, "Switched to model: init-gem-small")
@patch('openai.OpenAI')
def test_get_llm_description_for_gemini_bot(self, mock_openai_constructor):
mock_openai_constructor.return_value = self.mock_openai_client
bot = GeminiTelegramInferenceBot(
small_model_name="gemini-pro-desc",
small_model_max_tokens="888",
# is_gemini is True by default in constructor call to super
)
# LLM description should indicate not Azure, even though it uses OpenAICompatible... base
# The is_gemini flag primarily affects client instantiation logic in the superclass.
# The azure_openai flag in superclass is based on azure_endpoint presence.
self.assertEqual(bot.get_llm_description(), "LLM: gemini-pro-desc, Max Tokens: 888, Azure: False")
if __name__ == '__main__':
unittest.main()