Merge pull request #35 from bucolucas/refactor-ai-providers

Refactor AI Providers
This commit is contained in:
2024-08-18 12:56:15 -05:00
committed by GitHub
3 changed files with 186 additions and 117 deletions
+51
View File
@@ -0,0 +1,51 @@
# Telegram Inference Bot Refactoring
This repository contains a refactored version of the Telegram Inference Bot, which now uses a more flexible and maintainable approach for handling different AI providers.
## Changes
1. Introduced an abstract `AIProvider` class and concrete implementations for Anthropic and OpenAI.
2. Refactored the main bot code to use the new AI provider classes.
3. Implemented a factory function `create_ai_provider` for easy provider instantiation.
4. Updated command handlers to work with the new AI provider system.
## How to Use
1. Set up your environment variables in a `.env` file:
```
TELEGRAM_BOT_TOKEN=your_telegram_bot_token
ANTHROPIC_API_KEY=your_anthropic_api_key
OPENAI_API_KEY=your_openai_api_key
```
2. Install the required dependencies:
```
pip install -r requirements.txt
```
3. Run the bot:
```
python telegram_inference_bot.py
```
## Commands
- `/start`: Start the bot and receive a welcome message.
- `/clear`: Clear the conversation history and any stored images.
- `/switch`: Switch between smart and regular models (OpenAI only).
- `/toggle`: Toggle between Anthropic and OpenAI providers.
- `/status`: Display the current AI provider and model being used.
## Extending the Bot
To add a new AI provider:
1. Create a new class in `ai_providers.py` that inherits from `AIProvider`.
2. Implement the required methods: `get_chat_response`, `format_messages`, `format_tool_calls`, etc.
3. Update the `create_ai_provider` function to include the new provider.
## Future Improvements
- Implement more robust error handling and logging.
- Add unit tests for the AI provider classes and main bot functionality.
- Extend the README with more detailed usage instructions and examples.
+100
View File
@@ -0,0 +1,100 @@
import os
import json
import anthropic
from openai import OpenAI
from abc import ABC, abstractmethod
class AIProvider(ABC):
@abstractmethod
def get_chat_response(self, messages):
pass
@abstractmethod
def format_messages(self, messages):
pass
@abstractmethod
def format_tool_calls(self, response):
pass
class AnthropicProvider(AIProvider):
def __init__(self):
self.client = anthropic.Anthropic(
api_key=os.environ.get("ANTHROPIC_API_KEY"),
default_headers={"anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15"}
)
self.model = "claude-3-5-sonnet-20240620"
def get_chat_response(self, messages):
try:
response = self.client.messages.create(
model=self.model,
system=messages[0]['content'],
messages=self.format_messages(messages[1:]),
max_tokens=8192,
tools=self.format_tools()
)
return response
except Exception as e:
logging.error(f"An error occurred: {str(e)}")
return None
def format_messages(self, messages):
return messages
def format_tool_calls(self, response):
tool_calls = []
for message in response.content:
if message.type == "tool_use":
tool_calls.append(message)
return tool_calls
def format_tools(self):
return [
{
"name": function['name'],
"description": function['description'],
"input_schema": function['parameters'] if function['parameters'] not in [None, {}] else {"type": "object", "properties": {"param1": {"type": "string", "description": "Unnecessary"}}, "required": []}
}
for function in functions # This assumes 'functions' is globally accessible
]
class OpenAIProvider(AIProvider):
def __init__(self, use_smart_model=True):
self.client = OpenAI()
self.use_smart_model = use_smart_model
self.model = self.get_model()
def get_model(self):
return "gpt-4o" if self.use_smart_model else "gpt-4o-mini"
def get_chat_response(self, messages):
response = self.client.chat.completions.create(
model=self.model,
messages=self.format_messages(messages),
functions=functions, # This assumes 'functions' is globally accessible
function_call="auto",
max_tokens=self.get_max_tokens()
)
return response
def format_messages(self, messages):
return messages
def format_tool_calls(self, response):
tool_calls = []
assistant_message = response.choices[0].message
if hasattr(assistant_message, 'function_call') and assistant_message.function_call is not None:
tool_calls.append(assistant_message.function_call)
return tool_calls
def get_max_tokens(self):
return 4096 if self.model == "gpt-4o" else 16384
def create_ai_provider(provider_name="anthropic", use_smart_model=True):
if provider_name.lower() == "anthropic":
return AnthropicProvider()
elif provider_name.lower() == "openai":
return OpenAIProvider(use_smart_model)
else:
raise ValueError(f"Unknown provider: {provider_name}")
+35 -117
View File
@@ -3,34 +3,15 @@ import os
import importlib import importlib
import inspect import inspect
import logging import logging
import anthropic
from telegram import Update from telegram import Update
from telegram.ext import Application, CommandHandler, MessageHandler, filters, ContextTypes from telegram.ext import Application, CommandHandler, MessageHandler, filters, ContextTypes
from openai import OpenAI
from dotenv import load_dotenv from dotenv import load_dotenv
from tools.base_tool import BaseTool from tools.base_tool import BaseTool
from ai_providers import create_ai_provider
# Load environment variables # Load environment variables
load_dotenv() load_dotenv()
openai_client = OpenAI()
anthropic_client = anthropic.Anthropic(
api_key=os.environ.get("ANTHROPIC_API_KEY"),
default_headers={"anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15"}
)
GPT_4O = "gpt-4o"
GPT_4O_MINI = "gpt-4o-mini"
model_max_tokens = {
GPT_4O: 4096,
GPT_4O_MINI: 16384
}
use_smart_model = True
use_anthropic = True
# Set up logging to console and file # Set up logging to console and file
logging.basicConfig(level=logging.WARNING, handlers=[ logging.basicConfig(level=logging.WARNING, handlers=[
logging.StreamHandler(), logging.StreamHandler(),
@@ -63,6 +44,9 @@ functions = []
for tool in tools: for tool in tools:
functions.extend(tool.get_functions()) functions.extend(tool.get_functions())
# Initialize AI provider
ai_provider = create_ai_provider("anthropic")
async def start(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: async def start(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
logging.info("Bot started") logging.info("Bot started")
await update.message.reply_text("Hello! I'm your AI assistant. How can I help you today? You can send me images and then ask questions about them.") await update.message.reply_text("Hello! I'm your AI assistant. How can I help you today? You can send me images and then ask questions about them.")
@@ -91,138 +75,72 @@ async def handle_message(update: Update, context: ContextTypes.DEFAULT_TYPE) ->
messages = conversation_history[user_id] messages = conversation_history[user_id]
response = get_chat_response(messages) response = ai_provider.get_chat_response([{"role": "system", "content": system_prompt}] + messages)
tool_calls = [] tool_calls = ai_provider.format_tool_calls(response)
if use_anthropic:
for message in response.content:
if message.type == "tool_use":
tool_calls.append(message)
else:
messages.append({"role": "assistant", "content": response.content})
else:
assistant_message = response.choices[0].message
if hasattr(assistant_message, 'function_call') and assistant_message.function_call is not None:
tool_calls.append(assistant_message.function_call)
toolUseCount = 0 toolUseCount = 0
while len(tool_calls) > 0 and toolUseCount < 50: while len(tool_calls) > 0 and toolUseCount < 50:
tool_call = tool_calls.pop(0) tool_call = tool_calls.pop(0)
function_name = tool_call.name function_name = tool_call.name
tool_response = call_tool(tool_call) tool_response = call_tool(tool_call)
formatted_result = {} formatted_result = ai_provider.format_tool_result(tool_call, tool_response)
if use_anthropic:
formatted_result = {"role": "user", "content":[{"type": "tool_result", "tool_use_id": tool_call.id, "content": json.dumps(tool_response)}]}
else:
formatted_result = {"role": "function", "name": function_name, "content": json.dumps(tool_response)}
messages.append(formatted_result) messages.append(formatted_result)
response = get_chat_response(messages) response = ai_provider.get_chat_response([{"role": "system", "content": system_prompt}] + messages)
assistant_message = "" tool_calls = ai_provider.format_tool_calls(response)
if use_anthropic:
for message in response.content:
if message.type == "tool_use":
tool_calls.append(message)
else:
messages.append({"role": "assistant", "content": response.content})
else:
assistant_message = response.choices[0].message
conversation_history[user_id].append({"role": "assistant", "content": assistant_message})
if hasattr(assistant_message, 'function_call') and assistant_message.function_call is not None:
tool_calls.append(assistant_message.function_call)
assistant_reply = assistant_message
toolUseCount += 1 toolUseCount += 1
if (toolUseCount == 0): if toolUseCount == 0:
if use_anthropic: assistant_reply = ai_provider.format_assistant_reply(response)
assistant_reply = response.content
else:
assistant_reply = assistant_message
conversation_history[user_id].append({"role": "assistant", "content": assistant_reply}) conversation_history[user_id].append({"role": "assistant", "content": assistant_reply})
if len(conversation_history[user_id]) > 20: if len(conversation_history[user_id]) > 20:
conversation_history[user_id] = conversation_history[user_id][-20:] conversation_history[user_id] = conversation_history[user_id][-20:]
if use_anthropic: await update.message.reply_text(ai_provider.get_reply_text(response))
await update.message.reply_text(messages[-1]["content"][0].text)
else:
await update.message.reply_text(assistant_reply.content)
except Exception as e: except Exception as e:
logging.error(f"An error occurred: {str(e)}") logging.error(f"An error occurred: {str(e)}")
await update.message.reply_text("Sorry, an error occurred while processing your request.") await update.message.reply_text("Sorry, an error occurred while processing your request.")
def call_tool(function_call): def call_tool(function_call):
function_name = function_call.name if use_anthropic else function_call.name function_name = function_call.name
function_args = json.dumps(function_call.input) if use_anthropic else function_call.arguments function_args = json.loads(function_call.arguments if hasattr(function_call, 'arguments') else json.dumps(function_call.input))
for tool in tools: for tool in tools:
if function_name in [f["name"] for f in tool.get_functions()]: if function_name in [f["name"] for f in tool.get_functions()]:
return tool.execute(function_name, **json.loads(function_args)) return tool.execute(function_name, **function_args)
def get_chat_response(messages):
return get_claude_response(messages) if use_anthropic else get_openai_response(messages)
def get_openai_response(messages):
model = GPT_4O if use_smart_model else GPT_4O_MINI
response = openai_client.chat.completions.create(
model=model,
messages = [{"role": "system", "content": system_prompt}] + messages,
functions=functions,
function_call="auto",
max_tokens=model_max_tokens[model]
)
return response
def get_claude_response(messages):
anthropic_tools = [
{
"name": function['name'],
"description": function['description'],
"input_schema": function['parameters'] if function['parameters'] not in [None, {}] else {"type": "object", "properties": {"param1": {"type": "string", "description": "Unnecessary"}}, "required": []}
}
for function in functions
]
try:
response = anthropic_client.messages.create(
model="claude-3-5-sonnet-20240620",
system=system_prompt,
messages=messages,
max_tokens=8192,
tools=anthropic_tools
)
except Exception as e:
logging.error(f"An error occurred: {str(e)}")
return None
return response
async def switch(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: async def switch(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
global use_smart_model global ai_provider
use_smart_model = not use_smart_model if isinstance(ai_provider, OpenAIProvider):
model = GPT_4O if use_smart_model else GPT_4O_MINI ai_provider.use_smart_model = not ai_provider.use_smart_model
logging.info(f"Switched to model: {model}") model = ai_provider.get_model()
await update.message.reply_text(f"Switched to model: {model}") logging.info(f"Switched to model: {model}")
await update.message.reply_text(f"Switched to model: {model}")
else:
await update.message.reply_text("Switching models is only available for OpenAI provider.")
async def switch_providers(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: async def switch_providers(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
await clear(update, context) await clear(update, context)
global use_anthropic global ai_provider
use_anthropic = not use_anthropic if isinstance(ai_provider, AnthropicProvider):
logging.info("Using Anthropic" if use_anthropic else "Using OpenAI") ai_provider = create_ai_provider("openai")
await update.message.reply_text("Using Anthropic" if use_anthropic else "Using OpenAI") logging.info("Switched to OpenAI provider")
await update.message.reply_text("Switched to OpenAI provider")
else:
ai_provider = create_ai_provider("anthropic")
logging.info("Switched to Anthropic provider")
await update.message.reply_text("Switched to Anthropic provider")
async def status(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: async def status(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
if use_anthropic: if isinstance(ai_provider, AnthropicProvider):
await update.message.reply_text("Currently using claude-3-5-sonnet-20240620") await update.message.reply_text(f"Currently using Anthropic: {ai_provider.model}")
else: else:
model = GPT_4O if use_smart_model else GPT_4O_MINI await update.message.reply_text(f"Currently using OpenAI: {ai_provider.get_model()}")
await update.message.reply_text(f"Currently using: {model}")
def main() -> None: def main() -> None:
# Create the Application and pass it your bot's token # Create the Application and pass it your bot's token