import importlib import json import os import logging import inspect from abc import abstractmethod from openai import OpenAI from tools.base_tool import BaseTool from telegram_helper import TelegramHelper import argparse from inference_bot import InferenceBot class OpenAICompatibleInferenceBot(InferenceBot): def __init__( self, api_key: str | None = None, base_url: str | None = None, small_model_name: str | None = None, small_model_max_tokens: str | None = None, large_model_name: str | None = None, large_model_max_tokens: str | None = None, allowed_function_tags: list[str] | None = None, system_prompt_path: str | None = None ): self.model_config = { "small_model_name": small_model_name, "small_model_max_tokens": small_model_max_tokens, "large_model_name": large_model_name, "large_model_max_tokens": large_model_max_tokens } self.allowed_function_tags = allowed_function_tags if allowed_function_tags else None self.conversation_history = {} self._processing_status = {} self.system_prompt_path = system_prompt_path # Store the prompt path for status # MODIFIED to pass arguments self.system_prompt = self.load_system_prompt( file_path=system_prompt_path ) self.tools, self.functions = self.load_functions() self.client = OpenAI(api_key=api_key, base_url=base_url) log_msg = f"Initialized OpenAI compatible client. Target URL: {base_url if base_url else 'OpenAI default'}." logging.info(log_msg) # Configure the actual model name and max_tokens for API calls self._configure_model_and_tokens( self.model_config["small_model_name"], self.model_config["small_model_max_tokens"] ) @property def processing_status(self): """ An attribute to store the processing status for users. Example usage in subclass: self.processing_status.get(user_id) """ return self._processing_status def clear_conversation_history(self, user_id): if user_id in self.conversation_history: del self.conversation_history[user_id] for tool in self.tools: tool.clear() def _configure_model_and_tokens(self, model_name: str | None, max_tokens_str: str | None): self.model = model_name try: # If max_tokens_str is explicitly "None" or empty, treat as None for API default if max_tokens_str and max_tokens_str.lower() not in ["none", "", "null"]: self.max_tokens = int(max_tokens_str) else: self.max_tokens = None # Use API default by not sending the parameter or sending null except ValueError: logging.warning(f"Invalid value for max_tokens: {max_tokens_str}. Using API default (None)") self.max_tokens = None # Use API default logging.info(f"Configured to use model: {self.model} with max_tokens: {self.max_tokens if self.max_tokens is not None else 'API default'}") def get_llm_description(self) -> str: client_type = type(self.client).__name__ return f"Client: {client_type}, LLM: {self.model}, Max Tokens: {self.max_tokens if self.max_tokens is not None else 'API default'}" def get_chat_response(self, messages): if not self.client: # This should ideally not be hit if __init__ is successful logging.error("OpenAI client not initialized before get_chat_response.") raise ValueError("OpenAI client not initialized.") try: # Pass self.max_tokens directly. If None, OpenAI library omits it or handles it. # Initialize tools filtering based on allowed tags cleaned_tools = None if hasattr(self, 'functions') and self.functions: # Create a copy of functions without "_tags" field cleaned_tools = [] for func in self.functions: include_function = False if not hasattr(self, 'allowed_function_tags') or self.allowed_function_tags is None: # Include all functions if no tag filtering is specified include_function = True else: # Only include if function has matching tags tags = func.get("_tags", []) if any(tag in self.allowed_function_tags for tag in tags): include_function = True if include_function: func_copy = {k: v for k, v in func.items() if k != "_tags"} cleaned_tools.append(func_copy) response = self.client.chat.completions.create( model=self.model, messages=messages, tools=cleaned_tools, tool_choice="auto" if cleaned_tools else None, max_tokens=self.max_tokens ) return response except Exception as e: logging.error(f"API call to model {self.model} failed: {e}") raise def get_bot_status(self): """ Returns a message with the currently enabled model and the system prompt path being used. """ model_name = self.model if hasattr(self, 'model') else None prompt_path = self.system_prompt_path or os.getenv("SYSTEM_PROMPT_PATH") or "(default prompt in use)" return f"Current model: {model_name}\nSystem prompt path: {prompt_path}" async def handle_message(self, user_id, user_message): if user_id not in self.conversation_history or not self.conversation_history[user_id]: self.conversation_history[user_id] = [] if self.system_prompt: # Use the loaded system_prompt self.conversation_history[user_id].append({"role": "system", "content": self.system_prompt}) self.conversation_history[user_id].append({"role": "user", "content": user_message}) messages = list(self.conversation_history[user_id]) # Work with a copy for this turn response = self.get_chat_response(messages) if not (response.choices and response.choices[0].message): logging.error("No valid response choice message from LLM.") # Persist the user message in history even if LLM fails this turn self.conversation_history[user_id] = messages return "Error: Could not get a valid response from the LLM." assistant_message = response.choices[0].message messages.append(assistant_message) tool_calls_from_response = list(assistant_message.tool_calls) if assistant_message.tool_calls else [] tool_use_count = 0 MAX_TOOL_ITERATIONS = 200 while tool_calls_from_response and tool_use_count < MAX_TOOL_ITERATIONS: tool_results_for_model = [] for tool_call in tool_calls_from_response: tool_call_id = tool_call.id function_to_call = tool_call.function function_name = function_to_call.name function_args_str = function_to_call.arguments logging.info(f"Attempting to call tool: {function_name} with args: {function_args_str}") if function_name not in [f["function"]["name"] for f in self.functions]: logging.warning(f"Tool function {function_name} not found in available functions.") tool_results_for_model.append({ "role": "tool", "tool_call_id": tool_call_id, "name": function_name, "content": f"Error: Tool function {function_name} not found." }) continue try: # Arguments are already a string from the API, self.call_tool expects dict or string tool_response_content = self.call_tool(function_name, function_args_str) # Ensure content is string for OpenAI tool role if not isinstance(tool_response_content, str): tool_response_content = json.dumps(tool_response_content) except Exception as e: logging.error(f"Error calling tool {function_name}: {e}") tool_response_content = f"Error executing tool {function_name}: {str(e)}" tool_results_for_model.append({ "role": "tool", "tool_call_id": tool_call_id, "name": function_name, "content": tool_response_content }) messages.extend(tool_results_for_model) response = self.get_chat_response(messages) if not (response.choices and response.choices[0].message): logging.error("No valid response choice message from LLM after tool call.") self.conversation_history[user_id] = messages # Persist state before error return "Error: Could not get a valid response from the LLM after tool call." assistant_message = response.choices[0].message messages.append(assistant_message) tool_calls_from_response = list(assistant_message.tool_calls) if assistant_message.tool_calls else [] tool_use_count += 1 if tool_use_count >= MAX_TOOL_ITERATIONS and tool_calls_from_response: logging.warning(f"Max tool iterations ({MAX_TOOL_ITERATIONS}) reached. Returning last assistant message.") # Ensure final content is returned even if max iterations hit with pending tool calls break self.conversation_history[user_id] = messages final_assistant_message = messages[-1] return final_assistant_message.content if final_assistant_message.role == "assistant" and final_assistant_message.content is not None else "Assistant did not provide a textual response." async def start(self): logging.info(f"{self.__class__.__name__} (Model: {self.model}) started.") async def abort_processing(self, user_id): # This is a soft abort for OpenAI compatible bots as API calls are synchronous within handle_message if user_id in self.processing_status: self.clear_processing_status(user_id) # Use base class method logging.info(f"Processing aborted for user {user_id}.") return "Processing aborted. You can send a new message or /clear the conversation." else: return "No active processing found to abort. If you wish, /clear the conversation history." def load_functions(self): tools = [] functions = [] tools_dir = os.path.join(os.path.dirname(__file__), 'tools') if not os.path.exists(tools_dir): logging.warning(f"Tools directory not found: {tools_dir}") return [], [] for filename in os.listdir(tools_dir): if filename.endswith('.py') and filename != '__init__.py' and filename != 'base_tool.py': module_name = f'tools.{filename[:-3]}' try: module = importlib.import_module(module_name) for name, obj in inspect.getmembers(module): if inspect.isclass(obj) and issubclass(obj, BaseTool) and obj != BaseTool: try: tools.append(obj()) # This instantiation might be an issue for tools needing config except Exception as e: logging.error(f"Error instantiating tool {name} from {filename}: {e}") except Exception as e: logging.error(f"Error importing module {module_name}: {e}") for tool in tools: functions.extend(tool.get_functions()) return tools, functions def load_system_prompt(self, direct_content: str | None = None, file_path: str | None = None) -> str: default_prompt = "You are a helpful AI assistant." if direct_content: logging.info("Using direct content for system prompt.") return direct_content.strip() prompt_path_to_try = file_path or os.getenv("SYSTEM_PROMPT_PATH") if prompt_path_to_try: if os.path.isfile(prompt_path_to_try): try: with open(prompt_path_to_try, "r", encoding="utf-8") as file: content = file.read().strip() logging.info(f"Successfully loaded system prompt from {prompt_path_to_try}.") return content except IOError as e: logging.warning(f"Could not read system prompt file {prompt_path_to_try}: {e}. Using default.") return default_prompt else: # This condition now also covers if 'file_path' argument was given but invalid logging.warning(f"System prompt file {prompt_path_to_try} not found. Using default system prompt.") return default_prompt else: logging.info("No system prompt path provided (argument or ENV) or direct content. Using default system prompt.") return default_prompt def set_processing_status(self, user_id: int, message_id: int): self.processing_status[user_id] = {"processing": True, "message_id": message_id} def clear_processing_status(self, user_id: int): if user_id in self.processing_status: del self.processing_status[user_id] def call_tool(self, function_call_name, function_call_arguments): function_name = function_call_name function_args = None if isinstance(function_call_arguments, dict): function_args = function_call_arguments elif isinstance(function_call_arguments, str): try: function_args = json.loads(function_call_arguments) except json.JSONDecodeError as e: logging.error(f"Error decoding function call arguments (string) for {function_call_name}: {e}. Arguments: {function_call_arguments}") return f"Error: Malformed arguments for tool call: {e}" else: if function_call_arguments is None: function_args = {} else: logging.error(f"Unexpected type for function_call_arguments for {function_name}: {type(function_call_arguments)}. Arguments: {function_call_arguments}") return f"Error: Invalid argument type for tool call: {type(function_call_arguments)}" for tool in self.tools: for function in tool.get_functions(): if function["function"]["name"] == function_name: try: if not isinstance(function_args, dict): logging.error(f"Internal error: function_args not a dict for {function_name} before execution. Args: {function_args}") return f"Internal error preparing arguments for tool {function_name}." return tool.execute(function_name, **function_args) except Exception as e: logging.error(f"Error executing tool {function_name} with args {function_args}: {e}") return f"Error executing tool {function_name}: {e}" logging.warning(f"Tool function {function_name} not found.") return f"Error: Tool function {function_name} not found." async def switch_model(self): if not self.model_config["small_model_name"] or not self.model_config["large_model_name"]: logging.warning("Small or Large model names are not defined. Cannot switch model.") return f"Model switching not fully configured. Currently using {self.model}." current_is_small = self.model == self.model_config["small_model_name"] current_is_large = self.model == self.model_config["large_model_name"] if current_is_large: target_model = self.model_config["small_model_name"] target_max_tokens_str = self.model_config["small_model_max_tokens"] elif current_is_small: target_model = self.model_config["large_model_name"] target_max_tokens_str = self.model_config["large_model_max_tokens"] else: logging.warning(f"Current model {self.model} is unrecognized. Switching to default small model: {self.model_config['small_model_name']}.") target_model = self.model_config["small_model_name"] target_max_tokens_str = self.model_config["small_model_max_tokens"] self._configure_model_and_tokens(target_model, target_max_tokens_str) return f"Switched model to {self.model}. Max tokens set to {self.max_tokens if self.max_tokens is not None else 'API default'}." def main(): logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') bot = None try: parser = argparse.ArgumentParser(description='OpenAI Compatible Inference Bot') parser.add_argument('--config', type=str, help='Configuration Prepend (i.e. gemini, openai, etc)', default="Telegram") parser.add_argument('--messenger', type=str, help='Messenger type (i.e. telegram)', required=True) parser.add_argument('--persona', type=str, help='Path to system prompt file', required=False) parser.add_argument('--tools', nargs='+', help='List of allowed function tags', required=False) # Add these to launch.json arguments if you want to limit the toolset available: "--tools", "read", "communicate" # Parse command line arguments args = parser.parse_args() if args.persona: logging.info(f"Using custom persona from: {args.persona}") system_prompt_path=args.persona if args.persona else None allowed_function_tags=args.tools if args.tools else None config_prepend = args.config if args.config else None messenger = args.messenger if args.messenger else None # Initialize model and max tokens based on the config prepend if config_prepend: api_key = os.environ.get(f"{config_prepend.upper()}_API_KEY") baseurl = os.environ.get(f"{config_prepend.upper()}_API_BASE_URL", "") small_model_name = os.environ.get(f"{config_prepend.upper()}_SMALL_MODEL") large_model_name = os.environ.get(f"{config_prepend.upper()}_LARGE_MODEL") small_model_max_tokens = os.environ.get(f"{config_prepend.upper()}_SMALL_MODEL_MAX_TOKENS") large_model_max_tokens = os.environ.get(f"{config_prepend.upper()}_LARGE_MODEL_MAX_TOKENS") bot = OpenAICompatibleInferenceBot( api_key=api_key, base_url=baseurl, small_model_name=small_model_name, small_model_max_tokens=small_model_max_tokens, large_model_name=large_model_name, large_model_max_tokens=large_model_max_tokens, system_prompt_path=system_prompt_path, allowed_function_tags=allowed_function_tags ) full_code_file = importlib.import_module(f'{messenger.lower()}_helper') messenger_helper_class_name = f"{messenger.capitalize()}Helper" if not hasattr(full_code_file, messenger_helper_class_name): messenger_helper_class_name = f"{messenger.upper()}Helper" if not hasattr(full_code_file, messenger_helper_class_name): raise ValueError(f"Messenger helper class {messenger_helper_class_name} not found in {full_code_file.__name__}.") helper_class = getattr(full_code_file, messenger_helper_class_name) helper = helper_class(bot) helper.run() except ValueError as e: logging.error(f"FATAL: {e}") return except Exception as e: # Catch any other init errors logging.error(f"An unexpected error occurred during bot initialization: {e}") return if __name__ == '__main__': main()