diff --git a/.gitignore b/.gitignore index c3f2f07..b79925f 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,7 @@ __pycache__/ # Virtual environment .venv/ +venv/ env/ # IDE files diff --git a/ai_providers.py b/ai_providers.py index 957e0e2..2b983f0 100644 --- a/ai_providers.py +++ b/ai_providers.py @@ -1,8 +1,14 @@ import os import json +import logging import anthropic from openai import OpenAI from abc import ABC, abstractmethod +from tools.github_tool import GitHubTool + +# Initialize GitHubTool and get functions +github_tool = GitHubTool() +functions = github_tool.get_functions() class AIProvider(ABC): @abstractmethod @@ -17,6 +23,10 @@ class AIProvider(ABC): def format_tool_calls(self, response): pass + @abstractmethod + def format_tool_result(self, tool_call, tool_response): + pass + class AnthropicProvider(AIProvider): def __init__(self): self.client = anthropic.Anthropic( @@ -56,9 +66,28 @@ class AnthropicProvider(AIProvider): "description": function['description'], "input_schema": function['parameters'] if function['parameters'] not in [None, {}] else {"type": "object", "properties": {"param1": {"type": "string", "description": "Unnecessary"}}, "required": []} } - for function in functions # This assumes 'functions' is globally accessible + for function in functions ] + def format_assistant_reply(self, response): + for message in response.content: + if message.type == "text": + return message.text + return "" + + def get_reply_text(self, response): + return self.format_assistant_reply(response) + + def get_model(self): + return self.model + + def format_tool_result(self, tool_call, tool_response): + return { + "role": "function", + "name": tool_call.name, + "content": json.dumps(tool_response) + } + class OpenAIProvider(AIProvider): def __init__(self, use_smart_model=True): self.client = OpenAI() @@ -72,7 +101,7 @@ class OpenAIProvider(AIProvider): response = self.client.chat.completions.create( model=self.model, messages=self.format_messages(messages), - functions=functions, # This assumes 'functions' is globally accessible + functions=functions, function_call="auto", max_tokens=self.get_max_tokens() ) @@ -91,6 +120,13 @@ class OpenAIProvider(AIProvider): def get_max_tokens(self): return 4096 if self.model == "gpt-4o" else 16384 + def format_tool_result(self, tool_call, tool_response): + return { + "role": "function", + "name": tool_call.name, + "content": json.dumps(tool_response) + } + def create_ai_provider(provider_name="anthropic", use_smart_model=True): if provider_name.lower() == "anthropic": return AnthropicProvider() diff --git a/telegram_inference_bot.py b/telegram_inference_bot.py index 458d916..df35994 100644 --- a/telegram_inference_bot.py +++ b/telegram_inference_bot.py @@ -3,8 +3,7 @@ import os import importlib import inspect import logging -import anthropic -from telegram import Update +from telegram import Update, __version__ as telegram_version from telegram.ext import Application, CommandHandler, MessageHandler, filters, ContextTypes from openai import OpenAI from dotenv import load_dotenv