import json import os import logging from abc import abstractmethod from base_telegram_inference_bot import BaseTelegramInferenceBot from openai import OpenAI, AzureOpenAI # Import both class OpenAICompatibleInferenceBot(BaseTelegramInferenceBot): DEFAULT_MAX_TOKENS = 1000 # Default for _configure_model_and_tokens def __init__( self, client: OpenAI | AzureOpenAI | None = None, api_key: str | None = None, base_url: str | None = None, api_version: str | None = None, # For Azure azure_deployment: str | None = None, # Model for Azure, distinct from general model_name if needed model_name: str | None = None, # General model name for the API call max_tokens_str: str | None = None, system_prompt_content: str | None = None, system_prompt_path: str | None = None, is_gemini: bool = False, # Hint for specific API key if others are not set max_history_length: int | None = None ): super().__init__(system_prompt_content=system_prompt_content, system_prompt_path=system_prompt_path) self.client = client if not self.client: _api_key = api_key _base_url = base_url _api_version = api_version _azure_deployment_name = azure_deployment # This will be used as the model for Azure # Determine if configuring for Azure OpenAI is_azure = False if _azure_deployment_name or (_base_url and "azure.com" in _base_url) or os.environ.get("AZURE_OPENAI_ENDPOINT"): is_azure = True if is_azure: _base_url = _base_url or os.environ.get("AZURE_OPENAI_ENDPOINT") _api_key = _api_key or os.environ.get("AZURE_OPENAI_KEY") _api_version = _api_version or os.environ.get("AZURE_OPENAI_API_VERSION") # For Azure, the model parameter in API calls is the deployment name _effective_model_name = _azure_deployment_name or model_name # Use deployment if available, else model_name if not _base_url or not _api_key or not _api_version or not _effective_model_name: raise ValueError("For Azure OpenAI, endpoint, API key, API version, and deployment/model name must be configured.") self.client = AzureOpenAI( api_key=_api_key, azure_endpoint=_base_url, api_version=_api_version ) # The model to be used in API calls for Azure is the deployment name. # _configure_model_and_tokens will set self.model to this. model_name_for_config = _effective_model_name logging.info(f"Initialized AzureOpenAI client for deployment: {model_name_for_config} at {_base_url}") else: # Standard OpenAI or other OpenAI-compatible (like Gemini via base_url) _base_url = _base_url or os.environ.get("OPENAI_API_BASE_URL") # For other compatible APIs if not _api_key: # Try different ENV sources for API key if is_gemini and os.environ.get("GEMINI_API_KEY"): _api_key = os.environ.get("GEMINI_API_KEY") else: _api_key = os.environ.get("OPENAI_API_KEY") if not _api_key and not _base_url : # For completely local models with no key needed via base_url pass # Allow client to be created with no API key if base_url is set and points to local model elif not _api_key: raise ValueError("API key must be provided for OpenAI compatible client if not Azure or local anonymous.") self.client = OpenAI(api_key=_api_key, base_url=_base_url) model_name_for_config = model_name # Use the general model_name for non-Azure log_msg = f"Initialized OpenAI compatible client. Target URL: {_base_url if _base_url else 'OpenAI default'}." logging.info(log_msg) else: # Client was provided directly model_name_for_config = model_name # Use provided model_name logging.info(f"Using provided client: {type(self.client)}") # Configure the actual model name and max_tokens for API calls self._configure_model_and_tokens( model_name_for_config, max_tokens_str, default_max_tokens=self.DEFAULT_MAX_TOKENS ) def _configure_model_and_tokens(self, model_name: str | None, max_tokens_str: str | None, default_max_tokens: int = 1000): self.model = model_name if model_name else "default-model" # Fallback model name try: # If max_tokens_str is explicitly "None" or empty, treat as None for API default if max_tokens_str and max_tokens_str.lower() not in ["none", "", "null"]: self.max_tokens = int(max_tokens_str) else: self.max_tokens = None # Use API default by not sending the parameter or sending null except ValueError: logging.warning(f"Invalid value for max_tokens: {max_tokens_str}. Using API default (None). stalwart default was {default_max_tokens}") self.max_tokens = None # Use API default logging.info(f"Configured to use model: {self.model} with max_tokens: {self.max_tokens if self.max_tokens is not None else 'API default'}") def get_llm_description(self) -> str: client_type = type(self.client).__name__ return f"Client: {client_type}, LLM: {self.model}, Max Tokens: {self.max_tokens if self.max_tokens is not None else 'API default'}" def get_chat_response(self, messages): if not self.client: # This should ideally not be hit if __init__ is successful logging.error("OpenAI client not initialized before get_chat_response.") raise ValueError("OpenAI client not initialized.") try: # Pass self.max_tokens directly. If None, OpenAI library omits it or handles it. response = self.client.chat.completions.create( model=self.model, messages=messages, tools=self.functions if hasattr(self, 'functions') and self.functions else None, tool_choice="auto" if hasattr(self, 'functions') and self.functions else None, max_tokens=self.max_tokens ) return response except Exception as e: logging.error(f"API call to model {self.model} failed: {e}") raise async def handle_message(self, user_id, user_message): if user_id not in self.conversation_history or not self.conversation_history[user_id]: self.conversation_history[user_id] = [] if self.system_prompt: # Use the loaded system_prompt self.conversation_history[user_id].append({"role": "system", "content": self.system_prompt}) self.conversation_history[user_id].append({"role": "user", "content": user_message}) messages = list(self.conversation_history[user_id]) # Work with a copy for this turn response = self.get_chat_response(messages) if not (response.choices and response.choices[0].message): logging.error("No valid response choice message from LLM.") # Persist the user message in history even if LLM fails this turn self.conversation_history[user_id] = messages return "Error: Could not get a valid response from the LLM." assistant_message = response.choices[0].message messages.append(assistant_message) tool_calls_from_response = list(assistant_message.tool_calls) if assistant_message.tool_calls else [] tool_use_count = 0 MAX_TOOL_ITERATIONS = 5 # OpenAI compatible typically uses fewer iterations than Anthropic while tool_calls_from_response and tool_use_count < MAX_TOOL_ITERATIONS: tool_results_for_model = [] for tool_call in tool_calls_from_response: tool_call_id = tool_call.id function_to_call = tool_call.function function_name = function_to_call.name function_args_str = function_to_call.arguments logging.info(f"Attempting to call tool: {function_name} with args: {function_args_str}") try: # Arguments are already a string from the API, self.call_tool expects dict or string tool_response_content = self.call_tool(function_name, function_args_str) # Ensure content is string for OpenAI tool role if not isinstance(tool_response_content, str): tool_response_content = json.dumps(tool_response_content) except Exception as e: logging.error(f"Error calling tool {function_name}: {e}") tool_response_content = f"Error executing tool {function_name}: {str(e)}" tool_results_for_model.append({ "role": "tool", "tool_call_id": tool_call_id, "name": function_name, "content": tool_response_content }) messages.extend(tool_results_for_model) response = self.get_chat_response(messages) if not (response.choices and response.choices[0].message): logging.error("No valid response choice message from LLM after tool call.") self.conversation_history[user_id] = messages # Persist state before error return "Error: Could not get a valid response from the LLM after tool call." assistant_message = response.choices[0].message messages.append(assistant_message) tool_calls_from_response = list(assistant_message.tool_calls) if assistant_message.tool_calls else [] tool_use_count += 1 if tool_use_count >= MAX_TOOL_ITERATIONS and tool_calls_from_response: logging.warning(f"Max tool iterations ({MAX_TOOL_ITERATIONS}) reached. Returning last assistant message.") # Ensure final content is returned even if max iterations hit with pending tool calls break self.conversation_history[user_id] = messages final_assistant_message = messages[-1] return final_assistant_message.content if final_assistant_message.role == "assistant" and final_assistant_message.content is not None else "Assistant did not provide a textual response." async def start(self): logging.info(f"{self.__class__.__name__} (Model: {self.model}) started.") # clear_conversation_history is inherited from BaseTelegramInferenceBot async def abort_processing(self, user_id): # This is a soft abort for OpenAI compatible bots as API calls are synchronous within handle_message if user_id in self.processing_status: self.clear_processing_status(user_id) # Use base class method logging.info(f"Processing aborted for user {user_id}.") # Optionally clear conversation history or let user do it explicitly # super().clear_conversation_history(user_id) return "Processing aborted. You can send a new message or /clear the conversation." else: # super().clear_conversation_history(user_id) return "No active processing found to abort. If you wish, /clear the conversation history." @abstractmethod async def switch_model(self): pass