Refactor: Inject dependencies in GeminiTelegramInferenceBot
This commit is contained in:
@@ -1,43 +1,103 @@
|
||||
import os
|
||||
import logging
|
||||
from openai import OpenAI
|
||||
from openai import OpenAI # For type hinting and default client creation if needed
|
||||
from openai_compatible_inference_bot import OpenAICompatibleInferenceBot
|
||||
from telegram_helper import TelegramHelper
|
||||
from telegram_helper import TelegramHelper # Used in main
|
||||
|
||||
class GeminiTelegramInferenceBot(OpenAICompatibleInferenceBot):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.client = OpenAI(api_key=os.environ.get("GEMINI_API_KEY"), base_url=os.environ.get("GEMINI_API_BASE_URL"))
|
||||
DEFAULT_GEMINI_API_BASE_URL = "https://generativelanguage.googleapis.com/v1beta"
|
||||
DEFAULT_SMALL_MODEL_NAME = "gemini-pro" # Actual model name for Gemini, not via OpenAI client directly
|
||||
DEFAULT_LARGE_MODEL_NAME = "gemini-1.5-pro-latest"
|
||||
DEFAULT_SMALL_MODEL_MAX_TOKENS = "2048" # Gemini uses outputTokenLimit, not exactly max_tokens in OpenAI sense
|
||||
DEFAULT_LARGE_MODEL_MAX_TOKENS = "8192"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: OpenAI | None = None, # OpenAI client for compatible mode
|
||||
api_key: str | None = None, # Gemini API Key
|
||||
base_url: str | None = None, # Gemini API Base URL for OpenAI client
|
||||
small_model_name: str | None = None,
|
||||
small_model_max_tokens: str | None = None,
|
||||
large_model_name: str | None = None,
|
||||
large_model_max_tokens: str | None = None,
|
||||
system_prompt_content: str | None = None,
|
||||
system_prompt_path: str | None = None
|
||||
):
|
||||
_api_key = api_key or os.environ.get("GEMINI_API_KEY")
|
||||
_base_url = base_url or os.environ.get("GEMINI_API_BASE_URL") or self.DEFAULT_GEMINI_API_BASE_URL
|
||||
|
||||
if not _api_key:
|
||||
# This check might seem redundant if super() also checks, but it's good for clarity
|
||||
# for this specific bot type if it were to be instantiated directly with missing critical env vars.
|
||||
raise ValueError("Gemini API key must be provided either via api_key argument or GEMINI_API_KEY environment variable.")
|
||||
|
||||
self.small_model_name = small_model_name or os.environ.get("GEMINI_SMALL_MODEL") or self.DEFAULT_SMALL_MODEL_NAME
|
||||
self.small_model_max_tokens_str = small_model_max_tokens or os.environ.get("GEMINI_SMALL_MODEL_MAX_TOKENS") or self.DEFAULT_SMALL_MODEL_MAX_TOKENS
|
||||
|
||||
self._configure_model_and_tokens(
|
||||
os.environ.get("GEMINI_SMALL_MODEL", "gemini-pro"),
|
||||
os.environ.get("GEMINI_SMALL_MODEL_MAX_TOKENS")
|
||||
self.large_model_name = large_model_name or os.environ.get("GEMINI_LARGE_MODEL") or self.DEFAULT_LARGE_MODEL_NAME
|
||||
self.large_model_max_tokens_str = large_model_max_tokens or os.environ.get("GEMINI_LARGE_MODEL_MAX_TOKENS") or self.DEFAULT_LARGE_MODEL_MAX_TOKENS
|
||||
|
||||
# Pass parameters to the OpenAICompatibleInferenceBot constructor
|
||||
# It will create an OpenAI client configured for the Gemini endpoint
|
||||
super().__init__(
|
||||
client=client,
|
||||
api_key=_api_key, # This key will be used by OpenAI client for the custom base_url
|
||||
model_name=self.small_model_name, # Initial model
|
||||
max_tokens_str=self.small_model_max_tokens_str,
|
||||
system_prompt_content=system_prompt_content,
|
||||
system_prompt_path=system_prompt_path,
|
||||
base_url=_base_url, # Crucial for Gemini via OpenAI client
|
||||
is_gemini=True # Flag for specific Gemini handling in compatible layer if needed
|
||||
)
|
||||
# self.client will be set by OpenAICompatibleInferenceBot with base_url and api_key.
|
||||
# Logging to confirm Gemini specific setup
|
||||
logging.info(f"GeminiTelegramInferenceBot initialized to use model {self.model} via {_base_url}")
|
||||
|
||||
async def switch_model(self):
|
||||
current_small_model = os.environ.get("GEMINI_SMALL_MODEL", "gemini-pro")
|
||||
current_large_model = os.environ.get("GEMINI_LARGE_MODEL", "gemini-1.5-pro-latest")
|
||||
if not self.small_model_name or not self.large_model_name:
|
||||
logging.warning("Small or Large model names for Gemini are not defined. Cannot switch model.")
|
||||
return f"Model switching not fully configured. Currently using {self.model}."
|
||||
|
||||
if self.model == current_large_model or self.model != current_small_model :
|
||||
target_model = current_small_model
|
||||
target_max_tokens = os.environ.get("GEMINI_SMALL_MODEL_MAX_TOKENS")
|
||||
current_is_small = self.model == self.small_model_name
|
||||
current_is_large = self.model == self.large_model_name
|
||||
|
||||
if current_is_large:
|
||||
target_model = self.small_model_name
|
||||
target_max_tokens_str = self.small_model_max_tokens_str
|
||||
elif current_is_small:
|
||||
target_model = self.large_model_name
|
||||
target_max_tokens_str = self.large_model_max_tokens_str
|
||||
else:
|
||||
target_model = current_large_model
|
||||
target_max_tokens = os.environ.get("GEMINI_LARGE_MODEL_MAX_TOKENS")
|
||||
|
||||
self._configure_model_and_tokens(target_model, target_max_tokens)
|
||||
logging.info(f"Switched to model: {self.model}")
|
||||
return f"Switched to model: {self.model}"
|
||||
logging.warning(f"Current model {self.model} is unrecognized for Gemini bot. Switching to default small model: {self.small_model_name}.")
|
||||
target_model = self.small_model_name
|
||||
target_max_tokens_str = self.small_model_max_tokens_str
|
||||
|
||||
self._configure_model_and_tokens(target_model, target_max_tokens_str)
|
||||
logging.info(f"Switched to Gemini model: {self.model}")
|
||||
# For Gemini, max_tokens might translate to outputTokenLimit, so be clear it's a configuration parameter
|
||||
return f"Switched to Gemini model: {self.model} (Configured Max Tokens: {self.max_tokens if self.max_tokens is not None else 'API default'})"
|
||||
|
||||
def main():
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
|
||||
# GEMINI_API_KEY is crucial for this bot
|
||||
if not os.environ.get("GEMINI_API_KEY"):
|
||||
logging.error("FATAL: GEMINI_API_KEY environment variable not set.")
|
||||
return
|
||||
# GEMINI_API_BASE_URL is also important, but constructor has a default
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
|
||||
bot = GeminiTelegramInferenceBot()
|
||||
telegram_helper = TelegramHelper(bot)
|
||||
try:
|
||||
bot = GeminiTelegramInferenceBot(
|
||||
# api_key and base_url will be picked from ENV by constructor if not passed
|
||||
)
|
||||
except ValueError as e:
|
||||
logging.error(f"FATAL: {e}")
|
||||
return
|
||||
except Exception as e: # Catch any other init errors
|
||||
logging.error(f"An unexpected error occurred during bot initialization: {e}")
|
||||
return
|
||||
|
||||
telegram_helper = TelegramHelper(bot)
|
||||
telegram_helper.run()
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Reference in New Issue
Block a user