From 9b16ca0d85303637eaaf3bee293ff4665643d089 Mon Sep 17 00:00:00 2001 From: bucolucas Date: Mon, 19 Aug 2024 11:35:10 -0500 Subject: [PATCH] Refactor ChatGPTTelegramInferenceBot to inherit from BaseTelegramInferenceBot --- chatgpt_telegram_inference_bot.py | 250 ++++++++---------------------- 1 file changed, 61 insertions(+), 189 deletions(-) diff --git a/chatgpt_telegram_inference_bot.py b/chatgpt_telegram_inference_bot.py index f7c6052..6cb5e02 100644 --- a/chatgpt_telegram_inference_bot.py +++ b/chatgpt_telegram_inference_bot.py @@ -1,225 +1,97 @@ import json import os -import importlib -import inspect import logging -import asyncio -from telegram import error as TelegramErrors, Update, __version__ as telegram_version, InlineKeyboardButton, InlineKeyboardMarkup -from telegram.ext import Application, CommandHandler, MessageHandler, filters, ContextTypes, CallbackQueryHandler -from dotenv import load_dotenv -from tools.base_tool import BaseTool -from tools.metrics_tool import MetricsTool +from base_telegram_inference_bot import BaseTelegramInferenceBot from openai import OpenAI -# Load environment variables -load_dotenv() +class ChatGPTTelegramInferenceBot(BaseTelegramInferenceBot): + def __init__(self): + super().__init__() + self.client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) + self.model = "gpt-4o-mini" + self.max_tokens = 16384 -client = OpenAI( - api_key=os.environ.get("OPENAI_API_KEY"), -) + def get_chat_response(self, messages): + response = self.client.chat.completions.create( + model=self.model, + messages=[{"role": "system", "content": self.system_prompt}] + messages, + functions=self.functions, + function_call="auto", + max_tokens=self.max_tokens + ) + return response -GPT_4O = "gpt-4o" -GPT_4O_MINI = "gpt-4o-mini" + async def handle_message(self, user_id, user_message): + if user_id not in self.conversation_history: + self.conversation_history[user_id] = [] -model_max_tokens = { - GPT_4O: 4096, - GPT_4O_MINI: 16384 -} + self.conversation_history[user_id].append({"role": "user", "content": user_message}) + messages = self.conversation_history[user_id] -use_smart_model = False - -# Set up logging to console and file -logging.basicConfig(level=logging.WARNING, handlers=[ - logging.StreamHandler(), - logging.FileHandler('logs/output.log', mode='a') -]) - -# Set up Telegram bot -TELEGRAM_BOT_TOKEN = os.getenv('TELEGRAM_BOT_TOKEN') - -# Load system prompt -with open("prompts/developer_prompt.txt", "r") as file: - system_prompt = file.read().strip() - -# Dictionary to store conversation history for each user -conversation_history = {} - -# Dictionary to store processing status for each user -processing_status = {} - -# Load tools -tools = [MetricsTool()] # Add MetricsTool instance -tools_dir = os.path.join(os.path.dirname(__file__), 'tools') -for filename in os.listdir(tools_dir): - if filename.endswith('.py') and filename not in ['__init__.py', 'base_tool.py', 'metrics_tool.py']: - module_name = f'tools.{filename[:-3]}' - module = importlib.import_module(module_name) - for name, obj in inspect.getmembers(module): - if inspect.isclass(obj) and issubclass(obj, BaseTool) and obj != BaseTool: - tools.append(obj()) - -# Collect all function definitions -functions = [] -for tool in tools: - functions.extend(tool.get_functions()) - -async def start(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: - logging.info("Bot started") - await update.message.reply_text( - "Hello! I'm your AI assistant. How can I help you today? You can send me images and then ask questions about them." - ) - -async def clear(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: - user_id = update.effective_user.id - if user_id in conversation_history: - del conversation_history[user_id] - for tool in tools: - tool.clear() - - logging.info(f"Cleared conversation history and image for user {user_id}") - await update.message.reply_text("Conversation history and image cleared. Let's start fresh!") - -async def update_status_message(context: ContextTypes.DEFAULT_TYPE, chat_id: int, message_id: int, status: str): - keyboard = [ - [InlineKeyboardButton("Abort", callback_data='abort')] - ] - reply_markup = InlineKeyboardMarkup(keyboard) - await context.bot.edit_message_text( - chat_id=chat_id, - message_id=message_id, - text=f"Current status: {status}", - reply_markup=reply_markup - ) - -async def handle_message(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: - try: - user_id = update.effective_user.id - user_message = update.message.text - - logging.info(f"Message from user {user_id}: {user_message}") - - if user_id not in conversation_history: - conversation_history[user_id] = [] - - conversation_history[user_id].append({"role": "user", "content": user_message}) - - # Send initial status message - status_message = await update.message.reply_text("Processing your request...", reply_markup=InlineKeyboardMarkup([[InlineKeyboardButton("Abort", callback_data='abort')]])) - processing_status[user_id] = {"processing": True, "message_id": status_message.message_id} - - messages = conversation_history[user_id] - - response = get_chat_response(messages) + response = self.get_chat_response(messages) assistant_message = response.choices[0].message tool_calls = [] if hasattr(assistant_message, 'function_call') and assistant_message.function_call is not None: tool_calls.append(assistant_message.function_call) - toolUseCount = 0 - previous_function_name = "" - - while len(tool_calls) > 0 and toolUseCount < 50 and processing_status[user_id]["processing"]: + tool_use_count = 0 + while len(tool_calls) > 0 and tool_use_count < 50: tool_use_results = [] - while len(tool_calls) > 0: - tool_call = tool_calls.pop(0) - function_name = tool_call.name - - if function_name != previous_function_name: - # Update status message - await update_status_message(context, update.effective_chat.id, status_message.message_id, f"Using tool: {function_name}") - previous_function_name = function_name - - tool_response = call_tool(tool_call) - tool_use_results.append({"role": "function", "name": function_name, "content": json.dumps(tool_response)}) + for tool_call in tool_calls: + tool_response = self.call_tool(tool_call) + tool_use_results.append({"role": "function", "name": tool_call.name, "content": json.dumps(tool_response)}) messages.extend(tool_use_results) - response = get_chat_response(messages) + response = self.get_chat_response(messages) assistant_message = response.choices[0].message messages.append({"role": "assistant", "content": assistant_message.content}) + tool_calls = [] if hasattr(assistant_message, 'function_call') and assistant_message.function_call is not None: tool_calls.append(assistant_message.function_call) - toolUseCount += 1 + tool_use_count += 1 - if toolUseCount == 0: + if tool_use_count == 0: messages.append({"role": "assistant", "content": assistant_message.content}) - if len(conversation_history[user_id]) > 20: - conversation_history[user_id] = conversation_history[user_id][-20:] + if len(self.conversation_history[user_id]) > 20: + self.conversation_history[user_id] = self.conversation_history[user_id][-20:] - # Remove the status message - await context.bot.delete_message(chat_id=update.effective_chat.id, message_id=status_message.message_id) - del processing_status[user_id] - try: - await update.message.reply_text(messages[-1]["content"]) - except TelegramErrors.BadRequest as e: - logging.error(f"An error occurred when trying to send a message in telegram: {str(e)}") + return messages[-1]["content"] - except Exception as e: - logging.error(f"An error occurred: {str(e)}") - await update.message.reply_text("Sorry, an error occurred while processing your request.") + async def start(self): + logging.info("Bot started") -def call_tool(function_call): - function_name = function_call.name - function_args = json.loads(function_call.arguments) - for tool in tools: - if function_name in [f["name"] for f in tool.get_functions()]: - return tool.execute(function_name, **function_args) + async def clear(self, user_id): + super().clear_conversation(user_id) + logging.info(f"Cleared conversation history for user {user_id}") -def get_chat_response(messages): - model = GPT_4O if use_smart_model else GPT_4O_MINI - response = client.chat.completions.create( - model=model, - messages = [{"role": "system", "content": system_prompt}] + messages, - functions=functions, - function_call="auto", - max_tokens=model_max_tokens[model] - ) - return response + async def status(self): + return f"Currently using: {self.model}" -async def switch(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: - global use_smart_model - use_smart_model = not use_smart_model - model = GPT_4O if use_smart_model else GPT_4O_MINI - logging.info(f"Switched to model: {model}") - await update.message.reply_text(f"Switched to model: {model}") + async def abort_processing(self, user_id): + if user_id in self.processing_status: + self.processing_status[user_id]["processing"] = False + await self.clear(user_id) + return "Processing aborted." + else: + return "No active processing to abort." -async def status(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: - model = GPT_4O if use_smart_model else GPT_4O_MINI - await update.message.reply_text(f"Currently using: {model}") + async def switch_model(self): + if self.model == "gpt-4o-mini": + self.model = "gpt-4o" + self.max_tokens = 4096 + else: + self.model = "gpt-4o-mini" + self.max_tokens = 16384 + logging.info(f"Switched to model: {self.model}") + return f"Switched to model: {self.model}" -async def abort_processing(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: - query = update.callback_query - await query.answer() - - user_id = query.from_user.id - if user_id in processing_status: - processing_status[user_id]["processing"] = False - await context.bot.edit_message_text( - chat_id=query.message.chat_id, - message_id=query.message.message_id, - text="Processing aborted." - ) - await clear(update, context) - else: - await query.edit_message_text(text="No active processing to abort.") - -def main() -> None: - # Create the Application and pass it your bot's token - application = Application.builder().token(TELEGRAM_BOT_TOKEN).build() - - # Add handlers - application.add_handler(CommandHandler("start", start)) - application.add_handler(CommandHandler("clear", clear)) - application.add_handler(CommandHandler("switch", switch)) - application.add_handler(CommandHandler("status", status)) - application.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, handle_message)) - application.add_handler(CallbackQueryHandler(abort_processing, pattern='^abort$')) - - # Start the Bot - logging.info("Bot is running...") - application.run_polling() +def main(): + bot = ChatGPTTelegramInferenceBot() + telegram_helper = TelegramHelper(bot) + telegram_helper.run() if __name__ == '__main__': main() \ No newline at end of file