diff --git a/gemini_telegram_inference_bot.py b/gemini_telegram_inference_bot.py index 5c5f549..09c174e 100644 --- a/gemini_telegram_inference_bot.py +++ b/gemini_telegram_inference_bot.py @@ -1,43 +1,103 @@ import os import logging -from openai import OpenAI +from openai import OpenAI # For type hinting and default client creation if needed from openai_compatible_inference_bot import OpenAICompatibleInferenceBot -from telegram_helper import TelegramHelper +from telegram_helper import TelegramHelper # Used in main class GeminiTelegramInferenceBot(OpenAICompatibleInferenceBot): - def __init__(self): - super().__init__() - self.client = OpenAI(api_key=os.environ.get("GEMINI_API_KEY"), base_url=os.environ.get("GEMINI_API_BASE_URL")) + DEFAULT_GEMINI_API_BASE_URL = "https://generativelanguage.googleapis.com/v1beta" + DEFAULT_SMALL_MODEL_NAME = "gemini-pro" # Actual model name for Gemini, not via OpenAI client directly + DEFAULT_LARGE_MODEL_NAME = "gemini-1.5-pro-latest" + DEFAULT_SMALL_MODEL_MAX_TOKENS = "2048" # Gemini uses outputTokenLimit, not exactly max_tokens in OpenAI sense + DEFAULT_LARGE_MODEL_MAX_TOKENS = "8192" + + def __init__( + self, + client: OpenAI | None = None, # OpenAI client for compatible mode + api_key: str | None = None, # Gemini API Key + base_url: str | None = None, # Gemini API Base URL for OpenAI client + small_model_name: str | None = None, + small_model_max_tokens: str | None = None, + large_model_name: str | None = None, + large_model_max_tokens: str | None = None, + system_prompt_content: str | None = None, + system_prompt_path: str | None = None + ): + _api_key = api_key or os.environ.get("GEMINI_API_KEY") + _base_url = base_url or os.environ.get("GEMINI_API_BASE_URL") or self.DEFAULT_GEMINI_API_BASE_URL + + if not _api_key: + # This check might seem redundant if super() also checks, but it's good for clarity + # for this specific bot type if it were to be instantiated directly with missing critical env vars. + raise ValueError("Gemini API key must be provided either via api_key argument or GEMINI_API_KEY environment variable.") + + self.small_model_name = small_model_name or os.environ.get("GEMINI_SMALL_MODEL") or self.DEFAULT_SMALL_MODEL_NAME + self.small_model_max_tokens_str = small_model_max_tokens or os.environ.get("GEMINI_SMALL_MODEL_MAX_TOKENS") or self.DEFAULT_SMALL_MODEL_MAX_TOKENS - self._configure_model_and_tokens( - os.environ.get("GEMINI_SMALL_MODEL", "gemini-pro"), - os.environ.get("GEMINI_SMALL_MODEL_MAX_TOKENS") + self.large_model_name = large_model_name or os.environ.get("GEMINI_LARGE_MODEL") or self.DEFAULT_LARGE_MODEL_NAME + self.large_model_max_tokens_str = large_model_max_tokens or os.environ.get("GEMINI_LARGE_MODEL_MAX_TOKENS") or self.DEFAULT_LARGE_MODEL_MAX_TOKENS + + # Pass parameters to the OpenAICompatibleInferenceBot constructor + # It will create an OpenAI client configured for the Gemini endpoint + super().__init__( + client=client, + api_key=_api_key, # This key will be used by OpenAI client for the custom base_url + model_name=self.small_model_name, # Initial model + max_tokens_str=self.small_model_max_tokens_str, + system_prompt_content=system_prompt_content, + system_prompt_path=system_prompt_path, + base_url=_base_url, # Crucial for Gemini via OpenAI client + is_gemini=True # Flag for specific Gemini handling in compatible layer if needed ) + # self.client will be set by OpenAICompatibleInferenceBot with base_url and api_key. + # Logging to confirm Gemini specific setup + logging.info(f"GeminiTelegramInferenceBot initialized to use model {self.model} via {_base_url}") async def switch_model(self): - current_small_model = os.environ.get("GEMINI_SMALL_MODEL", "gemini-pro") - current_large_model = os.environ.get("GEMINI_LARGE_MODEL", "gemini-1.5-pro-latest") + if not self.small_model_name or not self.large_model_name: + logging.warning("Small or Large model names for Gemini are not defined. Cannot switch model.") + return f"Model switching not fully configured. Currently using {self.model}." - if self.model == current_large_model or self.model != current_small_model : - target_model = current_small_model - target_max_tokens = os.environ.get("GEMINI_SMALL_MODEL_MAX_TOKENS") + current_is_small = self.model == self.small_model_name + current_is_large = self.model == self.large_model_name + + if current_is_large: + target_model = self.small_model_name + target_max_tokens_str = self.small_model_max_tokens_str + elif current_is_small: + target_model = self.large_model_name + target_max_tokens_str = self.large_model_max_tokens_str else: - target_model = current_large_model - target_max_tokens = os.environ.get("GEMINI_LARGE_MODEL_MAX_TOKENS") - - self._configure_model_and_tokens(target_model, target_max_tokens) - logging.info(f"Switched to model: {self.model}") - return f"Switched to model: {self.model}" + logging.warning(f"Current model {self.model} is unrecognized for Gemini bot. Switching to default small model: {self.small_model_name}.") + target_model = self.small_model_name + target_max_tokens_str = self.small_model_max_tokens_str + + self._configure_model_and_tokens(target_model, target_max_tokens_str) + logging.info(f"Switched to Gemini model: {self.model}") + # For Gemini, max_tokens might translate to outputTokenLimit, so be clear it's a configuration parameter + return f"Switched to Gemini model: {self.model} (Configured Max Tokens: {self.max_tokens if self.max_tokens is not None else 'API default'})" def main(): + logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') + + # GEMINI_API_KEY is crucial for this bot if not os.environ.get("GEMINI_API_KEY"): logging.error("FATAL: GEMINI_API_KEY environment variable not set.") return + # GEMINI_API_BASE_URL is also important, but constructor has a default - logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') - - bot = GeminiTelegramInferenceBot() - telegram_helper = TelegramHelper(bot) + try: + bot = GeminiTelegramInferenceBot( + # api_key and base_url will be picked from ENV by constructor if not passed + ) + except ValueError as e: + logging.error(f"FATAL: {e}") + return + except Exception as e: # Catch any other init errors + logging.error(f"An unexpected error occurred during bot initialization: {e}") + return + + telegram_helper = TelegramHelper(bot) telegram_helper.run() if __name__ == '__main__':