import unittest from unittest.mock import MagicMock, patch, ANY import os # Assuming gemini_telegram_inference_bot.py and its parent are accessible from gemini_telegram_inference_bot import GeminiTelegramInferenceBot from openai_compatible_inference_bot import OpenAICompatibleInferenceBot # For patching super class TestGeminiTelegramInferenceBot(unittest.IsolatedAsyncioTestCase): def setUp(self): # Store and clear relevant environment variables self.original_gemini_key = os.environ.get("GEMINI_API_KEY") self.original_gemini_base_url = os.environ.get("GEMINI_API_BASE_URL") self.original_small_model = os.environ.get("GEMINI_SMALL_MODEL") self.original_large_model = os.environ.get("GEMINI_LARGE_MODEL") self.original_small_tokens = os.environ.get("GEMINI_SMALL_MODEL_MAX_TOKENS") self.original_large_tokens = os.environ.get("GEMINI_LARGE_MODEL_MAX_TOKENS") self.original_system_prompt_path = os.environ.get("SYSTEM_PROMPT_PATH") for key in ["GEMINI_API_KEY", "GEMINI_API_BASE_URL", "GEMINI_SMALL_MODEL", "GEMINI_LARGE_MODEL", "GEMINI_SMALL_MODEL_MAX_TOKENS", "GEMINI_LARGE_MODEL_MAX_TOKENS", "SYSTEM_PROMPT_PATH"]: if os.environ.get(key): del os.environ[key] self.mock_openai_client = MagicMock() # Used if superclass creates an OpenAI client def tearDown(self): # Restore environment variables if self.original_gemini_key: os.environ["GEMINI_API_KEY"] = self.original_gemini_key if self.original_gemini_base_url: os.environ["GEMINI_API_BASE_URL"] = self.original_gemini_base_url if self.original_small_model: os.environ["GEMINI_SMALL_MODEL"] = self.original_small_model if self.original_large_model: os.environ["GEMINI_LARGE_MODEL"] = self.original_large_model if self.original_small_tokens: os.environ["GEMINI_SMALL_MODEL_MAX_TOKENS"] = self.original_small_tokens if self.original_large_tokens: os.environ["GEMINI_LARGE_MODEL_MAX_TOKENS"] = self.original_large_tokens if self.original_system_prompt_path: os.environ["SYSTEM_PROMPT_PATH"] = self.original_system_prompt_path @patch.object(OpenAICompatibleInferenceBot, '__init__') # Mock the superclass's __init__ def test_init_defaults_and_super_call(self, mock_super_init): os.environ["GEMINI_API_KEY"] = "test_key_gemini" os.environ["GEMINI_API_BASE_URL"] = "https://gemini.env.com" os.environ["GEMINI_SMALL_MODEL"] = "gemini-pro-env" os.environ["GEMINI_SMALL_MODEL_MAX_TOKENS"] = "360" bot = GeminiTelegramInferenceBot() mock_super_init.assert_called_once_with( client=None, api_key="test_key_gemini", base_url="https://gemini.env.com", # Passed to super api_version=None, azure_deployment=None, model_name="gemini-pro-env", max_tokens_str="360", small_model_name="gemini-pro-env", small_model_max_tokens_str="360", large_model_name=os.environ.get("GEMINI_LARGE_MODEL", "gemini-1.5-pro-latest"), # Default large large_model_max_tokens_str=os.environ.get("GEMINI_LARGE_MODEL_MAX_TOKENS"), system_prompt_content=None, system_prompt_path=None, is_gemini=True, # Important for Gemini bot max_history_length=20 ) @patch.object(OpenAICompatibleInferenceBot, '__init__') def test_init_with_arguments(self, mock_super_init): mock_client_arg = MagicMock() bot = GeminiTelegramInferenceBot( openai_client=mock_client_arg, # Name in Gemini bot is openai_client for consistency api_key="arg_gem_key", base_url="https://arg.gemini.com", small_model_name="arg_gem_small", small_model_max_tokens="124", large_model_name="arg_gem_large", large_model_max_tokens="457", system_prompt_content="Gemini prompt" ) mock_super_init.assert_called_once_with( client=mock_client_arg, api_key="arg_gem_key", base_url="https://arg.gemini.com", api_version=None, azure_deployment=None, model_name="arg_gem_small", max_tokens_str="124", small_model_name="arg_gem_small", small_model_max_tokens_str="124", large_model_name="arg_gem_large", large_model_max_tokens_str="457", system_prompt_content="Gemini prompt", system_prompt_path=None, is_gemini=True, max_history_length=20 ) @patch('openai.OpenAI') # Gemini bot uses OpenAI client configured for Gemini endpoint async def test_switch_model_logic(self, mock_openai_constructor): mock_openai_constructor.return_value = self.mock_openai_client os.environ["GEMINI_SMALL_MODEL"] = "env-gemini-small" os.environ["GEMINI_SMALL_MODEL_MAX_TOKENS"] = "110" os.environ["GEMINI_LARGE_MODEL"] = "env-gemini-large" os.environ["GEMINI_LARGE_MODEL_MAX_TOKENS"] = "220" bot = GeminiTelegramInferenceBot() # Uses env vars by default self.assertEqual(bot.model, "env-gemini-small") self.assertEqual(bot.max_tokens, 110) status = await bot.switch_model() self.assertEqual(bot.model, "env-gemini-large") self.assertEqual(bot.max_tokens, 220) self.assertEqual(status, "Switched to model: env-gemini-large") status = await bot.switch_model() self.assertEqual(bot.model, "env-gemini-small") self.assertEqual(bot.max_tokens, 110) self.assertEqual(status, "Switched to model: env-gemini-small") @patch('openai.OpenAI') async def test_switch_model_uses_instance_configs_if_provided(self, mock_openai_constructor): mock_openai_constructor.return_value = self.mock_openai_client bot = GeminiTelegramInferenceBot( small_model_name="init-gem-small", small_model_max_tokens="55", large_model_name="init-gem-large", large_model_max_tokens="155" ) self.assertEqual(bot.model, "init-gem-small") self.assertEqual(bot.max_tokens, 55) status = await bot.switch_model() self.assertEqual(bot.model, "init-gem-large") self.assertEqual(bot.max_tokens, 155) self.assertEqual(status, "Switched to model: init-gem-large") status = await bot.switch_model() self.assertEqual(bot.model, "init-gem-small") self.assertEqual(bot.max_tokens, 55) self.assertEqual(status, "Switched to model: init-gem-small") @patch('openai.OpenAI') def test_get_llm_description_for_gemini_bot(self, mock_openai_constructor): mock_openai_constructor.return_value = self.mock_openai_client bot = GeminiTelegramInferenceBot( small_model_name="gemini-pro-desc", small_model_max_tokens="888", # is_gemini is True by default in constructor call to super ) # LLM description should indicate not Azure, even though it uses OpenAICompatible... base # The is_gemini flag primarily affects client instantiation logic in the superclass. # The azure_openai flag in superclass is based on azure_endpoint presence. self.assertEqual(bot.get_llm_description(), "LLM: gemini-pro-desc, Max Tokens: 888, Azure: False") if __name__ == '__main__': unittest.main()