2025-06-03 13:04:42 -05:00
import importlib
2025-06-02 14:56:23 -05:00
import json
import os
import logging
2025-06-03 13:04:42 -05:00
import inspect
2025-06-02 14:56:23 -05:00
from abc import abstractmethod
2025-06-03 13:04:42 -05:00
from openai import OpenAI
from tools . base_tool import BaseTool
from telegram_helper import TelegramHelper
import argparse
from inference_bot import InferenceBot
2025-06-06 14:25:15 -05:00
import tiktoken # Added this import
2025-06-02 16:43:39 -05:00
2025-06-03 13:04:42 -05:00
class OpenAICompatibleInferenceBot ( InferenceBot ) :
2025-06-02 16:43:39 -05:00
def __init__ (
self ,
api_key : str | None = None ,
base_url : str | None = None ,
2025-06-03 13:04:42 -05:00
small_model_name : str | None = None ,
small_model_max_tokens : str | None = None ,
large_model_name : str | None = None ,
large_model_max_tokens : str | None = None ,
allowed_function_tags : list [ str ] | None = None ,
2025-06-05 18:03:18 -05:00
system_prompt_path : str | None = None ,
2025-06-06 14:25:15 -05:00
use_large_model : bool = False
2025-06-02 16:43:39 -05:00
) :
2025-06-03 13:04:42 -05:00
self . model_config = {
" small_model_name " : small_model_name ,
" small_model_max_tokens " : small_model_max_tokens ,
" large_model_name " : large_model_name ,
" large_model_max_tokens " : large_model_max_tokens
}
self . allowed_function_tags = allowed_function_tags if allowed_function_tags else None
self . conversation_history = { }
self . _processing_status = { }
2025-06-06 14:25:15 -05:00
self . system_prompt_path = system_prompt_path
2025-06-03 13:04:42 -05:00
self . system_prompt = self . load_system_prompt (
file_path = system_prompt_path
)
self . tools , self . functions = self . load_functions ( )
self . client = OpenAI ( api_key = api_key , base_url = base_url )
log_msg = f " Initialized OpenAI compatible client. Target URL: { base_url if base_url else ' OpenAI default ' } . "
logging . info ( log_msg )
2025-06-02 16:43:39 -05:00
2025-06-06 14:25:15 -05:00
# Load inference token limits
self . small_model_max_inference_tokens = int ( os . getenv ( " _SMALL_MODEL_MAX_INFERENCE_TOKENS " , " 32768 " ) )
self . large_model_max_inference_tokens = int ( os . getenv ( " _LARGE_MODEL_MAX_INFERENCE_TOKENS " , " 32768 " ) )
2025-06-02 16:43:39 -05:00
# Configure the actual model name and max_tokens for API calls
2025-06-05 18:03:18 -05:00
if use_large_model :
self . _configure_model_and_tokens (
self . model_config [ " large_model_name " ] ,
self . model_config [ " large_model_max_tokens " ]
)
else :
self . _configure_model_and_tokens (
self . model_config [ " small_model_name " ] ,
self . model_config [ " small_model_max_tokens " ]
)
2025-06-06 14:25:15 -05:00
2025-06-03 13:04:42 -05:00
@property
def processing_status ( self ) :
return self . _processing_status
def clear_conversation_history ( self , user_id ) :
if user_id in self . conversation_history :
del self . conversation_history [ user_id ]
for tool in self . tools :
tool . clear ( )
2025-06-02 16:43:39 -05:00
2025-06-03 13:04:42 -05:00
def _configure_model_and_tokens ( self , model_name : str | None , max_tokens_str : str | None ) :
self . model = model_name
2025-06-02 14:56:23 -05:00
try :
2025-06-02 16:43:39 -05:00
if max_tokens_str and max_tokens_str . lower ( ) not in [ " none " , " " , " null " ] :
self . max_tokens = int ( max_tokens_str )
else :
2025-06-06 14:25:15 -05:00
self . max_tokens = None
2025-06-02 14:56:23 -05:00
except ValueError :
2025-06-03 13:04:42 -05:00
logging . warning ( f " Invalid value for max_tokens: { max_tokens_str } . Using API default (None) " )
2025-06-06 14:25:15 -05:00
self . max_tokens = None
2025-06-02 16:43:39 -05:00
logging . info ( f " Configured to use model: { self . model } with max_tokens: { self . max_tokens if self . max_tokens is not None else ' API default ' } " )
2025-06-02 14:56:23 -05:00
def get_llm_description ( self ) - > str :
2025-06-02 16:43:39 -05:00
client_type = type ( self . client ) . __name__
return f " Client: { client_type } , LLM: { self . model } , Max Tokens: { self . max_tokens if self . max_tokens is not None else ' API default ' } "
2025-06-02 14:56:23 -05:00
2025-06-06 14:25:15 -05:00
def _count_tokens ( self , messages , model ) :
""" Returns the number of tokens in a list of messages. """
try :
encoding = tiktoken . encoding_for_model ( model )
except KeyError :
encoding = tiktoken . get_encoding ( " cl100k_base " ) # Fallback for unknown models
logging . warning ( f " Warning: model { model } not found. Using cl100k_base encoding. " )
num_tokens = 0
for message in messages :
num_tokens + = 4
2025-08-07 15:38:01 -05:00
if hasattr ( message , " items " ) :
for key , value in message . items ( ) :
if isinstance ( value , str ) :
num_tokens + = len ( encoding . encode ( value ) )
if key == " name " :
num_tokens + = 1
2025-06-06 14:25:15 -05:00
num_tokens + = 2
return num_tokens
2025-06-02 14:56:23 -05:00
def get_chat_response ( self , messages ) :
if not self . client :
2025-06-02 16:43:39 -05:00
logging . error ( " OpenAI client not initialized before get_chat_response. " )
raise ValueError ( " OpenAI client not initialized. " )
2025-06-02 14:56:23 -05:00
try :
2025-06-03 13:04:42 -05:00
cleaned_tools = None
if hasattr ( self , ' functions ' ) and self . functions :
cleaned_tools = [ ]
for func in self . functions :
include_function = False
if not hasattr ( self , ' allowed_function_tags ' ) or self . allowed_function_tags is None :
include_function = True
else :
tags = func . get ( " _tags " , [ ] )
if any ( tag in self . allowed_function_tags for tag in tags ) :
include_function = True
if include_function :
func_copy = { k : v for k , v in func . items ( ) if k != " _tags " }
cleaned_tools . append ( func_copy )
2025-06-02 14:56:23 -05:00
response = self . client . chat . completions . create (
2025-06-02 16:43:39 -05:00
model = self . model ,
2025-06-02 14:56:23 -05:00
messages = messages ,
2025-06-03 13:04:42 -05:00
tools = cleaned_tools ,
tool_choice = " auto " if cleaned_tools else None ,
2025-08-07 15:38:01 -05:00
max_tokens = self . max_tokens ,
2025-06-02 14:56:23 -05:00
)
return response
except Exception as e :
2025-06-02 16:43:39 -05:00
logging . error ( f " API call to model { self . model } failed: { e } " )
2025-06-02 14:56:23 -05:00
raise
2025-06-03 14:04:27 -05:00
def get_bot_status ( self ) :
"""
Returns a message with the currently enabled model and the system prompt path being used.
"""
model_name = self . model if hasattr ( self , ' model ' ) else None
prompt_path = self . system_prompt_path or os . getenv ( " SYSTEM_PROMPT_PATH " ) or " (default prompt in use) "
return f " Current model: { model_name } \n System prompt path: { prompt_path } "
2025-06-02 14:56:23 -05:00
async def handle_message ( self , user_id , user_message ) :
2025-06-02 16:43:39 -05:00
if user_id not in self . conversation_history or not self . conversation_history [ user_id ] :
2025-06-02 14:56:23 -05:00
self . conversation_history [ user_id ] = [ ]
2025-06-06 14:25:15 -05:00
if self . system_prompt :
2025-06-02 14:56:23 -05:00
self . conversation_history [ user_id ] . append ( { " role " : " system " , " content " : self . system_prompt } )
self . conversation_history [ user_id ] . append ( { " role " : " user " , " content " : user_message } )
2025-06-06 14:25:15 -05:00
messages = list ( self . conversation_history [ user_id ] )
# Pre-inference token limit check
current_model_is_small = self . model == self . model_config [ " small_model_name " ]
current_model_is_large = self . model == self . model_config [ " large_model_name " ]
inference_token_limit = None
if current_model_is_small :
inference_token_limit = self . small_model_max_inference_tokens
elif current_model_is_large :
inference_token_limit = self . large_model_max_inference_tokens
else :
logging . warning ( f " Could not determine inference token limit for model: { self . model } . Proceeding without check. " )
if inference_token_limit is not None :
token_count = self . _count_tokens ( messages , self . model )
if token_count > inference_token_limit :
logging . warning ( f " Request for user { user_id } exceeds inference token limit ( { token_count } / { inference_token_limit } ). " )
# Do not persist this message in history as it was not processed by LLM
# Remove the last user message from history before returning, to prevent accumulation
if self . conversation_history [ user_id ] and self . conversation_history [ user_id ] [ - 1 ] [ " role " ] == " user " and self . conversation_history [ user_id ] [ - 1 ] [ " content " ] == user_message :
self . conversation_history [ user_id ] . pop ( )
return " Request exceeds inference token limit. Please use the /clear command, or implement RAG in your application. "
2025-06-02 14:56:23 -05:00
response = self . get_chat_response ( messages )
if not ( response . choices and response . choices [ 0 ] . message ) :
logging . error ( " No valid response choice message from LLM. " )
2025-06-02 16:43:39 -05:00
self . conversation_history [ user_id ] = messages
2025-06-02 14:56:23 -05:00
return " Error: Could not get a valid response from the LLM. "
2025-06-02 16:43:39 -05:00
assistant_message = response . choices [ 0 ] . message
messages . append ( assistant_message )
2025-06-02 14:56:23 -05:00
2025-06-02 16:43:39 -05:00
tool_calls_from_response = list ( assistant_message . tool_calls ) if assistant_message . tool_calls else [ ]
2025-06-02 14:56:23 -05:00
tool_use_count = 0
2025-06-02 19:35:41 -05:00
MAX_TOOL_ITERATIONS = 200
2025-06-02 14:56:23 -05:00
while tool_calls_from_response and tool_use_count < MAX_TOOL_ITERATIONS :
tool_results_for_model = [ ]
for tool_call in tool_calls_from_response :
tool_call_id = tool_call . id
function_to_call = tool_call . function
2025-06-02 16:43:39 -05:00
function_name = function_to_call . name
function_args_str = function_to_call . arguments
2025-06-02 14:56:23 -05:00
2025-06-02 16:43:39 -05:00
logging . info ( f " Attempting to call tool: { function_name } with args: { function_args_str } " )
2025-06-03 17:36:26 -05:00
if function_name not in [ f [ " function " ] [ " name " ] for f in self . functions ] :
2025-06-03 17:32:19 -05:00
logging . warning ( f " Tool function { function_name } not found in available functions. " )
tool_results_for_model . append ( {
" role " : " tool " ,
" tool_call_id " : tool_call_id ,
" name " : function_name ,
" content " : f " Error: Tool function { function_name } not found. "
} )
continue
2025-06-02 14:56:23 -05:00
try :
2025-06-02 16:43:39 -05:00
tool_response_content = self . call_tool ( function_name , function_args_str )
2025-06-02 14:56:23 -05:00
if not isinstance ( tool_response_content , str ) :
tool_response_content = json . dumps ( tool_response_content )
except Exception as e :
2025-06-02 16:43:39 -05:00
logging . error ( f " Error calling tool { function_name } : { e } " )
tool_response_content = f " Error executing tool { function_name } : { str ( e ) } "
2025-06-02 14:56:23 -05:00
tool_results_for_model . append ( {
" role " : " tool " ,
" tool_call_id " : tool_call_id ,
2025-06-02 16:43:39 -05:00
" name " : function_name ,
2025-06-02 14:56:23 -05:00
" content " : tool_response_content
} )
messages . extend ( tool_results_for_model )
response = self . get_chat_response ( messages )
if not ( response . choices and response . choices [ 0 ] . message ) :
logging . error ( " No valid response choice message from LLM after tool call. " )
2025-06-06 14:25:15 -05:00
self . conversation_history [ user_id ] = messages
2025-06-02 14:56:23 -05:00
return " Error: Could not get a valid response from the LLM after tool call. "
2025-06-02 16:43:39 -05:00
assistant_message = response . choices [ 0 ] . message
messages . append ( assistant_message )
2025-06-02 14:56:23 -05:00
2025-06-02 16:43:39 -05:00
tool_calls_from_response = list ( assistant_message . tool_calls ) if assistant_message . tool_calls else [ ]
2025-06-02 14:56:23 -05:00
tool_use_count + = 1
if tool_use_count > = MAX_TOOL_ITERATIONS and tool_calls_from_response :
logging . warning ( f " Max tool iterations ( { MAX_TOOL_ITERATIONS } ) reached. Returning last assistant message. " )
2025-06-02 16:43:39 -05:00
break
2025-06-02 17:13:05 -05:00
self . conversation_history [ user_id ] = messages
2025-06-02 14:56:23 -05:00
final_assistant_message = messages [ - 1 ]
2025-06-02 16:43:39 -05:00
return final_assistant_message . content if final_assistant_message . role == " assistant " and final_assistant_message . content is not None else " Assistant did not provide a textual response. "
2025-06-02 14:56:23 -05:00
async def start ( self ) :
2025-06-02 16:43:39 -05:00
logging . info ( f " { self . __class__ . __name__ } (Model: { self . model } ) started. " )
2025-06-02 14:56:23 -05:00
async def abort_processing ( self , user_id ) :
if user_id in self . processing_status :
2025-06-06 14:25:15 -05:00
self . clear_processing_status ( user_id )
2025-06-02 16:43:39 -05:00
logging . info ( f " Processing aborted for user { user_id } . " )
return " Processing aborted. You can send a new message or /clear the conversation. "
2025-06-02 14:56:23 -05:00
else :
2025-06-02 16:43:39 -05:00
return " No active processing found to abort. If you wish, /clear the conversation history. "
2025-06-03 13:04:42 -05:00
def load_functions ( self ) :
tools = [ ]
functions = [ ]
tools_dir = os . path . join ( os . path . dirname ( __file__ ) , ' tools ' )
if not os . path . exists ( tools_dir ) :
logging . warning ( f " Tools directory not found: { tools_dir } " )
return [ ] , [ ]
2025-06-02 14:56:23 -05:00
2025-06-03 13:04:42 -05:00
for filename in os . listdir ( tools_dir ) :
if filename . endswith ( ' .py ' ) and filename != ' __init__.py ' and filename != ' base_tool.py ' :
module_name = f ' tools. { filename [ : - 3 ] } '
try :
module = importlib . import_module ( module_name )
2025-06-05 18:06:13 -05:00
for name , obj in inspect . getmembers ( module ) :
2025-06-03 13:04:42 -05:00
if inspect . isclass ( obj ) and issubclass ( obj , BaseTool ) and obj != BaseTool :
try :
tools . append ( obj ( ) ) # This instantiation might be an issue for tools needing config
except Exception as e :
logging . error ( f " Error instantiating tool { name } from { filename } : { e } " )
except Exception as e :
logging . error ( f " Error importing module { module_name } : { e } " )
for tool in tools :
functions . extend ( tool . get_functions ( ) )
return tools , functions
def load_system_prompt ( self , direct_content : str | None = None , file_path : str | None = None ) - > str :
default_prompt = " You are a helpful AI assistant. "
if direct_content :
logging . info ( " Using direct content for system prompt. " )
return direct_content . strip ( )
prompt_path_to_try = file_path or os . getenv ( " SYSTEM_PROMPT_PATH " )
if prompt_path_to_try :
if os . path . isfile ( prompt_path_to_try ) :
try :
with open ( prompt_path_to_try , " r " , encoding = " utf-8 " ) as file :
content = file . read ( ) . strip ( )
logging . info ( f " Successfully loaded system prompt from { prompt_path_to_try } . " )
return content
except IOError as e :
logging . warning ( f " Could not read system prompt file { prompt_path_to_try } : { e } . Using default. " )
return default_prompt
else :
logging . warning ( f " System prompt file { prompt_path_to_try } not found. Using default system prompt. " )
return default_prompt
else :
logging . info ( " No system prompt path provided (argument or ENV) or direct content. Using default system prompt. " )
return default_prompt
def set_processing_status ( self , user_id : int , message_id : int ) :
self . processing_status [ user_id ] = { " processing " : True , " message_id " : message_id }
def clear_processing_status ( self , user_id : int ) :
if user_id in self . processing_status :
del self . processing_status [ user_id ]
def call_tool ( self , function_call_name , function_call_arguments ) :
function_name = function_call_name
function_args = None
if isinstance ( function_call_arguments , dict ) :
function_args = function_call_arguments
elif isinstance ( function_call_arguments , str ) :
try :
function_args = json . loads ( function_call_arguments )
except json . JSONDecodeError as e :
logging . error ( f " Error decoding function call arguments (string) for { function_call_name } : { e } . Arguments: { function_call_arguments } " )
return f " Error: Malformed arguments for tool call: { e } "
else :
if function_call_arguments is None :
function_args = { }
else :
2025-06-03 14:04:27 -05:00
logging . error ( f " Unexpected type for function_call_arguments for { function_name } : { type ( function_call_arguments ) } . Arguments: { function_call_arguments } " )
2025-06-03 13:04:42 -05:00
return f " Error: Invalid argument type for tool call: { type ( function_call_arguments ) } "
for tool in self . tools :
for function in tool . get_functions ( ) :
if function [ " function " ] [ " name " ] == function_name :
try :
if not isinstance ( function_args , dict ) :
logging . error ( f " Internal error: function_args not a dict for { function_name } before execution. Args: { function_args } " )
return f " Internal error preparing arguments for tool { function_name } . "
return tool . execute ( function_name , * * function_args )
except Exception as e :
logging . error ( f " Error executing tool { function_name } with args { function_args } : { e } " )
return f " Error executing tool { function_name } : { e } "
logging . warning ( f " Tool function { function_name } not found. " )
return f " Error: Tool function { function_name } not found. "
2025-06-02 14:56:23 -05:00
async def switch_model ( self ) :
2025-06-03 13:04:42 -05:00
if not self . model_config [ " small_model_name " ] or not self . model_config [ " large_model_name " ] :
logging . warning ( " Small or Large model names are not defined. Cannot switch model. " )
return f " Model switching not fully configured. Currently using { self . model } . "
current_is_small = self . model == self . model_config [ " small_model_name " ]
current_is_large = self . model == self . model_config [ " large_model_name " ]
if current_is_large :
target_model = self . model_config [ " small_model_name " ]
target_max_tokens_str = self . model_config [ " small_model_max_tokens " ]
elif current_is_small :
target_model = self . model_config [ " large_model_name " ]
target_max_tokens_str = self . model_config [ " large_model_max_tokens " ]
else :
logging . warning ( f " Current model { self . model } is unrecognized. Switching to default small model: { self . model_config [ ' small_model_name ' ] } . " )
target_model = self . model_config [ " small_model_name " ]
target_max_tokens_str = self . model_config [ " small_model_max_tokens " ]
self . _configure_model_and_tokens ( target_model , target_max_tokens_str )
2025-06-03 13:54:38 -05:00
return f " Switched model to { self . model } . Max tokens set to { self . max_tokens if self . max_tokens is not None else ' API default ' } . "
2025-06-03 13:04:42 -05:00
def main ( ) :
logging . basicConfig ( level = logging . INFO , format = ' %(asctime)s - %(name)s - %(levelname)s - %(message)s ' )
bot = None
try :
parser = argparse . ArgumentParser ( description = ' OpenAI Compatible Inference Bot ' )
parser . add_argument ( ' --config ' , type = str , help = ' Configuration Prepend (i.e. gemini, openai, etc) ' , default = " Telegram " )
parser . add_argument ( ' --messenger ' , type = str , help = ' Messenger type (i.e. telegram) ' , required = True )
parser . add_argument ( ' --persona ' , type = str , help = ' Path to system prompt file ' , required = False )
parser . add_argument ( ' --tools ' , nargs = ' + ' , help = ' List of allowed function tags ' , required = False )
2025-06-06 14:25:15 -05:00
parser . add_argument ( ' --use-large-model ' , action = ' store_true ' , help = ' Use the large model instead of the small model ' )
2025-06-03 13:04:42 -05:00
# Add these to launch.json arguments if you want to limit the toolset available: "--tools", "read", "communicate"
# Parse command line arguments
args = parser . parse_args ( )
if args . persona :
logging . info ( f " Using custom persona from: { args . persona } " )
system_prompt_path = args . persona if args . persona else None
allowed_function_tags = args . tools if args . tools else None
config_prepend = args . config if args . config else None
messenger = args . messenger if args . messenger else None
2025-06-06 14:25:15 -05:00
use_large_model = args . use_large_model
2025-06-03 13:04:42 -05:00
# Initialize model and max tokens based on the config prepend
if config_prepend :
api_key = os . environ . get ( f " { config_prepend . upper ( ) } _API_KEY " )
baseurl = os . environ . get ( f " { config_prepend . upper ( ) } _API_BASE_URL " , " " )
small_model_name = os . environ . get ( f " { config_prepend . upper ( ) } _SMALL_MODEL " )
large_model_name = os . environ . get ( f " { config_prepend . upper ( ) } _LARGE_MODEL " )
small_model_max_tokens = os . environ . get ( f " { config_prepend . upper ( ) } _SMALL_MODEL_MAX_TOKENS " )
large_model_max_tokens = os . environ . get ( f " { config_prepend . upper ( ) } _LARGE_MODEL_MAX_TOKENS " )
bot = OpenAICompatibleInferenceBot (
api_key = api_key ,
base_url = baseurl ,
small_model_name = small_model_name ,
small_model_max_tokens = small_model_max_tokens ,
large_model_name = large_model_name ,
large_model_max_tokens = large_model_max_tokens ,
system_prompt_path = system_prompt_path ,
2025-06-05 18:03:18 -05:00
allowed_function_tags = allowed_function_tags ,
2025-06-06 14:25:15 -05:00
use_large_model = use_large_model
2025-06-03 13:04:42 -05:00
)
2025-06-03 17:32:19 -05:00
full_code_file = importlib . import_module ( f ' { messenger . lower ( ) } _helper ' )
2025-06-03 13:04:42 -05:00
messenger_helper_class_name = f " { messenger . capitalize ( ) } Helper "
2025-06-03 17:32:19 -05:00
if not hasattr ( full_code_file , messenger_helper_class_name ) :
messenger_helper_class_name = f " { messenger . upper ( ) } Helper "
if not hasattr ( full_code_file , messenger_helper_class_name ) :
raise ValueError ( f " Messenger helper class { messenger_helper_class_name } not found in { full_code_file . __name__ } . " )
helper_class = getattr ( full_code_file , messenger_helper_class_name )
2025-06-03 13:04:42 -05:00
2025-06-03 17:32:19 -05:00
helper = helper_class ( bot )
2025-06-03 13:04:42 -05:00
helper . run ( )
except ValueError as e :
logging . error ( f " FATAL: { e } " )
return
except Exception as e : # Catch any other init errors
logging . error ( f " An unexpected error occurred during bot initialization: { e } " )
return
if __name__ == ' __main__ ' :
2025-06-03 14:04:27 -05:00
main ( )