Refactor: Inject dependencies in GeminiTelegramInferenceBot

This commit is contained in:
cyclop-bot
2025-06-02 16:42:48 -05:00
parent 0347b8bd4f
commit 0e4ba10e04
+83 -23
View File
@@ -1,43 +1,103 @@
import os
import logging
from openai import OpenAI
from openai import OpenAI # For type hinting and default client creation if needed
from openai_compatible_inference_bot import OpenAICompatibleInferenceBot
from telegram_helper import TelegramHelper
from telegram_helper import TelegramHelper # Used in main
class GeminiTelegramInferenceBot(OpenAICompatibleInferenceBot):
def __init__(self):
super().__init__()
self.client = OpenAI(api_key=os.environ.get("GEMINI_API_KEY"), base_url=os.environ.get("GEMINI_API_BASE_URL"))
DEFAULT_GEMINI_API_BASE_URL = "https://generativelanguage.googleapis.com/v1beta"
DEFAULT_SMALL_MODEL_NAME = "gemini-pro" # Actual model name for Gemini, not via OpenAI client directly
DEFAULT_LARGE_MODEL_NAME = "gemini-1.5-pro-latest"
DEFAULT_SMALL_MODEL_MAX_TOKENS = "2048" # Gemini uses outputTokenLimit, not exactly max_tokens in OpenAI sense
DEFAULT_LARGE_MODEL_MAX_TOKENS = "8192"
def __init__(
self,
client: OpenAI | None = None, # OpenAI client for compatible mode
api_key: str | None = None, # Gemini API Key
base_url: str | None = None, # Gemini API Base URL for OpenAI client
small_model_name: str | None = None,
small_model_max_tokens: str | None = None,
large_model_name: str | None = None,
large_model_max_tokens: str | None = None,
system_prompt_content: str | None = None,
system_prompt_path: str | None = None
):
_api_key = api_key or os.environ.get("GEMINI_API_KEY")
_base_url = base_url or os.environ.get("GEMINI_API_BASE_URL") or self.DEFAULT_GEMINI_API_BASE_URL
if not _api_key:
# This check might seem redundant if super() also checks, but it's good for clarity
# for this specific bot type if it were to be instantiated directly with missing critical env vars.
raise ValueError("Gemini API key must be provided either via api_key argument or GEMINI_API_KEY environment variable.")
self.small_model_name = small_model_name or os.environ.get("GEMINI_SMALL_MODEL") or self.DEFAULT_SMALL_MODEL_NAME
self.small_model_max_tokens_str = small_model_max_tokens or os.environ.get("GEMINI_SMALL_MODEL_MAX_TOKENS") or self.DEFAULT_SMALL_MODEL_MAX_TOKENS
self._configure_model_and_tokens(
os.environ.get("GEMINI_SMALL_MODEL", "gemini-pro"),
os.environ.get("GEMINI_SMALL_MODEL_MAX_TOKENS")
self.large_model_name = large_model_name or os.environ.get("GEMINI_LARGE_MODEL") or self.DEFAULT_LARGE_MODEL_NAME
self.large_model_max_tokens_str = large_model_max_tokens or os.environ.get("GEMINI_LARGE_MODEL_MAX_TOKENS") or self.DEFAULT_LARGE_MODEL_MAX_TOKENS
# Pass parameters to the OpenAICompatibleInferenceBot constructor
# It will create an OpenAI client configured for the Gemini endpoint
super().__init__(
client=client,
api_key=_api_key, # This key will be used by OpenAI client for the custom base_url
model_name=self.small_model_name, # Initial model
max_tokens_str=self.small_model_max_tokens_str,
system_prompt_content=system_prompt_content,
system_prompt_path=system_prompt_path,
base_url=_base_url, # Crucial for Gemini via OpenAI client
is_gemini=True # Flag for specific Gemini handling in compatible layer if needed
)
# self.client will be set by OpenAICompatibleInferenceBot with base_url and api_key.
# Logging to confirm Gemini specific setup
logging.info(f"GeminiTelegramInferenceBot initialized to use model {self.model} via {_base_url}")
async def switch_model(self):
current_small_model = os.environ.get("GEMINI_SMALL_MODEL", "gemini-pro")
current_large_model = os.environ.get("GEMINI_LARGE_MODEL", "gemini-1.5-pro-latest")
if not self.small_model_name or not self.large_model_name:
logging.warning("Small or Large model names for Gemini are not defined. Cannot switch model.")
return f"Model switching not fully configured. Currently using {self.model}."
if self.model == current_large_model or self.model != current_small_model :
target_model = current_small_model
target_max_tokens = os.environ.get("GEMINI_SMALL_MODEL_MAX_TOKENS")
current_is_small = self.model == self.small_model_name
current_is_large = self.model == self.large_model_name
if current_is_large:
target_model = self.small_model_name
target_max_tokens_str = self.small_model_max_tokens_str
elif current_is_small:
target_model = self.large_model_name
target_max_tokens_str = self.large_model_max_tokens_str
else:
target_model = current_large_model
target_max_tokens = os.environ.get("GEMINI_LARGE_MODEL_MAX_TOKENS")
self._configure_model_and_tokens(target_model, target_max_tokens)
logging.info(f"Switched to model: {self.model}")
return f"Switched to model: {self.model}"
logging.warning(f"Current model {self.model} is unrecognized for Gemini bot. Switching to default small model: {self.small_model_name}.")
target_model = self.small_model_name
target_max_tokens_str = self.small_model_max_tokens_str
self._configure_model_and_tokens(target_model, target_max_tokens_str)
logging.info(f"Switched to Gemini model: {self.model}")
# For Gemini, max_tokens might translate to outputTokenLimit, so be clear it's a configuration parameter
return f"Switched to Gemini model: {self.model} (Configured Max Tokens: {self.max_tokens if self.max_tokens is not None else 'API default'})"
def main():
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# GEMINI_API_KEY is crucial for this bot
if not os.environ.get("GEMINI_API_KEY"):
logging.error("FATAL: GEMINI_API_KEY environment variable not set.")
return
# GEMINI_API_BASE_URL is also important, but constructor has a default
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
bot = GeminiTelegramInferenceBot()
telegram_helper = TelegramHelper(bot)
try:
bot = GeminiTelegramInferenceBot(
# api_key and base_url will be picked from ENV by constructor if not passed
)
except ValueError as e:
logging.error(f"FATAL: {e}")
return
except Exception as e: # Catch any other init errors
logging.error(f"An unexpected error occurred during bot initialization: {e}")
return
telegram_helper = TelegramHelper(bot)
telegram_helper.run()
if __name__ == '__main__':