Files
cyclop/openai_compatible_inference_bot.py
T

621 lines
30 KiB
Python
Raw Normal View History

import importlib
import json
import os
import logging
import inspect
import re
from abc import abstractmethod
from openai import OpenAI
from tools.base_tool import BaseTool
from telegram_helper import TelegramHelper
import argparse
from inference_bot import InferenceBot
import tiktoken # Added this import
class OpenAICompatibleInferenceBot(InferenceBot):
def __init__(
self,
api_key: str | None = None,
base_url: str | None = None,
small_model_name: str | None = None,
small_model_max_tokens: str | None = None,
large_model_name: str | None = None,
large_model_max_tokens: str | None = None,
allowed_function_tags: list[str] | None = None,
system_prompt_path: str | None = None,
use_large_model: bool = False
):
self.model_config = {
"small_model_name": small_model_name,
"small_model_max_tokens": small_model_max_tokens,
"large_model_name": large_model_name,
"large_model_max_tokens": large_model_max_tokens
}
self.allowed_function_tags = allowed_function_tags if allowed_function_tags else None
self.conversation_history = {}
self._processing_status = {}
self.system_prompt_path = system_prompt_path
self.system_prompt = self.load_system_prompt(
file_path=system_prompt_path
)
self.tools, self.functions = self.load_functions()
self.client = OpenAI(api_key=api_key, base_url=base_url)
log_msg = f"Initialized OpenAI compatible client. Target URL: {base_url if base_url else 'OpenAI default'}."
logging.info(log_msg)
# Load inference token limits (defaults: small=16k, large=32k)
self.small_model_max_inference_tokens = int(os.getenv("_SMALL_MODEL_MAX_INFERENCE_TOKENS", "16384"))
self.large_model_max_inference_tokens = int(os.getenv("_LARGE_MODEL_MAX_INFERENCE_TOKENS", "32768"))
# Configure the actual model name and max_tokens for API calls
if use_large_model:
self._configure_model_and_tokens(
self.model_config["large_model_name"],
self.model_config["large_model_max_tokens"]
)
else:
self._configure_model_and_tokens(
self.model_config["small_model_name"],
self.model_config["small_model_max_tokens"]
)
@property
def processing_status(self):
return self._processing_status
def clear_conversation_history(self, user_id):
if user_id in self.conversation_history:
del self.conversation_history[user_id]
for tool in self.tools:
tool.clear()
def _configure_model_and_tokens(self, model_name: str | None, max_tokens_str: str | None):
self.model = model_name
try:
if max_tokens_str and max_tokens_str.lower() not in ["none", "", "null"]:
self.max_tokens = int(max_tokens_str)
else:
self.max_tokens = None
except ValueError:
logging.warning(f"Invalid value for max_tokens: {max_tokens_str}. Using API default (None)")
self.max_tokens = None
logging.info(f"Configured to use model: {self.model} with max_tokens: {self.max_tokens if self.max_tokens is not None else 'API default'}")
def get_llm_description(self) -> str:
client_type = type(self.client).__name__
return f"Client: {client_type}, LLM: {self.model}, Max Tokens: {self.max_tokens if self.max_tokens is not None else 'API default'}"
def _encoding_for_model(self, model: str | None):
try:
return tiktoken.encoding_for_model(model) if model else tiktoken.get_encoding("cl100k_base")
except KeyError:
logging.warning(f"Warning: model {model} not found. Using cl100k_base encoding.")
return tiktoken.get_encoding("cl100k_base")
def _normalize_messages(self, messages):
"""Return a list of plain dict chat messages acceptable by the API.
- Converts OpenAI SDK message objects into dicts
- Preserves tool_calls structure where present
"""
normalized = []
for m in messages:
if isinstance(m, dict):
# Ensure only known keys are present; copy shallowly
entry = {k: v for k, v in m.items() if k in {"role", "content", "name", "tool_call_id", "tool_calls"}}
normalized.append(entry)
else:
# Likely an OpenAI message object
role = getattr(m, "role", None)
content = getattr(m, "content", None)
name = getattr(m, "name", None)
tool_calls = []
tc_list = getattr(m, "tool_calls", None)
if tc_list:
for tc in tc_list:
try:
tool_calls.append({
"id": getattr(tc, "id", None),
"type": getattr(tc, "type", "function"),
"function": {
"name": getattr(getattr(tc, "function", None), "name", None),
"arguments": getattr(getattr(tc, "function", None), "arguments", "{}"),
}
})
except Exception:
# Best-effort fallback
tool_calls.append({"id": None, "type": "function", "function": {"name": "unknown", "arguments": "{}"}})
entry = {"role": role, "content": content}
if name:
entry["name"] = name
if tool_calls:
entry["tool_calls"] = tool_calls
normalized.append(entry)
return normalized
def _estimate_tokens(self, messages):
"""Estimate tokens for messages with tiktoken, including tool_calls arguments.
Based on OpenAI's chat token counting rules approximation.
"""
enc = self._encoding_for_model(self.model)
num_tokens = 0
for m in messages:
num_tokens += 4 # per-message overhead
if not isinstance(m, dict):
continue
# role/content
for key in ("role", "name", "content"):
v = m.get(key)
if isinstance(v, str):
num_tokens += len(enc.encode(v))
# tool calls request portion (arguments)
tcs = m.get("tool_calls")
if tcs and isinstance(tcs, list):
# approximate cost of the tool_calls JSON the model sees
for tc in tcs:
fn = tc.get("function", {}) if isinstance(tc, dict) else {}
fname = fn.get("name")
fargs = fn.get("arguments")
if isinstance(fname, str):
num_tokens += len(enc.encode(fname))
if isinstance(fargs, str):
num_tokens += len(enc.encode(fargs))
num_tokens += 2 # assistant priming
return num_tokens
def _get_inference_limit(self):
current_model_is_small = self.model == self.model_config["small_model_name"]
current_model_is_large = self.model == self.model_config["large_model_name"]
if current_model_is_small:
return self.small_model_max_inference_tokens
if current_model_is_large:
return self.large_model_max_inference_tokens
logging.warning(f"Could not determine inference token limit for model: {self.model}. Proceeding without check.")
return None
def _summarize_tool_args(self, args_str: str, max_chars: int = 512) -> str:
"""Summarize tool-call request arguments without altering tool responses.
- If JSON, keep keys and short previews of string values.
- If plain string, truncate with an indicator.
"""
try:
parsed = json.loads(args_str)
if isinstance(parsed, dict):
summary = {}
for k, v in parsed.items():
if isinstance(v, str):
if len(v) > 160:
summary[k] = v[:120] + f"... [len={len(v)}]"
else:
summary[k] = v
elif isinstance(v, (list, dict)):
# structural summary only
summary[k] = f"<{type(v).__name__} size={len(v)}>"
else:
summary[k] = v
s = json.dumps(summary, ensure_ascii=False)
if len(s) > max_chars:
s = s[: max_chars - 20] + "... [summarized]"
return s
except Exception:
pass
# Fallback: truncate raw string
return (args_str[: max_chars - 20] + "... [summarized]") if len(args_str) > max_chars else args_str
def _summarize_tool_call_requests_in_messages(self, messages):
changed = False
for m in messages:
if isinstance(m, dict) and m.get("tool_calls"):
new_tool_calls = []
for tc in m["tool_calls"]:
if not isinstance(tc, dict):
new_tool_calls.append(tc)
continue
fn = tc.get("function", {})
args = fn.get("arguments")
if isinstance(args, str) and args and len(args) > 700:
# summarize long request arguments only
fn = dict(fn)
fn["arguments"] = self._summarize_tool_args(args)
tc = dict(tc)
tc["function"] = fn
changed = True
new_tool_calls.append(tc)
if changed:
m["tool_calls"] = new_tool_calls
return changed
def _elide_redundant_code_blocks(self, messages):
"""As a last resort, remove large code blocks from older assistant messages.
Keep the latest assistant message intact.
"""
changed = False
# Identify indices of assistant messages
assistant_indices = [i for i, m in enumerate(messages) if isinstance(m, dict) and m.get("role") == "assistant" and m.get("content")]
if len(assistant_indices) <= 1:
return changed
# Protect the last assistant message
for i in assistant_indices[:-1]:
m = messages[i]
content = m.get("content")
if not isinstance(content, str):
continue
if "```" in content or "\n " in content:
# Replace code blocks fenced by ``` with succinct markers
orig = content
content = re.sub(r"```[\s\S]*?```", "[code block omitted]", content)
# Also collapse long indented blocks
content = re.sub(r"(?:\n\s{4,}.+)+", "\n[long block omitted]", content)
if content != orig:
m["content"] = content
changed = True
return changed
def _enforce_budget(self, messages):
"""Normalize and enforce token budget by summarizing only tool-call requests first,
then eliding redundant code blocks if still too large. Returns normalized messages.
"""
normalized = self._normalize_messages(messages)
limit = self._get_inference_limit()
if not limit:
return normalized
# Reserve space for completion tokens
reserve = self.max_tokens if isinstance(self.max_tokens, int) else 1024
budget = max(1024, limit - reserve)
tokens = self._estimate_tokens(normalized)
if tokens <= budget:
return normalized
# Step 1: summarize only tool-call request arguments
if self._summarize_tool_call_requests_in_messages(normalized):
tokens = self._estimate_tokens(normalized)
logging.info(f"Applied tool-call request summarization. tokens={tokens}/{budget}")
if tokens <= budget:
return normalized
# Step 2: elide redundant code blocks from older assistant messages
if self._elide_redundant_code_blocks(normalized):
tokens = self._estimate_tokens(normalized)
logging.info(f"Elided redundant code blocks. tokens={tokens}/{budget}")
if tokens <= budget:
return normalized
# If still over, log and proceed; the API may still reject; caller may choose to abort
logging.warning(f"Projected tokens still exceed budget after optimizations: {tokens}/{budget}")
return normalized
def get_chat_response(self, messages):
if not self.client:
logging.error("OpenAI client not initialized before get_chat_response.")
raise ValueError("OpenAI client not initialized.")
try:
cleaned_tools = None
if hasattr(self, 'functions') and self.functions:
cleaned_tools = []
for func in self.functions:
include_function = False
if not hasattr(self, 'allowed_function_tags') or self.allowed_function_tags is None:
include_function = True
else:
tags = func.get("_tags", [])
if any(tag in self.allowed_function_tags for tag in tags):
include_function = True
if include_function:
func_copy = {k: v for k, v in func.items() if k != "_tags"}
cleaned_tools.append(func_copy)
# Enforce token budget prior to API call
messages_for_api = self._enforce_budget(messages)
response = self.client.chat.completions.create(
model=self.model,
messages=messages_for_api,
tools=cleaned_tools,
tool_choice="auto" if cleaned_tools else None,
2025-08-07 15:38:01 -05:00
max_tokens=self.max_tokens,
)
return response
except Exception as e:
logging.error(f"API call to model {self.model} failed: {e}")
raise
def get_bot_status(self):
"""
Returns a message with the currently enabled model and the system prompt path being used.
"""
model_name = self.model if hasattr(self, 'model') else None
prompt_path = self.system_prompt_path or os.getenv("SYSTEM_PROMPT_PATH") or "(default prompt in use)"
return f"Current model: {model_name}\nSystem prompt path: {prompt_path}"
async def handle_message(self, user_id, user_message):
if user_id not in self.conversation_history or not self.conversation_history[user_id]:
self.conversation_history[user_id] = []
if self.system_prompt:
self.conversation_history[user_id].append({"role": "system", "content": self.system_prompt})
self.conversation_history[user_id].append({"role": "user", "content": user_message})
messages = list(self.conversation_history[user_id])
# Pre-inference token limit check with budgeted optimizations
limit = self._get_inference_limit()
if limit is not None:
# Estimate on normalized messages after applying request-only summarization if needed
provisional = self._enforce_budget(messages)
token_count = self._estimate_tokens(provisional)
reserve = self.max_tokens if isinstance(self.max_tokens, int) else 1024
budget = max(1024, limit - reserve)
if token_count > budget:
logging.warning(f"Request for user {user_id} exceeds inference token budget even after optimizations ({token_count}/{budget}).")
# Do not persist this message in history as it was not processed by LLM
if self.conversation_history[user_id] and self.conversation_history[user_id][-1]["role"] == "user" and self.conversation_history[user_id][-1]["content"] == user_message:
self.conversation_history[user_id].pop()
return "Request exceeds inference token limit after optimization. Please shorten your request, use /clear, or implement RAG in your application."
response = self.get_chat_response(messages)
if not (response.choices and response.choices[0].message):
logging.error("No valid response choice message from LLM.")
self.conversation_history[user_id] = messages
return "Error: Could not get a valid response from the LLM."
assistant_message = response.choices[0].message
messages.append(assistant_message)
tool_calls_from_response = list(assistant_message.tool_calls) if assistant_message.tool_calls else []
tool_use_count = 0
2025-06-02 19:35:41 -05:00
MAX_TOOL_ITERATIONS = 200
while tool_calls_from_response and tool_use_count < MAX_TOOL_ITERATIONS:
tool_results_for_model = []
for tool_call in tool_calls_from_response:
tool_call_id = tool_call.id
function_to_call = tool_call.function
function_name = function_to_call.name
function_args_str = function_to_call.arguments
logging.info(f"Attempting to call tool: {function_name} with args: [request summarized if large]")
2025-06-03 17:36:26 -05:00
if function_name not in [f["function"]["name"] for f in self.functions]:
2025-06-03 17:32:19 -05:00
logging.warning(f"Tool function {function_name} not found in available functions.")
tool_results_for_model.append({
"role": "tool",
"tool_call_id": tool_call_id,
"name": function_name,
"content": f"Error: Tool function {function_name} not found."
})
continue
try:
tool_response_content = self.call_tool(function_name, function_args_str)
if not isinstance(tool_response_content, str):
tool_response_content = json.dumps(tool_response_content)
except Exception as e:
logging.error(f"Error calling tool {function_name}: {e}")
tool_response_content = f"Error executing tool {function_name}: {str(e)}"
tool_results_for_model.append({
"role": "tool",
"tool_call_id": tool_call_id,
"name": function_name,
"content": tool_response_content
})
messages.extend(tool_results_for_model)
# Enforce budget before next LLM call (summarize request portion only; preserve tool responses)
response = self.get_chat_response(messages)
if not (response.choices and response.choices[0].message):
logging.error("No valid response choice message from LLM after tool call.")
self.conversation_history[user_id] = messages
return "Error: Could not get a valid response from the LLM after tool call."
assistant_message = response.choices[0].message
messages.append(assistant_message)
tool_calls_from_response = list(assistant_message.tool_calls) if assistant_message.tool_calls else []
tool_use_count += 1
if tool_use_count >= MAX_TOOL_ITERATIONS and tool_calls_from_response:
logging.warning(f"Max tool iterations ({MAX_TOOL_ITERATIONS}) reached. Returning last assistant message.")
break
self.conversation_history[user_id] = messages
final_assistant_message = messages[-1]
return final_assistant_message.content if getattr(final_assistant_message, "role", None) == "assistant" and getattr(final_assistant_message, "content", None) is not None else (final_assistant_message.get("content") if isinstance(final_assistant_message, dict) else "Assistant did not provide a textual response.")
async def start(self):
logging.info(f"{self.__class__.__name__} (Model: {self.model}) started.")
async def abort_processing(self, user_id):
if user_id in self.processing_status:
self.clear_processing_status(user_id)
logging.info(f"Processing aborted for user {user_id}.")
return "Processing aborted. You can send a new message or /clear the conversation."
else:
return "No active processing found to abort. If you wish, /clear the conversation history."
def load_functions(self):
tools = []
functions = []
tools_dir = os.path.join(os.path.dirname(__file__), 'tools')
if not os.path.exists(tools_dir):
logging.warning(f"Tools directory not found: {tools_dir}")
return [], []
for filename in os.listdir(tools_dir):
if filename.endswith('.py') and filename != '__init__.py' and filename != 'base_tool.py':
module_name = f'tools.{filename[:-3]}'
try:
module = importlib.import_module(module_name)
2025-06-05 18:06:13 -05:00
for name, obj in inspect.getmembers(module):
if inspect.isclass(obj) and issubclass(obj, BaseTool) and obj != BaseTool:
try:
tools.append(obj()) # This instantiation might be an issue for tools needing config
except Exception as e:
logging.error(f"Error instantiating tool {name} from {filename}: {e}")
except Exception as e:
logging.error(f"Error importing module {module_name}: {e}")
for tool in tools:
functions.extend(tool.get_functions())
return tools, functions
def load_system_prompt(self, direct_content: str | None = None, file_path: str | None = None) -> str:
default_prompt = "You are a helpful AI assistant."
if direct_content:
logging.info("Using direct content for system prompt.")
return direct_content.strip()
prompt_path_to_try = file_path or os.getenv("SYSTEM_PROMPT_PATH")
if prompt_path_to_try:
if os.path.isfile(prompt_path_to_try):
try:
with open(prompt_path_to_try, "r", encoding="utf-8") as file:
content = file.read().strip()
logging.info(f"Successfully loaded system prompt from {prompt_path_to_try}.")
return content
except IOError as e:
logging.warning(f"Could not read system prompt file {prompt_path_to_try}: {e}. Using default.")
return default_prompt
else:
logging.warning(f"System prompt file {prompt_path_to_try} not found. Using default system prompt.")
return default_prompt
else:
logging.info("No system prompt path provided (argument or ENV) or direct content. Using default system prompt.")
return default_prompt
def set_processing_status(self, user_id: int, message_id: int):
self.processing_status[user_id] = {"processing": True, "message_id": message_id}
def clear_processing_status(self, user_id: int):
if user_id in self.processing_status:
del self.processing_status[user_id]
def call_tool(self, function_call_name, function_call_arguments):
function_name = function_call_name
function_args = None
if isinstance(function_call_arguments, dict):
function_args = function_call_arguments
elif isinstance(function_call_arguments, str):
try:
function_args = json.loads(function_call_arguments)
except json.JSONDecodeError as e:
logging.error(f"Error decoding function call arguments (string) for {function_call_name}: {e}. Arguments: {function_call_arguments}")
return f"Error: Malformed arguments for tool call: {e}"
else:
if function_call_arguments is None:
function_args = {}
else:
logging.error(f"Unexpected type for function_call_arguments for {function_name}: {type(function_call_arguments)}. Arguments: {function_call_arguments}")
return f"Error: Invalid argument type for tool call: {type(function_call_arguments)}"
for tool in self.tools:
for function in tool.get_functions():
if function["function"]["name"] == function_name:
try:
if not isinstance(function_args, dict):
logging.error(f"Internal error: function_args not a dict for {function_name} before execution. Args: {function_args}")
return f"Internal error preparing arguments for tool {function_name}."
return tool.execute(function_name, **function_args)
except Exception as e:
logging.error(f"Error executing tool {function_name} with args {function_args}: {e}")
return f"Error executing tool {function_name}: {e}"
logging.warning(f"Tool function {function_name} not found.")
return f"Error: Tool function {function_name} not found."
async def switch_model(self):
if not self.model_config["small_model_name"] or not self.model_config["large_model_name"]:
logging.warning("Small or Large model names are not defined. Cannot switch model.")
return f"Model switching not fully configured. Currently using {self.model}."
current_is_small = self.model == self.model_config["small_model_name"]
current_is_large = self.model == self.model_config["large_model_name"]
if current_is_large:
target_model = self.model_config["small_model_name"]
target_max_tokens_str = self.model_config["small_model_max_tokens"]
elif current_is_small:
target_model = self.model_config["large_model_name"]
target_max_tokens_str = self.model_config["large_model_max_tokens"]
else:
logging.warning(f"Current model {self.model} is unrecognized. Switching to default small model: {self.model_config['small_model_name']}.")
target_model = self.model_config["small_model_name"]
target_max_tokens_str = self.model_config["small_model_max_tokens"]
self._configure_model_and_tokens(target_model, target_max_tokens_str)
2025-06-03 13:54:38 -05:00
return f"Switched model to {self.model}. Max tokens set to {self.max_tokens if self.max_tokens is not None else 'API default'}."
def main():
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
bot = None
try:
parser = argparse.ArgumentParser(description='OpenAI Compatible Inference Bot')
parser.add_argument('--config', type=str, help='Configuration Prepend (i.e. gemini, openai, etc)', default="Telegram")
parser.add_argument('--messenger', type=str, help='Messenger type (i.e. telegram)', required=True)
parser.add_argument('--persona', type=str, help='Path to system prompt file', required=False)
parser.add_argument('--tools', nargs='+', help='List of allowed function tags', required=False)
parser.add_argument('--use-large-model', action='store_true', help='Use the large model instead of the small model')
# Add these to launch.json arguments if you want to limit the toolset available: "--tools", "read", "communicate"
# Parse command line arguments
args = parser.parse_args()
if args.persona:
logging.info(f"Using custom persona from: {args.persona}")
system_prompt_path=args.persona if args.persona else None
allowed_function_tags=args.tools if args.tools else None
config_prepend = args.config if args.config else None
messenger = args.messenger if args.messenger else None
use_large_model = args.use_large_model
# Initialize model and max tokens based on the config prepend
if config_prepend:
api_key = os.environ.get(f"{config_prepend.upper()}_API_KEY")
baseurl = os.environ.get(f"{config_prepend.upper()}_API_BASE_URL", "")
small_model_name = os.environ.get(f"{config_prepend.upper()}_SMALL_MODEL")
large_model_name = os.environ.get(f"{config_prepend.upper()}_LARGE_MODEL")
small_model_max_tokens = os.environ.get(f"{config_prepend.upper()}_SMALL_MODEL_MAX_TOKENS")
large_model_max_tokens = os.environ.get(f"{config_prepend.upper()}_LARGE_MODEL_MAX_TOKENS")
bot = OpenAICompatibleInferenceBot(
api_key=api_key,
base_url=baseurl,
small_model_name=small_model_name,
small_model_max_tokens=small_model_max_tokens,
large_model_name=large_model_name,
large_model_max_tokens=large_model_max_tokens,
system_prompt_path=system_prompt_path,
allowed_function_tags=allowed_function_tags,
use_large_model=use_large_model
)
2025-06-03 17:32:19 -05:00
full_code_file = importlib.import_module(f'{messenger.lower()}_helper')
messenger_helper_class_name = f"{messenger.capitalize()}Helper"
2025-06-03 17:32:19 -05:00
if not hasattr(full_code_file, messenger_helper_class_name):
messenger_helper_class_name = f"{messenger.upper()}Helper"
if not hasattr(full_code_file, messenger_helper_class_name):
raise ValueError(f"Messenger helper class {messenger_helper_class_name} not found in {full_code_file.__name__}.")
helper_class = getattr(full_code_file, messenger_helper_class_name)
2025-06-03 17:32:19 -05:00
helper = helper_class(bot)
helper.run()
except ValueError as e:
logging.error(f"FATAL: {e}")
return
except Exception as e: # Catch any other init errors
logging.error(f"An unexpected error occurred during bot initialization: {e}")
return
if __name__ == '__main__':
main()