Refactor: Generalize OpenAICompatibleInferenceBot initialization
This commit is contained in:
@@ -3,32 +3,114 @@ import os
|
|||||||
import logging
|
import logging
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from base_telegram_inference_bot import BaseTelegramInferenceBot
|
from base_telegram_inference_bot import BaseTelegramInferenceBot
|
||||||
from openai import OpenAI
|
from openai import OpenAI, AzureOpenAI # Import both
|
||||||
|
|
||||||
class OpenAICompatibleInferenceBot(BaseTelegramInferenceBot):
|
class OpenAICompatibleInferenceBot(BaseTelegramInferenceBot):
|
||||||
def __init__(self):
|
DEFAULT_MAX_HISTORY_LENGTH = 20
|
||||||
super().__init__()
|
DEFAULT_MAX_TOKENS = 1000 # Default for _configure_model_and_tokens
|
||||||
# Client and model configuration will be handled by subclasses
|
|
||||||
self.client = None
|
|
||||||
self.model = None
|
|
||||||
self.max_tokens = None
|
|
||||||
|
|
||||||
def _configure_model_and_tokens(self, model_name, max_tokens_str, default_max_tokens=1000):
|
def __init__(
|
||||||
self.model = model_name if model_name else "default-model"
|
self,
|
||||||
|
client: OpenAI | AzureOpenAI | None = None,
|
||||||
|
api_key: str | None = None,
|
||||||
|
base_url: str | None = None,
|
||||||
|
api_version: str | None = None, # For Azure
|
||||||
|
azure_deployment: str | None = None, # Model for Azure, distinct from general model_name if needed
|
||||||
|
model_name: str | None = None, # General model name for the API call
|
||||||
|
max_tokens_str: str | None = None,
|
||||||
|
system_prompt_content: str | None = None,
|
||||||
|
system_prompt_path: str | None = None,
|
||||||
|
is_gemini: bool = False, # Hint for specific API key if others are not set
|
||||||
|
max_history_length: int | None = None
|
||||||
|
):
|
||||||
|
super().__init__(system_prompt_content=system_prompt_content, system_prompt_path=system_prompt_path)
|
||||||
|
|
||||||
|
self.max_history_length = max_history_length if max_history_length is not None else self.DEFAULT_MAX_HISTORY_LENGTH
|
||||||
|
self.client = client
|
||||||
|
|
||||||
|
if not self.client:
|
||||||
|
_api_key = api_key
|
||||||
|
_base_url = base_url
|
||||||
|
_api_version = api_version
|
||||||
|
_azure_deployment_name = azure_deployment # This will be used as the model for Azure
|
||||||
|
|
||||||
|
# Determine if configuring for Azure OpenAI
|
||||||
|
is_azure = False
|
||||||
|
if _azure_deployment_name or (_base_url and "azure.com" in _base_url) or os.environ.get("AZURE_OPENAI_ENDPOINT"):
|
||||||
|
is_azure = True
|
||||||
|
|
||||||
|
if is_azure:
|
||||||
|
_base_url = _base_url or os.environ.get("AZURE_OPENAI_ENDPOINT")
|
||||||
|
_api_key = _api_key or os.environ.get("AZURE_OPENAI_KEY")
|
||||||
|
_api_version = _api_version or os.environ.get("AZURE_OPENAI_API_VERSION")
|
||||||
|
# For Azure, the model parameter in API calls is the deployment name
|
||||||
|
_effective_model_name = _azure_deployment_name or model_name # Use deployment if available, else model_name
|
||||||
|
if not _base_url or not _api_key or not _api_version or not _effective_model_name:
|
||||||
|
raise ValueError("For Azure OpenAI, endpoint, API key, API version, and deployment/model name must be configured.")
|
||||||
|
self.client = AzureOpenAI(
|
||||||
|
api_key=_api_key,
|
||||||
|
azure_endpoint=_base_url,
|
||||||
|
api_version=_api_version
|
||||||
|
)
|
||||||
|
# The model to be used in API calls for Azure is the deployment name.
|
||||||
|
# _configure_model_and_tokens will set self.model to this.
|
||||||
|
model_name_for_config = _effective_model_name
|
||||||
|
logging.info(f"Initialized AzureOpenAI client for deployment: {model_name_for_config} at {_base_url}")
|
||||||
|
else:
|
||||||
|
# Standard OpenAI or other OpenAI-compatible (like Gemini via base_url)
|
||||||
|
_base_url = _base_url or os.environ.get("OPENAI_API_BASE_URL") # For other compatible APIs
|
||||||
|
if not _api_key: # Try different ENV sources for API key
|
||||||
|
if is_gemini and os.environ.get("GEMINI_API_KEY"):
|
||||||
|
_api_key = os.environ.get("GEMINI_API_KEY")
|
||||||
|
else:
|
||||||
|
_api_key = os.environ.get("OPENAI_API_KEY")
|
||||||
|
|
||||||
|
if not _api_key and not _base_url : # For completely local models with no key needed via base_url
|
||||||
|
pass # Allow client to be created with no API key if base_url is set and points to local model
|
||||||
|
elif not _api_key:
|
||||||
|
raise ValueError("API key must be provided for OpenAI compatible client if not Azure or local anonymous.")
|
||||||
|
|
||||||
|
self.client = OpenAI(api_key=_api_key, base_url=_base_url)
|
||||||
|
model_name_for_config = model_name # Use the general model_name for non-Azure
|
||||||
|
log_msg = f"Initialized OpenAI compatible client. Target URL: {_base_url if _base_url else 'OpenAI default'}."
|
||||||
|
logging.info(log_msg)
|
||||||
|
else:
|
||||||
|
# Client was provided directly
|
||||||
|
model_name_for_config = model_name # Use provided model_name
|
||||||
|
logging.info(f"Using provided client: {type(self.client)}")
|
||||||
|
|
||||||
|
# Configure the actual model name and max_tokens for API calls
|
||||||
|
self._configure_model_and_tokens(
|
||||||
|
model_name_for_config,
|
||||||
|
max_tokens_str,
|
||||||
|
default_max_tokens=self.DEFAULT_MAX_TOKENS
|
||||||
|
)
|
||||||
|
|
||||||
|
def _configure_model_and_tokens(self, model_name: str | None, max_tokens_str: str | None, default_max_tokens: int = 1000):
|
||||||
|
self.model = model_name if model_name else "default-model" # Fallback model name
|
||||||
try:
|
try:
|
||||||
self.max_tokens = int(max_tokens_str) if max_tokens_str is not None else default_max_tokens
|
# If max_tokens_str is explicitly "None" or empty, treat as None for API default
|
||||||
|
if max_tokens_str and max_tokens_str.lower() not in ["none", "", "null"]:
|
||||||
|
self.max_tokens = int(max_tokens_str)
|
||||||
|
else:
|
||||||
|
self.max_tokens = None # Use API default by not sending the parameter or sending null
|
||||||
except ValueError:
|
except ValueError:
|
||||||
logging.error(f"Invalid value for max_tokens: {max_tokens_str}. Using default {default_max_tokens}.")
|
logging.warning(f"Invalid value for max_tokens: {max_tokens_str}. Using API default (None). stalwart default was {default_max_tokens}")
|
||||||
self.max_tokens = default_max_tokens
|
self.max_tokens = None # Use API default
|
||||||
logging.info(f"Configured to use model: {self.model} with max_tokens: {self.max_tokens}")
|
|
||||||
|
logging.info(f"Configured to use model: {self.model} with max_tokens: {self.max_tokens if self.max_tokens is not None else 'API default'}")
|
||||||
|
|
||||||
def get_llm_description(self) -> str:
|
def get_llm_description(self) -> str:
|
||||||
return f"LLM: {self.model}, Max Tokens: {self.max_tokens}"
|
client_type = type(self.client).__name__
|
||||||
|
return f"Client: {client_type}, LLM: {self.model}, Max Tokens: {self.max_tokens if self.max_tokens is not None else 'API default'}"
|
||||||
|
|
||||||
def get_chat_response(self, messages):
|
def get_chat_response(self, messages):
|
||||||
if not self.client:
|
if not self.client:
|
||||||
raise ValueError("OpenAI client not initialized. Subclasses must initialize it.")
|
# This should ideally not be hit if __init__ is successful
|
||||||
|
logging.error("OpenAI client not initialized before get_chat_response.")
|
||||||
|
raise ValueError("OpenAI client not initialized.")
|
||||||
try:
|
try:
|
||||||
|
# Pass self.max_tokens directly. If None, OpenAI library omits it or handles it.
|
||||||
response = self.client.chat.completions.create(
|
response = self.client.chat.completions.create(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
@@ -38,32 +120,33 @@ class OpenAICompatibleInferenceBot(BaseTelegramInferenceBot):
|
|||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"API call failed: {e}")
|
logging.error(f"API call to model {self.model} failed: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def handle_message(self, user_id, user_message):
|
async def handle_message(self, user_id, user_message):
|
||||||
if user_id not in self.conversation_history:
|
if user_id not in self.conversation_history or not self.conversation_history[user_id]:
|
||||||
self.conversation_history[user_id] = []
|
self.conversation_history[user_id] = []
|
||||||
if hasattr(self, 'system_prompt') and self.system_prompt:
|
if self.system_prompt: # Use the loaded system_prompt
|
||||||
self.conversation_history[user_id].append({"role": "system", "content": self.system_prompt})
|
self.conversation_history[user_id].append({"role": "system", "content": self.system_prompt})
|
||||||
|
|
||||||
self.conversation_history[user_id].append({"role": "user", "content": user_message})
|
self.conversation_history[user_id].append({"role": "user", "content": user_message})
|
||||||
messages = self.conversation_history[user_id]
|
messages = list(self.conversation_history[user_id]) # Work with a copy for this turn
|
||||||
|
|
||||||
response = self.get_chat_response(messages)
|
response = self.get_chat_response(messages)
|
||||||
|
|
||||||
if not (response.choices and response.choices[0].message):
|
if not (response.choices and response.choices[0].message):
|
||||||
logging.error("No valid response choice message from LLM.")
|
logging.error("No valid response choice message from LLM.")
|
||||||
|
# Persist the user message in history even if LLM fails this turn
|
||||||
|
self.conversation_history[user_id] = messages
|
||||||
return "Error: Could not get a valid response from the LLM."
|
return "Error: Could not get a valid response from the LLM."
|
||||||
|
|
||||||
messages.append(response.choices[0].message) # Append the assistant's response message
|
assistant_message = response.choices[0].message
|
||||||
|
messages.append(assistant_message)
|
||||||
|
|
||||||
tool_calls_from_response = []
|
tool_calls_from_response = list(assistant_message.tool_calls) if assistant_message.tool_calls else []
|
||||||
if response.choices[0].message.tool_calls:
|
|
||||||
tool_calls_from_response.extend(response.choices[0].message.tool_calls)
|
|
||||||
|
|
||||||
tool_use_count = 0
|
tool_use_count = 0
|
||||||
MAX_TOOL_ITERATIONS = 200
|
MAX_TOOL_ITERATIONS = 5 # OpenAI compatible typically uses fewer iterations than Anthropic
|
||||||
|
|
||||||
while tool_calls_from_response and tool_use_count < MAX_TOOL_ITERATIONS:
|
while tool_calls_from_response and tool_use_count < MAX_TOOL_ITERATIONS:
|
||||||
tool_results_for_model = []
|
tool_results_for_model = []
|
||||||
@@ -71,20 +154,24 @@ class OpenAICompatibleInferenceBot(BaseTelegramInferenceBot):
|
|||||||
for tool_call in tool_calls_from_response:
|
for tool_call in tool_calls_from_response:
|
||||||
tool_call_id = tool_call.id
|
tool_call_id = tool_call.id
|
||||||
function_to_call = tool_call.function
|
function_to_call = tool_call.function
|
||||||
|
function_name = function_to_call.name
|
||||||
|
function_args_str = function_to_call.arguments
|
||||||
|
|
||||||
logging.info(f"Attempting to call tool: {function_to_call.name} with args: {function_to_call.arguments}")
|
logging.info(f"Attempting to call tool: {function_name} with args: {function_args_str}")
|
||||||
try:
|
try:
|
||||||
tool_response_content = self.call_tool(function_to_call.name, function_to_call.arguments)
|
# Arguments are already a string from the API, self.call_tool expects dict or string
|
||||||
|
tool_response_content = self.call_tool(function_name, function_args_str)
|
||||||
|
# Ensure content is string for OpenAI tool role
|
||||||
if not isinstance(tool_response_content, str):
|
if not isinstance(tool_response_content, str):
|
||||||
tool_response_content = json.dumps(tool_response_content)
|
tool_response_content = json.dumps(tool_response_content)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error calling tool {function_to_call.name}: {e}")
|
logging.error(f"Error calling tool {function_name}: {e}")
|
||||||
tool_response_content = f"Error executing tool {function_to_call.name}: {str(e)}"
|
tool_response_content = f"Error executing tool {function_name}: {str(e)}"
|
||||||
|
|
||||||
tool_results_for_model.append({
|
tool_results_for_model.append({
|
||||||
"role": "tool",
|
"role": "tool",
|
||||||
"tool_call_id": tool_call_id,
|
"tool_call_id": tool_call_id,
|
||||||
"name": function_to_call.name,
|
"name": function_name,
|
||||||
"content": tool_response_content
|
"content": tool_response_content
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -93,40 +180,50 @@ class OpenAICompatibleInferenceBot(BaseTelegramInferenceBot):
|
|||||||
response = self.get_chat_response(messages)
|
response = self.get_chat_response(messages)
|
||||||
if not (response.choices and response.choices[0].message):
|
if not (response.choices and response.choices[0].message):
|
||||||
logging.error("No valid response choice message from LLM after tool call.")
|
logging.error("No valid response choice message from LLM after tool call.")
|
||||||
|
self.conversation_history[user_id] = messages # Persist state before error
|
||||||
return "Error: Could not get a valid response from the LLM after tool call."
|
return "Error: Could not get a valid response from the LLM after tool call."
|
||||||
|
|
||||||
messages.append(response.choices[0].message)
|
assistant_message = response.choices[0].message
|
||||||
|
messages.append(assistant_message)
|
||||||
|
|
||||||
tool_calls_from_response = []
|
tool_calls_from_response = list(assistant_message.tool_calls) if assistant_message.tool_calls else []
|
||||||
if response.choices[0].message.tool_calls:
|
|
||||||
tool_calls_from_response.extend(response.choices[0].message.tool_calls)
|
|
||||||
|
|
||||||
tool_use_count += 1
|
tool_use_count += 1
|
||||||
if tool_use_count >= MAX_TOOL_ITERATIONS and tool_calls_from_response:
|
if tool_use_count >= MAX_TOOL_ITERATIONS and tool_calls_from_response:
|
||||||
logging.warning(f"Max tool iterations ({MAX_TOOL_ITERATIONS}) reached. Returning last assistant message.")
|
logging.warning(f"Max tool iterations ({MAX_TOOL_ITERATIONS}) reached. Returning last assistant message.")
|
||||||
|
# Ensure final content is returned even if max iterations hit with pending tool calls
|
||||||
|
break
|
||||||
|
|
||||||
# Conversation history management
|
self.conversation_history[user_id] = messages # Persist the full exchange for this turn
|
||||||
# This limit should be reviewed and potentially made configurable
|
# Apply history length limit
|
||||||
if len(self.conversation_history[user_id]) > 20: # Example limit, adjust as needed
|
if len(self.conversation_history[user_id]) > self.max_history_length:
|
||||||
self.conversation_history[user_id] = self.conversation_history[user_id][-20:]
|
# Keep system prompt if present as the first message, then trim the rest
|
||||||
|
if self.conversation_history[user_id][0]["role"] == "system":
|
||||||
|
system_msg = [self.conversation_history[user_id][0]]
|
||||||
|
trimmed_history = self.conversation_history[user_id][-(self.max_history_length-1):]
|
||||||
|
self.conversation_history[user_id] = system_msg + trimmed_history
|
||||||
|
else:
|
||||||
|
self.conversation_history[user_id] = self.conversation_history[user_id][-self.max_history_length:]
|
||||||
|
|
||||||
final_assistant_message = messages[-1]
|
final_assistant_message = messages[-1]
|
||||||
return final_assistant_message.content if final_assistant_message.role == "assistant" and final_assistant_message.content else "No content in final message."
|
return final_assistant_message.content if final_assistant_message.role == "assistant" and final_assistant_message.content is not None else "Assistant did not provide a textual response."
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
logging.info(f"{self.__class__.__name__} started.")
|
logging.info(f"{self.__class__.__name__} (Model: {self.model}) started.")
|
||||||
|
|
||||||
def clear(self, user_id):
|
# clear_conversation_history is inherited from BaseTelegramInferenceBot
|
||||||
super().clear_conversation_history(user_id)
|
|
||||||
|
|
||||||
async def abort_processing(self, user_id):
|
async def abort_processing(self, user_id):
|
||||||
|
# This is a soft abort for OpenAI compatible bots as API calls are synchronous within handle_message
|
||||||
if user_id in self.processing_status:
|
if user_id in self.processing_status:
|
||||||
self.processing_status[user_id]["processing"] = False
|
self.clear_processing_status(user_id) # Use base class method
|
||||||
self.clear(user_id)
|
logging.info(f"Processing aborted for user {user_id}.")
|
||||||
return "Processing aborted and conversation cleared."
|
# Optionally clear conversation history or let user do it explicitly
|
||||||
|
# super().clear_conversation_history(user_id)
|
||||||
|
return "Processing aborted. You can send a new message or /clear the conversation."
|
||||||
else:
|
else:
|
||||||
self.clear(user_id)
|
# super().clear_conversation_history(user_id)
|
||||||
return "No active processing found to abort. Conversation cleared."
|
return "No active processing found to abort. If you wish, /clear the conversation history."
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def switch_model(self):
|
async def switch_model(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user