100 lines
3.3 KiB
Python
100 lines
3.3 KiB
Python
import os
|
|
import json
|
|
import anthropic
|
|
from openai import OpenAI
|
|
from abc import ABC, abstractmethod
|
|
|
|
class AIProvider(ABC):
|
|
@abstractmethod
|
|
def get_chat_response(self, messages):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def format_messages(self, messages):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def format_tool_calls(self, response):
|
|
pass
|
|
|
|
class AnthropicProvider(AIProvider):
|
|
def __init__(self):
|
|
self.client = anthropic.Anthropic(
|
|
api_key=os.environ.get("ANTHROPIC_API_KEY"),
|
|
default_headers={"anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15"}
|
|
)
|
|
self.model = "claude-3-5-sonnet-20240620"
|
|
|
|
def get_chat_response(self, messages):
|
|
try:
|
|
response = self.client.messages.create(
|
|
model=self.model,
|
|
system=messages[0]['content'],
|
|
messages=self.format_messages(messages[1:]),
|
|
max_tokens=8192,
|
|
tools=self.format_tools()
|
|
)
|
|
return response
|
|
except Exception as e:
|
|
logging.error(f"An error occurred: {str(e)}")
|
|
return None
|
|
|
|
def format_messages(self, messages):
|
|
return messages
|
|
|
|
def format_tool_calls(self, response):
|
|
tool_calls = []
|
|
for message in response.content:
|
|
if message.type == "tool_use":
|
|
tool_calls.append(message)
|
|
return tool_calls
|
|
|
|
def format_tools(self):
|
|
return [
|
|
{
|
|
"name": function['name'],
|
|
"description": function['description'],
|
|
"input_schema": function['parameters'] if function['parameters'] not in [None, {}] else {"type": "object", "properties": {"param1": {"type": "string", "description": "Unnecessary"}}, "required": []}
|
|
}
|
|
for function in functions # This assumes 'functions' is globally accessible
|
|
]
|
|
|
|
class OpenAIProvider(AIProvider):
|
|
def __init__(self, use_smart_model=True):
|
|
self.client = OpenAI()
|
|
self.use_smart_model = use_smart_model
|
|
self.model = self.get_model()
|
|
|
|
def get_model(self):
|
|
return "gpt-4o" if self.use_smart_model else "gpt-4o-mini"
|
|
|
|
def get_chat_response(self, messages):
|
|
response = self.client.chat.completions.create(
|
|
model=self.model,
|
|
messages=self.format_messages(messages),
|
|
functions=functions, # This assumes 'functions' is globally accessible
|
|
function_call="auto",
|
|
max_tokens=self.get_max_tokens()
|
|
)
|
|
return response
|
|
|
|
def format_messages(self, messages):
|
|
return messages
|
|
|
|
def format_tool_calls(self, response):
|
|
tool_calls = []
|
|
assistant_message = response.choices[0].message
|
|
if hasattr(assistant_message, 'function_call') and assistant_message.function_call is not None:
|
|
tool_calls.append(assistant_message.function_call)
|
|
return tool_calls
|
|
|
|
def get_max_tokens(self):
|
|
return 4096 if self.model == "gpt-4o" else 16384
|
|
|
|
def create_ai_provider(provider_name="anthropic", use_smart_model=True):
|
|
if provider_name.lower() == "anthropic":
|
|
return AnthropicProvider()
|
|
elif provider_name.lower() == "openai":
|
|
return OpenAIProvider(use_smart_model)
|
|
else:
|
|
raise ValueError(f"Unknown provider: {provider_name}") |