added dual models
This commit is contained in:
+21
-2
@@ -4,11 +4,15 @@ import json
|
|||||||
from tools.base_tool import BaseTool
|
from tools.base_tool import BaseTool
|
||||||
|
|
||||||
class PersonaTool(BaseTool):
|
class PersonaTool(BaseTool):
|
||||||
|
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.api_key = os.environ.get("OPENAI_API_KEY")
|
self.api_key = os.environ.get("OPENAI_API_KEY")
|
||||||
|
|
||||||
|
GPT_4O = "gpt-4o"
|
||||||
|
GPT_4O_MINI = "gpt-4o-mini"
|
||||||
|
|
||||||
def generate_response(self, persona_description: str, query: str) -> str:
|
def generate_response(self, persona_description: str, query: str) -> str:
|
||||||
"""
|
"""
|
||||||
Makes a call to the OpenAI API using the persona as a system prompt.
|
Makes a call to the OpenAI API using the persona as a system prompt.
|
||||||
@@ -57,4 +61,19 @@ class PersonaTool(BaseTool):
|
|||||||
if function_name == "generate_response":
|
if function_name == "generate_response":
|
||||||
return self.generate_response(kwargs.get("persona_description"), kwargs.get("query"))
|
return self.generate_response(kwargs.get("persona_description"), kwargs.get("query"))
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Function {function_name} not found")
|
raise ValueError(f"Function {function_name} not found")
|
||||||
|
|
||||||
|
def get_chat_response(client, messages, model):
|
||||||
|
|
||||||
|
model_max_tokens = {
|
||||||
|
GPT_4O: 4096,
|
||||||
|
GPT_4O_MINI: 16384
|
||||||
|
}
|
||||||
|
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
function_call="none",
|
||||||
|
max_tokens=model_max_tokens[model]
|
||||||
|
)
|
||||||
|
return response
|
||||||
Reference in New Issue
Block a user