Files
cyclop/ai_providers.py
T

100 lines
3.3 KiB
Python
Raw Normal View History

import os
import json
import anthropic
from openai import OpenAI
from abc import ABC, abstractmethod
class AIProvider(ABC):
@abstractmethod
def get_chat_response(self, messages):
pass
@abstractmethod
def format_messages(self, messages):
pass
@abstractmethod
def format_tool_calls(self, response):
pass
class AnthropicProvider(AIProvider):
def __init__(self):
self.client = anthropic.Anthropic(
api_key=os.environ.get("ANTHROPIC_API_KEY"),
default_headers={"anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15"}
)
self.model = "claude-3-5-sonnet-20240620"
def get_chat_response(self, messages):
try:
response = self.client.messages.create(
model=self.model,
system=messages[0]['content'],
messages=self.format_messages(messages[1:]),
max_tokens=8192,
tools=self.format_tools()
)
return response
except Exception as e:
logging.error(f"An error occurred: {str(e)}")
return None
def format_messages(self, messages):
return messages
def format_tool_calls(self, response):
tool_calls = []
for message in response.content:
if message.type == "tool_use":
tool_calls.append(message)
return tool_calls
def format_tools(self):
return [
{
"name": function['name'],
"description": function['description'],
"input_schema": function['parameters'] if function['parameters'] not in [None, {}] else {"type": "object", "properties": {"param1": {"type": "string", "description": "Unnecessary"}}, "required": []}
}
for function in functions # This assumes 'functions' is globally accessible
]
class OpenAIProvider(AIProvider):
def __init__(self, use_smart_model=True):
self.client = OpenAI()
self.use_smart_model = use_smart_model
self.model = self.get_model()
def get_model(self):
return "gpt-4o" if self.use_smart_model else "gpt-4o-mini"
def get_chat_response(self, messages):
response = self.client.chat.completions.create(
model=self.model,
messages=self.format_messages(messages),
functions=functions, # This assumes 'functions' is globally accessible
function_call="auto",
max_tokens=self.get_max_tokens()
)
return response
def format_messages(self, messages):
return messages
def format_tool_calls(self, response):
tool_calls = []
assistant_message = response.choices[0].message
if hasattr(assistant_message, 'function_call') and assistant_message.function_call is not None:
tool_calls.append(assistant_message.function_call)
return tool_calls
def get_max_tokens(self):
return 4096 if self.model == "gpt-4o" else 16384
def create_ai_provider(provider_name="anthropic", use_smart_model=True):
if provider_name.lower() == "anthropic":
return AnthropicProvider()
elif provider_name.lower() == "openai":
return OpenAIProvider(use_smart_model)
else:
raise ValueError(f"Unknown provider: {provider_name}")