2025-06-02 14:56:23 -05:00
import json
import os
import logging
from abc import abstractmethod
from base_telegram_inference_bot import BaseTelegramInferenceBot
2025-06-02 16:43:39 -05:00
from openai import OpenAI , AzureOpenAI # Import both
2025-06-02 14:56:23 -05:00
class OpenAICompatibleInferenceBot ( BaseTelegramInferenceBot ) :
2025-06-02 16:43:39 -05:00
DEFAULT_MAX_TOKENS = 1000 # Default for _configure_model_and_tokens
def __init__ (
self ,
client : OpenAI | AzureOpenAI | None = None ,
api_key : str | None = None ,
base_url : str | None = None ,
api_version : str | None = None , # For Azure
azure_deployment : str | None = None , # Model for Azure, distinct from general model_name if needed
model_name : str | None = None , # General model name for the API call
max_tokens_str : str | None = None ,
system_prompt_content : str | None = None ,
system_prompt_path : str | None = None ,
is_gemini : bool = False , # Hint for specific API key if others are not set
max_history_length : int | None = None
) :
super ( ) . __init__ ( system_prompt_content = system_prompt_content , system_prompt_path = system_prompt_path )
self . client = client
if not self . client :
_api_key = api_key
_base_url = base_url
_api_version = api_version
_azure_deployment_name = azure_deployment # This will be used as the model for Azure
# Determine if configuring for Azure OpenAI
is_azure = False
if _azure_deployment_name or ( _base_url and " azure.com " in _base_url ) or os . environ . get ( " AZURE_OPENAI_ENDPOINT " ) :
is_azure = True
if is_azure :
_base_url = _base_url or os . environ . get ( " AZURE_OPENAI_ENDPOINT " )
_api_key = _api_key or os . environ . get ( " AZURE_OPENAI_KEY " )
_api_version = _api_version or os . environ . get ( " AZURE_OPENAI_API_VERSION " )
# For Azure, the model parameter in API calls is the deployment name
_effective_model_name = _azure_deployment_name or model_name # Use deployment if available, else model_name
if not _base_url or not _api_key or not _api_version or not _effective_model_name :
raise ValueError ( " For Azure OpenAI, endpoint, API key, API version, and deployment/model name must be configured. " )
self . client = AzureOpenAI (
api_key = _api_key ,
azure_endpoint = _base_url ,
api_version = _api_version
)
# The model to be used in API calls for Azure is the deployment name.
# _configure_model_and_tokens will set self.model to this.
model_name_for_config = _effective_model_name
logging . info ( f " Initialized AzureOpenAI client for deployment: { model_name_for_config } at { _base_url } " )
else :
# Standard OpenAI or other OpenAI-compatible (like Gemini via base_url)
_base_url = _base_url or os . environ . get ( " OPENAI_API_BASE_URL " ) # For other compatible APIs
if not _api_key : # Try different ENV sources for API key
if is_gemini and os . environ . get ( " GEMINI_API_KEY " ) :
_api_key = os . environ . get ( " GEMINI_API_KEY " )
else :
_api_key = os . environ . get ( " OPENAI_API_KEY " )
if not _api_key and not _base_url : # For completely local models with no key needed via base_url
pass # Allow client to be created with no API key if base_url is set and points to local model
elif not _api_key :
raise ValueError ( " API key must be provided for OpenAI compatible client if not Azure or local anonymous. " )
self . client = OpenAI ( api_key = _api_key , base_url = _base_url )
model_name_for_config = model_name # Use the general model_name for non-Azure
log_msg = f " Initialized OpenAI compatible client. Target URL: { _base_url if _base_url else ' OpenAI default ' } . "
logging . info ( log_msg )
else :
# Client was provided directly
model_name_for_config = model_name # Use provided model_name
logging . info ( f " Using provided client: { type ( self . client ) } " )
# Configure the actual model name and max_tokens for API calls
self . _configure_model_and_tokens (
model_name_for_config ,
max_tokens_str ,
default_max_tokens = self . DEFAULT_MAX_TOKENS
)
def _configure_model_and_tokens ( self , model_name : str | None , max_tokens_str : str | None , default_max_tokens : int = 1000 ) :
self . model = model_name if model_name else " default-model " # Fallback model name
2025-06-02 14:56:23 -05:00
try :
2025-06-02 16:43:39 -05:00
# If max_tokens_str is explicitly "None" or empty, treat as None for API default
if max_tokens_str and max_tokens_str . lower ( ) not in [ " none " , " " , " null " ] :
self . max_tokens = int ( max_tokens_str )
else :
self . max_tokens = None # Use API default by not sending the parameter or sending null
2025-06-02 14:56:23 -05:00
except ValueError :
2025-06-02 16:43:39 -05:00
logging . warning ( f " Invalid value for max_tokens: { max_tokens_str } . Using API default (None). stalwart default was { default_max_tokens } " )
self . max_tokens = None # Use API default
logging . info ( f " Configured to use model: { self . model } with max_tokens: { self . max_tokens if self . max_tokens is not None else ' API default ' } " )
2025-06-02 14:56:23 -05:00
def get_llm_description ( self ) - > str :
2025-06-02 16:43:39 -05:00
client_type = type ( self . client ) . __name__
return f " Client: { client_type } , LLM: { self . model } , Max Tokens: { self . max_tokens if self . max_tokens is not None else ' API default ' } "
2025-06-02 14:56:23 -05:00
def get_chat_response ( self , messages ) :
if not self . client :
2025-06-02 16:43:39 -05:00
# This should ideally not be hit if __init__ is successful
logging . error ( " OpenAI client not initialized before get_chat_response. " )
raise ValueError ( " OpenAI client not initialized. " )
2025-06-02 14:56:23 -05:00
try :
2025-06-02 16:43:39 -05:00
# Pass self.max_tokens directly. If None, OpenAI library omits it or handles it.
2025-06-02 14:56:23 -05:00
response = self . client . chat . completions . create (
2025-06-02 16:43:39 -05:00
model = self . model ,
2025-06-02 14:56:23 -05:00
messages = messages ,
tools = self . functions if hasattr ( self , ' functions ' ) and self . functions else None ,
tool_choice = " auto " if hasattr ( self , ' functions ' ) and self . functions else None ,
2025-06-02 16:43:39 -05:00
max_tokens = self . max_tokens
2025-06-02 14:56:23 -05:00
)
return response
except Exception as e :
2025-06-02 16:43:39 -05:00
logging . error ( f " API call to model { self . model } failed: { e } " )
2025-06-02 14:56:23 -05:00
raise
async def handle_message ( self , user_id , user_message ) :
2025-06-02 16:43:39 -05:00
if user_id not in self . conversation_history or not self . conversation_history [ user_id ] :
2025-06-02 14:56:23 -05:00
self . conversation_history [ user_id ] = [ ]
2025-06-02 16:43:39 -05:00
if self . system_prompt : # Use the loaded system_prompt
2025-06-02 14:56:23 -05:00
self . conversation_history [ user_id ] . append ( { " role " : " system " , " content " : self . system_prompt } )
self . conversation_history [ user_id ] . append ( { " role " : " user " , " content " : user_message } )
2025-06-02 16:43:39 -05:00
messages = list ( self . conversation_history [ user_id ] ) # Work with a copy for this turn
2025-06-02 14:56:23 -05:00
response = self . get_chat_response ( messages )
if not ( response . choices and response . choices [ 0 ] . message ) :
logging . error ( " No valid response choice message from LLM. " )
2025-06-02 16:43:39 -05:00
# Persist the user message in history even if LLM fails this turn
self . conversation_history [ user_id ] = messages
2025-06-02 14:56:23 -05:00
return " Error: Could not get a valid response from the LLM. "
2025-06-02 16:43:39 -05:00
assistant_message = response . choices [ 0 ] . message
messages . append ( assistant_message )
2025-06-02 14:56:23 -05:00
2025-06-02 16:43:39 -05:00
tool_calls_from_response = list ( assistant_message . tool_calls ) if assistant_message . tool_calls else [ ]
2025-06-02 14:56:23 -05:00
tool_use_count = 0
2025-06-02 16:43:39 -05:00
MAX_TOOL_ITERATIONS = 5 # OpenAI compatible typically uses fewer iterations than Anthropic
2025-06-02 14:56:23 -05:00
while tool_calls_from_response and tool_use_count < MAX_TOOL_ITERATIONS :
tool_results_for_model = [ ]
for tool_call in tool_calls_from_response :
tool_call_id = tool_call . id
function_to_call = tool_call . function
2025-06-02 16:43:39 -05:00
function_name = function_to_call . name
function_args_str = function_to_call . arguments
2025-06-02 14:56:23 -05:00
2025-06-02 16:43:39 -05:00
logging . info ( f " Attempting to call tool: { function_name } with args: { function_args_str } " )
2025-06-02 14:56:23 -05:00
try :
2025-06-02 16:43:39 -05:00
# Arguments are already a string from the API, self.call_tool expects dict or string
tool_response_content = self . call_tool ( function_name , function_args_str )
# Ensure content is string for OpenAI tool role
2025-06-02 14:56:23 -05:00
if not isinstance ( tool_response_content , str ) :
tool_response_content = json . dumps ( tool_response_content )
except Exception as e :
2025-06-02 16:43:39 -05:00
logging . error ( f " Error calling tool { function_name } : { e } " )
tool_response_content = f " Error executing tool { function_name } : { str ( e ) } "
2025-06-02 14:56:23 -05:00
tool_results_for_model . append ( {
" role " : " tool " ,
" tool_call_id " : tool_call_id ,
2025-06-02 16:43:39 -05:00
" name " : function_name ,
2025-06-02 14:56:23 -05:00
" content " : tool_response_content
} )
messages . extend ( tool_results_for_model )
response = self . get_chat_response ( messages )
if not ( response . choices and response . choices [ 0 ] . message ) :
logging . error ( " No valid response choice message from LLM after tool call. " )
2025-06-02 16:43:39 -05:00
self . conversation_history [ user_id ] = messages # Persist state before error
2025-06-02 14:56:23 -05:00
return " Error: Could not get a valid response from the LLM after tool call. "
2025-06-02 16:43:39 -05:00
assistant_message = response . choices [ 0 ] . message
messages . append ( assistant_message )
2025-06-02 14:56:23 -05:00
2025-06-02 16:43:39 -05:00
tool_calls_from_response = list ( assistant_message . tool_calls ) if assistant_message . tool_calls else [ ]
2025-06-02 14:56:23 -05:00
tool_use_count + = 1
if tool_use_count > = MAX_TOOL_ITERATIONS and tool_calls_from_response :
logging . warning ( f " Max tool iterations ( { MAX_TOOL_ITERATIONS } ) reached. Returning last assistant message. " )
2025-06-02 16:43:39 -05:00
# Ensure final content is returned even if max iterations hit with pending tool calls
break
2025-06-02 17:13:05 -05:00
self . conversation_history [ user_id ] = messages
2025-06-02 14:56:23 -05:00
final_assistant_message = messages [ - 1 ]
2025-06-02 16:43:39 -05:00
return final_assistant_message . content if final_assistant_message . role == " assistant " and final_assistant_message . content is not None else " Assistant did not provide a textual response. "
2025-06-02 14:56:23 -05:00
async def start ( self ) :
2025-06-02 16:43:39 -05:00
logging . info ( f " { self . __class__ . __name__ } (Model: { self . model } ) started. " )
2025-06-02 14:56:23 -05:00
2025-06-02 16:43:39 -05:00
# clear_conversation_history is inherited from BaseTelegramInferenceBot
2025-06-02 14:56:23 -05:00
async def abort_processing ( self , user_id ) :
2025-06-02 16:43:39 -05:00
# This is a soft abort for OpenAI compatible bots as API calls are synchronous within handle_message
2025-06-02 14:56:23 -05:00
if user_id in self . processing_status :
2025-06-02 16:43:39 -05:00
self . clear_processing_status ( user_id ) # Use base class method
logging . info ( f " Processing aborted for user { user_id } . " )
# Optionally clear conversation history or let user do it explicitly
# super().clear_conversation_history(user_id)
return " Processing aborted. You can send a new message or /clear the conversation. "
2025-06-02 14:56:23 -05:00
else :
2025-06-02 16:43:39 -05:00
# super().clear_conversation_history(user_id)
return " No active processing found to abort. If you wish, /clear the conversation history. "
2025-06-02 14:56:23 -05:00
@abstractmethod
async def switch_model ( self ) :
2025-06-02 16:43:39 -05:00
pass