Refactor: Inject dependencies in ChatGPTTelegramInferenceBot
This commit is contained in:
@@ -1,42 +1,104 @@
|
|||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
from openai import OpenAI
|
from openai import OpenAI # Keep for type hinting and default client creation
|
||||||
from openai_compatible_inference_bot import OpenAICompatibleInferenceBot
|
from openai_compatible_inference_bot import OpenAICompatibleInferenceBot
|
||||||
from telegram_helper import TelegramHelper
|
from telegram_helper import TelegramHelper # Used in main
|
||||||
|
|
||||||
class ChatGPTTelegramInferenceBot(OpenAICompatibleInferenceBot):
|
class ChatGPTTelegramInferenceBot(OpenAICompatibleInferenceBot):
|
||||||
def __init__(self):
|
DEFAULT_SMALL_MODEL_NAME = "gpt-3.5-turbo"
|
||||||
super().__init__()
|
DEFAULT_LARGE_MODEL_NAME = "gpt-4"
|
||||||
self.client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
|
# Default max tokens can be None, relying on parent or API defaults
|
||||||
|
DEFAULT_SMALL_MODEL_MAX_TOKENS = None
|
||||||
|
DEFAULT_LARGE_MODEL_MAX_TOKENS = None
|
||||||
|
|
||||||
self._configure_model_and_tokens(
|
def __init__(
|
||||||
os.environ.get("OPENAI_SMALL_MODEL", "gpt-3.5-turbo"),
|
self,
|
||||||
os.environ.get("OPENAI_SMALL_MODEL_MAX_TOKENS")
|
client: OpenAI | None = None, # Accepts an OpenAI client
|
||||||
|
api_key: str | None = None,
|
||||||
|
small_model_name: str | None = None,
|
||||||
|
small_model_max_tokens: str | None = None, # Kept as str for consistency with env vars
|
||||||
|
large_model_name: str | None = None,
|
||||||
|
large_model_max_tokens: str | None = None,
|
||||||
|
system_prompt_content: str | None = None,
|
||||||
|
system_prompt_path: str | None = None,
|
||||||
|
base_url: str | None = None, # For OpenAI compatible, though direct OpenAI client doesn't use it here
|
||||||
|
):
|
||||||
|
# Initialize model names and tokens before calling super, as super might use them via _configure_model_and_tokens
|
||||||
|
self.small_model_name = small_model_name or os.environ.get("OPENAI_SMALL_MODEL") or self.DEFAULT_SMALL_MODEL_NAME
|
||||||
|
self.small_model_max_tokens_str = small_model_max_tokens or os.environ.get("OPENAI_SMALL_MODEL_MAX_TOKENS") or self.DEFAULT_SMALL_MODEL_MAX_TOKENS
|
||||||
|
|
||||||
|
self.large_model_name = large_model_name or os.environ.get("OPENAI_LARGE_MODEL") or self.DEFAULT_LARGE_MODEL_NAME
|
||||||
|
self.large_model_max_tokens_str = large_model_max_tokens or os.environ.get("OPENAI_LARGE_MODEL_MAX_TOKENS") or self.DEFAULT_LARGE_MODEL_MAX_TOKENS
|
||||||
|
|
||||||
|
# The actual client and active model configuration will be handled by OpenAICompatibleInferenceBot's __init__
|
||||||
|
# We pass the specific OpenAI client or parameters to create one.
|
||||||
|
# If a client is passed, api_key and base_url might be ignored by super if super prioritizes existing client.
|
||||||
|
super().__init__(
|
||||||
|
client=client,
|
||||||
|
api_key=api_key,
|
||||||
|
model_name=self.small_model_name, # Initial model
|
||||||
|
max_tokens_str=self.small_model_max_tokens_str,
|
||||||
|
system_prompt_content=system_prompt_content,
|
||||||
|
system_prompt_path=system_prompt_path,
|
||||||
|
base_url=base_url # Pass base_url, though for standard OpenAI it's fixed
|
||||||
)
|
)
|
||||||
|
# Ensure client is of type OpenAI for this specific class, if not already set by super with a compatible one.
|
||||||
|
# This check is more of an assertion, as OpenAICompatibleInferenceBot should handle client creation.
|
||||||
|
if not isinstance(self.client, OpenAI):
|
||||||
|
# If super() didn't create a vanilla OpenAI client (e.g. if base_url was for Azure)
|
||||||
|
# we might need to recreate it here if this class *must* use a non-Azure OpenAI client.
|
||||||
|
# However, the current structure of OpenAICompatibleInferenceBot handles this.
|
||||||
|
# This is more about ensuring type correctness if code specific to OpenAI (non-compatible) methods were added here.
|
||||||
|
_api_key = api_key or os.environ.get("OPENAI_API_KEY")
|
||||||
|
if not self.client or (base_url and not isinstance(self.client, OpenAI)):
|
||||||
|
# If superclass initialized with a generic client due to base_url, re-init for OpenAI specifically if needed.
|
||||||
|
# For now, assume superclass correctly initializes based on absence of Azure env vars for this path.
|
||||||
|
# This logic might be simplified once OpenAICompatibleInferenceBot is fully refactored.
|
||||||
|
if not _api_key: # Ensure API key is available if we need to create a client
|
||||||
|
raise ValueError("OpenAI API key must be provided for ChatGPTTelegramInferenceBot if no client is passed.")
|
||||||
|
self.client = OpenAI(api_key=_api_key)
|
||||||
|
logging.info("Client re-initialized to standard OpenAI client for ChatGPTTelegramInferenceBot.")
|
||||||
|
|
||||||
async def switch_model(self):
|
async def switch_model(self):
|
||||||
current_small_model = os.environ.get("OPENAI_SMALL_MODEL", "gpt-3.5-turbo")
|
# Uses instance variables for model names set in __init__
|
||||||
current_large_model = os.environ.get("OPENAI_LARGE_MODEL", "gpt-4")
|
if not self.small_model_name or not self.large_model_name:
|
||||||
|
logging.warning("Small or Large model names for OpenAI are not defined. Cannot switch model.")
|
||||||
|
return f"Model switching not fully configured. Currently using {self.model}."
|
||||||
|
|
||||||
if self.model == current_large_model or self.model != current_small_model:
|
current_is_small = self.model == self.small_model_name
|
||||||
target_model = current_small_model
|
current_is_large = self.model == self.large_model_name
|
||||||
target_max_tokens = os.environ.get("OPENAI_SMALL_MODEL_MAX_TOKENS")
|
|
||||||
|
if current_is_large:
|
||||||
|
target_model = self.small_model_name
|
||||||
|
target_max_tokens_str = self.small_model_max_tokens_str
|
||||||
|
elif current_is_small:
|
||||||
|
target_model = self.large_model_name
|
||||||
|
target_max_tokens_str = self.large_model_max_tokens_str
|
||||||
else:
|
else:
|
||||||
target_model = current_large_model
|
# Current model is neither the designated small nor large for this bot,
|
||||||
target_max_tokens = os.environ.get("OPENAI_LARGE_MODEL_MAX_TOKENS")
|
# switch to this bot's default small model as a reset.
|
||||||
|
logging.warning(f"Current model {self.model} is unrecognized for ChatGPT bot. Switching to default small model: {self.small_model_name}.")
|
||||||
|
target_model = self.small_model_name
|
||||||
|
target_max_tokens_str = self.small_model_max_tokens_str
|
||||||
|
|
||||||
self._configure_model_and_tokens(target_model, target_max_tokens)
|
self._configure_model_and_tokens(target_model, target_max_tokens_str)
|
||||||
logging.info(f"Switched to model: {self.model}")
|
# self.model and self.max_tokens are updated by _configure_model_and_tokens
|
||||||
return f"Switched to model: {self.model}"
|
logging.info(f"Switched to OpenAI model: {self.model}")
|
||||||
|
return f"Switched to OpenAI model: {self.model} (Max Tokens: {self.max_tokens if self.max_tokens is not None else 'API default'})"
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
if not os.environ.get("OPENAI_API_KEY"):
|
|
||||||
logging.error("FATAL: OPENAI_API_KEY environment variable not set.")
|
|
||||||
return
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||||
|
|
||||||
bot = ChatGPTTelegramInferenceBot()
|
try:
|
||||||
|
# Example: api_key from env, other params default or from env via constructor logic
|
||||||
|
bot = ChatGPTTelegramInferenceBot(api_key=os.environ.get("OPENAI_API_KEY"))
|
||||||
|
except ValueError as e:
|
||||||
|
logging.error(f"FATAL: {e}")
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"An unexpected error occurred during bot initialization: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
telegram_helper = TelegramHelper(bot)
|
telegram_helper = TelegramHelper(bot)
|
||||||
telegram_helper.run()
|
telegram_helper.run()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user