diff --git a/ai_providers.py b/ai_providers.py new file mode 100644 index 0000000..957e0e2 --- /dev/null +++ b/ai_providers.py @@ -0,0 +1,100 @@ +import os +import json +import anthropic +from openai import OpenAI +from abc import ABC, abstractmethod + +class AIProvider(ABC): + @abstractmethod + def get_chat_response(self, messages): + pass + + @abstractmethod + def format_messages(self, messages): + pass + + @abstractmethod + def format_tool_calls(self, response): + pass + +class AnthropicProvider(AIProvider): + def __init__(self): + self.client = anthropic.Anthropic( + api_key=os.environ.get("ANTHROPIC_API_KEY"), + default_headers={"anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15"} + ) + self.model = "claude-3-5-sonnet-20240620" + + def get_chat_response(self, messages): + try: + response = self.client.messages.create( + model=self.model, + system=messages[0]['content'], + messages=self.format_messages(messages[1:]), + max_tokens=8192, + tools=self.format_tools() + ) + return response + except Exception as e: + logging.error(f"An error occurred: {str(e)}") + return None + + def format_messages(self, messages): + return messages + + def format_tool_calls(self, response): + tool_calls = [] + for message in response.content: + if message.type == "tool_use": + tool_calls.append(message) + return tool_calls + + def format_tools(self): + return [ + { + "name": function['name'], + "description": function['description'], + "input_schema": function['parameters'] if function['parameters'] not in [None, {}] else {"type": "object", "properties": {"param1": {"type": "string", "description": "Unnecessary"}}, "required": []} + } + for function in functions # This assumes 'functions' is globally accessible + ] + +class OpenAIProvider(AIProvider): + def __init__(self, use_smart_model=True): + self.client = OpenAI() + self.use_smart_model = use_smart_model + self.model = self.get_model() + + def get_model(self): + return "gpt-4o" if self.use_smart_model else "gpt-4o-mini" + + def get_chat_response(self, messages): + response = self.client.chat.completions.create( + model=self.model, + messages=self.format_messages(messages), + functions=functions, # This assumes 'functions' is globally accessible + function_call="auto", + max_tokens=self.get_max_tokens() + ) + return response + + def format_messages(self, messages): + return messages + + def format_tool_calls(self, response): + tool_calls = [] + assistant_message = response.choices[0].message + if hasattr(assistant_message, 'function_call') and assistant_message.function_call is not None: + tool_calls.append(assistant_message.function_call) + return tool_calls + + def get_max_tokens(self): + return 4096 if self.model == "gpt-4o" else 16384 + +def create_ai_provider(provider_name="anthropic", use_smart_model=True): + if provider_name.lower() == "anthropic": + return AnthropicProvider() + elif provider_name.lower() == "openai": + return OpenAIProvider(use_smart_model) + else: + raise ValueError(f"Unknown provider: {provider_name}") \ No newline at end of file