import importlib import json import os import logging import inspect import re from abc import abstractmethod from openai import OpenAI from tools.base_tool import BaseTool from telegram_helper import TelegramHelper import argparse from inference_bot import InferenceBot import tiktoken # Added this import class OpenAICompatibleInferenceBot(InferenceBot): def __init__( self, api_key: str | None = None, base_url: str | None = None, small_model_name: str | None = None, small_model_max_tokens: str | None = None, large_model_name: str | None = None, large_model_max_tokens: str | None = None, allowed_function_tags: list[str] | None = None, system_prompt_path: str | None = None, use_large_model: bool = False ): self.model_config = { "small_model_name": small_model_name, "small_model_max_tokens": small_model_max_tokens, "large_model_name": large_model_name, "large_model_max_tokens": large_model_max_tokens } self.allowed_function_tags = allowed_function_tags if allowed_function_tags else None self.conversation_history = {} self._processing_status = {} self.system_prompt_path = system_prompt_path self.system_prompt = self.load_system_prompt( file_path=system_prompt_path ) self.tools, self.functions = self.load_functions() self.client = OpenAI(api_key=api_key, base_url=base_url) log_msg = f"Initialized OpenAI compatible client. Target URL: {base_url if base_url else 'OpenAI default'}." logging.info(log_msg) # Load inference token limits (defaults: small=16k, large=32k) self.small_model_max_inference_tokens = int(os.getenv("_SMALL_MODEL_MAX_INFERENCE_TOKENS", "16384")) self.large_model_max_inference_tokens = int(os.getenv("_LARGE_MODEL_MAX_INFERENCE_TOKENS", "32768")) # Configure the actual model name and max_tokens for API calls if use_large_model: self._configure_model_and_tokens( self.model_config["large_model_name"], self.model_config["large_model_max_tokens"] ) else: self._configure_model_and_tokens( self.model_config["small_model_name"], self.model_config["small_model_max_tokens"] ) @property def processing_status(self): return self._processing_status def clear_conversation_history(self, user_id): if user_id in self.conversation_history: del self.conversation_history[user_id] for tool in self.tools: tool.clear() def _configure_model_and_tokens(self, model_name: str | None, max_tokens_str: str | None): self.model = model_name try: if max_tokens_str and max_tokens_str.lower() not in ["none", "", "null"]: self.max_tokens = int(max_tokens_str) else: self.max_tokens = None except ValueError: logging.warning(f"Invalid value for max_tokens: {max_tokens_str}. Using API default (None)") self.max_tokens = None logging.info(f"Configured to use model: {self.model} with max_tokens: {self.max_tokens if self.max_tokens is not None else 'API default'}") def get_llm_description(self) -> str: client_type = type(self.client).__name__ return f"Client: {client_type}, LLM: {self.model}, Max Tokens: {self.max_tokens if self.max_tokens is not None else 'API default'}" def _encoding_for_model(self, model: str | None): try: return tiktoken.encoding_for_model(model) if model else tiktoken.get_encoding("cl100k_base") except KeyError: logging.warning(f"Warning: model {model} not found. Using cl100k_base encoding.") return tiktoken.get_encoding("cl100k_base") def _normalize_messages(self, messages): """Return a list of plain dict chat messages acceptable by the API. - Converts OpenAI SDK message objects into dicts - Preserves tool_calls structure where present """ normalized = [] for m in messages: if isinstance(m, dict): # Ensure only known keys are present; copy shallowly entry = {k: v for k, v in m.items() if k in {"role", "content", "name", "tool_call_id", "tool_calls"}} normalized.append(entry) else: # Likely an OpenAI message object role = getattr(m, "role", None) content = getattr(m, "content", None) name = getattr(m, "name", None) tool_calls = [] tc_list = getattr(m, "tool_calls", None) if tc_list: for tc in tc_list: try: tool_calls.append({ "id": getattr(tc, "id", None), "type": getattr(tc, "type", "function"), "function": { "name": getattr(getattr(tc, "function", None), "name", None), "arguments": getattr(getattr(tc, "function", None), "arguments", "{}"), } }) except Exception: # Best-effort fallback tool_calls.append({"id": None, "type": "function", "function": {"name": "unknown", "arguments": "{}"}}) entry = {"role": role, "content": content} if name: entry["name"] = name if tool_calls: entry["tool_calls"] = tool_calls normalized.append(entry) return normalized def _estimate_tokens(self, messages): """Estimate tokens for messages with tiktoken, including tool_calls arguments. Based on OpenAI's chat token counting rules approximation. """ enc = self._encoding_for_model(self.model) num_tokens = 0 for m in messages: num_tokens += 4 # per-message overhead if not isinstance(m, dict): continue # role/content for key in ("role", "name", "content"): v = m.get(key) if isinstance(v, str): num_tokens += len(enc.encode(v)) # tool calls request portion (arguments) tcs = m.get("tool_calls") if tcs and isinstance(tcs, list): # approximate cost of the tool_calls JSON the model sees for tc in tcs: fn = tc.get("function", {}) if isinstance(tc, dict) else {} fname = fn.get("name") fargs = fn.get("arguments") if isinstance(fname, str): num_tokens += len(enc.encode(fname)) if isinstance(fargs, str): num_tokens += len(enc.encode(fargs)) num_tokens += 2 # assistant priming return num_tokens def _get_inference_limit(self): current_model_is_small = self.model == self.model_config["small_model_name"] current_model_is_large = self.model == self.model_config["large_model_name"] if current_model_is_small: return self.small_model_max_inference_tokens if current_model_is_large: return self.large_model_max_inference_tokens logging.warning(f"Could not determine inference token limit for model: {self.model}. Proceeding without check.") return None def _summarize_tool_args(self, args_str: str, max_chars: int = 512) -> str: """Summarize tool-call request arguments without altering tool responses. - If JSON, keep keys and short previews of string values. - If plain string, truncate with an indicator. """ try: parsed = json.loads(args_str) if isinstance(parsed, dict): summary = {} for k, v in parsed.items(): if isinstance(v, str): if len(v) > 160: summary[k] = v[:120] + f"... [len={len(v)}]" else: summary[k] = v elif isinstance(v, (list, dict)): # structural summary only summary[k] = f"<{type(v).__name__} size={len(v)}>" else: summary[k] = v s = json.dumps(summary, ensure_ascii=False) if len(s) > max_chars: s = s[: max_chars - 20] + "... [summarized]" return s except Exception: pass # Fallback: truncate raw string return (args_str[: max_chars - 20] + "... [summarized]") if len(args_str) > max_chars else args_str def _summarize_tool_call_requests_in_messages(self, messages): changed = False for m in messages: if isinstance(m, dict) and m.get("tool_calls"): new_tool_calls = [] for tc in m["tool_calls"]: if not isinstance(tc, dict): new_tool_calls.append(tc) continue fn = tc.get("function", {}) args = fn.get("arguments") if isinstance(args, str) and args and len(args) > 700: # summarize long request arguments only fn = dict(fn) fn["arguments"] = self._summarize_tool_args(args) tc = dict(tc) tc["function"] = fn changed = True new_tool_calls.append(tc) if changed: m["tool_calls"] = new_tool_calls return changed def _elide_redundant_code_blocks(self, messages): """As a last resort, remove large code blocks from older assistant messages. Keep the latest assistant message intact. """ changed = False # Identify indices of assistant messages assistant_indices = [i for i, m in enumerate(messages) if isinstance(m, dict) and m.get("role") == "assistant" and m.get("content")] if len(assistant_indices) <= 1: return changed # Protect the last assistant message for i in assistant_indices[:-1]: m = messages[i] content = m.get("content") if not isinstance(content, str): continue if "```" in content or "\n " in content: # Replace code blocks fenced by ``` with succinct markers orig = content content = re.sub(r"```[\s\S]*?```", "[code block omitted]", content) # Also collapse long indented blocks content = re.sub(r"(?:\n\s{4,}.+)+", "\n[long block omitted]", content) if content != orig: m["content"] = content changed = True return changed def _enforce_budget(self, messages): """Normalize and enforce token budget by summarizing only tool-call requests first, then eliding redundant code blocks if still too large. Returns normalized messages. """ normalized = self._normalize_messages(messages) limit = self._get_inference_limit() if not limit: return normalized # Reserve space for completion tokens reserve = self.max_tokens if isinstance(self.max_tokens, int) else 1024 budget = max(1024, limit - reserve) tokens = self._estimate_tokens(normalized) if tokens <= budget: return normalized # Step 1: summarize only tool-call request arguments if self._summarize_tool_call_requests_in_messages(normalized): tokens = self._estimate_tokens(normalized) logging.info(f"Applied tool-call request summarization. tokens={tokens}/{budget}") if tokens <= budget: return normalized # Step 2: elide redundant code blocks from older assistant messages if self._elide_redundant_code_blocks(normalized): tokens = self._estimate_tokens(normalized) logging.info(f"Elided redundant code blocks. tokens={tokens}/{budget}") if tokens <= budget: return normalized # If still over, log and proceed; the API may still reject; caller may choose to abort logging.warning(f"Projected tokens still exceed budget after optimizations: {tokens}/{budget}") return normalized def get_chat_response(self, messages): if not self.client: logging.error("OpenAI client not initialized before get_chat_response.") raise ValueError("OpenAI client not initialized.") try: cleaned_tools = None if hasattr(self, 'functions') and self.functions: cleaned_tools = [] for func in self.functions: include_function = False if not hasattr(self, 'allowed_function_tags') or self.allowed_function_tags is None: include_function = True else: tags = func.get("_tags", []) if any(tag in self.allowed_function_tags for tag in tags): include_function = True if include_function: func_copy = {k: v for k, v in func.items() if k != "_tags"} cleaned_tools.append(func_copy) # Enforce token budget prior to API call messages_for_api = self._enforce_budget(messages) response = self.client.chat.completions.create( model=self.model, messages=messages_for_api, tools=cleaned_tools, tool_choice="auto" if cleaned_tools else None, max_tokens=self.max_tokens, ) return response except Exception as e: logging.error(f"API call to model {self.model} failed: {e}") raise def get_bot_status(self): """ Returns a message with the currently enabled model and the system prompt path being used. """ model_name = self.model if hasattr(self, 'model') else None prompt_path = self.system_prompt_path or os.getenv("SYSTEM_PROMPT_PATH") or "(default prompt in use)" return f"Current model: {model_name}\nSystem prompt path: {prompt_path}" async def handle_message(self, user_id, user_message): if user_id not in self.conversation_history or not self.conversation_history[user_id]: self.conversation_history[user_id] = [] if self.system_prompt: self.conversation_history[user_id].append({"role": "system", "content": self.system_prompt}) self.conversation_history[user_id].append({"role": "user", "content": user_message}) messages = list(self.conversation_history[user_id]) # Pre-inference token limit check with budgeted optimizations limit = self._get_inference_limit() if limit is not None: # Estimate on normalized messages after applying request-only summarization if needed provisional = self._enforce_budget(messages) token_count = self._estimate_tokens(provisional) reserve = self.max_tokens if isinstance(self.max_tokens, int) else 1024 budget = max(1024, limit - reserve) if token_count > budget: logging.warning(f"Request for user {user_id} exceeds inference token budget even after optimizations ({token_count}/{budget}).") # Do not persist this message in history as it was not processed by LLM if self.conversation_history[user_id] and self.conversation_history[user_id][-1]["role"] == "user" and self.conversation_history[user_id][-1]["content"] == user_message: self.conversation_history[user_id].pop() return "Request exceeds inference token limit after optimization. Please shorten your request, use /clear, or implement RAG in your application." response = self.get_chat_response(messages) if not (response.choices and response.choices[0].message): logging.error("No valid response choice message from LLM.") self.conversation_history[user_id] = messages return "Error: Could not get a valid response from the LLM." assistant_message = response.choices[0].message messages.append(assistant_message) tool_calls_from_response = list(assistant_message.tool_calls) if assistant_message.tool_calls else [] tool_use_count = 0 MAX_TOOL_ITERATIONS = 200 while tool_calls_from_response and tool_use_count < MAX_TOOL_ITERATIONS: tool_results_for_model = [] for tool_call in tool_calls_from_response: tool_call_id = tool_call.id function_to_call = tool_call.function function_name = function_to_call.name function_args_str = function_to_call.arguments logging.info(f"Attempting to call tool: {function_name} with args: [request summarized if large]") if function_name not in [f["function"]["name"] for f in self.functions]: logging.warning(f"Tool function {function_name} not found in available functions.") tool_results_for_model.append({ "role": "tool", "tool_call_id": tool_call_id, "name": function_name, "content": f"Error: Tool function {function_name} not found." }) continue try: tool_response_content = self.call_tool(function_name, function_args_str) if not isinstance(tool_response_content, str): tool_response_content = json.dumps(tool_response_content) except Exception as e: logging.error(f"Error calling tool {function_name}: {e}") tool_response_content = f"Error executing tool {function_name}: {str(e)}" tool_results_for_model.append({ "role": "tool", "tool_call_id": tool_call_id, "name": function_name, "content": tool_response_content }) messages.extend(tool_results_for_model) # Enforce budget before next LLM call (summarize request portion only; preserve tool responses) response = self.get_chat_response(messages) if not (response.choices and response.choices[0].message): logging.error("No valid response choice message from LLM after tool call.") self.conversation_history[user_id] = messages return "Error: Could not get a valid response from the LLM after tool call." assistant_message = response.choices[0].message messages.append(assistant_message) tool_calls_from_response = list(assistant_message.tool_calls) if assistant_message.tool_calls else [] tool_use_count += 1 if tool_use_count >= MAX_TOOL_ITERATIONS and tool_calls_from_response: logging.warning(f"Max tool iterations ({MAX_TOOL_ITERATIONS}) reached. Returning last assistant message.") break self.conversation_history[user_id] = messages final_assistant_message = messages[-1] return final_assistant_message.content if getattr(final_assistant_message, "role", None) == "assistant" and getattr(final_assistant_message, "content", None) is not None else (final_assistant_message.get("content") if isinstance(final_assistant_message, dict) else "Assistant did not provide a textual response.") async def start(self): logging.info(f"{self.__class__.__name__} (Model: {self.model}) started.") async def abort_processing(self, user_id): if user_id in self.processing_status: self.clear_processing_status(user_id) logging.info(f"Processing aborted for user {user_id}.") return "Processing aborted. You can send a new message or /clear the conversation." else: return "No active processing found to abort. If you wish, /clear the conversation history." def load_functions(self): tools = [] functions = [] tools_dir = os.path.join(os.path.dirname(__file__), 'tools') if not os.path.exists(tools_dir): logging.warning(f"Tools directory not found: {tools_dir}") return [], [] for filename in os.listdir(tools_dir): if filename.endswith('.py') and filename != '__init__.py' and filename != 'base_tool.py': module_name = f'tools.{filename[:-3]}' try: module = importlib.import_module(module_name) for name, obj in inspect.getmembers(module): if inspect.isclass(obj) and issubclass(obj, BaseTool) and obj != BaseTool: try: tools.append(obj()) # This instantiation might be an issue for tools needing config except Exception as e: logging.error(f"Error instantiating tool {name} from {filename}: {e}") except Exception as e: logging.error(f"Error importing module {module_name}: {e}") for tool in tools: functions.extend(tool.get_functions()) return tools, functions def load_system_prompt(self, direct_content: str | None = None, file_path: str | None = None) -> str: default_prompt = "You are a helpful AI assistant." if direct_content: logging.info("Using direct content for system prompt.") return direct_content.strip() prompt_path_to_try = file_path or os.getenv("SYSTEM_PROMPT_PATH") if prompt_path_to_try: if os.path.isfile(prompt_path_to_try): try: with open(prompt_path_to_try, "r", encoding="utf-8") as file: content = file.read().strip() logging.info(f"Successfully loaded system prompt from {prompt_path_to_try}.") return content except IOError as e: logging.warning(f"Could not read system prompt file {prompt_path_to_try}: {e}. Using default.") return default_prompt else: logging.warning(f"System prompt file {prompt_path_to_try} not found. Using default system prompt.") return default_prompt else: logging.info("No system prompt path provided (argument or ENV) or direct content. Using default system prompt.") return default_prompt def set_processing_status(self, user_id: int, message_id: int): self.processing_status[user_id] = {"processing": True, "message_id": message_id} def clear_processing_status(self, user_id: int): if user_id in self.processing_status: del self.processing_status[user_id] def call_tool(self, function_call_name, function_call_arguments): function_name = function_call_name function_args = None if isinstance(function_call_arguments, dict): function_args = function_call_arguments elif isinstance(function_call_arguments, str): try: function_args = json.loads(function_call_arguments) except json.JSONDecodeError as e: logging.error(f"Error decoding function call arguments (string) for {function_call_name}: {e}. Arguments: {function_call_arguments}") return f"Error: Malformed arguments for tool call: {e}" else: if function_call_arguments is None: function_args = {} else: logging.error(f"Unexpected type for function_call_arguments for {function_name}: {type(function_call_arguments)}. Arguments: {function_call_arguments}") return f"Error: Invalid argument type for tool call: {type(function_call_arguments)}" for tool in self.tools: for function in tool.get_functions(): if function["function"]["name"] == function_name: try: if not isinstance(function_args, dict): logging.error(f"Internal error: function_args not a dict for {function_name} before execution. Args: {function_args}") return f"Internal error preparing arguments for tool {function_name}." return tool.execute(function_name, **function_args) except Exception as e: logging.error(f"Error executing tool {function_name} with args {function_args}: {e}") return f"Error executing tool {function_name}: {e}" logging.warning(f"Tool function {function_name} not found.") return f"Error: Tool function {function_name} not found." async def switch_model(self): if not self.model_config["small_model_name"] or not self.model_config["large_model_name"]: logging.warning("Small or Large model names are not defined. Cannot switch model.") return f"Model switching not fully configured. Currently using {self.model}." current_is_small = self.model == self.model_config["small_model_name"] current_is_large = self.model == self.model_config["large_model_name"] if current_is_large: target_model = self.model_config["small_model_name"] target_max_tokens_str = self.model_config["small_model_max_tokens"] elif current_is_small: target_model = self.model_config["large_model_name"] target_max_tokens_str = self.model_config["large_model_max_tokens"] else: logging.warning(f"Current model {self.model} is unrecognized. Switching to default small model: {self.model_config['small_model_name']}.") target_model = self.model_config["small_model_name"] target_max_tokens_str = self.model_config["small_model_max_tokens"] self._configure_model_and_tokens(target_model, target_max_tokens_str) return f"Switched model to {self.model}. Max tokens set to {self.max_tokens if self.max_tokens is not None else 'API default'}." def main(): logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') bot = None try: parser = argparse.ArgumentParser(description='OpenAI Compatible Inference Bot') parser.add_argument('--config', type=str, help='Configuration Prepend (i.e. gemini, openai, etc)', default="Telegram") parser.add_argument('--messenger', type=str, help='Messenger type (i.e. telegram)', required=True) parser.add_argument('--persona', type=str, help='Path to system prompt file', required=False) parser.add_argument('--tools', nargs='+', help='List of allowed function tags', required=False) parser.add_argument('--use-large-model', action='store_true', help='Use the large model instead of the small model') # Add these to launch.json arguments if you want to limit the toolset available: "--tools", "read", "communicate" # Parse command line arguments args = parser.parse_args() if args.persona: logging.info(f"Using custom persona from: {args.persona}") system_prompt_path=args.persona if args.persona else None allowed_function_tags=args.tools if args.tools else None config_prepend = args.config if args.config else None messenger = args.messenger if args.messenger else None use_large_model = args.use_large_model # Initialize model and max tokens based on the config prepend if config_prepend: api_key = os.environ.get(f"{config_prepend.upper()}_API_KEY") baseurl = os.environ.get(f"{config_prepend.upper()}_API_BASE_URL", "") small_model_name = os.environ.get(f"{config_prepend.upper()}_SMALL_MODEL") large_model_name = os.environ.get(f"{config_prepend.upper()}_LARGE_MODEL") small_model_max_tokens = os.environ.get(f"{config_prepend.upper()}_SMALL_MODEL_MAX_TOKENS") large_model_max_tokens = os.environ.get(f"{config_prepend.upper()}_LARGE_MODEL_MAX_TOKENS") bot = OpenAICompatibleInferenceBot( api_key=api_key, base_url=baseurl, small_model_name=small_model_name, small_model_max_tokens=small_model_max_tokens, large_model_name=large_model_name, large_model_max_tokens=large_model_max_tokens, system_prompt_path=system_prompt_path, allowed_function_tags=allowed_function_tags, use_large_model=use_large_model ) full_code_file = importlib.import_module(f'{messenger.lower()}_helper') messenger_helper_class_name = f"{messenger.capitalize()}Helper" if not hasattr(full_code_file, messenger_helper_class_name): messenger_helper_class_name = f"{messenger.upper()}Helper" if not hasattr(full_code_file, messenger_helper_class_name): raise ValueError(f"Messenger helper class {messenger_helper_class_name} not found in {full_code_file.__name__}.") helper_class = getattr(full_code_file, messenger_helper_class_name) helper = helper_class(bot) helper.run() except ValueError as e: logging.error(f"FATAL: {e}") return except Exception as e: # Catch any other init errors logging.error(f"An unexpected error occurred during bot initialization: {e}") return if __name__ == '__main__': main()