This commit is contained in:
2024-08-18 14:20:58 -05:00
3 changed files with 40 additions and 4 deletions
+1
View File
@@ -8,6 +8,7 @@ __pycache__/
# Virtual environment
.venv/
venv/
env/
# IDE files
+38 -2
View File
@@ -1,8 +1,14 @@
import os
import json
import logging
import anthropic
from openai import OpenAI
from abc import ABC, abstractmethod
from tools.github_tool import GitHubTool
# Initialize GitHubTool and get functions
github_tool = GitHubTool()
functions = github_tool.get_functions()
class AIProvider(ABC):
@abstractmethod
@@ -17,6 +23,10 @@ class AIProvider(ABC):
def format_tool_calls(self, response):
pass
@abstractmethod
def format_tool_result(self, tool_call, tool_response):
pass
class AnthropicProvider(AIProvider):
def __init__(self):
self.client = anthropic.Anthropic(
@@ -56,9 +66,28 @@ class AnthropicProvider(AIProvider):
"description": function['description'],
"input_schema": function['parameters'] if function['parameters'] not in [None, {}] else {"type": "object", "properties": {"param1": {"type": "string", "description": "Unnecessary"}}, "required": []}
}
for function in functions # This assumes 'functions' is globally accessible
for function in functions
]
def format_assistant_reply(self, response):
for message in response.content:
if message.type == "text":
return message.text
return ""
def get_reply_text(self, response):
return self.format_assistant_reply(response)
def get_model(self):
return self.model
def format_tool_result(self, tool_call, tool_response):
return {
"role": "function",
"name": tool_call.name,
"content": json.dumps(tool_response)
}
class OpenAIProvider(AIProvider):
def __init__(self, use_smart_model=True):
self.client = OpenAI()
@@ -72,7 +101,7 @@ class OpenAIProvider(AIProvider):
response = self.client.chat.completions.create(
model=self.model,
messages=self.format_messages(messages),
functions=functions, # This assumes 'functions' is globally accessible
functions=functions,
function_call="auto",
max_tokens=self.get_max_tokens()
)
@@ -91,6 +120,13 @@ class OpenAIProvider(AIProvider):
def get_max_tokens(self):
return 4096 if self.model == "gpt-4o" else 16384
def format_tool_result(self, tool_call, tool_response):
return {
"role": "function",
"name": tool_call.name,
"content": json.dumps(tool_response)
}
def create_ai_provider(provider_name="anthropic", use_smart_model=True):
if provider_name.lower() == "anthropic":
return AnthropicProvider()
+1 -2
View File
@@ -3,8 +3,7 @@ import os
import importlib
import inspect
import logging
import anthropic
from telegram import Update
from telegram import Update, __version__ as telegram_version
from telegram.ext import Application, CommandHandler, MessageHandler, filters, ContextTypes
from openai import OpenAI
from dotenv import load_dotenv