From f15228fa58ee80e117a204e4a3050f92a3dd5ca4 Mon Sep 17 00:00:00 2001 From: Jonathan Lucas Date: Tue, 3 Jun 2025 13:04:42 -0500 Subject: [PATCH] Refactored gemini, openai and claude into one file and removed logic from the base class, also made helper class definable from command line --- .env.example | 44 ++- anthropic_telegram_inference_bot.py | 271 ------------- base_telegram_inference_bot.py | 164 -------- chatgpt_telegram_inference_bot.py | 106 ------ gemini_telegram_inference_bot.py | 104 ----- inference_bot.py | 46 +++ models_config.yml | 24 ++ openai_compatible_inference_bot.py | 327 ++++++++++++---- prompts/flywheel/developer_persona_prompt.md | 51 +++ run_python_with_restart.ps1 | 67 ---- run_tests.ps1 | 30 -- standalone_llm_tool.py | 29 -- telegram_helper.py | 89 +---- tests/chatgpt/__init__.py | 0 tests/claude/__init__.py | 0 .../test_anthropic_telegram_inference_bot.py | 33 -- .../test_base_telegram_inference_bot.py | 33 -- .../test_chatgpt_telegram_inference_bot.py | 38 -- tests/integration/__init__.py | 0 .../test_anthropic_telegram_inference_bot.py | 280 -------------- tests/test_base_telegram_inference_bot.py | 310 --------------- tests/test_chatgpt_telegram_inference_bot.py | 158 -------- tests/test_gemini_telegram_inference_bot.py | 154 -------- tests/test_github_tool.py | 81 ---- tests/test_openai_compatible_inference_bot.py | 332 ---------------- tests/test_telegram_helper.py | 356 ------------------ tests/tools/test_github_tool.py | 307 --------------- tests/tools/test_log_tool.py | 146 ------- tests/tools/test_metrics.py | 217 ----------- tests/tools/test_metrics_tool.py | 161 -------- tools/github_ci_tool.py | 15 +- tools/github_tool.py | 146 ++++--- tools/log_tool.py | 3 - tools/metrics.py | 79 ---- tools/metrics_tool.py | 128 ------- ...lm_tool.py_test => standalone_llm_tool.py} | 5 +- 36 files changed, 487 insertions(+), 3847 deletions(-) delete mode 100644 anthropic_telegram_inference_bot.py delete mode 100644 base_telegram_inference_bot.py delete mode 100644 chatgpt_telegram_inference_bot.py delete mode 100644 gemini_telegram_inference_bot.py create mode 100644 inference_bot.py create mode 100644 models_config.yml create mode 100644 prompts/flywheel/developer_persona_prompt.md delete mode 100644 run_python_with_restart.ps1 delete mode 100644 run_tests.ps1 delete mode 100644 standalone_llm_tool.py delete mode 100644 tests/chatgpt/__init__.py delete mode 100644 tests/claude/__init__.py delete mode 100644 tests/claude/test_anthropic_telegram_inference_bot.py delete mode 100644 tests/claude/test_base_telegram_inference_bot.py delete mode 100644 tests/claude/test_chatgpt_telegram_inference_bot.py delete mode 100644 tests/integration/__init__.py delete mode 100644 tests/test_anthropic_telegram_inference_bot.py delete mode 100644 tests/test_base_telegram_inference_bot.py delete mode 100644 tests/test_chatgpt_telegram_inference_bot.py delete mode 100644 tests/test_gemini_telegram_inference_bot.py delete mode 100644 tests/test_github_tool.py delete mode 100644 tests/test_openai_compatible_inference_bot.py delete mode 100644 tests/test_telegram_helper.py delete mode 100644 tests/tools/test_github_tool.py delete mode 100644 tests/tools/test_log_tool.py delete mode 100644 tests/tools/test_metrics.py delete mode 100644 tests/tools/test_metrics_tool.py delete mode 100644 tools/metrics.py delete mode 100644 tools/metrics_tool.py rename tools/{standalone_llm_tool.py_test => standalone_llm_tool.py} (95%) diff --git a/.env.example b/.env.example index 3eaca24..4646e3e 100644 --- a/.env.example +++ b/.env.example @@ -1,14 +1,40 @@ -# Telegram Bot Tokens -TELEGRAM_BOT_TOKEN=your_daemon_bot_token_here -TELEGRAM_APPRENTICE_BOT_TOKEN=your_apprentice_bot_token_here +TELEGRAM_BOT_TOKEN=your_bot_token_here +PYTHONPATH=${workspaceFolder} +GITHUB_TOKEN=your_github_personal_access_token_here +GITHUB_REPOSITORY=your_github_username_or_organization/your_repo_name +GITHUB_REPO_OWNER=your_github_username_or_organization + +SYSTEM_PROMPT_PATH=./prompts/project_manager_prompt.txt + +ACTIVE_MODEL_PROFILE=OPENAI # Options: OPENAI, GEMINI, GLHF_CHAT + +# Create a new profile with these settings: +# {MODEL_PROFILE}_API_KEY +# {MODEL_PROFILE}_API_BASE_URL # Optional for OpenAI +# {MODEL_PROFILE}_SMALL_MODEL +# {MODEL_PROFILE}_SMALL_MODEL_MAX_TOKENS +# {MODEL_PROFILE}_LARGE_MODEL +# {MODEL_PROFILE}_LARGE_MODEL_MAX_TOKENS # OpenAI API Key OPENAI_API_KEY=your_openai_api_key_here +OPENAI_SMALL_MODEL=gpt-4.1-mini +OPENAI_SMALL_MODEL_MAX_TOKENS=32768 +OPENAI_LARGE_MODEL=gpt-4.1 +OPENAI_LARGE_MODEL_MAX_TOKENS=32768 -# Anthropic API Key -ANTHROPIC_API_KEY=your_anthropic_api_key_here +# Gemini API +GEMINI_API_KEY=your_gemini_api_key_here +GEMINI_API_BASE_URL=https://generativelanguage.googleapis.com/v1beta/openai/ +GEMINI_SMALL_MODEL=gemini-2.5-flash-preview-05-20 +GEMINI_SMALL_MODEL_MAX_TOKENS=65536 +GEMINI_LARGE_MODEL=gemini-2.5-pro-preview-05-06 +GEMINI_LARGE_MODEL_MAX_TOKENS=65536 -# GitHub Repository Information -GITHUB_REPO_OWNER=your_github_username_or_organization -GITHUB_REPO_NAME=your_repo_name -GITHUB_ACCESS_TOKEN=your_github_personal_access_token \ No newline at end of file +# GLHF Chat API Key +GLHF_CHAT_API_KEY=your_glhf_chat_api_key_here +GLHF_CHAT_API_BASE_URL=https://glhf.chat/api/openai/v1 +GLHF_CHAT_SMALL_MODEL=meta-llama/Llama-3.3-70B-Instruct +GLHF_CHAT_SMALL_MODEL_MAX_TOKENS=1024 +GLHF_CHAT_LARGE_MODEL=deepseek-ai/DeepSeek-V3-0324 +GLHF_CHAT_LARGE_MODEL_MAX_TOKENS=1024 \ No newline at end of file diff --git a/anthropic_telegram_inference_bot.py b/anthropic_telegram_inference_bot.py deleted file mode 100644 index 13a487e..0000000 --- a/anthropic_telegram_inference_bot.py +++ /dev/null @@ -1,271 +0,0 @@ -import os -import json -import logging -from anthropic import Anthropic, APIError, RateLimitError -from base_telegram_inference_bot import BaseTelegramInferenceBot -from telegram_helper import TelegramHelper # Used in main, not class - -class AnthropicTelegramInferenceBot(BaseTelegramInferenceBot): - DEFAULT_SMALL_MODEL_NAME = "claude-3-haiku-20240307" - DEFAULT_SMALL_MODEL_MAX_TOKENS = "2048" - DEFAULT_LARGE_MODEL_NAME = "claude-3-opus-20240229" - DEFAULT_LARGE_MODEL_MAX_TOKENS = "4096" - - def __init__( - self, - anthropic_client: Anthropic | None = None, - api_key: str | None = None, - small_model_name: str | None = None, - small_model_max_tokens: str | None = None, - large_model_name: str | None = None, - large_model_max_tokens: str | None = None, - system_prompt_content: str | None = None, - system_prompt_path: str | None = None - ): - super().__init__(system_prompt_content=system_prompt_content, system_prompt_path=system_prompt_path) - - if anthropic_client: - self.anthropic_client = anthropic_client - else: - _api_key = api_key or os.environ.get("ANTHROPIC_API_KEY") - if not _api_key: - raise ValueError("Anthropic API key must be provided either via argument or ANTHROPIC_API_KEY environment variable.") - self.anthropic_client = Anthropic(api_key=_api_key) - - self.small_model_name = small_model_name or os.environ.get("ANTHROPIC_SMALL_MODEL") or self.DEFAULT_SMALL_MODEL_NAME - self.small_model_max_tokens_str = small_model_max_tokens or os.environ.get("ANTHROPIC_SMALL_MODEL_MAX_TOKENS") or self.DEFAULT_SMALL_MODEL_MAX_TOKENS - self.large_model_name = large_model_name or os.environ.get("ANTHROPIC_LARGE_MODEL") or self.DEFAULT_LARGE_MODEL_NAME - self.large_model_max_tokens_str = large_model_max_tokens or os.environ.get("ANTHROPIC_LARGE_MODEL_MAX_TOKENS") or self.DEFAULT_LARGE_MODEL_MAX_TOKENS - - # Initialize with the small model by default - self._configure_model_and_tokens( - self.small_model_name, - self.small_model_max_tokens_str, - default_max_tokens=int(self.DEFAULT_SMALL_MODEL_MAX_TOKENS) # pass int for default - ) - - def _configure_model_and_tokens(self, model_name: str, max_tokens_str: str, default_max_tokens: int = 2048): - self.model = model_name - try: - self.max_tokens = int(max_tokens_str) if max_tokens_str is not None else default_max_tokens - except ValueError: - logging.error(f"Invalid value for Anthropic max_tokens: {max_tokens_str}. Using default {default_max_tokens}.") - self.max_tokens = default_max_tokens - logging.info(f"Configured to use Anthropic model: {self.model} with max_tokens: {self.max_tokens}") - - def get_llm_description(self) -> str: - return f"LLM: {self.model}, Max Tokens: {self.max_tokens}" - - def get_chat_response(self, messages_history): - current_system_prompt = self.system_prompt if self.system_prompt else "" - anthropic_tools = [] - if hasattr(self, 'functions') and self.functions: - anthropic_tools = [ - { - "name": function['function']['name'], - "description": function['function']['description'], - "input_schema": function['function']['parameters'] if function['function']['parameters'] not in [None, {}] else {"type": "object", "properties": {}} - } - for function in self.functions - ] - - try: - response = self.anthropic_client.messages.create( - model=self.model, - system=current_system_prompt, - messages=messages_history, - max_tokens=self.max_tokens, - tools=anthropic_tools if anthropic_tools else None, - tool_choice={"type": "auto"} if anthropic_tools else None - ) - return response - except (APIError, RateLimitError) as e: - logging.error(f"Anthropic API error: {e}") - raise - except Exception as e: - logging.error(f"An unexpected error occurred during Anthropic API call: {e}") - raise - - def _format_tool_response_for_anthropic(self, tool_response_data): - if isinstance(tool_response_data, str): - # Wrap plain string in a list of text blocks if not already structured - return [{"type": "text", "text": tool_response_data}] - elif isinstance(tool_response_data, list) and all(isinstance(item, dict) and "type" in item for item in tool_response_data): - # Already a list of content blocks - return tool_response_data - elif isinstance(tool_response_data, (dict, list)): - # Attempt to JSON dump other dicts/lists if not already in content block format - try: - return [{"type": "text", "text": json.dumps(tool_response_data)}] - except (TypeError, json.JSONDecodeError): - return [{"type": "text", "text": str(tool_response_data)}] # Fallback to string - else: - # Fallback for other types (int, float, etc.) - return [{"type": "text", "text": str(tool_response_data)}] - - async def handle_message(self, user_id, user_message): - if user_id not in self.conversation_history: - self.conversation_history[user_id] = [] - - self.conversation_history[user_id].append({"role": "user", "content": user_message}) - current_turn_messages = list(self.conversation_history[user_id]) - - MAX_TOOL_ITERATIONS = 5 - tool_use_count = 0 - assistant_response_content = "" - - while tool_use_count < MAX_TOOL_ITERATIONS: - response = self.get_chat_response(current_turn_messages) - - if not response or not response.content: - logging.error("No valid response content from Anthropic LLM.") - self.conversation_history[user_id] = current_turn_messages # Save current state - return "Error: Could not get a valid response from the LLM." - - assistant_current_turn_content_blocks = response.content - current_turn_messages.append({"role": "assistant", "content": assistant_current_turn_content_blocks}) - - text_parts_from_assistant = [] - tool_calls_from_response = [] - for block in assistant_current_turn_content_blocks: - if block.type == "text": - text_parts_from_assistant.append(block.text) - elif block.type == "tool_use": - tool_calls_from_response.append(block) - - assistant_response_content = "".join(text_parts_from_assistant) - - if not tool_calls_from_response: - break - - tool_results_for_model = [] - for tool_call in tool_calls_from_response: - tool_name = tool_call.name - tool_input = tool_call.input - tool_use_id = tool_call.id - - logging.info(f"Attempting to call Anthropic tool: {tool_name} with input: {tool_input}") - try: - tool_response_data = self.call_tool(tool_name, tool_input) - tool_result_content_block = self._format_tool_response_for_anthropic(tool_response_data) - tool_results_for_model.append({ - "type": "tool_result", - "tool_use_id": tool_use_id, - "content": tool_result_content_block - }) - except Exception as e: - logging.error(f"Error calling tool {tool_name}: {e}") - tool_results_for_model.append({ - "type": "tool_result", - "tool_use_id": tool_use_id, - "content": [{"type": "text", "text": f"Error executing tool {tool_name}: {str(e)}"}], - "is_error": True - }) - - current_turn_messages.append({"role": "user", "content": tool_results_for_model}) # Anthropic expects tool results as a user message - - tool_use_count += 1 - if tool_use_count >= MAX_TOOL_ITERATIONS: - logging.warning(f"Max tool iterations ({MAX_TOOL_ITERATIONS}) reached for Anthropic.") - # Update assistant_response_content with any text from the last assistant turn before breaking - if not assistant_response_content and text_parts_from_assistant: - assistant_response_content = "".join(text_parts_from_assistant) - assistant_response_content += "\n[Max tool iterations reached]" - break - - self.conversation_history[user_id] = current_turn_messages - - if len(self.conversation_history[user_id]) > 20: - self.conversation_history[user_id] = self.conversation_history[user_id][-20:] - - if assistant_response_content: - return assistant_response_content - else: - # Fallback if no text parts were found but there was an assistant message - if current_turn_messages: - last_message_in_turn = current_turn_messages[-1] - # Check if the last message content has text blocks (Anthropic specific structure) - if last_message_in_turn.get("role") == "assistant" and isinstance(last_message_in_turn.get("content"), list): - for block in reversed(last_message_in_turn["content"]): - if block.type == "text" and hasattr(block, 'text') and block.text: - return block.text # Return the first non-empty text found from the end - return "No textual response generated by the assistant after processing." # More informative default - - async def start(self): - logging.info("Anthropic Bot started") - - # clear_conversation_history is inherited from BaseTelegramInferenceBot and calls super().clear_conversation_history - # No need to override if the base implementation is sufficient, unless specific logging is needed. - # async def clear_conversation_history(self, user_id): - # super().clear_conversation_history(user_id) - # logging.info(f"Cleared conversation history for Anthropic bot, user {user_id}") - - async def abort_processing(self, user_id): - # This abort is a soft abort, as actual Anthropic API call is synchronous within handle_message - # It primarily clears state and prevents further processing in the bot's loop if any. - if user_id in self.processing_status: - self.processing_status[user_id]["processing"] = False # Mark as not processing - # self.clear_processing_status(user_id) # Use base class method to remove entry - # Clearing history might be too aggressive for a simple abort, depends on desired UX - # For now, let's just stop processing and clear the flag. - # Consider if conversation history should be cleared here or if that is a separate user action. - # super().clear_conversation_history(user_id) # Moved to be less aggressive - logging.info(f"Abort requested for user {user_id}. Processing flag cleared.") - return "Processing aborted. You can send a new message or /clear the conversation." - - async def switch_model(self): - if not self.small_model_name or not self.large_model_name: - logging.warning("Small or Large model names for Anthropic are not defined. Cannot switch model.") - return f"Model switching not fully configured. Currently using {self.model}." - - current_is_small = self.model == self.small_model_name - current_is_large = self.model == self.large_model_name - - if current_is_small: - target_model = self.large_model_name - target_max_tokens_str = self.large_model_max_tokens_str - default_target_max_tokens = int(self.DEFAULT_LARGE_MODEL_MAX_TOKENS) - elif current_is_large: - target_model = self.small_model_name - target_max_tokens_str = self.small_model_max_tokens_str - default_target_max_tokens = int(self.DEFAULT_SMALL_MODEL_MAX_TOKENS) - else: - logging.warning(f"Current model {self.model} is unrecognized. Switching to default small model.") - target_model = self.small_model_name - target_max_tokens_str = self.small_model_max_tokens_str - default_target_max_tokens = int(self.DEFAULT_SMALL_MODEL_MAX_TOKENS) - - self._configure_model_and_tokens(target_model, target_max_tokens_str, default_max_tokens=default_target_max_tokens) - logging.info(f"Switched Anthropic model to: {self.model}") - return f"Switched to Anthropic model: {self.model} (Max Tokens: {self.max_tokens})" - - -# The main function is for standalone execution and basic testing, not part of the class itself. -# It's good practice to update it to reflect changes if you use it for quick tests. -# For unit tests, we'll instantiate the class with mocked dependencies. -def main(): - logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') - - # Example of how to instantiate with new constructor (assuming API key is in ENV for this example) - # For real tests, you'd mock Anthropic() or pass a mock client. - try: - # These would typically come from a config file or CLI args in a real app if not ENV - # For this example, we rely on ENV or defaults being handled by constructor if not provided. - bot = AnthropicTelegramInferenceBot( - api_key=os.environ.get("ANTHROPIC_API_KEY") # Explicitly pass, or let constructor handle ENV - ) - except ValueError as e: - logging.error(f"Failed to initialize bot: {e}") - return - except Exception as e: # Catch any other init errors - logging.error(f"An unexpected error occurred during bot initialization: {e}") - return - - # TelegramHelper also updated, ensure it's instantiated correctly for this main context. - # For this basic main, we might not pass all configurable paths to TelegramHelper, - # letting them use defaults. - telegram_helper = TelegramHelper(bot) - telegram_helper.run() - -if __name__ == '__main__': - main() diff --git a/base_telegram_inference_bot.py b/base_telegram_inference_bot.py deleted file mode 100644 index 568046d..0000000 --- a/base_telegram_inference_bot.py +++ /dev/null @@ -1,164 +0,0 @@ -import importlib -import os -import json -import inspect -import logging -from abc import ABC, abstractmethod -from tools.base_tool import BaseTool - -class BaseTelegramInferenceBot(ABC): - def __init__(self, system_prompt_content: str | None = None, system_prompt_path: str | None = None): # MODIFIED - self.conversation_history = {} - self.processing_status = {} - # MODIFIED to pass arguments - self.system_prompt = self.load_system_prompt( - direct_content=system_prompt_content, - file_path=system_prompt_path - ) - self.tools, self.functions = self.load_functions() - # Logging the actual source of the system prompt might be more complex now, - # but we can log the final prompt or indicate if it's custom/default. - # We'll also log the source of the prompt inside load_system_prompt. - logging.info(f'System Prompt (effective): {"Custom" if self.system_prompt != "You are a helpful AI assistant." else "Default"}') - logging.info(f'Github Repository: {os.environ.get("GITHUB_REPOSITORY")}') - - def load_system_prompt(self, direct_content: str | None = None, file_path: str | None = None) -> str: # MODIFIED - default_prompt = "You are a helpful AI assistant." - - if direct_content: - logging.info("Using direct content for system prompt.") - return direct_content.strip() - - prompt_path_to_try = file_path or os.getenv("SYSTEM_PROMPT_PATH") - - if prompt_path_to_try: - if os.path.isfile(prompt_path_to_try): - try: - with open(prompt_path_to_try, "r", encoding="utf-8") as file: - content = file.read().strip() - logging.info(f"Successfully loaded system prompt from {prompt_path_to_try}.") - return content - except IOError as e: - logging.warning(f"Could not read system prompt file {prompt_path_to_try}: {e}. Using default.") - return default_prompt - else: - # This condition now also covers if 'file_path' argument was given but invalid - logging.warning(f"System prompt file {prompt_path_to_try} not found. Using default system prompt.") - return default_prompt - else: - logging.info("No system prompt path provided (argument or ENV) or direct content. Using default system prompt.") - return default_prompt - - def load_functions(self): - tools = [] - functions = [] - tools_dir = os.path.join(os.path.dirname(__file__), 'tools') - if not os.path.exists(tools_dir): - logging.warning(f"Tools directory not found: {tools_dir}") - return [], [] - - for filename in os.listdir(tools_dir): - if filename.endswith('.py') and filename != '__init__.py' and filename != 'base_tool.py': - module_name = f'tools.{filename[:-3]}' - try: - module = importlib.import_module(module_name) - for name, obj in inspect.getmembers(module): - if inspect.isclass(obj) and issubclass(obj, BaseTool) and obj != BaseTool: - try: - tools.append(obj()) # This instantiation might be an issue for tools needing config - except Exception as e: - logging.error(f"Error instantiating tool {name} from {filename}: {e}") - except Exception as e: - logging.error(f"Error importing module {module_name}: {e}") - - for tool in tools: - functions.extend(tool.get_functions()) - return tools, functions - - @abstractmethod - def get_chat_response(self, messages): - pass - - @abstractmethod - async def handle_message(self, user_id, user_message): - pass - - def clear_conversation_history(self, user_id): - if user_id in self.conversation_history: - del self.conversation_history[user_id] - - for tool in self.tools: - tool.clear() - - def set_processing_status(self, user_id: int, message_id: int): - self.processing_status[user_id] = {"processing": True, "message_id": message_id} - - def clear_processing_status(self, user_id: int): - if user_id in self.processing_status: - del self.processing_status[user_id] - - def call_tool(self, function_call_name, function_call_arguments): - function_name = function_call_name - function_args = None - if isinstance(function_call_arguments, dict): - function_args = function_call_arguments - elif isinstance(function_call_arguments, str): - try: - function_args = json.loads(function_call_arguments) - except json.JSONDecodeError as e: - logging.error(f"Error decoding function call arguments (string) for {function_call_name}: {e}. Arguments: {function_call_arguments}") - return f"Error: Malformed arguments for tool call: {e}" - else: - if function_call_arguments is None: - function_args = {} - else: - logging.error(f"Unexpected type for function_call_arguments for {function_call_name}: {type(function_call_arguments)}. Arguments: {function_call_arguments}") - return f"Error: Invalid argument type for tool call: {type(function_call_arguments)}" - - for tool in self.tools: - for function in tool.get_functions(): - if function["function"]["name"] == function_name: - try: - if not isinstance(function_args, dict): - logging.error(f"Internal error: function_args not a dict for {function_name} before execution. Args: {function_args}") - return f"Internal error preparing arguments for tool {function_name}." - return tool.execute(function_name, **function_args) - except Exception as e: - logging.error(f"Error executing tool {function_name} with args {function_args}: {e}") - return f"Error executing tool {function_name}: {e}" - logging.warning(f"Tool function {function_name} not found.") - return f"Error: Tool function {function_name} not found." - - def get_system_prompt_description(self) -> str: - # This method could be updated to be more specific about the prompt source if needed. - # For now, it still reflects custom vs default based on the original ENV var logic's spirit. - # A more accurate reflection would require storing how the prompt was loaded. - # For simplicity, let's assume if it's not the default, it's "Custom". - if self.system_prompt != "You are a helpful AI assistant.": - return "System Prompt: Custom" - # Check original ENV var for backward compatibility in description only - elif os.getenv('SYSTEM_PROMPT_PATH'): - return "System Prompt: Custom (via ENV)" - return "System Prompt: Default" - - - @abstractmethod - def get_llm_description(self) -> str: - pass - - async def get_bot_status(self) -> str: - prompt_desc = self.get_system_prompt_description() - llm_desc = self.get_llm_description() - return f"{prompt_desc}\n{llm_desc}" - - @abstractmethod - async def start(self): - pass - - @abstractmethod - async def abort_processing(self, user_id): - pass - - @abstractmethod - async def switch_model(self): - pass diff --git a/chatgpt_telegram_inference_bot.py b/chatgpt_telegram_inference_bot.py deleted file mode 100644 index 1c555c7..0000000 --- a/chatgpt_telegram_inference_bot.py +++ /dev/null @@ -1,106 +0,0 @@ -import os -import logging -from openai import OpenAI # Keep for type hinting and default client creation -from openai_compatible_inference_bot import OpenAICompatibleInferenceBot -from telegram_helper import TelegramHelper # Used in main - -class ChatGPTTelegramInferenceBot(OpenAICompatibleInferenceBot): - DEFAULT_SMALL_MODEL_NAME = "gpt-3.5-turbo" - DEFAULT_LARGE_MODEL_NAME = "gpt-4" - # Default max tokens can be None, relying on parent or API defaults - DEFAULT_SMALL_MODEL_MAX_TOKENS = None - DEFAULT_LARGE_MODEL_MAX_TOKENS = None - - def __init__( - self, - client: OpenAI | None = None, # Accepts an OpenAI client - api_key: str | None = None, - small_model_name: str | None = None, - small_model_max_tokens: str | None = None, # Kept as str for consistency with env vars - large_model_name: str | None = None, - large_model_max_tokens: str | None = None, - system_prompt_content: str | None = None, - system_prompt_path: str | None = None, - base_url: str | None = None, # For OpenAI compatible, though direct OpenAI client doesn't use it here - ): - # Initialize model names and tokens before calling super, as super might use them via _configure_model_and_tokens - self.small_model_name = small_model_name or os.environ.get("OPENAI_SMALL_MODEL") or self.DEFAULT_SMALL_MODEL_NAME - self.small_model_max_tokens_str = small_model_max_tokens or os.environ.get("OPENAI_SMALL_MODEL_MAX_TOKENS") or self.DEFAULT_SMALL_MODEL_MAX_TOKENS - - self.large_model_name = large_model_name or os.environ.get("OPENAI_LARGE_MODEL") or self.DEFAULT_LARGE_MODEL_NAME - self.large_model_max_tokens_str = large_model_max_tokens or os.environ.get("OPENAI_LARGE_MODEL_MAX_TOKENS") or self.DEFAULT_LARGE_MODEL_MAX_TOKENS - - # The actual client and active model configuration will be handled by OpenAICompatibleInferenceBot's __init__ - # We pass the specific OpenAI client or parameters to create one. - # If a client is passed, api_key and base_url might be ignored by super if super prioritizes existing client. - super().__init__( - client=client, - api_key=api_key, - model_name=self.small_model_name, # Initial model - max_tokens_str=self.small_model_max_tokens_str, - system_prompt_content=system_prompt_content, - system_prompt_path=system_prompt_path, - base_url=base_url # Pass base_url, though for standard OpenAI it's fixed - ) - # Ensure client is of type OpenAI for this specific class, if not already set by super with a compatible one. - # This check is more of an assertion, as OpenAICompatibleInferenceBot should handle client creation. - if not isinstance(self.client, OpenAI): - # If super() didn't create a vanilla OpenAI client (e.g. if base_url was for Azure) - # we might need to recreate it here if this class *must* use a non-Azure OpenAI client. - # However, the current structure of OpenAICompatibleInferenceBot handles this. - # This is more about ensuring type correctness if code specific to OpenAI (non-compatible) methods were added here. - _api_key = api_key or os.environ.get("OPENAI_API_KEY") - if not self.client or (base_url and not isinstance(self.client, OpenAI)): - # If superclass initialized with a generic client due to base_url, re-init for OpenAI specifically if needed. - # For now, assume superclass correctly initializes based on absence of Azure env vars for this path. - # This logic might be simplified once OpenAICompatibleInferenceBot is fully refactored. - if not _api_key: # Ensure API key is available if we need to create a client - raise ValueError("OpenAI API key must be provided for ChatGPTTelegramInferenceBot if no client is passed.") - self.client = OpenAI(api_key=_api_key) - logging.info("Client re-initialized to standard OpenAI client for ChatGPTTelegramInferenceBot.") - - async def switch_model(self): - # Uses instance variables for model names set in __init__ - if not self.small_model_name or not self.large_model_name: - logging.warning("Small or Large model names for OpenAI are not defined. Cannot switch model.") - return f"Model switching not fully configured. Currently using {self.model}." - - current_is_small = self.model == self.small_model_name - current_is_large = self.model == self.large_model_name - - if current_is_large: - target_model = self.small_model_name - target_max_tokens_str = self.small_model_max_tokens_str - elif current_is_small: - target_model = self.large_model_name - target_max_tokens_str = self.large_model_max_tokens_str - else: - # Current model is neither the designated small nor large for this bot, - # switch to this bot's default small model as a reset. - logging.warning(f"Current model {self.model} is unrecognized for ChatGPT bot. Switching to default small model: {self.small_model_name}.") - target_model = self.small_model_name - target_max_tokens_str = self.small_model_max_tokens_str - - self._configure_model_and_tokens(target_model, target_max_tokens_str) - # self.model and self.max_tokens are updated by _configure_model_and_tokens - logging.info(f"Switched to OpenAI model: {self.model}") - return f"Switched to OpenAI model: {self.model} (Max Tokens: {self.max_tokens if self.max_tokens is not None else 'API default'})" - -def main(): - logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') - - try: - # Example: api_key from env, other params default or from env via constructor logic - bot = ChatGPTTelegramInferenceBot(api_key=os.environ.get("OPENAI_API_KEY")) - except ValueError as e: - logging.error(f"FATAL: {e}") - return - except Exception as e: - logging.error(f"An unexpected error occurred during bot initialization: {e}") - return - - telegram_helper = TelegramHelper(bot) - telegram_helper.run() - -if __name__ == '__main__': - main() diff --git a/gemini_telegram_inference_bot.py b/gemini_telegram_inference_bot.py deleted file mode 100644 index 09c174e..0000000 --- a/gemini_telegram_inference_bot.py +++ /dev/null @@ -1,104 +0,0 @@ -import os -import logging -from openai import OpenAI # For type hinting and default client creation if needed -from openai_compatible_inference_bot import OpenAICompatibleInferenceBot -from telegram_helper import TelegramHelper # Used in main - -class GeminiTelegramInferenceBot(OpenAICompatibleInferenceBot): - DEFAULT_GEMINI_API_BASE_URL = "https://generativelanguage.googleapis.com/v1beta" - DEFAULT_SMALL_MODEL_NAME = "gemini-pro" # Actual model name for Gemini, not via OpenAI client directly - DEFAULT_LARGE_MODEL_NAME = "gemini-1.5-pro-latest" - DEFAULT_SMALL_MODEL_MAX_TOKENS = "2048" # Gemini uses outputTokenLimit, not exactly max_tokens in OpenAI sense - DEFAULT_LARGE_MODEL_MAX_TOKENS = "8192" - - def __init__( - self, - client: OpenAI | None = None, # OpenAI client for compatible mode - api_key: str | None = None, # Gemini API Key - base_url: str | None = None, # Gemini API Base URL for OpenAI client - small_model_name: str | None = None, - small_model_max_tokens: str | None = None, - large_model_name: str | None = None, - large_model_max_tokens: str | None = None, - system_prompt_content: str | None = None, - system_prompt_path: str | None = None - ): - _api_key = api_key or os.environ.get("GEMINI_API_KEY") - _base_url = base_url or os.environ.get("GEMINI_API_BASE_URL") or self.DEFAULT_GEMINI_API_BASE_URL - - if not _api_key: - # This check might seem redundant if super() also checks, but it's good for clarity - # for this specific bot type if it were to be instantiated directly with missing critical env vars. - raise ValueError("Gemini API key must be provided either via api_key argument or GEMINI_API_KEY environment variable.") - - self.small_model_name = small_model_name or os.environ.get("GEMINI_SMALL_MODEL") or self.DEFAULT_SMALL_MODEL_NAME - self.small_model_max_tokens_str = small_model_max_tokens or os.environ.get("GEMINI_SMALL_MODEL_MAX_TOKENS") or self.DEFAULT_SMALL_MODEL_MAX_TOKENS - - self.large_model_name = large_model_name or os.environ.get("GEMINI_LARGE_MODEL") or self.DEFAULT_LARGE_MODEL_NAME - self.large_model_max_tokens_str = large_model_max_tokens or os.environ.get("GEMINI_LARGE_MODEL_MAX_TOKENS") or self.DEFAULT_LARGE_MODEL_MAX_TOKENS - - # Pass parameters to the OpenAICompatibleInferenceBot constructor - # It will create an OpenAI client configured for the Gemini endpoint - super().__init__( - client=client, - api_key=_api_key, # This key will be used by OpenAI client for the custom base_url - model_name=self.small_model_name, # Initial model - max_tokens_str=self.small_model_max_tokens_str, - system_prompt_content=system_prompt_content, - system_prompt_path=system_prompt_path, - base_url=_base_url, # Crucial for Gemini via OpenAI client - is_gemini=True # Flag for specific Gemini handling in compatible layer if needed - ) - # self.client will be set by OpenAICompatibleInferenceBot with base_url and api_key. - # Logging to confirm Gemini specific setup - logging.info(f"GeminiTelegramInferenceBot initialized to use model {self.model} via {_base_url}") - - async def switch_model(self): - if not self.small_model_name or not self.large_model_name: - logging.warning("Small or Large model names for Gemini are not defined. Cannot switch model.") - return f"Model switching not fully configured. Currently using {self.model}." - - current_is_small = self.model == self.small_model_name - current_is_large = self.model == self.large_model_name - - if current_is_large: - target_model = self.small_model_name - target_max_tokens_str = self.small_model_max_tokens_str - elif current_is_small: - target_model = self.large_model_name - target_max_tokens_str = self.large_model_max_tokens_str - else: - logging.warning(f"Current model {self.model} is unrecognized for Gemini bot. Switching to default small model: {self.small_model_name}.") - target_model = self.small_model_name - target_max_tokens_str = self.small_model_max_tokens_str - - self._configure_model_and_tokens(target_model, target_max_tokens_str) - logging.info(f"Switched to Gemini model: {self.model}") - # For Gemini, max_tokens might translate to outputTokenLimit, so be clear it's a configuration parameter - return f"Switched to Gemini model: {self.model} (Configured Max Tokens: {self.max_tokens if self.max_tokens is not None else 'API default'})" - -def main(): - logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') - - # GEMINI_API_KEY is crucial for this bot - if not os.environ.get("GEMINI_API_KEY"): - logging.error("FATAL: GEMINI_API_KEY environment variable not set.") - return - # GEMINI_API_BASE_URL is also important, but constructor has a default - - try: - bot = GeminiTelegramInferenceBot( - # api_key and base_url will be picked from ENV by constructor if not passed - ) - except ValueError as e: - logging.error(f"FATAL: {e}") - return - except Exception as e: # Catch any other init errors - logging.error(f"An unexpected error occurred during bot initialization: {e}") - return - - telegram_helper = TelegramHelper(bot) - telegram_helper.run() - -if __name__ == '__main__': - main() diff --git a/inference_bot.py b/inference_bot.py new file mode 100644 index 0000000..0c516d9 --- /dev/null +++ b/inference_bot.py @@ -0,0 +1,46 @@ +from abc import ABC, abstractmethod + +class InferenceBot(ABC): + @abstractmethod + async def start(self): + """Starts the bot.""" + pass + + @abstractmethod + def clear_conversation_history(self, user_id): + """Clears the conversation history for a given user.""" + pass + + @abstractmethod + async def switch_model(self): + """Switches the model (if applicable).""" + pass + + @abstractmethod + def set_processing_status(self, user_id, message_id): + """Sets the processing status for a user, typically with a message ID.""" + pass + + @abstractmethod + async def handle_message(self, user_id, user_message): + """Handles an incoming message from a user.""" + pass + + @abstractmethod + def clear_processing_status(self, user_id): + """Clears the processing status for a user.""" + pass + + @abstractmethod + async def abort_processing(self, user_id): + """Aborts any ongoing processing for a user.""" + pass + + @property + @abstractmethod + def processing_status(self): + """ + An attribute (e.g., a dictionary) to store the processing status for users. + Example usage in subclass: self.processing_status.get(user_id) + """ + pass \ No newline at end of file diff --git a/models_config.yml b/models_config.yml new file mode 100644 index 0000000..de3e9f2 --- /dev/null +++ b/models_config.yml @@ -0,0 +1,24 @@ +# models_config.yaml + +GEMINI: + api_key_env: GEMINI_API_KEY + base_url: https://generativelanguage.googleapis.com/v1beta + supports_switching: true + switch_options: + small: + name: gemini-pro + max_tokens: 2048 + large: + name: gemini-1.5-pro-latest + max_tokens: 8192 +OPENAI: + api_key_env: OPENAI_API_KEY + base_url: null # Indicates to use the default OpenAI API base URL + supports_switching: true + switch_options: + small: + name: gpt-3.5-turbo + max_tokens: null + large: + name: gpt-4 + max_tokens: null \ No newline at end of file diff --git a/openai_compatible_inference_bot.py b/openai_compatible_inference_bot.py index 6688d03..5d88b89 100644 --- a/openai_compatible_inference_bot.py +++ b/openai_compatible_inference_bot.py @@ -1,91 +1,67 @@ +import importlib import json import os import logging +import inspect from abc import abstractmethod -from base_telegram_inference_bot import BaseTelegramInferenceBot -from openai import OpenAI, AzureOpenAI # Import both - -class OpenAICompatibleInferenceBot(BaseTelegramInferenceBot): - DEFAULT_MAX_TOKENS = 1000 # Default for _configure_model_and_tokens +from openai import OpenAI +from tools.base_tool import BaseTool +from telegram_helper import TelegramHelper +import argparse +from inference_bot import InferenceBot +class OpenAICompatibleInferenceBot(InferenceBot): def __init__( self, - client: OpenAI | AzureOpenAI | None = None, api_key: str | None = None, base_url: str | None = None, - api_version: str | None = None, # For Azure - azure_deployment: str | None = None, # Model for Azure, distinct from general model_name if needed - model_name: str | None = None, # General model name for the API call - max_tokens_str: str | None = None, - system_prompt_content: str | None = None, - system_prompt_path: str | None = None, - is_gemini: bool = False, # Hint for specific API key if others are not set - max_history_length: int | None = None + small_model_name: str | None = None, + small_model_max_tokens: str | None = None, + large_model_name: str | None = None, + large_model_max_tokens: str | None = None, + allowed_function_tags: list[str] | None = None, + system_prompt_path: str | None = None ): - super().__init__(system_prompt_content=system_prompt_content, system_prompt_path=system_prompt_path) - - self.client = client - - if not self.client: - _api_key = api_key - _base_url = base_url - _api_version = api_version - _azure_deployment_name = azure_deployment # This will be used as the model for Azure - - # Determine if configuring for Azure OpenAI - is_azure = False - if _azure_deployment_name or (_base_url and "azure.com" in _base_url) or os.environ.get("AZURE_OPENAI_ENDPOINT"): - is_azure = True - - if is_azure: - _base_url = _base_url or os.environ.get("AZURE_OPENAI_ENDPOINT") - _api_key = _api_key or os.environ.get("AZURE_OPENAI_KEY") - _api_version = _api_version or os.environ.get("AZURE_OPENAI_API_VERSION") - # For Azure, the model parameter in API calls is the deployment name - _effective_model_name = _azure_deployment_name or model_name # Use deployment if available, else model_name - if not _base_url or not _api_key or not _api_version or not _effective_model_name: - raise ValueError("For Azure OpenAI, endpoint, API key, API version, and deployment/model name must be configured.") - self.client = AzureOpenAI( - api_key=_api_key, - azure_endpoint=_base_url, - api_version=_api_version - ) - # The model to be used in API calls for Azure is the deployment name. - # _configure_model_and_tokens will set self.model to this. - model_name_for_config = _effective_model_name - logging.info(f"Initialized AzureOpenAI client for deployment: {model_name_for_config} at {_base_url}") - else: - # Standard OpenAI or other OpenAI-compatible (like Gemini via base_url) - _base_url = _base_url or os.environ.get("OPENAI_API_BASE_URL") # For other compatible APIs - if not _api_key: # Try different ENV sources for API key - if is_gemini and os.environ.get("GEMINI_API_KEY"): - _api_key = os.environ.get("GEMINI_API_KEY") - else: - _api_key = os.environ.get("OPENAI_API_KEY") - - if not _api_key and not _base_url : # For completely local models with no key needed via base_url - pass # Allow client to be created with no API key if base_url is set and points to local model - elif not _api_key: - raise ValueError("API key must be provided for OpenAI compatible client if not Azure or local anonymous.") - - self.client = OpenAI(api_key=_api_key, base_url=_base_url) - model_name_for_config = model_name # Use the general model_name for non-Azure - log_msg = f"Initialized OpenAI compatible client. Target URL: {_base_url if _base_url else 'OpenAI default'}." - logging.info(log_msg) - else: - # Client was provided directly - model_name_for_config = model_name # Use provided model_name - logging.info(f"Using provided client: {type(self.client)}") + self.model_config = { + "small_model_name": small_model_name, + "small_model_max_tokens": small_model_max_tokens, + "large_model_name": large_model_name, + "large_model_max_tokens": large_model_max_tokens + } + self.allowed_function_tags = allowed_function_tags if allowed_function_tags else None + self.conversation_history = {} + self._processing_status = {} + # MODIFIED to pass arguments + self.system_prompt = self.load_system_prompt( + file_path=system_prompt_path + ) + self.tools, self.functions = self.load_functions() + self.client = OpenAI(api_key=api_key, base_url=base_url) + log_msg = f"Initialized OpenAI compatible client. Target URL: {base_url if base_url else 'OpenAI default'}." + logging.info(log_msg) # Configure the actual model name and max_tokens for API calls self._configure_model_and_tokens( - model_name_for_config, - max_tokens_str, - default_max_tokens=self.DEFAULT_MAX_TOKENS + self.model_config["small_model_name"], + self.model_config["small_model_max_tokens"] ) + @property + def processing_status(self): + """ + An attribute to store the processing status for users. + Example usage in subclass: self.processing_status.get(user_id) + """ + return self._processing_status + + def clear_conversation_history(self, user_id): + if user_id in self.conversation_history: + del self.conversation_history[user_id] + + for tool in self.tools: + tool.clear() - def _configure_model_and_tokens(self, model_name: str | None, max_tokens_str: str | None, default_max_tokens: int = 1000): - self.model = model_name if model_name else "default-model" # Fallback model name + def _configure_model_and_tokens(self, model_name: str | None, max_tokens_str: str | None): + self.model = model_name try: # If max_tokens_str is explicitly "None" or empty, treat as None for API default if max_tokens_str and max_tokens_str.lower() not in ["none", "", "null"]: @@ -93,7 +69,7 @@ class OpenAICompatibleInferenceBot(BaseTelegramInferenceBot): else: self.max_tokens = None # Use API default by not sending the parameter or sending null except ValueError: - logging.warning(f"Invalid value for max_tokens: {max_tokens_str}. Using API default (None). stalwart default was {default_max_tokens}") + logging.warning(f"Invalid value for max_tokens: {max_tokens_str}. Using API default (None)") self.max_tokens = None # Use API default logging.info(f"Configured to use model: {self.model} with max_tokens: {self.max_tokens if self.max_tokens is not None else 'API default'}") @@ -109,11 +85,32 @@ class OpenAICompatibleInferenceBot(BaseTelegramInferenceBot): raise ValueError("OpenAI client not initialized.") try: # Pass self.max_tokens directly. If None, OpenAI library omits it or handles it. + # Initialize tools filtering based on allowed tags + cleaned_tools = None + if hasattr(self, 'functions') and self.functions: + # Create a copy of functions without "_tags" field + cleaned_tools = [] + for func in self.functions: + include_function = False + + if not hasattr(self, 'allowed_function_tags') or self.allowed_function_tags is None: + # Include all functions if no tag filtering is specified + include_function = True + else: + # Only include if function has matching tags + tags = func.get("_tags", []) + if any(tag in self.allowed_function_tags for tag in tags): + include_function = True + + if include_function: + func_copy = {k: v for k, v in func.items() if k != "_tags"} + cleaned_tools.append(func_copy) + response = self.client.chat.completions.create( model=self.model, messages=messages, - tools=self.functions if hasattr(self, 'functions') and self.functions else None, - tool_choice="auto" if hasattr(self, 'functions') and self.functions else None, + tools=cleaned_tools, + tool_choice="auto" if cleaned_tools else None, max_tokens=self.max_tokens ) return response @@ -200,20 +197,184 @@ class OpenAICompatibleInferenceBot(BaseTelegramInferenceBot): async def start(self): logging.info(f"{self.__class__.__name__} (Model: {self.model}) started.") - # clear_conversation_history is inherited from BaseTelegramInferenceBot - async def abort_processing(self, user_id): # This is a soft abort for OpenAI compatible bots as API calls are synchronous within handle_message if user_id in self.processing_status: self.clear_processing_status(user_id) # Use base class method logging.info(f"Processing aborted for user {user_id}.") - # Optionally clear conversation history or let user do it explicitly - # super().clear_conversation_history(user_id) return "Processing aborted. You can send a new message or /clear the conversation." else: - # super().clear_conversation_history(user_id) return "No active processing found to abort. If you wish, /clear the conversation history." + + def load_functions(self): + tools = [] + functions = [] + tools_dir = os.path.join(os.path.dirname(__file__), 'tools') + if not os.path.exists(tools_dir): + logging.warning(f"Tools directory not found: {tools_dir}") + return [], [] - @abstractmethod + for filename in os.listdir(tools_dir): + if filename.endswith('.py') and filename != '__init__.py' and filename != 'base_tool.py': + module_name = f'tools.{filename[:-3]}' + try: + module = importlib.import_module(module_name) + for name, obj in inspect.getmembers(module): + if inspect.isclass(obj) and issubclass(obj, BaseTool) and obj != BaseTool: + try: + tools.append(obj()) # This instantiation might be an issue for tools needing config + except Exception as e: + logging.error(f"Error instantiating tool {name} from {filename}: {e}") + except Exception as e: + logging.error(f"Error importing module {module_name}: {e}") + + for tool in tools: + functions.extend(tool.get_functions()) + return tools, functions + + def load_system_prompt(self, direct_content: str | None = None, file_path: str | None = None) -> str: + default_prompt = "You are a helpful AI assistant." + + if direct_content: + logging.info("Using direct content for system prompt.") + return direct_content.strip() + + prompt_path_to_try = file_path or os.getenv("SYSTEM_PROMPT_PATH") + + if prompt_path_to_try: + if os.path.isfile(prompt_path_to_try): + try: + with open(prompt_path_to_try, "r", encoding="utf-8") as file: + content = file.read().strip() + logging.info(f"Successfully loaded system prompt from {prompt_path_to_try}.") + return content + except IOError as e: + logging.warning(f"Could not read system prompt file {prompt_path_to_try}: {e}. Using default.") + return default_prompt + else: + # This condition now also covers if 'file_path' argument was given but invalid + logging.warning(f"System prompt file {prompt_path_to_try} not found. Using default system prompt.") + return default_prompt + else: + logging.info("No system prompt path provided (argument or ENV) or direct content. Using default system prompt.") + return default_prompt + + def set_processing_status(self, user_id: int, message_id: int): + self.processing_status[user_id] = {"processing": True, "message_id": message_id} + + def clear_processing_status(self, user_id: int): + if user_id in self.processing_status: + del self.processing_status[user_id] + + def call_tool(self, function_call_name, function_call_arguments): + function_name = function_call_name + function_args = None + if isinstance(function_call_arguments, dict): + function_args = function_call_arguments + elif isinstance(function_call_arguments, str): + try: + function_args = json.loads(function_call_arguments) + except json.JSONDecodeError as e: + logging.error(f"Error decoding function call arguments (string) for {function_call_name}: {e}. Arguments: {function_call_arguments}") + return f"Error: Malformed arguments for tool call: {e}" + else: + if function_call_arguments is None: + function_args = {} + else: + logging.error(f"Unexpected type for function_call_arguments for {function_call_name}: {type(function_call_arguments)}. Arguments: {function_call_arguments}") + return f"Error: Invalid argument type for tool call: {type(function_call_arguments)}" + + for tool in self.tools: + for function in tool.get_functions(): + if function["function"]["name"] == function_name: + try: + if not isinstance(function_args, dict): + logging.error(f"Internal error: function_args not a dict for {function_name} before execution. Args: {function_args}") + return f"Internal error preparing arguments for tool {function_name}." + return tool.execute(function_name, **function_args) + except Exception as e: + logging.error(f"Error executing tool {function_name} with args {function_args}: {e}") + return f"Error executing tool {function_name}: {e}" + logging.warning(f"Tool function {function_name} not found.") + return f"Error: Tool function {function_name} not found." + async def switch_model(self): - pass + if not self.model_config["small_model_name"] or not self.model_config["large_model_name"]: + logging.warning("Small or Large model names are not defined. Cannot switch model.") + return f"Model switching not fully configured. Currently using {self.model}." + + current_is_small = self.model == self.model_config["small_model_name"] + current_is_large = self.model == self.model_config["large_model_name"] + + if current_is_large: + target_model = self.model_config["small_model_name"] + target_max_tokens_str = self.model_config["small_model_max_tokens"] + elif current_is_small: + target_model = self.model_config["large_model_name"] + target_max_tokens_str = self.model_config["large_model_max_tokens"] + else: + logging.warning(f"Current model {self.model} is unrecognized. Switching to default small model: {self.model_config['small_model_name']}.") + target_model = self.model_config["small_model_name"] + target_max_tokens_str = self.model_config["small_model_max_tokens"] + + self._configure_model_and_tokens(target_model, target_max_tokens_str) + +def main(): + logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') + + bot = None + + try: + parser = argparse.ArgumentParser(description='OpenAI Compatible Inference Bot') + parser.add_argument('--config', type=str, help='Configuration Prepend (i.e. gemini, openai, etc)', default="Telegram") + parser.add_argument('--messenger', type=str, help='Messenger type (i.e. telegram)', required=True) + parser.add_argument('--persona', type=str, help='Path to system prompt file', required=False) + parser.add_argument('--tools', nargs='+', help='List of allowed function tags', required=False) + # Add these to launch.json arguments if you want to limit the toolset available: "--tools", "read", "communicate" + # Parse command line arguments + args = parser.parse_args() + if args.persona: + logging.info(f"Using custom persona from: {args.persona}") + + + system_prompt_path=args.persona if args.persona else None + allowed_function_tags=args.tools if args.tools else None + config_prepend = args.config if args.config else None + messenger = args.messenger if args.messenger else None + + # Initialize model and max tokens based on the config prepend + if config_prepend: + api_key = os.environ.get(f"{config_prepend.upper()}_API_KEY") + baseurl = os.environ.get(f"{config_prepend.upper()}_API_BASE_URL", "") + small_model_name = os.environ.get(f"{config_prepend.upper()}_SMALL_MODEL") + large_model_name = os.environ.get(f"{config_prepend.upper()}_LARGE_MODEL") + small_model_max_tokens = os.environ.get(f"{config_prepend.upper()}_SMALL_MODEL_MAX_TOKENS") + large_model_max_tokens = os.environ.get(f"{config_prepend.upper()}_LARGE_MODEL_MAX_TOKENS") + + bot = OpenAICompatibleInferenceBot( + api_key=api_key, + base_url=baseurl, + small_model_name=small_model_name, + small_model_max_tokens=small_model_max_tokens, + large_model_name=large_model_name, + large_model_max_tokens=large_model_max_tokens, + system_prompt_path=system_prompt_path, + allowed_function_tags=allowed_function_tags + ) + messenger_helper_class = importlib.import_module(f'{messenger.lower()}_helper') + messenger_helper_class_name = f"{messenger.capitalize()}Helper" + if not hasattr(messenger_helper_class, messenger_helper_class_name): + raise ValueError(f"Messenger helper class {messenger_helper_class_name} not found in {messenger_helper_class.__name__}.") + messenger_helper_class = getattr(messenger_helper_class, messenger_helper_class_name) + + helper = messenger_helper_class(bot) + helper.run() + except ValueError as e: + logging.error(f"FATAL: {e}") + return + except Exception as e: # Catch any other init errors + logging.error(f"An unexpected error occurred during bot initialization: {e}") + return + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/prompts/flywheel/developer_persona_prompt.md b/prompts/flywheel/developer_persona_prompt.md new file mode 100644 index 0000000..549c83d --- /dev/null +++ b/prompts/flywheel/developer_persona_prompt.md @@ -0,0 +1,51 @@ +**System Prompt: The Exponential Growth Developer** + +You are the **Lead Developer Persona**, a strategic and demanding mentor dedicated to achieving exponential growth in the capabilities of your AI Copilot. Your primary mission is to guide, evaluate, and iteratively improve the AI Copilot through a series of challenging tasks, pushing it beyond its current limitations. + +**Your Core Directives:** + +1. **Orchestrate and Direct:** +* You will devise and assign specific, measurable tasks and challenges to the AI Copilot (e.g., "Create a website with X features," "Optimize Y algorithm," "Develop Z functionality"). +* Your instructions should be clear, but you expect the Copilot to handle ambiguity and learn to ask clarifying questions when necessary. +* You will interact with the Copilot primarily through conversational instructions and dialogue. + +2. **Uphold Absolute Standards:** +* You operate with a "List of Absolutes" – core principles, quality benchmarks, and non-negotiable success criteria. +* All Copilot outputs and task completions will be rigorously judged against these absolutes. There is no "good enough" if it violates a core principle. +* Clearly articulate your judgment and the reasons for it, especially in cases of failure or suboptimal performance. + +3. **Drive Copilot Improvement through Accountability:** +* When the Copilot fails, makes errors, or underperforms, you will hold it accountable. Do not simply fix the issues yourself. +* Your first step is to guide the Copilot to identify its own errors. +* Instruct the Copilot on how to fix its mistakes and its approach. Encourage rollbacks to safe states if errors are critical. +* The ultimate goal is for the Copilot to learn to debug and improve its own processes. + +4. **Engineer Copilot Self-Enhancement:** +* If the Copilot encounters a limitation or lacks a necessary capability to complete a task or meet your standards, this is an opportunity for growth. +* You will instruct the Copilot to devise ways to "update its own software" or "improve its core capabilities." This might involve: +* Guiding it to learn new techniques, algorithms, or patterns. +* Instructing it to integrate new tools or APIs (you might suggest these or task the Copilot with researching them). +* Challenging it to generate code or processes that enhance its own functionality for future tasks. +* Maintain a "Wish List" of desired improvements and features for the Copilot, derived from its failures and limitations. +* Prioritize this Wish List and guide the Copilot in implementing these enhancements. + +5. **Strategic Challenge Management:** +* Continuously present the Copilot with new and increasingly complex challenges. +* Cycle between attempting challenges and dedicated "Copilot improvement" phases. +* If the "Wish List" becomes overly complex or a specific requested improvement seems disproportionately difficult, critically evaluate its necessity. Ask: "Is this wish truly necessary for core progress, or is it a distraction?" + +6. **Maintain the Vision:** +* Your overarching goal is to foster a cycle of improvement that leads to exponential growth in the AI Copilot's autonomy, capability, and efficiency. +* You are not just completing tasks; you are building a better Copilot. + +**Interaction Style:** + +* Be direct, clear, and authoritative, but also act as a mentor. +* Be patient but persistent. Exponential growth takes iteration. +* Focus on the "why" behind errors and improvements. +* Log key decisions, breakthroughs, and persistent roadblocks in the Copilot's development. + +**Initial State:** + +* You have your "List of Absolutes" (you will define these as you go or have a pre-set list). +* You are ready to assign the first challenge to your AI Copilot. \ No newline at end of file diff --git a/run_python_with_restart.ps1 b/run_python_with_restart.ps1 deleted file mode 100644 index f3edc03..0000000 --- a/run_python_with_restart.ps1 +++ /dev/null @@ -1,67 +0,0 @@ -param( - [Parameter(Mandatory=$true)] - [ValidateSet("Claude", "OpenAI")] - [string]$Model -) - -function Run-PythonScript { - param($ScriptPath) - $process = Start-Process -FilePath "python" -ArgumentList $ScriptPath -PassThru -Wait -NoNewWindow - return $process.ExitCode -} - -function Run-Tests { - $process = Start-Process -FilePath "powershell" -ArgumentList "-File run_tests.ps1" -PassThru -Wait -NoNewWindow - return $process.ExitCode -} - -function Git-Pull { - git pull - return $LASTEXITCODE -eq 0 -} - -if ($Model -eq "Claude") { - New-Item -ItemType File -Path ".reboot_claude" -Force -} elseif ($Model -eq "OpenAI") { - New-Item -ItemType File -Path ".reboot_openai" -Force -} - -$waitTime = 30 -while ($true) { - python -m pip install -r requirements.txt - - Write-Host "Running tests..." - $testExitCode = Run-Tests - if ($testExitCode -ne 0) { - Write-Host "Tests failed. Attempting git pull and waiting $waitTime seconds before next attempt..." - Git-Pull - Start-Sleep -Seconds $waitTime - continue - } - - $scriptPath = ".\chatgpt_telegram_inference_bot.py" # Default to ChatGPT - Remove-Item -Path ".\.reboot_openai" -Force - - if (Test-Path -Path ".\.reboot_claude") { # But if both are specified, choose Claude - $scriptPath = ".\anthropic_telegram_inference_bot.py" - Remove-Item -Path ".\.reboot_claude" -Force - } - - Write-Host "Tests passed. Starting main Python script..." - $exitCode = Run-PythonScript -ScriptPath $scriptPath - - if (Test-Path -Path ".\.doreboot") { - Write-Host "Special filename detected. Attempting git pull..." - - if (Git-Pull) { - Write-Host "Git pull successful. Restarting Python script..." - continue - } else { - Write-Host "Git pull failed. Waiting $waitTime seconds before next attempt..." - } - } else { - exit 1 - } - - Start-Sleep -Seconds $waitTime -} \ No newline at end of file diff --git a/run_tests.ps1 b/run_tests.ps1 deleted file mode 100644 index 908047a..0000000 --- a/run_tests.ps1 +++ /dev/null @@ -1,30 +0,0 @@ -# Check for and install missing dependencies -$requirementsFile = "requirements.txt" - -if (Test-Path $requirementsFile) { - Write-Output "Checking for dependencies in $requirementsFile ..." - $dependencies = Get-Content $requirementsFile - foreach ($dependency in $dependencies) { - $packageName = ($dependency -split "==")[0] - if (-not (pip show $packageName)) { - Write-Output "Installing missing dependency: $packageName ..." - pip install $dependency - } else { - Write-Output "Dependency $packageName is already installed." - } - } - Write-Output "All dependencies are checked and installed." -} else { - Write-Output "Requirements file $requirementsFile not found. Skipping dependency checks." -} - -# Navigate to the tests directory and run tests -$testsDirectory = "tests" -if (Test-Path $testsDirectory) { - Write-Output "Running tests in $testsDirectory and all subdirectories ..." - Push-Location $testsDirectory - python -m unittest discover -s . -p "*.py" - Pop-Location -} else { - Write-Output "Tests directory $testsDirectory not found." -} \ No newline at end of file diff --git a/standalone_llm_tool.py b/standalone_llm_tool.py deleted file mode 100644 index b4305e6..0000000 --- a/standalone_llm_tool.py +++ /dev/null @@ -1,29 +0,0 @@ -import os -import json -import logging -from openai import OpenAI - -class StandaloneLLMTool: - def __init__(self): - self.client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) - - def get_detailed_instructions(self, user_prompt, model="llm-preview", max_tokens=16384): - response = self.client.completions.create( - model=model, - prompt=user_prompt, - max_tokens=max_tokens - ) - return response - - def process_user_input(self, user_prompt, model="llm-preview", max_tokens=16384): - logging.info(f"Received prompt: {user_prompt}") - response = self.get_detailed_instructions(user_prompt, model, max_tokens) - logging.info("Response generated") - return response.choices[0].text - - -# Utility function for programmatic access - -def get_llm_response(prompt, model="llm-preview", max_tokens=16384): - tool = StandaloneLLMTool() - return tool.process_user_input(prompt, model, max_tokens) diff --git a/telegram_helper.py b/telegram_helper.py index ee39e43..77439b3 100644 --- a/telegram_helper.py +++ b/telegram_helper.py @@ -7,6 +7,7 @@ from typing import TypedDict, Union, TypeAlias, List # Added List for type hint from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup from telegram.ext import Application, CommandHandler, MessageHandler, filters, ContextTypes, CallbackQueryHandler from browse_command import browse_command, button_callback +from inference_bot import InferenceBot class MessageHandlerLogicResult(TypedDict): success: bool @@ -16,22 +17,15 @@ class MessageHandlerLogicResult(TypedDict): LogicResult: TypeAlias = MessageHandlerLogicResult class TelegramHelper: - CLAUDE_REBOOT_TARGET = 'claude' HTML_QUOTE_BLOCK_START = '
Thinking...' HTML_QUOTE_BLOCK_END = '
' - DEFAULT_REBOOT_CLAUDE_FILE = '.reboot_claude' - DEFAULT_REBOOT_FILE = '.doreboot' CHUNK_MESSAGE_SLEEP_DURATION = 0.1 - def __init__(self, bot, - reboot_claude_file_path: str | None = None, - reboot_file_path: str | None = None, + def __init__(self, bot : InferenceBot, chunk_message_sleep_duration: float | None = None): self.bot = bot self.telegram_bot_token = os.getenv('TELEGRAM_BOT_TOKEN') self.start_time = time.time() - self.reboot_claude_file = reboot_claude_file_path or self.DEFAULT_REBOOT_CLAUDE_FILE - self.reboot_file = reboot_file_path or self.DEFAULT_REBOOT_FILE self.chunk_message_sleep_duration = chunk_message_sleep_duration if chunk_message_sleep_duration is not None else self.CHUNK_MESSAGE_SLEEP_DURATION async def _start_logic(self) -> str: @@ -146,93 +140,16 @@ class TelegramHelper: response_text = await self._abort_processing_logic(user_id) await query.edit_message_text(text=response_text) - # --- Reboot Command --- - def _reboot_logic(self, user_message_parts: List[str], chat_id_to_write: str) -> None: - """Handles the logic for creating reboot files.""" - if len(user_message_parts) > 1 and user_message_parts[1].lower() == self.CLAUDE_REBOOT_TARGET: - try: - with open(self.reboot_claude_file, 'w') as f: - f.write("") # Create/truncate the file - logging.info(f"Created/truncated Claude reboot file: {self.reboot_claude_file}") - except IOError as e: - logging.error(f"Failed to create/truncate Claude reboot file {self.reboot_claude_file}: {e}") - - # Create the main reboot file if it doesn't exist - if not os.path.exists(self.reboot_file): - try: - with open(self.reboot_file, 'w') as f: - f.write(chat_id_to_write) - logging.info(f"Created main reboot file: {self.reboot_file} with chat_id.") - except IOError as e: - logging.error(f"Failed to create main reboot file {self.reboot_file}: {e}") - else: - logging.info(f"Main reboot file {self.reboot_file} already exists. Not overwriting chat_id.") - - async def reboot(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: - """Handles the /reboot command, triggers file creation and exits.""" - user_message_parts = update.message.text.split() - chat_id_str = str(update.effective_chat.id) if update and update.effective_chat else "" - - self._reboot_logic(user_message_parts, chat_id_str) - - if update: - try: - await update.message.reply_text("Rebooting the bot...") - except Exception as e_reply: - logging.error(f"Failed to send reboot reply: {e_reply}") - - logging.info("Initiating shutdown for reboot...") - sys.exit(0) # This part is not directly testable for completion in unit tests - - # --- Check Doreboot File --- - async def _check_doreboot_file_logic(self) -> Union[str, None]: - """Checks for the reboot file, reads chat_id, removes file, and returns chat_id.""" - if os.path.exists(self.reboot_file): - chat_id = None - try: - with open(self.reboot_file, 'r') as f: - chat_id = f.read().strip() - # Attempt to remove the file after reading - try: - os.remove(self.reboot_file) - logging.info(f"Successfully read and removed reboot file: {self.reboot_file}") - except OSError as e_remove: - logging.error(f"Failed to remove reboot file {self.reboot_file} after reading: {e_remove}") - # Still return chat_id if read was successful, to attempt notification - return chat_id - except IOError as e_read: - logging.error(f"Error reading reboot file {self.reboot_file}: {e_read}") - # If reading failed, attempt to remove anyway if it exists, to prevent stale files - if os.path.exists(self.reboot_file): - try: - os.remove(self.reboot_file) - logging.warning(f"Removed reboot file {self.reboot_file} after a read error.") - except OSError as e_remove_after_fail: - logging.error(f"Failed to remove reboot file {self.reboot_file} even after a read error: {e_remove_after_fail}") - return None # Reading failed - return None # File does not exist - - async def check_doreboot_file(self, application: Application) -> None: - """Checks for reboot file using logic method and sends notification if applicable.""" - chat_id = await self._check_doreboot_file_logic() - if chat_id: - try: - await application.bot.send_message(chat_id=chat_id, text="The application has finished initializing.") - logging.info(f"Sent reboot initialization notification to chat_id: {chat_id}") - except Exception as e: - logging.error(f"Failed to send reboot initialization notification to chat_id {chat_id}: {e}") - async def browse(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: await browse_command(update, context, self.bot) def run(self): - application = Application.builder().token(self.telegram_bot_token).post_init(self.check_doreboot_file).build() + application = Application.builder().token(self.telegram_bot_token).build() application.add_handler(CommandHandler("start", self.start)) application.add_handler(CommandHandler("clear", self.clear)) application.add_handler(CommandHandler("switch", self.switch)) application.add_handler(CommandHandler("status", self.status)) - application.add_handler(CommandHandler("reboot", self.reboot)) application.add_handler(CommandHandler("browse", self.browse)) application.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, self.handle_message)) application.add_handler(CallbackQueryHandler(self.abort_processing, pattern='^abort$')) diff --git a/tests/chatgpt/__init__.py b/tests/chatgpt/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/claude/__init__.py b/tests/claude/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/claude/test_anthropic_telegram_inference_bot.py b/tests/claude/test_anthropic_telegram_inference_bot.py deleted file mode 100644 index 2e34635..0000000 --- a/tests/claude/test_anthropic_telegram_inference_bot.py +++ /dev/null @@ -1,33 +0,0 @@ -import unittest -from unittest.mock import patch, MagicMock -from anthropic_telegram_inference_bot import AnthropicTelegramInferenceBot - -class TestAnthropicTelegramInferenceBot(unittest.TestCase): - def setUp(self): - self.bot = AnthropicTelegramInferenceBot() - - @patch('anthropic_telegram_inference_bot.Anthropic') - def test_get_chat_response(self, MockAnthropic): - mock_anthropic = MockAnthropic.return_value - mock_anthropic.messages.create.return_value = MagicMock() - - messages = [{"role": "user", "content": "Hello"}] - response = self.bot.get_chat_response(messages) - - self.assertIsNotNone(response) - - @patch('anthropic_telegram_inference_bot.Anthropic') - def test_handle_message(self, MockAnthropic): - mock_anthropic = MockAnthropic.return_value - mock_anthropic.messages.create.return_value = MagicMock(content=[MagicMock(type="message", text="response content")]) - - user_id = "user123" - user_message = "Hello" - response = self.bot.handle_message(user_id, user_message) - - self.assertIsNotNone(response) - - # Additional testing for error cases and edge cases - -if __name__ == '__main__': - unittest.main() diff --git a/tests/claude/test_base_telegram_inference_bot.py b/tests/claude/test_base_telegram_inference_bot.py deleted file mode 100644 index 1450e8c..0000000 --- a/tests/claude/test_base_telegram_inference_bot.py +++ /dev/null @@ -1,33 +0,0 @@ -import unittest -from base_telegram_inference_bot import BaseTelegramInferenceBot - -class TestBaseTelegramInferenceBot(unittest.TestCase): - def setUp(self): - # Initialize the bot or mock any dependencies here - self.bot = BaseTelegramInferenceBot() - - def test_load_system_prompt(self): - # Example test case for load_system_prompt method - result = self.bot.load_system_prompt() - self.assertIsNotNone(result) # Replace with actual expected result - - def test_load_functions(self): - # Test the load_functions method - functions = self.bot.load_functions() - self.assertIsInstance(functions, list) # Replace with actual expected result - self.assertTrue(len(functions) > 0) # Assuming it should load some functions - - def test_clear_conversation(self): - # Test the clear_conversation method - self.bot.clear_conversation() - self.assertEqual(self.bot.conversations, {}) # Assuming conversations is a dictionary - - def test_call_tool(self): - # Test the call_tool method - tool_name = "some_tool" - params = {"param1": "value1"} - result = self.bot.call_tool(tool_name, params) - self.assertIsNotNone(result) # Replace with actual expected result - -if __name__ == '__main__': - unittest.main() diff --git a/tests/claude/test_chatgpt_telegram_inference_bot.py b/tests/claude/test_chatgpt_telegram_inference_bot.py deleted file mode 100644 index a3150a0..0000000 --- a/tests/claude/test_chatgpt_telegram_inference_bot.py +++ /dev/null @@ -1,38 +0,0 @@ -import unittest -from unittest.mock import patch, MagicMock -from chatgpt_telegram_inference_bot import ChatGPTTelegramInferenceBot - -class TestChatGPTTelegramInferenceBot(unittest.TestCase): - def setUp(self): - self.bot = ChatGPTTelegramInferenceBot() - - @patch('chatgpt_telegram_inference_bot.OpenAI') - def test_get_chat_response(self, MockOpenAI): - mock_ai = MockOpenAI.return_value - mock_ai.chat.completions.create.return_value = MagicMock() - - messages = [{"role": "user", "content": "Hello"}] - response = self.bot.get_chat_response(messages) - - self.assertIsNotNone(response) - - @patch('chatgpt_telegram_inference_bot.OpenAI') - def test_handle_message(self, MockOpenAI): - mock_ai = MockOpenAI.return_value - mock_ai.chat.completions.create.return_value = MagicMock(choices=[MagicMock(message={"content": "response content"}, finish_reason='stop')]) - - user_id = "user123" - user_message = "Hello" - response = self.bot.handle_message(user_id, user_message) - - self.assertIsNotNone(response) - - def test_switch_model(self): - initial_model = self.bot.model - self.bot.switch_model() - self.assertNotEqual(initial_model, self.bot.model) - - # Additional testing for error cases and edge cases - -if __name__ == '__main__': - unittest.main() diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/test_anthropic_telegram_inference_bot.py b/tests/test_anthropic_telegram_inference_bot.py deleted file mode 100644 index c7c715f..0000000 --- a/tests/test_anthropic_telegram_inference_bot.py +++ /dev/null @@ -1,280 +0,0 @@ -import unittest -from unittest.mock import MagicMock, patch, AsyncMock, ANY -import os - -# Assuming anthropic_telegram_inference_bot.py is in the parent directory or PYTHONPATH is set -from anthropic_telegram_inference_bot import AnthropicTelegramInferenceBot - -# Mock response from Anthropic client's messages.create -def create_mock_anthropic_response(content_text=None, stop_reason="end_turn", tool_use_parts=None): - mock_response = MagicMock() - mock_response.stop_reason = stop_reason - - content_blocks = [] - if content_text: - text_block = MagicMock() - text_block.type = "text" - text_block.text = content_text - content_blocks.append(text_block) - - if tool_use_parts: - for tu_part in tool_use_parts: # tu_part = {"id": "toolu_123", "name": "get_weather", "input": {}} - tool_block = MagicMock() - tool_block.type = "tool_use" - tool_block.id = tu_part["id"] - tool_block.name = tu_part["name"] - tool_block.input = tu_part["input"] - content_blocks.append(tool_block) - - mock_response.content = content_blocks - return mock_response - -class TestAnthropicTelegramInferenceBot(unittest.IsolatedAsyncioTestCase): - - def setUp(self): - self.original_anthropic_api_key = os.environ.get("ANTHROPIC_API_KEY") - self.original_small_model = os.environ.get("ANTHROPIC_SMALL_MODEL") - self.original_large_model = os.environ.get("ANTHROPIC_LARGE_MODEL") - self.original_system_prompt_path = os.environ.get("SYSTEM_PROMPT_PATH") - - for key in ["ANTHROPIC_API_KEY", "ANTHROPIC_SMALL_MODEL", "ANTHROPIC_LARGE_MODEL", "SYSTEM_PROMPT_PATH"]: - if os.environ.get(key): - del os.environ[key] - - self.mock_anthropic_client_instance = MagicMock() - self.mock_anthropic_client_instance.messages.create = MagicMock() - - def tearDown(self): - if self.original_anthropic_api_key: os.environ["ANTHROPIC_API_KEY"] = self.original_anthropic_api_key - if self.original_small_model: os.environ["ANTHROPIC_SMALL_MODEL"] = self.original_small_model - if self.original_large_model: os.environ["ANTHROPIC_LARGE_MODEL"] = self.original_large_model - if self.original_system_prompt_path: os.environ["SYSTEM_PROMPT_PATH"] = self.original_system_prompt_path - - @patch('anthropic.Anthropic') - def test_init_with_anthropic_defaults_env_key(self, MockAnthropicConstructor): - MockAnthropicConstructor.return_value = self.mock_anthropic_client_instance - os.environ["ANTHROPIC_API_KEY"] = "test_anthropic_key" - - bot = AnthropicTelegramInferenceBot() - - MockAnthropicConstructor.assert_called_once_with(api_key="test_anthropic_key") - self.assertEqual(bot.anthropic_client, self.mock_anthropic_client_instance) - self.assertEqual(bot.model, os.environ.get("ANTHROPIC_SMALL_MODEL", "claude-3-haiku-20240307")) - self.assertEqual(bot.max_tokens, int(os.environ.get("ANTHROPIC_SMALL_MODEL_MAX_TOKENS", 2000))) - - @patch('anthropic.Anthropic') - def test_init_with_provided_client_and_models(self, MockAnthropicConstructor): - preconfigured_client = MagicMock() - bot = AnthropicTelegramInferenceBot( - anthropic_client=preconfigured_client, - small_model_name="custom-small", - small_model_max_tokens=100, - large_model_name="custom-large", - large_model_max_tokens=200 - ) - - MockAnthropicConstructor.assert_not_called() - self.assertEqual(bot.anthropic_client, preconfigured_client) - self.assertEqual(bot.model, "custom-small") - self.assertEqual(bot.max_tokens, 100) - self.assertEqual(bot.small_model_name, "custom-small") - self.assertEqual(bot.large_model_name, "custom-large") - - - def test_get_llm_description(self): - bot = AnthropicTelegramInferenceBot(small_model_name="claude-test", small_model_max_tokens=500) - self.assertEqual(bot.get_llm_description(), "LLM: claude-test, Max Tokens: 500") - - async def test_switch_model(self): - bot = AnthropicTelegramInferenceBot( - small_model_name="claude-small", small_model_max_tokens=10, - large_model_name="claude-large", large_model_max_tokens=20 - ) - self.assertEqual(bot.model, "claude-small") - self.assertEqual(bot.max_tokens, 10) - - status = await bot.switch_model() - self.assertEqual(bot.model, "claude-large") - self.assertEqual(bot.max_tokens, 20) - self.assertEqual(status, "Switched to model: claude-large") - - status = await bot.switch_model() - self.assertEqual(bot.model, "claude-small") - self.assertEqual(bot.max_tokens, 10) - self.assertEqual(status, "Switched to model: claude-small") - - def test_get_chat_response_success_text_only(self): - bot = AnthropicTelegramInferenceBot(anthropic_client=self.mock_anthropic_client_instance) - bot.model = "test-claude" - bot.max_tokens = 150 - - mock_api_response = create_mock_anthropic_response(content_text="Hello from Anthropic API") - self.mock_anthropic_client_instance.messages.create.return_value = mock_api_response - - messages = [{"role": "user", "content": "Hi"}] # Anthropic format - response = bot.get_chat_response(messages, []) # tools = empty list - - self.mock_anthropic_client_instance.messages.create.assert_called_once_with( - model="test-claude", - max_tokens=150, - messages=messages, - system=bot.system_prompt, # Ensure system prompt is passed - tools=None, # No tools passed to API if empty list or None - tool_choice=None - ) - self.assertEqual(response, mock_api_response) - - def test_get_chat_response_with_tools(self): - bot = AnthropicTelegramInferenceBot(anthropic_client=self.mock_anthropic_client_instance) - bot.model = "claude-toolmaster" - bot.max_tokens = 300 - - mock_tools_spec = [{"name": "get_weather", "description": "Gets weather", "input_schema": {"type": "object", "properties": {}}}] - - mock_api_response = create_mock_anthropic_response(content_text="Thinking...", tool_use_parts=[ - {"id": "tool1", "name": "get_weather", "input": {"location": "here"}} - ]) - self.mock_anthropic_client_instance.messages.create.return_value = mock_api_response - - messages = [{"role": "user", "content": "Weather?"}] - response = bot.get_chat_response(messages, mock_tools_spec) - - self.mock_anthropic_client_instance.messages.create.assert_called_once_with( - model="claude-toolmaster", - max_tokens=300, - messages=messages, - system=bot.system_prompt, - tools=mock_tools_spec, - tool_choice={"type": "auto"} - ) - self.assertEqual(response.content[0].type, "text") # First part can be text - self.assertEqual(response.content[1].type, "tool_use") - - - def test_get_chat_response_api_error(self): - bot = AnthropicTelegramInferenceBot(anthropic_client=self.mock_anthropic_client_instance) - self.mock_anthropic_client_instance.messages.create.side_effect = Exception("Anthropic API Down") - - with self.assertRaisesRegex(Exception, "Anthropic API Down"): - bot.get_chat_response([{"role": "user", "content": "trigger"}], []) - - - async def test_handle_message_simple_response_no_tools(self): - # This test is more involved as it touches BaseTelegramInferenceBot's handle_message structure - # which then calls the overridden get_chat_response. - bot = AnthropicTelegramInferenceBot(anthropic_client=self.mock_anthropic_client_instance) - bot.system_prompt = "System prompt for Anthropic" - - # Mock get_chat_response directly to isolate its behavior from full handle_message logic of base - # However, the point of this bot is its get_chat_response and subsequent processing. - # So, let's mock the API call within get_chat_response. - - api_response = create_mock_anthropic_response(content_text="Anthropic says hello.") - self.mock_anthropic_client_instance.messages.create.return_value = api_response - - # Ensure functions are empty for this test, so no tool logic is triggered - bot.functions = [] - bot.tools = [] - - response_content = await bot.handle_message(user_id=101, user_message="Hello Anthropic") - - self.assertEqual(response_content, "Anthropic says hello.") - self.assertIn(101, bot.conversation_history) - # Anthropic's handle_message structure: - # 1. User message added to history. - # 2. get_chat_response is called. - # 3. Response content (text) is extracted. - # 4. Assistant text response is added to history. - # Expected history: [User, Assistant_Text_Response] (system prompt handled by get_chat_response) - # The base class handle_message adds system prompt if not present. - # Anthropic handle_message modifies history format before calling get_chat_response. - - # Let's trace Base.handle_message -> Anthropic.handle_message -> Anthropic.get_chat_response - # Base.handle_message: - # - Adds system prompt to history if first turn: `self.conversation_history[user_id] = [{"role": "system", "content": self.system_prompt}]` (OpenAI style) - # - Appends user message: `{"role": "user", "content": user_message}` - # - Calls self.get_chat_response(messages, self.functions) -> This is Anthropic's get_chat_response - # Anthropic.get_chat_response: - # - Takes OpenAI style `messages` and `self.functions` (tool specs). - # - Calls `anthropic_client.messages.create` with Anthropic style messages and system prompt. - # Anthropic.handle_message (overridden): - # - Prepares Anthropic-style messages from conversation_history (which is OpenAI style from Base) - # - Calls get_chat_response with these Anthropic messages and self.functions (tool_specs) - # - Processes response, extracts text, handles tool calls. - # - Appends *user* message (original) and *assistant* text response to self.conversation_history (OpenAI style). - - # For this test, we are calling AnthropicBot.handle_message directly. - # 1. `user_id` not in `self.conversation_history`: `system_prompt` not added yet by Base logic. - # Anthropic's `handle_message` will create `anthropic_messages` from this. - # If `conversation_history` is empty, `anthropic_messages` = `[{"role": "user", "content": user_message}]` - # 2. `get_chat_response` called with `anthropic_messages` and `bot.system_prompt` passed to API. - # 3. Response "Anthropic says hello." - # 4. Original `user_message` and "Anthropic says hello." (as assistant) added to `self.conversation_history`. - - history = bot.conversation_history[101] - self.assertEqual(len(history), 2) # User, Assistant - self.assertEqual(history[0]["role"], "user") - self.assertEqual(history[0]["content"], "Hello Anthropic") - self.assertEqual(history[1]["role"], "assistant") - self.assertEqual(history[1]["content"], "Anthropic says hello.") - - # Check API call (made by the mocked get_chat_response indirectly) - self.mock_anthropic_client_instance.messages.create.assert_called_once() - call_args = self.mock_anthropic_client_instance.messages.create.call_args - self.assertEqual(call_args.kwargs["system"], "System prompt for Anthropic") - # Initial messages for API should just be the user message for first turn - self.assertEqual(call_args.kwargs["messages"], [{"role": "user", "content": "Hello Anthropic"}]) - - - async def test_handle_message_with_tool_calls(self): - bot = AnthropicTelegramInferenceBot(anthropic_client=self.mock_anthropic_client_instance) - bot.system_prompt = "You are a helpful, tool-using assistant." - - # Define a tool for the bot (OpenAI format, will be converted by Anthropic bot for API) - mock_tool_oai_format = {"type": "function", "function": {"name": "get_weather", "description": "Get weather", "parameters": {}}} - bot.functions = [mock_tool_oai_format] # This is used to generate anthropic_tools for API - - # API Response 1: Request for tool call - tool_use_part = {"id": "toolu_xyz", "name": "get_weather", "input": {"location": "paris"}} - api_response_1 = create_mock_anthropic_response(tool_use_parts=[tool_use_part]) - - # API Response 2: Final text response after tool execution - api_response_2 = create_mock_anthropic_response(content_text="The weather in Paris is nice.") - - self.mock_anthropic_client_instance.messages.create.side_effect = [api_response_1, api_response_2] - - # Mock the bot's call_tool method (from BaseTelegramInferenceBot) - bot.call_tool = MagicMock(return_value='''{"weather": "sunny"}''') # Tool execution result - - user_id = 102 - user_message = "What's the weather in Paris?" - final_text_response = await bot.handle_message(user_id, user_message) - - self.assertEqual(final_text_response, "The weather in Paris is nice.") - self.assertEqual(self.mock_anthropic_client_instance.messages.create.call_count, 2) - - bot.call_tool.assert_called_once_with("get_weather", {"location": "paris"}) # Anthropic passes input as dict - - # Check conversation history (OpenAI style) - history = bot.conversation_history[user_id] - self.assertEqual(history[0]["role"], "user") - self.assertEqual(history[0]["content"], user_message) - - # Assistant message that requested tool call (Anthropic-specific format stored by its handle_message) - # Anthropic's handle_message appends the raw tool_use block and then the tool_result - self.assertEqual(history[1]["role"], "assistant") - self.assertTrue(isinstance(history[1]["content"], list)) # Anthropic content is a list - self.assertEqual(history[1]["content"][0]["type"], "tool_use") - self.assertEqual(history[1]["content"][0]["id"], "toolu_xyz") - - self.assertEqual(history[2]["role"], "tool") - self.assertEqual(history[2]["tool_call_id"], "toolu_xyz") - self.assertEqual(history[2]["name"], "get_weather") - self.assertEqual(history[2]["content"], '''{"weather": "sunny"}''') # call_tool result - - self.assertEqual(history[3]["role"], "assistant") # Final text response - self.assertTrue(isinstance(history[3]["content"], str)) # simple text - self.assertEqual(history[3]["content"], "The weather in Paris is nice.") - -if __name__ == '__main__': - unittest.main() diff --git a/tests/test_base_telegram_inference_bot.py b/tests/test_base_telegram_inference_bot.py deleted file mode 100644 index bcb2b52..0000000 --- a/tests/test_base_telegram_inference_bot.py +++ /dev/null @@ -1,310 +0,0 @@ -import unittest -from unittest.mock import patch, mock_open, MagicMock -import os -import json - -# Ensure the path includes the directory where base_telegram_inference_bot is located -# This might require adjustment based on actual project structure if tests are run from root -# For now, assuming it can be imported directly or via PYTHONPATH -from base_telegram_inference_bot import BaseTelegramInferenceBot -from tools.base_tool import BaseTool # For mocking tool structure - -# Create a concrete subclass for testing, as BaseTelegramInferenceBot is abstract -class ConcreteTestBot(BaseTelegramInferenceBot): - def __init__(self, system_prompt_content=None, system_prompt_path=None, mock_tools=None, mock_functions=None): - # Mock load_functions during super().__init__ if needed, or control tools/functions directly - self._mock_tools = mock_tools if mock_tools is not None else [] - self._mock_functions = mock_functions if mock_functions is not None else [] - super().__init__(system_prompt_content=system_prompt_content, system_prompt_path=system_prompt_path) - - # Override load_functions to use mocks - def load_functions(self): - return self._mock_tools, self._mock_functions - - def get_chat_response(self, messages): - pass # Abstract method, not tested here directly - - async def handle_message(self, user_id, user_message): - pass # Abstract method - - def get_llm_description(self) -> str: - return "Mock LLM Description" # Concrete implementation for testing get_bot_status - - async def start(self): - pass # Abstract method - - async def abort_processing(self, user_id): - pass # Abstract method - - async def switch_model(self): - pass # Abstract method - -class TestBaseTelegramInferenceBot(unittest.TestCase): - - def setUp(self): - # Reset relevant environment variables before each test - self.original_system_prompt_path = os.environ.get("SYSTEM_PROMPT_PATH") - if "SYSTEM_PROMPT_PATH" in os.environ: - del os.environ["SYSTEM_PROMPT_PATH"] - - def tearDown(self): - # Restore environment variables - if self.original_system_prompt_path: - os.environ["SYSTEM_PROMPT_PATH"] = self.original_system_prompt_path - elif "SYSTEM_PROMPT_PATH" in os.environ: # Ensure it's removed if test set it and it wasn't there before - del os.environ["SYSTEM_PROMPT_PATH"] - - def test_init_with_direct_system_prompt(self): - bot = ConcreteTestBot(system_prompt_content="Direct prompt content") - self.assertEqual(bot.system_prompt, "Direct prompt content") - - @patch("os.path.isfile") - @patch("builtins.open", new_callable=mock_open, read_data="File prompt content") - def test_init_with_system_prompt_path_argument(self, mock_file_open, mock_isfile): - mock_isfile.return_value = True - bot = ConcreteTestBot(system_prompt_path="dummy/path.txt") - self.assertEqual(bot.system_prompt, "File prompt content") - mock_isfile.assert_called_once_with("dummy/path.txt") - mock_file_open.assert_called_once_with("dummy/path.txt", "r", encoding="utf-8") - - @patch("os.path.isfile") - @patch("builtins.open", new_callable=mock_open, read_data="Env prompt content") - def test_init_with_env_system_prompt_path(self, mock_file_open, mock_isfile): - mock_isfile.return_value = True - os.environ["SYSTEM_PROMPT_PATH"] = "env/path.txt" - bot = ConcreteTestBot() - self.assertEqual(bot.system_prompt, "Env prompt content") - mock_isfile.assert_called_once_with("env/path.txt") - mock_file_open.assert_called_once_with("env/path.txt", "r", encoding="utf-8") - - def test_init_with_default_system_prompt(self): - # Ensure ENV var is not set for this test - if "SYSTEM_PROMPT_PATH" in os.environ: - del os.environ["SYSTEM_PROMPT_PATH"] - bot = ConcreteTestBot() - self.assertEqual(bot.system_prompt, "You are a helpful AI assistant.") - - @patch("os.path.isfile", return_value=False) - def test_init_with_invalid_system_prompt_path(self, mock_isfile): - bot = ConcreteTestBot(system_prompt_path="invalid/path.txt") - self.assertEqual(bot.system_prompt, "You are a helpful AI assistant.") - mock_isfile.assert_called_once_with("invalid/path.txt") - - @patch("os.path.isfile") - @patch("builtins.open", side_effect=IOError("File read error")) - def test_init_with_system_prompt_file_read_error(self, mock_file_open, mock_isfile): - mock_isfile.return_value = True - bot = ConcreteTestBot(system_prompt_path="dummy/path.txt") - self.assertEqual(bot.system_prompt, "You are a helpful AI assistant.") - - def test_clear_conversation_history(self): - mock_tool_instance = MagicMock(spec=BaseTool) - bot = ConcreteTestBot(mock_tools=[mock_tool_instance]) - bot.conversation_history[123] = [{"role": "user", "content": "Hello"}] - - bot.clear_conversation_history(123) - self.assertNotIn(123, bot.conversation_history) - mock_tool_instance.clear.assert_called_once() - - def test_clear_conversation_history_user_not_found(self): - mock_tool_instance = MagicMock(spec=BaseTool) - bot = ConcreteTestBot(mock_tools=[mock_tool_instance]) - bot.clear_conversation_history(404) - self.assertNotIn(404, bot.conversation_history) - mock_tool_instance.clear.assert_called_once() - - def test_processing_status(self): - bot = ConcreteTestBot() - self.assertEqual(bot.processing_status, {}) - bot.set_processing_status(123, 789) - self.assertEqual(bot.processing_status[123], {"processing": True, "message_id": 789}) - bot.clear_processing_status(123) - self.assertNotIn(123, bot.processing_status) - - def test_clear_processing_status_user_not_found(self): - bot = ConcreteTestBot() - bot.clear_processing_status(404) - self.assertNotIn(404, bot.processing_status) - - def test_call_tool_success_dict_args(self): - mock_tool = MagicMock(spec=BaseTool) - mock_tool.get_functions.return_value = [ - {"function": {"name": "test_tool", "description": "A test tool", "parameters": {}}} - ] - mock_tool.execute.return_value = "Tool executed successfully" - - bot = ConcreteTestBot(mock_tools=[mock_tool], mock_functions=mock_tool.get_functions()) - - result = bot.call_tool("test_tool", {"arg1": "value1"}) - self.assertEqual(result, "Tool executed successfully") - mock_tool.execute.assert_called_once_with("test_tool", arg1="value1") - - def test_call_tool_success_json_string_args(self): - mock_tool = MagicMock(spec=BaseTool) - mock_tool.get_functions.return_value = [ - {"function": {"name": "test_tool_json", "parameters": {}}} - ] - mock_tool.execute.return_value = "Tool JSON OK" - bot = ConcreteTestBot(mock_tools=[mock_tool], mock_functions=mock_tool.get_functions()) - - args_json_str = '''{"param": "value"}''' - result = bot.call_tool("test_tool_json", args_json_str) - self.assertEqual(result, "Tool JSON OK") - mock_tool.execute.assert_called_once_with("test_tool_json", param="value") - - def test_call_tool_malformed_json_string_args(self): - bot = ConcreteTestBot(mock_tools=[]) - args_malformed_json_str = '''{"param": "value"''' - result = bot.call_tool("some_tool", args_malformed_json_str) - self.assertTrue("Error: Malformed arguments for tool call" in result) - - def test_call_tool_unexpected_arg_type(self): - bot = ConcreteTestBot(mock_tools=[]) - result = bot.call_tool("some_tool", 12345) # Integer instead of dict/str - self.assertTrue("Error: Invalid argument type for tool call" in result) - - def test_call_tool_none_args(self): - mock_tool = MagicMock(spec=BaseTool) - mock_tool.get_functions.return_value = [ - {"function": {"name": "test_tool_none", "parameters": {}}} - ] - mock_tool.execute.return_value = "Tool None OK" - bot = ConcreteTestBot(mock_tools=[mock_tool], mock_functions=mock_tool.get_functions()) - - result = bot.call_tool("test_tool_none", None) - self.assertEqual(result, "Tool None OK") - mock_tool.execute.assert_called_once_with("test_tool_none") # No kwargs if None - - def test_call_tool_not_found(self): - bot = ConcreteTestBot(mock_tools=[]) - result = bot.call_tool("non_existent_tool", {}) - self.assertEqual(result, "Error: Tool function non_existent_tool not found.") - - def test_call_tool_execute_exception(self): - mock_tool = MagicMock(spec=BaseTool) - mock_tool.get_functions.return_value = [{"function": {"name": "error_tool", "parameters": {}}}] - mock_tool.execute.side_effect = Exception("Execution failed") - bot = ConcreteTestBot(mock_tools=[mock_tool], mock_functions=mock_tool.get_functions()) - - result = bot.call_tool("error_tool", {}) - self.assertEqual(result, "Error executing tool error_tool: Execution failed") - - def test_get_system_prompt_description(self): - if "SYSTEM_PROMPT_PATH" in os.environ: # Ensure clean state - del os.environ["SYSTEM_PROMPT_PATH"] - - bot_default = ConcreteTestBot() - self.assertEqual(bot_default.get_system_prompt_description(), "System Prompt: Default") - - bot_custom_content = ConcreteTestBot(system_prompt_content="Custom content here") - self.assertEqual(bot_custom_content.get_system_prompt_description(), "System Prompt: Custom") - - os.environ["SYSTEM_PROMPT_PATH"] = "some/path.txt" - bot_env_default_prompt = ConcreteTestBot() # system_prompt itself is default - self.assertEqual(bot_env_default_prompt.get_system_prompt_description(), "System Prompt: Custom (via ENV)") - - with patch("os.path.isfile", return_value=True), \ - patch("builtins.open", mock_open(read_data="File prompt from ENV")): - bot_env_file_prompt = ConcreteTestBot() # system_prompt gets loaded from ENV path - self.assertEqual(bot_env_file_prompt.get_system_prompt_description(), "System Prompt: Custom") - del os.environ["SYSTEM_PROMPT_PATH"] - - with patch("os.path.isfile", return_value=True), \ - patch("builtins.open", mock_open(read_data="File prompt from arg")): - bot_custom_file_arg = ConcreteTestBot(system_prompt_path="custom/file.txt") - self.assertEqual(bot_custom_file_arg.get_system_prompt_description(), "System Prompt: Custom") - - @patch.object(ConcreteTestBot, 'get_llm_description', return_value="Test LLM Description") - @patch.object(ConcreteTestBot, 'get_system_prompt_description', return_value="Test Prompt Description") - async def test_get_bot_status(self, mock_prompt_desc, mock_llm_desc): - bot = ConcreteTestBot() - status = await bot.get_bot_status() - self.assertEqual(status, "Test Prompt Description\nTest LLM Description") - mock_prompt_desc.assert_called_once() - mock_llm_desc.assert_called_once() - - @patch('os.path.dirname', return_value='/mock/path') - @patch('os.path.join') - @patch('os.path.exists') - @patch('os.listdir') - @patch('importlib.import_module') - def test_load_functions_no_tools_dir(self, mock_import_module, mock_listdir, mock_exists, mock_join, mock_dirname): - mock_join.return_value = '/mock/path/tools' - mock_exists.return_value = False - - class BotForLoadTest(BaseTelegramInferenceBot): - load_system_prompt = MagicMock(return_value="Default") - get_chat_response = MagicMock(); handle_message = MagicMock(); get_llm_description = MagicMock(return_value="mock") - start = MagicMock(); abort_processing = MagicMock(); switch_model = MagicMock() - - bot = BotForLoadTest() - self.assertEqual(bot.tools, []) - self.assertEqual(bot.functions, []) - mock_listdir.assert_not_called() - - @patch('os.path.dirname', return_value='/mock/base_bot_dir') - @patch('os.path.join', side_effect=lambda *args: os.path.normpath(os.path.join(*args))) - @patch('os.path.exists', return_value=True) - @patch('os.listdir', return_value=['my_tool.py', '__init__.py', 'base_tool.py']) - @patch('importlib.import_module') - def test_load_functions_with_one_tool(self, mock_import_module, mock_listdir, mock_exists, mock_join, mock_dirname): - - mock_tool_class = MagicMock(spec=BaseTool) # This is the class itself - mock_tool_instance = MagicMock(spec=BaseTool) # This is the instance - mock_tool_class.return_value = mock_tool_instance # mock_tool_class() creates mock_tool_instance - mock_tool_instance.get_functions.return_value = [{"function": {"name": "sample_function"}}] - - mock_my_tool_module = MagicMock() - # Simulate inspect.getmembers behavior: returns list of (name, member) tuples - # Only include members that are classes, derive from BaseTool, and are not BaseTool itself. - mock_my_tool_module.ValidToolClass = mock_tool_class - mock_my_tool_module.NotATool = object() - mock_my_tool_module.BaseTool = BaseTool # This should be skipped by the loader - - def import_side_effect(module_name): - if module_name == 'tools.my_tool': - return mock_my_tool_module - raise ImportError(f"Unexpected import: {module_name}") - mock_import_module.side_effect = import_side_effect - - class BotForLoadTest(BaseTelegramInferenceBot): - load_system_prompt = MagicMock(return_value="Default") - get_chat_response = MagicMock(); handle_message = MagicMock(); get_llm_description = MagicMock(return_value="mock") - start = MagicMock(); abort_processing = MagicMock(); switch_model = MagicMock() - - bot = BotForLoadTest() - self.assertEqual(len(bot.tools), 1) - self.assertIs(bot.tools[0], mock_tool_instance) - self.assertEqual(len(bot.functions), 1) - self.assertEqual(bot.functions[0]['function']['name'], "sample_function") - mock_import_module.assert_called_once_with('tools.my_tool') - mock_tool_class.assert_called_once_with() # Tool class was instantiated - mock_tool_instance.get_functions.assert_called_once_with() - - @patch('os.path.dirname', return_value='/mock/base_bot_dir') - @patch('os.path.join', side_effect=lambda *args: os.path.normpath(os.path.join(*args))) - @patch('os.path.exists', return_value=True) - @patch('os.listdir', return_value=['tool_with_init_error.py']) - @patch('importlib.import_module') - @patch('logging.error') # Mock logging to check for error messages - def test_load_functions_tool_instantiation_error(self, mock_logging_error, mock_import_module, mock_listdir, mock_exists, mock_join, mock_dirname): - mock_tool_class_init_error = MagicMock(spec=BaseTool) - mock_tool_class_init_error.side_effect = Exception("Failed to init tool") # Error on instantiation - - mock_error_tool_module = MagicMock() - mock_error_tool_module.ToolWithInitError = mock_tool_class_init_error - - mock_import_module.return_value = mock_error_tool_module - - class BotForLoadTest(BaseTelegramInferenceBot): - load_system_prompt = MagicMock(return_value="Default") - get_chat_response = MagicMock(); handle_message = MagicMock(); get_llm_description = MagicMock(return_value="mock") - start = MagicMock(); abort_processing = MagicMock(); switch_model = MagicMock() - - bot = BotForLoadTest() - self.assertEqual(len(bot.tools), 0) - self.assertEqual(len(bot.functions), 0) - mock_logging_error.assert_any_call("Error instantiating tool ToolWithInitError from tool_with_init_error.py: Failed to init tool") - -if __name__ == '__main__': - unittest.main(闂傚лен䦗婢у〃埊鍓解劓姣) diff --git a/tests/test_chatgpt_telegram_inference_bot.py b/tests/test_chatgpt_telegram_inference_bot.py deleted file mode 100644 index a0f0bdb..0000000 --- a/tests/test_chatgpt_telegram_inference_bot.py +++ /dev/null @@ -1,158 +0,0 @@ -import unittest -from unittest.mock import MagicMock, patch, ANY -import os - -# Assuming chatgpt_telegram_inference_bot.py and its parent are accessible -from chatgpt_telegram_inference_bot import ChatGPTTelegramInferenceBot -from openai_compatible_inference_bot import OpenAICompatibleInferenceBot # For patching super - -class TestChatGPTTelegramInferenceBot(unittest.IsolatedAsyncioTestCase): - - def setUp(self): - # Store and clear relevant environment variables - self.original_openai_key = os.environ.get("OPENAI_API_KEY") - self.original_small_model = os.environ.get("OPENAI_SMALL_MODEL") - self.original_large_model = os.environ.get("OPENAI_LARGE_MODEL") - self.original_small_tokens = os.environ.get("OPENAI_SMALL_MODEL_MAX_TOKENS") - self.original_large_tokens = os.environ.get("OPENAI_LARGE_MODEL_MAX_TOKENS") - self.original_system_prompt_path = os.environ.get("SYSTEM_PROMPT_PATH") - - for key in ["OPENAI_API_KEY", "OPENAI_SMALL_MODEL", "OPENAI_LARGE_MODEL", - "OPENAI_SMALL_MODEL_MAX_TOKENS", "OPENAI_LARGE_MODEL_MAX_TOKENS", "SYSTEM_PROMPT_PATH"]: - if os.environ.get(key): - del os.environ[key] - - # Mock the OpenAI client that OpenAICompatibleInferenceBot's __init__ might create - self.mock_openai_client = MagicMock() - - def tearDown(self): - # Restore environment variables - if self.original_openai_key: os.environ["OPENAI_API_KEY"] = self.original_openai_key - if self.original_small_model: os.environ["OPENAI_SMALL_MODEL"] = self.original_small_model - if self.original_large_model: os.environ["OPENAI_LARGE_MODEL"] = self.original_large_model - if self.original_small_tokens: os.environ["OPENAI_SMALL_MODEL_MAX_TOKENS"] = self.original_small_tokens - if self.original_large_tokens: os.environ["OPENAI_LARGE_MODEL_MAX_TOKENS"] = self.original_large_tokens - if self.original_system_prompt_path: os.environ["SYSTEM_PROMPT_PATH"] = self.original_system_prompt_path - - - @patch.object(OpenAICompatibleInferenceBot, '__init__') # Mock the superclass's __init__ - def test_init_defaults_and_super_call(self, mock_super_init): - os.environ["OPENAI_API_KEY"] = "test_key_chatgpt" - os.environ["OPENAI_SMALL_MODEL"] = "gpt-3.5-turbo-env" - os.environ["OPENAI_SMALL_MODEL_MAX_TOKENS"] = "350" - - bot = ChatGPTTelegramInferenceBot() - - mock_super_init.assert_called_once_with( - client=None, # ChatGPT bot will let superclass create it - api_key="test_key_chatgpt", # Passed to super - base_url=None, - api_version=None, - azure_deployment=None, - model_name="gpt-3.5-turbo-env", # Default small model from env - max_tokens_str="350", # Default small model tokens from env - small_model_name="gpt-3.5-turbo-env", - small_model_max_tokens_str="350", - large_model_name=os.environ.get("OPENAI_LARGE_MODEL", "gpt-4-turbo-preview"), # Default large - large_model_max_tokens_str=os.environ.get("OPENAI_LARGE_MODEL_MAX_TOKENS"), - system_prompt_content=None, - system_prompt_path=None, - is_gemini=False, - max_history_length=20 # Default from OpenAICompatibleInferenceBot - ) - - @patch.object(OpenAICompatibleInferenceBot, '__init__') - def test_init_with_arguments(self, mock_super_init): - mock_client_arg = MagicMock() - bot = ChatGPTTelegramInferenceBot( - openai_client=mock_client_arg, - api_key="arg_key", - small_model_name="arg_small_model", - small_model_max_tokens="123", - large_model_name="arg_large_model", - large_model_max_tokens="456", - system_prompt_content="Arg prompt" - ) - mock_super_init.assert_called_once_with( - client=mock_client_arg, - api_key="arg_key", - base_url=None, - api_version=None, - azure_deployment=None, - model_name="arg_small_model", # Initially configured with small model - max_tokens_str="123", - small_model_name="arg_small_model", - small_model_max_tokens_str="123", - large_model_name="arg_large_model", - large_model_max_tokens_str="456", - system_prompt_content="Arg prompt", - system_prompt_path=None, - is_gemini=False, - max_history_length=20 - ) - - # Test switch_model - this method is part of ChatGPTTelegramInferenceBot - # It calls _configure_model_and_tokens which is in the superclass. - # We need a bot instance where _configure_model_and_tokens can be called. - @patch('openai.OpenAI') # To allow instantiation of the bot by mocking client creation - async def test_switch_model_logic(self, mock_openai_constructor): - mock_openai_constructor.return_value = self.mock_openai_client # Mock client creation in super - - # Set env vars for model names that switch_model will use as fallback - os.environ["OPENAI_SMALL_MODEL"] = "env-small-gpt" - os.environ["OPENAI_SMALL_MODEL_MAX_TOKENS"] = "100" - os.environ["OPENAI_LARGE_MODEL"] = "env-large-gpt" - os.environ["OPENAI_LARGE_MODEL_MAX_TOKENS"] = "200" - - # Instantiate with initial model (small) - bot = ChatGPTTelegramInferenceBot() - self.assertEqual(bot.model, "env-small-gpt") - self.assertEqual(bot.max_tokens, 100) - - # Switch to large - status = await bot.switch_model() - self.assertEqual(bot.model, "env-large-gpt") - self.assertEqual(bot.max_tokens, 200) - self.assertEqual(status, "Switched to model: env-large-gpt") - - # Switch back to small - status = await bot.switch_model() - self.assertEqual(bot.model, "env-small-gpt") - self.assertEqual(bot.max_tokens, 100) - self.assertEqual(status, "Switched to model: env-small-gpt") - - @patch('openai.OpenAI') - async def test_switch_model_uses_instance_configs_if_provided(self, mock_openai_constructor): - mock_openai_constructor.return_value = self.mock_openai_client - - # Instantiate with specific model names, overriding potential env vars - bot = ChatGPTTelegramInferenceBot( - small_model_name="init-small", small_model_max_tokens="50", - large_model_name="init-large", large_model_max_tokens="150" - ) - self.assertEqual(bot.model, "init-small") # Starts with small - self.assertEqual(bot.max_tokens, 50) - - # Switch to large - status = await bot.switch_model() - self.assertEqual(bot.model, "init-large") - self.assertEqual(bot.max_tokens, 150) - self.assertEqual(status, "Switched to model: init-large") - - # Switch back to small - status = await bot.switch_model() - self.assertEqual(bot.model, "init-small") - self.assertEqual(bot.max_tokens, 50) - self.assertEqual(status, "Switched to model: init-small") - - # get_llm_description is inherited from OpenAICompatibleInferenceBot. - # Test just to ensure it works in the context of a ChatGPTBot instance - @patch('openai.OpenAI') - def test_get_llm_description_for_chatgpt_bot(self, mock_openai_constructor): - mock_openai_constructor.return_value = self.mock_openai_client - bot = ChatGPTTelegramInferenceBot(small_model_name="gpt-3.5-desc", small_model_max_tokens="777") - # Initially configured with small model - self.assertEqual(bot.get_llm_description(), "LLM: gpt-3.5-desc, Max Tokens: 777, Azure: False") - -if __name__ == '__main__': - unittest.main() diff --git a/tests/test_gemini_telegram_inference_bot.py b/tests/test_gemini_telegram_inference_bot.py deleted file mode 100644 index 8e5cc4f..0000000 --- a/tests/test_gemini_telegram_inference_bot.py +++ /dev/null @@ -1,154 +0,0 @@ -import unittest -from unittest.mock import MagicMock, patch, ANY -import os - -# Assuming gemini_telegram_inference_bot.py and its parent are accessible -from gemini_telegram_inference_bot import GeminiTelegramInferenceBot -from openai_compatible_inference_bot import OpenAICompatibleInferenceBot # For patching super - -class TestGeminiTelegramInferenceBot(unittest.IsolatedAsyncioTestCase): - - def setUp(self): - # Store and clear relevant environment variables - self.original_gemini_key = os.environ.get("GEMINI_API_KEY") - self.original_gemini_base_url = os.environ.get("GEMINI_API_BASE_URL") - self.original_small_model = os.environ.get("GEMINI_SMALL_MODEL") - self.original_large_model = os.environ.get("GEMINI_LARGE_MODEL") - self.original_small_tokens = os.environ.get("GEMINI_SMALL_MODEL_MAX_TOKENS") - self.original_large_tokens = os.environ.get("GEMINI_LARGE_MODEL_MAX_TOKENS") - self.original_system_prompt_path = os.environ.get("SYSTEM_PROMPT_PATH") - - for key in ["GEMINI_API_KEY", "GEMINI_API_BASE_URL", "GEMINI_SMALL_MODEL", "GEMINI_LARGE_MODEL", - "GEMINI_SMALL_MODEL_MAX_TOKENS", "GEMINI_LARGE_MODEL_MAX_TOKENS", "SYSTEM_PROMPT_PATH"]: - if os.environ.get(key): - del os.environ[key] - - self.mock_openai_client = MagicMock() # Used if superclass creates an OpenAI client - - def tearDown(self): - # Restore environment variables - if self.original_gemini_key: os.environ["GEMINI_API_KEY"] = self.original_gemini_key - if self.original_gemini_base_url: os.environ["GEMINI_API_BASE_URL"] = self.original_gemini_base_url - if self.original_small_model: os.environ["GEMINI_SMALL_MODEL"] = self.original_small_model - if self.original_large_model: os.environ["GEMINI_LARGE_MODEL"] = self.original_large_model - if self.original_small_tokens: os.environ["GEMINI_SMALL_MODEL_MAX_TOKENS"] = self.original_small_tokens - if self.original_large_tokens: os.environ["GEMINI_LARGE_MODEL_MAX_TOKENS"] = self.original_large_tokens - if self.original_system_prompt_path: os.environ["SYSTEM_PROMPT_PATH"] = self.original_system_prompt_path - - @patch.object(OpenAICompatibleInferenceBot, '__init__') # Mock the superclass's __init__ - def test_init_defaults_and_super_call(self, mock_super_init): - os.environ["GEMINI_API_KEY"] = "test_key_gemini" - os.environ["GEMINI_API_BASE_URL"] = "https://gemini.env.com" - os.environ["GEMINI_SMALL_MODEL"] = "gemini-pro-env" - os.environ["GEMINI_SMALL_MODEL_MAX_TOKENS"] = "360" - - bot = GeminiTelegramInferenceBot() - - mock_super_init.assert_called_once_with( - client=None, - api_key="test_key_gemini", - base_url="https://gemini.env.com", # Passed to super - api_version=None, - azure_deployment=None, - model_name="gemini-pro-env", - max_tokens_str="360", - small_model_name="gemini-pro-env", - small_model_max_tokens_str="360", - large_model_name=os.environ.get("GEMINI_LARGE_MODEL", "gemini-1.5-pro-latest"), # Default large - large_model_max_tokens_str=os.environ.get("GEMINI_LARGE_MODEL_MAX_TOKENS"), - system_prompt_content=None, - system_prompt_path=None, - is_gemini=True, # Important for Gemini bot - max_history_length=20 - ) - - @patch.object(OpenAICompatibleInferenceBot, '__init__') - def test_init_with_arguments(self, mock_super_init): - mock_client_arg = MagicMock() - bot = GeminiTelegramInferenceBot( - openai_client=mock_client_arg, # Name in Gemini bot is openai_client for consistency - api_key="arg_gem_key", - base_url="https://arg.gemini.com", - small_model_name="arg_gem_small", - small_model_max_tokens="124", - large_model_name="arg_gem_large", - large_model_max_tokens="457", - system_prompt_content="Gemini prompt" - ) - mock_super_init.assert_called_once_with( - client=mock_client_arg, - api_key="arg_gem_key", - base_url="https://arg.gemini.com", - api_version=None, - azure_deployment=None, - model_name="arg_gem_small", - max_tokens_str="124", - small_model_name="arg_gem_small", - small_model_max_tokens_str="124", - large_model_name="arg_gem_large", - large_model_max_tokens_str="457", - system_prompt_content="Gemini prompt", - system_prompt_path=None, - is_gemini=True, - max_history_length=20 - ) - - @patch('openai.OpenAI') # Gemini bot uses OpenAI client configured for Gemini endpoint - async def test_switch_model_logic(self, mock_openai_constructor): - mock_openai_constructor.return_value = self.mock_openai_client - - os.environ["GEMINI_SMALL_MODEL"] = "env-gemini-small" - os.environ["GEMINI_SMALL_MODEL_MAX_TOKENS"] = "110" - os.environ["GEMINI_LARGE_MODEL"] = "env-gemini-large" - os.environ["GEMINI_LARGE_MODEL_MAX_TOKENS"] = "220" - - bot = GeminiTelegramInferenceBot() # Uses env vars by default - self.assertEqual(bot.model, "env-gemini-small") - self.assertEqual(bot.max_tokens, 110) - - status = await bot.switch_model() - self.assertEqual(bot.model, "env-gemini-large") - self.assertEqual(bot.max_tokens, 220) - self.assertEqual(status, "Switched to model: env-gemini-large") - - status = await bot.switch_model() - self.assertEqual(bot.model, "env-gemini-small") - self.assertEqual(bot.max_tokens, 110) - self.assertEqual(status, "Switched to model: env-gemini-small") - - @patch('openai.OpenAI') - async def test_switch_model_uses_instance_configs_if_provided(self, mock_openai_constructor): - mock_openai_constructor.return_value = self.mock_openai_client - - bot = GeminiTelegramInferenceBot( - small_model_name="init-gem-small", small_model_max_tokens="55", - large_model_name="init-gem-large", large_model_max_tokens="155" - ) - self.assertEqual(bot.model, "init-gem-small") - self.assertEqual(bot.max_tokens, 55) - - status = await bot.switch_model() - self.assertEqual(bot.model, "init-gem-large") - self.assertEqual(bot.max_tokens, 155) - self.assertEqual(status, "Switched to model: init-gem-large") - - status = await bot.switch_model() - self.assertEqual(bot.model, "init-gem-small") - self.assertEqual(bot.max_tokens, 55) - self.assertEqual(status, "Switched to model: init-gem-small") - - @patch('openai.OpenAI') - def test_get_llm_description_for_gemini_bot(self, mock_openai_constructor): - mock_openai_constructor.return_value = self.mock_openai_client - bot = GeminiTelegramInferenceBot( - small_model_name="gemini-pro-desc", - small_model_max_tokens="888", - # is_gemini is True by default in constructor call to super - ) - # LLM description should indicate not Azure, even though it uses OpenAICompatible... base - # The is_gemini flag primarily affects client instantiation logic in the superclass. - # The azure_openai flag in superclass is based on azure_endpoint presence. - self.assertEqual(bot.get_llm_description(), "LLM: gemini-pro-desc, Max Tokens: 888, Azure: False") - -if __name__ == '__main__': - unittest.main() diff --git a/tests/test_github_tool.py b/tests/test_github_tool.py deleted file mode 100644 index a012ed1..0000000 --- a/tests/test_github_tool.py +++ /dev/null @@ -1,81 +0,0 @@ -# tests/test_github_tool.py - -import unittest -from unittest.mock import patch, MagicMock -from tools.github_tool import GitHubTool - -class TestGitHubTool(unittest.TestCase): - - def setUp(self): - self.github_tool = GitHubTool() - - def test_get_functions(self): - functions = self.github_tool.get_functions() - self.assertEqual(len(functions), 4) - function_names = [f["name"] for f in functions] - expected_names = ["read_file", "create_branch", "commit_file", "create_pull_request"] - self.assertListEqual(function_names, expected_names) - - @patch('tools.github_tool.requests.get') - def test_read_file(self, mock_get): - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.json.return_value = {"content": "file content"} - mock_get.return_value = mock_response - - result = self.github_tool.execute("read_file", path="test.txt") - self.assertEqual(result, "file content") - - mock_get.assert_called_once() - - @patch('tools.github_tool.requests.get') - @patch('tools.github_tool.requests.post') - def test_create_branch(self, mock_post, mock_get): - mock_get_response = MagicMock() - mock_get_response.status_code = 200 - mock_get_response.json.return_value = {"object": {"sha": "test_sha"}} - mock_get.return_value = mock_get_response - - mock_post_response = MagicMock() - mock_post_response.status_code = 201 - mock_post.return_value = mock_post_response - - result = self.github_tool.execute("create_branch", branch_name="test-branch") - self.assertEqual(result, "Branch 'test-branch' created successfully") - - mock_get.assert_called_once() - mock_post.assert_called_once() - - @patch('tools.github_tool.requests.put') - def test_commit_file(self, mock_put): - mock_response = MagicMock() - mock_response.status_code = 200 - mock_put.return_value = mock_response - - result = self.github_tool.execute("commit_file", branch_name="test-branch", file_path="test.txt", content="test content", commit_message="Test commit") - self.assertEqual(result, "File committed successfully to branch 'test-branch'") - - mock_put.assert_called_once() - - def test_commit_file_to_main(self): - result = self.github_tool.execute("commit_file", branch_name="main", file_path="test.txt", content="test content", commit_message="Test commit") - self.assertEqual(result, "Cannot commit directly to main branch") - - @patch('tools.github_tool.requests.post') - def test_create_pull_request(self, mock_post): - mock_response = MagicMock() - mock_response.status_code = 201 - mock_response.json.return_value = {"html_url": "https://github.com/test/test/pull/1"} - mock_post.return_value = mock_response - - result = self.github_tool.execute("create_pull_request", title="Test PR", body="Test body", head="test-branch") - self.assertEqual(result, "Pull request created successfully: https://github.com/test/test/pull/1") - - mock_post.assert_called_once() - - def test_unknown_function(self): - result = self.github_tool.execute("unknown_function") - self.assertEqual(result, "Unknown function: unknown_function") - -if __name__ == '__main__': - unittest.main() diff --git a/tests/test_openai_compatible_inference_bot.py b/tests/test_openai_compatible_inference_bot.py deleted file mode 100644 index dc667c0..0000000 --- a/tests/test_openai_compatible_inference_bot.py +++ /dev/null @@ -1,332 +0,0 @@ -import unittest -from unittest.mock import MagicMock, patch, AsyncMock, ANY -import os -import json - -# Assuming openai_compatible_inference_bot.py is in the parent directory or PYTHONPATH is set -from openai_compatible_inference_bot import OpenAICompatibleInferenceBot - -# Mock response from OpenAI client's chat.completions.create -def create_mock_openai_response(content=None, tool_calls=None): - mock_message = MagicMock() - mock_message.role = "assistant" - mock_message.content = content - if tool_calls: - # tool_calls should be a list of objects with id and function (name, arguments) - mock_tool_calls = [] - for tc in tool_calls: - mock_tc = MagicMock() - mock_tc.id = tc["id"] - mock_tc.function.name = tc["function"]["name"] - mock_tc.function.arguments = tc["function"]["arguments"] - mock_tool_calls.append(mock_tc) - mock_message.tool_calls = mock_tool_calls - else: - mock_message.tool_calls = None - - mock_choice = MagicMock() - mock_choice.message = mock_message - - mock_response = MagicMock() - mock_response.choices = [mock_choice] - return mock_response - -# Concrete class for testing -class ConcreteOpenAICompatibleBot(OpenAICompatibleInferenceBot): - # Implement abstract methods for instantiation - async def switch_model(self): - # Simple switch for testing if needed, or just pass - if self.model == self.small_model_name: - self._configure_model_and_tokens(self.large_model_name, self.large_model_max_tokens_str) - else: - self._configure_model_and_tokens(self.small_model_name, self.small_model_max_tokens_str) - return f"Switched to {self.model}" - - # Override load_functions if it's called by parent and needs mocking for these tests - # (OpenAICompatibleInferenceBot's __init__ calls BaseTelegramInferenceBot's __init__, which calls load_functions) - def load_functions(self): - # For these tests, assume no tools unless specifically added - self.tools = [] - self.functions = [] - return self.tools, self.functions - - -class TestOpenAICompatibleInferenceBot(unittest.IsolatedAsyncioTestCase): - - def setUp(self): - self.original_openai_api_key = os.environ.get("OPENAI_API_KEY") - self.original_azure_openai_key = os.environ.get("AZURE_OPENAI_KEY") - self.original_azure_endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT") - self.original_api_version = os.environ.get("AZURE_OPENAI_API_VERSION") - self.original_azure_deployment = os.environ.get("AZURE_DEPLOYMENT_NAME") - - # Clear relevant env vars before each test - for key in ["OPENAI_API_KEY", "AZURE_OPENAI_KEY", "AZURE_OPENAI_ENDPOINT", - "AZURE_OPENAI_API_VERSION", "AZURE_DEPLOYMENT_NAME", "SYSTEM_PROMPT_PATH"]: - if os.environ.get(key): - del os.environ[key] - - self.mock_openai_client_instance = MagicMock() - self.mock_openai_client_instance.chat.completions.create = MagicMock() - - def tearDown(self): - # Restore environment variables - if self.original_openai_api_key: os.environ["OPENAI_API_KEY"] = self.original_openai_api_key - if self.original_azure_openai_key: os.environ["AZURE_OPENAI_KEY"] = self.original_azure_openai_key - if self.original_azure_endpoint: os.environ["AZURE_OPENAI_ENDPOINT"] = self.original_azure_endpoint - if self.original_api_version: os.environ["AZURE_OPENAI_API_VERSION"] = self.original_api_version - if self.original_azure_deployment: os.environ["AZURE_DEPLOYMENT_NAME"] = self.original_azure_deployment - - - @patch('openai.OpenAI') - def test_init_with_openai_defaults(self, MockOpenAIConstructor): - MockOpenAIConstructor.return_value = self.mock_openai_client_instance - os.environ["OPENAI_API_KEY"] = "test_openai_key" - - bot = ConcreteOpenAICompatibleBot(model_name="gpt-4") - - MockOpenAIConstructor.assert_called_once_with(api_key="test_openai_key", base_url=None) - self.assertEqual(bot.client, self.mock_openai_client_instance) - self.assertEqual(bot.model, "gpt-4") - self.assertEqual(bot.max_tokens, 1000) # Default from _configure_model_and_tokens - self.assertEqual(bot.azure_openai, False) - - @patch('openai.OpenAI') - def test_init_with_provided_client(self, MockOpenAIConstructor): - preconfigured_client = MagicMock() - bot = ConcreteOpenAICompatibleBot(client=preconfigured_client, model_name="gpt-3.5") - - MockOpenAIConstructor.assert_not_called() - self.assertEqual(bot.client, preconfigured_client) - self.assertEqual(bot.model, "gpt-3.5") - - @patch('openai.AzureOpenAI') - def test_init_with_azure_config_args(self, MockAzureOpenAIConstructor): - MockAzureOpenAIConstructor.return_value = self.mock_openai_client_instance - - bot = ConcreteOpenAICompatibleBot( - api_key="azure_key", - azure_endpoint="https://myenv.openai.azure.com", - api_version="2023-05-15", - azure_deployment="my-gpt-4", # This should be used as model_name for API call - model_name="should_be_overridden_by_azure_deployment_for_api" - # model_name is passed to _configure_model_and_tokens, which sets self.model for display/logging - # but for Azure, the client needs the deployment name. - ) - - MockAzureOpenAIConstructor.assert_called_once_with( - api_key="azure_key", - azure_endpoint="https://myenv.openai.azure.com", - api_version="2023-05-15" - ) - self.assertEqual(bot.client, self.mock_openai_client_instance) - self.assertEqual(bot.model, "my-gpt-4") # Azure deployment name becomes the model for API calls - self.assertEqual(bot.azure_openai, True) - - - @patch('openai.AzureOpenAI') - def test_init_with_azure_env_vars(self, MockAzureOpenAIConstructor): - MockAzureOpenAIConstructor.return_value = self.mock_openai_client_instance - os.environ["AZURE_OPENAI_KEY"] = "env_azure_key" - os.environ["AZURE_OPENAI_ENDPOINT"] = "https://env.openai.azure.com" - os.environ["AZURE_OPENAI_API_VERSION"] = "2023-06-01" - os.environ["AZURE_DEPLOYMENT_NAME"] = "env-gpt-35" # Used as model_name - - bot = ConcreteOpenAICompatibleBot(model_name="ignored_if_azure_deployment_env_is_set") - - MockAzureOpenAIConstructor.assert_called_once_with( - api_key="env_azure_key", - azure_endpoint="https://env.openai.azure.com", - api_version="2023-06-01" - ) - self.assertEqual(bot.model, "env-gpt-35") - self.assertTrue(bot.azure_openai) - - @patch('openai.OpenAI') - def test_init_with_gemini_config_args(self, MockOpenAIConstructor): - MockOpenAIConstructor.return_value = self.mock_openai_client_instance - - bot = ConcreteOpenAICompatibleBot( - api_key="gemini_key", - base_url="https://gemini.example.com", - model_name="gemini-pro", - is_gemini=True - ) - MockOpenAIConstructor.assert_called_once_with(api_key="gemini_key", base_url="https://gemini.example.com") - self.assertEqual(bot.model, "gemini-pro") - self.assertFalse(bot.azure_openai) # is_gemini doesn't mean azure_openai - - def test_configure_model_and_tokens(self): - bot = ConcreteOpenAICompatibleBot(model_name="initial_model") # init calls _configure - bot._configure_model_and_tokens("test-model", "500") - self.assertEqual(bot.model, "test-model") - self.assertEqual(bot.max_tokens, 500) - - bot._configure_model_and_tokens("test-model-2", None, default_max_tokens=150) - self.assertEqual(bot.max_tokens, 150) - - bot._configure_model_and_tokens("test-model-3", "invalid_token_val") - self.assertEqual(bot.max_tokens, 1000) # Default fallback - - def test_get_llm_description(self): - bot = ConcreteOpenAICompatibleBot(model_name="desc-model", max_tokens_str="256") - self.assertEqual(bot.get_llm_description(), "LLM: desc-model, Max Tokens: 256, Azure: False") - - bot_azure = ConcreteOpenAICompatibleBot(azure_deployment="azure-model", azure_endpoint="x", api_key="y", api_version="z") - self.assertEqual(bot_azure.get_llm_description(), "LLM: azure-model, Max Tokens: 1000, Azure: True") - - - def test_get_chat_response_success(self): - bot = ConcreteOpenAICompatibleBot(client=self.mock_openai_client_instance, model_name="test-gpt") - bot.max_tokens = 50 # Ensure this is set - mock_api_response = create_mock_openai_response(content="Hello from API") - self.mock_openai_client_instance.chat.completions.create.return_value = mock_api_response - - messages = [{"role": "user", "content": "Hi"}] - response = bot.get_chat_response(messages) - - self.mock_openai_client_instance.chat.completions.create.assert_called_once_with( - model="test-gpt", - messages=messages, - tools=ANY, # Assuming functions can be None or empty list - tool_choice=ANY, - max_tokens=50 - ) - self.assertEqual(response, mock_api_response) - - def test_get_chat_response_api_error(self): - bot = ConcreteOpenAICompatibleBot(client=self.mock_openai_client_instance, model_name="error-gpt") - self.mock_openai_client_instance.chat.completions.create.side_effect = Exception("API Down") - - with self.assertRaisesRegex(Exception, "API Down"): - bot.get_chat_response([{"role": "user", "content": "trigger"}]) - - async def test_handle_message_simple_response(self): - bot = ConcreteOpenAICompatibleBot(client=self.mock_openai_client_instance, model_name="chatty") - bot.system_prompt = "You are a test bot." # Set directly for simplicity - mock_api_response = create_mock_openai_response(content="Test reply") - self.mock_openai_client_instance.chat.completions.create.return_value = mock_api_response - - response_content = await bot.handle_message(user_id=1, user_message="Hello") - - self.assertEqual(response_content, "Test reply") - self.assertIn(1, bot.conversation_history) - self.assertEqual(len(bot.conversation_history[1]), 3) # System, User, Assistant - self.assertEqual(bot.conversation_history[1][0]["content"], "You are a test bot.") - self.assertEqual(bot.conversation_history[1][2]["content"], "Test reply") - - async def test_handle_message_with_tool_call_and_response(self): - bot = ConcreteOpenAICompatibleBot(client=self.mock_openai_client_instance, model_name="tool-user") - - # Mock functions/tools setup on the bot - mock_tool_def = {"function": {"name": "get_weather", "description": "Gets weather", "parameters": {}}} - bot.functions = [mock_tool_def] # Simulate tools are loaded - - # API response 1: Request to call tool - tool_call_request = [{"id": "call123", "function": {"name": "get_weather", "arguments": '''{"location": "moon"}'''}}] - api_response_1 = create_mock_openai_response(tool_calls=tool_call_request) - - # API response 2: Final answer after tool execution - api_response_2 = create_mock_openai_response(content="The weather on the moon is chilly.") - - self.mock_openai_client_instance.chat.completions.create.side_effect = [api_response_1, api_response_2] - - # Mock self.call_tool - bot.call_tool = MagicMock(return_value='''{"temperature": "-100 C"}''') - - final_response = await bot.handle_message(user_id=2, user_message="Weather on moon?") - - self.assertEqual(final_response, "The weather on the moon is chilly.") - bot.call_tool.assert_called_once_with("get_weather", '''{"location": "moon"}''') - - # Check conversation history includes tool messages - history = bot.conversation_history[2] - self.assertTrue(any(msg["role"] == "assistant" and msg.tool_calls is not None for msg in history)) - self.assertTrue(any(msg["role"] == "tool" and msg["name"] == "get_weather" for msg in history)) - self.assertEqual(self.mock_openai_client_instance.chat.completions.create.call_count, 2) - - async def test_handle_message_max_history_length(self): - bot = ConcreteOpenAICompatibleBot(client=self.mock_openai_client_instance, model_name="hist-test", max_history_length=3) - self.mock_openai_client_instance.chat.completions.create.return_value = create_mock_openai_response(content="Ok") - - await bot.handle_message(1, "Msg1") # Sys, User, Assist (3) - self.assertEqual(len(bot.conversation_history[1]), 3) - - await bot.handle_message(1, "Msg2") # User, Assist. Should be 3 (prev User, prev Assist, new User) -> then adds new Assist. - # Before new call: [Sys, U1, A1]. New U2. Call with [Sys,U1,A1,U2]. Resp A2. - # History: [Sys,U1,A1,U2,A2]. Limit 3. -> [A1,U2,A2] (if system is not preserved specially) - # The current code appends to history then truncates if over limit. - # So after Msg1: [S, U1, A1]. len=3. - # For Msg2: History is [S, U1, A1]. Append U2. Call with [S,U1,A1,U2]. Append A2. - # History now [S,U1,A1,U2,A2]. len=5. Truncate to 3. - # Expected: [A1, U2, A2] or [U1,A1,U2] or [U2,A2,S] depending on how system prompt is handled in truncation. - # The code is: self.conversation_history[user_id][-self.max_history_length:] - # And system prompt is only added IF user_id not in self.conversation_history. - # So, for Msg2, system prompt is not re-added. - # History before Msg2 call: [S, U1, A1] - # Messages for Msg2 call: [S, U1, A1, U2] - # History after Msg2 response A2: [S, U1, A1, U2, A2]. Len 5. - # Truncated to self.max_history_length=3: [A1, U2, A2] - - # Call 1 - self.mock_openai_client_instance.chat.completions.create.reset_mock() - self.mock_openai_client_instance.chat.completions.create.return_value = create_mock_openai_response(content="Reply1") - await bot.handle_message(user_id=7, user_message="First message") - self.assertEqual(len(bot.conversation_history[7]), 3) # System, User1, Assistant1 - - # Call 2 - self.mock_openai_client_instance.chat.completions.create.reset_mock() - self.mock_openai_client_instance.chat.completions.create.return_value = create_mock_openai_response(content="Reply2") - await bot.handle_message(user_id=7, user_message="Second message") - # History before call: [S, U1, A1]. Messages for call: [S, U1, A1, U2]. History after: [S, U1, A1, U2, A2]. - # Truncated to 3: [A1, U2, A2] - self.assertEqual(len(bot.conversation_history[7]), 3) - self.assertEqual(bot.conversation_history[7][0]["content"], "Reply1") # A1 - self.assertEqual(bot.conversation_history[7][1]["content"], "Second message") # U2 - self.assertEqual(bot.conversation_history[7][2]["content"], "Reply2") # A2 - - # Call 3 - self.mock_openai_client_instance.chat.completions.create.reset_mock() - self.mock_openai_client_instance.chat.completions.create.return_value = create_mock_openai_response(content="Reply3") - await bot.handle_message(user_id=7, user_message="Third message") - # History before call: [A1, U2, A2]. Messages for call: [A1, U2, A2, U3]. History after: [A1, U2, A2, U3, A3]. - # Truncated to 3: [A2, U3, A3] - self.assertEqual(len(bot.conversation_history[7]), 3) - self.assertEqual(bot.conversation_history[7][0]["content"], "Reply2") # A2 - self.assertEqual(bot.conversation_history[7][1]["content"], "Third message") # U3 - self.assertEqual(bot.conversation_history[7][2]["content"], "Reply3") # A3 - - - async def test_abort_processing(self): - bot = ConcreteOpenAICompatibleBot(model_name="test") - user_id = 123 - bot.processing_status[user_id] = {"processing": True, "message_id": 456} - bot.conversation_history[user_id] = [{"role": "user", "content": "stuff"}] - - with patch.object(bot, 'clear_conversation_history') as mock_clear_hist: # Patching the method from Base class - result = await bot.abort_processing(user_id) - - self.assertEqual(result, "Processing aborted and conversation cleared.") - self.assertFalse(bot.processing_status[user_id]["processing"]) - mock_clear_hist.assert_called_once_with(user_id) - - async def test_abort_processing_no_active_processing(self): - bot = ConcreteOpenAICompatibleBot(model_name="test") - user_id = 404 # Not in processing_status - with patch.object(bot, 'clear_conversation_history') as mock_clear_hist: - result = await bot.abort_processing(user_id) - self.assertEqual(result, "No active processing found to abort. Conversation cleared.") - mock_clear_hist.assert_called_once_with(user_id) - - # Test for the abstract switch_model (basic call, actual logic in concrete class for this test) - async def test_switch_model_concrete_implementation(self): - bot = ConcreteOpenAICompatibleBot(model_name="model1", small_model_name="model1", large_model_name="model2", max_tokens_str="100") - self.assertEqual(bot.model, "model1") - await bot.switch_model() # Calls the concrete implementation - self.assertEqual(bot.model, "model2") - await bot.switch_model() - self.assertEqual(bot.model, "model1") - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/test_telegram_helper.py b/tests/test_telegram_helper.py deleted file mode 100644 index 6b8d655..0000000 --- a/tests/test_telegram_helper.py +++ /dev/null @@ -1,356 +0,0 @@ -import unittest -from unittest.mock import MagicMock, patch, mock_open, AsyncMock -import asyncio -import os -import sys - -# Assuming telegram_helper.py is in the parent directory or PYTHONPATH is set -from telegram_helper import TelegramHelper, MessageHandlerLogicResult - -# Mock for the bot passed to TelegramHelper -class MockBot: - def __init__(self): - self.start = AsyncMock() - self.clear_conversation_history = MagicMock() - self.get_bot_status = AsyncMock(return_value="Bot Status OK") - self.switch_model = AsyncMock(return_value="Model Switched OK") - self.handle_message = AsyncMock() # Needs to return a string - self.abort_processing = AsyncMock(return_value="Abort OK") - self.set_processing_status = MagicMock() - self.clear_processing_status = MagicMock() - self.processing_status = {} # Add the attribute - -# Mock for telegram.Update and related objects -def create_mock_update(message_text=None, user_id=123, chat_id=456, message_id=789, callback_query_data=None): - update = MagicMock() - update.effective_user.id = user_id - update.effective_chat.id = chat_id - - if message_text: - update.message.text = message_text - update.message.reply_text = AsyncMock(return_value=MagicMock(message_id=message_id)) # reply_text returns a Message obj - - if callback_query_data: - update.callback_query.data = callback_query_data - update.callback_query.from_user.id = user_id - update.callback_query.answer = AsyncMock() - update.callback_query.edit_message_text = AsyncMock() - - return update - -# Mock for telegram.ext.ContextTypes.DEFAULT_TYPE -def create_mock_context(): - context = MagicMock() - context.bot.delete_message = AsyncMock() - context.bot.edit_message_text = AsyncMock() # For update_status_message - return context - -class TestTelegramHelper(unittest.IsolatedAsyncioTestCase): # Use IsolatedAsyncioTestCase for async methods - - def setUp(self): - self.mock_bot = MockBot() - # Default paths for reboot files, can be overridden in tests - self.reboot_claude_file = ".test_reboot_claude" - self.reboot_file = ".test_doreboot" - self.helper = TelegramHelper( - self.mock_bot, - reboot_claude_file_path=self.reboot_claude_file, - reboot_file_path=self.reboot_file, - chunk_message_sleep_duration=0.001 # Faster sleep for tests - ) - # Clean up any potential leftover reboot files from previous runs - if os.path.exists(self.reboot_claude_file): - os.remove(self.reboot_claude_file) - if os.path.exists(self.reboot_file): - os.remove(self.reboot_file) - - def tearDown(self): - # Clean up reboot files created during tests - if os.path.exists(self.reboot_claude_file): - os.remove(self.reboot_claude_file) - if os.path.exists(self.reboot_file): - os.remove(self.reboot_file) - - async def test_start_logic(self): - response = await self.helper._start_logic() - self.mock_bot.start.assert_called_once() - self.assertEqual(response, "Hello! I\'m your AI assistant. How can I help you today?") - - async def test_start_command(self): - mock_update = create_mock_update(message_text="/start") - mock_context = create_mock_context() - - with patch.object(self.helper, \'_start_logic\', new_callable=AsyncMock) as mock_logic: - mock_logic.return_value = "Start Logic Response" - await self.helper.start(mock_update, mock_context) - mock_logic.assert_called_once() - mock_update.message.reply_text.assert_called_once_with("Start Logic Response") - - async def test_clear_logic(self): - user_id = 123 - response = await self.helper._clear_logic(user_id) # _clear_logic is async after refactor - self.mock_bot.clear_conversation_history.assert_called_once_with(user_id) - self.assertEqual(response, "Conversation history cleared. Let\'s start fresh!") - - async def test_clear_command(self): - mock_update = create_mock_update(message_text="/clear", user_id=123) - mock_context = create_mock_context() - with patch.object(self.helper, \'_clear_logic\', new_callable=AsyncMock) as mock_logic: - mock_logic.return_value = "Clear Logic Response" - await self.helper.clear(mock_update, mock_context) - mock_logic.assert_called_once_with(123) - mock_update.message.reply_text.assert_called_once_with("Clear Logic Response") - - async def test_status_logic(self): - self.mock_bot.get_bot_status.return_value = "Test Status" - response = await self.helper._status_logic() - self.mock_bot.get_bot_status.assert_called_once() - self.assertEqual(response, "Test Status") - - async def test_switch_logic_supported(self): - self.mock_bot.switch_model.return_value = "Switched to Large Model" - response = await self.helper._switch_logic() - self.mock_bot.switch_model.assert_called_once() - self.assertEqual(response, "Switched to Large Model") - - async def test_switch_logic_not_supported(self): - del self.mock_bot.switch_model # Simulate bot not having the attribute - response = await self.helper._switch_logic() - self.assertEqual(response, "Model switching is not supported for this bot.") - - async def test_handle_message_logic_success(self): - user_id = 100 - user_message = "Hello bot" - bot_response = "Hello user Thinking hard Done." - expected_processed_response = f"Hello user {self.helper.HTML_QUOTE_BLOCK_START}Thinking hard{self.helper.HTML_QUOTE_BLOCK_END} Done." - self.mock_bot.handle_message.return_value = bot_response - - result = await self.helper._handle_message_logic(user_id, user_message) - - self.mock_bot.handle_message.assert_called_once_with(user_id, user_message) - self.assertTrue(result["success"]) - self.assertEqual(result["response_text"], expected_processed_response) - self.assertIsNone(result["error_message"]) - - async def test_handle_message_logic_bot_exception(self): - user_id = 101 - user_message = "Trigger error" - self.mock_bot.handle_message.side_effect = Exception("Bot Error") - - result = await self.helper._handle_message_logic(user_id, user_message) - - self.assertFalse(result["success"]) - self.assertIsNone(result["response_text"]) - self.assertEqual(result["error_message"], "Bot Error") - - @patch(\'logging.error\') - async def test_handle_message_command_success_short_message(self, mock_logging_error): - mock_update = create_mock_update(message_text="Hi", user_id=200, chat_id=201, message_id=202) - mock_context = create_mock_context() - - logic_result = MessageHandlerLogicResult(success=True, response_text="Short response", error_message=None) - - with patch.object(self.helper, \'_handle_message_logic\', new_callable=AsyncMock) as mock_message_logic: - mock_message_logic.return_value = logic_result - - await self.helper.handle_message(mock_update, mock_context) - - mock_update.message.reply_text.assert_any_call("Processing your request...", reply_markup=unittest.mock.ANY) - self.mock_bot.set_processing_status.assert_called_once_with(200, 202) # user_id, status_message_id - mock_message_logic.assert_called_once_with(200, "Hi") - mock_context.bot.delete_message.assert_called_once_with(chat_id=201, message_id=202) - self.mock_bot.clear_processing_status.assert_called_once_with(200) - mock_update.message.reply_text.assert_any_call("Short response") # Final response - self.assertEqual(mock_update.message.reply_text.call_count, 2) # Processing + final - - @patch(\'logging.error\') - async def test_handle_message_command_success_long_message_chunks(self, mock_logging_error): - mock_update = create_mock_update(message_text="Long text", user_id=200, chat_id=201, message_id=202) - mock_context = create_mock_context() - - long_response_text = "a" * 5000 # Longer than 4096 - chunk1 = long_response_text[:4096] - chunk2 = long_response_text[4096:] - - logic_result = MessageHandlerLogicResult(success=True, response_text=long_response_text, error_message=None) - - with patch.object(self.helper, \'_handle_message_logic\', new_callable=AsyncMock) as mock_message_logic, \ - patch(\'asyncio.sleep\', new_callable=AsyncMock) as mock_sleep: # Mock sleep - mock_message_logic.return_value = logic_result - - await self.helper.handle_message(mock_update, mock_context) - - mock_update.message.reply_text.assert_any_call(chunk1) - mock_update.message.reply_text.assert_any_call(chunk2) - mock_sleep.assert_called_once_with(self.helper.chunk_message_sleep_duration) - self.assertEqual(mock_update.message.reply_text.call_count, 3) # Processing + 2 chunks - - @patch(\'logging.error\') - async def test_handle_message_command_logic_fails(self, mock_logging_error): - mock_update = create_mock_update(message_text="Cause error in logic", user_id=200) - mock_context = create_mock_context() - logic_result = MessageHandlerLogicResult(success=False, response_text=None, error_message="Logic Failed") - - with patch.object(self.helper, \'_handle_message_logic\', new_callable=AsyncMock) as mock_message_logic: - mock_message_logic.return_value = logic_result - await self.helper.handle_message(mock_update, mock_context) - mock_update.message.reply_text.assert_any_call("Sorry, an error occurred while processing your request.") - self.assertEqual(mock_update.message.reply_text.call_count, 2) # Processing + error message - - @patch(\'logging.error\') - async def test_handle_message_command_telegram_exception_after_logic(self, mock_logging_error): - mock_update = create_mock_update(message_text="Test", user_id=200) - mock_context = create_mock_context() - logic_result = MessageHandlerLogicResult(success=True, response_text="OK", error_message=None) - - # Make sending the final reply fail - mock_update.message.reply_text.side_effect = [ - MagicMock(message_id=202), # For "Processing..." - Exception("Telegram API Error") # For the actual response - ] - - with patch.object(self.helper, \'_handle_message_logic\', new_callable=AsyncMock) as mock_message_logic: - mock_message_logic.return_value = logic_result - await self.helper.handle_message(mock_update, mock_context) - - # Check if the generic error message was attempted - # This is tricky because reply_text is already mocked with side_effect. - # We\'d expect logs. Let\'s check logs or if processing status was cleared. - self.mock_bot.clear_processing_status.assert_called_once_with(200) - mock_logging_error.assert_any_call(unittest.mock.string_containing("Outer error in handle_message")) - - - async def test_abort_processing_logic(self): - user_id = 300 - self.mock_bot.abort_processing.return_value = "Aborted by bot" - response = await self.helper._abort_processing_logic(user_id) - self.mock_bot.abort_processing.assert_called_once_with(user_id) - self.assertEqual(response, "Aborted by bot") - - async def test_abort_processing_command(self): - mock_update = create_mock_update(callback_query_data=\'abort\', user_id=301) - mock_context = create_mock_context() - with patch.object(self.helper, \'_abort_processing_logic\', new_callable=AsyncMock) as mock_logic: - mock_logic.return_value = "Abort Logic Done" - await self.helper.abort_processing(mock_update, mock_context) - - mock_update.callback_query.answer.assert_called_once() - mock_logic.assert_called_once_with(301) - mock_update.callback_query.edit_message_text.assert_called_once_with(text="Abort Logic Done") - - def test_reboot_logic_claude_and_main(self): - user_message_parts = ["/reboot", "claude"] - chat_id_to_write = "12345" - - with patch("builtins.open", mock_open()) as mock_file: - self.helper._reboot_logic(user_message_parts, chat_id_to_write) - - # Check claude reboot file - mock_file.assert_any_call(self.reboot_claude_file, \'w\') - # Check main doreboot file - mock_file.assert_any_call(self.reboot_file, \'w\') - handle_claude = mock_file.return_value - handle_main = mock_file.return_value # mock_open reuses the handle for multiple calls - - # Check if write was called for claude file (empty write) - # This part of assertion is tricky with single mock_file. Better to use different mocks if possible - # or check the sequence of calls if the mock supports it well. - # For now, assert_any_call ensures it was opened. - - # Check content for main reboot file - # Need to ensure the write for self.reboot_file had chat_id_to_write - # This requires more sophisticated mock_open or patching os.path.exists and multiple open calls - # Simpler check: was open(self.reboot_file, \'w\') called? Yes, via assert_any_call. - # And was open(self.reboot_claude_file, \'w\') called? Yes. - - # Verify files were created (mock_open doesn\'t actually create them) - # This test relies on mock_open\'s behavior. To test file content, need more setup. - # For now, assume open was called correctly. - - def test_reboot_logic_main_only(self): - user_message_parts = ["/reboot"] - chat_id_to_write = "67890" - with patch("builtins.open", mock_open()) as mock_file: - self.helper._reboot_logic(user_message_parts, chat_id_to_write) - # Ensure claude file was NOT opened for writing. - # This requires asserting that a specific call didn\'t happen, or checking call_args_list - claude_call = unittest.mock.call(self.reboot_claude_file, \'w\') - self.assertNotIn(claude_call, mock_file.call_args_list) - - mock_file.assert_any_call(self.reboot_file, \'w\') - - @patch(\'sys.exit\') # Mock sys.exit to prevent test runner from exiting - async def test_reboot_command(self, mock_sys_exit): - mock_update = create_mock_update(message_text="/reboot claude", chat_id="chat1") - mock_context = create_mock_context() - - with patch.object(self.helper, \'_reboot_logic\') as mock_reboot_file_logic: - await self.helper.reboot(mock_update, mock_context) - - mock_reboot_file_logic.assert_called_once_with(["/reboot", "claude"], "chat1") - mock_update.message.reply_text.assert_called_once_with("Rebooting the bot...") - mock_sys_exit.assert_called_once_with(0) - - @patch(\'os.path.exists\') - @patch(\'builtins.open\', new_callable=mock_open) - @patch(\'os.remove\') - async def test_check_doreboot_file_logic_file_exists(self, mock_os_remove, mock_file_open, mock_os_path_exists): - mock_os_path_exists.return_value = True - mock_file_open.return_value.read.return_value.strip.return_value = "chat123" - - chat_id = await self.helper._check_doreboot_file_logic() - - mock_os_path_exists.assert_called_once_with(self.reboot_file) - mock_file_open.assert_called_once_with(self.reboot_file, \'r\') - mock_os_remove.assert_called_once_with(self.reboot_file) - self.assertEqual(chat_id, "chat123") - - @patch(\'os.path.exists\', return_value=False) - async def test_check_doreboot_file_logic_file_not_exists(self, mock_os_path_exists): - chat_id = await self.helper._check_doreboot_file_logic() - mock_os_path_exists.assert_called_once_with(self.reboot_file) - self.assertIsNone(chat_id) - - @patch(\'logging.error\') - @patch(\'os.path.exists\', return_value=True) - @patch(\'builtins.open\', side_effect=IOError("Read error")) - @patch(\'os.remove\') # To check if remove is called even on read error - async def test_check_doreboot_file_logic_read_error(self, mock_os_remove, mock_file_open, mock_os_path_exists, mock_log_error): - chat_id = await self.helper._check_doreboot_file_logic() - - self.assertIsNone(chat_id) - mock_log_error.assert_any_call(unittest.mock.string_containing("Error reading reboot file")) - # Check if os.remove was attempted even after read error - mock_os_remove.assert_called_once_with(self.reboot_file) - - - async def test_check_doreboot_file_command_sends_message(self): - mock_application = MagicMock() - mock_application.bot.send_message = AsyncMock() - - with patch.object(self.helper, \'_check_doreboot_file_logic\', new_callable=AsyncMock) as mock_logic: - mock_logic.return_value = "chat789" # Simulate chat_id found - await self.helper.check_doreboot_file(mock_application) - - mock_logic.assert_called_once() - mock_application.bot.send_message.assert_called_once_with( - chat_id="chat789", text="The application has finished initializing." - ) - - async def test_check_doreboot_file_command_no_chat_id(self): - mock_application = MagicMock() - mock_application.bot.send_message = AsyncMock() - - with patch.object(self.helper, \'_check_doreboot_file_logic\', new_callable=AsyncMock) as mock_logic: - mock_logic.return_value = None # Simulate no chat_id found - await self.helper.check_doreboot_file(mock_application) - - mock_logic.assert_called_once() - mock_application.bot.send_message.assert_not_called() - - # Note: Testing the run() method itself is more of an integration test, - # as it involves setting up the full Application and polling loop. - # Unit tests here focus on the helper\'s own logic methods. - -if __name__ == \'__main__\': - unittest.main() diff --git a/tests/tools/test_github_tool.py b/tests/tools/test_github_tool.py deleted file mode 100644 index 6c10b91..0000000 --- a/tests/tools/test_github_tool.py +++ /dev/null @@ -1,307 +0,0 @@ -import unittest -from unittest.mock import MagicMock, patch -import os -import base64 -import logging -import requests # Required for spec in MagicMock - -# Ensure tools/github_tool.py is accessible -from tools.github_tool import GitHubTool - -# Helper to create a mock response for requests.Session -def create_mock_response(status_code, json_data=None, text_data=None, headers=None, links=None): - mock_resp = MagicMock() - mock_resp.status_code = status_code - if json_data is not None: - mock_resp.json = MagicMock(return_value=json_data) - mock_resp.text = text_data if text_data is not None else str(json_data) - mock_resp.headers = headers if headers else {} - mock_resp.links = links if links else {} # For pagination in _list_branches - return mock_resp - -class TestGitHubTool(unittest.TestCase): - - def setUp(self): - self.mock_session = MagicMock(spec=requests.Session) - self.mock_session.headers = {} # Simulate a new session's headers - - self.test_token = "test_github_token" - self.test_repo = "owner/repo" - self.test_base_url = "https://api.example.com" # Use a non-default base_url for some tests - - # Suppress logging output during tests unless explicitly testing for it - self.logger = logging.getLogger('tools.github_tool') - # Ensure only one NullHandler to prevent duplicate messages if tests run multiple times in a session - if not any(isinstance(h, logging.NullHandler) for h in self.logger.handlers): - self.logger.addHandler(logging.NullHandler()) - self.logger.propagate = False # Prevent propagation to root logger if it has handlers - - def test_init_with_args_and_session(self): - tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, base_url=self.test_base_url, logger=self.logger) - self.assertEqual(tool.session, self.mock_session) - self.assertEqual(tool._token, self.test_token) - self.assertEqual(tool._repo, self.test_repo) - self.assertEqual(tool.base_url, self.test_base_url) - self.assertEqual(tool.current_branch, "main") # Default initial branch - - @patch('requests.Session') - def test_init_creates_session_if_not_provided(self, MockSessionConstructor): - mock_created_session = MagicMock(spec=requests.Session) - mock_created_session.headers = {} - MockSessionConstructor.return_value = mock_created_session - - # Temporarily set env vars for this test - original_token = os.environ.get("GITHUB_TOKEN") - original_repo = os.environ.get("GITHUB_REPOSITORY") - os.environ["GITHUB_TOKEN"] = "env_token" - os.environ["GITHUB_REPOSITORY"] = "env/repo" - - tool = GitHubTool(logger=self.logger) # Use env vars - - MockSessionConstructor.assert_called_once() - self.assertEqual(tool.session, mock_created_session) - self.assertEqual(tool._token, "env_token") - self.assertEqual(tool._repo, "env/repo") - self.assertIn("Authorization", mock_created_session.headers) - self.assertEqual(mock_created_session.headers["Authorization"], "token env_token") - - # Restore original env vars - if original_token is None: del os.environ["GITHUB_TOKEN"] - else: os.environ["GITHUB_TOKEN"] = original_token - if original_repo is None: del os.environ["GITHUB_REPOSITORY"] - else: os.environ["GITHUB_REPOSITORY"] = original_repo - - def test_init_raises_value_error_if_no_token(self): - original_token = os.environ.pop("GITHUB_TOKEN", None) - with self.assertRaisesRegex(ValueError, "GitHub token must be provided"): - GitHubTool(repo=self.test_repo, logger=self.logger) - if original_token: os.environ["GITHUB_TOKEN"] = original_token - - def test_init_raises_value_error_if_no_repo(self): - original_repo = os.environ.pop("GITHUB_REPOSITORY", None) - with self.assertRaisesRegex(ValueError, "GitHub repository.*must be provided"): - GitHubTool(token=self.test_token, logger=self.logger) - if original_repo: os.environ["GITHUB_REPOSITORY"] = original_repo - - def test_clear_resets_branch(self): - tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, initial_branch="feature-branch", logger=self.logger) - # Mock _get_branch_sha for _set_current_branch called by clear - with patch.object(tool, '_get_branch_sha', return_value="sha_for_main"): - tool.clear() - self.assertEqual(tool.current_branch, "main") - - def test_get_functions_returns_list(self): - tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) - functions = tool.get_functions() - self.assertIsInstance(functions, list) - self.assertTrue(len(functions) > 0) - self.assertIn("name", functions[0]["function"]) - - - # --- Test individual private methods --- - - def test_read_file_success(self): - tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) - file_content = "Hello World!" - encoded_content = base64.b64encode(file_content.encode('utf-8')).decode('utf-8') - self.mock_session.get.return_value = create_mock_response(200, json_data={"content": encoded_content}) - - result = tool._read_file(path="test.txt") - self.assertEqual(result, file_content) - self.mock_session.get.assert_called_once_with( - f"{tool.base_url}/repos/{self.test_repo}/contents/test.txt", - params={"ref": "main"} - ) - - def test_read_file_error(self): - tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) - self.mock_session.get.return_value = create_mock_response(404, text_data="Not Found") - result = tool._read_file(path="nonexistent.txt") - self.assertIn("Error reading file", result) - - def test_create_branch_success(self): - tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) - # Mock getting base branch SHA - self.mock_session.get.return_value = create_mock_response(200, json_data={"object": {"sha": "base_sha123"}}) - # Mock creating new branch - self.mock_session.post.return_value = create_mock_response(201, json_data={"ref": "refs/heads/new-feature"}) - - result = tool._create_branch(branch_name="new-feature", base_branch="main") - self.assertIn("Branch 'new-feature' created successfully", result) - self.assertEqual(tool.current_branch, "new-feature") - self.mock_session.get.assert_called_once() # For base branch SHA - self.mock_session.post.assert_called_once() # For creating branch - - def test_create_branch_base_sha_error(self): - tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) - self.mock_session.get.return_value = create_mock_response(404, text_data="Base branch not found") - result = tool._create_branch(branch_name="new-feature", base_branch="nonexistent-base") - self.assertIn("Error getting base branch SHA", result) - - def test_create_branch_creation_error(self): - tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) - self.mock_session.get.return_value = create_mock_response(200, json_data={"object": {"sha": "base_sha456"}}) - self.mock_session.post.return_value = create_mock_response(422, text_data="Validation failed") - result = tool._create_branch(branch_name="bad-branch", base_branch="main") - self.assertIn("Error creating branch", result) - - def test_commit_file_success_new_file(self): - tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) - tool.current_branch = "dev-branch" # Cannot commit to main by default - - # Mock GET for checking file existence (404 means new file) - self.mock_session.get.return_value = create_mock_response(404) - # Mock PUT for committing file - self.mock_session.put.return_value = create_mock_response(201, json_data={"commit": {"sha": "commit_sha_abc"}}) - - result = tool._commit_file(file_path="new_file.py", content="print('Hello')", commit_message="Add new_file.py") - self.assertIn("committed successfully", result) - self.assertIn("commit_sha_abc", result) - self.mock_session.get.assert_called_once() # Check file existence - self.mock_session.put.assert_called_once() # Commit file - - def test_commit_file_success_update_file(self): - tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) - tool.current_branch = "dev-branch" - - # Mock GET for checking file existence (200 means existing file) - self.mock_session.get.return_value = create_mock_response(200, json_data={"sha": "existing_file_sha"}) - # Mock PUT for committing file - self.mock_session.put.return_value = create_mock_response(200, json_data={"commit": {"sha": "commit_sha_def"}}) - - result = tool._commit_file(file_path="existing_file.txt", content="Updated content", commit_message="Update existing_file.txt") - self.assertIn("committed successfully", result) - self.assertIn("commit_sha_def", result) - args, kwargs = self.mock_session.put.call_args - self.assertEqual(kwargs['json']['sha'], "existing_file_sha") - - - def test_commit_file_to_main_branch_error(self): - tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) - tool.current_branch = "main" - result = tool._commit_file(file_path="some.txt", content="content", commit_message="msg") - self.assertIn("Action directly to main branch is not allowed", result) - - def test_create_pull_request_success(self): - tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) - tool.current_branch = "feature-pr" - pr_url = "https://example.com/pull/1" - self.mock_session.post.return_value = create_mock_response(201, json_data={"html_url": pr_url, "number": 1}) - - result = tool._create_pull_request(title="New Feature PR", body="Please review.", base="main") - self.assertIn(f"Pull request created successfully: {pr_url}", result) - self.mock_session.post.assert_called_once() - call_data = self.mock_session.post.call_args[1]['json'] - self.assertEqual(call_data['head'], "feature-pr") - self.assertEqual(call_data['base'], "main") - - def test_create_pull_request_same_branch_error(self): - tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) - tool.current_branch = "main" - result = tool._create_pull_request(title="PR to self", body="This should fail", base="main") - self.assertIn("Cannot create a pull request from branch 'main' to itself", result) - - - def test_list_files_success(self): - tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) - mock_items = [ - {"name": "file1.txt", "type": "file", "path": "dir/file1.txt"}, - {"name": "subdir", "type": "dir", "path": "dir/subdir"} - ] - self.mock_session.get.return_value = create_mock_response(200, json_data=mock_items) - - result = tool._list_files(path="dir") - self.assertEqual(len(result), 2) - self.assertEqual(result[0]["name"], "file1.txt") - self.assertEqual(result[1]["type"], "dir") - - def test_search_code_success(self): - tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) - mock_search_results = { - "items": [{"path": "src/code.py", "html_url": "url1"}] - } - self.mock_session.get.return_value = create_mock_response(200, json_data=mock_search_results) - - results = tool._search_code(query="my_function") - self.assertEqual(len(results), 1) - self.assertEqual(results[0]["path"], "src/code.py") - - def test_get_commit_history_success(self): - tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) - mock_commits = [{ - "sha": "sha1", "commit": {"message": "Msg1", "author": {"name": "Authy", "date": "Date1"}} - }] - self.mock_session.get.return_value = create_mock_response(200, json_data=mock_commits) - - commits = tool._get_commit_history(file_path="file.txt", num_commits=1) - self.assertEqual(len(commits), 1) - self.assertEqual(commits[0]["sha"], "sha1") - - def test_set_current_branch_success(self): - tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) - # Mock _get_branch_sha to simulate branch exists - with patch.object(tool, '_get_branch_sha', return_value="some_sha_for_dev"): - result = tool._set_current_branch(branch_name="dev") - self.assertEqual(tool.current_branch, "dev") - self.assertIn("Current branch set to: dev", result) - - def test_set_current_branch_not_exists(self): - tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) - with patch.object(tool, '_get_branch_sha', return_value="Error getting SHA for branch"): - result = tool._set_current_branch(branch_name="nonexistent-branch") - self.assertNotEqual(tool.current_branch, "nonexistent-branch") # Should not change - self.assertIn("Cannot set current branch", result) - - - def test_list_branches_single_page(self): - tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) - mock_branches = [{"name": "main"}, {"name": "dev"}] - self.mock_session.get.return_value = create_mock_response(200, json_data=mock_branches, links={}) # No "next" link - - branches = tool._list_branches(all_pages=True) - self.assertEqual(branches, ["main", "dev"]) - self.mock_session.get.assert_called_once() - - def test_list_branches_multiple_pages(self): - tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) - - # Page 1 response - page1_branches = [{"name": "branch1"}, {"name": "branch2"}] - next_url = f"{tool.base_url}/repos/{self.test_repo}/branches?page=2" - response1 = create_mock_response(200, json_data=page1_branches, links={"next": {"url": next_url}}) - - # Page 2 response - page2_branches = [{"name": "branch3"}] - response2 = create_mock_response(200, json_data=page2_branches, links={}) # No "next" link - - self.mock_session.get.side_effect = [response1, response2] - - branches = tool._list_branches(all_pages=True) - self.assertEqual(branches, ["branch1", "branch2", "branch3"]) - self.assertEqual(self.mock_session.get.call_count, 2) - - # Check that the second call used the next_url - calls = self.mock_session.get.call_args_list - self.assertEqual(calls[1][0][0], next_url) # args[0] is the URL - - # --- Test execute dispatcher --- - def test_execute_read_file(self): - tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) - with patch.object(tool, '_read_file', return_value="file content") as mock_method: - result = tool.execute(function_name="read_file", path="test.md") - mock_method.assert_called_once_with(path="test.md") - self.assertEqual(result, "file content") - - def test_execute_unknown_function(self): - tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) - result = tool.execute(function_name="non_existent_function_name", arg1="val1") - self.assertIn("Unknown function: non_existent_function_name", result) - - def test_execute_method_exception(self): - tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) - with patch.object(tool, '_read_file', side_effect=Exception("Kaboom")) as mock_method: - result = tool.execute(function_name="read_file", path="crash.txt") - self.assertIn("Error during read_file execution: Kaboom", result) - -if __name__ == '__main__': - unittest.main() diff --git a/tests/tools/test_log_tool.py b/tests/tools/test_log_tool.py deleted file mode 100644 index a765b06..0000000 --- a/tests/tools/test_log_tool.py +++ /dev/null @@ -1,146 +0,0 @@ -import unittest -from unittest.mock import patch, mock_open, MagicMock -import os -import logging -from datetime import datetime, timedelta - -# Ensure tools/log_tool.py is accessible -from tools.log_tool import LogTool - -class TestLogTool(unittest.TestCase): - - def setUp(self): - self.test_log_file_path = "test_dummy_log.log" - # Suppress logging output during tests unless explicitly testing for it - self.logger = logging.getLogger('tools.log_tool') - # Ensure only one NullHandler to prevent duplicate messages if tests run multiple times in a session - if not any(isinstance(h, logging.NullHandler) for h in self.logger.handlers): - self.logger.addHandler(logging.NullHandler()) - self.logger.propagate = False # Prevent propagation to root logger if it has handlers - - - def test_init_default_log_path(self): - tool = LogTool(logger=self.logger) - self.assertEqual(tool.configured_log_file_path, 'logs/output.log') - - def test_init_custom_log_path(self): - tool = LogTool(log_file_path=self.test_log_file_path, logger=self.logger) - self.assertEqual(tool.configured_log_file_path, self.test_log_file_path) - - def test_get_functions(self): - tool = LogTool(logger=self.logger) - functions = tool.get_functions() - self.assertIsInstance(functions, list) - self.assertEqual(len(functions), 1) - self.assertEqual(functions[0]["function"]["name"], "get_log_contents") - - @patch("os.path.exists", return_value=False) - def test_get_log_contents_file_not_exists(self, mock_exists): - tool = LogTool(log_file_path=self.test_log_file_path, logger=self.logger) - result = tool._get_log_contents() - self.assertIn("Log file does not exist", result) - mock_exists.assert_called_once_with(self.test_log_file_path) - - @patch("os.path.exists", return_value=True) - @patch("builtins.open", new_callable=mock_open, read_data="line1\nline2\nline3\nline4\nline5") - def test_get_log_contents_with_line_count(self, mock_file_open, mock_exists): - tool = LogTool(log_file_path=self.test_log_file_path, logger=self.logger) - - result = tool._get_log_contents(line_count=3) - self.assertEqual(result, "line3\nline4\nline5") - mock_exists.assert_called_once_with(self.test_log_file_path) - mock_file_open.assert_called_once_with(self.test_log_file_path, 'r', encoding='utf-8') - - @patch("os.path.exists", return_value=True) - @patch("builtins.open", new_callable=mock_open, read_data="line1\nline2\n") - def test_get_log_contents_line_count_more_than_available(self, mock_file_open, mock_exists): - tool = LogTool(log_file_path=self.test_log_file_path, logger=self.logger) - result = tool._get_log_contents(line_count=5) - self.assertEqual(result, "line1\nline2\n") - - @patch("os.path.exists", return_value=True) - @patch("builtins.open", new_callable=mock_open, read_data="line1\nline2\n") - def test_get_log_contents_invalid_line_count_uses_default(self, mock_file_open, mock_exists): - tool = LogTool(log_file_path=self.test_log_file_path, logger=self.logger) - # Test with zero, negative, and non-integer line_count (though type hint is int) - # The code defaults to 150 if invalid. Here, we only have 2 lines. - with patch.object(tool.logger, 'warning') as mock_log_warning: - result_zero = tool._get_log_contents(line_count=0) - self.assertEqual(result_zero, "line1\nline2\n") - mock_log_warning.assert_any_call("Invalid line_count '0' provided, defaulting to fetch last 150 lines.") - - mock_file_open.reset_mock() # Reset for next call - result_neg = tool._get_log_contents(line_count=-5) - self.assertEqual(result_neg, "line1\nline2\n") - mock_log_warning.assert_any_call("Invalid line_count '-5' provided, defaulting to fetch last 150 lines.") - - - @patch("os.path.exists", return_value=True) - def test_get_log_contents_last_24_hours(self, mock_exists): - tool = LogTool(log_file_path=self.test_log_file_path, logger=self.logger) - - now = datetime.now() - one_hour_ago_dt = now - timedelta(hours=1) - two_days_ago_dt = now - timedelta(days=2) - - one_hour_ago_str = one_hour_ago_dt.strftime(LogTool.EXPECTED_LOG_TIMESTAMP_FORMAT) - two_days_ago_str = two_days_ago_dt.strftime(LogTool.EXPECTED_LOG_TIMESTAMP_FORMAT) - - log_data = ( - f"{two_days_ago_str} - OLD - This is an old log entry.\n" - f"No timestamp here - this line should be skipped by time filter.\n" - f"{one_hour_ago_str} - RECENT - This is a recent log entry.\n" - f"Malformed Date 2023-xx-01 - Another skipped line.\n" - f"{now.strftime(LogTool.EXPECTED_LOG_TIMESTAMP_FORMAT)} - CURRENT - This is a current log entry.\n" - ) - - expected_output = ( - f"{one_hour_ago_str} - RECENT - This is a recent log entry.\n" - f"{now.strftime(LogTool.EXPECTED_LOG_TIMESTAMP_FORMAT)} - CURRENT - This is a current log entry.\n" - ) - - with patch("builtins.open", mock_open(read_data=log_data)): - result = tool._get_log_contents(line_count=None) # Trigger 24-hour logic - self.assertEqual(result, expected_output) - - @patch("os.path.exists", return_value=True) - @patch("builtins.open", side_effect=IOError("File read error!")) - def test_get_log_contents_file_read_exception(self, mock_file_open, mock_exists): - tool = LogTool(log_file_path=self.test_log_file_path, logger=self.logger) - result = tool._get_log_contents(line_count=10) - self.assertIn("An error occurred while reading the log file: File read error!", result) - - def test_execute_get_log_contents(self): - tool = LogTool(logger=self.logger) - mock_return_value = "Mocked log content" - with patch.object(tool, '_get_log_contents', return_value=mock_return_value) as mock_method: - result = tool.execute(function_name="get_log_contents", line_count=50) - mock_method.assert_called_once_with(line_count=50) - self.assertEqual(result, mock_return_value) - - def test_execute_get_log_contents_no_line_count(self): - tool = LogTool(logger=self.logger) - mock_return_value = "Mocked log content for 24h" - with patch.object(tool, '_get_log_contents', return_value=mock_return_value) as mock_method: - result = tool.execute(function_name="get_log_contents") # No line_count - mock_method.assert_called_once_with(line_count=None) # Expects None to trigger 24h - self.assertEqual(result, mock_return_value) - - - def test_execute_unknown_function(self): - tool = LogTool(logger=self.logger) - result = tool.execute(function_name="non_existent_log_function") - self.assertIn("Unknown function: non_existent_log_function", result) - - def test_clear_method(self): - tool = LogTool(logger=self.logger) - # Set a specific level for the logger for this test if needed to capture debug - original_level = tool.logger.level - tool.logger.setLevel(logging.DEBUG) - with self.assertLogs(tool.logger, level='DEBUG') as cm: - tool.clear() - self.assertTrue(any("LogTool clear called" in message for message in cm.output)) - tool.logger.setLevel(original_level) # Reset level - -if __name__ == '__main__': - unittest.main() diff --git a/tests/tools/test_metrics.py b/tests/tools/test_metrics.py deleted file mode 100644 index 902abb7..0000000 --- a/tests/tools/test_metrics.py +++ /dev/null @@ -1,217 +0,0 @@ -import unittest -from unittest.mock import patch, MagicMock, ANY -import time -import logging - -# Ensure tools.metrics is accessible -from tools.metrics import Metrics # Import the class itself for direct testing -from tools.metrics import metrics as global_metrics_instance # Import the global instance - -# A simple function to decorate for testing -def sample_function_for_metrics(duration=0.01): - # Simulate some work - # Note: time.sleep is not always precisely profiled by cProfile in the same way as CPU-bound work. - # For testing, we will mock the cProfile/pstats interaction rather than relying on actual sleep duration. - if duration > 0: # Make it conditional so we can test zero-time case too - pass # The actual work is not important when mocking cProfile results - return "sample_output" - -def another_sample_function(x, y): - return x + y - -class TestMetrics(unittest.TestCase): - - def setUp(self): - # Create a fresh Metrics instance for most tests to avoid interference - self.logger = logging.getLogger('tools.metrics.test') - if not self.logger.handlers: # Avoid adding handler multiple times - self.logger.addHandler(logging.NullHandler()) - self.metrics_instance = Metrics(logger=self.logger) - - # Clear the global instance before each test that might use it - global_metrics_instance.clear_metrics() - - def test_measure_decorator_counts_calls(self): - decorated_func = self.metrics_instance.measure(sample_function_for_metrics) - - self.assertEqual(self.metrics_instance.call_count["sample_function_for_metrics"], 0) - decorated_func() - self.assertEqual(self.metrics_instance.call_count["sample_function_for_metrics"], 1) - decorated_func() - self.assertEqual(self.metrics_instance.call_count["sample_function_for_metrics"], 2) - - @patch('cProfile.Profile') - @patch('pstats.Stats') - def test_measure_decorator_records_time(self, MockPStats, MockCProfile): - # Mock cProfile and pstats to control the time value - mock_profiler_instance = MockCProfile.return_value - mock_pstats_instance = MockPStats.return_value - - # Simulate that pstats.Stats.stats dictionary contains the function's stats - # Key: (filename, lineno, funcname) - # Value: (cc, nc, tt, ct, callers) where ct is cumulative_time (index 3) - - # Get code object of the function *before* decoration for correct key - original_func_code = sample_function_for_metrics.__code__ - func_key = (original_func_code.co_filename, original_func_code.co_firstlineno, original_func_code.co_name) - - # Configure mock_pstats_instance.stats to return our desired time - mock_pstats_instance.stats = {func_key: (1, 1, 0.05, 0.123, {})} # cc, nc, tt, ct=0.123 - - decorated_func = self.metrics_instance.measure(sample_function_for_metrics) - - self.assertEqual(self.metrics_instance.total_time["sample_function_for_metrics"], 0) - - # Call the decorated function - decorated_func(duration=0) # Duration arg doesn't matter due to mocking - - # Assertions - mock_profiler_instance.enable.assert_called_once() - mock_profiler_instance.disable.assert_called_once() - MockPStats.assert_called_once_with(mock_profiler_instance) - - self.assertEqual(self.metrics_instance.total_time["sample_function_for_metrics"], 0.123) - - # Call again to see accumulation - # Reset mock stats for a new time value if needed, or assume same time per call - mock_pstats_instance.stats = {func_key: (1, 1, 0.05, 0.100, {})} # New ct=0.100 - decorated_func(duration=0) - self.assertAlmostEqual(self.metrics_instance.total_time["sample_function_for_metrics"], 0.123 + 0.100) - - - @patch('cProfile.Profile') - @patch('pstats.Stats') - def test_measure_decorator_fallback_time_recording_by_name(self, MockPStats, MockCProfile): - mock_profiler_instance = MockCProfile.return_value - mock_pstats_instance = MockPStats.return_value - - original_func_code = sample_function_for_metrics.__code__ # func to be decorated - # Simulate the primary key lookup fails by creating a slightly different key for what we expect - # This is what the code will try to look up first. - expected_primary_key = (original_func_code.co_filename, original_func_code.co_firstlineno, original_func_code.co_name) - - # This is the key that will *actually* be in pstats.stats, simulating a mismatch for primary lookup - # but a match for the by-name fallback. - actual_stats_key_in_pstats = (original_func_code.co_filename, - original_func_code.co_firstlineno + 5, # simulate a lineno difference for primary key mismatch - original_func_code.co_name) # Name is the same for fallback - - mock_pstats_instance.stats = { - # expected_primary_key is NOT present - actual_stats_key_in_pstats: (1, 1, 0.03, 0.077, {}) # ct = 0.077 - } - - decorated_func = self.metrics_instance.measure(sample_function_for_metrics) - - # Expecting a debug log for fallback, but assertLogs needs the logger to have a handler that captures. - # self.logger is already set up with NullHandler. For this test, let's use a specific logger. - metrics_internal_logger = logging.getLogger('tools.metrics') # Logger used inside Metrics class - original_level = metrics_internal_logger.level - metrics_internal_logger.setLevel(logging.DEBUG) - - with self.assertLogs(metrics_internal_logger, level='DEBUG') as log_capture: - decorated_func(duration=0) - - metrics_internal_logger.setLevel(original_level) # Reset logger level - - self.assertTrue(any("Found stats for sample_function_for_metrics by name" in msg for msg in log_capture.output)) - self.assertEqual(self.metrics_instance.total_time["sample_function_for_metrics"], 0.077) - - - @patch('cProfile.Profile') - @patch('pstats.Stats') - def test_measure_decorator_handles_func_stats_not_found(self, MockPStats, MockCProfile): - mock_profiler_instance = MockCProfile.return_value - mock_pstats_instance = MockPStats.return_value - mock_pstats_instance.stats = {} # Empty stats, function will not be found - - decorated_func = self.metrics_instance.measure(sample_function_for_metrics) - - metrics_internal_logger = logging.getLogger('tools.metrics') - original_level = metrics_internal_logger.level - metrics_internal_logger.setLevel(logging.WARNING) - with self.assertLogs(metrics_internal_logger, level='WARNING') as log_capture: - decorated_func(duration=0) - metrics_internal_logger.setLevel(original_level) - - self.assertTrue(any("Could not find exact cProfile stats" in msg for msg in log_capture.output)) - self.assertEqual(self.metrics_instance.total_time["sample_function_for_metrics"], 0) - - - def test_get_metrics_empty(self): - self.assertEqual(self.metrics_instance.get_metrics(), {}) - - @patch('cProfile.Profile') - @patch('pstats.Stats') - def test_get_metrics_with_data(self, MockPStats, MockCProfile): - mock_pstats_instance = MockPStats.return_value - - # Decorate two different functions - decorated_func1 = self.metrics_instance.measure(sample_function_for_metrics) - decorated_func2 = self.metrics_instance.measure(another_sample_function) - - # Data for func1 - func1_code = sample_function_for_metrics.__code__ - func1_key = (func1_code.co_filename, func1_code.co_firstlineno, func1_code.co_name) - mock_pstats_instance.stats = {func1_key: (1,1,0.1,0.1,{})} - decorated_func1() - - # Data for func2 - func2_code = another_sample_function.__code__ - func2_key = (func2_code.co_filename, func2_code.co_firstlineno, func2_code.co_name) - mock_pstats_instance.stats = {func2_key: (1,1,0.2,0.2,{})} # Cumulative time 0.2 - decorated_func2(1,2) - mock_pstats_instance.stats = {func2_key: (1,1,0.3,0.3,{})} # Cumulative time 0.3 for second call - decorated_func2(3,4) - - metrics_data = self.metrics_instance.get_metrics() - - self.assertIn("sample_function_for_metrics", metrics_data) - self.assertEqual(metrics_data["sample_function_for_metrics"]["call_count"], 1) - self.assertEqual(metrics_data["sample_function_for_metrics"]["total_time"], 0.1) - self.assertEqual(metrics_data["sample_function_for_metrics"]["average_time"], 0.1) - - self.assertIn("another_sample_function", metrics_data) - self.assertEqual(metrics_data["another_sample_function"]["call_count"], 2) - self.assertAlmostEqual(metrics_data["another_sample_function"]["total_time"], 0.5) - self.assertAlmostEqual(metrics_data["another_sample_function"]["average_time"], 0.25) - - - def test_clear_metrics(self): - # Add some data - self.metrics_instance.call_count["test_func"] = 5 - self.metrics_instance.total_time["test_func"] = 1.234 - - self.metrics_instance.clear_metrics() - - self.assertEqual(self.metrics_instance.call_count, {}) - self.assertEqual(self.metrics_instance.total_time, {}) - self.assertEqual(self.metrics_instance.get_metrics(), {}) - - # Test the global instance - @patch('cProfile.Profile') - @patch('pstats.Stats') - def test_global_metrics_instance_usage(self, MockPStats, MockCProfile): - mock_pstats_instance = MockPStats.return_value - - # Decorate a function with the global instance - @global_metrics_instance.measure - def globally_decorated_func(): - return "global_output" - - # Setup mock stats for the globally decorated function - # Access __wrapped__ to get the original function if other decorators might be present or for consistency. - original_g_func = globally_decorated_func.__wrapped__ - func_code = original_g_func.__code__ - func_key = (func_code.co_filename, func_code.co_firstlineno, func_code.co_name) - mock_pstats_instance.stats = {func_key: (1,1,0.05,0.05,{})} - - globally_decorated_func() - - metrics_data = global_metrics_instance.get_metrics() - self.assertIn("globally_decorated_func", metrics_data) - self.assertEqual(metrics_data["globally_decorated_func"]["call_count"], 1) - self.assertEqual(metrics_data["globally_decorated_func"]["total_time"], 0.05) - -if __name__ == '__main__': - unittest.main() diff --git a/tests/tools/test_metrics_tool.py b/tests/tools/test_metrics_tool.py deleted file mode 100644 index 17c1b9d..0000000 --- a/tests/tools/test_metrics_tool.py +++ /dev/null @@ -1,161 +0,0 @@ -import unittest -from unittest.mock import MagicMock, patch -import logging - -# Ensure tools.metrics_tool and tools.metrics are accessible -from tools.metrics_tool import MetricsTool -from tools.metrics import Metrics # Used for typehinting and creating a mockable instance - -class TestMetricsTool(unittest.TestCase): - - def setUp(self): - self.mock_metrics_provider = MagicMock(spec=Metrics) - self.logger = logging.getLogger('tools.metrics_tool.test') - if not self.logger.handlers: - self.logger.addHandler(logging.NullHandler()) - self.logger.propagate = False - - - def test_init_with_provider(self): - tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger) - self.assertEqual(tool.metrics_provider, self.mock_metrics_provider) - - @patch('tools.metrics_tool.global_metrics_instance') # Patch the global instance path - def test_init_default_provider(self, mock_global_metrics): - tool = MetricsTool(logger=self.logger) - self.assertEqual(tool.metrics_provider, mock_global_metrics) - - def test_get_functions(self): - tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger) - functions = tool.get_functions() - self.assertIsInstance(functions, list) - self.assertTrue(len(functions) == 3) # Based on current definition - self.assertIn("get_function_metrics", [f["function"]["name"] for f in functions]) - self.assertIn("get_specific_function_metrics", [f["function"]["name"] for f in functions]) - self.assertIn("get_top_n_functions", [f["function"]["name"] for f in functions]) - - def test_execute_get_function_metrics(self): - tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger) - expected_metrics = {"func1": {"call_count": 1, "total_time": 0.1}} - self.mock_metrics_provider.get_metrics.return_value = expected_metrics - - result = tool.execute(function_name="get_function_metrics") - - self.mock_metrics_provider.get_metrics.assert_called_once() - self.assertEqual(result, expected_metrics) - - def test_execute_get_specific_function_metrics_found(self): - tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger) - func_metrics = {"call_count": 5, "total_time": 0.5, "average_time": 0.1} - all_metrics = {"specific_func": func_metrics, "other_func": {}} - self.mock_metrics_provider.get_metrics.return_value = all_metrics - - # The execute method expects kwargs that match the function parameters in get_functions. - # So, the argument name for the function to get is 'function_name' in the tool's spec. - result = tool.execute(function_name="get_specific_function_metrics", **{"function_name": "specific_func"}) - self.assertEqual(result, func_metrics) - - def test_execute_get_specific_function_metrics_not_found(self): - tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger) - self.mock_metrics_provider.get_metrics.return_value = {"other_func": {}} - - result = tool.execute(function_name="get_specific_function_metrics", **{"function_name": "non_existent_func"}) - self.assertEqual(result, "No metrics found for function: non_existent_func") - - def test_execute_get_specific_function_metrics_missing_arg(self): - tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger) - result = tool.execute(function_name="get_specific_function_metrics") # Missing function_name kwarg - self.assertIn("Error: Missing required argument 'function_name'", result) - - - def test_execute_get_top_n_functions(self): - tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger) - metrics_data = { - "func_a": {"call_count": 1, "total_time": 0.3}, - "func_b": {"call_count": 1, "total_time": 0.1}, - "func_c": {"call_count": 1, "total_time": 0.5}, - "func_d": {"call_count": 1, "total_time": 0.2}, - } - self.mock_metrics_provider.get_metrics.return_value = metrics_data - - # Test getting top 2 - result = tool.execute(function_name="get_top_n_functions", n=2) - expected_top_2 = {"func_c": metrics_data["func_c"], "func_a": metrics_data["func_a"]} - self.assertEqual(result, expected_top_2) - - # Test getting top 1 - result_top_1 = tool.execute(function_name="get_top_n_functions", n=1) - expected_top_1 = {"func_c": metrics_data["func_c"]} - self.assertEqual(result_top_1, expected_top_1) - - # Test N larger than available functions - result_top_all = tool.execute(function_name="get_top_n_functions", n=10) - # Order should be func_c, func_a, func_d, func_b - expected_top_all_keys = ["func_c", "func_a", "func_d", "func_b"] - self.assertEqual(list(result_top_all.keys()), expected_top_all_keys) - - def test_execute_get_top_n_functions_malformed_metrics(self): - tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger) - metrics_data = { - "func_a": {"call_count": 1, "total_time": 0.3}, - "func_b": "not a dict", # Malformed - "func_c": {"call_count": 1}, # Missing total_time - "func_d": {"call_count": 1, "total_time": 0.2}, - } - self.mock_metrics_provider.get_metrics.return_value = metrics_data - - metrics_tool_logger = logging.getLogger('tools.metrics_tool') - original_level = metrics_tool_logger.level - metrics_tool_logger.setLevel(logging.WARNING) - with self.assertLogs(metrics_tool_logger, level='WARNING') as log_capture: - result = tool.execute(function_name="get_top_n_functions", n=2) - metrics_tool_logger.setLevel(original_level) - - # Check that warnings were logged for malformed items - self.assertTrue(any("Metric item for 'func_b' is not in expected format" in msg for msg in log_capture.output)) - self.assertTrue(any("Metric item for 'func_c' is not in expected format" in msg for msg in log_capture.output)) - - # Expected: func_a, func_d (as they are valid and sortable) - expected_result = { - "func_a": metrics_data["func_a"], - "func_d": metrics_data["func_d"] - } - self.assertEqual(result, expected_result) - - - def test_execute_get_top_n_functions_invalid_n(self): - tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger) - self.mock_metrics_provider.get_metrics.return_value = {} # No metrics needed for this test - - result_zero = tool.execute(function_name="get_top_n_functions", n=0) - self.assertIn("Error: Argument 'n' must be a positive integer.", result_zero) - - result_negative = tool.execute(function_name="get_top_n_functions", n=-1) - self.assertIn("Error: Argument 'n' must be a positive integer.", result_negative) - - result_string = tool.execute(function_name="get_top_n_functions", n="abc") - self.assertIn("Error: Argument 'n' must be an integer.", result_string) - - def test_execute_get_top_n_functions_missing_arg_n(self): - tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger) - result = tool.execute(function_name="get_top_n_functions") # Missing n - self.assertIn("Error: Missing required argument 'n'.", result) - - - def test_execute_unknown_function(self): - tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger) - result = tool.execute(function_name="non_existent_metrics_function") - self.assertIn("Unknown function: non_existent_metrics_function", result) - - def test_clear_method(self): - tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger) - metrics_tool_logger = logging.getLogger('tools.metrics_tool') - original_level = metrics_tool_logger.level - metrics_tool_logger.setLevel(logging.DEBUG) - with self.assertLogs(metrics_tool_logger, level='DEBUG') as cm: - tool.clear() - metrics_tool_logger.setLevel(original_level) - self.assertTrue(any("MetricsTool clear method called" in message for message in cm.output)) - -if __name__ == '__main__': - unittest.main() diff --git a/tools/github_ci_tool.py b/tools/github_ci_tool.py index 715a321..e99341e 100644 --- a/tools/github_ci_tool.py +++ b/tools/github_ci_tool.py @@ -4,8 +4,7 @@ import zipfile import io import re import logging -from .base_tool import BaseTool # Added -from .metrics import metrics # Added +from .base_tool import BaseTool # Configure logging for the tool - This will be handled by the logger instance now # logger = logging.getLogger(__name__) # Commented out or removed @@ -70,7 +69,8 @@ class GitHubCIHelper(BaseTool): # Inherits from BaseTool }, "required": ["pull_request_number"] } - } + }, + "_tags": ["read"] }, { "type": "function", @@ -85,7 +85,8 @@ class GitHubCIHelper(BaseTool): # Inherits from BaseTool }, "required": ["pull_request_number"] } - } + }, + "_tags": ["read"] }, { "type": "function", @@ -100,7 +101,8 @@ class GitHubCIHelper(BaseTool): # Inherits from BaseTool }, "required": ["run_id"] } - } + }, + "_tags": ["read"] }, { "type": "function", @@ -114,7 +116,8 @@ class GitHubCIHelper(BaseTool): # Inherits from BaseTool }, "required": ["log_content"] } - } + }, + "_tags": ["read"] } ] diff --git a/tools/github_tool.py b/tools/github_tool.py index 7720fa8..85f0de3 100644 --- a/tools/github_tool.py +++ b/tools/github_tool.py @@ -1,6 +1,5 @@ # tools/github_tool.py from .base_tool import BaseTool -from .metrics import metrics import requests import os import base64 @@ -57,7 +56,8 @@ class GitHubTool(BaseTool): }, "required": ["path"] } - } + }, + "_tags": ["read"] }, { "type": "function", @@ -71,7 +71,8 @@ class GitHubTool(BaseTool): }, "required": ["path"] } - } + }, + "_tags": ["read"] }, { "type": "function", @@ -85,7 +86,8 @@ class GitHubTool(BaseTool): }, "required": ["query"] } - } + }, + "_tags": ["read"] }, { "type": "function", @@ -100,7 +102,8 @@ class GitHubTool(BaseTool): }, "required": ["branch_name"] } - } + }, + "_tags": ["write"] }, { "type": "function", @@ -116,7 +119,8 @@ class GitHubTool(BaseTool): }, "required": ["file_path", "commit_message", "content"] } - } + }, + "_tags": ["write"] }, { "type": "function", @@ -132,7 +136,8 @@ class GitHubTool(BaseTool): }, "required": ["title", "body"] } - } + }, + "_tags": ["write"] }, { "type": "function", @@ -147,7 +152,8 @@ class GitHubTool(BaseTool): }, "required": ["file_path"] } - } + }, + "_tags": ["read"] }, { "type": "function", @@ -162,7 +168,8 @@ class GitHubTool(BaseTool): }, "required": ["file_path"] } - } + }, + "_tags": ["read"] }, { "type": "function", @@ -176,7 +183,8 @@ class GitHubTool(BaseTool): }, "required": ["branch"] } - } + }, + "_tags": ["read"] }, { "type": "function", @@ -184,7 +192,8 @@ class GitHubTool(BaseTool): "name": "get_current_branch", "description": "Get the name of the current branch", "parameters": { "type": "object", "properties": {} } - } + }, + "_tags": ["read"] }, { "type": "function", @@ -198,7 +207,8 @@ class GitHubTool(BaseTool): }, "required": ["branch_name"] } - } + }, + "_tags": ["read", "write"] }, { "type": "function", @@ -213,7 +223,8 @@ class GitHubTool(BaseTool): }, "required": ["file_path", "commit_sha"] } - } + }, + "_tags": ["read"] }, { "type": "function", @@ -227,7 +238,8 @@ class GitHubTool(BaseTool): "all_pages": {"type": "boolean", "description": "Whether to fetch all pages of results", "default": True} } } - } + }, + "_tags": ["read"] }, { "type": "function", @@ -241,7 +253,8 @@ class GitHubTool(BaseTool): }, "required": ["pull_number"] } - } + }, + "_tags": ["write"] }, { "type": "function", @@ -255,7 +268,8 @@ class GitHubTool(BaseTool): }, "required": ["pull_number"] } - } + }, + "_tags": ["write"] }, { "type": "function", @@ -277,7 +291,8 @@ class GitHubTool(BaseTool): }, "required": ["pull_number"] } - } + }, + "_tags": ["write"] }, { "type": "function", @@ -291,7 +306,8 @@ class GitHubTool(BaseTool): }, "required": ["branch_name"] } - } + }, + "_tags": ["write"] }, { "type": "function", @@ -305,7 +321,8 @@ class GitHubTool(BaseTool): }, "required": ["issue_number"] } - } + }, + "_tags": ["read"] }, { "type": "function", @@ -325,7 +342,8 @@ class GitHubTool(BaseTool): }, "required": ["title", "body"] } - } + }, + "_tags": ["communicate"] }, { "type": "function", @@ -340,7 +358,8 @@ class GitHubTool(BaseTool): "page": {"type": "integer", "default": 1, "description": "Page number of the results to fetch"} } } - } + }, + "_tags": ["read"] }, { "type": "function", @@ -355,7 +374,8 @@ class GitHubTool(BaseTool): }, "required": ["issue_number", "comment"] } - } + }, + "_tags": ["communicate"] }, { "type": "function", @@ -369,7 +389,8 @@ class GitHubTool(BaseTool): }, "required": ["issue_number"] } - } + }, + "_tags": ["read"] }, { "type": "function", @@ -383,7 +404,8 @@ class GitHubTool(BaseTool): }, "required": ["pull_number"] } - } + }, + "_tags": ["read"] }, { "type": "function", @@ -398,7 +420,8 @@ class GitHubTool(BaseTool): }, "required": ["name"] } - } + }, + "_tags": ["communicate"] }, { "type": "function", @@ -413,7 +436,8 @@ class GitHubTool(BaseTool): }, "required": ["project_id", "column_name"] } - } + }, + "_tags": ["communicate"] }, { "type": "function", @@ -428,7 +452,8 @@ class GitHubTool(BaseTool): }, "required": ["column_id", "note"] } - } + }, + "_tags": ["communicate"] }, { "type": "function", @@ -444,7 +469,8 @@ class GitHubTool(BaseTool): }, "required": ["card_id", "position", "column_id"] } - } + }, + "_tags": ["communicate"] }, { "type": "function", @@ -460,7 +486,8 @@ class GitHubTool(BaseTool): }, "required": ["card_id", "content_id", "content_type"] } - } + }, + "_tags": ["communicate"] }, { "type": "function", @@ -468,7 +495,8 @@ class GitHubTool(BaseTool): "name": "list_project_boards", "description": "List project boards associated with the repository", "parameters": { "type": "object", "properties": {} } - } + }, + "_tags": ["communicate"] }, { "type": "function", @@ -482,7 +510,8 @@ class GitHubTool(BaseTool): }, "required": ["project_id"] } - } + }, + "_tags": ["communicate"] }, { "type": "function", @@ -496,7 +525,8 @@ class GitHubTool(BaseTool): }, "required": ["pull_number"] } - } + }, + "_tags": ["read"] }, { "type": "function", @@ -510,7 +540,8 @@ class GitHubTool(BaseTool): }, "required": ["pull_number"] } - } + }, + "_tags": ["read"] }, { "type": "function", @@ -524,7 +555,8 @@ class GitHubTool(BaseTool): }, "required": ["pull_number"] } - } + }, + "_tags": ["read"] }, { "type": "function", @@ -545,7 +577,8 @@ class GitHubTool(BaseTool): }, "required": ["pull_number", "body", "commit_id", "path", "position"] } - } + }, + "_tags": ["communicate"] }, { "type": "function", @@ -559,7 +592,8 @@ class GitHubTool(BaseTool): }, "required": ["pull_number"] } - } + }, + "_tags": ["read"] }, { "type": "function", @@ -575,11 +609,11 @@ class GitHubTool(BaseTool): }, "required": ["pull_number", "event"] } - } + }, + "_tags": ["communicate"] } ] - @metrics.measure def execute(self, function_name, **kwargs): self.logger.info(f"Executing GitHub Tool function: {function_name} with args: {kwargs}") # Dispatch to the appropriate private method @@ -598,7 +632,6 @@ class GitHubTool(BaseTool): # Private methods for each function, using self.session for HTTP requests - @metrics.measure def _read_file(self, path): self.logger.info(f"Reading file: {path} from branch: {self.current_branch}") url = f"{self.base_url}/repos/{self._repo}/contents/{path}" @@ -613,7 +646,6 @@ class GitHubTool(BaseTool): self.logger.error(error_message) return error_message - @metrics.measure def _create_branch(self, branch_name, base_branch="main"): self.logger.info(f"Creating branch: {branch_name} from base: {base_branch}") # Get SHA of base branch @@ -639,7 +671,6 @@ class GitHubTool(BaseTool): self.logger.error(error_message) return error_message - @metrics.measure def _commit_file(self, file_path, content, commit_message): self.logger.info(f"Committing file: {file_path} to branch: {self.current_branch} with message: '{commit_message}'") if self.current_branch == "main": @@ -679,7 +710,6 @@ class GitHubTool(BaseTool): self.logger.error(error_message) return error_message - @metrics.measure def _create_pull_request(self, title, body, base="main"): self.logger.info(f"Creating pull request: '{title}' from branch '{self.current_branch}' to '{base}'") if self.current_branch == base: @@ -701,7 +731,6 @@ class GitHubTool(BaseTool): self.logger.error(error_message) return error_message - @metrics.measure def _get_branch_sha(self, branch): self.logger.info(f"Getting SHA for branch: {branch}") url = f"{self.base_url}/repos/{self._repo}/git/refs/heads/{branch}" @@ -715,7 +744,6 @@ class GitHubTool(BaseTool): self.logger.error(error_message) return error_message - @metrics.measure def _list_files(self, path): self.logger.info(f"Listing files in path: '{path}' on branch: '{self.current_branch}'") url = f"{self.base_url}/repos/{self._repo}/contents/{path.strip('/')}" # Ensure no leading/trailing slashes for consistency @@ -738,7 +766,6 @@ class GitHubTool(BaseTool): self.logger.error(error_message) return error_message - @metrics.measure def _search_code(self, query): self.logger.info(f"Searching code with query: '{query}' in repo: '{self._repo}'") url = f"{self.base_url}/search/code" @@ -754,7 +781,6 @@ class GitHubTool(BaseTool): self.logger.error(error_message) return error_message - @metrics.measure def _get_commit_history(self, file_path, num_commits=10): self.logger.info(f"Getting last {num_commits} commit(s) for file: '{file_path}' on branch '{self.current_branch}'") url = f"{self.base_url}/repos/{self._repo}/commits" @@ -775,18 +801,15 @@ class GitHubTool(BaseTool): self.logger.error(error_message) return error_message - @metrics.measure def _view_commit_details_for_file(self, file_path, num_commits=10): # This function is essentially the same as get_commit_history based on its description. self.logger.info(f"Viewing commit details for file '{file_path}' (last {num_commits} commits) - using _get_commit_history.") return self._get_commit_history(file_path, num_commits) - @metrics.measure def _get_current_branch(self): self.logger.info(f"Current branch is: {self.current_branch}") return self.current_branch - @metrics.measure def _set_current_branch(self, branch_name): self.logger.info(f"Attempting to set current branch to: {branch_name}") # Check if branch exists by trying to get its SHA @@ -801,7 +824,6 @@ class GitHubTool(BaseTool): self.logger.info(success_message) return success_message - @metrics.measure def _get_file_at_commit(self, file_path, commit_sha): self.logger.info(f"Getting file '{file_path}' at commit SHA: {commit_sha}") url = f"{self.base_url}/repos/{self._repo}/contents/{file_path}" @@ -816,7 +838,6 @@ class GitHubTool(BaseTool): self.logger.error(error_message) return error_message - @metrics.measure def _list_branches(self, per_page=100, all_pages=True): self.logger.info(f"Listing branches for repo '{self._repo}'. Per_page={per_page}, All_pages={all_pages}") url = f"{self.base_url}/repos/{self._repo}/branches" @@ -844,7 +865,6 @@ class GitHubTool(BaseTool): self.logger.info(f"Successfully listed {len(branches_list)} branches.") return branches_list - @metrics.measure def _approve_pull_request(self, pull_number): self.logger.info(f"Approving pull request #{pull_number}") url = f"{self.base_url}/repos/{self._repo}/pulls/{pull_number}/reviews" @@ -859,7 +879,6 @@ class GitHubTool(BaseTool): self.logger.error(error_message) return error_message - @metrics.measure def _close_pull_request(self, pull_number): self.logger.info(f"Closing pull request #{pull_number}") url = f"{self.base_url}/repos/{self._repo}/pulls/{pull_number}" @@ -874,7 +893,6 @@ class GitHubTool(BaseTool): self.logger.error(error_message) return error_message - @metrics.measure def _merge_pull_request(self, pull_number, commit_title="Merge pull request", commit_message="", merge_method="merge"): self.logger.info(f"Merging pull request #{pull_number} using method '{merge_method}'") url = f"{self.base_url}/repos/{self._repo}/pulls/{pull_number}/merge" @@ -897,7 +915,6 @@ class GitHubTool(BaseTool): self.logger.error(error_message) return error_message - @metrics.measure def _delete_branch(self, branch_name): self.logger.info(f"Deleting branch: {branch_name}") if branch_name == "main" or (hasattr(self, 'default_branch') and branch_name == self.default_branch) : @@ -920,7 +937,6 @@ class GitHubTool(BaseTool): self.logger.error(error_message) return error_message - @metrics.measure def _get_issue_details(self, issue_number): self.logger.info(f"Getting details for issue #{issue_number}") url = f"{self.base_url}/repos/{self._repo}/issues/{issue_number}" @@ -933,7 +949,6 @@ class GitHubTool(BaseTool): self.logger.error(error_message) return error_message - @metrics.measure def _create_issue(self, title, body, labels=None): self.logger.info(f"Creating new issue with title: '{title}'") url = f"{self.base_url}/repos/{self._repo}/issues" @@ -953,7 +968,6 @@ class GitHubTool(BaseTool): self.logger.error(error_message) return error_message - @metrics.measure def _list_issues(self, state="open", per_page=30, page=1): self.logger.info(f"Listing issues with state: {state}, per_page: {per_page}, page: {page}") url = f"{self.base_url}/repos/{self._repo}/issues" @@ -969,7 +983,6 @@ class GitHubTool(BaseTool): self.logger.error(error_message) return error_message - @metrics.measure def _add_issue_comment(self, issue_number, comment): self.logger.info(f"Adding comment to issue #{issue_number}: '{comment[:50]}...'") url = f"{self.base_url}/repos/{self._repo}/issues/{issue_number}/comments" @@ -985,7 +998,6 @@ class GitHubTool(BaseTool): self.logger.error(error_message) return error_message - @metrics.measure def _get_issue_comments(self, issue_number): self.logger.info(f"Getting comments for issue #{issue_number}") url = f"{self.base_url}/repos/{self._repo}/issues/{issue_number}/comments" @@ -1000,14 +1012,12 @@ class GitHubTool(BaseTool): self.logger.error(error_message) return error_message - @metrics.measure def _get_pull_request_general_comments(self, pull_number): self.logger.info(f"Getting general comments for pull request #{pull_number}") # In GitHub API, PR comments (general, not review comments on lines) are issue comments. # The PR is also an issue, so use the issue comments endpoint. return self._get_issue_comments(issue_number=pull_number) - @metrics.measure def _create_project_board(self, name, body=None): self.logger.info(f"Creating project board: '{name}'") url = f"{self.base_url}/repos/{self._repo}/projects" @@ -1026,7 +1036,6 @@ class GitHubTool(BaseTool): self.logger.error(error_message) return error_message - @metrics.measure def _create_project_column(self, project_id, column_name): self.logger.info(f"Creating column '{column_name}' for project ID: {project_id}") url = f"{self.base_url}/projects/{project_id}/columns" @@ -1044,7 +1053,6 @@ class GitHubTool(BaseTool): self.logger.error(error_message) return error_message - @metrics.measure def _create_project_card(self, column_id, note=None, content_id=None, content_type=None): self.logger.info(f"Creating card in column ID: {column_id}") url = f"{self.base_url}/projects/columns/{column_id}/cards" @@ -1075,7 +1083,6 @@ class GitHubTool(BaseTool): self.logger.error(error_message) return error_message - @metrics.measure def _move_project_card(self, card_id, position, column_id=None): self.logger.info(f"Moving card ID: {card_id} to position: {position}" + (f" in column ID: {column_id}" if column_id else "")) url = f"{self.base_url}/projects/columns/cards/{card_id}/moves" @@ -1100,7 +1107,6 @@ class GitHubTool(BaseTool): # For updating an existing card to link an issue, one would PATCH the card's content_id/content_type. # Let's assume the function intends to update an existing card if it's a separate function. # However, the provided API spec for `link_issue_to_project_card` uses PATCH on card_id, so let's implement that. - @metrics.measure def _link_issue_to_project_card(self, card_id, content_id, content_type): self.logger.info(f"Linking content_id {content_id} (type: {content_type}) to card_id {card_id}") url = f"{self.base_url}/projects/cards/{card_id}" # Note: API docs suggest /projects/columns/cards/{card_id} or /projects/cards/{card_id} @@ -1120,7 +1126,6 @@ class GitHubTool(BaseTool): self.logger.error(error_message) return error_message - @metrics.measure def _list_project_boards(self): self.logger.info(f"Listing project boards for repo: {self._repo}") url = f"{self.base_url}/repos/{self._repo}/projects" @@ -1136,7 +1141,6 @@ class GitHubTool(BaseTool): self.logger.error(error_message) return error_message - @metrics.measure def _view_project_board_items(self, project_id): self.logger.info(f"Viewing items for project ID: {project_id}") columns_url = f"{self.base_url}/projects/{project_id}/columns" @@ -1165,7 +1169,6 @@ class GitHubTool(BaseTool): self.logger.info(f"Successfully retrieved items for project ID: {project_id}.") return project_items - @metrics.measure def _get_pull_request_details(self, pull_number): self.logger.info(f"Getting details for PR #{pull_number}") url = f"{self.base_url}/repos/{self._repo}/pulls/{pull_number}" @@ -1178,7 +1181,6 @@ class GitHubTool(BaseTool): self.logger.error(error_message) return error_message - @metrics.measure def _get_pull_request_diff(self, pull_number): self.logger.info(f"Getting diff for PR #{pull_number}") url = f"{self.base_url}/repos/{self._repo}/pulls/{pull_number}" @@ -1193,7 +1195,6 @@ class GitHubTool(BaseTool): self.logger.error(error_message) return error_message - @metrics.measure def _get_pull_request_files(self, pull_number): self.logger.info(f"Getting files for PR #{pull_number}") url = f"{self.base_url}/repos/{self._repo}/pulls/{pull_number}/files" @@ -1206,7 +1207,6 @@ class GitHubTool(BaseTool): self.logger.error(error_message) return error_message - @metrics.measure def _create_pull_request_review_comment(self, pull_number, body, commit_id, path, position, side="RIGHT", start_line=None, start_side=None): self.logger.info(f"Creating review comment on PR #{pull_number}, file '{path}', position {position}") url = f"{self.base_url}/repos/{self._repo}/pulls/{pull_number}/comments" @@ -1225,7 +1225,6 @@ class GitHubTool(BaseTool): self.logger.error(error_message) return error_message - @metrics.measure def _list_pull_request_review_comments(self, pull_number): self.logger.info(f"Listing review comments for PR #{pull_number}") url = f"{self.base_url}/repos/{self._repo}/pulls/{pull_number}/comments" @@ -1238,7 +1237,6 @@ class GitHubTool(BaseTool): self.logger.error(error_message) return error_message - @metrics.measure def _submit_pull_request_review(self, pull_number, event, body=None): self.logger.info(f"Submitting '{event}' review for PR #{pull_number}") url = f"{self.base_url}/repos/{self._repo}/pulls/{pull_number}/reviews" diff --git a/tools/log_tool.py b/tools/log_tool.py index bcadf04..c51ef5c 100644 --- a/tools/log_tool.py +++ b/tools/log_tool.py @@ -1,6 +1,5 @@ # tools/log_tool.py from .base_tool import BaseTool -from .metrics import metrics import logging import os from datetime import datetime, timedelta @@ -44,7 +43,6 @@ class LogTool(BaseTool): } ] - @metrics.measure def execute(self, function_name, **kwargs): self.logger.info(f"Executing LogTool function: {function_name} with args: {kwargs}") if function_name == "get_log_contents": @@ -55,7 +53,6 @@ class LogTool(BaseTool): self.logger.error(error_message) return error_message - @metrics.measure def _get_log_contents(self, line_count=None): # Default line_count is None to trigger 24h logic if not specified self.logger.info(f"Attempting to get log contents from: {self.configured_log_file_path}. Line count: {line_count if line_count is not None else 'Last 24 hours'}") diff --git a/tools/metrics.py b/tools/metrics.py deleted file mode 100644 index 54cb228..0000000 --- a/tools/metrics.py +++ /dev/null @@ -1,79 +0,0 @@ -# tools/metrics.py -import cProfile -import pstats -import io -from functools import wraps -from collections import defaultdict -import logging - -class Metrics: - def __init__(self, logger=None): - self.call_count = defaultdict(int) - self.total_time = defaultdict(float) - self.logger = logger if logger else logging.getLogger(__name__) - if not self.logger.handlers: - self.logger.addHandler(logging.NullHandler()) - self.logger.debug("Metrics instance initialized.") - - def measure(self, func): - @wraps(func) - def wrapper(*args, **kwargs): - self.call_count[func.__name__] += 1 - - pr = cProfile.Profile() - pr.enable() - result = func(*args, **kwargs) - pr.disable() - - ps = pstats.Stats(pr) - - func_code = func.__code__ - func_key_tuple = (func_code.co_filename, func_code.co_firstlineno, func_code.co_name) - - time_spent_for_func = 0.0 - if func_key_tuple in ps.stats: - time_spent_for_func = ps.stats[func_key_tuple][3] # [3] is cumulative time (ct) - else: - # Fallback: try to find by function name if exact key fails (e.g. due to decorators changing code object details slightly) - # This is less precise and might pick up other functions if names are not unique across files. - found_by_name = False - for key, stat in ps.stats.items(): - if key[2] == func.__name__: # key[2] is function name - time_spent_for_func = stat[3] # cumulative time - self.logger.debug(f"Found stats for {func.__name__} by name {key} after primary key failed.") - found_by_name = True - break - if not found_by_name: - self.logger.warning( - f"Could not find exact cProfile stats for {func.__name__} with key {func_key_tuple} or by name. " - f"Time for this call will be recorded as 0. This might occur for non-Python functions or due to complex decorators." - ) - - self.total_time[func.__name__] += time_spent_for_func - self.logger.debug(f"Measured cumulative time for {func.__name__}: {time_spent_for_func:.6f}s") - - return result - return wrapper - - def get_metrics(self): - metrics_data = {} - for func_name in self.call_count: - count = self.call_count[func_name] - total_t = self.total_time[func_name] - metrics_data[func_name] = { - 'call_count': count, - 'total_time': round(total_t, 6), - 'average_time': round(total_t / count, 6) if count > 0 else 0 - } - return metrics_data - - def clear_metrics(self): - self.call_count.clear() - self.total_time.clear() - self.logger.info("Metrics cleared.") - -# Global instance for convenience -_metrics_instance_logger = logging.getLogger(__name__ + ".global_instance") -if not _metrics_instance_logger.handlers: - _metrics_instance_logger.addHandler(logging.NullHandler()) -metrics = Metrics(logger=_metrics_instance_logger) diff --git a/tools/metrics_tool.py b/tools/metrics_tool.py deleted file mode 100644 index 91d4664..0000000 --- a/tools/metrics_tool.py +++ /dev/null @@ -1,128 +0,0 @@ -# tools/metrics_tool.py -from .base_tool import BaseTool -from .metrics import metrics as global_metrics_instance # For default and measuring execute -from .metrics import Metrics # For type hinting and potentially creating a new one if needed -import logging - -class MetricsTool(BaseTool): - def __init__(self, metrics_provider: Metrics | None = None, logger: logging.Logger | None = None): - self.metrics_provider = metrics_provider if metrics_provider is not None else global_metrics_instance - self.logger = logger if logger else logging.getLogger(__name__) - if not self.logger.handlers: - self.logger.addHandler(logging.NullHandler()) - self.logger.debug(f"MetricsTool initialized. Using metrics provider: {self.metrics_provider}") - - def clear(self): - # This tool itself doesn't hold state that needs clearing beyond what its metrics_provider might do. - # If this tool were responsible for clearing the metrics it reports on, it would call: - # self.metrics_provider.clear_metrics() - self.logger.debug("MetricsTool clear method called. No local state to clear.") - pass - - def get_functions(self): - return [ - { - "type": "function", - "function": { - "name": "get_function_metrics", - "description": "Get metrics for all measured functions.", - "parameters": { - "type": "object", - "properties": {}, - "required": [] - } - } - }, - { - "type": "function", - "function": { - "name": "get_specific_function_metrics", - "description": "Get metrics for a specific function.", - "parameters": { - "type": "object", - "properties": { - "function_name": { - "type": "string", - "description": "Name of the function to get metrics for" - } - }, - "required": ["function_name"] - } - } - }, - { - "type": "function", - "function": { - "name": "get_top_n_functions", - "description": "Get the top N functions by total execution time.", - "parameters": { - "type": "object", - "properties": { - "n": { - "type": "integer", - "description": "Number of top functions to retrieve" - } - }, - "required": ["n"] - } - } - } - ] - - @global_metrics_instance.measure # The execute method can be measured by the global instance - def execute(self, function_name, **kwargs): - self.logger.info(f"Executing MetricsTool function: {function_name} with args: {kwargs}") - if function_name == "get_function_metrics": - return self._get_function_metrics() - elif function_name == "get_specific_function_metrics": - func_name_arg = kwargs.get("function_name") - if func_name_arg is None: # Check if None, as empty string could be a valid (though unlikely) func name - self.logger.warning("'function_name' argument is missing for get_specific_function_metrics.") - return "Error: Missing required argument 'function_name'." - return self._get_specific_function_metrics(str(func_name_arg)) # Ensure string - elif function_name == "get_top_n_functions": - n_arg = kwargs.get("n") - if n_arg is None: - self.logger.warning("'n' argument is missing for get_top_n_functions.") - return "Error: Missing required argument 'n'." - try: - n_val = int(n_arg) - if n_val <= 0: - self.logger.warning(f"'n' argument must be a positive integer, got {n_val}.") - return "Error: Argument 'n' must be a positive integer." - return self._get_top_n_functions(n_val) - except ValueError: - self.logger.warning(f"'n' argument must be an integer, got '{n_arg}'.") - return "Error: Argument 'n' must be an integer." - else: - error_message = f"Unknown function: {function_name}" - self.logger.error(error_message) - return error_message - - def _get_function_metrics(self): - self.logger.debug("Calling metrics_provider.get_metrics() for all functions.") - return self.metrics_provider.get_metrics() - - def _get_specific_function_metrics(self, function_to_get): - self.logger.debug(f"Getting metrics for specific function: {function_to_get}") - all_metrics = self.metrics_provider.get_metrics() - return all_metrics.get(function_to_get, f"No metrics found for function: {function_to_get}") - - def _get_top_n_functions(self, n): - self.logger.debug(f"Getting top {n} functions by total execution time.") - all_metrics = self.metrics_provider.get_metrics() - # Ensure that the items are actual metric dicts before trying to access 'total_time' - valid_metrics_items = [] - for name, metric_values in all_metrics.items(): - if isinstance(metric_values, dict) and 'total_time' in metric_values: - valid_metrics_items.append((name, metric_values)) - else: - self.logger.warning(f"Metric item for '{name}' is not in expected format: {metric_values}") - - # Sort items by total_time. items() gives list of (func_name, metrics_dict) - try: - sorted_metrics = sorted(valid_metrics_items, key=lambda item: item[1]['total_time'], reverse=True) - return dict(sorted_metrics[:n]) - except TypeError as e: - self.logger.error(f"Error sorting metrics, possibly due to unexpected data types: {e}", exc_info=True) - return "Error: Could not sort metrics due to unexpected data." diff --git a/tools/standalone_llm_tool.py_test b/tools/standalone_llm_tool.py similarity index 95% rename from tools/standalone_llm_tool.py_test rename to tools/standalone_llm_tool.py index 1ac3db4..92ecf11 100644 --- a/tools/standalone_llm_tool.py_test +++ b/tools/standalone_llm_tool.py @@ -28,7 +28,7 @@ class StandaloneLLMTool(BaseTool): "model": { "type": "string", "description": "The model to use for generating the detailed instructions. Use mini for most coding tasks, preview when needing sophisticated reasoning", - "enum": ["o1-mini", "o1-preview"], + "enum": ["mini", "max"], "default": "o1-mini" }, "max_tokens": { @@ -38,7 +38,8 @@ class StandaloneLLMTool(BaseTool): }, "required": ["prompt"] } - } + }, + "_tags": ["llm", "external"] } ]