Add format_tool_result method to AIProvider classes

This commit is contained in:
2024-08-18 13:59:34 -05:00
parent f33123b176
commit a0f2682660
+18
View File
@@ -23,6 +23,10 @@ class AIProvider(ABC):
def format_tool_calls(self, response): def format_tool_calls(self, response):
pass pass
@abstractmethod
def format_tool_result(self, tool_call, tool_response):
pass
class AnthropicProvider(AIProvider): class AnthropicProvider(AIProvider):
def __init__(self): def __init__(self):
self.client = anthropic.Anthropic( self.client = anthropic.Anthropic(
@@ -77,6 +81,13 @@ class AnthropicProvider(AIProvider):
def get_model(self): def get_model(self):
return self.model return self.model
def format_tool_result(self, tool_call, tool_response):
return {
"role": "function",
"name": tool_call.name,
"content": json.dumps(tool_response)
}
class OpenAIProvider(AIProvider): class OpenAIProvider(AIProvider):
def __init__(self, use_smart_model=True): def __init__(self, use_smart_model=True):
self.client = OpenAI() self.client = OpenAI()
@@ -109,6 +120,13 @@ class OpenAIProvider(AIProvider):
def get_max_tokens(self): def get_max_tokens(self):
return 4096 if self.model == "gpt-4o" else 16384 return 4096 if self.model == "gpt-4o" else 16384
def format_tool_result(self, tool_call, tool_response):
return {
"role": "function",
"name": tool_call.name,
"content": json.dumps(tool_response)
}
def create_ai_provider(provider_name="anthropic", use_smart_model=True): def create_ai_provider(provider_name="anthropic", use_smart_model=True):
if provider_name.lower() == "anthropic": if provider_name.lower() == "anthropic":
return AnthropicProvider() return AnthropicProvider()