This commit is contained in:
2024-08-18 14:20:58 -05:00
3 changed files with 40 additions and 4 deletions
+1
View File
@@ -8,6 +8,7 @@ __pycache__/
# Virtual environment # Virtual environment
.venv/ .venv/
venv/
env/ env/
# IDE files # IDE files
+38 -2
View File
@@ -1,8 +1,14 @@
import os import os
import json import json
import logging
import anthropic import anthropic
from openai import OpenAI from openai import OpenAI
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from tools.github_tool import GitHubTool
# Initialize GitHubTool and get functions
github_tool = GitHubTool()
functions = github_tool.get_functions()
class AIProvider(ABC): class AIProvider(ABC):
@abstractmethod @abstractmethod
@@ -17,6 +23,10 @@ class AIProvider(ABC):
def format_tool_calls(self, response): def format_tool_calls(self, response):
pass pass
@abstractmethod
def format_tool_result(self, tool_call, tool_response):
pass
class AnthropicProvider(AIProvider): class AnthropicProvider(AIProvider):
def __init__(self): def __init__(self):
self.client = anthropic.Anthropic( self.client = anthropic.Anthropic(
@@ -56,9 +66,28 @@ class AnthropicProvider(AIProvider):
"description": function['description'], "description": function['description'],
"input_schema": function['parameters'] if function['parameters'] not in [None, {}] else {"type": "object", "properties": {"param1": {"type": "string", "description": "Unnecessary"}}, "required": []} "input_schema": function['parameters'] if function['parameters'] not in [None, {}] else {"type": "object", "properties": {"param1": {"type": "string", "description": "Unnecessary"}}, "required": []}
} }
for function in functions # This assumes 'functions' is globally accessible for function in functions
] ]
def format_assistant_reply(self, response):
for message in response.content:
if message.type == "text":
return message.text
return ""
def get_reply_text(self, response):
return self.format_assistant_reply(response)
def get_model(self):
return self.model
def format_tool_result(self, tool_call, tool_response):
return {
"role": "function",
"name": tool_call.name,
"content": json.dumps(tool_response)
}
class OpenAIProvider(AIProvider): class OpenAIProvider(AIProvider):
def __init__(self, use_smart_model=True): def __init__(self, use_smart_model=True):
self.client = OpenAI() self.client = OpenAI()
@@ -72,7 +101,7 @@ class OpenAIProvider(AIProvider):
response = self.client.chat.completions.create( response = self.client.chat.completions.create(
model=self.model, model=self.model,
messages=self.format_messages(messages), messages=self.format_messages(messages),
functions=functions, # This assumes 'functions' is globally accessible functions=functions,
function_call="auto", function_call="auto",
max_tokens=self.get_max_tokens() max_tokens=self.get_max_tokens()
) )
@@ -91,6 +120,13 @@ class OpenAIProvider(AIProvider):
def get_max_tokens(self): def get_max_tokens(self):
return 4096 if self.model == "gpt-4o" else 16384 return 4096 if self.model == "gpt-4o" else 16384
def format_tool_result(self, tool_call, tool_response):
return {
"role": "function",
"name": tool_call.name,
"content": json.dumps(tool_response)
}
def create_ai_provider(provider_name="anthropic", use_smart_model=True): def create_ai_provider(provider_name="anthropic", use_smart_model=True):
if provider_name.lower() == "anthropic": if provider_name.lower() == "anthropic":
return AnthropicProvider() return AnthropicProvider()
+1 -2
View File
@@ -3,8 +3,7 @@ import os
import importlib import importlib
import inspect import inspect
import logging import logging
import anthropic from telegram import Update, __version__ as telegram_version
from telegram import Update
from telegram.ext import Application, CommandHandler, MessageHandler, filters, ContextTypes from telegram.ext import Application, CommandHandler, MessageHandler, filters, ContextTypes
from openai import OpenAI from openai import OpenAI
from dotenv import load_dotenv from dotenv import load_dotenv