578 lines
28 KiB
Python
578 lines
28 KiB
Python
import importlib
|
|
import json
|
|
import os
|
|
import logging
|
|
import inspect
|
|
import re
|
|
from abc import abstractmethod
|
|
from openai import OpenAI
|
|
from tools.base_tool import BaseTool
|
|
from tools.github_tool import GitHubTool
|
|
from telegram_helper import TelegramHelper
|
|
import argparse
|
|
from inference_bot import InferenceBot
|
|
import tiktoken # Added this import
|
|
|
|
class OpenAICompatibleInferenceBot(InferenceBot):
|
|
def __init__(
|
|
self,
|
|
api_key: str | None = None,
|
|
base_url: str | None = None,
|
|
small_model_name: str | None = None,
|
|
small_model_max_tokens: str | None = None,
|
|
large_model_name: str | None = None,
|
|
large_model_max_tokens: str | None = None,
|
|
allowed_function_tags: list[str] | None = None,
|
|
system_prompt_path: str | None = None,
|
|
use_large_model: bool = False
|
|
):
|
|
self.model_config = {
|
|
"small_model_name": small_model_name,
|
|
"small_model_max_tokens": small_model_max_tokens,
|
|
"large_model_name": large_model_name,
|
|
"large_model_max_tokens": large_model_max_tokens
|
|
}
|
|
self.allowed_function_tags = allowed_function_tags if allowed_function_tags else None
|
|
self.conversation_history = {}
|
|
self._processing_status = {}
|
|
self.system_prompt_path = system_prompt_path
|
|
self.system_prompt = self.load_system_prompt(
|
|
file_path=system_prompt_path
|
|
)
|
|
self.tools, self.functions = self.load_functions()
|
|
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
|
log_msg = f"Initialized OpenAI compatible client. Target URL: {base_url if base_url else 'OpenAI default'}."
|
|
logging.info(log_msg)
|
|
|
|
# Load inference token limits (defaults: small=16k, large=32k)
|
|
self.small_model_max_inference_tokens = int(os.getenv("_SMALL_MODEL_MAX_INFERENCE_TOKENS", "16384"))
|
|
self.large_model_max_inference_tokens = int(os.getenv("_LARGE_MODEL_MAX_INFERENCE_TOKENS", "32768"))
|
|
|
|
# Configure the actual model name and max_tokens for API calls
|
|
if use_large_model:
|
|
self._configure_model_and_tokens(
|
|
self.model_config["large_model_name"],
|
|
self.model_config["large_model_max_tokens"]
|
|
)
|
|
else:
|
|
self._configure_model_and_tokens(
|
|
self.model_config["small_model_name"],
|
|
self.model_config["small_model_max_tokens"]
|
|
)
|
|
|
|
@property
|
|
def processing_status(self):
|
|
return self._processing_status
|
|
|
|
def clear_conversation_history(self, user_id):
|
|
if user_id in self.conversation_history:
|
|
del self.conversation_history[user_id]
|
|
|
|
for tool in self.tools:
|
|
tool.clear()
|
|
|
|
def _configure_model_and_tokens(self, model_name: str | None, max_tokens_str: str | None):
|
|
self.model = model_name
|
|
try:
|
|
if max_tokens_str and max_tokens_str.lower() not in ["none", "", "null"]:
|
|
self.max_tokens = int(max_tokens_str)
|
|
else:
|
|
self.max_tokens = None
|
|
except ValueError:
|
|
logging.warning(f"Invalid value for max_tokens: {max_tokens_str}. Using API default (None)")
|
|
self.max_tokens = None
|
|
|
|
logging.info(f"Configured to use model: {self.model} with max_tokens: {self.max_tokens if self.max_tokens is not None else 'API default'}")
|
|
|
|
def get_llm_description(self) -> str:
|
|
client_type = type(self.client).__name__
|
|
return f"Client: {client_type}, LLM: {self.model}, Max Tokens: {self.max_tokens if self.max_tokens is not None else 'API default'}"
|
|
|
|
def _encoding_for_model(self, model: str | None):
|
|
try:
|
|
return tiktoken.encoding_for_model(model) if model else tiktoken.get_encoding("cl100k_base")
|
|
except KeyError:
|
|
logging.warning(f"Warning: model {model} not found. Using cl100k_base encoding.")
|
|
return tiktoken.get_encoding("cl100k_base")
|
|
|
|
def _normalize_messages(self, messages):
|
|
"""Return a list of plain dict chat messages acceptable by the API.
|
|
- Converts OpenAI SDK message objects into dicts
|
|
- Preserves tool_calls structure where present
|
|
"""
|
|
normalized = []
|
|
for m in messages:
|
|
if isinstance(m, dict):
|
|
# Ensure only known keys are present; copy shallowly
|
|
entry = {k: v for k, v in m.items() if k in {"role", "content", "name", "tool_call_id", "tool_calls"}}
|
|
normalized.append(entry)
|
|
else:
|
|
# Likely an OpenAI message object
|
|
role = getattr(m, "role", None)
|
|
content = getattr(m, "content", None)
|
|
name = getattr(m, "name", None)
|
|
tool_calls = []
|
|
tc_list = getattr(m, "tool_calls", None)
|
|
if tc_list:
|
|
for tc in tc_list:
|
|
try:
|
|
tool_calls.append({
|
|
"id": getattr(tc, "id", None),
|
|
"type": getattr(tc, "type", "function"),
|
|
"function": {
|
|
"name": getattr(getattr(tc, "function", None), "name", None),
|
|
"arguments": getattr(getattr(tc, "function", None), "arguments", "{}"),
|
|
}
|
|
})
|
|
except Exception:
|
|
# Best-effort fallback
|
|
tool_calls.append({"id": None, "type": "function", "function": {"name": "unknown", "arguments": "{}"}})
|
|
entry = {"role": role, "content": content}
|
|
if name:
|
|
entry["name"] = name
|
|
if tool_calls:
|
|
entry["tool_calls"] = tool_calls
|
|
normalized.append(entry)
|
|
return normalized
|
|
|
|
def _estimate_tokens(self, messages):
|
|
"""Estimate tokens for messages with tiktoken, including tool_calls arguments.
|
|
Based on OpenAI's chat token counting rules approximation.
|
|
"""
|
|
enc = self._encoding_for_model(self.model)
|
|
num_tokens = 0
|
|
for m in messages:
|
|
num_tokens += 4 # per-message overhead
|
|
if not isinstance(m, dict):
|
|
continue
|
|
# role/content
|
|
for key in ("role", "name", "content"):
|
|
v = m.get(key)
|
|
if isinstance(v, str):
|
|
num_tokens += len(enc.encode(v))
|
|
# tool calls request portion (arguments)
|
|
tcs = m.get("tool_calls")
|
|
if tcs and isinstance(tcs, list):
|
|
# approximate cost of the tool_calls JSON the model sees
|
|
for tc in tcs:
|
|
fn = tc.get("function", {}) if isinstance(tc, dict) else {}
|
|
fname = fn.get("name")
|
|
fargs = fn.get("arguments")
|
|
if isinstance(fname, str):
|
|
num_tokens += len(enc.encode(fname))
|
|
if isinstance(fargs, str):
|
|
num_tokens += len(enc.encode(fargs))
|
|
num_tokens += 2 # assistant priming
|
|
return num_tokens
|
|
|
|
def _get_inference_limit(self):
|
|
current_model_is_small = self.model == self.model_config["small_model_name"]
|
|
current_model_is_large = self.model == self.model_config["large_model_name"]
|
|
if current_model_is_small:
|
|
return self.small_model_max_inference_tokens
|
|
if current_model_is_large:
|
|
return self.large_model_max_inference_tokens
|
|
logging.warning(f"Could not determine inference token limit for model: {self.model}. Proceeding without check.")
|
|
return None
|
|
|
|
def _summarize_tool_args(self, args_str: str, max_chars: int = 512) -> str:
|
|
"""Summarize tool-call request arguments without altering tool responses.
|
|
- If JSON, keep keys and short previews of string values.
|
|
- If plain string, truncate with an indicator.
|
|
"""
|
|
try:
|
|
parsed = json.loads(args_str)
|
|
if isinstance(parsed, dict):
|
|
summary = {}
|
|
for k, v in parsed.items():
|
|
if isinstance(v, str):
|
|
if len(v) > 160:
|
|
summary[k] = v[:120] + f"... [len={len(v)}]"
|
|
else:
|
|
summary[k] = v
|
|
elif isinstance(v, (list, dict)):
|
|
# structural summary only
|
|
summary[k] = f"<{type(v).__name__} size={len(v)}>"
|
|
else:
|
|
summary[k] = v
|
|
s = json.dumps(summary, ensure_ascii=False)
|
|
if len(s) > max_chars:
|
|
s = s[: max_chars - 20] + "... [summarized]"
|
|
return s
|
|
except Exception:
|
|
pass
|
|
# Fallback: truncate raw string
|
|
return (args_str[: max_chars - 20] + "... [summarized]") if len(args_str) > max_chars else args_str
|
|
|
|
def _summarize_tool_call_requests_in_messages(self, messages):
|
|
changed = False
|
|
for m in messages:
|
|
if isinstance(m, dict) and m.get("tool_calls"):
|
|
new_tool_calls = []
|
|
for tc in m["tool_calls"]:
|
|
if not isinstance(tc, dict):
|
|
new_tool_calls.append(tc)
|
|
continue
|
|
fn = tc.get("function", {})
|
|
args = fn.get("arguments")
|
|
if isinstance(args, str) and args and len(args) > 700:
|
|
# summarize long request arguments only
|
|
fn = dict(fn)
|
|
fn["arguments"] = self._summarize_tool_args(args)
|
|
tc = dict(tc)
|
|
tc["function"] = fn
|
|
changed = True
|
|
new_tool_calls.append(tc)
|
|
if changed:
|
|
m["tool_calls"] = new_tool_calls
|
|
return changed
|
|
|
|
def _elide_redundant_code_blocks(self, messages):
|
|
"""As a last resort, remove large code blocks from older assistant messages.
|
|
Keep the latest assistant message intact.
|
|
"""
|
|
changed = False
|
|
# Identify indices of assistant messages
|
|
assistant_indices = [i for i, m in enumerate(messages) if isinstance(m, dict) and m.get("role") == "assistant" and m.get("content")]
|
|
if len(assistant_indices) <= 1:
|
|
return changed
|
|
# Protect the last assistant message
|
|
for i in assistant_indices[:-1]:
|
|
m = messages[i]
|
|
content = m.get("content")
|
|
if not isinstance(content, str):
|
|
continue
|
|
if "```" in content or "\n " in content:
|
|
# Replace code blocks fenced by ``` with succinct markers
|
|
orig = content
|
|
content = re.sub(r"```[\s\S]*?```", "[code block omitted]", content)
|
|
# Also collapse long indented blocks
|
|
content = re.sub(r"(?:\n\s{4,}.+)+", "\n[long block omitted]", content)
|
|
if content != orig:
|
|
m["content"] = content
|
|
changed = True
|
|
return changed
|
|
|
|
def get_chat_response(self, messages):
|
|
if not self.client:
|
|
logging.error("OpenAI client not initialized before get_chat_response.")
|
|
raise ValueError("OpenAI client not initialized.")
|
|
try:
|
|
cleaned_tools = None
|
|
if hasattr(self, 'functions') and self.functions:
|
|
cleaned_tools = []
|
|
for func in self.functions:
|
|
include_function = False
|
|
|
|
if not hasattr(self, 'allowed_function_tags') or self.allowed_function_tags is None:
|
|
include_function = True
|
|
else:
|
|
tags = func.get("_tags", [])
|
|
if any(tag in self.allowed_function_tags for tag in tags):
|
|
include_function = True
|
|
|
|
if include_function:
|
|
func_copy = {k: v for k, v in func.items() if k != "_tags"}
|
|
cleaned_tools.append(func_copy)
|
|
|
|
response = self.client.chat.completions.create(
|
|
model=self.model,
|
|
messages=messages,
|
|
tools=cleaned_tools,
|
|
tool_choice="auto" if cleaned_tools else None,
|
|
max_tokens=self.max_tokens,
|
|
|
|
)
|
|
return response
|
|
except Exception as e:
|
|
logging.error(f"API call to model {self.model} failed: {e}")
|
|
raise
|
|
|
|
def get_bot_status(self):
|
|
"""
|
|
Returns a message with the currently enabled model and the system prompt path being used.
|
|
"""
|
|
model_name = self.model if hasattr(self, 'model') else None
|
|
prompt_path = self.system_prompt_path or os.getenv("SYSTEM_PROMPT_PATH") or "(default prompt in use)"
|
|
return f"Current model: {model_name}\nSystem prompt path: {prompt_path}"
|
|
|
|
async def handle_message(self, user_id, user_message):
|
|
if user_id not in self.conversation_history or not self.conversation_history[user_id]:
|
|
self.conversation_history[user_id] = []
|
|
if self.system_prompt:
|
|
github_tool = (GitHubTool)(self.github_tool)
|
|
repo_name = os.environ.get("GITHUB_REPOSITORY")
|
|
sysprompt = self.system_prompt.format(repo_name=repo_name,
|
|
branch=github_tool._get_current_branch())
|
|
self.conversation_history[user_id].append({"role": "system", "content": sysprompt})
|
|
|
|
self.conversation_history[user_id].append({"role": "user", "content": user_message})
|
|
messages = list(self.conversation_history[user_id])
|
|
|
|
response = self.get_chat_response(messages)
|
|
|
|
if not (response.choices and response.choices[0].message):
|
|
logging.error("No valid response choice message from LLM.")
|
|
self.conversation_history[user_id] = messages
|
|
return "Error: Could not get a valid response from the LLM."
|
|
|
|
assistant_message = response.choices[0].message
|
|
messages.append(assistant_message)
|
|
|
|
tool_calls_from_response = list(assistant_message.tool_calls) if assistant_message.tool_calls else []
|
|
|
|
tool_use_count = 0
|
|
MAX_TOOL_ITERATIONS = 200
|
|
|
|
while tool_calls_from_response and tool_use_count < MAX_TOOL_ITERATIONS:
|
|
tool_results_for_model = []
|
|
|
|
for tool_call in tool_calls_from_response:
|
|
tool_call_id = tool_call.id
|
|
function_to_call = tool_call.function
|
|
function_name = function_to_call.name
|
|
function_args_str = function_to_call.arguments
|
|
|
|
logging.info(f"Attempting to call tool: {function_name} with args: [request summarized if large]")
|
|
if function_name not in [f["function"]["name"] for f in self.functions]:
|
|
logging.warning(f"Tool function {function_name} not found in available functions.")
|
|
tool_results_for_model.append({
|
|
"role": "tool",
|
|
"tool_call_id": tool_call_id,
|
|
"name": function_name,
|
|
"content": f"Error: Tool function {function_name} not found."
|
|
})
|
|
continue
|
|
|
|
try:
|
|
tool_response_content = self.call_tool(function_name, function_args_str)
|
|
if not isinstance(tool_response_content, str):
|
|
tool_response_content = json.dumps(tool_response_content)
|
|
except Exception as e:
|
|
logging.error(f"Error calling tool {function_name}: {e}")
|
|
tool_response_content = f"Error executing tool {function_name}: {str(e)}"
|
|
|
|
tool_results_for_model.append({
|
|
"role": "tool",
|
|
"tool_call_id": tool_call_id,
|
|
"name": function_name,
|
|
"content": tool_response_content
|
|
})
|
|
|
|
messages.extend(tool_results_for_model)
|
|
|
|
# Enforce budget before next LLM call (summarize request portion only; preserve tool responses)
|
|
response = self.get_chat_response(messages)
|
|
if not (response.choices and response.choices[0].message):
|
|
logging.error("No valid response choice message from LLM after tool call.")
|
|
self.conversation_history[user_id] = messages
|
|
return "Error: Could not get a valid response from the LLM after tool call."
|
|
|
|
assistant_message = response.choices[0].message
|
|
messages.append(assistant_message)
|
|
|
|
tool_calls_from_response = list(assistant_message.tool_calls) if assistant_message.tool_calls else []
|
|
|
|
tool_use_count += 1
|
|
if tool_use_count >= MAX_TOOL_ITERATIONS and tool_calls_from_response:
|
|
logging.warning(f"Max tool iterations ({MAX_TOOL_ITERATIONS}) reached. Returning last assistant message.")
|
|
break
|
|
|
|
self.conversation_history[user_id] = messages
|
|
|
|
final_assistant_message = messages[-1]
|
|
return final_assistant_message.content if getattr(final_assistant_message, "role", None) == "assistant" and getattr(final_assistant_message, "content", None) is not None else (final_assistant_message.get("content") if isinstance(final_assistant_message, dict) else "Assistant did not provide a textual response.")
|
|
|
|
async def start(self):
|
|
logging.info(f"{self.__class__.__name__} (Model: {self.model}) started.")
|
|
|
|
async def abort_processing(self, user_id):
|
|
if user_id in self.processing_status:
|
|
self.clear_processing_status(user_id)
|
|
logging.info(f"Processing aborted for user {user_id}.")
|
|
return "Processing aborted. You can send a new message or /clear the conversation."
|
|
else:
|
|
return "No active processing found to abort. If you wish, /clear the conversation history."
|
|
|
|
def load_functions(self):
|
|
tools = []
|
|
functions = []
|
|
tools_dir = os.path.join(os.path.dirname(__file__), 'tools')
|
|
if not os.path.exists(tools_dir):
|
|
logging.warning(f"Tools directory not found: {tools_dir}")
|
|
return [], []
|
|
|
|
for filename in os.listdir(tools_dir):
|
|
if filename.endswith('.py') and filename != '__init__.py' and filename != 'base_tool.py':
|
|
module_name = f'tools.{filename[:-3]}'
|
|
try:
|
|
module = importlib.import_module(module_name)
|
|
for name, obj in inspect.getmembers(module):
|
|
if inspect.isclass(obj) and issubclass(obj, BaseTool) and obj != BaseTool:
|
|
try:
|
|
obj_to_add = obj()
|
|
if obj == GitHubTool:
|
|
self.github_tool = obj_to_add
|
|
tools.append(obj_to_add) # This instantiation might be an issue for tools needing config
|
|
except Exception as e:
|
|
logging.error(f"Error instantiating tool {name} from {filename}: {e}")
|
|
except Exception as e:
|
|
logging.error(f"Error importing module {module_name}: {e}")
|
|
|
|
for tool in tools:
|
|
functions.extend(tool.get_functions())
|
|
return tools, functions
|
|
|
|
def load_system_prompt(self, direct_content: str | None = None, file_path: str | None = None) -> str:
|
|
default_prompt = "You are a helpful AI assistant."
|
|
|
|
if direct_content:
|
|
logging.info("Using direct content for system prompt.")
|
|
return direct_content.strip()
|
|
|
|
prompt_path_to_try = file_path or os.getenv("SYSTEM_PROMPT_PATH")
|
|
|
|
if prompt_path_to_try:
|
|
if os.path.isfile(prompt_path_to_try):
|
|
try:
|
|
with open(prompt_path_to_try, "r", encoding="utf-8") as file:
|
|
content = file.read().strip()
|
|
logging.info(f"Successfully loaded system prompt from {prompt_path_to_try}.")
|
|
return content
|
|
except IOError as e:
|
|
logging.warning(f"Could not read system prompt file {prompt_path_to_try}: {e}. Using default.")
|
|
return default_prompt
|
|
else:
|
|
logging.warning(f"System prompt file {prompt_path_to_try} not found. Using default system prompt.")
|
|
return default_prompt
|
|
else:
|
|
logging.info("No system prompt path provided (argument or ENV) or direct content. Using default system prompt.")
|
|
return default_prompt
|
|
|
|
def set_processing_status(self, user_id: int, message_id: int):
|
|
self.processing_status[user_id] = {"processing": True, "message_id": message_id}
|
|
|
|
def clear_processing_status(self, user_id: int):
|
|
if user_id in self.processing_status:
|
|
del self.processing_status[user_id]
|
|
|
|
def call_tool(self, function_call_name, function_call_arguments):
|
|
function_name = function_call_name
|
|
function_args = None
|
|
if isinstance(function_call_arguments, dict):
|
|
function_args = function_call_arguments
|
|
elif isinstance(function_call_arguments, str):
|
|
try:
|
|
function_args = json.loads(function_call_arguments)
|
|
except json.JSONDecodeError as e:
|
|
logging.error(f"Error decoding function call arguments (string) for {function_call_name}: {e}. Arguments: {function_call_arguments}")
|
|
return f"Error: Malformed arguments for tool call: {e}"
|
|
else:
|
|
if function_call_arguments is None:
|
|
function_args = {}
|
|
else:
|
|
logging.error(f"Unexpected type for function_call_arguments for {function_name}: {type(function_call_arguments)}. Arguments: {function_call_arguments}")
|
|
return f"Error: Invalid argument type for tool call: {type(function_call_arguments)}"
|
|
|
|
for tool in self.tools:
|
|
for function in tool.get_functions():
|
|
if function["function"]["name"] == function_name:
|
|
try:
|
|
if not isinstance(function_args, dict):
|
|
logging.error(f"Internal error: function_args not a dict for {function_name} before execution. Args: {function_args}")
|
|
return f"Internal error preparing arguments for tool {function_name}."
|
|
return tool.execute(function_name, **function_args)
|
|
except Exception as e:
|
|
logging.error(f"Error executing tool {function_name} with args {function_args}: {e}")
|
|
return f"Error executing tool {function_name}: {e}"
|
|
logging.warning(f"Tool function {function_name} not found.")
|
|
return f"Error: Tool function {function_name} not found."
|
|
|
|
async def switch_model(self):
|
|
if not self.model_config["small_model_name"] or not self.model_config["large_model_name"]:
|
|
logging.warning("Small or Large model names are not defined. Cannot switch model.")
|
|
return f"Model switching not fully configured. Currently using {self.model}."
|
|
|
|
current_is_small = self.model == self.model_config["small_model_name"]
|
|
current_is_large = self.model == self.model_config["large_model_name"]
|
|
|
|
if current_is_large:
|
|
target_model = self.model_config["small_model_name"]
|
|
target_max_tokens_str = self.model_config["small_model_max_tokens"]
|
|
elif current_is_small:
|
|
target_model = self.model_config["large_model_name"]
|
|
target_max_tokens_str = self.model_config["large_model_max_tokens"]
|
|
else:
|
|
logging.warning(f"Current model {self.model} is unrecognized. Switching to default small model: {self.model_config['small_model_name']}.")
|
|
target_model = self.model_config["small_model_name"]
|
|
target_max_tokens_str = self.model_config["small_model_max_tokens"]
|
|
|
|
self._configure_model_and_tokens(target_model, target_max_tokens_str)
|
|
return f"Switched model to {self.model}. Max tokens set to {self.max_tokens if self.max_tokens is not None else 'API default'}."
|
|
|
|
def main():
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
|
|
bot = None
|
|
|
|
try:
|
|
parser = argparse.ArgumentParser(description='OpenAI Compatible Inference Bot')
|
|
parser.add_argument('--config', type=str, help='Configuration Prepend (i.e. gemini, openai, etc)', default="Telegram")
|
|
parser.add_argument('--messenger', type=str, help='Messenger type (i.e. telegram)', required=True)
|
|
parser.add_argument('--persona', type=str, help='Path to system prompt file', required=False)
|
|
parser.add_argument('--tools', nargs='+', help='List of allowed function tags', required=False)
|
|
parser.add_argument('--use-large-model', action='store_true', help='Use the large model instead of the small model')
|
|
# Add these to launch.json arguments if you want to limit the toolset available: "--tools", "read", "communicate"
|
|
# Parse command line arguments
|
|
args = parser.parse_args()
|
|
if args.persona:
|
|
logging.info(f"Using custom persona from: {args.persona}")
|
|
|
|
|
|
system_prompt_path=args.persona if args.persona else None
|
|
allowed_function_tags=args.tools if args.tools else None
|
|
config_prepend = args.config if args.config else None
|
|
messenger = args.messenger if args.messenger else None
|
|
use_large_model = args.use_large_model
|
|
|
|
# Initialize model and max tokens based on the config prepend
|
|
if config_prepend:
|
|
api_key = os.environ.get(f"{config_prepend.upper()}_API_KEY")
|
|
baseurl = os.environ.get(f"{config_prepend.upper()}_API_BASE_URL", "")
|
|
small_model_name = os.environ.get(f"{config_prepend.upper()}_SMALL_MODEL")
|
|
system_prompt_path = os.environ.get(f"{config_prepend.upper()}_SMALL_MODEL_SYSTEM_PROMPT_PATH")
|
|
large_model_name = os.environ.get(f"{config_prepend.upper()}_LARGE_MODEL")
|
|
small_model_max_tokens = os.environ.get(f"{config_prepend.upper()}_SMALL_MODEL_MAX_TOKENS")
|
|
large_model_max_tokens = os.environ.get(f"{config_prepend.upper()}_LARGE_MODEL_MAX_TOKENS")
|
|
|
|
bot = OpenAICompatibleInferenceBot(
|
|
api_key=api_key,
|
|
base_url=baseurl,
|
|
small_model_name=small_model_name,
|
|
small_model_max_tokens=small_model_max_tokens,
|
|
large_model_name=large_model_name,
|
|
large_model_max_tokens=large_model_max_tokens,
|
|
system_prompt_path=system_prompt_path,
|
|
allowed_function_tags=allowed_function_tags,
|
|
use_large_model=use_large_model
|
|
)
|
|
full_code_file = importlib.import_module(f'{messenger.lower()}_helper')
|
|
messenger_helper_class_name = f"{messenger.capitalize()}Helper"
|
|
if not hasattr(full_code_file, messenger_helper_class_name):
|
|
messenger_helper_class_name = f"{messenger.upper()}Helper"
|
|
if not hasattr(full_code_file, messenger_helper_class_name):
|
|
raise ValueError(f"Messenger helper class {messenger_helper_class_name} not found in {full_code_file.__name__}.")
|
|
helper_class = getattr(full_code_file, messenger_helper_class_name)
|
|
|
|
helper = helper_class(bot)
|
|
helper.run()
|
|
except ValueError as e:
|
|
logging.error(f"FATAL: {e}")
|
|
return
|
|
except Exception as e: # Catch any other init errors
|
|
logging.error(f"An unexpected error occurred during bot initialization: {e}")
|
|
return
|
|
|
|
if __name__ == '__main__':
|
|
main()
|