From a0f2682660649658366e5679213051d1f03556f8 Mon Sep 17 00:00:00 2001 From: bucolucas Date: Sun, 18 Aug 2024 13:59:34 -0500 Subject: [PATCH] Add format_tool_result method to AIProvider classes --- ai_providers.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/ai_providers.py b/ai_providers.py index 74d7cd9..2b983f0 100644 --- a/ai_providers.py +++ b/ai_providers.py @@ -23,6 +23,10 @@ class AIProvider(ABC): def format_tool_calls(self, response): pass + @abstractmethod + def format_tool_result(self, tool_call, tool_response): + pass + class AnthropicProvider(AIProvider): def __init__(self): self.client = anthropic.Anthropic( @@ -77,6 +81,13 @@ class AnthropicProvider(AIProvider): def get_model(self): return self.model + def format_tool_result(self, tool_call, tool_response): + return { + "role": "function", + "name": tool_call.name, + "content": json.dumps(tool_response) + } + class OpenAIProvider(AIProvider): def __init__(self, use_smart_model=True): self.client = OpenAI() @@ -109,6 +120,13 @@ class OpenAIProvider(AIProvider): def get_max_tokens(self): return 4096 if self.model == "gpt-4o" else 16384 + def format_tool_result(self, tool_call, tool_response): + return { + "role": "function", + "name": tool_call.name, + "content": json.dumps(tool_response) + } + def create_ai_provider(provider_name="anthropic", use_smart_model=True): if provider_name.lower() == "anthropic": return AnthropicProvider()