From 19a12fccbcbf15e61b363b66d433cae94466be0c Mon Sep 17 00:00:00 2001 From: Jonathan Lucas Date: Sun, 18 Aug 2024 07:47:36 -0500 Subject: [PATCH] added dual models --- tools/persona_tool.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/tools/persona_tool.py b/tools/persona_tool.py index 30e109f..cee6ea8 100644 --- a/tools/persona_tool.py +++ b/tools/persona_tool.py @@ -4,11 +4,15 @@ import json from tools.base_tool import BaseTool class PersonaTool(BaseTool): + + def __init__(self): super().__init__() - self.api_key = os.environ.get("OPENAI_API_KEY") + GPT_4O = "gpt-4o" + GPT_4O_MINI = "gpt-4o-mini" + def generate_response(self, persona_description: str, query: str) -> str: """ Makes a call to the OpenAI API using the persona as a system prompt. @@ -57,4 +61,19 @@ class PersonaTool(BaseTool): if function_name == "generate_response": return self.generate_response(kwargs.get("persona_description"), kwargs.get("query")) else: - raise ValueError(f"Function {function_name} not found") \ No newline at end of file + raise ValueError(f"Function {function_name} not found") + + def get_chat_response(client, messages, model): + + model_max_tokens = { + GPT_4O: 4096, + GPT_4O_MINI: 16384 + } + + response = client.chat.completions.create( + model=model, + messages=messages, + function_call="none", + max_tokens=model_max_tokens[model] + ) + return response \ No newline at end of file