Files
cyclop/openai_compatible_inference_bot.py
T
cyclop-bot 61fe33e1c4 feat: Add --use-large-model argument to openai_compatible_inference_bot.py
This commit introduces a new command-line argument `--use-large-model` to `openai_compatible_inference_bot.py`. When this argument is provided, the bot will initialize and use the large model (as configured via environment variables) by default, instead of the small model. This allows for easier testing and deployment of the large model from the command line.

Fixes #224
2025-06-05 18:03:18 -05:00

413 lines
21 KiB
Python

import importlib
import json
import os
import logging
import inspect
from abc import abstractmethod
from openai import OpenAI
from tools.base_tool import BaseTool
from telegram_helper import TelegramHelper
import argparse
from inference_bot import InferenceBot
class OpenAICompatibleInferenceBot(InferenceBot):
def __init__(
self,
api_key: str | None = None,
base_url: str | None = None,
small_model_name: str | None = None,
small_model_max_tokens: str | None = None,
large_model_name: str | None = None,
large_model_max_tokens: str | None = None,
allowed_function_tags: list[str] | None = None,
system_prompt_path: str | None = None,
use_large_model: bool = False # New argument
):
self.model_config = {
"small_model_name": small_model_name,
"small_model_max_tokens": small_model_max_tokens,
"large_model_name": large_model_name,
"large_model_max_tokens": large_model_max_tokens
}
self.allowed_function_tags = allowed_function_tags if allowed_function_tags else None
self.conversation_history = {}
self._processing_status = {}
self.system_prompt_path = system_prompt_path # Store the prompt path for status
# MODIFIED to pass arguments
self.system_prompt = self.load_system_prompt(
file_path=system_prompt_path
)
self.tools, self.functions = self.load_functions()
self.client = OpenAI(api_key=api_key, base_url=base_url)
log_msg = f"Initialized OpenAI compatible client. Target URL: {base_url if base_url else 'OpenAI default'}."
logging.info(log_msg)
# Configure the actual model name and max_tokens for API calls
if use_large_model:
self._configure_model_and_tokens(
self.model_config["large_model_name"],
self.model_config["large_model_max_tokens"]
)
else:
self._configure_model_and_tokens(
self.model_config["small_model_name"],
self.model_config["small_model_max_tokens"]
)
@property
def processing_status(self):
"""
An attribute to store the processing status for users.
Example usage in subclass: self.processing_status.get(user_id)
"""
return self._processing_status
def clear_conversation_history(self, user_id):
if user_id in self.conversation_history:
del self.conversation_history[user_id]
for tool in self.tools:
tool.clear()
def _configure_model_and_tokens(self, model_name: str | None, max_tokens_str: str | None):
self.model = model_name
try:
# If max_tokens_str is explicitly "None" or empty, treat as None for API default
if max_tokens_str and max_tokens_str.lower() not in ["none", "", "null"]:
self.max_tokens = int(max_tokens_str)
else:
self.max_tokens = None # Use API default by not sending the parameter or sending null
except ValueError:
logging.warning(f"Invalid value for max_tokens: {max_tokens_str}. Using API default (None)")
self.max_tokens = None # Use API default
logging.info(f"Configured to use model: {self.model} with max_tokens: {self.max_tokens if self.max_tokens is not None else 'API default'}")
def get_llm_description(self) -> str:
client_type = type(self.client).__name__
return f"Client: {client_type}, LLM: {self.model}, Max Tokens: {self.max_tokens if self.max_tokens is not None else 'API default'}"
def get_chat_response(self, messages):
if not self.client:
# This should ideally not be hit if __init__ is successful
logging.error("OpenAI client not initialized before get_chat_response.")
raise ValueError("OpenAI client not initialized.")
try:
# Pass self.max_tokens directly. If None, OpenAI library omits it or handles it.
# Initialize tools filtering based on allowed tags
cleaned_tools = None
if hasattr(self, 'functions') and self.functions:
# Create a copy of functions without "_tags" field
cleaned_tools = []
for func in self.functions:
include_function = False
if not hasattr(self, 'allowed_function_tags') or self.allowed_function_tags is None:
# Include all functions if no tag filtering is specified
include_function = True
else:
# Only include if function has matching tags
tags = func.get("_tags", [])
if any(tag in self.allowed_function_tags for tag in tags):
include_function = True
if include_function:
func_copy = {k: v for k, v in func.items() if k != "_tags"}
cleaned_tools.append(func_copy)
response = self.client.chat.completions.create(
model=self.model,
messages=messages,
tools=cleaned_tools,
tool_choice="auto" if cleaned_tools else None,
max_tokens=self.max_tokens
)
return response
except Exception as e:
logging.error(f"API call to model {self.model} failed: {e}")
raise
def get_bot_status(self):
"""
Returns a message with the currently enabled model and the system prompt path being used.
"""
model_name = self.model if hasattr(self, 'model') else None
prompt_path = self.system_prompt_path or os.getenv("SYSTEM_PROMPT_PATH") or "(default prompt in use)"
return f"Current model: {model_name}\nSystem prompt path: {prompt_path}"
async def handle_message(self, user_id, user_message):
if user_id not in self.conversation_history or not self.conversation_history[user_id]:
self.conversation_history[user_id] = []
if self.system_prompt: # Use the loaded system_prompt
self.conversation_history[user_id].append({"role": "system", "content": self.system_prompt})
self.conversation_history[user_id].append({"role": "user", "content": user_message})
messages = list(self.conversation_history[user_id]) # Work with a copy for this turn
response = self.get_chat_response(messages)
if not (response.choices and response.choices[0].message):
logging.error("No valid response choice message from LLM.")
# Persist the user message in history even if LLM fails this turn
self.conversation_history[user_id] = messages
return "Error: Could not get a valid response from the LLM."
assistant_message = response.choices[0].message
messages.append(assistant_message)
tool_calls_from_response = list(assistant_message.tool_calls) if assistant_message.tool_calls else []
tool_use_count = 0
MAX_TOOL_ITERATIONS = 200
while tool_calls_from_response and tool_use_count < MAX_TOOL_ITERATIONS:
tool_results_for_model = []
for tool_call in tool_calls_from_response:
tool_call_id = tool_call.id
function_to_call = tool_call.function
function_name = function_to_call.name
function_args_str = function_to_call.arguments
logging.info(f"Attempting to call tool: {function_name} with args: {function_args_str}")
if function_name not in [f["function"]["name"] for f in self.functions]:
logging.warning(f"Tool function {function_name} not found in available functions.")
tool_results_for_model.append({
"role": "tool",
"tool_call_id": tool_call_id,
"name": function_name,
"content": f"Error: Tool function {function_name} not found."
})
continue
try:
# Arguments are already a string from the API, self.call_tool expects dict or string
tool_response_content = self.call_tool(function_name, function_args_str)
# Ensure content is string for OpenAI tool role
if not isinstance(tool_response_content, str):
tool_response_content = json.dumps(tool_response_content)
except Exception as e:
logging.error(f"Error calling tool {function_name}: {e}")
tool_response_content = f"Error executing tool {function_name}: {str(e)}"
tool_results_for_model.append({
"role": "tool",
"tool_call_id": tool_call_id,
"name": function_name,
"content": tool_response_content
})
messages.extend(tool_results_for_model)
response = self.get_chat_response(messages)
if not (response.choices and response.choices[0].message):
logging.error("No valid response choice message from LLM after tool call.")
self.conversation_history[user_id] = messages # Persist state before error
return "Error: Could not get a valid response from the LLM after tool call."
assistant_message = response.choices[0].message
messages.append(assistant_message)
tool_calls_from_response = list(assistant_message.tool_calls) if assistant_message.tool_calls else []
tool_use_count += 1
if tool_use_count >= MAX_TOOL_ITERATIONS and tool_calls_from_response:
logging.warning(f"Max tool iterations ({MAX_TOOL_ITERATIONS}) reached. Returning last assistant message.")
# Ensure final content is returned even if max iterations hit with pending tool calls
break
self.conversation_history[user_id] = messages
final_assistant_message = messages[-1]
return final_assistant_message.content if final_assistant_message.role == "assistant" and final_assistant_message.content is not None else "Assistant did not provide a textual response."
async def start(self):
logging.info(f"{self.__class__.__name__} (Model: {self.model}) started.")
async def abort_processing(self, user_id):
# This is a soft abort for OpenAI compatible bots as API calls are synchronous within handle_message
if user_id in self.processing_status:
self.clear_processing_status(user_id) # Use base class method
logging.info(f"Processing aborted for user {user_id}.")
return "Processing aborted. You can send a new message or /clear the conversation."
else:
return "No active processing found to abort. If you wish, /clear the conversation history."
def load_functions(self):
tools = []
functions = []
tools_dir = os.path.join(os.path.dirname(__file__), 'tools')
if not os.path.exists(tools_dir):
logging.warning(f"Tools directory not found: {tools_dir}")
return [], []
for filename in os.listdir(tools_dir):
if filename.endswith('.py') and filename != '__init__.py' and filename != 'base_tool.py':
module_name = f'tools.{filename[:-3]}'
try:
module = importlib.import_module(module_name)
for name, obj in inspect.getmembers(module):\
if inspect.isclass(obj) and issubclass(obj, BaseTool) and obj != BaseTool:
try:
tools.append(obj()) # This instantiation might be an issue for tools needing config
except Exception as e:
logging.error(f"Error instantiating tool {name} from {filename}: {e}")
except Exception as e:
logging.error(f"Error importing module {module_name}: {e}")
for tool in tools:
functions.extend(tool.get_functions())
return tools, functions
def load_system_prompt(self, direct_content: str | None = None, file_path: str | None = None) -> str:
default_prompt = "You are a helpful AI assistant."
if direct_content:
logging.info("Using direct content for system prompt.")
return direct_content.strip()
prompt_path_to_try = file_path or os.getenv("SYSTEM_PROMPT_PATH")
if prompt_path_to_try:
if os.path.isfile(prompt_path_to_try):
try:
with open(prompt_path_to_try, "r", encoding="utf-8") as file:
content = file.read().strip()
logging.info(f"Successfully loaded system prompt from {prompt_path_to_try}.")
return content
except IOError as e:
logging.warning(f"Could not read system prompt file {prompt_path_to_try}: {e}. Using default.")
return default_prompt
else:
# This condition now also covers if 'file_path' argument was given but invalid
logging.warning(f"System prompt file {prompt_path_to_try} not found. Using default system prompt.")
return default_prompt
else:
logging.info("No system prompt path provided (argument or ENV) or direct content. Using default system prompt.")
return default_prompt
def set_processing_status(self, user_id: int, message_id: int):
self.processing_status[user_id] = {"processing": True, "message_id": message_id}
def clear_processing_status(self, user_id: int):
if user_id in self.processing_status:
del self.processing_status[user_id]
def call_tool(self, function_call_name, function_call_arguments):
function_name = function_call_name
function_args = None
if isinstance(function_call_arguments, dict):
function_args = function_call_arguments
elif isinstance(function_call_arguments, str):
try:
function_args = json.loads(function_call_arguments)
except json.JSONDecodeError as e:
logging.error(f"Error decoding function call arguments (string) for {function_call_name}: {e}. Arguments: {function_call_arguments}")
return f"Error: Malformed arguments for tool call: {e}"
else:
if function_call_arguments is None:
function_args = {}
else:
logging.error(f"Unexpected type for function_call_arguments for {function_name}: {type(function_call_arguments)}. Arguments: {function_call_arguments}")
return f"Error: Invalid argument type for tool call: {type(function_call_arguments)}"
for tool in self.tools:
for function in tool.get_functions():
if function["function"]["name"] == function_name:
try:
if not isinstance(function_args, dict):
logging.error(f"Internal error: function_args not a dict for {function_name} before execution. Args: {function_args}")
return f"Internal error preparing arguments for tool {function_name}."
return tool.execute(function_name, **function_args)
except Exception as e:
logging.error(f"Error executing tool {function_name} with args {function_args}: {e}")
return f"Error executing tool {function_name}: {e}"
logging.warning(f"Tool function {function_name} not found.")
return f"Error: Tool function {function_name} not found."
async def switch_model(self):
if not self.model_config["small_model_name"] or not self.model_config["large_model_name"]:
logging.warning("Small or Large model names are not defined. Cannot switch model.")
return f"Model switching not fully configured. Currently using {self.model}."
current_is_small = self.model == self.model_config["small_model_name"]
current_is_large = self.model == self.model_config["large_model_name"]
if current_is_large:
target_model = self.model_config["small_model_name"]
target_max_tokens_str = self.model_config["small_model_max_tokens"]
elif current_is_small:
target_model = self.model_config["large_model_name"]
target_max_tokens_str = self.model_config["large_model_max_tokens"]
else:
logging.warning(f"Current model {self.model} is unrecognized. Switching to default small model: {self.model_config['small_model_name']}.")
target_model = self.model_config["small_model_name"]
target_max_tokens_str = self.model_config["small_model_max_tokens"]
self._configure_model_and_tokens(target_model, target_max_tokens_str)
return f"Switched model to {self.model}. Max tokens set to {self.max_tokens if self.max_tokens is not None else 'API default'}."
def main():
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
bot = None
try:
parser = argparse.ArgumentParser(description='OpenAI Compatible Inference Bot')
parser.add_argument('--config', type=str, help='Configuration Prepend (i.e. gemini, openai, etc)', default="Telegram")
parser.add_argument('--messenger', type=str, help='Messenger type (i.e. telegram)', required=True)
parser.add_argument('--persona', type=str, help='Path to system prompt file', required=False)
parser.add_argument('--tools', nargs='+', help='List of allowed function tags', required=False)
parser.add_argument('--use-large-model', action='store_true', help='Use the large model instead of the small model') # New argument
# Add these to launch.json arguments if you want to limit the toolset available: "--tools", "read", "communicate"
# Parse command line arguments
args = parser.parse_args()
if args.persona:
logging.info(f"Using custom persona from: {args.persona}")
system_prompt_path=args.persona if args.persona else None
allowed_function_tags=args.tools if args.tools else None
config_prepend = args.config if args.config else None
messenger = args.messenger if args.messenger else None
use_large_model = args.use_large_model # Get the value of the new argument
# Initialize model and max tokens based on the config prepend
if config_prepend:
api_key = os.environ.get(f"{config_prepend.upper()}_API_KEY")
baseurl = os.environ.get(f"{config_prepend.upper()}_API_BASE_URL", "")
small_model_name = os.environ.get(f"{config_prepend.upper()}_SMALL_MODEL")
large_model_name = os.environ.get(f"{config_prepend.upper()}_LARGE_MODEL")
small_model_max_tokens = os.environ.get(f"{config_prepend.upper()}_SMALL_MODEL_MAX_TOKENS")
large_model_max_tokens = os.environ.get(f"{config_prepend.upper()}_LARGE_MODEL_MAX_TOKENS")
bot = OpenAICompatibleInferenceBot(
api_key=api_key,
base_url=baseurl,
small_model_name=small_model_name,
small_model_max_tokens=small_model_max_tokens,
large_model_name=large_model_name,
large_model_max_tokens=large_model_max_tokens,
system_prompt_path=system_prompt_path,
allowed_function_tags=allowed_function_tags,
use_large_model=use_large_model # Pass the new argument
)
full_code_file = importlib.import_module(f'{messenger.lower()}_helper')
messenger_helper_class_name = f"{messenger.capitalize()}Helper"
if not hasattr(full_code_file, messenger_helper_class_name):
messenger_helper_class_name = f"{messenger.upper()}Helper"
if not hasattr(full_code_file, messenger_helper_class_name):
raise ValueError(f"Messenger helper class {messenger_helper_class_name} not found in {full_code_file.__name__}.")
helper_class = getattr(full_code_file, messenger_helper_class_name)
helper = helper_class(bot)
helper.run()
except ValueError as e:
logging.error(f"FATAL: {e}")
return
except Exception as e: # Catch any other init errors
logging.error(f"An unexpected error occurred during bot initialization: {e}")
return
if __name__ == '__main__':
main()