diff --git a/chatgpt_telegram_inference_bot.py b/chatgpt_telegram_inference_bot.py index 086c49f..1c555c7 100644 --- a/chatgpt_telegram_inference_bot.py +++ b/chatgpt_telegram_inference_bot.py @@ -1,43 +1,105 @@ import os import logging -from openai import OpenAI +from openai import OpenAI # Keep for type hinting and default client creation from openai_compatible_inference_bot import OpenAICompatibleInferenceBot -from telegram_helper import TelegramHelper +from telegram_helper import TelegramHelper # Used in main class ChatGPTTelegramInferenceBot(OpenAICompatibleInferenceBot): - def __init__(self): - super().__init__() - self.client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) + DEFAULT_SMALL_MODEL_NAME = "gpt-3.5-turbo" + DEFAULT_LARGE_MODEL_NAME = "gpt-4" + # Default max tokens can be None, relying on parent or API defaults + DEFAULT_SMALL_MODEL_MAX_TOKENS = None + DEFAULT_LARGE_MODEL_MAX_TOKENS = None + + def __init__( + self, + client: OpenAI | None = None, # Accepts an OpenAI client + api_key: str | None = None, + small_model_name: str | None = None, + small_model_max_tokens: str | None = None, # Kept as str for consistency with env vars + large_model_name: str | None = None, + large_model_max_tokens: str | None = None, + system_prompt_content: str | None = None, + system_prompt_path: str | None = None, + base_url: str | None = None, # For OpenAI compatible, though direct OpenAI client doesn't use it here + ): + # Initialize model names and tokens before calling super, as super might use them via _configure_model_and_tokens + self.small_model_name = small_model_name or os.environ.get("OPENAI_SMALL_MODEL") or self.DEFAULT_SMALL_MODEL_NAME + self.small_model_max_tokens_str = small_model_max_tokens or os.environ.get("OPENAI_SMALL_MODEL_MAX_TOKENS") or self.DEFAULT_SMALL_MODEL_MAX_TOKENS - self._configure_model_and_tokens( - os.environ.get("OPENAI_SMALL_MODEL", "gpt-3.5-turbo"), - os.environ.get("OPENAI_SMALL_MODEL_MAX_TOKENS") + self.large_model_name = large_model_name or os.environ.get("OPENAI_LARGE_MODEL") or self.DEFAULT_LARGE_MODEL_NAME + self.large_model_max_tokens_str = large_model_max_tokens or os.environ.get("OPENAI_LARGE_MODEL_MAX_TOKENS") or self.DEFAULT_LARGE_MODEL_MAX_TOKENS + + # The actual client and active model configuration will be handled by OpenAICompatibleInferenceBot's __init__ + # We pass the specific OpenAI client or parameters to create one. + # If a client is passed, api_key and base_url might be ignored by super if super prioritizes existing client. + super().__init__( + client=client, + api_key=api_key, + model_name=self.small_model_name, # Initial model + max_tokens_str=self.small_model_max_tokens_str, + system_prompt_content=system_prompt_content, + system_prompt_path=system_prompt_path, + base_url=base_url # Pass base_url, though for standard OpenAI it's fixed ) + # Ensure client is of type OpenAI for this specific class, if not already set by super with a compatible one. + # This check is more of an assertion, as OpenAICompatibleInferenceBot should handle client creation. + if not isinstance(self.client, OpenAI): + # If super() didn't create a vanilla OpenAI client (e.g. if base_url was for Azure) + # we might need to recreate it here if this class *must* use a non-Azure OpenAI client. + # However, the current structure of OpenAICompatibleInferenceBot handles this. + # This is more about ensuring type correctness if code specific to OpenAI (non-compatible) methods were added here. + _api_key = api_key or os.environ.get("OPENAI_API_KEY") + if not self.client or (base_url and not isinstance(self.client, OpenAI)): + # If superclass initialized with a generic client due to base_url, re-init for OpenAI specifically if needed. + # For now, assume superclass correctly initializes based on absence of Azure env vars for this path. + # This logic might be simplified once OpenAICompatibleInferenceBot is fully refactored. + if not _api_key: # Ensure API key is available if we need to create a client + raise ValueError("OpenAI API key must be provided for ChatGPTTelegramInferenceBot if no client is passed.") + self.client = OpenAI(api_key=_api_key) + logging.info("Client re-initialized to standard OpenAI client for ChatGPTTelegramInferenceBot.") async def switch_model(self): - current_small_model = os.environ.get("OPENAI_SMALL_MODEL", "gpt-3.5-turbo") - current_large_model = os.environ.get("OPENAI_LARGE_MODEL", "gpt-4") + # Uses instance variables for model names set in __init__ + if not self.small_model_name or not self.large_model_name: + logging.warning("Small or Large model names for OpenAI are not defined. Cannot switch model.") + return f"Model switching not fully configured. Currently using {self.model}." - if self.model == current_large_model or self.model != current_small_model: - target_model = current_small_model - target_max_tokens = os.environ.get("OPENAI_SMALL_MODEL_MAX_TOKENS") + current_is_small = self.model == self.small_model_name + current_is_large = self.model == self.large_model_name + + if current_is_large: + target_model = self.small_model_name + target_max_tokens_str = self.small_model_max_tokens_str + elif current_is_small: + target_model = self.large_model_name + target_max_tokens_str = self.large_model_max_tokens_str else: - target_model = current_large_model - target_max_tokens = os.environ.get("OPENAI_LARGE_MODEL_MAX_TOKENS") + # Current model is neither the designated small nor large for this bot, + # switch to this bot's default small model as a reset. + logging.warning(f"Current model {self.model} is unrecognized for ChatGPT bot. Switching to default small model: {self.small_model_name}.") + target_model = self.small_model_name + target_max_tokens_str = self.small_model_max_tokens_str - self._configure_model_and_tokens(target_model, target_max_tokens) - logging.info(f"Switched to model: {self.model}") - return f"Switched to model: {self.model}" + self._configure_model_and_tokens(target_model, target_max_tokens_str) + # self.model and self.max_tokens are updated by _configure_model_and_tokens + logging.info(f"Switched to OpenAI model: {self.model}") + return f"Switched to OpenAI model: {self.model} (Max Tokens: {self.max_tokens if self.max_tokens is not None else 'API default'})" def main(): - if not os.environ.get("OPENAI_API_KEY"): - logging.error("FATAL: OPENAI_API_KEY environment variable not set.") - return - logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') - bot = ChatGPTTelegramInferenceBot() - telegram_helper = TelegramHelper(bot) + try: + # Example: api_key from env, other params default or from env via constructor logic + bot = ChatGPTTelegramInferenceBot(api_key=os.environ.get("OPENAI_API_KEY")) + except ValueError as e: + logging.error(f"FATAL: {e}") + return + except Exception as e: + logging.error(f"An unexpected error occurred during bot initialization: {e}") + return + + telegram_helper = TelegramHelper(bot) telegram_helper.run() if __name__ == '__main__':