diff --git a/tests/test_gemini_telegram_inference_bot.py b/tests/test_gemini_telegram_inference_bot.py new file mode 100644 index 0000000..8e5cc4f --- /dev/null +++ b/tests/test_gemini_telegram_inference_bot.py @@ -0,0 +1,154 @@ +import unittest +from unittest.mock import MagicMock, patch, ANY +import os + +# Assuming gemini_telegram_inference_bot.py and its parent are accessible +from gemini_telegram_inference_bot import GeminiTelegramInferenceBot +from openai_compatible_inference_bot import OpenAICompatibleInferenceBot # For patching super + +class TestGeminiTelegramInferenceBot(unittest.IsolatedAsyncioTestCase): + + def setUp(self): + # Store and clear relevant environment variables + self.original_gemini_key = os.environ.get("GEMINI_API_KEY") + self.original_gemini_base_url = os.environ.get("GEMINI_API_BASE_URL") + self.original_small_model = os.environ.get("GEMINI_SMALL_MODEL") + self.original_large_model = os.environ.get("GEMINI_LARGE_MODEL") + self.original_small_tokens = os.environ.get("GEMINI_SMALL_MODEL_MAX_TOKENS") + self.original_large_tokens = os.environ.get("GEMINI_LARGE_MODEL_MAX_TOKENS") + self.original_system_prompt_path = os.environ.get("SYSTEM_PROMPT_PATH") + + for key in ["GEMINI_API_KEY", "GEMINI_API_BASE_URL", "GEMINI_SMALL_MODEL", "GEMINI_LARGE_MODEL", + "GEMINI_SMALL_MODEL_MAX_TOKENS", "GEMINI_LARGE_MODEL_MAX_TOKENS", "SYSTEM_PROMPT_PATH"]: + if os.environ.get(key): + del os.environ[key] + + self.mock_openai_client = MagicMock() # Used if superclass creates an OpenAI client + + def tearDown(self): + # Restore environment variables + if self.original_gemini_key: os.environ["GEMINI_API_KEY"] = self.original_gemini_key + if self.original_gemini_base_url: os.environ["GEMINI_API_BASE_URL"] = self.original_gemini_base_url + if self.original_small_model: os.environ["GEMINI_SMALL_MODEL"] = self.original_small_model + if self.original_large_model: os.environ["GEMINI_LARGE_MODEL"] = self.original_large_model + if self.original_small_tokens: os.environ["GEMINI_SMALL_MODEL_MAX_TOKENS"] = self.original_small_tokens + if self.original_large_tokens: os.environ["GEMINI_LARGE_MODEL_MAX_TOKENS"] = self.original_large_tokens + if self.original_system_prompt_path: os.environ["SYSTEM_PROMPT_PATH"] = self.original_system_prompt_path + + @patch.object(OpenAICompatibleInferenceBot, '__init__') # Mock the superclass's __init__ + def test_init_defaults_and_super_call(self, mock_super_init): + os.environ["GEMINI_API_KEY"] = "test_key_gemini" + os.environ["GEMINI_API_BASE_URL"] = "https://gemini.env.com" + os.environ["GEMINI_SMALL_MODEL"] = "gemini-pro-env" + os.environ["GEMINI_SMALL_MODEL_MAX_TOKENS"] = "360" + + bot = GeminiTelegramInferenceBot() + + mock_super_init.assert_called_once_with( + client=None, + api_key="test_key_gemini", + base_url="https://gemini.env.com", # Passed to super + api_version=None, + azure_deployment=None, + model_name="gemini-pro-env", + max_tokens_str="360", + small_model_name="gemini-pro-env", + small_model_max_tokens_str="360", + large_model_name=os.environ.get("GEMINI_LARGE_MODEL", "gemini-1.5-pro-latest"), # Default large + large_model_max_tokens_str=os.environ.get("GEMINI_LARGE_MODEL_MAX_TOKENS"), + system_prompt_content=None, + system_prompt_path=None, + is_gemini=True, # Important for Gemini bot + max_history_length=20 + ) + + @patch.object(OpenAICompatibleInferenceBot, '__init__') + def test_init_with_arguments(self, mock_super_init): + mock_client_arg = MagicMock() + bot = GeminiTelegramInferenceBot( + openai_client=mock_client_arg, # Name in Gemini bot is openai_client for consistency + api_key="arg_gem_key", + base_url="https://arg.gemini.com", + small_model_name="arg_gem_small", + small_model_max_tokens="124", + large_model_name="arg_gem_large", + large_model_max_tokens="457", + system_prompt_content="Gemini prompt" + ) + mock_super_init.assert_called_once_with( + client=mock_client_arg, + api_key="arg_gem_key", + base_url="https://arg.gemini.com", + api_version=None, + azure_deployment=None, + model_name="arg_gem_small", + max_tokens_str="124", + small_model_name="arg_gem_small", + small_model_max_tokens_str="124", + large_model_name="arg_gem_large", + large_model_max_tokens_str="457", + system_prompt_content="Gemini prompt", + system_prompt_path=None, + is_gemini=True, + max_history_length=20 + ) + + @patch('openai.OpenAI') # Gemini bot uses OpenAI client configured for Gemini endpoint + async def test_switch_model_logic(self, mock_openai_constructor): + mock_openai_constructor.return_value = self.mock_openai_client + + os.environ["GEMINI_SMALL_MODEL"] = "env-gemini-small" + os.environ["GEMINI_SMALL_MODEL_MAX_TOKENS"] = "110" + os.environ["GEMINI_LARGE_MODEL"] = "env-gemini-large" + os.environ["GEMINI_LARGE_MODEL_MAX_TOKENS"] = "220" + + bot = GeminiTelegramInferenceBot() # Uses env vars by default + self.assertEqual(bot.model, "env-gemini-small") + self.assertEqual(bot.max_tokens, 110) + + status = await bot.switch_model() + self.assertEqual(bot.model, "env-gemini-large") + self.assertEqual(bot.max_tokens, 220) + self.assertEqual(status, "Switched to model: env-gemini-large") + + status = await bot.switch_model() + self.assertEqual(bot.model, "env-gemini-small") + self.assertEqual(bot.max_tokens, 110) + self.assertEqual(status, "Switched to model: env-gemini-small") + + @patch('openai.OpenAI') + async def test_switch_model_uses_instance_configs_if_provided(self, mock_openai_constructor): + mock_openai_constructor.return_value = self.mock_openai_client + + bot = GeminiTelegramInferenceBot( + small_model_name="init-gem-small", small_model_max_tokens="55", + large_model_name="init-gem-large", large_model_max_tokens="155" + ) + self.assertEqual(bot.model, "init-gem-small") + self.assertEqual(bot.max_tokens, 55) + + status = await bot.switch_model() + self.assertEqual(bot.model, "init-gem-large") + self.assertEqual(bot.max_tokens, 155) + self.assertEqual(status, "Switched to model: init-gem-large") + + status = await bot.switch_model() + self.assertEqual(bot.model, "init-gem-small") + self.assertEqual(bot.max_tokens, 55) + self.assertEqual(status, "Switched to model: init-gem-small") + + @patch('openai.OpenAI') + def test_get_llm_description_for_gemini_bot(self, mock_openai_constructor): + mock_openai_constructor.return_value = self.mock_openai_client + bot = GeminiTelegramInferenceBot( + small_model_name="gemini-pro-desc", + small_model_max_tokens="888", + # is_gemini is True by default in constructor call to super + ) + # LLM description should indicate not Azure, even though it uses OpenAICompatible... base + # The is_gemini flag primarily affects client instantiation logic in the superclass. + # The azure_openai flag in superclass is based on azure_endpoint presence. + self.assertEqual(bot.get_llm_description(), "LLM: gemini-pro-desc, Max Tokens: 888, Azure: False") + +if __name__ == '__main__': + unittest.main()