diff --git a/base_inference_bot.py b/base_inference_bot.py new file mode 100644 index 0000000..68a4609 --- /dev/null +++ b/base_inference_bot.py @@ -0,0 +1,80 @@ +import json +import os +import importlib +import inspect +import logging +from abc import ABC, abstractmethod +from dotenv import load_dotenv +from tools.base_tool import BaseTool +from tools.metrics_tool import MetricsTool + +class BaseInferenceBot(ABC): + def __init__(self): + load_dotenv() + self.setup_logging() + self.load_system_prompt() + self.load_tools() + self.conversation_history = {} + self.processing_status = {} + + def setup_logging(self): + logging.basicConfig(level=logging.WARNING, handlers=[ + logging.StreamHandler(), + logging.FileHandler('logs/output.log', mode='a') + ]) + + def load_system_prompt(self): + with open("prompts/developer_prompt.txt", "r") as file: + self.system_prompt = file.read().strip() + + def load_tools(self): + self.tools = [MetricsTool()] + tools_dir = os.path.join(os.path.dirname(__file__), 'tools') + for filename in os.listdir(tools_dir): + if filename.endswith('.py') and filename not in ['__init__.py', 'base_tool.py', 'metrics_tool.py']: + module_name = f'tools.{filename[:-3]}' + module = importlib.import_module(module_name) + for name, obj in inspect.getmembers(module): + if inspect.isclass(obj) and issubclass(obj, BaseTool) and obj != BaseTool: + self.tools.append(obj()) + + self.functions = [] + for tool in self.tools: + self.functions.extend(tool.get_functions()) + + def clear_conversation(self, user_id): + if user_id in self.conversation_history: + del self.conversation_history[user_id] + for tool in self.tools: + tool.clear() + + def call_tool(self, function_call): + function_name = function_call.name + function_args = json.loads(function_call.arguments) + for tool in self.tools: + if function_name in [f["name"] for f in tool.get_functions()]: + return tool.execute(function_name, **function_args) + + @abstractmethod + def get_chat_response(self, messages): + pass + + @abstractmethod + async def handle_message(self, user_id, user_message): + pass + + @abstractmethod + async def start(self): + pass + + @abstractmethod + async def clear(self, user_id): + pass + + @abstractmethod + async def status(self): + pass + + @abstractmethod + async def abort_processing(self, user_id): + pass \ No newline at end of file