2025-06-03 13:04:42 -05:00
import importlib
2025-06-02 14:56:23 -05:00
import json
import os
import logging
2025-06-03 13:04:42 -05:00
import inspect
2025-08-13 14:25:13 -05:00
import re
2025-06-02 14:56:23 -05:00
from abc import abstractmethod
2025-06-03 13:04:42 -05:00
from openai import OpenAI
from tools . base_tool import BaseTool
from telegram_helper import TelegramHelper
import argparse
from inference_bot import InferenceBot
2025-06-06 14:25:15 -05:00
import tiktoken # Added this import
2025-06-02 16:43:39 -05:00
2025-06-03 13:04:42 -05:00
class OpenAICompatibleInferenceBot ( InferenceBot ) :
2025-06-02 16:43:39 -05:00
def __init__ (
self ,
api_key : str | None = None ,
base_url : str | None = None ,
2025-06-03 13:04:42 -05:00
small_model_name : str | None = None ,
small_model_max_tokens : str | None = None ,
large_model_name : str | None = None ,
large_model_max_tokens : str | None = None ,
allowed_function_tags : list [ str ] | None = None ,
2025-06-05 18:03:18 -05:00
system_prompt_path : str | None = None ,
2025-06-06 14:25:15 -05:00
use_large_model : bool = False
2025-06-02 16:43:39 -05:00
) :
2025-06-03 13:04:42 -05:00
self . model_config = {
" small_model_name " : small_model_name ,
" small_model_max_tokens " : small_model_max_tokens ,
" large_model_name " : large_model_name ,
" large_model_max_tokens " : large_model_max_tokens
}
self . allowed_function_tags = allowed_function_tags if allowed_function_tags else None
self . conversation_history = { }
self . _processing_status = { }
2025-06-06 14:25:15 -05:00
self . system_prompt_path = system_prompt_path
2025-06-03 13:04:42 -05:00
self . system_prompt = self . load_system_prompt (
file_path = system_prompt_path
)
self . tools , self . functions = self . load_functions ( )
self . client = OpenAI ( api_key = api_key , base_url = base_url )
log_msg = f " Initialized OpenAI compatible client. Target URL: { base_url if base_url else ' OpenAI default ' } . "
logging . info ( log_msg )
2025-06-02 16:43:39 -05:00
2025-08-13 14:25:13 -05:00
# Load inference token limits (defaults: small=16k, large=32k)
self . small_model_max_inference_tokens = int ( os . getenv ( " _SMALL_MODEL_MAX_INFERENCE_TOKENS " , " 16384 " ) )
2025-06-06 14:25:15 -05:00
self . large_model_max_inference_tokens = int ( os . getenv ( " _LARGE_MODEL_MAX_INFERENCE_TOKENS " , " 32768 " ) )
2025-06-02 16:43:39 -05:00
# Configure the actual model name and max_tokens for API calls
2025-06-05 18:03:18 -05:00
if use_large_model :
self . _configure_model_and_tokens (
self . model_config [ " large_model_name " ] ,
self . model_config [ " large_model_max_tokens " ]
)
else :
self . _configure_model_and_tokens (
self . model_config [ " small_model_name " ] ,
self . model_config [ " small_model_max_tokens " ]
)
2025-06-06 14:25:15 -05:00
2025-06-03 13:04:42 -05:00
@property
def processing_status ( self ) :
return self . _processing_status
def clear_conversation_history ( self , user_id ) :
if user_id in self . conversation_history :
del self . conversation_history [ user_id ]
for tool in self . tools :
tool . clear ( )
2025-06-02 16:43:39 -05:00
2025-06-03 13:04:42 -05:00
def _configure_model_and_tokens ( self , model_name : str | None , max_tokens_str : str | None ) :
self . model = model_name
2025-06-02 14:56:23 -05:00
try :
2025-06-02 16:43:39 -05:00
if max_tokens_str and max_tokens_str . lower ( ) not in [ " none " , " " , " null " ] :
self . max_tokens = int ( max_tokens_str )
else :
2025-06-06 14:25:15 -05:00
self . max_tokens = None
2025-06-02 14:56:23 -05:00
except ValueError :
2025-06-03 13:04:42 -05:00
logging . warning ( f " Invalid value for max_tokens: { max_tokens_str } . Using API default (None) " )
2025-06-06 14:25:15 -05:00
self . max_tokens = None
2025-06-02 16:43:39 -05:00
logging . info ( f " Configured to use model: { self . model } with max_tokens: { self . max_tokens if self . max_tokens is not None else ' API default ' } " )
2025-06-02 14:56:23 -05:00
def get_llm_description ( self ) - > str :
2025-06-02 16:43:39 -05:00
client_type = type ( self . client ) . __name__
return f " Client: { client_type } , LLM: { self . model } , Max Tokens: { self . max_tokens if self . max_tokens is not None else ' API default ' } "
2025-06-02 14:56:23 -05:00
2025-08-13 14:25:13 -05:00
def _encoding_for_model ( self , model : str | None ) :
2025-06-06 14:25:15 -05:00
try :
2025-08-13 14:25:13 -05:00
return tiktoken . encoding_for_model ( model ) if model else tiktoken . get_encoding ( " cl100k_base " )
2025-06-06 14:25:15 -05:00
except KeyError :
logging . warning ( f " Warning: model { model } not found. Using cl100k_base encoding. " )
2025-08-13 14:25:13 -05:00
return tiktoken . get_encoding ( " cl100k_base " )
2025-06-06 14:25:15 -05:00
2025-08-13 14:25:13 -05:00
def _normalize_messages ( self , messages ) :
""" Return a list of plain dict chat messages acceptable by the API.
- Converts OpenAI SDK message objects into dicts
- Preserves tool_calls structure where present
"""
normalized = [ ]
for m in messages :
if isinstance ( m , dict ) :
# Ensure only known keys are present; copy shallowly
entry = { k : v for k , v in m . items ( ) if k in { " role " , " content " , " name " , " tool_call_id " , " tool_calls " } }
normalized . append ( entry )
else :
# Likely an OpenAI message object
role = getattr ( m , " role " , None )
content = getattr ( m , " content " , None )
name = getattr ( m , " name " , None )
tool_calls = [ ]
tc_list = getattr ( m , " tool_calls " , None )
if tc_list :
for tc in tc_list :
try :
tool_calls . append ( {
" id " : getattr ( tc , " id " , None ) ,
" type " : getattr ( tc , " type " , " function " ) ,
" function " : {
" name " : getattr ( getattr ( tc , " function " , None ) , " name " , None ) ,
" arguments " : getattr ( getattr ( tc , " function " , None ) , " arguments " , " {} " ) ,
}
} )
except Exception :
# Best-effort fallback
tool_calls . append ( { " id " : None , " type " : " function " , " function " : { " name " : " unknown " , " arguments " : " {} " } } )
entry = { " role " : role , " content " : content }
if name :
entry [ " name " ] = name
if tool_calls :
entry [ " tool_calls " ] = tool_calls
normalized . append ( entry )
return normalized
def _estimate_tokens ( self , messages ) :
""" Estimate tokens for messages with tiktoken, including tool_calls arguments.
Based on OpenAI ' s chat token counting rules approximation.
"""
enc = self . _encoding_for_model ( self . model )
2025-06-06 14:25:15 -05:00
num_tokens = 0
2025-08-13 14:25:13 -05:00
for m in messages :
num_tokens + = 4 # per-message overhead
if not isinstance ( m , dict ) :
continue
# role/content
for key in ( " role " , " name " , " content " ) :
v = m . get ( key )
if isinstance ( v , str ) :
num_tokens + = len ( enc . encode ( v ) )
# tool calls request portion (arguments)
tcs = m . get ( " tool_calls " )
if tcs and isinstance ( tcs , list ) :
# approximate cost of the tool_calls JSON the model sees
for tc in tcs :
fn = tc . get ( " function " , { } ) if isinstance ( tc , dict ) else { }
fname = fn . get ( " name " )
fargs = fn . get ( " arguments " )
if isinstance ( fname , str ) :
num_tokens + = len ( enc . encode ( fname ) )
if isinstance ( fargs , str ) :
num_tokens + = len ( enc . encode ( fargs ) )
num_tokens + = 2 # assistant priming
2025-06-06 14:25:15 -05:00
return num_tokens
2025-08-13 14:25:13 -05:00
def _get_inference_limit ( self ) :
current_model_is_small = self . model == self . model_config [ " small_model_name " ]
current_model_is_large = self . model == self . model_config [ " large_model_name " ]
if current_model_is_small :
return self . small_model_max_inference_tokens
if current_model_is_large :
return self . large_model_max_inference_tokens
logging . warning ( f " Could not determine inference token limit for model: { self . model } . Proceeding without check. " )
return None
def _summarize_tool_args ( self , args_str : str , max_chars : int = 512 ) - > str :
""" Summarize tool-call request arguments without altering tool responses.
- If JSON, keep keys and short previews of string values.
- If plain string, truncate with an indicator.
"""
try :
parsed = json . loads ( args_str )
if isinstance ( parsed , dict ) :
summary = { }
for k , v in parsed . items ( ) :
if isinstance ( v , str ) :
if len ( v ) > 160 :
summary [ k ] = v [ : 120 ] + f " ... [len= { len ( v ) } ] "
else :
summary [ k ] = v
elif isinstance ( v , ( list , dict ) ) :
# structural summary only
summary [ k ] = f " < { type ( v ) . __name__ } size= { len ( v ) } > "
else :
summary [ k ] = v
s = json . dumps ( summary , ensure_ascii = False )
if len ( s ) > max_chars :
s = s [ : max_chars - 20 ] + " ... [summarized] "
return s
except Exception :
pass
# Fallback: truncate raw string
return ( args_str [ : max_chars - 20 ] + " ... [summarized] " ) if len ( args_str ) > max_chars else args_str
def _summarize_tool_call_requests_in_messages ( self , messages ) :
changed = False
for m in messages :
if isinstance ( m , dict ) and m . get ( " tool_calls " ) :
new_tool_calls = [ ]
for tc in m [ " tool_calls " ] :
if not isinstance ( tc , dict ) :
new_tool_calls . append ( tc )
continue
fn = tc . get ( " function " , { } )
args = fn . get ( " arguments " )
if isinstance ( args , str ) and args and len ( args ) > 700 :
# summarize long request arguments only
fn = dict ( fn )
fn [ " arguments " ] = self . _summarize_tool_args ( args )
tc = dict ( tc )
tc [ " function " ] = fn
changed = True
new_tool_calls . append ( tc )
if changed :
m [ " tool_calls " ] = new_tool_calls
return changed
def _elide_redundant_code_blocks ( self , messages ) :
""" As a last resort, remove large code blocks from older assistant messages.
Keep the latest assistant message intact.
"""
changed = False
# Identify indices of assistant messages
assistant_indices = [ i for i , m in enumerate ( messages ) if isinstance ( m , dict ) and m . get ( " role " ) == " assistant " and m . get ( " content " ) ]
if len ( assistant_indices ) < = 1 :
return changed
# Protect the last assistant message
for i in assistant_indices [ : - 1 ] :
m = messages [ i ]
content = m . get ( " content " )
if not isinstance ( content , str ) :
continue
if " ``` " in content or " \n " in content :
# Replace code blocks fenced by ``` with succinct markers
orig = content
content = re . sub ( r " ```[ \ s \ S]*?``` " , " [code block omitted] " , content )
# Also collapse long indented blocks
content = re . sub ( r " (?: \ n \ s { 4,}.+)+ " , " \n [long block omitted] " , content )
if content != orig :
m [ " content " ] = content
changed = True
return changed
2025-06-02 14:56:23 -05:00
def get_chat_response ( self , messages ) :
if not self . client :
2025-06-02 16:43:39 -05:00
logging . error ( " OpenAI client not initialized before get_chat_response. " )
raise ValueError ( " OpenAI client not initialized. " )
2025-06-02 14:56:23 -05:00
try :
2025-06-03 13:04:42 -05:00
cleaned_tools = None
if hasattr ( self , ' functions ' ) and self . functions :
cleaned_tools = [ ]
for func in self . functions :
include_function = False
if not hasattr ( self , ' allowed_function_tags ' ) or self . allowed_function_tags is None :
include_function = True
else :
tags = func . get ( " _tags " , [ ] )
if any ( tag in self . allowed_function_tags for tag in tags ) :
include_function = True
if include_function :
func_copy = { k : v for k , v in func . items ( ) if k != " _tags " }
cleaned_tools . append ( func_copy )
2025-06-02 14:56:23 -05:00
response = self . client . chat . completions . create (
2025-06-02 16:43:39 -05:00
model = self . model ,
2026-01-21 13:41:03 -06:00
messages = messages ,
2025-06-03 13:04:42 -05:00
tools = cleaned_tools ,
tool_choice = " auto " if cleaned_tools else None ,
2025-08-07 15:38:01 -05:00
max_tokens = self . max_tokens ,
2025-06-02 14:56:23 -05:00
)
return response
except Exception as e :
2025-06-02 16:43:39 -05:00
logging . error ( f " API call to model { self . model } failed: { e } " )
2025-06-02 14:56:23 -05:00
raise
2025-06-03 14:04:27 -05:00
def get_bot_status ( self ) :
"""
Returns a message with the currently enabled model and the system prompt path being used.
"""
model_name = self . model if hasattr ( self , ' model ' ) else None
prompt_path = self . system_prompt_path or os . getenv ( " SYSTEM_PROMPT_PATH " ) or " (default prompt in use) "
return f " Current model: { model_name } \n System prompt path: { prompt_path } "
2025-06-02 14:56:23 -05:00
async def handle_message ( self , user_id , user_message ) :
2025-06-02 16:43:39 -05:00
if user_id not in self . conversation_history or not self . conversation_history [ user_id ] :
2025-06-02 14:56:23 -05:00
self . conversation_history [ user_id ] = [ ]
2025-06-06 14:25:15 -05:00
if self . system_prompt :
2025-06-02 14:56:23 -05:00
self . conversation_history [ user_id ] . append ( { " role " : " system " , " content " : self . system_prompt } )
self . conversation_history [ user_id ] . append ( { " role " : " user " , " content " : user_message } )
2025-06-06 14:25:15 -05:00
messages = list ( self . conversation_history [ user_id ] )
2025-06-02 14:56:23 -05:00
response = self . get_chat_response ( messages )
if not ( response . choices and response . choices [ 0 ] . message ) :
logging . error ( " No valid response choice message from LLM. " )
2025-06-02 16:43:39 -05:00
self . conversation_history [ user_id ] = messages
2025-06-02 14:56:23 -05:00
return " Error: Could not get a valid response from the LLM. "
2025-06-02 16:43:39 -05:00
assistant_message = response . choices [ 0 ] . message
messages . append ( assistant_message )
2025-06-02 14:56:23 -05:00
2025-06-02 16:43:39 -05:00
tool_calls_from_response = list ( assistant_message . tool_calls ) if assistant_message . tool_calls else [ ]
2025-06-02 14:56:23 -05:00
tool_use_count = 0
2025-06-02 19:35:41 -05:00
MAX_TOOL_ITERATIONS = 200
2025-06-02 14:56:23 -05:00
while tool_calls_from_response and tool_use_count < MAX_TOOL_ITERATIONS :
tool_results_for_model = [ ]
for tool_call in tool_calls_from_response :
tool_call_id = tool_call . id
function_to_call = tool_call . function
2025-06-02 16:43:39 -05:00
function_name = function_to_call . name
function_args_str = function_to_call . arguments
2025-06-02 14:56:23 -05:00
2025-08-13 14:25:13 -05:00
logging . info ( f " Attempting to call tool: { function_name } with args: [request summarized if large] " )
2025-06-03 17:36:26 -05:00
if function_name not in [ f [ " function " ] [ " name " ] for f in self . functions ] :
2025-06-03 17:32:19 -05:00
logging . warning ( f " Tool function { function_name } not found in available functions. " )
tool_results_for_model . append ( {
" role " : " tool " ,
" tool_call_id " : tool_call_id ,
" name " : function_name ,
" content " : f " Error: Tool function { function_name } not found. "
} )
continue
2025-06-02 14:56:23 -05:00
try :
2025-06-02 16:43:39 -05:00
tool_response_content = self . call_tool ( function_name , function_args_str )
2025-06-02 14:56:23 -05:00
if not isinstance ( tool_response_content , str ) :
tool_response_content = json . dumps ( tool_response_content )
except Exception as e :
2025-06-02 16:43:39 -05:00
logging . error ( f " Error calling tool { function_name } : { e } " )
tool_response_content = f " Error executing tool { function_name } : { str ( e ) } "
2025-06-02 14:56:23 -05:00
tool_results_for_model . append ( {
" role " : " tool " ,
" tool_call_id " : tool_call_id ,
2025-06-02 16:43:39 -05:00
" name " : function_name ,
2025-06-02 14:56:23 -05:00
" content " : tool_response_content
} )
messages . extend ( tool_results_for_model )
2025-08-13 14:25:13 -05:00
# Enforce budget before next LLM call (summarize request portion only; preserve tool responses)
2025-06-02 14:56:23 -05:00
response = self . get_chat_response ( messages )
if not ( response . choices and response . choices [ 0 ] . message ) :
logging . error ( " No valid response choice message from LLM after tool call. " )
2025-06-06 14:25:15 -05:00
self . conversation_history [ user_id ] = messages
2025-06-02 14:56:23 -05:00
return " Error: Could not get a valid response from the LLM after tool call. "
2025-06-02 16:43:39 -05:00
assistant_message = response . choices [ 0 ] . message
messages . append ( assistant_message )
2025-06-02 14:56:23 -05:00
2025-06-02 16:43:39 -05:00
tool_calls_from_response = list ( assistant_message . tool_calls ) if assistant_message . tool_calls else [ ]
2025-06-02 14:56:23 -05:00
tool_use_count + = 1
if tool_use_count > = MAX_TOOL_ITERATIONS and tool_calls_from_response :
logging . warning ( f " Max tool iterations ( { MAX_TOOL_ITERATIONS } ) reached. Returning last assistant message. " )
2025-06-02 16:43:39 -05:00
break
2025-06-02 17:13:05 -05:00
self . conversation_history [ user_id ] = messages
2025-06-02 14:56:23 -05:00
final_assistant_message = messages [ - 1 ]
2025-08-13 14:25:13 -05:00
return final_assistant_message . content if getattr ( final_assistant_message , " role " , None ) == " assistant " and getattr ( final_assistant_message , " content " , None ) is not None else ( final_assistant_message . get ( " content " ) if isinstance ( final_assistant_message , dict ) else " Assistant did not provide a textual response. " )
2025-06-02 14:56:23 -05:00
async def start ( self ) :
2025-06-02 16:43:39 -05:00
logging . info ( f " { self . __class__ . __name__ } (Model: { self . model } ) started. " )
2025-06-02 14:56:23 -05:00
async def abort_processing ( self , user_id ) :
if user_id in self . processing_status :
2025-06-06 14:25:15 -05:00
self . clear_processing_status ( user_id )
2025-06-02 16:43:39 -05:00
logging . info ( f " Processing aborted for user { user_id } . " )
return " Processing aborted. You can send a new message or /clear the conversation. "
2025-06-02 14:56:23 -05:00
else :
2025-06-02 16:43:39 -05:00
return " No active processing found to abort. If you wish, /clear the conversation history. "
2025-06-03 13:04:42 -05:00
def load_functions ( self ) :
tools = [ ]
functions = [ ]
tools_dir = os . path . join ( os . path . dirname ( __file__ ) , ' tools ' )
if not os . path . exists ( tools_dir ) :
logging . warning ( f " Tools directory not found: { tools_dir } " )
return [ ] , [ ]
2025-06-02 14:56:23 -05:00
2025-06-03 13:04:42 -05:00
for filename in os . listdir ( tools_dir ) :
if filename . endswith ( ' .py ' ) and filename != ' __init__.py ' and filename != ' base_tool.py ' :
module_name = f ' tools. { filename [ : - 3 ] } '
try :
module = importlib . import_module ( module_name )
2025-06-05 18:06:13 -05:00
for name , obj in inspect . getmembers ( module ) :
2025-06-03 13:04:42 -05:00
if inspect . isclass ( obj ) and issubclass ( obj , BaseTool ) and obj != BaseTool :
try :
tools . append ( obj ( ) ) # This instantiation might be an issue for tools needing config
except Exception as e :
logging . error ( f " Error instantiating tool { name } from { filename } : { e } " )
except Exception as e :
logging . error ( f " Error importing module { module_name } : { e } " )
for tool in tools :
functions . extend ( tool . get_functions ( ) )
return tools , functions
def load_system_prompt ( self , direct_content : str | None = None , file_path : str | None = None ) - > str :
default_prompt = " You are a helpful AI assistant. "
if direct_content :
logging . info ( " Using direct content for system prompt. " )
return direct_content . strip ( )
prompt_path_to_try = file_path or os . getenv ( " SYSTEM_PROMPT_PATH " )
if prompt_path_to_try :
if os . path . isfile ( prompt_path_to_try ) :
try :
with open ( prompt_path_to_try , " r " , encoding = " utf-8 " ) as file :
content = file . read ( ) . strip ( )
logging . info ( f " Successfully loaded system prompt from { prompt_path_to_try } . " )
return content
except IOError as e :
logging . warning ( f " Could not read system prompt file { prompt_path_to_try } : { e } . Using default. " )
return default_prompt
else :
logging . warning ( f " System prompt file { prompt_path_to_try } not found. Using default system prompt. " )
return default_prompt
else :
logging . info ( " No system prompt path provided (argument or ENV) or direct content. Using default system prompt. " )
return default_prompt
def set_processing_status ( self , user_id : int , message_id : int ) :
self . processing_status [ user_id ] = { " processing " : True , " message_id " : message_id }
def clear_processing_status ( self , user_id : int ) :
if user_id in self . processing_status :
del self . processing_status [ user_id ]
def call_tool ( self , function_call_name , function_call_arguments ) :
function_name = function_call_name
function_args = None
if isinstance ( function_call_arguments , dict ) :
function_args = function_call_arguments
elif isinstance ( function_call_arguments , str ) :
try :
function_args = json . loads ( function_call_arguments )
except json . JSONDecodeError as e :
logging . error ( f " Error decoding function call arguments (string) for { function_call_name } : { e } . Arguments: { function_call_arguments } " )
return f " Error: Malformed arguments for tool call: { e } "
else :
if function_call_arguments is None :
function_args = { }
else :
2025-06-03 14:04:27 -05:00
logging . error ( f " Unexpected type for function_call_arguments for { function_name } : { type ( function_call_arguments ) } . Arguments: { function_call_arguments } " )
2025-06-03 13:04:42 -05:00
return f " Error: Invalid argument type for tool call: { type ( function_call_arguments ) } "
for tool in self . tools :
for function in tool . get_functions ( ) :
if function [ " function " ] [ " name " ] == function_name :
try :
if not isinstance ( function_args , dict ) :
logging . error ( f " Internal error: function_args not a dict for { function_name } before execution. Args: { function_args } " )
return f " Internal error preparing arguments for tool { function_name } . "
return tool . execute ( function_name , * * function_args )
except Exception as e :
logging . error ( f " Error executing tool { function_name } with args { function_args } : { e } " )
return f " Error executing tool { function_name } : { e } "
logging . warning ( f " Tool function { function_name } not found. " )
return f " Error: Tool function { function_name } not found. "
2025-06-02 14:56:23 -05:00
async def switch_model ( self ) :
2025-06-03 13:04:42 -05:00
if not self . model_config [ " small_model_name " ] or not self . model_config [ " large_model_name " ] :
logging . warning ( " Small or Large model names are not defined. Cannot switch model. " )
return f " Model switching not fully configured. Currently using { self . model } . "
current_is_small = self . model == self . model_config [ " small_model_name " ]
current_is_large = self . model == self . model_config [ " large_model_name " ]
if current_is_large :
target_model = self . model_config [ " small_model_name " ]
target_max_tokens_str = self . model_config [ " small_model_max_tokens " ]
elif current_is_small :
target_model = self . model_config [ " large_model_name " ]
target_max_tokens_str = self . model_config [ " large_model_max_tokens " ]
else :
logging . warning ( f " Current model { self . model } is unrecognized. Switching to default small model: { self . model_config [ ' small_model_name ' ] } . " )
target_model = self . model_config [ " small_model_name " ]
target_max_tokens_str = self . model_config [ " small_model_max_tokens " ]
self . _configure_model_and_tokens ( target_model , target_max_tokens_str )
2025-06-03 13:54:38 -05:00
return f " Switched model to { self . model } . Max tokens set to { self . max_tokens if self . max_tokens is not None else ' API default ' } . "
2025-06-03 13:04:42 -05:00
def main ( ) :
logging . basicConfig ( level = logging . INFO , format = ' %(asctime)s - %(name)s - %(levelname)s - %(message)s ' )
bot = None
try :
parser = argparse . ArgumentParser ( description = ' OpenAI Compatible Inference Bot ' )
parser . add_argument ( ' --config ' , type = str , help = ' Configuration Prepend (i.e. gemini, openai, etc) ' , default = " Telegram " )
parser . add_argument ( ' --messenger ' , type = str , help = ' Messenger type (i.e. telegram) ' , required = True )
parser . add_argument ( ' --persona ' , type = str , help = ' Path to system prompt file ' , required = False )
parser . add_argument ( ' --tools ' , nargs = ' + ' , help = ' List of allowed function tags ' , required = False )
2025-06-06 14:25:15 -05:00
parser . add_argument ( ' --use-large-model ' , action = ' store_true ' , help = ' Use the large model instead of the small model ' )
2025-06-03 13:04:42 -05:00
# Add these to launch.json arguments if you want to limit the toolset available: "--tools", "read", "communicate"
# Parse command line arguments
args = parser . parse_args ( )
if args . persona :
logging . info ( f " Using custom persona from: { args . persona } " )
system_prompt_path = args . persona if args . persona else None
allowed_function_tags = args . tools if args . tools else None
config_prepend = args . config if args . config else None
messenger = args . messenger if args . messenger else None
2025-06-06 14:25:15 -05:00
use_large_model = args . use_large_model
2025-06-03 13:04:42 -05:00
# Initialize model and max tokens based on the config prepend
if config_prepend :
api_key = os . environ . get ( f " { config_prepend . upper ( ) } _API_KEY " )
baseurl = os . environ . get ( f " { config_prepend . upper ( ) } _API_BASE_URL " , " " )
small_model_name = os . environ . get ( f " { config_prepend . upper ( ) } _SMALL_MODEL " )
large_model_name = os . environ . get ( f " { config_prepend . upper ( ) } _LARGE_MODEL " )
small_model_max_tokens = os . environ . get ( f " { config_prepend . upper ( ) } _SMALL_MODEL_MAX_TOKENS " )
large_model_max_tokens = os . environ . get ( f " { config_prepend . upper ( ) } _LARGE_MODEL_MAX_TOKENS " )
bot = OpenAICompatibleInferenceBot (
api_key = api_key ,
base_url = baseurl ,
small_model_name = small_model_name ,
small_model_max_tokens = small_model_max_tokens ,
large_model_name = large_model_name ,
large_model_max_tokens = large_model_max_tokens ,
system_prompt_path = system_prompt_path ,
2025-06-05 18:03:18 -05:00
allowed_function_tags = allowed_function_tags ,
2025-06-06 14:25:15 -05:00
use_large_model = use_large_model
2025-06-03 13:04:42 -05:00
)
2025-06-03 17:32:19 -05:00
full_code_file = importlib . import_module ( f ' { messenger . lower ( ) } _helper ' )
2025-06-03 13:04:42 -05:00
messenger_helper_class_name = f " { messenger . capitalize ( ) } Helper "
2025-06-03 17:32:19 -05:00
if not hasattr ( full_code_file , messenger_helper_class_name ) :
messenger_helper_class_name = f " { messenger . upper ( ) } Helper "
if not hasattr ( full_code_file , messenger_helper_class_name ) :
raise ValueError ( f " Messenger helper class { messenger_helper_class_name } not found in { full_code_file . __name__ } . " )
helper_class = getattr ( full_code_file , messenger_helper_class_name )
2025-06-03 13:04:42 -05:00
2025-06-03 17:32:19 -05:00
helper = helper_class ( bot )
2025-06-03 13:04:42 -05:00
helper . run ( )
except ValueError as e :
logging . error ( f " FATAL: { e } " )
return
except Exception as e : # Catch any other init errors
logging . error ( f " An unexpected error occurred during bot initialization: { e } " )
return
if __name__ == ' __main__ ' :
2025-06-03 14:04:27 -05:00
main ( )