Refactor to use AI provider classes

This commit is contained in:
2024-08-18 12:54:11 -05:00
parent 731e655ce1
commit a7abc5ebd0
+35 -117
View File
@@ -3,34 +3,15 @@ import os
import importlib import importlib
import inspect import inspect
import logging import logging
import anthropic
from telegram import Update from telegram import Update
from telegram.ext import Application, CommandHandler, MessageHandler, filters, ContextTypes from telegram.ext import Application, CommandHandler, MessageHandler, filters, ContextTypes
from openai import OpenAI
from dotenv import load_dotenv from dotenv import load_dotenv
from tools.base_tool import BaseTool from tools.base_tool import BaseTool
from ai_providers import create_ai_provider
# Load environment variables # Load environment variables
load_dotenv() load_dotenv()
openai_client = OpenAI()
anthropic_client = anthropic.Anthropic(
api_key=os.environ.get("ANTHROPIC_API_KEY"),
default_headers={"anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15"}
)
GPT_4O = "gpt-4o"
GPT_4O_MINI = "gpt-4o-mini"
model_max_tokens = {
GPT_4O: 4096,
GPT_4O_MINI: 16384
}
use_smart_model = True
use_anthropic = True
# Set up logging to console and file # Set up logging to console and file
logging.basicConfig(level=logging.WARNING, handlers=[ logging.basicConfig(level=logging.WARNING, handlers=[
logging.StreamHandler(), logging.StreamHandler(),
@@ -63,6 +44,9 @@ functions = []
for tool in tools: for tool in tools:
functions.extend(tool.get_functions()) functions.extend(tool.get_functions())
# Initialize AI provider
ai_provider = create_ai_provider("anthropic")
async def start(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: async def start(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
logging.info("Bot started") logging.info("Bot started")
await update.message.reply_text("Hello! I'm your AI assistant. How can I help you today? You can send me images and then ask questions about them.") await update.message.reply_text("Hello! I'm your AI assistant. How can I help you today? You can send me images and then ask questions about them.")
@@ -91,138 +75,72 @@ async def handle_message(update: Update, context: ContextTypes.DEFAULT_TYPE) ->
messages = conversation_history[user_id] messages = conversation_history[user_id]
response = get_chat_response(messages) response = ai_provider.get_chat_response([{"role": "system", "content": system_prompt}] + messages)
tool_calls = [] tool_calls = ai_provider.format_tool_calls(response)
if use_anthropic:
for message in response.content:
if message.type == "tool_use":
tool_calls.append(message)
else:
messages.append({"role": "assistant", "content": response.content})
else:
assistant_message = response.choices[0].message
if hasattr(assistant_message, 'function_call') and assistant_message.function_call is not None:
tool_calls.append(assistant_message.function_call)
toolUseCount = 0 toolUseCount = 0
while len(tool_calls) > 0 and toolUseCount < 50: while len(tool_calls) > 0 and toolUseCount < 50:
tool_call = tool_calls.pop(0) tool_call = tool_calls.pop(0)
function_name = tool_call.name function_name = tool_call.name
tool_response = call_tool(tool_call) tool_response = call_tool(tool_call)
formatted_result = {} formatted_result = ai_provider.format_tool_result(tool_call, tool_response)
if use_anthropic:
formatted_result = {"role": "user", "content":[{"type": "tool_result", "tool_use_id": tool_call.id, "content": json.dumps(tool_response)}]}
else:
formatted_result = {"role": "function", "name": function_name, "content": json.dumps(tool_response)}
messages.append(formatted_result) messages.append(formatted_result)
response = get_chat_response(messages) response = ai_provider.get_chat_response([{"role": "system", "content": system_prompt}] + messages)
assistant_message = "" tool_calls = ai_provider.format_tool_calls(response)
if use_anthropic:
for message in response.content:
if message.type == "tool_use":
tool_calls.append(message)
else:
messages.append({"role": "assistant", "content": response.content})
else:
assistant_message = response.choices[0].message
conversation_history[user_id].append({"role": "assistant", "content": assistant_message})
if hasattr(assistant_message, 'function_call') and assistant_message.function_call is not None:
tool_calls.append(assistant_message.function_call)
assistant_reply = assistant_message
toolUseCount += 1 toolUseCount += 1
if (toolUseCount == 0): if toolUseCount == 0:
if use_anthropic: assistant_reply = ai_provider.format_assistant_reply(response)
assistant_reply = response.content
else:
assistant_reply = assistant_message
conversation_history[user_id].append({"role": "assistant", "content": assistant_reply}) conversation_history[user_id].append({"role": "assistant", "content": assistant_reply})
if len(conversation_history[user_id]) > 20: if len(conversation_history[user_id]) > 20:
conversation_history[user_id] = conversation_history[user_id][-20:] conversation_history[user_id] = conversation_history[user_id][-20:]
if use_anthropic: await update.message.reply_text(ai_provider.get_reply_text(response))
await update.message.reply_text(messages[-1]["content"][0].text)
else:
await update.message.reply_text(assistant_reply.content)
except Exception as e: except Exception as e:
logging.error(f"An error occurred: {str(e)}") logging.error(f"An error occurred: {str(e)}")
await update.message.reply_text("Sorry, an error occurred while processing your request.") await update.message.reply_text("Sorry, an error occurred while processing your request.")
def call_tool(function_call): def call_tool(function_call):
function_name = function_call.name if use_anthropic else function_call.name function_name = function_call.name
function_args = json.dumps(function_call.input) if use_anthropic else function_call.arguments function_args = json.loads(function_call.arguments if hasattr(function_call, 'arguments') else json.dumps(function_call.input))
for tool in tools: for tool in tools:
if function_name in [f["name"] for f in tool.get_functions()]: if function_name in [f["name"] for f in tool.get_functions()]:
return tool.execute(function_name, **json.loads(function_args)) return tool.execute(function_name, **function_args)
def get_chat_response(messages):
return get_claude_response(messages) if use_anthropic else get_openai_response(messages)
def get_openai_response(messages):
model = GPT_4O if use_smart_model else GPT_4O_MINI
response = openai_client.chat.completions.create(
model=model,
messages = [{"role": "system", "content": system_prompt}] + messages,
functions=functions,
function_call="auto",
max_tokens=model_max_tokens[model]
)
return response
def get_claude_response(messages):
anthropic_tools = [
{
"name": function['name'],
"description": function['description'],
"input_schema": function['parameters'] if function['parameters'] not in [None, {}] else {"type": "object", "properties": {"param1": {"type": "string", "description": "Unnecessary"}}, "required": []}
}
for function in functions
]
try:
response = anthropic_client.messages.create(
model="claude-3-5-sonnet-20240620",
system=system_prompt,
messages=messages,
max_tokens=8192,
tools=anthropic_tools
)
except Exception as e:
logging.error(f"An error occurred: {str(e)}")
return None
return response
async def switch(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: async def switch(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
global use_smart_model global ai_provider
use_smart_model = not use_smart_model if isinstance(ai_provider, OpenAIProvider):
model = GPT_4O if use_smart_model else GPT_4O_MINI ai_provider.use_smart_model = not ai_provider.use_smart_model
logging.info(f"Switched to model: {model}") model = ai_provider.get_model()
await update.message.reply_text(f"Switched to model: {model}") logging.info(f"Switched to model: {model}")
await update.message.reply_text(f"Switched to model: {model}")
else:
await update.message.reply_text("Switching models is only available for OpenAI provider.")
async def switch_providers(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: async def switch_providers(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
await clear(update, context) await clear(update, context)
global use_anthropic global ai_provider
use_anthropic = not use_anthropic if isinstance(ai_provider, AnthropicProvider):
logging.info("Using Anthropic" if use_anthropic else "Using OpenAI") ai_provider = create_ai_provider("openai")
await update.message.reply_text("Using Anthropic" if use_anthropic else "Using OpenAI") logging.info("Switched to OpenAI provider")
await update.message.reply_text("Switched to OpenAI provider")
else:
ai_provider = create_ai_provider("anthropic")
logging.info("Switched to Anthropic provider")
await update.message.reply_text("Switched to Anthropic provider")
async def status(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: async def status(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
if use_anthropic: if isinstance(ai_provider, AnthropicProvider):
await update.message.reply_text("Currently using claude-3-5-sonnet-20240620") await update.message.reply_text(f"Currently using Anthropic: {ai_provider.model}")
else: else:
model = GPT_4O if use_smart_model else GPT_4O_MINI await update.message.reply_text(f"Currently using OpenAI: {ai_provider.get_model()}")
await update.message.reply_text(f"Currently using: {model}")
def main() -> None: def main() -> None:
# Create the Application and pass it your bot's token # Create the Application and pass it your bot's token