diff --git a/anthropic_telegram_inference_bot.py b/anthropic_telegram_inference_bot.py index 16d5344..2c39863 100644 --- a/anthropic_telegram_inference_bot.py +++ b/anthropic_telegram_inference_bot.py @@ -28,7 +28,8 @@ class AnthropicTelegramInferenceBot(BaseTelegramInferenceBot): system=self.system_prompt, messages=messages, max_tokens=8192, - tools=anthropic_tools + tools=anthropic_tools, + tool_choice={"type": "auto"} ) except Exception as e: logging.error(f"An error occurred: {str(e)}") @@ -55,15 +56,16 @@ class AnthropicTelegramInferenceBot(BaseTelegramInferenceBot): tool_use_count = 0 while len(tool_calls) > 0 and tool_use_count < 50: tool_use_results = [] - for tool_call in tool_calls: - tool_response = self.call_tool(tool_call) + while len(tool_calls) > 0: + tool_call = tool_calls.pop(0) + tool_response = self.call_tool(tool_call.name, json.dumps(tool_call.input)) tool_use_results.append({"type": "tool_result", "tool_use_id": tool_call.id, "content": json.dumps(tool_response)}) messages.append({"role": "user", "content": tool_use_results}) response = self.get_chat_response(messages) full_message = [] - tool_calls = [] + for message_part in response.content: full_message.append(message_part) if message_part.type == "tool_use": @@ -71,7 +73,11 @@ class AnthropicTelegramInferenceBot(BaseTelegramInferenceBot): messages.append({"role": "assistant", "content": full_message}) tool_use_count += 1 - + + if (tool_use_count == 0): + assistant_reply = response.content + self.conversation_history[user_id].append({"role": "assistant", "content": assistant_reply}) + if len(self.conversation_history[user_id]) > 20: self.conversation_history[user_id] = self.conversation_history[user_id][-20:] diff --git a/base_inference_bot.py b/base_inference_bot.py deleted file mode 100644 index 68a4609..0000000 --- a/base_inference_bot.py +++ /dev/null @@ -1,80 +0,0 @@ -import json -import os -import importlib -import inspect -import logging -from abc import ABC, abstractmethod -from dotenv import load_dotenv -from tools.base_tool import BaseTool -from tools.metrics_tool import MetricsTool - -class BaseInferenceBot(ABC): - def __init__(self): - load_dotenv() - self.setup_logging() - self.load_system_prompt() - self.load_tools() - self.conversation_history = {} - self.processing_status = {} - - def setup_logging(self): - logging.basicConfig(level=logging.WARNING, handlers=[ - logging.StreamHandler(), - logging.FileHandler('logs/output.log', mode='a') - ]) - - def load_system_prompt(self): - with open("prompts/developer_prompt.txt", "r") as file: - self.system_prompt = file.read().strip() - - def load_tools(self): - self.tools = [MetricsTool()] - tools_dir = os.path.join(os.path.dirname(__file__), 'tools') - for filename in os.listdir(tools_dir): - if filename.endswith('.py') and filename not in ['__init__.py', 'base_tool.py', 'metrics_tool.py']: - module_name = f'tools.{filename[:-3]}' - module = importlib.import_module(module_name) - for name, obj in inspect.getmembers(module): - if inspect.isclass(obj) and issubclass(obj, BaseTool) and obj != BaseTool: - self.tools.append(obj()) - - self.functions = [] - for tool in self.tools: - self.functions.extend(tool.get_functions()) - - def clear_conversation(self, user_id): - if user_id in self.conversation_history: - del self.conversation_history[user_id] - for tool in self.tools: - tool.clear() - - def call_tool(self, function_call): - function_name = function_call.name - function_args = json.loads(function_call.arguments) - for tool in self.tools: - if function_name in [f["name"] for f in tool.get_functions()]: - return tool.execute(function_name, **function_args) - - @abstractmethod - def get_chat_response(self, messages): - pass - - @abstractmethod - async def handle_message(self, user_id, user_message): - pass - - @abstractmethod - async def start(self): - pass - - @abstractmethod - async def clear(self, user_id): - pass - - @abstractmethod - async def status(self): - pass - - @abstractmethod - async def abort_processing(self, user_id): - pass \ No newline at end of file diff --git a/base_telegram_inference_bot.py b/base_telegram_inference_bot.py index 2d26516..cbc8279 100644 --- a/base_telegram_inference_bot.py +++ b/base_telegram_inference_bot.py @@ -1,14 +1,17 @@ +import importlib import os import json import logging +import inspect from abc import ABC, abstractmethod +from tools.base_tool import BaseTool class BaseTelegramInferenceBot(ABC): def __init__(self): self.conversation_history = {} self.processing_status = {} self.system_prompt = self.load_system_prompt() - self.functions = self.load_functions() + self.tools, self.functions = self.load_functions() @staticmethod def load_system_prompt(): @@ -17,8 +20,21 @@ class BaseTelegramInferenceBot(ABC): @staticmethod def load_functions(): - # Implement function loading logic here - return [] + tools = [] + tools_dir = os.path.join(os.path.dirname(__file__), 'tools') + for filename in os.listdir(tools_dir): + if filename.endswith('.py') and filename != '__init__.py' and filename != 'base_tool.py': + module_name = f'tools.{filename[:-3]}' + module = importlib.import_module(module_name) + for name, obj in inspect.getmembers(module): + if inspect.isclass(obj) and issubclass(obj, BaseTool) and obj != BaseTool: + tools.append(obj()) + + # Collect all function definitions + functions = [] + for tool in tools: + functions.extend(tool.get_functions()) + return tools, functions @abstractmethod def get_chat_response(self, messages): @@ -32,9 +48,9 @@ class BaseTelegramInferenceBot(ABC): if user_id in self.conversation_history: del self.conversation_history[user_id] - def call_tool(self, function_call): - function_name = function_call.name - function_args = json.loads(function_call.arguments) + def call_tool(self, function_call_name, function_call_arguments): + function_name = function_call_name + function_args = json.loads(function_call_arguments if function_call_arguments is not None else "{}") for tool in self.tools: if function_name in [f["name"] for f in tool.get_functions()]: return tool.execute(function_name, **function_args)