diff --git a/tools/persona_tool.py b/tools/persona_tool.py index f5b4970..6596d1a 100644 --- a/tools/persona_tool.py +++ b/tools/persona_tool.py @@ -1,33 +1,58 @@ import openai +import json from tools.base_tool import BaseTool class PersonaTool(BaseTool): - def __init__(self): + def __init__(self, api_key: str): super().__init__() - # Initialize OpenAI API key - openai.api_key = "YOUR_OPENAI_API_KEY" + openai.api_key = api_key def generate_response(self, persona_description: str, query: str) -> str: """ Makes a call to the OpenAI API using the persona as a system prompt. - + Parameters: persona_description (str): Description of the persona. query (str): Query to be processed. - + Returns: str: The response generated by the OpenAI API. """ - try: - response = openai.ChatCompletion.create( - model="gpt-3.5-turbo", # Specify the model - messages=[ - {"role": "system", "content": persona_description}, - {"role": "user", "content": query}, - ], - max_tokens=150 # Adjust token limit as needed - ) - return response['choices'][0]['message']['content'] - - except Exception as e: - return f"An error occurred: {str(e)}" \ No newline at end of file + response = openai.ChatCompletion.create( + model="gpt-3.5-turbo", + messages=[ + {"role": "system", "content": persona_description}, + {"role": "user", "content": query} + ] + ) + return response.choices[0].message['content'] + + def get_functions(self): + return json.dumps({ + "functions": [ + { + "name": "generate_response", + "description": "Generates a response based on a persona description and a user query.", + "parameters": { + "type": "object", + "properties": { + "persona_description": { + "type": "string", + "description": "Description of the persona." + }, + "query": { + "type": "string", + "description": "User's query to be processed." + } + }, + "required": ["persona_description", "query"] + } + } + ] + }) + + def execute(self, function_name, **kwargs): + if function_name == "generate_response": + return self.generate_response(kwargs.get("persona_description"), kwargs.get("query")) + else: + raise ValueError(f"Function {function_name} not found") \ No newline at end of file