diff --git a/README.md b/README.md new file mode 100644 index 0000000..37adc14 --- /dev/null +++ b/README.md @@ -0,0 +1,51 @@ +# Telegram Inference Bot Refactoring + +This repository contains a refactored version of the Telegram Inference Bot, which now uses a more flexible and maintainable approach for handling different AI providers. + +## Changes + +1. Introduced an abstract `AIProvider` class and concrete implementations for Anthropic and OpenAI. +2. Refactored the main bot code to use the new AI provider classes. +3. Implemented a factory function `create_ai_provider` for easy provider instantiation. +4. Updated command handlers to work with the new AI provider system. + +## How to Use + +1. Set up your environment variables in a `.env` file: + ``` + TELEGRAM_BOT_TOKEN=your_telegram_bot_token + ANTHROPIC_API_KEY=your_anthropic_api_key + OPENAI_API_KEY=your_openai_api_key + ``` + +2. Install the required dependencies: + ``` + pip install -r requirements.txt + ``` + +3. Run the bot: + ``` + python telegram_inference_bot.py + ``` + +## Commands + +- `/start`: Start the bot and receive a welcome message. +- `/clear`: Clear the conversation history and any stored images. +- `/switch`: Switch between smart and regular models (OpenAI only). +- `/toggle`: Toggle between Anthropic and OpenAI providers. +- `/status`: Display the current AI provider and model being used. + +## Extending the Bot + +To add a new AI provider: + +1. Create a new class in `ai_providers.py` that inherits from `AIProvider`. +2. Implement the required methods: `get_chat_response`, `format_messages`, `format_tool_calls`, etc. +3. Update the `create_ai_provider` function to include the new provider. + +## Future Improvements + +- Implement more robust error handling and logging. +- Add unit tests for the AI provider classes and main bot functionality. +- Extend the README with more detailed usage instructions and examples. diff --git a/ai_providers.py b/ai_providers.py new file mode 100644 index 0000000..957e0e2 --- /dev/null +++ b/ai_providers.py @@ -0,0 +1,100 @@ +import os +import json +import anthropic +from openai import OpenAI +from abc import ABC, abstractmethod + +class AIProvider(ABC): + @abstractmethod + def get_chat_response(self, messages): + pass + + @abstractmethod + def format_messages(self, messages): + pass + + @abstractmethod + def format_tool_calls(self, response): + pass + +class AnthropicProvider(AIProvider): + def __init__(self): + self.client = anthropic.Anthropic( + api_key=os.environ.get("ANTHROPIC_API_KEY"), + default_headers={"anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15"} + ) + self.model = "claude-3-5-sonnet-20240620" + + def get_chat_response(self, messages): + try: + response = self.client.messages.create( + model=self.model, + system=messages[0]['content'], + messages=self.format_messages(messages[1:]), + max_tokens=8192, + tools=self.format_tools() + ) + return response + except Exception as e: + logging.error(f"An error occurred: {str(e)}") + return None + + def format_messages(self, messages): + return messages + + def format_tool_calls(self, response): + tool_calls = [] + for message in response.content: + if message.type == "tool_use": + tool_calls.append(message) + return tool_calls + + def format_tools(self): + return [ + { + "name": function['name'], + "description": function['description'], + "input_schema": function['parameters'] if function['parameters'] not in [None, {}] else {"type": "object", "properties": {"param1": {"type": "string", "description": "Unnecessary"}}, "required": []} + } + for function in functions # This assumes 'functions' is globally accessible + ] + +class OpenAIProvider(AIProvider): + def __init__(self, use_smart_model=True): + self.client = OpenAI() + self.use_smart_model = use_smart_model + self.model = self.get_model() + + def get_model(self): + return "gpt-4o" if self.use_smart_model else "gpt-4o-mini" + + def get_chat_response(self, messages): + response = self.client.chat.completions.create( + model=self.model, + messages=self.format_messages(messages), + functions=functions, # This assumes 'functions' is globally accessible + function_call="auto", + max_tokens=self.get_max_tokens() + ) + return response + + def format_messages(self, messages): + return messages + + def format_tool_calls(self, response): + tool_calls = [] + assistant_message = response.choices[0].message + if hasattr(assistant_message, 'function_call') and assistant_message.function_call is not None: + tool_calls.append(assistant_message.function_call) + return tool_calls + + def get_max_tokens(self): + return 4096 if self.model == "gpt-4o" else 16384 + +def create_ai_provider(provider_name="anthropic", use_smart_model=True): + if provider_name.lower() == "anthropic": + return AnthropicProvider() + elif provider_name.lower() == "openai": + return OpenAIProvider(use_smart_model) + else: + raise ValueError(f"Unknown provider: {provider_name}") \ No newline at end of file diff --git a/telegram_inference_bot.py b/telegram_inference_bot.py index 85fc61b..855c16b 100644 --- a/telegram_inference_bot.py +++ b/telegram_inference_bot.py @@ -3,34 +3,15 @@ import os import importlib import inspect import logging -import anthropic from telegram import Update from telegram.ext import Application, CommandHandler, MessageHandler, filters, ContextTypes -from openai import OpenAI from dotenv import load_dotenv from tools.base_tool import BaseTool +from ai_providers import create_ai_provider # Load environment variables load_dotenv() -openai_client = OpenAI() - -anthropic_client = anthropic.Anthropic( - api_key=os.environ.get("ANTHROPIC_API_KEY"), - default_headers={"anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15"} -) - -GPT_4O = "gpt-4o" -GPT_4O_MINI = "gpt-4o-mini" - -model_max_tokens = { - GPT_4O: 4096, - GPT_4O_MINI: 16384 -} - -use_smart_model = True -use_anthropic = True - # Set up logging to console and file logging.basicConfig(level=logging.WARNING, handlers=[ logging.StreamHandler(), @@ -63,6 +44,9 @@ functions = [] for tool in tools: functions.extend(tool.get_functions()) +# Initialize AI provider +ai_provider = create_ai_provider("anthropic") + async def start(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: logging.info("Bot started") await update.message.reply_text("Hello! I'm your AI assistant. How can I help you today? You can send me images and then ask questions about them.") @@ -91,138 +75,72 @@ async def handle_message(update: Update, context: ContextTypes.DEFAULT_TYPE) -> messages = conversation_history[user_id] - response = get_chat_response(messages) - tool_calls = [] - if use_anthropic: - for message in response.content: - if message.type == "tool_use": - tool_calls.append(message) - else: - messages.append({"role": "assistant", "content": response.content}) - else: - assistant_message = response.choices[0].message - if hasattr(assistant_message, 'function_call') and assistant_message.function_call is not None: - tool_calls.append(assistant_message.function_call) + response = ai_provider.get_chat_response([{"role": "system", "content": system_prompt}] + messages) + tool_calls = ai_provider.format_tool_calls(response) toolUseCount = 0 while len(tool_calls) > 0 and toolUseCount < 50: - tool_call = tool_calls.pop(0) function_name = tool_call.name tool_response = call_tool(tool_call) - formatted_result = {} - - if use_anthropic: - formatted_result = {"role": "user", "content":[{"type": "tool_result", "tool_use_id": tool_call.id, "content": json.dumps(tool_response)}]} - else: - formatted_result = {"role": "function", "name": function_name, "content": json.dumps(tool_response)} - + formatted_result = ai_provider.format_tool_result(tool_call, tool_response) messages.append(formatted_result) - response = get_chat_response(messages) - assistant_message = "" - if use_anthropic: - for message in response.content: - if message.type == "tool_use": - tool_calls.append(message) - else: - messages.append({"role": "assistant", "content": response.content}) - else: - assistant_message = response.choices[0].message - conversation_history[user_id].append({"role": "assistant", "content": assistant_message}) - if hasattr(assistant_message, 'function_call') and assistant_message.function_call is not None: - tool_calls.append(assistant_message.function_call) - assistant_reply = assistant_message + response = ai_provider.get_chat_response([{"role": "system", "content": system_prompt}] + messages) + tool_calls = ai_provider.format_tool_calls(response) toolUseCount += 1 - - if (toolUseCount == 0): - if use_anthropic: - assistant_reply = response.content - else: - assistant_reply = assistant_message + if toolUseCount == 0: + assistant_reply = ai_provider.format_assistant_reply(response) conversation_history[user_id].append({"role": "assistant", "content": assistant_reply}) if len(conversation_history[user_id]) > 20: conversation_history[user_id] = conversation_history[user_id][-20:] - if use_anthropic: - await update.message.reply_text(messages[-1]["content"][0].text) - else: - await update.message.reply_text(assistant_reply.content) + await update.message.reply_text(ai_provider.get_reply_text(response)) except Exception as e: logging.error(f"An error occurred: {str(e)}") await update.message.reply_text("Sorry, an error occurred while processing your request.") def call_tool(function_call): - function_name = function_call.name if use_anthropic else function_call.name - function_args = json.dumps(function_call.input) if use_anthropic else function_call.arguments + function_name = function_call.name + function_args = json.loads(function_call.arguments if hasattr(function_call, 'arguments') else json.dumps(function_call.input)) for tool in tools: if function_name in [f["name"] for f in tool.get_functions()]: - return tool.execute(function_name, **json.loads(function_args)) - -def get_chat_response(messages): - return get_claude_response(messages) if use_anthropic else get_openai_response(messages) - -def get_openai_response(messages): - model = GPT_4O if use_smart_model else GPT_4O_MINI - response = openai_client.chat.completions.create( - model=model, - messages = [{"role": "system", "content": system_prompt}] + messages, - functions=functions, - function_call="auto", - max_tokens=model_max_tokens[model] - ) - return response - -def get_claude_response(messages): - anthropic_tools = [ - { - "name": function['name'], - "description": function['description'], - "input_schema": function['parameters'] if function['parameters'] not in [None, {}] else {"type": "object", "properties": {"param1": {"type": "string", "description": "Unnecessary"}}, "required": []} - } - for function in functions - ] - try: - response = anthropic_client.messages.create( - model="claude-3-5-sonnet-20240620", - system=system_prompt, - messages=messages, - max_tokens=8192, - tools=anthropic_tools - ) - except Exception as e: - logging.error(f"An error occurred: {str(e)}") - return None - - return response + return tool.execute(function_name, **function_args) async def switch(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: - global use_smart_model - use_smart_model = not use_smart_model - model = GPT_4O if use_smart_model else GPT_4O_MINI - logging.info(f"Switched to model: {model}") - await update.message.reply_text(f"Switched to model: {model}") + global ai_provider + if isinstance(ai_provider, OpenAIProvider): + ai_provider.use_smart_model = not ai_provider.use_smart_model + model = ai_provider.get_model() + logging.info(f"Switched to model: {model}") + await update.message.reply_text(f"Switched to model: {model}") + else: + await update.message.reply_text("Switching models is only available for OpenAI provider.") async def switch_providers(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: await clear(update, context) - global use_anthropic - use_anthropic = not use_anthropic - logging.info("Using Anthropic" if use_anthropic else "Using OpenAI") - await update.message.reply_text("Using Anthropic" if use_anthropic else "Using OpenAI") + global ai_provider + if isinstance(ai_provider, AnthropicProvider): + ai_provider = create_ai_provider("openai") + logging.info("Switched to OpenAI provider") + await update.message.reply_text("Switched to OpenAI provider") + else: + ai_provider = create_ai_provider("anthropic") + logging.info("Switched to Anthropic provider") + await update.message.reply_text("Switched to Anthropic provider") async def status(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: - if use_anthropic: - await update.message.reply_text("Currently using claude-3-5-sonnet-20240620") + if isinstance(ai_provider, AnthropicProvider): + await update.message.reply_text(f"Currently using Anthropic: {ai_provider.model}") else: - model = GPT_4O if use_smart_model else GPT_4O_MINI - await update.message.reply_text(f"Currently using: {model}") + await update.message.reply_text(f"Currently using OpenAI: {ai_provider.get_model()}") def main() -> None: # Create the Application and pass it your bot's token