61fe33e1c4
This commit introduces a new command-line argument `--use-large-model` to `openai_compatible_inference_bot.py`. When this argument is provided, the bot will initialize and use the large model (as configured via environment variables) by default, instead of the small model. This allows for easier testing and deployment of the large model from the command line. Fixes #224
413 lines
21 KiB
Python
413 lines
21 KiB
Python
import importlib
|
|
import json
|
|
import os
|
|
import logging
|
|
import inspect
|
|
from abc import abstractmethod
|
|
from openai import OpenAI
|
|
from tools.base_tool import BaseTool
|
|
from telegram_helper import TelegramHelper
|
|
import argparse
|
|
from inference_bot import InferenceBot
|
|
|
|
class OpenAICompatibleInferenceBot(InferenceBot):
|
|
def __init__(
|
|
self,
|
|
api_key: str | None = None,
|
|
base_url: str | None = None,
|
|
small_model_name: str | None = None,
|
|
small_model_max_tokens: str | None = None,
|
|
large_model_name: str | None = None,
|
|
large_model_max_tokens: str | None = None,
|
|
allowed_function_tags: list[str] | None = None,
|
|
system_prompt_path: str | None = None,
|
|
use_large_model: bool = False # New argument
|
|
):
|
|
self.model_config = {
|
|
"small_model_name": small_model_name,
|
|
"small_model_max_tokens": small_model_max_tokens,
|
|
"large_model_name": large_model_name,
|
|
"large_model_max_tokens": large_model_max_tokens
|
|
}
|
|
self.allowed_function_tags = allowed_function_tags if allowed_function_tags else None
|
|
self.conversation_history = {}
|
|
self._processing_status = {}
|
|
self.system_prompt_path = system_prompt_path # Store the prompt path for status
|
|
# MODIFIED to pass arguments
|
|
self.system_prompt = self.load_system_prompt(
|
|
file_path=system_prompt_path
|
|
)
|
|
self.tools, self.functions = self.load_functions()
|
|
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
|
log_msg = f"Initialized OpenAI compatible client. Target URL: {base_url if base_url else 'OpenAI default'}."
|
|
logging.info(log_msg)
|
|
|
|
# Configure the actual model name and max_tokens for API calls
|
|
if use_large_model:
|
|
self._configure_model_and_tokens(
|
|
self.model_config["large_model_name"],
|
|
self.model_config["large_model_max_tokens"]
|
|
)
|
|
else:
|
|
self._configure_model_and_tokens(
|
|
self.model_config["small_model_name"],
|
|
self.model_config["small_model_max_tokens"]
|
|
)
|
|
@property
|
|
def processing_status(self):
|
|
"""
|
|
An attribute to store the processing status for users.
|
|
Example usage in subclass: self.processing_status.get(user_id)
|
|
"""
|
|
return self._processing_status
|
|
|
|
def clear_conversation_history(self, user_id):
|
|
if user_id in self.conversation_history:
|
|
del self.conversation_history[user_id]
|
|
|
|
for tool in self.tools:
|
|
tool.clear()
|
|
|
|
def _configure_model_and_tokens(self, model_name: str | None, max_tokens_str: str | None):
|
|
self.model = model_name
|
|
try:
|
|
# If max_tokens_str is explicitly "None" or empty, treat as None for API default
|
|
if max_tokens_str and max_tokens_str.lower() not in ["none", "", "null"]:
|
|
self.max_tokens = int(max_tokens_str)
|
|
else:
|
|
self.max_tokens = None # Use API default by not sending the parameter or sending null
|
|
except ValueError:
|
|
logging.warning(f"Invalid value for max_tokens: {max_tokens_str}. Using API default (None)")
|
|
self.max_tokens = None # Use API default
|
|
|
|
logging.info(f"Configured to use model: {self.model} with max_tokens: {self.max_tokens if self.max_tokens is not None else 'API default'}")
|
|
|
|
def get_llm_description(self) -> str:
|
|
client_type = type(self.client).__name__
|
|
return f"Client: {client_type}, LLM: {self.model}, Max Tokens: {self.max_tokens if self.max_tokens is not None else 'API default'}"
|
|
|
|
def get_chat_response(self, messages):
|
|
if not self.client:
|
|
# This should ideally not be hit if __init__ is successful
|
|
logging.error("OpenAI client not initialized before get_chat_response.")
|
|
raise ValueError("OpenAI client not initialized.")
|
|
try:
|
|
# Pass self.max_tokens directly. If None, OpenAI library omits it or handles it.
|
|
# Initialize tools filtering based on allowed tags
|
|
cleaned_tools = None
|
|
if hasattr(self, 'functions') and self.functions:
|
|
# Create a copy of functions without "_tags" field
|
|
cleaned_tools = []
|
|
for func in self.functions:
|
|
include_function = False
|
|
|
|
if not hasattr(self, 'allowed_function_tags') or self.allowed_function_tags is None:
|
|
# Include all functions if no tag filtering is specified
|
|
include_function = True
|
|
else:
|
|
# Only include if function has matching tags
|
|
tags = func.get("_tags", [])
|
|
if any(tag in self.allowed_function_tags for tag in tags):
|
|
include_function = True
|
|
|
|
if include_function:
|
|
func_copy = {k: v for k, v in func.items() if k != "_tags"}
|
|
cleaned_tools.append(func_copy)
|
|
|
|
response = self.client.chat.completions.create(
|
|
model=self.model,
|
|
messages=messages,
|
|
tools=cleaned_tools,
|
|
tool_choice="auto" if cleaned_tools else None,
|
|
max_tokens=self.max_tokens
|
|
)
|
|
return response
|
|
except Exception as e:
|
|
logging.error(f"API call to model {self.model} failed: {e}")
|
|
raise
|
|
|
|
def get_bot_status(self):
|
|
"""
|
|
Returns a message with the currently enabled model and the system prompt path being used.
|
|
"""
|
|
model_name = self.model if hasattr(self, 'model') else None
|
|
prompt_path = self.system_prompt_path or os.getenv("SYSTEM_PROMPT_PATH") or "(default prompt in use)"
|
|
return f"Current model: {model_name}\nSystem prompt path: {prompt_path}"
|
|
|
|
async def handle_message(self, user_id, user_message):
|
|
if user_id not in self.conversation_history or not self.conversation_history[user_id]:
|
|
self.conversation_history[user_id] = []
|
|
if self.system_prompt: # Use the loaded system_prompt
|
|
self.conversation_history[user_id].append({"role": "system", "content": self.system_prompt})
|
|
|
|
self.conversation_history[user_id].append({"role": "user", "content": user_message})
|
|
messages = list(self.conversation_history[user_id]) # Work with a copy for this turn
|
|
|
|
response = self.get_chat_response(messages)
|
|
|
|
if not (response.choices and response.choices[0].message):
|
|
logging.error("No valid response choice message from LLM.")
|
|
# Persist the user message in history even if LLM fails this turn
|
|
self.conversation_history[user_id] = messages
|
|
return "Error: Could not get a valid response from the LLM."
|
|
|
|
assistant_message = response.choices[0].message
|
|
messages.append(assistant_message)
|
|
|
|
tool_calls_from_response = list(assistant_message.tool_calls) if assistant_message.tool_calls else []
|
|
|
|
tool_use_count = 0
|
|
MAX_TOOL_ITERATIONS = 200
|
|
|
|
while tool_calls_from_response and tool_use_count < MAX_TOOL_ITERATIONS:
|
|
tool_results_for_model = []
|
|
|
|
for tool_call in tool_calls_from_response:
|
|
tool_call_id = tool_call.id
|
|
function_to_call = tool_call.function
|
|
function_name = function_to_call.name
|
|
function_args_str = function_to_call.arguments
|
|
|
|
logging.info(f"Attempting to call tool: {function_name} with args: {function_args_str}")
|
|
if function_name not in [f["function"]["name"] for f in self.functions]:
|
|
logging.warning(f"Tool function {function_name} not found in available functions.")
|
|
tool_results_for_model.append({
|
|
"role": "tool",
|
|
"tool_call_id": tool_call_id,
|
|
"name": function_name,
|
|
"content": f"Error: Tool function {function_name} not found."
|
|
})
|
|
continue
|
|
|
|
try:
|
|
# Arguments are already a string from the API, self.call_tool expects dict or string
|
|
tool_response_content = self.call_tool(function_name, function_args_str)
|
|
# Ensure content is string for OpenAI tool role
|
|
if not isinstance(tool_response_content, str):
|
|
tool_response_content = json.dumps(tool_response_content)
|
|
except Exception as e:
|
|
logging.error(f"Error calling tool {function_name}: {e}")
|
|
tool_response_content = f"Error executing tool {function_name}: {str(e)}"
|
|
|
|
tool_results_for_model.append({
|
|
"role": "tool",
|
|
"tool_call_id": tool_call_id,
|
|
"name": function_name,
|
|
"content": tool_response_content
|
|
})
|
|
|
|
messages.extend(tool_results_for_model)
|
|
|
|
response = self.get_chat_response(messages)
|
|
if not (response.choices and response.choices[0].message):
|
|
logging.error("No valid response choice message from LLM after tool call.")
|
|
self.conversation_history[user_id] = messages # Persist state before error
|
|
return "Error: Could not get a valid response from the LLM after tool call."
|
|
|
|
assistant_message = response.choices[0].message
|
|
messages.append(assistant_message)
|
|
|
|
tool_calls_from_response = list(assistant_message.tool_calls) if assistant_message.tool_calls else []
|
|
|
|
tool_use_count += 1
|
|
if tool_use_count >= MAX_TOOL_ITERATIONS and tool_calls_from_response:
|
|
logging.warning(f"Max tool iterations ({MAX_TOOL_ITERATIONS}) reached. Returning last assistant message.")
|
|
# Ensure final content is returned even if max iterations hit with pending tool calls
|
|
break
|
|
|
|
self.conversation_history[user_id] = messages
|
|
|
|
final_assistant_message = messages[-1]
|
|
return final_assistant_message.content if final_assistant_message.role == "assistant" and final_assistant_message.content is not None else "Assistant did not provide a textual response."
|
|
|
|
async def start(self):
|
|
logging.info(f"{self.__class__.__name__} (Model: {self.model}) started.")
|
|
|
|
async def abort_processing(self, user_id):
|
|
# This is a soft abort for OpenAI compatible bots as API calls are synchronous within handle_message
|
|
if user_id in self.processing_status:
|
|
self.clear_processing_status(user_id) # Use base class method
|
|
logging.info(f"Processing aborted for user {user_id}.")
|
|
return "Processing aborted. You can send a new message or /clear the conversation."
|
|
else:
|
|
return "No active processing found to abort. If you wish, /clear the conversation history."
|
|
|
|
def load_functions(self):
|
|
tools = []
|
|
functions = []
|
|
tools_dir = os.path.join(os.path.dirname(__file__), 'tools')
|
|
if not os.path.exists(tools_dir):
|
|
logging.warning(f"Tools directory not found: {tools_dir}")
|
|
return [], []
|
|
|
|
for filename in os.listdir(tools_dir):
|
|
if filename.endswith('.py') and filename != '__init__.py' and filename != 'base_tool.py':
|
|
module_name = f'tools.{filename[:-3]}'
|
|
try:
|
|
module = importlib.import_module(module_name)
|
|
for name, obj in inspect.getmembers(module):\
|
|
if inspect.isclass(obj) and issubclass(obj, BaseTool) and obj != BaseTool:
|
|
try:
|
|
tools.append(obj()) # This instantiation might be an issue for tools needing config
|
|
except Exception as e:
|
|
logging.error(f"Error instantiating tool {name} from {filename}: {e}")
|
|
except Exception as e:
|
|
logging.error(f"Error importing module {module_name}: {e}")
|
|
|
|
for tool in tools:
|
|
functions.extend(tool.get_functions())
|
|
return tools, functions
|
|
|
|
def load_system_prompt(self, direct_content: str | None = None, file_path: str | None = None) -> str:
|
|
default_prompt = "You are a helpful AI assistant."
|
|
|
|
if direct_content:
|
|
logging.info("Using direct content for system prompt.")
|
|
return direct_content.strip()
|
|
|
|
prompt_path_to_try = file_path or os.getenv("SYSTEM_PROMPT_PATH")
|
|
|
|
if prompt_path_to_try:
|
|
if os.path.isfile(prompt_path_to_try):
|
|
try:
|
|
with open(prompt_path_to_try, "r", encoding="utf-8") as file:
|
|
content = file.read().strip()
|
|
logging.info(f"Successfully loaded system prompt from {prompt_path_to_try}.")
|
|
return content
|
|
except IOError as e:
|
|
logging.warning(f"Could not read system prompt file {prompt_path_to_try}: {e}. Using default.")
|
|
return default_prompt
|
|
else:
|
|
# This condition now also covers if 'file_path' argument was given but invalid
|
|
logging.warning(f"System prompt file {prompt_path_to_try} not found. Using default system prompt.")
|
|
return default_prompt
|
|
else:
|
|
logging.info("No system prompt path provided (argument or ENV) or direct content. Using default system prompt.")
|
|
return default_prompt
|
|
|
|
def set_processing_status(self, user_id: int, message_id: int):
|
|
self.processing_status[user_id] = {"processing": True, "message_id": message_id}
|
|
|
|
def clear_processing_status(self, user_id: int):
|
|
if user_id in self.processing_status:
|
|
del self.processing_status[user_id]
|
|
|
|
def call_tool(self, function_call_name, function_call_arguments):
|
|
function_name = function_call_name
|
|
function_args = None
|
|
if isinstance(function_call_arguments, dict):
|
|
function_args = function_call_arguments
|
|
elif isinstance(function_call_arguments, str):
|
|
try:
|
|
function_args = json.loads(function_call_arguments)
|
|
except json.JSONDecodeError as e:
|
|
logging.error(f"Error decoding function call arguments (string) for {function_call_name}: {e}. Arguments: {function_call_arguments}")
|
|
return f"Error: Malformed arguments for tool call: {e}"
|
|
else:
|
|
if function_call_arguments is None:
|
|
function_args = {}
|
|
else:
|
|
logging.error(f"Unexpected type for function_call_arguments for {function_name}: {type(function_call_arguments)}. Arguments: {function_call_arguments}")
|
|
return f"Error: Invalid argument type for tool call: {type(function_call_arguments)}"
|
|
|
|
for tool in self.tools:
|
|
for function in tool.get_functions():
|
|
if function["function"]["name"] == function_name:
|
|
try:
|
|
if not isinstance(function_args, dict):
|
|
logging.error(f"Internal error: function_args not a dict for {function_name} before execution. Args: {function_args}")
|
|
return f"Internal error preparing arguments for tool {function_name}."
|
|
return tool.execute(function_name, **function_args)
|
|
except Exception as e:
|
|
logging.error(f"Error executing tool {function_name} with args {function_args}: {e}")
|
|
return f"Error executing tool {function_name}: {e}"
|
|
logging.warning(f"Tool function {function_name} not found.")
|
|
return f"Error: Tool function {function_name} not found."
|
|
|
|
async def switch_model(self):
|
|
if not self.model_config["small_model_name"] or not self.model_config["large_model_name"]:
|
|
logging.warning("Small or Large model names are not defined. Cannot switch model.")
|
|
return f"Model switching not fully configured. Currently using {self.model}."
|
|
|
|
current_is_small = self.model == self.model_config["small_model_name"]
|
|
current_is_large = self.model == self.model_config["large_model_name"]
|
|
|
|
if current_is_large:
|
|
target_model = self.model_config["small_model_name"]
|
|
target_max_tokens_str = self.model_config["small_model_max_tokens"]
|
|
elif current_is_small:
|
|
target_model = self.model_config["large_model_name"]
|
|
target_max_tokens_str = self.model_config["large_model_max_tokens"]
|
|
else:
|
|
logging.warning(f"Current model {self.model} is unrecognized. Switching to default small model: {self.model_config['small_model_name']}.")
|
|
target_model = self.model_config["small_model_name"]
|
|
target_max_tokens_str = self.model_config["small_model_max_tokens"]
|
|
|
|
self._configure_model_and_tokens(target_model, target_max_tokens_str)
|
|
return f"Switched model to {self.model}. Max tokens set to {self.max_tokens if self.max_tokens is not None else 'API default'}."
|
|
|
|
def main():
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
|
|
bot = None
|
|
|
|
try:
|
|
parser = argparse.ArgumentParser(description='OpenAI Compatible Inference Bot')
|
|
parser.add_argument('--config', type=str, help='Configuration Prepend (i.e. gemini, openai, etc)', default="Telegram")
|
|
parser.add_argument('--messenger', type=str, help='Messenger type (i.e. telegram)', required=True)
|
|
parser.add_argument('--persona', type=str, help='Path to system prompt file', required=False)
|
|
parser.add_argument('--tools', nargs='+', help='List of allowed function tags', required=False)
|
|
parser.add_argument('--use-large-model', action='store_true', help='Use the large model instead of the small model') # New argument
|
|
# Add these to launch.json arguments if you want to limit the toolset available: "--tools", "read", "communicate"
|
|
# Parse command line arguments
|
|
args = parser.parse_args()
|
|
if args.persona:
|
|
logging.info(f"Using custom persona from: {args.persona}")
|
|
|
|
|
|
system_prompt_path=args.persona if args.persona else None
|
|
allowed_function_tags=args.tools if args.tools else None
|
|
config_prepend = args.config if args.config else None
|
|
messenger = args.messenger if args.messenger else None
|
|
use_large_model = args.use_large_model # Get the value of the new argument
|
|
|
|
# Initialize model and max tokens based on the config prepend
|
|
if config_prepend:
|
|
api_key = os.environ.get(f"{config_prepend.upper()}_API_KEY")
|
|
baseurl = os.environ.get(f"{config_prepend.upper()}_API_BASE_URL", "")
|
|
small_model_name = os.environ.get(f"{config_prepend.upper()}_SMALL_MODEL")
|
|
large_model_name = os.environ.get(f"{config_prepend.upper()}_LARGE_MODEL")
|
|
small_model_max_tokens = os.environ.get(f"{config_prepend.upper()}_SMALL_MODEL_MAX_TOKENS")
|
|
large_model_max_tokens = os.environ.get(f"{config_prepend.upper()}_LARGE_MODEL_MAX_TOKENS")
|
|
|
|
bot = OpenAICompatibleInferenceBot(
|
|
api_key=api_key,
|
|
base_url=baseurl,
|
|
small_model_name=small_model_name,
|
|
small_model_max_tokens=small_model_max_tokens,
|
|
large_model_name=large_model_name,
|
|
large_model_max_tokens=large_model_max_tokens,
|
|
system_prompt_path=system_prompt_path,
|
|
allowed_function_tags=allowed_function_tags,
|
|
use_large_model=use_large_model # Pass the new argument
|
|
)
|
|
full_code_file = importlib.import_module(f'{messenger.lower()}_helper')
|
|
messenger_helper_class_name = f"{messenger.capitalize()}Helper"
|
|
if not hasattr(full_code_file, messenger_helper_class_name):
|
|
messenger_helper_class_name = f"{messenger.upper()}Helper"
|
|
if not hasattr(full_code_file, messenger_helper_class_name):
|
|
raise ValueError(f"Messenger helper class {messenger_helper_class_name} not found in {full_code_file.__name__}.")
|
|
helper_class = getattr(full_code_file, messenger_helper_class_name)
|
|
|
|
helper = helper_class(bot)
|
|
helper.run()
|
|
except ValueError as e:
|
|
logging.error(f"FATAL: {e}")
|
|
return
|
|
except Exception as e: # Catch any other init errors
|
|
logging.error(f"An unexpected error occurred during bot initialization: {e}")
|
|
return
|
|
|
|
if __name__ == '__main__':
|
|
main()
|