Refactor chatgpt_telegram_inference_bot.py to use OpenAICompatibleInferenceBot.

This commit is contained in:
cyclop-bot
2025-06-02 14:56:28 -05:00
parent 56ffb70af0
commit af8fbfec80
+7 -137
View File
@@ -1,156 +1,27 @@
import json
import os import os
import logging import logging
from base_telegram_inference_bot import BaseTelegramInferenceBot
from telegram_helper import TelegramHelper
from openai import OpenAI from openai import OpenAI
from openai_compatible_inference_bot import OpenAICompatibleInferenceBot
from telegram_helper import TelegramHelper
# logging.basicConfig(level=logging.INFO) # Usually configured in main execution script class ChatGPTTelegramInferenceBot(OpenAICompatibleInferenceBot):
class ChatGPTTelegramInferenceBot(BaseTelegramInferenceBot):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) self.client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
self._configure_model_and_tokens( self._configure_model_and_tokens(
os.environ.get("OPENAI_SMALL_MODEL", "gpt-3.5-turbo"), # Default to a common small model os.environ.get("OPENAI_SMALL_MODEL", "gpt-3.5-turbo"),
os.environ.get("OPENAI_SMALL_MODEL_MAX_TOKENS") os.environ.get("OPENAI_SMALL_MODEL_MAX_TOKENS")
) )
def _configure_model_and_tokens(self, model_name, max_tokens_str, default_max_tokens=1000):
self.model = model_name if model_name else "gpt-3.5-turbo" # Ensure model has a default
try:
self.max_tokens = int(max_tokens_str) if max_tokens_str is not None else default_max_tokens
except ValueError:
logging.error(f"Invalid value for max_tokens: {max_tokens_str}. Using default {default_max_tokens}.")
self.max_tokens = default_max_tokens
logging.info(f"Configured to use model: {self.model} with max_tokens: {self.max_tokens}")
def get_system_prompt_description(self) -> str:
system_prompt_path = os.getenv("SYSTEM_PROMPT_PATH")
if system_prompt_path and os.path.isfile(system_prompt_path):
return f"System Prompt File: {os.path.basename(system_prompt_path)}"
elif system_prompt_path: # Path is set but file not found
return f"System Prompt File: {os.path.basename(system_prompt_path)} (Not found at path: {system_prompt_path})"
else: # Path not set
return "System Prompt File: Not configured (SYSTEM_PROMPT_PATH not set)."
def get_llm_description(self) -> str:
return f"LLM: {self.model}, Max Tokens: {self.max_tokens}"
def get_chat_response(self, messages):
try:
response = self.client.chat.completions.create(
model=self.model,
messages=messages,
tools=self.functions if hasattr(self, 'functions') and self.functions else None,
tool_choice="auto" if hasattr(self, 'functions') and self.functions else None,
max_tokens=self.max_tokens
)
return response
except Exception as e:
logging.error(f"OpenAI API call failed: {e}")
raise
async def handle_message(self, user_id, user_message):
if user_id not in self.conversation_history:
self.conversation_history[user_id] = []
if hasattr(self, 'system_prompt') and self.system_prompt:
self.conversation_history[user_id].append({"role": "system", "content": self.system_prompt})
self.conversation_history[user_id].append({"role": "user", "content": user_message})
messages = self.conversation_history[user_id]
response = self.get_chat_response(messages)
if not (response.choices and response.choices[0].message):
logging.error("No valid response choice message from LLM.")
return "Error: Could not get a valid response from the LLM."
messages.append(response.choices[0].message) # Append the assistant's response message
tool_calls_from_response = []
if response.choices[0].message.tool_calls:
tool_calls_from_response.extend(response.choices[0].message.tool_calls)
tool_use_count = 0
MAX_TOOL_ITERATIONS = 5
while tool_calls_from_response and tool_use_count < MAX_TOOL_ITERATIONS:
tool_results_for_model = []
for tool_call in tool_calls_from_response:
tool_call_id = tool_call.id
function_to_call = tool_call.function
logging.info(f"Attempting to call tool: {function_to_call.name} with args: {function_to_call.arguments}")
try:
tool_response_content = self.call_tool(function_to_call.name, function_to_call.arguments)
if not isinstance(tool_response_content, str):
tool_response_content = json.dumps(tool_response_content)
except Exception as e:
logging.error(f"Error calling tool {function_to_call.name}: {e}")
tool_response_content = f"Error executing tool {function_to_call.name}: {str(e)}"
tool_results_for_model.append({
"role": "tool",
"tool_call_id": tool_call_id,
"name": function_to_call.name,
"content": tool_response_content
})
messages.extend(tool_results_for_model)
response = self.get_chat_response(messages)
if not (response.choices and response.choices[0].message):
logging.error("No valid response choice message from LLM after tool call.")
return "Error: Could not get a valid response from the LLM after tool call."
messages.append(response.choices[0].message)
tool_calls_from_response = []
if response.choices[0].message.tool_calls:
tool_calls_from_response.extend(response.choices[0].message.tool_calls)
tool_use_count += 1
if tool_use_count >= MAX_TOOL_ITERATIONS and tool_calls_from_response:
logging.warning(f"Max tool iterations ({MAX_TOOL_ITERATIONS}) reached. Returning last assistant message.")
if len(self.conversation_history[user_id]) > 20: # This limit seems small, consider increasing
self.conversation_history[user_id] = self.conversation_history[user_id][-20:]
final_assistant_message = messages[-1]
return final_assistant_message.content if final_assistant_message.role == "assistant" and final_assistant_message.content else "No content in final message."
async def start(self):
logging.info("ChatGPT Bot started")
# super().start() if Base class start() has common logic
async def clear(self, user_id):
super().clear_conversation(user_id)
# status() method is inherited from BaseTelegramInferenceBot
async def abort_processing(self, user_id):
if user_id in self.processing_status: # Relies on processing_status from Base
self.processing_status[user_id]["processing"] = False
await self.clear(user_id)
return "Processing aborted and conversation cleared."
else:
await self.clear(user_id)
return "No active processing found to abort. Conversation cleared."
async def switch_model(self): async def switch_model(self):
# Ensure environment variables for model names are set for this to work meaningfully
current_small_model = os.environ.get("OPENAI_SMALL_MODEL", "gpt-3.5-turbo") current_small_model = os.environ.get("OPENAI_SMALL_MODEL", "gpt-3.5-turbo")
current_large_model = os.environ.get("OPENAI_LARGE_MODEL", "gpt-4") # Example large model current_large_model = os.environ.get("OPENAI_LARGE_MODEL", "gpt-4")
# Default to small model if current model is not recognized or if it's the large one if self.model == current_large_model or self.model != current_small_model:
if self.model == current_large_model or self.model != current_small_model :
target_model = current_small_model target_model = current_small_model
target_max_tokens = os.environ.get("OPENAI_SMALL_MODEL_MAX_TOKENS") target_max_tokens = os.environ.get("OPENAI_SMALL_MODEL_MAX_TOKENS")
else: # Current is small (or default), switch to large else:
target_model = current_large_model target_model = current_large_model
target_max_tokens = os.environ.get("OPENAI_LARGE_MODEL_MAX_TOKENS") target_max_tokens = os.environ.get("OPENAI_LARGE_MODEL_MAX_TOKENS")
@@ -163,7 +34,6 @@ def main():
logging.error("FATAL: OPENAI_API_KEY environment variable not set.") logging.error("FATAL: OPENAI_API_KEY environment variable not set.")
return return
# Configure logging here if it's the main entry point
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
bot = ChatGPTTelegramInferenceBot() bot = ChatGPTTelegramInferenceBot()