Added tests
This commit is contained in:
+63
-11
@@ -1,6 +1,8 @@
|
||||
import os
|
||||
import importlib
|
||||
import inspect
|
||||
import tempfile
|
||||
import base64
|
||||
from telegram import Update
|
||||
from telegram.ext import Application, CommandHandler, MessageHandler, filters, ContextTypes
|
||||
from openai import OpenAI
|
||||
@@ -12,12 +14,17 @@ load_dotenv()
|
||||
|
||||
client = OpenAI()
|
||||
|
||||
GPT_4O = "gpt-4o"
|
||||
GPT_4O_MINI = "gpt-4o-mini"
|
||||
# Set up Telegram bot
|
||||
TELEGRAM_BOT_TOKEN = os.getenv('TELEGRAM_BOT_TOKEN')
|
||||
|
||||
# Dictionary to store conversation history for each user
|
||||
conversation_history = {}
|
||||
|
||||
# Dictionary to store the last image file for each user
|
||||
user_images = {}
|
||||
|
||||
# Load tools
|
||||
tools = []
|
||||
tools_dir = os.path.join(os.path.dirname(__file__), 'tools')
|
||||
@@ -35,13 +42,32 @@ for tool in tools:
|
||||
functions.extend(tool.get_functions())
|
||||
|
||||
async def start(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
await update.message.reply_text("Hello! I'm your AI assistant. How can I help you today?")
|
||||
await update.message.reply_text("Hello! I'm your AI assistant. How can I help you today? You can send me images and then ask questions about them.")
|
||||
|
||||
async def clear(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
user_id = update.effective_user.id
|
||||
if user_id in conversation_history:
|
||||
del conversation_history[user_id]
|
||||
await update.message.reply_text("Conversation history cleared. Let's start fresh!")
|
||||
if user_id in user_images:
|
||||
os.remove(user_images[user_id])
|
||||
del user_images[user_id]
|
||||
await update.message.reply_text("Conversation history and image cleared. Let's start fresh!")
|
||||
|
||||
async def handle_image(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
user_id = update.effective_user.id
|
||||
|
||||
# Get the largest available photo
|
||||
photo = max(update.message.photo, key=lambda x: x.file_size)
|
||||
|
||||
# Download the photo
|
||||
photo_file = await context.bot.get_file(photo.file_id)
|
||||
|
||||
# Create a temporary file to store the image
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as temp_file:
|
||||
await photo_file.download_to_drive(custom_path=temp_file.name)
|
||||
user_images[user_id] = temp_file.name
|
||||
|
||||
await update.message.reply_text("I've received your image. What would you like to know about it?")
|
||||
|
||||
async def handle_message(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
try:
|
||||
@@ -58,18 +84,43 @@ async def handle_message(update: Update, context: ContextTypes.DEFAULT_TYPE) ->
|
||||
# Prepare messages for OpenAI API
|
||||
messages = [{"role": "system", "content": "You are a helpful assistant."}] + conversation_history[user_id]
|
||||
|
||||
# Call OpenAI API for inference
|
||||
response = client.chat.completions.create(
|
||||
model="gpt-4",
|
||||
messages=messages,
|
||||
functions=functions,
|
||||
function_call="auto"
|
||||
)
|
||||
# Check if there's an image to process
|
||||
if user_id in user_images:
|
||||
with open(user_images[user_id], "rb") as image_file:
|
||||
response = client.chat.completions.create(
|
||||
model=GPT_4O_MINI,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": user_message},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{base64.b64encode(image_file.read()).decode('utf-8')}"
|
||||
}
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
max_tokens=2048,
|
||||
)
|
||||
# Remove the temporary image file
|
||||
os.remove(user_images[user_id])
|
||||
del user_images[user_id]
|
||||
else:
|
||||
# Call OpenAI API for inference (text-only)
|
||||
response = client.chat.completions.create(
|
||||
model=GPT_4O,
|
||||
messages=messages,
|
||||
functions=functions,
|
||||
function_call="auto"
|
||||
)
|
||||
|
||||
# Extract the assistant's reply
|
||||
assistant_message = response.choices[0].message
|
||||
|
||||
if assistant_message.function_call:
|
||||
if hasattr(assistant_message, 'function_call') and assistant_message.function_call:
|
||||
# Execute the function
|
||||
function_name = assistant_message.function_call.name
|
||||
function_args = assistant_message.function_call.arguments
|
||||
@@ -83,7 +134,7 @@ async def handle_message(update: Update, context: ContextTypes.DEFAULT_TYPE) ->
|
||||
})
|
||||
# Call API again to get the final response
|
||||
response = client.chat.completions.create(
|
||||
model="gpt-4",
|
||||
model=GPT_4O,
|
||||
messages=messages
|
||||
)
|
||||
assistant_reply = response.choices[0].message.content
|
||||
@@ -112,6 +163,7 @@ def main() -> None:
|
||||
# Add handlers
|
||||
application.add_handler(CommandHandler("start", start))
|
||||
application.add_handler(CommandHandler("clear", clear))
|
||||
application.add_handler(MessageHandler(filters.PHOTO, handle_image))
|
||||
application.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, handle_message))
|
||||
|
||||
# Start the Bot
|
||||
|
||||
Reference in New Issue
Block a user