Refactored gemini, openai and claude into one file and removed logic from the base class, also made helper class definable from command line
This commit is contained in:
@@ -1,91 +1,67 @@
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
import logging
|
||||
import inspect
|
||||
from abc import abstractmethod
|
||||
from base_telegram_inference_bot import BaseTelegramInferenceBot
|
||||
from openai import OpenAI, AzureOpenAI # Import both
|
||||
|
||||
class OpenAICompatibleInferenceBot(BaseTelegramInferenceBot):
|
||||
DEFAULT_MAX_TOKENS = 1000 # Default for _configure_model_and_tokens
|
||||
from openai import OpenAI
|
||||
from tools.base_tool import BaseTool
|
||||
from telegram_helper import TelegramHelper
|
||||
import argparse
|
||||
from inference_bot import InferenceBot
|
||||
|
||||
class OpenAICompatibleInferenceBot(InferenceBot):
|
||||
def __init__(
|
||||
self,
|
||||
client: OpenAI | AzureOpenAI | None = None,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
api_version: str | None = None, # For Azure
|
||||
azure_deployment: str | None = None, # Model for Azure, distinct from general model_name if needed
|
||||
model_name: str | None = None, # General model name for the API call
|
||||
max_tokens_str: str | None = None,
|
||||
system_prompt_content: str | None = None,
|
||||
system_prompt_path: str | None = None,
|
||||
is_gemini: bool = False, # Hint for specific API key if others are not set
|
||||
max_history_length: int | None = None
|
||||
small_model_name: str | None = None,
|
||||
small_model_max_tokens: str | None = None,
|
||||
large_model_name: str | None = None,
|
||||
large_model_max_tokens: str | None = None,
|
||||
allowed_function_tags: list[str] | None = None,
|
||||
system_prompt_path: str | None = None
|
||||
):
|
||||
super().__init__(system_prompt_content=system_prompt_content, system_prompt_path=system_prompt_path)
|
||||
|
||||
self.client = client
|
||||
|
||||
if not self.client:
|
||||
_api_key = api_key
|
||||
_base_url = base_url
|
||||
_api_version = api_version
|
||||
_azure_deployment_name = azure_deployment # This will be used as the model for Azure
|
||||
|
||||
# Determine if configuring for Azure OpenAI
|
||||
is_azure = False
|
||||
if _azure_deployment_name or (_base_url and "azure.com" in _base_url) or os.environ.get("AZURE_OPENAI_ENDPOINT"):
|
||||
is_azure = True
|
||||
|
||||
if is_azure:
|
||||
_base_url = _base_url or os.environ.get("AZURE_OPENAI_ENDPOINT")
|
||||
_api_key = _api_key or os.environ.get("AZURE_OPENAI_KEY")
|
||||
_api_version = _api_version or os.environ.get("AZURE_OPENAI_API_VERSION")
|
||||
# For Azure, the model parameter in API calls is the deployment name
|
||||
_effective_model_name = _azure_deployment_name or model_name # Use deployment if available, else model_name
|
||||
if not _base_url or not _api_key or not _api_version or not _effective_model_name:
|
||||
raise ValueError("For Azure OpenAI, endpoint, API key, API version, and deployment/model name must be configured.")
|
||||
self.client = AzureOpenAI(
|
||||
api_key=_api_key,
|
||||
azure_endpoint=_base_url,
|
||||
api_version=_api_version
|
||||
)
|
||||
# The model to be used in API calls for Azure is the deployment name.
|
||||
# _configure_model_and_tokens will set self.model to this.
|
||||
model_name_for_config = _effective_model_name
|
||||
logging.info(f"Initialized AzureOpenAI client for deployment: {model_name_for_config} at {_base_url}")
|
||||
else:
|
||||
# Standard OpenAI or other OpenAI-compatible (like Gemini via base_url)
|
||||
_base_url = _base_url or os.environ.get("OPENAI_API_BASE_URL") # For other compatible APIs
|
||||
if not _api_key: # Try different ENV sources for API key
|
||||
if is_gemini and os.environ.get("GEMINI_API_KEY"):
|
||||
_api_key = os.environ.get("GEMINI_API_KEY")
|
||||
else:
|
||||
_api_key = os.environ.get("OPENAI_API_KEY")
|
||||
|
||||
if not _api_key and not _base_url : # For completely local models with no key needed via base_url
|
||||
pass # Allow client to be created with no API key if base_url is set and points to local model
|
||||
elif not _api_key:
|
||||
raise ValueError("API key must be provided for OpenAI compatible client if not Azure or local anonymous.")
|
||||
|
||||
self.client = OpenAI(api_key=_api_key, base_url=_base_url)
|
||||
model_name_for_config = model_name # Use the general model_name for non-Azure
|
||||
log_msg = f"Initialized OpenAI compatible client. Target URL: {_base_url if _base_url else 'OpenAI default'}."
|
||||
logging.info(log_msg)
|
||||
else:
|
||||
# Client was provided directly
|
||||
model_name_for_config = model_name # Use provided model_name
|
||||
logging.info(f"Using provided client: {type(self.client)}")
|
||||
self.model_config = {
|
||||
"small_model_name": small_model_name,
|
||||
"small_model_max_tokens": small_model_max_tokens,
|
||||
"large_model_name": large_model_name,
|
||||
"large_model_max_tokens": large_model_max_tokens
|
||||
}
|
||||
self.allowed_function_tags = allowed_function_tags if allowed_function_tags else None
|
||||
self.conversation_history = {}
|
||||
self._processing_status = {}
|
||||
# MODIFIED to pass arguments
|
||||
self.system_prompt = self.load_system_prompt(
|
||||
file_path=system_prompt_path
|
||||
)
|
||||
self.tools, self.functions = self.load_functions()
|
||||
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
||||
log_msg = f"Initialized OpenAI compatible client. Target URL: {base_url if base_url else 'OpenAI default'}."
|
||||
logging.info(log_msg)
|
||||
|
||||
# Configure the actual model name and max_tokens for API calls
|
||||
self._configure_model_and_tokens(
|
||||
model_name_for_config,
|
||||
max_tokens_str,
|
||||
default_max_tokens=self.DEFAULT_MAX_TOKENS
|
||||
self.model_config["small_model_name"],
|
||||
self.model_config["small_model_max_tokens"]
|
||||
)
|
||||
@property
|
||||
def processing_status(self):
|
||||
"""
|
||||
An attribute to store the processing status for users.
|
||||
Example usage in subclass: self.processing_status.get(user_id)
|
||||
"""
|
||||
return self._processing_status
|
||||
|
||||
def clear_conversation_history(self, user_id):
|
||||
if user_id in self.conversation_history:
|
||||
del self.conversation_history[user_id]
|
||||
|
||||
for tool in self.tools:
|
||||
tool.clear()
|
||||
|
||||
def _configure_model_and_tokens(self, model_name: str | None, max_tokens_str: str | None, default_max_tokens: int = 1000):
|
||||
self.model = model_name if model_name else "default-model" # Fallback model name
|
||||
def _configure_model_and_tokens(self, model_name: str | None, max_tokens_str: str | None):
|
||||
self.model = model_name
|
||||
try:
|
||||
# If max_tokens_str is explicitly "None" or empty, treat as None for API default
|
||||
if max_tokens_str and max_tokens_str.lower() not in ["none", "", "null"]:
|
||||
@@ -93,7 +69,7 @@ class OpenAICompatibleInferenceBot(BaseTelegramInferenceBot):
|
||||
else:
|
||||
self.max_tokens = None # Use API default by not sending the parameter or sending null
|
||||
except ValueError:
|
||||
logging.warning(f"Invalid value for max_tokens: {max_tokens_str}. Using API default (None). stalwart default was {default_max_tokens}")
|
||||
logging.warning(f"Invalid value for max_tokens: {max_tokens_str}. Using API default (None)")
|
||||
self.max_tokens = None # Use API default
|
||||
|
||||
logging.info(f"Configured to use model: {self.model} with max_tokens: {self.max_tokens if self.max_tokens is not None else 'API default'}")
|
||||
@@ -109,11 +85,32 @@ class OpenAICompatibleInferenceBot(BaseTelegramInferenceBot):
|
||||
raise ValueError("OpenAI client not initialized.")
|
||||
try:
|
||||
# Pass self.max_tokens directly. If None, OpenAI library omits it or handles it.
|
||||
# Initialize tools filtering based on allowed tags
|
||||
cleaned_tools = None
|
||||
if hasattr(self, 'functions') and self.functions:
|
||||
# Create a copy of functions without "_tags" field
|
||||
cleaned_tools = []
|
||||
for func in self.functions:
|
||||
include_function = False
|
||||
|
||||
if not hasattr(self, 'allowed_function_tags') or self.allowed_function_tags is None:
|
||||
# Include all functions if no tag filtering is specified
|
||||
include_function = True
|
||||
else:
|
||||
# Only include if function has matching tags
|
||||
tags = func.get("_tags", [])
|
||||
if any(tag in self.allowed_function_tags for tag in tags):
|
||||
include_function = True
|
||||
|
||||
if include_function:
|
||||
func_copy = {k: v for k, v in func.items() if k != "_tags"}
|
||||
cleaned_tools.append(func_copy)
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
tools=self.functions if hasattr(self, 'functions') and self.functions else None,
|
||||
tool_choice="auto" if hasattr(self, 'functions') and self.functions else None,
|
||||
tools=cleaned_tools,
|
||||
tool_choice="auto" if cleaned_tools else None,
|
||||
max_tokens=self.max_tokens
|
||||
)
|
||||
return response
|
||||
@@ -200,20 +197,184 @@ class OpenAICompatibleInferenceBot(BaseTelegramInferenceBot):
|
||||
async def start(self):
|
||||
logging.info(f"{self.__class__.__name__} (Model: {self.model}) started.")
|
||||
|
||||
# clear_conversation_history is inherited from BaseTelegramInferenceBot
|
||||
|
||||
async def abort_processing(self, user_id):
|
||||
# This is a soft abort for OpenAI compatible bots as API calls are synchronous within handle_message
|
||||
if user_id in self.processing_status:
|
||||
self.clear_processing_status(user_id) # Use base class method
|
||||
logging.info(f"Processing aborted for user {user_id}.")
|
||||
# Optionally clear conversation history or let user do it explicitly
|
||||
# super().clear_conversation_history(user_id)
|
||||
return "Processing aborted. You can send a new message or /clear the conversation."
|
||||
else:
|
||||
# super().clear_conversation_history(user_id)
|
||||
return "No active processing found to abort. If you wish, /clear the conversation history."
|
||||
|
||||
def load_functions(self):
|
||||
tools = []
|
||||
functions = []
|
||||
tools_dir = os.path.join(os.path.dirname(__file__), 'tools')
|
||||
if not os.path.exists(tools_dir):
|
||||
logging.warning(f"Tools directory not found: {tools_dir}")
|
||||
return [], []
|
||||
|
||||
@abstractmethod
|
||||
for filename in os.listdir(tools_dir):
|
||||
if filename.endswith('.py') and filename != '__init__.py' and filename != 'base_tool.py':
|
||||
module_name = f'tools.{filename[:-3]}'
|
||||
try:
|
||||
module = importlib.import_module(module_name)
|
||||
for name, obj in inspect.getmembers(module):
|
||||
if inspect.isclass(obj) and issubclass(obj, BaseTool) and obj != BaseTool:
|
||||
try:
|
||||
tools.append(obj()) # This instantiation might be an issue for tools needing config
|
||||
except Exception as e:
|
||||
logging.error(f"Error instantiating tool {name} from {filename}: {e}")
|
||||
except Exception as e:
|
||||
logging.error(f"Error importing module {module_name}: {e}")
|
||||
|
||||
for tool in tools:
|
||||
functions.extend(tool.get_functions())
|
||||
return tools, functions
|
||||
|
||||
def load_system_prompt(self, direct_content: str | None = None, file_path: str | None = None) -> str:
|
||||
default_prompt = "You are a helpful AI assistant."
|
||||
|
||||
if direct_content:
|
||||
logging.info("Using direct content for system prompt.")
|
||||
return direct_content.strip()
|
||||
|
||||
prompt_path_to_try = file_path or os.getenv("SYSTEM_PROMPT_PATH")
|
||||
|
||||
if prompt_path_to_try:
|
||||
if os.path.isfile(prompt_path_to_try):
|
||||
try:
|
||||
with open(prompt_path_to_try, "r", encoding="utf-8") as file:
|
||||
content = file.read().strip()
|
||||
logging.info(f"Successfully loaded system prompt from {prompt_path_to_try}.")
|
||||
return content
|
||||
except IOError as e:
|
||||
logging.warning(f"Could not read system prompt file {prompt_path_to_try}: {e}. Using default.")
|
||||
return default_prompt
|
||||
else:
|
||||
# This condition now also covers if 'file_path' argument was given but invalid
|
||||
logging.warning(f"System prompt file {prompt_path_to_try} not found. Using default system prompt.")
|
||||
return default_prompt
|
||||
else:
|
||||
logging.info("No system prompt path provided (argument or ENV) or direct content. Using default system prompt.")
|
||||
return default_prompt
|
||||
|
||||
def set_processing_status(self, user_id: int, message_id: int):
|
||||
self.processing_status[user_id] = {"processing": True, "message_id": message_id}
|
||||
|
||||
def clear_processing_status(self, user_id: int):
|
||||
if user_id in self.processing_status:
|
||||
del self.processing_status[user_id]
|
||||
|
||||
def call_tool(self, function_call_name, function_call_arguments):
|
||||
function_name = function_call_name
|
||||
function_args = None
|
||||
if isinstance(function_call_arguments, dict):
|
||||
function_args = function_call_arguments
|
||||
elif isinstance(function_call_arguments, str):
|
||||
try:
|
||||
function_args = json.loads(function_call_arguments)
|
||||
except json.JSONDecodeError as e:
|
||||
logging.error(f"Error decoding function call arguments (string) for {function_call_name}: {e}. Arguments: {function_call_arguments}")
|
||||
return f"Error: Malformed arguments for tool call: {e}"
|
||||
else:
|
||||
if function_call_arguments is None:
|
||||
function_args = {}
|
||||
else:
|
||||
logging.error(f"Unexpected type for function_call_arguments for {function_call_name}: {type(function_call_arguments)}. Arguments: {function_call_arguments}")
|
||||
return f"Error: Invalid argument type for tool call: {type(function_call_arguments)}"
|
||||
|
||||
for tool in self.tools:
|
||||
for function in tool.get_functions():
|
||||
if function["function"]["name"] == function_name:
|
||||
try:
|
||||
if not isinstance(function_args, dict):
|
||||
logging.error(f"Internal error: function_args not a dict for {function_name} before execution. Args: {function_args}")
|
||||
return f"Internal error preparing arguments for tool {function_name}."
|
||||
return tool.execute(function_name, **function_args)
|
||||
except Exception as e:
|
||||
logging.error(f"Error executing tool {function_name} with args {function_args}: {e}")
|
||||
return f"Error executing tool {function_name}: {e}"
|
||||
logging.warning(f"Tool function {function_name} not found.")
|
||||
return f"Error: Tool function {function_name} not found."
|
||||
|
||||
async def switch_model(self):
|
||||
pass
|
||||
if not self.model_config["small_model_name"] or not self.model_config["large_model_name"]:
|
||||
logging.warning("Small or Large model names are not defined. Cannot switch model.")
|
||||
return f"Model switching not fully configured. Currently using {self.model}."
|
||||
|
||||
current_is_small = self.model == self.model_config["small_model_name"]
|
||||
current_is_large = self.model == self.model_config["large_model_name"]
|
||||
|
||||
if current_is_large:
|
||||
target_model = self.model_config["small_model_name"]
|
||||
target_max_tokens_str = self.model_config["small_model_max_tokens"]
|
||||
elif current_is_small:
|
||||
target_model = self.model_config["large_model_name"]
|
||||
target_max_tokens_str = self.model_config["large_model_max_tokens"]
|
||||
else:
|
||||
logging.warning(f"Current model {self.model} is unrecognized. Switching to default small model: {self.model_config['small_model_name']}.")
|
||||
target_model = self.model_config["small_model_name"]
|
||||
target_max_tokens_str = self.model_config["small_model_max_tokens"]
|
||||
|
||||
self._configure_model_and_tokens(target_model, target_max_tokens_str)
|
||||
|
||||
def main():
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
|
||||
bot = None
|
||||
|
||||
try:
|
||||
parser = argparse.ArgumentParser(description='OpenAI Compatible Inference Bot')
|
||||
parser.add_argument('--config', type=str, help='Configuration Prepend (i.e. gemini, openai, etc)', default="Telegram")
|
||||
parser.add_argument('--messenger', type=str, help='Messenger type (i.e. telegram)', required=True)
|
||||
parser.add_argument('--persona', type=str, help='Path to system prompt file', required=False)
|
||||
parser.add_argument('--tools', nargs='+', help='List of allowed function tags', required=False)
|
||||
# Add these to launch.json arguments if you want to limit the toolset available: "--tools", "read", "communicate"
|
||||
# Parse command line arguments
|
||||
args = parser.parse_args()
|
||||
if args.persona:
|
||||
logging.info(f"Using custom persona from: {args.persona}")
|
||||
|
||||
|
||||
system_prompt_path=args.persona if args.persona else None
|
||||
allowed_function_tags=args.tools if args.tools else None
|
||||
config_prepend = args.config if args.config else None
|
||||
messenger = args.messenger if args.messenger else None
|
||||
|
||||
# Initialize model and max tokens based on the config prepend
|
||||
if config_prepend:
|
||||
api_key = os.environ.get(f"{config_prepend.upper()}_API_KEY")
|
||||
baseurl = os.environ.get(f"{config_prepend.upper()}_API_BASE_URL", "")
|
||||
small_model_name = os.environ.get(f"{config_prepend.upper()}_SMALL_MODEL")
|
||||
large_model_name = os.environ.get(f"{config_prepend.upper()}_LARGE_MODEL")
|
||||
small_model_max_tokens = os.environ.get(f"{config_prepend.upper()}_SMALL_MODEL_MAX_TOKENS")
|
||||
large_model_max_tokens = os.environ.get(f"{config_prepend.upper()}_LARGE_MODEL_MAX_TOKENS")
|
||||
|
||||
bot = OpenAICompatibleInferenceBot(
|
||||
api_key=api_key,
|
||||
base_url=baseurl,
|
||||
small_model_name=small_model_name,
|
||||
small_model_max_tokens=small_model_max_tokens,
|
||||
large_model_name=large_model_name,
|
||||
large_model_max_tokens=large_model_max_tokens,
|
||||
system_prompt_path=system_prompt_path,
|
||||
allowed_function_tags=allowed_function_tags
|
||||
)
|
||||
messenger_helper_class = importlib.import_module(f'{messenger.lower()}_helper')
|
||||
messenger_helper_class_name = f"{messenger.capitalize()}Helper"
|
||||
if not hasattr(messenger_helper_class, messenger_helper_class_name):
|
||||
raise ValueError(f"Messenger helper class {messenger_helper_class_name} not found in {messenger_helper_class.__name__}.")
|
||||
messenger_helper_class = getattr(messenger_helper_class, messenger_helper_class_name)
|
||||
|
||||
helper = messenger_helper_class(bot)
|
||||
helper.run()
|
||||
except ValueError as e:
|
||||
logging.error(f"FATAL: {e}")
|
||||
return
|
||||
except Exception as e: # Catch any other init errors
|
||||
logging.error(f"An unexpected error occurred during bot initialization: {e}")
|
||||
return
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user