2024-08-19 12:54:13 -05:00
import importlib
2024-08-19 11:34:31 -05:00
import os
import json
2024-08-19 12:54:13 -05:00
import inspect
2025-06-02 14:55:45 -05:00
import logging
2024-08-19 11:34:31 -05:00
from abc import ABC , abstractmethod
2024-08-19 12:54:13 -05:00
from tools . base_tool import BaseTool
2024-08-19 11:34:31 -05:00
class BaseTelegramInferenceBot ( ABC ) :
2025-06-02 16:35:08 -05:00
def __init__ ( self , system_prompt_content : str | None = None , system_prompt_path : str | None = None ) : # MODIFIED
2024-08-19 11:34:31 -05:00
self . conversation_history = { }
self . processing_status = { }
2025-06-02 16:35:08 -05:00
# MODIFIED to pass arguments
self . system_prompt = self . load_system_prompt (
direct_content = system_prompt_content ,
file_path = system_prompt_path
)
2024-08-19 12:54:13 -05:00
self . tools , self . functions = self . load_functions ( )
2025-06-02 16:35:08 -05:00
# Logging the actual source of the system prompt might be more complex now,
# but we can log the final prompt or indicate if it's custom/default.
# We'll also log the source of the prompt inside load_system_prompt.
logging . info ( f ' System Prompt (effective): { " Custom " if self . system_prompt != " You are a helpful AI assistant. " else " Default " } ' )
2025-06-02 14:55:45 -05:00
logging . info ( f ' Github Repository: { os . environ . get ( " GITHUB_REPOSITORY " ) } ' )
2024-08-19 11:34:31 -05:00
2025-06-02 16:35:08 -05:00
def load_system_prompt ( self , direct_content : str | None = None , file_path : str | None = None ) - > str : # MODIFIED
default_prompt = " You are a helpful AI assistant. "
if direct_content :
logging . info ( " Using direct content for system prompt. " )
return direct_content . strip ( )
prompt_path_to_try = file_path or os . getenv ( " SYSTEM_PROMPT_PATH " )
if prompt_path_to_try :
if os . path . isfile ( prompt_path_to_try ) :
try :
with open ( prompt_path_to_try , " r " , encoding = " utf-8 " ) as file :
content = file . read ( ) . strip ( )
logging . info ( f " Successfully loaded system prompt from { prompt_path_to_try } . " )
return content
except IOError as e :
logging . warning ( f " Could not read system prompt file { prompt_path_to_try } : { e } . Using default. " )
return default_prompt
else :
# This condition now also covers if 'file_path' argument was given but invalid
logging . warning ( f " System prompt file { prompt_path_to_try } not found. Using default system prompt. " )
return default_prompt
2024-08-21 08:09:47 -05:00
else :
2025-06-02 16:35:08 -05:00
logging . info ( " No system prompt path provided (argument or ENV) or direct content. Using default system prompt. " )
return default_prompt
2024-08-19 11:34:31 -05:00
2025-06-02 14:55:45 -05:00
def load_functions ( self ) :
2024-08-19 12:54:13 -05:00
tools = [ ]
2025-06-02 14:55:45 -05:00
functions = [ ]
2024-08-19 12:54:13 -05:00
tools_dir = os . path . join ( os . path . dirname ( __file__ ) , ' tools ' )
2025-06-02 14:55:45 -05:00
if not os . path . exists ( tools_dir ) :
logging . warning ( f " Tools directory not found: { tools_dir } " )
return [ ] , [ ]
2024-08-19 12:54:13 -05:00
for filename in os . listdir ( tools_dir ) :
if filename . endswith ( ' .py ' ) and filename != ' __init__.py ' and filename != ' base_tool.py ' :
module_name = f ' tools. { filename [ : - 3 ] } '
2025-06-02 14:55:45 -05:00
try :
module = importlib . import_module ( module_name )
for name , obj in inspect . getmembers ( module ) :
if inspect . isclass ( obj ) and issubclass ( obj , BaseTool ) and obj != BaseTool :
try :
2025-06-02 16:35:08 -05:00
tools . append ( obj ( ) ) # This instantiation might be an issue for tools needing config
2025-06-02 14:55:45 -05:00
except Exception as e :
logging . error ( f " Error instantiating tool { name } from { filename } : { e } " )
except Exception as e :
logging . error ( f " Error importing module { module_name } : { e } " )
2024-08-19 12:54:13 -05:00
for tool in tools :
functions . extend ( tool . get_functions ( ) )
return tools , functions
2024-08-19 11:34:31 -05:00
@abstractmethod
def get_chat_response ( self , messages ) :
pass
@abstractmethod
async def handle_message ( self , user_id , user_message ) :
pass
2025-06-02 14:55:45 -05:00
def clear_conversation_history ( self , user_id ) :
2024-08-19 11:34:31 -05:00
if user_id in self . conversation_history :
del self . conversation_history [ user_id ]
2024-08-19 13:38:39 -05:00
for tool in self . tools :
tool . clear ( )
2024-08-19 11:34:31 -05:00
2025-06-02 14:55:45 -05:00
def set_processing_status ( self , user_id : int , message_id : int ) :
self . processing_status [ user_id ] = { " processing " : True , " message_id " : message_id }
def clear_processing_status ( self , user_id : int ) :
if user_id in self . processing_status :
del self . processing_status [ user_id ]
2024-08-19 12:54:13 -05:00
def call_tool ( self , function_call_name , function_call_arguments ) :
function_name = function_call_name
2025-06-02 15:56:43 -05:00
function_args = None
if isinstance ( function_call_arguments , dict ) :
function_args = function_call_arguments
elif isinstance ( function_call_arguments , str ) :
try :
function_args = json . loads ( function_call_arguments )
except json . JSONDecodeError as e :
logging . error ( f " Error decoding function call arguments (string) for { function_call_name } : { e } . Arguments: { function_call_arguments } " )
return f " Error: Malformed arguments for tool call: { e } "
2025-06-02 16:35:08 -05:00
else :
2025-06-02 15:56:43 -05:00
if function_call_arguments is None :
2025-06-02 16:35:08 -05:00
function_args = { }
2025-06-02 15:56:43 -05:00
else :
logging . error ( f " Unexpected type for function_call_arguments for { function_call_name } : { type ( function_call_arguments ) } . Arguments: { function_call_arguments } " )
return f " Error: Invalid argument type for tool call: { type ( function_call_arguments ) } "
2025-06-02 14:55:45 -05:00
2024-08-19 11:34:31 -05:00
for tool in self . tools :
2025-06-01 11:50:12 -05:00
for function in tool . get_functions ( ) :
if function [ " function " ] [ " name " ] == function_name :
2025-06-02 14:55:45 -05:00
try :
2025-06-02 15:56:43 -05:00
if not isinstance ( function_args , dict ) :
logging . error ( f " Internal error: function_args not a dict for { function_name } before execution. Args: { function_args } " )
return f " Internal error preparing arguments for tool { function_name } . "
2025-06-02 14:55:45 -05:00
return tool . execute ( function_name , * * function_args )
except Exception as e :
logging . error ( f " Error executing tool { function_name } with args { function_args } : { e } " )
return f " Error executing tool { function_name } : { e } "
logging . warning ( f " Tool function { function_name } not found. " )
return f " Error: Tool function { function_name } not found. "
2025-06-02 14:30:08 -05:00
def get_system_prompt_description ( self ) - > str :
2025-06-02 16:35:08 -05:00
# This method could be updated to be more specific about the prompt source if needed.
# For now, it still reflects custom vs default based on the original ENV var logic's spirit.
# A more accurate reflection would require storing how the prompt was loaded.
# For simplicity, let's assume if it's not the default, it's "Custom".
if self . system_prompt != " You are a helpful AI assistant. " :
return " System Prompt: Custom "
# Check original ENV var for backward compatibility in description only
elif os . getenv ( ' SYSTEM_PROMPT_PATH ' ) :
return " System Prompt: Custom (via ENV) "
return " System Prompt: Default "
2024-08-19 11:34:31 -05:00
@abstractmethod
2025-06-02 14:30:08 -05:00
def get_llm_description ( self ) - > str :
2024-08-19 11:34:31 -05:00
pass
2025-06-02 14:55:45 -05:00
async def get_bot_status ( self ) - > str :
2025-06-02 14:30:08 -05:00
prompt_desc = self . get_system_prompt_description ( )
llm_desc = self . get_llm_description ( )
return f " { prompt_desc } \n { llm_desc } "
2024-08-19 11:34:31 -05:00
@abstractmethod
2025-06-02 14:30:08 -05:00
async def start ( self ) :
2024-08-19 11:34:31 -05:00
pass
@abstractmethod
2025-06-02 14:55:45 -05:00
async def abort_processing ( self , user_id ) :
2024-08-19 11:34:31 -05:00
pass
2025-06-02 14:55:45 -05:00
2024-08-19 11:34:31 -05:00
@abstractmethod
2025-06-02 14:55:45 -05:00
async def switch_model ( self ) :
2025-06-02 14:30:08 -05:00
pass