155 lines
7.4 KiB
Python
155 lines
7.4 KiB
Python
|
|
import unittest
|
||
|
|
from unittest.mock import MagicMock, patch, ANY
|
||
|
|
import os
|
||
|
|
|
||
|
|
# Assuming gemini_telegram_inference_bot.py and its parent are accessible
|
||
|
|
from gemini_telegram_inference_bot import GeminiTelegramInferenceBot
|
||
|
|
from openai_compatible_inference_bot import OpenAICompatibleInferenceBot # For patching super
|
||
|
|
|
||
|
|
class TestGeminiTelegramInferenceBot(unittest.IsolatedAsyncioTestCase):
|
||
|
|
|
||
|
|
def setUp(self):
|
||
|
|
# Store and clear relevant environment variables
|
||
|
|
self.original_gemini_key = os.environ.get("GEMINI_API_KEY")
|
||
|
|
self.original_gemini_base_url = os.environ.get("GEMINI_API_BASE_URL")
|
||
|
|
self.original_small_model = os.environ.get("GEMINI_SMALL_MODEL")
|
||
|
|
self.original_large_model = os.environ.get("GEMINI_LARGE_MODEL")
|
||
|
|
self.original_small_tokens = os.environ.get("GEMINI_SMALL_MODEL_MAX_TOKENS")
|
||
|
|
self.original_large_tokens = os.environ.get("GEMINI_LARGE_MODEL_MAX_TOKENS")
|
||
|
|
self.original_system_prompt_path = os.environ.get("SYSTEM_PROMPT_PATH")
|
||
|
|
|
||
|
|
for key in ["GEMINI_API_KEY", "GEMINI_API_BASE_URL", "GEMINI_SMALL_MODEL", "GEMINI_LARGE_MODEL",
|
||
|
|
"GEMINI_SMALL_MODEL_MAX_TOKENS", "GEMINI_LARGE_MODEL_MAX_TOKENS", "SYSTEM_PROMPT_PATH"]:
|
||
|
|
if os.environ.get(key):
|
||
|
|
del os.environ[key]
|
||
|
|
|
||
|
|
self.mock_openai_client = MagicMock() # Used if superclass creates an OpenAI client
|
||
|
|
|
||
|
|
def tearDown(self):
|
||
|
|
# Restore environment variables
|
||
|
|
if self.original_gemini_key: os.environ["GEMINI_API_KEY"] = self.original_gemini_key
|
||
|
|
if self.original_gemini_base_url: os.environ["GEMINI_API_BASE_URL"] = self.original_gemini_base_url
|
||
|
|
if self.original_small_model: os.environ["GEMINI_SMALL_MODEL"] = self.original_small_model
|
||
|
|
if self.original_large_model: os.environ["GEMINI_LARGE_MODEL"] = self.original_large_model
|
||
|
|
if self.original_small_tokens: os.environ["GEMINI_SMALL_MODEL_MAX_TOKENS"] = self.original_small_tokens
|
||
|
|
if self.original_large_tokens: os.environ["GEMINI_LARGE_MODEL_MAX_TOKENS"] = self.original_large_tokens
|
||
|
|
if self.original_system_prompt_path: os.environ["SYSTEM_PROMPT_PATH"] = self.original_system_prompt_path
|
||
|
|
|
||
|
|
@patch.object(OpenAICompatibleInferenceBot, '__init__') # Mock the superclass's __init__
|
||
|
|
def test_init_defaults_and_super_call(self, mock_super_init):
|
||
|
|
os.environ["GEMINI_API_KEY"] = "test_key_gemini"
|
||
|
|
os.environ["GEMINI_API_BASE_URL"] = "https://gemini.env.com"
|
||
|
|
os.environ["GEMINI_SMALL_MODEL"] = "gemini-pro-env"
|
||
|
|
os.environ["GEMINI_SMALL_MODEL_MAX_TOKENS"] = "360"
|
||
|
|
|
||
|
|
bot = GeminiTelegramInferenceBot()
|
||
|
|
|
||
|
|
mock_super_init.assert_called_once_with(
|
||
|
|
client=None,
|
||
|
|
api_key="test_key_gemini",
|
||
|
|
base_url="https://gemini.env.com", # Passed to super
|
||
|
|
api_version=None,
|
||
|
|
azure_deployment=None,
|
||
|
|
model_name="gemini-pro-env",
|
||
|
|
max_tokens_str="360",
|
||
|
|
small_model_name="gemini-pro-env",
|
||
|
|
small_model_max_tokens_str="360",
|
||
|
|
large_model_name=os.environ.get("GEMINI_LARGE_MODEL", "gemini-1.5-pro-latest"), # Default large
|
||
|
|
large_model_max_tokens_str=os.environ.get("GEMINI_LARGE_MODEL_MAX_TOKENS"),
|
||
|
|
system_prompt_content=None,
|
||
|
|
system_prompt_path=None,
|
||
|
|
is_gemini=True, # Important for Gemini bot
|
||
|
|
max_history_length=20
|
||
|
|
)
|
||
|
|
|
||
|
|
@patch.object(OpenAICompatibleInferenceBot, '__init__')
|
||
|
|
def test_init_with_arguments(self, mock_super_init):
|
||
|
|
mock_client_arg = MagicMock()
|
||
|
|
bot = GeminiTelegramInferenceBot(
|
||
|
|
openai_client=mock_client_arg, # Name in Gemini bot is openai_client for consistency
|
||
|
|
api_key="arg_gem_key",
|
||
|
|
base_url="https://arg.gemini.com",
|
||
|
|
small_model_name="arg_gem_small",
|
||
|
|
small_model_max_tokens="124",
|
||
|
|
large_model_name="arg_gem_large",
|
||
|
|
large_model_max_tokens="457",
|
||
|
|
system_prompt_content="Gemini prompt"
|
||
|
|
)
|
||
|
|
mock_super_init.assert_called_once_with(
|
||
|
|
client=mock_client_arg,
|
||
|
|
api_key="arg_gem_key",
|
||
|
|
base_url="https://arg.gemini.com",
|
||
|
|
api_version=None,
|
||
|
|
azure_deployment=None,
|
||
|
|
model_name="arg_gem_small",
|
||
|
|
max_tokens_str="124",
|
||
|
|
small_model_name="arg_gem_small",
|
||
|
|
small_model_max_tokens_str="124",
|
||
|
|
large_model_name="arg_gem_large",
|
||
|
|
large_model_max_tokens_str="457",
|
||
|
|
system_prompt_content="Gemini prompt",
|
||
|
|
system_prompt_path=None,
|
||
|
|
is_gemini=True,
|
||
|
|
max_history_length=20
|
||
|
|
)
|
||
|
|
|
||
|
|
@patch('openai.OpenAI') # Gemini bot uses OpenAI client configured for Gemini endpoint
|
||
|
|
async def test_switch_model_logic(self, mock_openai_constructor):
|
||
|
|
mock_openai_constructor.return_value = self.mock_openai_client
|
||
|
|
|
||
|
|
os.environ["GEMINI_SMALL_MODEL"] = "env-gemini-small"
|
||
|
|
os.environ["GEMINI_SMALL_MODEL_MAX_TOKENS"] = "110"
|
||
|
|
os.environ["GEMINI_LARGE_MODEL"] = "env-gemini-large"
|
||
|
|
os.environ["GEMINI_LARGE_MODEL_MAX_TOKENS"] = "220"
|
||
|
|
|
||
|
|
bot = GeminiTelegramInferenceBot() # Uses env vars by default
|
||
|
|
self.assertEqual(bot.model, "env-gemini-small")
|
||
|
|
self.assertEqual(bot.max_tokens, 110)
|
||
|
|
|
||
|
|
status = await bot.switch_model()
|
||
|
|
self.assertEqual(bot.model, "env-gemini-large")
|
||
|
|
self.assertEqual(bot.max_tokens, 220)
|
||
|
|
self.assertEqual(status, "Switched to model: env-gemini-large")
|
||
|
|
|
||
|
|
status = await bot.switch_model()
|
||
|
|
self.assertEqual(bot.model, "env-gemini-small")
|
||
|
|
self.assertEqual(bot.max_tokens, 110)
|
||
|
|
self.assertEqual(status, "Switched to model: env-gemini-small")
|
||
|
|
|
||
|
|
@patch('openai.OpenAI')
|
||
|
|
async def test_switch_model_uses_instance_configs_if_provided(self, mock_openai_constructor):
|
||
|
|
mock_openai_constructor.return_value = self.mock_openai_client
|
||
|
|
|
||
|
|
bot = GeminiTelegramInferenceBot(
|
||
|
|
small_model_name="init-gem-small", small_model_max_tokens="55",
|
||
|
|
large_model_name="init-gem-large", large_model_max_tokens="155"
|
||
|
|
)
|
||
|
|
self.assertEqual(bot.model, "init-gem-small")
|
||
|
|
self.assertEqual(bot.max_tokens, 55)
|
||
|
|
|
||
|
|
status = await bot.switch_model()
|
||
|
|
self.assertEqual(bot.model, "init-gem-large")
|
||
|
|
self.assertEqual(bot.max_tokens, 155)
|
||
|
|
self.assertEqual(status, "Switched to model: init-gem-large")
|
||
|
|
|
||
|
|
status = await bot.switch_model()
|
||
|
|
self.assertEqual(bot.model, "init-gem-small")
|
||
|
|
self.assertEqual(bot.max_tokens, 55)
|
||
|
|
self.assertEqual(status, "Switched to model: init-gem-small")
|
||
|
|
|
||
|
|
@patch('openai.OpenAI')
|
||
|
|
def test_get_llm_description_for_gemini_bot(self, mock_openai_constructor):
|
||
|
|
mock_openai_constructor.return_value = self.mock_openai_client
|
||
|
|
bot = GeminiTelegramInferenceBot(
|
||
|
|
small_model_name="gemini-pro-desc",
|
||
|
|
small_model_max_tokens="888",
|
||
|
|
# is_gemini is True by default in constructor call to super
|
||
|
|
)
|
||
|
|
# LLM description should indicate not Azure, even though it uses OpenAICompatible... base
|
||
|
|
# The is_gemini flag primarily affects client instantiation logic in the superclass.
|
||
|
|
# The azure_openai flag in superclass is based on azure_endpoint presence.
|
||
|
|
self.assertEqual(bot.get_llm_description(), "LLM: gemini-pro-desc, Max Tokens: 888, Azure: False")
|
||
|
|
|
||
|
|
if __name__ == '__main__':
|
||
|
|
unittest.main()
|