2025-06-02 14:56:23 -05:00
|
|
|
import json
|
|
|
|
|
import os
|
|
|
|
|
import logging
|
|
|
|
|
from abc import abstractmethod
|
|
|
|
|
from base_telegram_inference_bot import BaseTelegramInferenceBot
|
|
|
|
|
from openai import OpenAI
|
|
|
|
|
|
|
|
|
|
class OpenAICompatibleInferenceBot(BaseTelegramInferenceBot):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super().__init__()
|
|
|
|
|
# Client and model configuration will be handled by subclasses
|
|
|
|
|
self.client = None
|
|
|
|
|
self.model = None
|
|
|
|
|
self.max_tokens = None
|
|
|
|
|
|
|
|
|
|
def _configure_model_and_tokens(self, model_name, max_tokens_str, default_max_tokens=1000):
|
|
|
|
|
self.model = model_name if model_name else "default-model"
|
|
|
|
|
try:
|
|
|
|
|
self.max_tokens = int(max_tokens_str) if max_tokens_str is not None else default_max_tokens
|
|
|
|
|
except ValueError:
|
|
|
|
|
logging.error(f"Invalid value for max_tokens: {max_tokens_str}. Using default {default_max_tokens}.")
|
|
|
|
|
self.max_tokens = default_max_tokens
|
|
|
|
|
logging.info(f"Configured to use model: {self.model} with max_tokens: {self.max_tokens}")
|
|
|
|
|
|
|
|
|
|
def get_llm_description(self) -> str:
|
|
|
|
|
return f"LLM: {self.model}, Max Tokens: {self.max_tokens}"
|
|
|
|
|
|
|
|
|
|
def get_chat_response(self, messages):
|
|
|
|
|
if not self.client:
|
|
|
|
|
raise ValueError("OpenAI client not initialized. Subclasses must initialize it.")
|
|
|
|
|
try:
|
|
|
|
|
response = self.client.chat.completions.create(
|
|
|
|
|
model=self.model,
|
|
|
|
|
messages=messages,
|
|
|
|
|
tools=self.functions if hasattr(self, 'functions') and self.functions else None,
|
|
|
|
|
tool_choice="auto" if hasattr(self, 'functions') and self.functions else None,
|
|
|
|
|
max_tokens=self.max_tokens
|
|
|
|
|
)
|
|
|
|
|
return response
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logging.error(f"API call failed: {e}")
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
async def handle_message(self, user_id, user_message):
|
|
|
|
|
if user_id not in self.conversation_history:
|
|
|
|
|
self.conversation_history[user_id] = []
|
|
|
|
|
if hasattr(self, 'system_prompt') and self.system_prompt:
|
|
|
|
|
self.conversation_history[user_id].append({"role": "system", "content": self.system_prompt})
|
|
|
|
|
|
|
|
|
|
self.conversation_history[user_id].append({"role": "user", "content": user_message})
|
|
|
|
|
messages = self.conversation_history[user_id]
|
|
|
|
|
|
|
|
|
|
response = self.get_chat_response(messages)
|
|
|
|
|
|
|
|
|
|
if not (response.choices and response.choices[0].message):
|
|
|
|
|
logging.error("No valid response choice message from LLM.")
|
|
|
|
|
return "Error: Could not get a valid response from the LLM."
|
|
|
|
|
|
|
|
|
|
messages.append(response.choices[0].message) # Append the assistant's response message
|
|
|
|
|
|
|
|
|
|
tool_calls_from_response = []
|
|
|
|
|
if response.choices[0].message.tool_calls:
|
|
|
|
|
tool_calls_from_response.extend(response.choices[0].message.tool_calls)
|
|
|
|
|
|
|
|
|
|
tool_use_count = 0
|
2025-06-02 15:23:20 -05:00
|
|
|
MAX_TOOL_ITERATIONS = 200
|
2025-06-02 14:56:23 -05:00
|
|
|
|
|
|
|
|
while tool_calls_from_response and tool_use_count < MAX_TOOL_ITERATIONS:
|
|
|
|
|
tool_results_for_model = []
|
|
|
|
|
|
|
|
|
|
for tool_call in tool_calls_from_response:
|
|
|
|
|
tool_call_id = tool_call.id
|
|
|
|
|
function_to_call = tool_call.function
|
|
|
|
|
|
|
|
|
|
logging.info(f"Attempting to call tool: {function_to_call.name} with args: {function_to_call.arguments}")
|
|
|
|
|
try:
|
|
|
|
|
tool_response_content = self.call_tool(function_to_call.name, function_to_call.arguments)
|
|
|
|
|
if not isinstance(tool_response_content, str):
|
|
|
|
|
tool_response_content = json.dumps(tool_response_content)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logging.error(f"Error calling tool {function_to_call.name}: {e}")
|
|
|
|
|
tool_response_content = f"Error executing tool {function_to_call.name}: {str(e)}"
|
|
|
|
|
|
|
|
|
|
tool_results_for_model.append({
|
|
|
|
|
"role": "tool",
|
|
|
|
|
"tool_call_id": tool_call_id,
|
|
|
|
|
"name": function_to_call.name,
|
|
|
|
|
"content": tool_response_content
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
messages.extend(tool_results_for_model)
|
|
|
|
|
|
|
|
|
|
response = self.get_chat_response(messages)
|
|
|
|
|
if not (response.choices and response.choices[0].message):
|
|
|
|
|
logging.error("No valid response choice message from LLM after tool call.")
|
|
|
|
|
return "Error: Could not get a valid response from the LLM after tool call."
|
|
|
|
|
|
|
|
|
|
messages.append(response.choices[0].message)
|
|
|
|
|
|
|
|
|
|
tool_calls_from_response = []
|
|
|
|
|
if response.choices[0].message.tool_calls:
|
|
|
|
|
tool_calls_from_response.extend(response.choices[0].message.tool_calls)
|
|
|
|
|
|
|
|
|
|
tool_use_count += 1
|
|
|
|
|
if tool_use_count >= MAX_TOOL_ITERATIONS and tool_calls_from_response:
|
|
|
|
|
logging.warning(f"Max tool iterations ({MAX_TOOL_ITERATIONS}) reached. Returning last assistant message.")
|
|
|
|
|
|
|
|
|
|
# Conversation history management
|
|
|
|
|
# This limit should be reviewed and potentially made configurable
|
|
|
|
|
if len(self.conversation_history[user_id]) > 20: # Example limit, adjust as needed
|
|
|
|
|
self.conversation_history[user_id] = self.conversation_history[user_id][-20:]
|
|
|
|
|
|
|
|
|
|
final_assistant_message = messages[-1]
|
|
|
|
|
return final_assistant_message.content if final_assistant_message.role == "assistant" and final_assistant_message.content else "No content in final message."
|
|
|
|
|
|
|
|
|
|
async def start(self):
|
|
|
|
|
logging.info(f"{self.__class__.__name__} started.")
|
|
|
|
|
|
|
|
|
|
async def clear(self, user_id):
|
|
|
|
|
super().clear_conversation_history(user_id)
|
|
|
|
|
|
|
|
|
|
async def abort_processing(self, user_id):
|
|
|
|
|
if user_id in self.processing_status:
|
|
|
|
|
self.processing_status[user_id]["processing"] = False
|
|
|
|
|
await self.clear(user_id)
|
|
|
|
|
return "Processing aborted and conversation cleared."
|
|
|
|
|
else:
|
|
|
|
|
await self.clear(user_id)
|
|
|
|
|
return "No active processing found to abort. Conversation cleared."
|
|
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
|
async def switch_model(self):
|
|
|
|
|
pass
|