diff --git a/anthropic_telegram_inference_bot.py b/anthropic_telegram_inference_bot.py index 294f860..13a487e 100644 --- a/anthropic_telegram_inference_bot.py +++ b/anthropic_telegram_inference_bot.py @@ -3,25 +3,48 @@ import json import logging from anthropic import Anthropic, APIError, RateLimitError from base_telegram_inference_bot import BaseTelegramInferenceBot -from telegram_helper import TelegramHelper +from telegram_helper import TelegramHelper # Used in main, not class class AnthropicTelegramInferenceBot(BaseTelegramInferenceBot): - def __init__(self): - super().__init__() - self.anthropic_client = Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY")) - - # Initialize with the small model by default - self.small_model_name = os.environ.get("ANTHROPIC_SMALL_MODEL", "claude-3-haiku-20240307") - self.small_model_max_tokens = os.environ.get("ANTHROPIC_SMALL_MODEL_MAX_TOKENS", "2048") - self.large_model_name = os.environ.get("ANTHROPIC_LARGE_MODEL", "claude-3-opus-20240229") - self.large_model_max_tokens = os.environ.get("ANTHROPIC_LARGE_MODEL_MAX_TOKENS", "4096") + DEFAULT_SMALL_MODEL_NAME = "claude-3-haiku-20240307" + DEFAULT_SMALL_MODEL_MAX_TOKENS = "2048" + DEFAULT_LARGE_MODEL_NAME = "claude-3-opus-20240229" + DEFAULT_LARGE_MODEL_MAX_TOKENS = "4096" + def __init__( + self, + anthropic_client: Anthropic | None = None, + api_key: str | None = None, + small_model_name: str | None = None, + small_model_max_tokens: str | None = None, + large_model_name: str | None = None, + large_model_max_tokens: str | None = None, + system_prompt_content: str | None = None, + system_prompt_path: str | None = None + ): + super().__init__(system_prompt_content=system_prompt_content, system_prompt_path=system_prompt_path) + + if anthropic_client: + self.anthropic_client = anthropic_client + else: + _api_key = api_key or os.environ.get("ANTHROPIC_API_KEY") + if not _api_key: + raise ValueError("Anthropic API key must be provided either via argument or ANTHROPIC_API_KEY environment variable.") + self.anthropic_client = Anthropic(api_key=_api_key) + + self.small_model_name = small_model_name or os.environ.get("ANTHROPIC_SMALL_MODEL") or self.DEFAULT_SMALL_MODEL_NAME + self.small_model_max_tokens_str = small_model_max_tokens or os.environ.get("ANTHROPIC_SMALL_MODEL_MAX_TOKENS") or self.DEFAULT_SMALL_MODEL_MAX_TOKENS + self.large_model_name = large_model_name or os.environ.get("ANTHROPIC_LARGE_MODEL") or self.DEFAULT_LARGE_MODEL_NAME + self.large_model_max_tokens_str = large_model_max_tokens or os.environ.get("ANTHROPIC_LARGE_MODEL_MAX_TOKENS") or self.DEFAULT_LARGE_MODEL_MAX_TOKENS + + # Initialize with the small model by default self._configure_model_and_tokens( self.small_model_name, - self.small_model_max_tokens + self.small_model_max_tokens_str, + default_max_tokens=int(self.DEFAULT_SMALL_MODEL_MAX_TOKENS) # pass int for default ) - def _configure_model_and_tokens(self, model_name, max_tokens_str, default_max_tokens=2048): # Default max_tokens adjusted for typical "small" + def _configure_model_and_tokens(self, model_name: str, max_tokens_str: str, default_max_tokens: int = 2048): self.model = model_name try: self.max_tokens = int(max_tokens_str) if max_tokens_str is not None else default_max_tokens @@ -65,17 +88,19 @@ class AnthropicTelegramInferenceBot(BaseTelegramInferenceBot): def _format_tool_response_for_anthropic(self, tool_response_data): if isinstance(tool_response_data, str): + # Wrap plain string in a list of text blocks if not already structured return [{"type": "text", "text": tool_response_data}] + elif isinstance(tool_response_data, list) and all(isinstance(item, dict) and "type" in item for item in tool_response_data): + # Already a list of content blocks + return tool_response_data elif isinstance(tool_response_data, (dict, list)): + # Attempt to JSON dump other dicts/lists if not already in content block format try: - is_valid_block_list = isinstance(tool_response_data, list) and all(isinstance(item, dict) and "type" in item for item in tool_response_data) - if is_valid_block_list: - return tool_response_data - else: - return [{"type": "text", "text": json.dumps(tool_response_data)}] + return [{"type": "text", "text": json.dumps(tool_response_data)}] except (TypeError, json.JSONDecodeError): - return [{"type": "text", "text": str(tool_response_data)}] + return [{"type": "text", "text": str(tool_response_data)}] # Fallback to string else: + # Fallback for other types (int, float, etc.) return [{"type": "text", "text": str(tool_response_data)}] async def handle_message(self, user_id, user_message): @@ -87,14 +112,14 @@ class AnthropicTelegramInferenceBot(BaseTelegramInferenceBot): MAX_TOOL_ITERATIONS = 5 tool_use_count = 0 - assistant_response_content = "" + assistant_response_content = "" while tool_use_count < MAX_TOOL_ITERATIONS: response = self.get_chat_response(current_turn_messages) if not response or not response.content: logging.error("No valid response content from Anthropic LLM.") - self.conversation_history[user_id] = current_turn_messages + self.conversation_history[user_id] = current_turn_messages # Save current state return "Error: Could not get a valid response from the LLM." assistant_current_turn_content_blocks = response.content @@ -111,23 +136,22 @@ class AnthropicTelegramInferenceBot(BaseTelegramInferenceBot): assistant_response_content = "".join(text_parts_from_assistant) if not tool_calls_from_response: - break + break tool_results_for_model = [] for tool_call in tool_calls_from_response: tool_name = tool_call.name - tool_input = tool_call.input + tool_input = tool_call.input tool_use_id = tool_call.id logging.info(f"Attempting to call Anthropic tool: {tool_name} with input: {tool_input}") try: - tool_response_data = self.call_tool(tool_name, tool_input) + tool_response_data = self.call_tool(tool_name, tool_input) tool_result_content_block = self._format_tool_response_for_anthropic(tool_response_data) - tool_results_for_model.append({ "type": "tool_result", "tool_use_id": tool_use_id, - "content": tool_result_content_block + "content": tool_result_content_block }) except Exception as e: logging.error(f"Error calling tool {tool_name}: {e}") @@ -135,14 +159,18 @@ class AnthropicTelegramInferenceBot(BaseTelegramInferenceBot): "type": "tool_result", "tool_use_id": tool_use_id, "content": [{"type": "text", "text": f"Error executing tool {tool_name}: {str(e)}"}], - "is_error": True + "is_error": True }) - current_turn_messages.append({"role": "user", "content": tool_results_for_model}) + current_turn_messages.append({"role": "user", "content": tool_results_for_model}) # Anthropic expects tool results as a user message tool_use_count += 1 if tool_use_count >= MAX_TOOL_ITERATIONS: logging.warning(f"Max tool iterations ({MAX_TOOL_ITERATIONS}) reached for Anthropic.") + # Update assistant_response_content with any text from the last assistant turn before breaking + if not assistant_response_content and text_parts_from_assistant: + assistant_response_content = "".join(text_parts_from_assistant) + assistant_response_content += "\n[Max tool iterations reached]" break self.conversation_history[user_id] = current_turn_messages @@ -153,70 +181,89 @@ class AnthropicTelegramInferenceBot(BaseTelegramInferenceBot): if assistant_response_content: return assistant_response_content else: + # Fallback if no text parts were found but there was an assistant message if current_turn_messages: last_message_in_turn = current_turn_messages[-1] + # Check if the last message content has text blocks (Anthropic specific structure) if last_message_in_turn.get("role") == "assistant" and isinstance(last_message_in_turn.get("content"), list): for block in reversed(last_message_in_turn["content"]): - if block.type == "text": - return block.text - return "No textual response from assistant." - + if block.type == "text" and hasattr(block, 'text') and block.text: + return block.text # Return the first non-empty text found from the end + return "No textual response generated by the assistant after processing." # More informative default async def start(self): logging.info("Anthropic Bot started") - async def clear_conversation_history(self, user_id): - super().clear_conversation_history(user_id) - logging.info(f"Cleared conversation history for Anthropic bot, user {user_id}") + # clear_conversation_history is inherited from BaseTelegramInferenceBot and calls super().clear_conversation_history + # No need to override if the base implementation is sufficient, unless specific logging is needed. + # async def clear_conversation_history(self, user_id): + # super().clear_conversation_history(user_id) + # logging.info(f"Cleared conversation history for Anthropic bot, user {user_id}") async def abort_processing(self, user_id): + # This abort is a soft abort, as actual Anthropic API call is synchronous within handle_message + # It primarily clears state and prevents further processing in the bot's loop if any. if user_id in self.processing_status: - self.processing_status[user_id]["processing"] = False - await self.clear_conversation_history(user_id) - return "Processing aborted and conversation cleared." - else: - await self.clear_conversation_history(user_id) - return "No active processing found to abort. Conversation cleared." + self.processing_status[user_id]["processing"] = False # Mark as not processing + # self.clear_processing_status(user_id) # Use base class method to remove entry + # Clearing history might be too aggressive for a simple abort, depends on desired UX + # For now, let's just stop processing and clear the flag. + # Consider if conversation history should be cleared here or if that is a separate user action. + # super().clear_conversation_history(user_id) # Moved to be less aggressive + logging.info(f"Abort requested for user {user_id}. Processing flag cleared.") + return "Processing aborted. You can send a new message or /clear the conversation." async def switch_model(self): - # Ensure ANTHROPIC_SMALL_MODEL and ANTHROPIC_LARGE_MODEL related env vars are loaded in __init__ - # or ensure they are freshly checked here if they can change during runtime (less common for model names). - # For this implementation, we rely on the values stored during __init__. - if not self.small_model_name or not self.large_model_name: logging.warning("Small or Large model names for Anthropic are not defined. Cannot switch model.") return f"Model switching not fully configured. Currently using {self.model}." - if self.model == self.small_model_name: + current_is_small = self.model == self.small_model_name + current_is_large = self.model == self.large_model_name + + if current_is_small: target_model = self.large_model_name - target_max_tokens = self.large_model_max_tokens - # Use default large max_tokens if specific one isn't set or invalid - default_max_tokens_for_large = "4096" - elif self.model == self.large_model_name: + target_max_tokens_str = self.large_model_max_tokens_str + default_target_max_tokens = int(self.DEFAULT_LARGE_MODEL_MAX_TOKENS) + elif current_is_large: target_model = self.small_model_name - target_max_tokens = self.small_model_max_tokens - # Use default small max_tokens if specific one isn't set or invalid - default_max_tokens_for_large = "2048" + target_max_tokens_str = self.small_model_max_tokens_str + default_target_max_tokens = int(self.DEFAULT_SMALL_MODEL_MAX_TOKENS) else: - # Current model is neither the designated small nor large, switch to small as a reset - logging.warning(f"Current model {self.model} is neither the configured small nor large model. Switching to small model.") + logging.warning(f"Current model {self.model} is unrecognized. Switching to default small model.") target_model = self.small_model_name - target_max_tokens = self.small_model_max_tokens - default_max_tokens_for_large = "2048" + target_max_tokens_str = self.small_model_max_tokens_str + default_target_max_tokens = int(self.DEFAULT_SMALL_MODEL_MAX_TOKENS) - - self._configure_model_and_tokens(target_model, target_max_tokens, default_max_tokens=int(default_max_tokens_for_large)) # Pass appropriate default + self._configure_model_and_tokens(target_model, target_max_tokens_str, default_max_tokens=default_target_max_tokens) logging.info(f"Switched Anthropic model to: {self.model}") - return f"Switched to Anthropic model: {self.model} (Max Tokens: {self.max_tokens})"#Provide token info + return f"Switched to Anthropic model: {self.model} (Max Tokens: {self.max_tokens})" + +# The main function is for standalone execution and basic testing, not part of the class itself. +# It's good practice to update it to reflect changes if you use it for quick tests. +# For unit tests, we'll instantiate the class with mocked dependencies. def main(): - if not os.environ.get("ANTHROPIC_API_KEY"): - logging.error("FATAL: ANTHROPIC_API_KEY environment variable not set.") - return - logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') - bot = AnthropicTelegramInferenceBot() + # Example of how to instantiate with new constructor (assuming API key is in ENV for this example) + # For real tests, you'd mock Anthropic() or pass a mock client. + try: + # These would typically come from a config file or CLI args in a real app if not ENV + # For this example, we rely on ENV or defaults being handled by constructor if not provided. + bot = AnthropicTelegramInferenceBot( + api_key=os.environ.get("ANTHROPIC_API_KEY") # Explicitly pass, or let constructor handle ENV + ) + except ValueError as e: + logging.error(f"Failed to initialize bot: {e}") + return + except Exception as e: # Catch any other init errors + logging.error(f"An unexpected error occurred during bot initialization: {e}") + return + + # TelegramHelper also updated, ensure it's instantiated correctly for this main context. + # For this basic main, we might not pass all configurable paths to TelegramHelper, + # letting them use defaults. telegram_helper = TelegramHelper(bot) telegram_helper.run() diff --git a/base_telegram_inference_bot.py b/base_telegram_inference_bot.py index 6751e67..568046d 100644 --- a/base_telegram_inference_bot.py +++ b/base_telegram_inference_bot.py @@ -7,26 +7,47 @@ from abc import ABC, abstractmethod from tools.base_tool import BaseTool class BaseTelegramInferenceBot(ABC): - def __init__(self): + def __init__(self, system_prompt_content: str | None = None, system_prompt_path: str | None = None): # MODIFIED self.conversation_history = {} self.processing_status = {} - self.system_prompt = self.load_system_prompt() + # MODIFIED to pass arguments + self.system_prompt = self.load_system_prompt( + direct_content=system_prompt_content, + file_path=system_prompt_path + ) self.tools, self.functions = self.load_functions() - logging.info(f'System Prompt: {os.environ.get("SYSTEM_PROMPT_PATH")}') + # Logging the actual source of the system prompt might be more complex now, + # but we can log the final prompt or indicate if it's custom/default. + # We'll also log the source of the prompt inside load_system_prompt. + logging.info(f'System Prompt (effective): {"Custom" if self.system_prompt != "You are a helpful AI assistant." else "Default"}') logging.info(f'Github Repository: {os.environ.get("GITHUB_REPOSITORY")}') - def load_system_prompt(self): - system_prompt_path = os.getenv("SYSTEM_PROMPT_PATH") - if system_prompt_path and os.path.isfile(system_prompt_path): - try: - with open(system_prompt_path, "r", encoding="utf-8") as file: - return file.read().strip() - except IOError as e: - logging.warning(f"Could not read system prompt file {system_prompt_path}: {e}") - return "You are a helpful AI assistant." + def load_system_prompt(self, direct_content: str | None = None, file_path: str | None = None) -> str: # MODIFIED + default_prompt = "You are a helpful AI assistant." + + if direct_content: + logging.info("Using direct content for system prompt.") + return direct_content.strip() + + prompt_path_to_try = file_path or os.getenv("SYSTEM_PROMPT_PATH") + + if prompt_path_to_try: + if os.path.isfile(prompt_path_to_try): + try: + with open(prompt_path_to_try, "r", encoding="utf-8") as file: + content = file.read().strip() + logging.info(f"Successfully loaded system prompt from {prompt_path_to_try}.") + return content + except IOError as e: + logging.warning(f"Could not read system prompt file {prompt_path_to_try}: {e}. Using default.") + return default_prompt + else: + # This condition now also covers if 'file_path' argument was given but invalid + logging.warning(f"System prompt file {prompt_path_to_try} not found. Using default system prompt.") + return default_prompt else: - logging.warning("SYSTEM_PROMPT_PATH is not set or file does not exist. Using default system prompt.") - return "You are a helpful AI assistant." + logging.info("No system prompt path provided (argument or ENV) or direct content. Using default system prompt.") + return default_prompt def load_functions(self): tools = [] @@ -44,7 +65,7 @@ class BaseTelegramInferenceBot(ABC): for name, obj in inspect.getmembers(module): if inspect.isclass(obj) and issubclass(obj, BaseTool) and obj != BaseTool: try: - tools.append(obj()) + tools.append(obj()) # This instantiation might be an issue for tools needing config except Exception as e: logging.error(f"Error instantiating tool {name} from {filename}: {e}") except Exception as e: @@ -87,9 +108,9 @@ class BaseTelegramInferenceBot(ABC): except json.JSONDecodeError as e: logging.error(f"Error decoding function call arguments (string) for {function_call_name}: {e}. Arguments: {function_call_arguments}") return f"Error: Malformed arguments for tool call: {e}" - else: # Handle cases where arguments might be None or other unexpected types + else: if function_call_arguments is None: - function_args = {} # Default to empty dict if arguments are None + function_args = {} else: logging.error(f"Unexpected type for function_call_arguments for {function_call_name}: {type(function_call_arguments)}. Arguments: {function_call_arguments}") return f"Error: Invalid argument type for tool call: {type(function_call_arguments)}" @@ -98,7 +119,6 @@ class BaseTelegramInferenceBot(ABC): for function in tool.get_functions(): if function["function"]["name"] == function_name: try: - # Ensure function_args is a dictionary before unpacking if not isinstance(function_args, dict): logging.error(f"Internal error: function_args not a dict for {function_name} before execution. Args: {function_args}") return f"Internal error preparing arguments for tool {function_name}." @@ -110,16 +130,23 @@ class BaseTelegramInferenceBot(ABC): return f"Error: Tool function {function_name} not found." def get_system_prompt_description(self) -> str: - """Returns a description of the system prompt being used.""" - return f"System Prompt: {'Custom' if os.getenv('SYSTEM_PROMPT_PATH') else 'Default'}" + # This method could be updated to be more specific about the prompt source if needed. + # For now, it still reflects custom vs default based on the original ENV var logic's spirit. + # A more accurate reflection would require storing how the prompt was loaded. + # For simplicity, let's assume if it's not the default, it's "Custom". + if self.system_prompt != "You are a helpful AI assistant.": + return "System Prompt: Custom" + # Check original ENV var for backward compatibility in description only + elif os.getenv('SYSTEM_PROMPT_PATH'): + return "System Prompt: Custom (via ENV)" + return "System Prompt: Default" + @abstractmethod def get_llm_description(self) -> str: - """Returns a description of the LLM being used.""" pass async def get_bot_status(self) -> str: - """Provides a status message including prompt and LLM information.""" prompt_desc = self.get_system_prompt_description() llm_desc = self.get_llm_description() return f"{prompt_desc}\n{llm_desc}" @@ -134,5 +161,4 @@ class BaseTelegramInferenceBot(ABC): @abstractmethod async def switch_model(self): - """Switches the underlying model if supported by the bot.""" pass diff --git a/chatgpt_telegram_inference_bot.py b/chatgpt_telegram_inference_bot.py index 086c49f..1c555c7 100644 --- a/chatgpt_telegram_inference_bot.py +++ b/chatgpt_telegram_inference_bot.py @@ -1,43 +1,105 @@ import os import logging -from openai import OpenAI +from openai import OpenAI # Keep for type hinting and default client creation from openai_compatible_inference_bot import OpenAICompatibleInferenceBot -from telegram_helper import TelegramHelper +from telegram_helper import TelegramHelper # Used in main class ChatGPTTelegramInferenceBot(OpenAICompatibleInferenceBot): - def __init__(self): - super().__init__() - self.client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) + DEFAULT_SMALL_MODEL_NAME = "gpt-3.5-turbo" + DEFAULT_LARGE_MODEL_NAME = "gpt-4" + # Default max tokens can be None, relying on parent or API defaults + DEFAULT_SMALL_MODEL_MAX_TOKENS = None + DEFAULT_LARGE_MODEL_MAX_TOKENS = None + + def __init__( + self, + client: OpenAI | None = None, # Accepts an OpenAI client + api_key: str | None = None, + small_model_name: str | None = None, + small_model_max_tokens: str | None = None, # Kept as str for consistency with env vars + large_model_name: str | None = None, + large_model_max_tokens: str | None = None, + system_prompt_content: str | None = None, + system_prompt_path: str | None = None, + base_url: str | None = None, # For OpenAI compatible, though direct OpenAI client doesn't use it here + ): + # Initialize model names and tokens before calling super, as super might use them via _configure_model_and_tokens + self.small_model_name = small_model_name or os.environ.get("OPENAI_SMALL_MODEL") or self.DEFAULT_SMALL_MODEL_NAME + self.small_model_max_tokens_str = small_model_max_tokens or os.environ.get("OPENAI_SMALL_MODEL_MAX_TOKENS") or self.DEFAULT_SMALL_MODEL_MAX_TOKENS - self._configure_model_and_tokens( - os.environ.get("OPENAI_SMALL_MODEL", "gpt-3.5-turbo"), - os.environ.get("OPENAI_SMALL_MODEL_MAX_TOKENS") + self.large_model_name = large_model_name or os.environ.get("OPENAI_LARGE_MODEL") or self.DEFAULT_LARGE_MODEL_NAME + self.large_model_max_tokens_str = large_model_max_tokens or os.environ.get("OPENAI_LARGE_MODEL_MAX_TOKENS") or self.DEFAULT_LARGE_MODEL_MAX_TOKENS + + # The actual client and active model configuration will be handled by OpenAICompatibleInferenceBot's __init__ + # We pass the specific OpenAI client or parameters to create one. + # If a client is passed, api_key and base_url might be ignored by super if super prioritizes existing client. + super().__init__( + client=client, + api_key=api_key, + model_name=self.small_model_name, # Initial model + max_tokens_str=self.small_model_max_tokens_str, + system_prompt_content=system_prompt_content, + system_prompt_path=system_prompt_path, + base_url=base_url # Pass base_url, though for standard OpenAI it's fixed ) + # Ensure client is of type OpenAI for this specific class, if not already set by super with a compatible one. + # This check is more of an assertion, as OpenAICompatibleInferenceBot should handle client creation. + if not isinstance(self.client, OpenAI): + # If super() didn't create a vanilla OpenAI client (e.g. if base_url was for Azure) + # we might need to recreate it here if this class *must* use a non-Azure OpenAI client. + # However, the current structure of OpenAICompatibleInferenceBot handles this. + # This is more about ensuring type correctness if code specific to OpenAI (non-compatible) methods were added here. + _api_key = api_key or os.environ.get("OPENAI_API_KEY") + if not self.client or (base_url and not isinstance(self.client, OpenAI)): + # If superclass initialized with a generic client due to base_url, re-init for OpenAI specifically if needed. + # For now, assume superclass correctly initializes based on absence of Azure env vars for this path. + # This logic might be simplified once OpenAICompatibleInferenceBot is fully refactored. + if not _api_key: # Ensure API key is available if we need to create a client + raise ValueError("OpenAI API key must be provided for ChatGPTTelegramInferenceBot if no client is passed.") + self.client = OpenAI(api_key=_api_key) + logging.info("Client re-initialized to standard OpenAI client for ChatGPTTelegramInferenceBot.") async def switch_model(self): - current_small_model = os.environ.get("OPENAI_SMALL_MODEL", "gpt-3.5-turbo") - current_large_model = os.environ.get("OPENAI_LARGE_MODEL", "gpt-4") + # Uses instance variables for model names set in __init__ + if not self.small_model_name or not self.large_model_name: + logging.warning("Small or Large model names for OpenAI are not defined. Cannot switch model.") + return f"Model switching not fully configured. Currently using {self.model}." - if self.model == current_large_model or self.model != current_small_model: - target_model = current_small_model - target_max_tokens = os.environ.get("OPENAI_SMALL_MODEL_MAX_TOKENS") + current_is_small = self.model == self.small_model_name + current_is_large = self.model == self.large_model_name + + if current_is_large: + target_model = self.small_model_name + target_max_tokens_str = self.small_model_max_tokens_str + elif current_is_small: + target_model = self.large_model_name + target_max_tokens_str = self.large_model_max_tokens_str else: - target_model = current_large_model - target_max_tokens = os.environ.get("OPENAI_LARGE_MODEL_MAX_TOKENS") + # Current model is neither the designated small nor large for this bot, + # switch to this bot's default small model as a reset. + logging.warning(f"Current model {self.model} is unrecognized for ChatGPT bot. Switching to default small model: {self.small_model_name}.") + target_model = self.small_model_name + target_max_tokens_str = self.small_model_max_tokens_str - self._configure_model_and_tokens(target_model, target_max_tokens) - logging.info(f"Switched to model: {self.model}") - return f"Switched to model: {self.model}" + self._configure_model_and_tokens(target_model, target_max_tokens_str) + # self.model and self.max_tokens are updated by _configure_model_and_tokens + logging.info(f"Switched to OpenAI model: {self.model}") + return f"Switched to OpenAI model: {self.model} (Max Tokens: {self.max_tokens if self.max_tokens is not None else 'API default'})" def main(): - if not os.environ.get("OPENAI_API_KEY"): - logging.error("FATAL: OPENAI_API_KEY environment variable not set.") - return - logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') - bot = ChatGPTTelegramInferenceBot() - telegram_helper = TelegramHelper(bot) + try: + # Example: api_key from env, other params default or from env via constructor logic + bot = ChatGPTTelegramInferenceBot(api_key=os.environ.get("OPENAI_API_KEY")) + except ValueError as e: + logging.error(f"FATAL: {e}") + return + except Exception as e: + logging.error(f"An unexpected error occurred during bot initialization: {e}") + return + + telegram_helper = TelegramHelper(bot) telegram_helper.run() if __name__ == '__main__': diff --git a/gemini_telegram_inference_bot.py b/gemini_telegram_inference_bot.py index 5c5f549..09c174e 100644 --- a/gemini_telegram_inference_bot.py +++ b/gemini_telegram_inference_bot.py @@ -1,43 +1,103 @@ import os import logging -from openai import OpenAI +from openai import OpenAI # For type hinting and default client creation if needed from openai_compatible_inference_bot import OpenAICompatibleInferenceBot -from telegram_helper import TelegramHelper +from telegram_helper import TelegramHelper # Used in main class GeminiTelegramInferenceBot(OpenAICompatibleInferenceBot): - def __init__(self): - super().__init__() - self.client = OpenAI(api_key=os.environ.get("GEMINI_API_KEY"), base_url=os.environ.get("GEMINI_API_BASE_URL")) + DEFAULT_GEMINI_API_BASE_URL = "https://generativelanguage.googleapis.com/v1beta" + DEFAULT_SMALL_MODEL_NAME = "gemini-pro" # Actual model name for Gemini, not via OpenAI client directly + DEFAULT_LARGE_MODEL_NAME = "gemini-1.5-pro-latest" + DEFAULT_SMALL_MODEL_MAX_TOKENS = "2048" # Gemini uses outputTokenLimit, not exactly max_tokens in OpenAI sense + DEFAULT_LARGE_MODEL_MAX_TOKENS = "8192" + + def __init__( + self, + client: OpenAI | None = None, # OpenAI client for compatible mode + api_key: str | None = None, # Gemini API Key + base_url: str | None = None, # Gemini API Base URL for OpenAI client + small_model_name: str | None = None, + small_model_max_tokens: str | None = None, + large_model_name: str | None = None, + large_model_max_tokens: str | None = None, + system_prompt_content: str | None = None, + system_prompt_path: str | None = None + ): + _api_key = api_key or os.environ.get("GEMINI_API_KEY") + _base_url = base_url or os.environ.get("GEMINI_API_BASE_URL") or self.DEFAULT_GEMINI_API_BASE_URL + + if not _api_key: + # This check might seem redundant if super() also checks, but it's good for clarity + # for this specific bot type if it were to be instantiated directly with missing critical env vars. + raise ValueError("Gemini API key must be provided either via api_key argument or GEMINI_API_KEY environment variable.") + + self.small_model_name = small_model_name or os.environ.get("GEMINI_SMALL_MODEL") or self.DEFAULT_SMALL_MODEL_NAME + self.small_model_max_tokens_str = small_model_max_tokens or os.environ.get("GEMINI_SMALL_MODEL_MAX_TOKENS") or self.DEFAULT_SMALL_MODEL_MAX_TOKENS - self._configure_model_and_tokens( - os.environ.get("GEMINI_SMALL_MODEL", "gemini-pro"), - os.environ.get("GEMINI_SMALL_MODEL_MAX_TOKENS") + self.large_model_name = large_model_name or os.environ.get("GEMINI_LARGE_MODEL") or self.DEFAULT_LARGE_MODEL_NAME + self.large_model_max_tokens_str = large_model_max_tokens or os.environ.get("GEMINI_LARGE_MODEL_MAX_TOKENS") or self.DEFAULT_LARGE_MODEL_MAX_TOKENS + + # Pass parameters to the OpenAICompatibleInferenceBot constructor + # It will create an OpenAI client configured for the Gemini endpoint + super().__init__( + client=client, + api_key=_api_key, # This key will be used by OpenAI client for the custom base_url + model_name=self.small_model_name, # Initial model + max_tokens_str=self.small_model_max_tokens_str, + system_prompt_content=system_prompt_content, + system_prompt_path=system_prompt_path, + base_url=_base_url, # Crucial for Gemini via OpenAI client + is_gemini=True # Flag for specific Gemini handling in compatible layer if needed ) + # self.client will be set by OpenAICompatibleInferenceBot with base_url and api_key. + # Logging to confirm Gemini specific setup + logging.info(f"GeminiTelegramInferenceBot initialized to use model {self.model} via {_base_url}") async def switch_model(self): - current_small_model = os.environ.get("GEMINI_SMALL_MODEL", "gemini-pro") - current_large_model = os.environ.get("GEMINI_LARGE_MODEL", "gemini-1.5-pro-latest") + if not self.small_model_name or not self.large_model_name: + logging.warning("Small or Large model names for Gemini are not defined. Cannot switch model.") + return f"Model switching not fully configured. Currently using {self.model}." - if self.model == current_large_model or self.model != current_small_model : - target_model = current_small_model - target_max_tokens = os.environ.get("GEMINI_SMALL_MODEL_MAX_TOKENS") + current_is_small = self.model == self.small_model_name + current_is_large = self.model == self.large_model_name + + if current_is_large: + target_model = self.small_model_name + target_max_tokens_str = self.small_model_max_tokens_str + elif current_is_small: + target_model = self.large_model_name + target_max_tokens_str = self.large_model_max_tokens_str else: - target_model = current_large_model - target_max_tokens = os.environ.get("GEMINI_LARGE_MODEL_MAX_TOKENS") - - self._configure_model_and_tokens(target_model, target_max_tokens) - logging.info(f"Switched to model: {self.model}") - return f"Switched to model: {self.model}" + logging.warning(f"Current model {self.model} is unrecognized for Gemini bot. Switching to default small model: {self.small_model_name}.") + target_model = self.small_model_name + target_max_tokens_str = self.small_model_max_tokens_str + + self._configure_model_and_tokens(target_model, target_max_tokens_str) + logging.info(f"Switched to Gemini model: {self.model}") + # For Gemini, max_tokens might translate to outputTokenLimit, so be clear it's a configuration parameter + return f"Switched to Gemini model: {self.model} (Configured Max Tokens: {self.max_tokens if self.max_tokens is not None else 'API default'})" def main(): + logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') + + # GEMINI_API_KEY is crucial for this bot if not os.environ.get("GEMINI_API_KEY"): logging.error("FATAL: GEMINI_API_KEY environment variable not set.") return + # GEMINI_API_BASE_URL is also important, but constructor has a default - logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') - - bot = GeminiTelegramInferenceBot() - telegram_helper = TelegramHelper(bot) + try: + bot = GeminiTelegramInferenceBot( + # api_key and base_url will be picked from ENV by constructor if not passed + ) + except ValueError as e: + logging.error(f"FATAL: {e}") + return + except Exception as e: # Catch any other init errors + logging.error(f"An unexpected error occurred during bot initialization: {e}") + return + + telegram_helper = TelegramHelper(bot) telegram_helper.run() if __name__ == '__main__': diff --git a/openai_compatible_inference_bot.py b/openai_compatible_inference_bot.py index ba58dd6..8a7c210 100644 --- a/openai_compatible_inference_bot.py +++ b/openai_compatible_inference_bot.py @@ -3,67 +3,150 @@ import os import logging from abc import abstractmethod from base_telegram_inference_bot import BaseTelegramInferenceBot -from openai import OpenAI +from openai import OpenAI, AzureOpenAI # Import both class OpenAICompatibleInferenceBot(BaseTelegramInferenceBot): - def __init__(self): - super().__init__() - # Client and model configuration will be handled by subclasses - self.client = None - self.model = None - self.max_tokens = None + DEFAULT_MAX_HISTORY_LENGTH = 20 + DEFAULT_MAX_TOKENS = 1000 # Default for _configure_model_and_tokens - def _configure_model_and_tokens(self, model_name, max_tokens_str, default_max_tokens=1000): - self.model = model_name if model_name else "default-model" + def __init__( + self, + client: OpenAI | AzureOpenAI | None = None, + api_key: str | None = None, + base_url: str | None = None, + api_version: str | None = None, # For Azure + azure_deployment: str | None = None, # Model for Azure, distinct from general model_name if needed + model_name: str | None = None, # General model name for the API call + max_tokens_str: str | None = None, + system_prompt_content: str | None = None, + system_prompt_path: str | None = None, + is_gemini: bool = False, # Hint for specific API key if others are not set + max_history_length: int | None = None + ): + super().__init__(system_prompt_content=system_prompt_content, system_prompt_path=system_prompt_path) + + self.max_history_length = max_history_length if max_history_length is not None else self.DEFAULT_MAX_HISTORY_LENGTH + self.client = client + + if not self.client: + _api_key = api_key + _base_url = base_url + _api_version = api_version + _azure_deployment_name = azure_deployment # This will be used as the model for Azure + + # Determine if configuring for Azure OpenAI + is_azure = False + if _azure_deployment_name or (_base_url and "azure.com" in _base_url) or os.environ.get("AZURE_OPENAI_ENDPOINT"): + is_azure = True + + if is_azure: + _base_url = _base_url or os.environ.get("AZURE_OPENAI_ENDPOINT") + _api_key = _api_key or os.environ.get("AZURE_OPENAI_KEY") + _api_version = _api_version or os.environ.get("AZURE_OPENAI_API_VERSION") + # For Azure, the model parameter in API calls is the deployment name + _effective_model_name = _azure_deployment_name or model_name # Use deployment if available, else model_name + if not _base_url or not _api_key or not _api_version or not _effective_model_name: + raise ValueError("For Azure OpenAI, endpoint, API key, API version, and deployment/model name must be configured.") + self.client = AzureOpenAI( + api_key=_api_key, + azure_endpoint=_base_url, + api_version=_api_version + ) + # The model to be used in API calls for Azure is the deployment name. + # _configure_model_and_tokens will set self.model to this. + model_name_for_config = _effective_model_name + logging.info(f"Initialized AzureOpenAI client for deployment: {model_name_for_config} at {_base_url}") + else: + # Standard OpenAI or other OpenAI-compatible (like Gemini via base_url) + _base_url = _base_url or os.environ.get("OPENAI_API_BASE_URL") # For other compatible APIs + if not _api_key: # Try different ENV sources for API key + if is_gemini and os.environ.get("GEMINI_API_KEY"): + _api_key = os.environ.get("GEMINI_API_KEY") + else: + _api_key = os.environ.get("OPENAI_API_KEY") + + if not _api_key and not _base_url : # For completely local models with no key needed via base_url + pass # Allow client to be created with no API key if base_url is set and points to local model + elif not _api_key: + raise ValueError("API key must be provided for OpenAI compatible client if not Azure or local anonymous.") + + self.client = OpenAI(api_key=_api_key, base_url=_base_url) + model_name_for_config = model_name # Use the general model_name for non-Azure + log_msg = f"Initialized OpenAI compatible client. Target URL: {_base_url if _base_url else 'OpenAI default'}." + logging.info(log_msg) + else: + # Client was provided directly + model_name_for_config = model_name # Use provided model_name + logging.info(f"Using provided client: {type(self.client)}") + + # Configure the actual model name and max_tokens for API calls + self._configure_model_and_tokens( + model_name_for_config, + max_tokens_str, + default_max_tokens=self.DEFAULT_MAX_TOKENS + ) + + def _configure_model_and_tokens(self, model_name: str | None, max_tokens_str: str | None, default_max_tokens: int = 1000): + self.model = model_name if model_name else "default-model" # Fallback model name try: - self.max_tokens = int(max_tokens_str) if max_tokens_str is not None else default_max_tokens + # If max_tokens_str is explicitly "None" or empty, treat as None for API default + if max_tokens_str and max_tokens_str.lower() not in ["none", "", "null"]: + self.max_tokens = int(max_tokens_str) + else: + self.max_tokens = None # Use API default by not sending the parameter or sending null except ValueError: - logging.error(f"Invalid value for max_tokens: {max_tokens_str}. Using default {default_max_tokens}.") - self.max_tokens = default_max_tokens - logging.info(f"Configured to use model: {self.model} with max_tokens: {self.max_tokens}") + logging.warning(f"Invalid value for max_tokens: {max_tokens_str}. Using API default (None). stalwart default was {default_max_tokens}") + self.max_tokens = None # Use API default + + logging.info(f"Configured to use model: {self.model} with max_tokens: {self.max_tokens if self.max_tokens is not None else 'API default'}") def get_llm_description(self) -> str: - return f"LLM: {self.model}, Max Tokens: {self.max_tokens}" + client_type = type(self.client).__name__ + return f"Client: {client_type}, LLM: {self.model}, Max Tokens: {self.max_tokens if self.max_tokens is not None else 'API default'}" def get_chat_response(self, messages): if not self.client: - raise ValueError("OpenAI client not initialized. Subclasses must initialize it.") + # This should ideally not be hit if __init__ is successful + logging.error("OpenAI client not initialized before get_chat_response.") + raise ValueError("OpenAI client not initialized.") try: + # Pass self.max_tokens directly. If None, OpenAI library omits it or handles it. response = self.client.chat.completions.create( - model=self.model, + model=self.model, messages=messages, tools=self.functions if hasattr(self, 'functions') and self.functions else None, tool_choice="auto" if hasattr(self, 'functions') and self.functions else None, - max_tokens=self.max_tokens + max_tokens=self.max_tokens ) return response except Exception as e: - logging.error(f"API call failed: {e}") + logging.error(f"API call to model {self.model} failed: {e}") raise async def handle_message(self, user_id, user_message): - if user_id not in self.conversation_history: + if user_id not in self.conversation_history or not self.conversation_history[user_id]: self.conversation_history[user_id] = [] - if hasattr(self, 'system_prompt') and self.system_prompt: + if self.system_prompt: # Use the loaded system_prompt self.conversation_history[user_id].append({"role": "system", "content": self.system_prompt}) self.conversation_history[user_id].append({"role": "user", "content": user_message}) - messages = self.conversation_history[user_id] + messages = list(self.conversation_history[user_id]) # Work with a copy for this turn response = self.get_chat_response(messages) if not (response.choices and response.choices[0].message): logging.error("No valid response choice message from LLM.") + # Persist the user message in history even if LLM fails this turn + self.conversation_history[user_id] = messages return "Error: Could not get a valid response from the LLM." - messages.append(response.choices[0].message) # Append the assistant's response message + assistant_message = response.choices[0].message + messages.append(assistant_message) - tool_calls_from_response = [] - if response.choices[0].message.tool_calls: - tool_calls_from_response.extend(response.choices[0].message.tool_calls) + tool_calls_from_response = list(assistant_message.tool_calls) if assistant_message.tool_calls else [] tool_use_count = 0 - MAX_TOOL_ITERATIONS = 200 + MAX_TOOL_ITERATIONS = 5 # OpenAI compatible typically uses fewer iterations than Anthropic while tool_calls_from_response and tool_use_count < MAX_TOOL_ITERATIONS: tool_results_for_model = [] @@ -71,20 +154,24 @@ class OpenAICompatibleInferenceBot(BaseTelegramInferenceBot): for tool_call in tool_calls_from_response: tool_call_id = tool_call.id function_to_call = tool_call.function + function_name = function_to_call.name + function_args_str = function_to_call.arguments - logging.info(f"Attempting to call tool: {function_to_call.name} with args: {function_to_call.arguments}") + logging.info(f"Attempting to call tool: {function_name} with args: {function_args_str}") try: - tool_response_content = self.call_tool(function_to_call.name, function_to_call.arguments) + # Arguments are already a string from the API, self.call_tool expects dict or string + tool_response_content = self.call_tool(function_name, function_args_str) + # Ensure content is string for OpenAI tool role if not isinstance(tool_response_content, str): tool_response_content = json.dumps(tool_response_content) except Exception as e: - logging.error(f"Error calling tool {function_to_call.name}: {e}") - tool_response_content = f"Error executing tool {function_to_call.name}: {str(e)}" + logging.error(f"Error calling tool {function_name}: {e}") + tool_response_content = f"Error executing tool {function_name}: {str(e)}" tool_results_for_model.append({ "role": "tool", "tool_call_id": tool_call_id, - "name": function_to_call.name, + "name": function_name, "content": tool_response_content }) @@ -93,41 +180,51 @@ class OpenAICompatibleInferenceBot(BaseTelegramInferenceBot): response = self.get_chat_response(messages) if not (response.choices and response.choices[0].message): logging.error("No valid response choice message from LLM after tool call.") + self.conversation_history[user_id] = messages # Persist state before error return "Error: Could not get a valid response from the LLM after tool call." - messages.append(response.choices[0].message) + assistant_message = response.choices[0].message + messages.append(assistant_message) - tool_calls_from_response = [] - if response.choices[0].message.tool_calls: - tool_calls_from_response.extend(response.choices[0].message.tool_calls) + tool_calls_from_response = list(assistant_message.tool_calls) if assistant_message.tool_calls else [] tool_use_count += 1 if tool_use_count >= MAX_TOOL_ITERATIONS and tool_calls_from_response: logging.warning(f"Max tool iterations ({MAX_TOOL_ITERATIONS}) reached. Returning last assistant message.") + # Ensure final content is returned even if max iterations hit with pending tool calls + break - # Conversation history management - # This limit should be reviewed and potentially made configurable - if len(self.conversation_history[user_id]) > 20: # Example limit, adjust as needed - self.conversation_history[user_id] = self.conversation_history[user_id][-20:] + self.conversation_history[user_id] = messages # Persist the full exchange for this turn + # Apply history length limit + if len(self.conversation_history[user_id]) > self.max_history_length: + # Keep system prompt if present as the first message, then trim the rest + if self.conversation_history[user_id][0]["role"] == "system": + system_msg = [self.conversation_history[user_id][0]] + trimmed_history = self.conversation_history[user_id][-(self.max_history_length-1):] + self.conversation_history[user_id] = system_msg + trimmed_history + else: + self.conversation_history[user_id] = self.conversation_history[user_id][-self.max_history_length:] final_assistant_message = messages[-1] - return final_assistant_message.content if final_assistant_message.role == "assistant" and final_assistant_message.content else "No content in final message." + return final_assistant_message.content if final_assistant_message.role == "assistant" and final_assistant_message.content is not None else "Assistant did not provide a textual response." async def start(self): - logging.info(f"{self.__class__.__name__} started.") + logging.info(f"{self.__class__.__name__} (Model: {self.model}) started.") - def clear(self, user_id): - super().clear_conversation_history(user_id) + # clear_conversation_history is inherited from BaseTelegramInferenceBot async def abort_processing(self, user_id): + # This is a soft abort for OpenAI compatible bots as API calls are synchronous within handle_message if user_id in self.processing_status: - self.processing_status[user_id]["processing"] = False - self.clear(user_id) - return "Processing aborted and conversation cleared." + self.clear_processing_status(user_id) # Use base class method + logging.info(f"Processing aborted for user {user_id}.") + # Optionally clear conversation history or let user do it explicitly + # super().clear_conversation_history(user_id) + return "Processing aborted. You can send a new message or /clear the conversation." else: - self.clear(user_id) - return "No active processing found to abort. Conversation cleared." + # super().clear_conversation_history(user_id) + return "No active processing found to abort. If you wish, /clear the conversation history." @abstractmethod async def switch_model(self): - pass \ No newline at end of file + pass diff --git a/prompts/developer_prompt.txt b/prompts/developer_prompt.txt index bcab8af..3207461 100644 --- a/prompts/developer_prompt.txt +++ b/prompts/developer_prompt.txt @@ -30,3 +30,19 @@ Pull Requests and Issues: The Collaborative Symphony Pull Request Mastery: Treat pull requests as complete change proposals. They evolve with each commit to their branch. Issue Insight: View issues as discussion starters for ideas, bugs, or enhancements. They may inspire multiple pull requests. Ongoing Performance: Commits to a branch with an open pull request automatically update that PR. No need for new PRs per commit. + +**Focus on Testability and Robust Design (Lessons Learned):** + +When implementing or refactoring, *aggressively prioritize testability*. This includes: +* **Dependency Injection:** Consistently apply Dependency Injection for all external services (e.g., API clients, database connections), configurations (e.g., API keys, file paths, model names, feature flags), and system resources (e.g., file system access via `open`, network requests via `requests.Session`, time/clock functions if timing is critical and needs mocking). +* **Configuration Management:** Externalize configurations. Allow them to be passed via constructor arguments, with environment variables or sensible defaults as fallbacks. Avoid hardcoding paths, keys, or URLs directly within functions or methods. +* **Separation of Concerns:** Clearly separate core business logic from framework-specific code, I/O operations, or direct external service interactions. This often involves creating internal `_logic` methods that can be tested independently of, for example, Telegram API update/context objects. +* **Logging for Libraries/Tools:** Components like tools or libraries should use `logging.getLogger(__name__)` for their logging. They should *not* configure handlers (e.g., `FileHandler`, `StreamHandler`) directly. Logging setup (handlers, formatters, levels) is the responsibility of the main application. Tools can accept an optional `logger` instance via their constructor for more explicit control by the application or for testing. +* **State Management for Testability:** For stateful components, tools, or classes, ensure there's a mechanism to reset or clear their state (e.g., a `clear()` or `reset()` method). This is crucial for test isolation and predictable behavior during testing. +* **Robust Metrics & Profiling:** When implementing metrics collection (e.g., using `cProfile` via decorators), ensure that data extraction (like execution time) is robust. Rely on stable APIs or attributes of the profiling tools (e.g., `pstats.Stats.stats` dictionary) rather than fragile string parsing of their output. Provide methods to clear/reset collected metrics to facilitate testing of the metrics system itself. +* **Comprehensive Unit Testing Strategy:** When generating unit tests: + * For abstract base classes, create simple concrete subclasses within the test file to enable instantiation and testing of shared, non-abstract logic. + * Employ `unittest.mock` (`MagicMock`, `patch`, `AsyncMock`, `mock_open`) extensively to isolate the unit under test from its dependencies. + * Cover various scenarios: initialization with different configurations, success paths for public methods, error conditions (e.g., API errors, file not found, invalid arguments), and relevant edge cases. + * Thoroughly mock external dependencies like file system operations, network calls, and any injected client objects. +* **Iterative Development Cycle:** For significant changes or new features, propose refactoring for testability *first*, then proceed to write comprehensive unit tests against the refactored code. This leads to more robust, maintainable, and reliable components. diff --git a/telegram_helper.py b/telegram_helper.py index 9ede863..2bdbae6 100644 --- a/telegram_helper.py +++ b/telegram_helper.py @@ -3,44 +3,70 @@ import logging import sys import asyncio import time +from typing import TypedDict, Union, TypeAlias, List # Added List for type hint from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup from telegram.ext import Application, CommandHandler, MessageHandler, filters, ContextTypes, CallbackQueryHandler from browse_command import browse_command, button_callback +class MessageHandlerLogicResult(TypedDict): + success: bool + response_text: Union[str, None] + error_message: Union[str, None] + +LogicResult: TypeAlias = MessageHandlerLogicResult + class TelegramHelper: - # --- Constants for configurable paths and magic strings --- - REBOOT_CLAUDE_FILE = '.reboot_claude' - REBOOT_FILE = '.doreboot' CLAUDE_REBOOT_TARGET = 'claude' HTML_QUOTE_BLOCK_START = '
Thinking...' HTML_QUOTE_BLOCK_END = '
' + DEFAULT_REBOOT_CLAUDE_FILE = '.reboot_claude' + DEFAULT_REBOOT_FILE = '.doreboot' + CHUNK_MESSAGE_SLEEP_DURATION = 0.1 - def __init__(self, bot): + def __init__(self, bot, + reboot_claude_file_path: str | None = None, + reboot_file_path: str | None = None, + chunk_message_sleep_duration: float | None = None): self.bot = bot self.telegram_bot_token = os.getenv('TELEGRAM_BOT_TOKEN') self.start_time = time.time() + self.reboot_claude_file = reboot_claude_file_path or self.DEFAULT_REBOOT_CLAUDE_FILE + self.reboot_file = reboot_file_path or self.DEFAULT_REBOOT_FILE + self.chunk_message_sleep_duration = chunk_message_sleep_duration if chunk_message_sleep_duration is not None else self.CHUNK_MESSAGE_SLEEP_DURATION + + async def _start_logic(self) -> str: + await self.bot.start() + return "Hello! I'm your AI assistant. How can I help you today?" async def start(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: - await self.bot.start() - await update.message.reply_text( - "Hello! I'm your AI assistant. How can I help you today?" - ) + response_message = await self._start_logic() + await update.message.reply_text(response_message) + + async def _clear_logic(self, user_id: int) -> str: + self.bot.clear_conversation_history(user_id) + return "Conversation history cleared. Let's start fresh!" async def clear(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: user_id = update.effective_user.id - self.bot.clear_conversation_history(user_id) - await update.message.reply_text("Conversation history cleared. Let's start fresh!") + response_message = await self._clear_logic(user_id) + await update.message.reply_text(response_message) + + async def _status_logic(self) -> str: + return await self.bot.get_bot_status() async def status(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: - status_message = await self.bot.get_bot_status() - await update.message.reply_text(status_message) + response_message = await self._status_logic() + await update.message.reply_text(response_message) + + async def _switch_logic(self) -> str: + if hasattr(self.bot, 'switch_model'): + return await self.bot.switch_model() + else: + return "Model switching is not supported for this bot." async def switch(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: - if hasattr(self.bot, 'switch_model'): - status_message = await self.bot.switch_model() - await update.message.reply_text(status_message) - else: - await update.message.reply_text("Model switching is not supported for this bot.") + response_message = await self._switch_logic() + await update.message.reply_text(response_message) async def update_status_message(self, context: ContextTypes.DEFAULT_TYPE, chat_id: int, message_id: int, status: str): keyboard = [ @@ -54,65 +80,147 @@ class TelegramHelper: reply_markup=reply_markup ) - async def handle_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + async def _handle_message_logic(self, user_id: int, user_message: str) -> LogicResult: try: - user_id = update.effective_user.id - user_message = update.message.text - - logging.info(f"Message from user {user_id}: {user_message}") - - status_message = await update.message.reply_text("Processing your request...", reply_markup=InlineKeyboardMarkup([[InlineKeyboardButton("Abort", callback_data='abort')]])) - self.bot.set_processing_status(user_id, status_message.message_id) - response = await self.bot.handle_message(user_id, user_message) + processed_response = response.replace("", self.HTML_QUOTE_BLOCK_START).replace("", self.HTML_QUOTE_BLOCK_END) + return LogicResult(success=True, response_text=processed_response, error_message=None) + except Exception as e: + logging.error(f"Error in _handle_message_logic for user {user_id}: {str(e)}") + return LogicResult(success=False, response_text=None, error_message=str(e)) - await context.bot.delete_message(chat_id=update.effective_chat.id, message_id=status_message.message_id) + async def handle_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + user_id = update.effective_user.id + user_message = update.message.text + chat_id = update.effective_chat.id + status_message_obj = None + + try: + status_message_obj = await update.message.reply_text( + "Processing your request...", + reply_markup=InlineKeyboardMarkup([[InlineKeyboardButton("Abort", callback_data='abort')]]) + ) + self.bot.set_processing_status(user_id, status_message_obj.message_id) + + logic_result = await self._handle_message_logic(user_id, user_message) + + if status_message_obj: + try: + await context.bot.delete_message(chat_id=chat_id, message_id=status_message_obj.message_id) + except Exception as e_del: + logging.warning(f"Failed to delete status message: {e_del}") self.bot.clear_processing_status(user_id) - response = response.replace("", self.HTML_QUOTE_BLOCK_START).replace("", self.HTML_QUOTE_BLOCK_END) - - if len(response) > 4096: - chunks = [response[i:i + 4096] for i in range(0, len(response), 4096)] - for chunk in chunks: - await update.message.reply_text(chunk) - await asyncio.sleep(0.1) + if logic_result["success"]: + response_text = logic_result["response_text"] + if response_text: + if len(response_text) > 4096: + chunks = [response_text[i:i + 4096] for i in range(0, len(response_text), 4096)] + for chunk in chunks: + await update.message.reply_text(chunk) + await asyncio.sleep(self.chunk_message_sleep_duration) + else: + await update.message.reply_text(response_text) + else: + logging.warning("Successful logic result but no response text.") + await update.message.reply_text("Something went unexpectedly well, but I have nothing to say.") else: - await update.message.reply_text(response) + await update.message.reply_text("Sorry, an error occurred while processing your request.") except Exception as e: - logging.error(f"An error occurred: {str(e)}") - await update.message.reply_text("Sorry, an error occurred while processing your request.") + logging.error(f"Outer error in handle_message for user {user_id}: {str(e)}") + if status_message_obj and self.bot.processing_status.get(user_id): + self.bot.clear_processing_status(user_id) + try: + await update.message.reply_text("Sorry, an unexpected error occurred with the bot.") + except Exception as e_reply: + logging.error(f"Failed to send error reply: {e_reply}") + + async def _abort_processing_logic(self, user_id: int) -> str: + return await self.bot.abort_processing(user_id) async def abort_processing(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: query = update.callback_query await query.answer() - user_id = query.from_user.id - result = await self.bot.abort_processing(user_id) - await query.edit_message_text(text=result) + response_text = await self._abort_processing_logic(user_id) + await query.edit_message_text(text=response_text) + + # --- Reboot Command --- + def _reboot_logic(self, user_message_parts: List[str], chat_id_to_write: str) -> None: + """Handles the logic for creating reboot files.""" + if len(user_message_parts) > 1 and user_message_parts[1].lower() == self.CLAUDE_REBOOT_TARGET: + try: + with open(self.reboot_claude_file, 'w') as f: + f.write("") # Create/truncate the file + logging.info(f"Created/truncated Claude reboot file: {self.reboot_claude_file}") + except IOError as e: + logging.error(f"Failed to create/truncate Claude reboot file {self.reboot_claude_file}: {e}") + + # Create the main reboot file if it doesn't exist + if not os.path.exists(self.reboot_file): + try: + with open(self.reboot_file, 'w') as f: + f.write(chat_id_to_write) + logging.info(f"Created main reboot file: {self.reboot_file} with chat_id.") + except IOError as e: + logging.error(f"Failed to create main reboot file {self.reboot_file}: {e}") + else: + logging.info(f"Main reboot file {self.reboot_file} already exists. Not overwriting chat_id.") async def reboot(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: - user_message = update.message.text.split() - if len(user_message) > 1 and user_message[1].lower() == self.CLAUDE_REBOOT_TARGET: - open(self.REBOOT_CLAUDE_FILE, 'w').close() + """Handles the /reboot command, triggers file creation and exits.""" + user_message_parts = update.message.text.split() + chat_id_str = str(update.effective_chat.id) if update and update.effective_chat else "" + + self._reboot_logic(user_message_parts, chat_id_str) if update: - await update.message.reply_text("Rebooting the bot...") - logging.info("Received reboot command. Exiting process...") - reboot_file_path = self.REBOOT_FILE - if not os.path.exists(reboot_file_path): - with open(reboot_file_path, 'w') as f: - f.write(str(update.effective_chat.id) if update else "") - sys.exit(0) + try: + await update.message.reply_text("Rebooting the bot...") + except Exception as e_reply: + logging.error(f"Failed to send reboot reply: {e_reply}") + + logging.info("Initiating shutdown for reboot...") + sys.exit(0) # This part is not directly testable for completion in unit tests - async def check_doreboot_file(self, application: Application): - reboot_file_path = self.REBOOT_FILE - if os.path.exists(reboot_file_path): - with open(reboot_file_path, 'r') as f: - chat_id = f.read().strip() - if chat_id: + # --- Check Doreboot File --- + async def _check_doreboot_file_logic(self) -> Union[str, None]: + """Checks for the reboot file, reads chat_id, removes file, and returns chat_id.""" + if os.path.exists(self.reboot_file): + chat_id = None + try: + with open(self.reboot_file, 'r') as f: + chat_id = f.read().strip() + # Attempt to remove the file after reading + try: + os.remove(self.reboot_file) + logging.info(f"Successfully read and removed reboot file: {self.reboot_file}") + except OSError as e_remove: + logging.error(f"Failed to remove reboot file {self.reboot_file} after reading: {e_remove}") + # Still return chat_id if read was successful, to attempt notification + return chat_id + except IOError as e_read: + logging.error(f"Error reading reboot file {self.reboot_file}: {e_read}") + # If reading failed, attempt to remove anyway if it exists, to prevent stale files + if os.path.exists(self.reboot_file): + try: + os.remove(self.reboot_file) + logging.warning(f"Removed reboot file {self.reboot_file} after a read error.") + except OSError as e_remove_after_fail: + logging.error(f"Failed to remove reboot file {self.reboot_file} even after a read error: {e_remove_after_fail}") + return None # Reading failed + return None # File does not exist + + async def check_doreboot_file(self, application: Application) -> None: + """Checks for reboot file using logic method and sends notification if applicable.""" + chat_id = await self._check_doreboot_file_logic() + if chat_id: + try: await application.bot.send_message(chat_id=chat_id, text="The application has finished initializing.") - os.remove(reboot_file_path) + logging.info(f"Sent reboot initialization notification to chat_id: {chat_id}") + except Exception as e: + logging.error(f"Failed to send reboot initialization notification to chat_id {chat_id}: {e}") async def browse(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: await browse_command(update, context, self.bot) @@ -132,6 +240,10 @@ class TelegramHelper: logging.info("Bot is running...") - asyncio.get_event_loop().create_task(self.check_doreboot_file(application)) + loop = asyncio.get_event_loop() + if loop.is_running(): # pragma: no cover + loop.create_task(self.check_doreboot_file(application)) + else: # pragma: no cover + asyncio.run(self.check_doreboot_file(application)) application.run_polling() diff --git a/tests/test_anthropic_telegram_inference_bot.py b/tests/test_anthropic_telegram_inference_bot.py new file mode 100644 index 0000000..c7c715f --- /dev/null +++ b/tests/test_anthropic_telegram_inference_bot.py @@ -0,0 +1,280 @@ +import unittest +from unittest.mock import MagicMock, patch, AsyncMock, ANY +import os + +# Assuming anthropic_telegram_inference_bot.py is in the parent directory or PYTHONPATH is set +from anthropic_telegram_inference_bot import AnthropicTelegramInferenceBot + +# Mock response from Anthropic client's messages.create +def create_mock_anthropic_response(content_text=None, stop_reason="end_turn", tool_use_parts=None): + mock_response = MagicMock() + mock_response.stop_reason = stop_reason + + content_blocks = [] + if content_text: + text_block = MagicMock() + text_block.type = "text" + text_block.text = content_text + content_blocks.append(text_block) + + if tool_use_parts: + for tu_part in tool_use_parts: # tu_part = {"id": "toolu_123", "name": "get_weather", "input": {}} + tool_block = MagicMock() + tool_block.type = "tool_use" + tool_block.id = tu_part["id"] + tool_block.name = tu_part["name"] + tool_block.input = tu_part["input"] + content_blocks.append(tool_block) + + mock_response.content = content_blocks + return mock_response + +class TestAnthropicTelegramInferenceBot(unittest.IsolatedAsyncioTestCase): + + def setUp(self): + self.original_anthropic_api_key = os.environ.get("ANTHROPIC_API_KEY") + self.original_small_model = os.environ.get("ANTHROPIC_SMALL_MODEL") + self.original_large_model = os.environ.get("ANTHROPIC_LARGE_MODEL") + self.original_system_prompt_path = os.environ.get("SYSTEM_PROMPT_PATH") + + for key in ["ANTHROPIC_API_KEY", "ANTHROPIC_SMALL_MODEL", "ANTHROPIC_LARGE_MODEL", "SYSTEM_PROMPT_PATH"]: + if os.environ.get(key): + del os.environ[key] + + self.mock_anthropic_client_instance = MagicMock() + self.mock_anthropic_client_instance.messages.create = MagicMock() + + def tearDown(self): + if self.original_anthropic_api_key: os.environ["ANTHROPIC_API_KEY"] = self.original_anthropic_api_key + if self.original_small_model: os.environ["ANTHROPIC_SMALL_MODEL"] = self.original_small_model + if self.original_large_model: os.environ["ANTHROPIC_LARGE_MODEL"] = self.original_large_model + if self.original_system_prompt_path: os.environ["SYSTEM_PROMPT_PATH"] = self.original_system_prompt_path + + @patch('anthropic.Anthropic') + def test_init_with_anthropic_defaults_env_key(self, MockAnthropicConstructor): + MockAnthropicConstructor.return_value = self.mock_anthropic_client_instance + os.environ["ANTHROPIC_API_KEY"] = "test_anthropic_key" + + bot = AnthropicTelegramInferenceBot() + + MockAnthropicConstructor.assert_called_once_with(api_key="test_anthropic_key") + self.assertEqual(bot.anthropic_client, self.mock_anthropic_client_instance) + self.assertEqual(bot.model, os.environ.get("ANTHROPIC_SMALL_MODEL", "claude-3-haiku-20240307")) + self.assertEqual(bot.max_tokens, int(os.environ.get("ANTHROPIC_SMALL_MODEL_MAX_TOKENS", 2000))) + + @patch('anthropic.Anthropic') + def test_init_with_provided_client_and_models(self, MockAnthropicConstructor): + preconfigured_client = MagicMock() + bot = AnthropicTelegramInferenceBot( + anthropic_client=preconfigured_client, + small_model_name="custom-small", + small_model_max_tokens=100, + large_model_name="custom-large", + large_model_max_tokens=200 + ) + + MockAnthropicConstructor.assert_not_called() + self.assertEqual(bot.anthropic_client, preconfigured_client) + self.assertEqual(bot.model, "custom-small") + self.assertEqual(bot.max_tokens, 100) + self.assertEqual(bot.small_model_name, "custom-small") + self.assertEqual(bot.large_model_name, "custom-large") + + + def test_get_llm_description(self): + bot = AnthropicTelegramInferenceBot(small_model_name="claude-test", small_model_max_tokens=500) + self.assertEqual(bot.get_llm_description(), "LLM: claude-test, Max Tokens: 500") + + async def test_switch_model(self): + bot = AnthropicTelegramInferenceBot( + small_model_name="claude-small", small_model_max_tokens=10, + large_model_name="claude-large", large_model_max_tokens=20 + ) + self.assertEqual(bot.model, "claude-small") + self.assertEqual(bot.max_tokens, 10) + + status = await bot.switch_model() + self.assertEqual(bot.model, "claude-large") + self.assertEqual(bot.max_tokens, 20) + self.assertEqual(status, "Switched to model: claude-large") + + status = await bot.switch_model() + self.assertEqual(bot.model, "claude-small") + self.assertEqual(bot.max_tokens, 10) + self.assertEqual(status, "Switched to model: claude-small") + + def test_get_chat_response_success_text_only(self): + bot = AnthropicTelegramInferenceBot(anthropic_client=self.mock_anthropic_client_instance) + bot.model = "test-claude" + bot.max_tokens = 150 + + mock_api_response = create_mock_anthropic_response(content_text="Hello from Anthropic API") + self.mock_anthropic_client_instance.messages.create.return_value = mock_api_response + + messages = [{"role": "user", "content": "Hi"}] # Anthropic format + response = bot.get_chat_response(messages, []) # tools = empty list + + self.mock_anthropic_client_instance.messages.create.assert_called_once_with( + model="test-claude", + max_tokens=150, + messages=messages, + system=bot.system_prompt, # Ensure system prompt is passed + tools=None, # No tools passed to API if empty list or None + tool_choice=None + ) + self.assertEqual(response, mock_api_response) + + def test_get_chat_response_with_tools(self): + bot = AnthropicTelegramInferenceBot(anthropic_client=self.mock_anthropic_client_instance) + bot.model = "claude-toolmaster" + bot.max_tokens = 300 + + mock_tools_spec = [{"name": "get_weather", "description": "Gets weather", "input_schema": {"type": "object", "properties": {}}}] + + mock_api_response = create_mock_anthropic_response(content_text="Thinking...", tool_use_parts=[ + {"id": "tool1", "name": "get_weather", "input": {"location": "here"}} + ]) + self.mock_anthropic_client_instance.messages.create.return_value = mock_api_response + + messages = [{"role": "user", "content": "Weather?"}] + response = bot.get_chat_response(messages, mock_tools_spec) + + self.mock_anthropic_client_instance.messages.create.assert_called_once_with( + model="claude-toolmaster", + max_tokens=300, + messages=messages, + system=bot.system_prompt, + tools=mock_tools_spec, + tool_choice={"type": "auto"} + ) + self.assertEqual(response.content[0].type, "text") # First part can be text + self.assertEqual(response.content[1].type, "tool_use") + + + def test_get_chat_response_api_error(self): + bot = AnthropicTelegramInferenceBot(anthropic_client=self.mock_anthropic_client_instance) + self.mock_anthropic_client_instance.messages.create.side_effect = Exception("Anthropic API Down") + + with self.assertRaisesRegex(Exception, "Anthropic API Down"): + bot.get_chat_response([{"role": "user", "content": "trigger"}], []) + + + async def test_handle_message_simple_response_no_tools(self): + # This test is more involved as it touches BaseTelegramInferenceBot's handle_message structure + # which then calls the overridden get_chat_response. + bot = AnthropicTelegramInferenceBot(anthropic_client=self.mock_anthropic_client_instance) + bot.system_prompt = "System prompt for Anthropic" + + # Mock get_chat_response directly to isolate its behavior from full handle_message logic of base + # However, the point of this bot is its get_chat_response and subsequent processing. + # So, let's mock the API call within get_chat_response. + + api_response = create_mock_anthropic_response(content_text="Anthropic says hello.") + self.mock_anthropic_client_instance.messages.create.return_value = api_response + + # Ensure functions are empty for this test, so no tool logic is triggered + bot.functions = [] + bot.tools = [] + + response_content = await bot.handle_message(user_id=101, user_message="Hello Anthropic") + + self.assertEqual(response_content, "Anthropic says hello.") + self.assertIn(101, bot.conversation_history) + # Anthropic's handle_message structure: + # 1. User message added to history. + # 2. get_chat_response is called. + # 3. Response content (text) is extracted. + # 4. Assistant text response is added to history. + # Expected history: [User, Assistant_Text_Response] (system prompt handled by get_chat_response) + # The base class handle_message adds system prompt if not present. + # Anthropic handle_message modifies history format before calling get_chat_response. + + # Let's trace Base.handle_message -> Anthropic.handle_message -> Anthropic.get_chat_response + # Base.handle_message: + # - Adds system prompt to history if first turn: `self.conversation_history[user_id] = [{"role": "system", "content": self.system_prompt}]` (OpenAI style) + # - Appends user message: `{"role": "user", "content": user_message}` + # - Calls self.get_chat_response(messages, self.functions) -> This is Anthropic's get_chat_response + # Anthropic.get_chat_response: + # - Takes OpenAI style `messages` and `self.functions` (tool specs). + # - Calls `anthropic_client.messages.create` with Anthropic style messages and system prompt. + # Anthropic.handle_message (overridden): + # - Prepares Anthropic-style messages from conversation_history (which is OpenAI style from Base) + # - Calls get_chat_response with these Anthropic messages and self.functions (tool_specs) + # - Processes response, extracts text, handles tool calls. + # - Appends *user* message (original) and *assistant* text response to self.conversation_history (OpenAI style). + + # For this test, we are calling AnthropicBot.handle_message directly. + # 1. `user_id` not in `self.conversation_history`: `system_prompt` not added yet by Base logic. + # Anthropic's `handle_message` will create `anthropic_messages` from this. + # If `conversation_history` is empty, `anthropic_messages` = `[{"role": "user", "content": user_message}]` + # 2. `get_chat_response` called with `anthropic_messages` and `bot.system_prompt` passed to API. + # 3. Response "Anthropic says hello." + # 4. Original `user_message` and "Anthropic says hello." (as assistant) added to `self.conversation_history`. + + history = bot.conversation_history[101] + self.assertEqual(len(history), 2) # User, Assistant + self.assertEqual(history[0]["role"], "user") + self.assertEqual(history[0]["content"], "Hello Anthropic") + self.assertEqual(history[1]["role"], "assistant") + self.assertEqual(history[1]["content"], "Anthropic says hello.") + + # Check API call (made by the mocked get_chat_response indirectly) + self.mock_anthropic_client_instance.messages.create.assert_called_once() + call_args = self.mock_anthropic_client_instance.messages.create.call_args + self.assertEqual(call_args.kwargs["system"], "System prompt for Anthropic") + # Initial messages for API should just be the user message for first turn + self.assertEqual(call_args.kwargs["messages"], [{"role": "user", "content": "Hello Anthropic"}]) + + + async def test_handle_message_with_tool_calls(self): + bot = AnthropicTelegramInferenceBot(anthropic_client=self.mock_anthropic_client_instance) + bot.system_prompt = "You are a helpful, tool-using assistant." + + # Define a tool for the bot (OpenAI format, will be converted by Anthropic bot for API) + mock_tool_oai_format = {"type": "function", "function": {"name": "get_weather", "description": "Get weather", "parameters": {}}} + bot.functions = [mock_tool_oai_format] # This is used to generate anthropic_tools for API + + # API Response 1: Request for tool call + tool_use_part = {"id": "toolu_xyz", "name": "get_weather", "input": {"location": "paris"}} + api_response_1 = create_mock_anthropic_response(tool_use_parts=[tool_use_part]) + + # API Response 2: Final text response after tool execution + api_response_2 = create_mock_anthropic_response(content_text="The weather in Paris is nice.") + + self.mock_anthropic_client_instance.messages.create.side_effect = [api_response_1, api_response_2] + + # Mock the bot's call_tool method (from BaseTelegramInferenceBot) + bot.call_tool = MagicMock(return_value='''{"weather": "sunny"}''') # Tool execution result + + user_id = 102 + user_message = "What's the weather in Paris?" + final_text_response = await bot.handle_message(user_id, user_message) + + self.assertEqual(final_text_response, "The weather in Paris is nice.") + self.assertEqual(self.mock_anthropic_client_instance.messages.create.call_count, 2) + + bot.call_tool.assert_called_once_with("get_weather", {"location": "paris"}) # Anthropic passes input as dict + + # Check conversation history (OpenAI style) + history = bot.conversation_history[user_id] + self.assertEqual(history[0]["role"], "user") + self.assertEqual(history[0]["content"], user_message) + + # Assistant message that requested tool call (Anthropic-specific format stored by its handle_message) + # Anthropic's handle_message appends the raw tool_use block and then the tool_result + self.assertEqual(history[1]["role"], "assistant") + self.assertTrue(isinstance(history[1]["content"], list)) # Anthropic content is a list + self.assertEqual(history[1]["content"][0]["type"], "tool_use") + self.assertEqual(history[1]["content"][0]["id"], "toolu_xyz") + + self.assertEqual(history[2]["role"], "tool") + self.assertEqual(history[2]["tool_call_id"], "toolu_xyz") + self.assertEqual(history[2]["name"], "get_weather") + self.assertEqual(history[2]["content"], '''{"weather": "sunny"}''') # call_tool result + + self.assertEqual(history[3]["role"], "assistant") # Final text response + self.assertTrue(isinstance(history[3]["content"], str)) # simple text + self.assertEqual(history[3]["content"], "The weather in Paris is nice.") + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_base_telegram_inference_bot.py b/tests/test_base_telegram_inference_bot.py new file mode 100644 index 0000000..bcb2b52 --- /dev/null +++ b/tests/test_base_telegram_inference_bot.py @@ -0,0 +1,310 @@ +import unittest +from unittest.mock import patch, mock_open, MagicMock +import os +import json + +# Ensure the path includes the directory where base_telegram_inference_bot is located +# This might require adjustment based on actual project structure if tests are run from root +# For now, assuming it can be imported directly or via PYTHONPATH +from base_telegram_inference_bot import BaseTelegramInferenceBot +from tools.base_tool import BaseTool # For mocking tool structure + +# Create a concrete subclass for testing, as BaseTelegramInferenceBot is abstract +class ConcreteTestBot(BaseTelegramInferenceBot): + def __init__(self, system_prompt_content=None, system_prompt_path=None, mock_tools=None, mock_functions=None): + # Mock load_functions during super().__init__ if needed, or control tools/functions directly + self._mock_tools = mock_tools if mock_tools is not None else [] + self._mock_functions = mock_functions if mock_functions is not None else [] + super().__init__(system_prompt_content=system_prompt_content, system_prompt_path=system_prompt_path) + + # Override load_functions to use mocks + def load_functions(self): + return self._mock_tools, self._mock_functions + + def get_chat_response(self, messages): + pass # Abstract method, not tested here directly + + async def handle_message(self, user_id, user_message): + pass # Abstract method + + def get_llm_description(self) -> str: + return "Mock LLM Description" # Concrete implementation for testing get_bot_status + + async def start(self): + pass # Abstract method + + async def abort_processing(self, user_id): + pass # Abstract method + + async def switch_model(self): + pass # Abstract method + +class TestBaseTelegramInferenceBot(unittest.TestCase): + + def setUp(self): + # Reset relevant environment variables before each test + self.original_system_prompt_path = os.environ.get("SYSTEM_PROMPT_PATH") + if "SYSTEM_PROMPT_PATH" in os.environ: + del os.environ["SYSTEM_PROMPT_PATH"] + + def tearDown(self): + # Restore environment variables + if self.original_system_prompt_path: + os.environ["SYSTEM_PROMPT_PATH"] = self.original_system_prompt_path + elif "SYSTEM_PROMPT_PATH" in os.environ: # Ensure it's removed if test set it and it wasn't there before + del os.environ["SYSTEM_PROMPT_PATH"] + + def test_init_with_direct_system_prompt(self): + bot = ConcreteTestBot(system_prompt_content="Direct prompt content") + self.assertEqual(bot.system_prompt, "Direct prompt content") + + @patch("os.path.isfile") + @patch("builtins.open", new_callable=mock_open, read_data="File prompt content") + def test_init_with_system_prompt_path_argument(self, mock_file_open, mock_isfile): + mock_isfile.return_value = True + bot = ConcreteTestBot(system_prompt_path="dummy/path.txt") + self.assertEqual(bot.system_prompt, "File prompt content") + mock_isfile.assert_called_once_with("dummy/path.txt") + mock_file_open.assert_called_once_with("dummy/path.txt", "r", encoding="utf-8") + + @patch("os.path.isfile") + @patch("builtins.open", new_callable=mock_open, read_data="Env prompt content") + def test_init_with_env_system_prompt_path(self, mock_file_open, mock_isfile): + mock_isfile.return_value = True + os.environ["SYSTEM_PROMPT_PATH"] = "env/path.txt" + bot = ConcreteTestBot() + self.assertEqual(bot.system_prompt, "Env prompt content") + mock_isfile.assert_called_once_with("env/path.txt") + mock_file_open.assert_called_once_with("env/path.txt", "r", encoding="utf-8") + + def test_init_with_default_system_prompt(self): + # Ensure ENV var is not set for this test + if "SYSTEM_PROMPT_PATH" in os.environ: + del os.environ["SYSTEM_PROMPT_PATH"] + bot = ConcreteTestBot() + self.assertEqual(bot.system_prompt, "You are a helpful AI assistant.") + + @patch("os.path.isfile", return_value=False) + def test_init_with_invalid_system_prompt_path(self, mock_isfile): + bot = ConcreteTestBot(system_prompt_path="invalid/path.txt") + self.assertEqual(bot.system_prompt, "You are a helpful AI assistant.") + mock_isfile.assert_called_once_with("invalid/path.txt") + + @patch("os.path.isfile") + @patch("builtins.open", side_effect=IOError("File read error")) + def test_init_with_system_prompt_file_read_error(self, mock_file_open, mock_isfile): + mock_isfile.return_value = True + bot = ConcreteTestBot(system_prompt_path="dummy/path.txt") + self.assertEqual(bot.system_prompt, "You are a helpful AI assistant.") + + def test_clear_conversation_history(self): + mock_tool_instance = MagicMock(spec=BaseTool) + bot = ConcreteTestBot(mock_tools=[mock_tool_instance]) + bot.conversation_history[123] = [{"role": "user", "content": "Hello"}] + + bot.clear_conversation_history(123) + self.assertNotIn(123, bot.conversation_history) + mock_tool_instance.clear.assert_called_once() + + def test_clear_conversation_history_user_not_found(self): + mock_tool_instance = MagicMock(spec=BaseTool) + bot = ConcreteTestBot(mock_tools=[mock_tool_instance]) + bot.clear_conversation_history(404) + self.assertNotIn(404, bot.conversation_history) + mock_tool_instance.clear.assert_called_once() + + def test_processing_status(self): + bot = ConcreteTestBot() + self.assertEqual(bot.processing_status, {}) + bot.set_processing_status(123, 789) + self.assertEqual(bot.processing_status[123], {"processing": True, "message_id": 789}) + bot.clear_processing_status(123) + self.assertNotIn(123, bot.processing_status) + + def test_clear_processing_status_user_not_found(self): + bot = ConcreteTestBot() + bot.clear_processing_status(404) + self.assertNotIn(404, bot.processing_status) + + def test_call_tool_success_dict_args(self): + mock_tool = MagicMock(spec=BaseTool) + mock_tool.get_functions.return_value = [ + {"function": {"name": "test_tool", "description": "A test tool", "parameters": {}}} + ] + mock_tool.execute.return_value = "Tool executed successfully" + + bot = ConcreteTestBot(mock_tools=[mock_tool], mock_functions=mock_tool.get_functions()) + + result = bot.call_tool("test_tool", {"arg1": "value1"}) + self.assertEqual(result, "Tool executed successfully") + mock_tool.execute.assert_called_once_with("test_tool", arg1="value1") + + def test_call_tool_success_json_string_args(self): + mock_tool = MagicMock(spec=BaseTool) + mock_tool.get_functions.return_value = [ + {"function": {"name": "test_tool_json", "parameters": {}}} + ] + mock_tool.execute.return_value = "Tool JSON OK" + bot = ConcreteTestBot(mock_tools=[mock_tool], mock_functions=mock_tool.get_functions()) + + args_json_str = '''{"param": "value"}''' + result = bot.call_tool("test_tool_json", args_json_str) + self.assertEqual(result, "Tool JSON OK") + mock_tool.execute.assert_called_once_with("test_tool_json", param="value") + + def test_call_tool_malformed_json_string_args(self): + bot = ConcreteTestBot(mock_tools=[]) + args_malformed_json_str = '''{"param": "value"''' + result = bot.call_tool("some_tool", args_malformed_json_str) + self.assertTrue("Error: Malformed arguments for tool call" in result) + + def test_call_tool_unexpected_arg_type(self): + bot = ConcreteTestBot(mock_tools=[]) + result = bot.call_tool("some_tool", 12345) # Integer instead of dict/str + self.assertTrue("Error: Invalid argument type for tool call" in result) + + def test_call_tool_none_args(self): + mock_tool = MagicMock(spec=BaseTool) + mock_tool.get_functions.return_value = [ + {"function": {"name": "test_tool_none", "parameters": {}}} + ] + mock_tool.execute.return_value = "Tool None OK" + bot = ConcreteTestBot(mock_tools=[mock_tool], mock_functions=mock_tool.get_functions()) + + result = bot.call_tool("test_tool_none", None) + self.assertEqual(result, "Tool None OK") + mock_tool.execute.assert_called_once_with("test_tool_none") # No kwargs if None + + def test_call_tool_not_found(self): + bot = ConcreteTestBot(mock_tools=[]) + result = bot.call_tool("non_existent_tool", {}) + self.assertEqual(result, "Error: Tool function non_existent_tool not found.") + + def test_call_tool_execute_exception(self): + mock_tool = MagicMock(spec=BaseTool) + mock_tool.get_functions.return_value = [{"function": {"name": "error_tool", "parameters": {}}}] + mock_tool.execute.side_effect = Exception("Execution failed") + bot = ConcreteTestBot(mock_tools=[mock_tool], mock_functions=mock_tool.get_functions()) + + result = bot.call_tool("error_tool", {}) + self.assertEqual(result, "Error executing tool error_tool: Execution failed") + + def test_get_system_prompt_description(self): + if "SYSTEM_PROMPT_PATH" in os.environ: # Ensure clean state + del os.environ["SYSTEM_PROMPT_PATH"] + + bot_default = ConcreteTestBot() + self.assertEqual(bot_default.get_system_prompt_description(), "System Prompt: Default") + + bot_custom_content = ConcreteTestBot(system_prompt_content="Custom content here") + self.assertEqual(bot_custom_content.get_system_prompt_description(), "System Prompt: Custom") + + os.environ["SYSTEM_PROMPT_PATH"] = "some/path.txt" + bot_env_default_prompt = ConcreteTestBot() # system_prompt itself is default + self.assertEqual(bot_env_default_prompt.get_system_prompt_description(), "System Prompt: Custom (via ENV)") + + with patch("os.path.isfile", return_value=True), \ + patch("builtins.open", mock_open(read_data="File prompt from ENV")): + bot_env_file_prompt = ConcreteTestBot() # system_prompt gets loaded from ENV path + self.assertEqual(bot_env_file_prompt.get_system_prompt_description(), "System Prompt: Custom") + del os.environ["SYSTEM_PROMPT_PATH"] + + with patch("os.path.isfile", return_value=True), \ + patch("builtins.open", mock_open(read_data="File prompt from arg")): + bot_custom_file_arg = ConcreteTestBot(system_prompt_path="custom/file.txt") + self.assertEqual(bot_custom_file_arg.get_system_prompt_description(), "System Prompt: Custom") + + @patch.object(ConcreteTestBot, 'get_llm_description', return_value="Test LLM Description") + @patch.object(ConcreteTestBot, 'get_system_prompt_description', return_value="Test Prompt Description") + async def test_get_bot_status(self, mock_prompt_desc, mock_llm_desc): + bot = ConcreteTestBot() + status = await bot.get_bot_status() + self.assertEqual(status, "Test Prompt Description\nTest LLM Description") + mock_prompt_desc.assert_called_once() + mock_llm_desc.assert_called_once() + + @patch('os.path.dirname', return_value='/mock/path') + @patch('os.path.join') + @patch('os.path.exists') + @patch('os.listdir') + @patch('importlib.import_module') + def test_load_functions_no_tools_dir(self, mock_import_module, mock_listdir, mock_exists, mock_join, mock_dirname): + mock_join.return_value = '/mock/path/tools' + mock_exists.return_value = False + + class BotForLoadTest(BaseTelegramInferenceBot): + load_system_prompt = MagicMock(return_value="Default") + get_chat_response = MagicMock(); handle_message = MagicMock(); get_llm_description = MagicMock(return_value="mock") + start = MagicMock(); abort_processing = MagicMock(); switch_model = MagicMock() + + bot = BotForLoadTest() + self.assertEqual(bot.tools, []) + self.assertEqual(bot.functions, []) + mock_listdir.assert_not_called() + + @patch('os.path.dirname', return_value='/mock/base_bot_dir') + @patch('os.path.join', side_effect=lambda *args: os.path.normpath(os.path.join(*args))) + @patch('os.path.exists', return_value=True) + @patch('os.listdir', return_value=['my_tool.py', '__init__.py', 'base_tool.py']) + @patch('importlib.import_module') + def test_load_functions_with_one_tool(self, mock_import_module, mock_listdir, mock_exists, mock_join, mock_dirname): + + mock_tool_class = MagicMock(spec=BaseTool) # This is the class itself + mock_tool_instance = MagicMock(spec=BaseTool) # This is the instance + mock_tool_class.return_value = mock_tool_instance # mock_tool_class() creates mock_tool_instance + mock_tool_instance.get_functions.return_value = [{"function": {"name": "sample_function"}}] + + mock_my_tool_module = MagicMock() + # Simulate inspect.getmembers behavior: returns list of (name, member) tuples + # Only include members that are classes, derive from BaseTool, and are not BaseTool itself. + mock_my_tool_module.ValidToolClass = mock_tool_class + mock_my_tool_module.NotATool = object() + mock_my_tool_module.BaseTool = BaseTool # This should be skipped by the loader + + def import_side_effect(module_name): + if module_name == 'tools.my_tool': + return mock_my_tool_module + raise ImportError(f"Unexpected import: {module_name}") + mock_import_module.side_effect = import_side_effect + + class BotForLoadTest(BaseTelegramInferenceBot): + load_system_prompt = MagicMock(return_value="Default") + get_chat_response = MagicMock(); handle_message = MagicMock(); get_llm_description = MagicMock(return_value="mock") + start = MagicMock(); abort_processing = MagicMock(); switch_model = MagicMock() + + bot = BotForLoadTest() + self.assertEqual(len(bot.tools), 1) + self.assertIs(bot.tools[0], mock_tool_instance) + self.assertEqual(len(bot.functions), 1) + self.assertEqual(bot.functions[0]['function']['name'], "sample_function") + mock_import_module.assert_called_once_with('tools.my_tool') + mock_tool_class.assert_called_once_with() # Tool class was instantiated + mock_tool_instance.get_functions.assert_called_once_with() + + @patch('os.path.dirname', return_value='/mock/base_bot_dir') + @patch('os.path.join', side_effect=lambda *args: os.path.normpath(os.path.join(*args))) + @patch('os.path.exists', return_value=True) + @patch('os.listdir', return_value=['tool_with_init_error.py']) + @patch('importlib.import_module') + @patch('logging.error') # Mock logging to check for error messages + def test_load_functions_tool_instantiation_error(self, mock_logging_error, mock_import_module, mock_listdir, mock_exists, mock_join, mock_dirname): + mock_tool_class_init_error = MagicMock(spec=BaseTool) + mock_tool_class_init_error.side_effect = Exception("Failed to init tool") # Error on instantiation + + mock_error_tool_module = MagicMock() + mock_error_tool_module.ToolWithInitError = mock_tool_class_init_error + + mock_import_module.return_value = mock_error_tool_module + + class BotForLoadTest(BaseTelegramInferenceBot): + load_system_prompt = MagicMock(return_value="Default") + get_chat_response = MagicMock(); handle_message = MagicMock(); get_llm_description = MagicMock(return_value="mock") + start = MagicMock(); abort_processing = MagicMock(); switch_model = MagicMock() + + bot = BotForLoadTest() + self.assertEqual(len(bot.tools), 0) + self.assertEqual(len(bot.functions), 0) + mock_logging_error.assert_any_call("Error instantiating tool ToolWithInitError from tool_with_init_error.py: Failed to init tool") + +if __name__ == '__main__': + unittest.main(闂傚лен䦗婢у〃埊鍓解劓姣) diff --git a/tests/test_chatgpt_telegram_inference_bot.py b/tests/test_chatgpt_telegram_inference_bot.py new file mode 100644 index 0000000..a0f0bdb --- /dev/null +++ b/tests/test_chatgpt_telegram_inference_bot.py @@ -0,0 +1,158 @@ +import unittest +from unittest.mock import MagicMock, patch, ANY +import os + +# Assuming chatgpt_telegram_inference_bot.py and its parent are accessible +from chatgpt_telegram_inference_bot import ChatGPTTelegramInferenceBot +from openai_compatible_inference_bot import OpenAICompatibleInferenceBot # For patching super + +class TestChatGPTTelegramInferenceBot(unittest.IsolatedAsyncioTestCase): + + def setUp(self): + # Store and clear relevant environment variables + self.original_openai_key = os.environ.get("OPENAI_API_KEY") + self.original_small_model = os.environ.get("OPENAI_SMALL_MODEL") + self.original_large_model = os.environ.get("OPENAI_LARGE_MODEL") + self.original_small_tokens = os.environ.get("OPENAI_SMALL_MODEL_MAX_TOKENS") + self.original_large_tokens = os.environ.get("OPENAI_LARGE_MODEL_MAX_TOKENS") + self.original_system_prompt_path = os.environ.get("SYSTEM_PROMPT_PATH") + + for key in ["OPENAI_API_KEY", "OPENAI_SMALL_MODEL", "OPENAI_LARGE_MODEL", + "OPENAI_SMALL_MODEL_MAX_TOKENS", "OPENAI_LARGE_MODEL_MAX_TOKENS", "SYSTEM_PROMPT_PATH"]: + if os.environ.get(key): + del os.environ[key] + + # Mock the OpenAI client that OpenAICompatibleInferenceBot's __init__ might create + self.mock_openai_client = MagicMock() + + def tearDown(self): + # Restore environment variables + if self.original_openai_key: os.environ["OPENAI_API_KEY"] = self.original_openai_key + if self.original_small_model: os.environ["OPENAI_SMALL_MODEL"] = self.original_small_model + if self.original_large_model: os.environ["OPENAI_LARGE_MODEL"] = self.original_large_model + if self.original_small_tokens: os.environ["OPENAI_SMALL_MODEL_MAX_TOKENS"] = self.original_small_tokens + if self.original_large_tokens: os.environ["OPENAI_LARGE_MODEL_MAX_TOKENS"] = self.original_large_tokens + if self.original_system_prompt_path: os.environ["SYSTEM_PROMPT_PATH"] = self.original_system_prompt_path + + + @patch.object(OpenAICompatibleInferenceBot, '__init__') # Mock the superclass's __init__ + def test_init_defaults_and_super_call(self, mock_super_init): + os.environ["OPENAI_API_KEY"] = "test_key_chatgpt" + os.environ["OPENAI_SMALL_MODEL"] = "gpt-3.5-turbo-env" + os.environ["OPENAI_SMALL_MODEL_MAX_TOKENS"] = "350" + + bot = ChatGPTTelegramInferenceBot() + + mock_super_init.assert_called_once_with( + client=None, # ChatGPT bot will let superclass create it + api_key="test_key_chatgpt", # Passed to super + base_url=None, + api_version=None, + azure_deployment=None, + model_name="gpt-3.5-turbo-env", # Default small model from env + max_tokens_str="350", # Default small model tokens from env + small_model_name="gpt-3.5-turbo-env", + small_model_max_tokens_str="350", + large_model_name=os.environ.get("OPENAI_LARGE_MODEL", "gpt-4-turbo-preview"), # Default large + large_model_max_tokens_str=os.environ.get("OPENAI_LARGE_MODEL_MAX_TOKENS"), + system_prompt_content=None, + system_prompt_path=None, + is_gemini=False, + max_history_length=20 # Default from OpenAICompatibleInferenceBot + ) + + @patch.object(OpenAICompatibleInferenceBot, '__init__') + def test_init_with_arguments(self, mock_super_init): + mock_client_arg = MagicMock() + bot = ChatGPTTelegramInferenceBot( + openai_client=mock_client_arg, + api_key="arg_key", + small_model_name="arg_small_model", + small_model_max_tokens="123", + large_model_name="arg_large_model", + large_model_max_tokens="456", + system_prompt_content="Arg prompt" + ) + mock_super_init.assert_called_once_with( + client=mock_client_arg, + api_key="arg_key", + base_url=None, + api_version=None, + azure_deployment=None, + model_name="arg_small_model", # Initially configured with small model + max_tokens_str="123", + small_model_name="arg_small_model", + small_model_max_tokens_str="123", + large_model_name="arg_large_model", + large_model_max_tokens_str="456", + system_prompt_content="Arg prompt", + system_prompt_path=None, + is_gemini=False, + max_history_length=20 + ) + + # Test switch_model - this method is part of ChatGPTTelegramInferenceBot + # It calls _configure_model_and_tokens which is in the superclass. + # We need a bot instance where _configure_model_and_tokens can be called. + @patch('openai.OpenAI') # To allow instantiation of the bot by mocking client creation + async def test_switch_model_logic(self, mock_openai_constructor): + mock_openai_constructor.return_value = self.mock_openai_client # Mock client creation in super + + # Set env vars for model names that switch_model will use as fallback + os.environ["OPENAI_SMALL_MODEL"] = "env-small-gpt" + os.environ["OPENAI_SMALL_MODEL_MAX_TOKENS"] = "100" + os.environ["OPENAI_LARGE_MODEL"] = "env-large-gpt" + os.environ["OPENAI_LARGE_MODEL_MAX_TOKENS"] = "200" + + # Instantiate with initial model (small) + bot = ChatGPTTelegramInferenceBot() + self.assertEqual(bot.model, "env-small-gpt") + self.assertEqual(bot.max_tokens, 100) + + # Switch to large + status = await bot.switch_model() + self.assertEqual(bot.model, "env-large-gpt") + self.assertEqual(bot.max_tokens, 200) + self.assertEqual(status, "Switched to model: env-large-gpt") + + # Switch back to small + status = await bot.switch_model() + self.assertEqual(bot.model, "env-small-gpt") + self.assertEqual(bot.max_tokens, 100) + self.assertEqual(status, "Switched to model: env-small-gpt") + + @patch('openai.OpenAI') + async def test_switch_model_uses_instance_configs_if_provided(self, mock_openai_constructor): + mock_openai_constructor.return_value = self.mock_openai_client + + # Instantiate with specific model names, overriding potential env vars + bot = ChatGPTTelegramInferenceBot( + small_model_name="init-small", small_model_max_tokens="50", + large_model_name="init-large", large_model_max_tokens="150" + ) + self.assertEqual(bot.model, "init-small") # Starts with small + self.assertEqual(bot.max_tokens, 50) + + # Switch to large + status = await bot.switch_model() + self.assertEqual(bot.model, "init-large") + self.assertEqual(bot.max_tokens, 150) + self.assertEqual(status, "Switched to model: init-large") + + # Switch back to small + status = await bot.switch_model() + self.assertEqual(bot.model, "init-small") + self.assertEqual(bot.max_tokens, 50) + self.assertEqual(status, "Switched to model: init-small") + + # get_llm_description is inherited from OpenAICompatibleInferenceBot. + # Test just to ensure it works in the context of a ChatGPTBot instance + @patch('openai.OpenAI') + def test_get_llm_description_for_chatgpt_bot(self, mock_openai_constructor): + mock_openai_constructor.return_value = self.mock_openai_client + bot = ChatGPTTelegramInferenceBot(small_model_name="gpt-3.5-desc", small_model_max_tokens="777") + # Initially configured with small model + self.assertEqual(bot.get_llm_description(), "LLM: gpt-3.5-desc, Max Tokens: 777, Azure: False") + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_gemini_telegram_inference_bot.py b/tests/test_gemini_telegram_inference_bot.py new file mode 100644 index 0000000..8e5cc4f --- /dev/null +++ b/tests/test_gemini_telegram_inference_bot.py @@ -0,0 +1,154 @@ +import unittest +from unittest.mock import MagicMock, patch, ANY +import os + +# Assuming gemini_telegram_inference_bot.py and its parent are accessible +from gemini_telegram_inference_bot import GeminiTelegramInferenceBot +from openai_compatible_inference_bot import OpenAICompatibleInferenceBot # For patching super + +class TestGeminiTelegramInferenceBot(unittest.IsolatedAsyncioTestCase): + + def setUp(self): + # Store and clear relevant environment variables + self.original_gemini_key = os.environ.get("GEMINI_API_KEY") + self.original_gemini_base_url = os.environ.get("GEMINI_API_BASE_URL") + self.original_small_model = os.environ.get("GEMINI_SMALL_MODEL") + self.original_large_model = os.environ.get("GEMINI_LARGE_MODEL") + self.original_small_tokens = os.environ.get("GEMINI_SMALL_MODEL_MAX_TOKENS") + self.original_large_tokens = os.environ.get("GEMINI_LARGE_MODEL_MAX_TOKENS") + self.original_system_prompt_path = os.environ.get("SYSTEM_PROMPT_PATH") + + for key in ["GEMINI_API_KEY", "GEMINI_API_BASE_URL", "GEMINI_SMALL_MODEL", "GEMINI_LARGE_MODEL", + "GEMINI_SMALL_MODEL_MAX_TOKENS", "GEMINI_LARGE_MODEL_MAX_TOKENS", "SYSTEM_PROMPT_PATH"]: + if os.environ.get(key): + del os.environ[key] + + self.mock_openai_client = MagicMock() # Used if superclass creates an OpenAI client + + def tearDown(self): + # Restore environment variables + if self.original_gemini_key: os.environ["GEMINI_API_KEY"] = self.original_gemini_key + if self.original_gemini_base_url: os.environ["GEMINI_API_BASE_URL"] = self.original_gemini_base_url + if self.original_small_model: os.environ["GEMINI_SMALL_MODEL"] = self.original_small_model + if self.original_large_model: os.environ["GEMINI_LARGE_MODEL"] = self.original_large_model + if self.original_small_tokens: os.environ["GEMINI_SMALL_MODEL_MAX_TOKENS"] = self.original_small_tokens + if self.original_large_tokens: os.environ["GEMINI_LARGE_MODEL_MAX_TOKENS"] = self.original_large_tokens + if self.original_system_prompt_path: os.environ["SYSTEM_PROMPT_PATH"] = self.original_system_prompt_path + + @patch.object(OpenAICompatibleInferenceBot, '__init__') # Mock the superclass's __init__ + def test_init_defaults_and_super_call(self, mock_super_init): + os.environ["GEMINI_API_KEY"] = "test_key_gemini" + os.environ["GEMINI_API_BASE_URL"] = "https://gemini.env.com" + os.environ["GEMINI_SMALL_MODEL"] = "gemini-pro-env" + os.environ["GEMINI_SMALL_MODEL_MAX_TOKENS"] = "360" + + bot = GeminiTelegramInferenceBot() + + mock_super_init.assert_called_once_with( + client=None, + api_key="test_key_gemini", + base_url="https://gemini.env.com", # Passed to super + api_version=None, + azure_deployment=None, + model_name="gemini-pro-env", + max_tokens_str="360", + small_model_name="gemini-pro-env", + small_model_max_tokens_str="360", + large_model_name=os.environ.get("GEMINI_LARGE_MODEL", "gemini-1.5-pro-latest"), # Default large + large_model_max_tokens_str=os.environ.get("GEMINI_LARGE_MODEL_MAX_TOKENS"), + system_prompt_content=None, + system_prompt_path=None, + is_gemini=True, # Important for Gemini bot + max_history_length=20 + ) + + @patch.object(OpenAICompatibleInferenceBot, '__init__') + def test_init_with_arguments(self, mock_super_init): + mock_client_arg = MagicMock() + bot = GeminiTelegramInferenceBot( + openai_client=mock_client_arg, # Name in Gemini bot is openai_client for consistency + api_key="arg_gem_key", + base_url="https://arg.gemini.com", + small_model_name="arg_gem_small", + small_model_max_tokens="124", + large_model_name="arg_gem_large", + large_model_max_tokens="457", + system_prompt_content="Gemini prompt" + ) + mock_super_init.assert_called_once_with( + client=mock_client_arg, + api_key="arg_gem_key", + base_url="https://arg.gemini.com", + api_version=None, + azure_deployment=None, + model_name="arg_gem_small", + max_tokens_str="124", + small_model_name="arg_gem_small", + small_model_max_tokens_str="124", + large_model_name="arg_gem_large", + large_model_max_tokens_str="457", + system_prompt_content="Gemini prompt", + system_prompt_path=None, + is_gemini=True, + max_history_length=20 + ) + + @patch('openai.OpenAI') # Gemini bot uses OpenAI client configured for Gemini endpoint + async def test_switch_model_logic(self, mock_openai_constructor): + mock_openai_constructor.return_value = self.mock_openai_client + + os.environ["GEMINI_SMALL_MODEL"] = "env-gemini-small" + os.environ["GEMINI_SMALL_MODEL_MAX_TOKENS"] = "110" + os.environ["GEMINI_LARGE_MODEL"] = "env-gemini-large" + os.environ["GEMINI_LARGE_MODEL_MAX_TOKENS"] = "220" + + bot = GeminiTelegramInferenceBot() # Uses env vars by default + self.assertEqual(bot.model, "env-gemini-small") + self.assertEqual(bot.max_tokens, 110) + + status = await bot.switch_model() + self.assertEqual(bot.model, "env-gemini-large") + self.assertEqual(bot.max_tokens, 220) + self.assertEqual(status, "Switched to model: env-gemini-large") + + status = await bot.switch_model() + self.assertEqual(bot.model, "env-gemini-small") + self.assertEqual(bot.max_tokens, 110) + self.assertEqual(status, "Switched to model: env-gemini-small") + + @patch('openai.OpenAI') + async def test_switch_model_uses_instance_configs_if_provided(self, mock_openai_constructor): + mock_openai_constructor.return_value = self.mock_openai_client + + bot = GeminiTelegramInferenceBot( + small_model_name="init-gem-small", small_model_max_tokens="55", + large_model_name="init-gem-large", large_model_max_tokens="155" + ) + self.assertEqual(bot.model, "init-gem-small") + self.assertEqual(bot.max_tokens, 55) + + status = await bot.switch_model() + self.assertEqual(bot.model, "init-gem-large") + self.assertEqual(bot.max_tokens, 155) + self.assertEqual(status, "Switched to model: init-gem-large") + + status = await bot.switch_model() + self.assertEqual(bot.model, "init-gem-small") + self.assertEqual(bot.max_tokens, 55) + self.assertEqual(status, "Switched to model: init-gem-small") + + @patch('openai.OpenAI') + def test_get_llm_description_for_gemini_bot(self, mock_openai_constructor): + mock_openai_constructor.return_value = self.mock_openai_client + bot = GeminiTelegramInferenceBot( + small_model_name="gemini-pro-desc", + small_model_max_tokens="888", + # is_gemini is True by default in constructor call to super + ) + # LLM description should indicate not Azure, even though it uses OpenAICompatible... base + # The is_gemini flag primarily affects client instantiation logic in the superclass. + # The azure_openai flag in superclass is based on azure_endpoint presence. + self.assertEqual(bot.get_llm_description(), "LLM: gemini-pro-desc, Max Tokens: 888, Azure: False") + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_openai_compatible_inference_bot.py b/tests/test_openai_compatible_inference_bot.py new file mode 100644 index 0000000..dc667c0 --- /dev/null +++ b/tests/test_openai_compatible_inference_bot.py @@ -0,0 +1,332 @@ +import unittest +from unittest.mock import MagicMock, patch, AsyncMock, ANY +import os +import json + +# Assuming openai_compatible_inference_bot.py is in the parent directory or PYTHONPATH is set +from openai_compatible_inference_bot import OpenAICompatibleInferenceBot + +# Mock response from OpenAI client's chat.completions.create +def create_mock_openai_response(content=None, tool_calls=None): + mock_message = MagicMock() + mock_message.role = "assistant" + mock_message.content = content + if tool_calls: + # tool_calls should be a list of objects with id and function (name, arguments) + mock_tool_calls = [] + for tc in tool_calls: + mock_tc = MagicMock() + mock_tc.id = tc["id"] + mock_tc.function.name = tc["function"]["name"] + mock_tc.function.arguments = tc["function"]["arguments"] + mock_tool_calls.append(mock_tc) + mock_message.tool_calls = mock_tool_calls + else: + mock_message.tool_calls = None + + mock_choice = MagicMock() + mock_choice.message = mock_message + + mock_response = MagicMock() + mock_response.choices = [mock_choice] + return mock_response + +# Concrete class for testing +class ConcreteOpenAICompatibleBot(OpenAICompatibleInferenceBot): + # Implement abstract methods for instantiation + async def switch_model(self): + # Simple switch for testing if needed, or just pass + if self.model == self.small_model_name: + self._configure_model_and_tokens(self.large_model_name, self.large_model_max_tokens_str) + else: + self._configure_model_and_tokens(self.small_model_name, self.small_model_max_tokens_str) + return f"Switched to {self.model}" + + # Override load_functions if it's called by parent and needs mocking for these tests + # (OpenAICompatibleInferenceBot's __init__ calls BaseTelegramInferenceBot's __init__, which calls load_functions) + def load_functions(self): + # For these tests, assume no tools unless specifically added + self.tools = [] + self.functions = [] + return self.tools, self.functions + + +class TestOpenAICompatibleInferenceBot(unittest.IsolatedAsyncioTestCase): + + def setUp(self): + self.original_openai_api_key = os.environ.get("OPENAI_API_KEY") + self.original_azure_openai_key = os.environ.get("AZURE_OPENAI_KEY") + self.original_azure_endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT") + self.original_api_version = os.environ.get("AZURE_OPENAI_API_VERSION") + self.original_azure_deployment = os.environ.get("AZURE_DEPLOYMENT_NAME") + + # Clear relevant env vars before each test + for key in ["OPENAI_API_KEY", "AZURE_OPENAI_KEY", "AZURE_OPENAI_ENDPOINT", + "AZURE_OPENAI_API_VERSION", "AZURE_DEPLOYMENT_NAME", "SYSTEM_PROMPT_PATH"]: + if os.environ.get(key): + del os.environ[key] + + self.mock_openai_client_instance = MagicMock() + self.mock_openai_client_instance.chat.completions.create = MagicMock() + + def tearDown(self): + # Restore environment variables + if self.original_openai_api_key: os.environ["OPENAI_API_KEY"] = self.original_openai_api_key + if self.original_azure_openai_key: os.environ["AZURE_OPENAI_KEY"] = self.original_azure_openai_key + if self.original_azure_endpoint: os.environ["AZURE_OPENAI_ENDPOINT"] = self.original_azure_endpoint + if self.original_api_version: os.environ["AZURE_OPENAI_API_VERSION"] = self.original_api_version + if self.original_azure_deployment: os.environ["AZURE_DEPLOYMENT_NAME"] = self.original_azure_deployment + + + @patch('openai.OpenAI') + def test_init_with_openai_defaults(self, MockOpenAIConstructor): + MockOpenAIConstructor.return_value = self.mock_openai_client_instance + os.environ["OPENAI_API_KEY"] = "test_openai_key" + + bot = ConcreteOpenAICompatibleBot(model_name="gpt-4") + + MockOpenAIConstructor.assert_called_once_with(api_key="test_openai_key", base_url=None) + self.assertEqual(bot.client, self.mock_openai_client_instance) + self.assertEqual(bot.model, "gpt-4") + self.assertEqual(bot.max_tokens, 1000) # Default from _configure_model_and_tokens + self.assertEqual(bot.azure_openai, False) + + @patch('openai.OpenAI') + def test_init_with_provided_client(self, MockOpenAIConstructor): + preconfigured_client = MagicMock() + bot = ConcreteOpenAICompatibleBot(client=preconfigured_client, model_name="gpt-3.5") + + MockOpenAIConstructor.assert_not_called() + self.assertEqual(bot.client, preconfigured_client) + self.assertEqual(bot.model, "gpt-3.5") + + @patch('openai.AzureOpenAI') + def test_init_with_azure_config_args(self, MockAzureOpenAIConstructor): + MockAzureOpenAIConstructor.return_value = self.mock_openai_client_instance + + bot = ConcreteOpenAICompatibleBot( + api_key="azure_key", + azure_endpoint="https://myenv.openai.azure.com", + api_version="2023-05-15", + azure_deployment="my-gpt-4", # This should be used as model_name for API call + model_name="should_be_overridden_by_azure_deployment_for_api" + # model_name is passed to _configure_model_and_tokens, which sets self.model for display/logging + # but for Azure, the client needs the deployment name. + ) + + MockAzureOpenAIConstructor.assert_called_once_with( + api_key="azure_key", + azure_endpoint="https://myenv.openai.azure.com", + api_version="2023-05-15" + ) + self.assertEqual(bot.client, self.mock_openai_client_instance) + self.assertEqual(bot.model, "my-gpt-4") # Azure deployment name becomes the model for API calls + self.assertEqual(bot.azure_openai, True) + + + @patch('openai.AzureOpenAI') + def test_init_with_azure_env_vars(self, MockAzureOpenAIConstructor): + MockAzureOpenAIConstructor.return_value = self.mock_openai_client_instance + os.environ["AZURE_OPENAI_KEY"] = "env_azure_key" + os.environ["AZURE_OPENAI_ENDPOINT"] = "https://env.openai.azure.com" + os.environ["AZURE_OPENAI_API_VERSION"] = "2023-06-01" + os.environ["AZURE_DEPLOYMENT_NAME"] = "env-gpt-35" # Used as model_name + + bot = ConcreteOpenAICompatibleBot(model_name="ignored_if_azure_deployment_env_is_set") + + MockAzureOpenAIConstructor.assert_called_once_with( + api_key="env_azure_key", + azure_endpoint="https://env.openai.azure.com", + api_version="2023-06-01" + ) + self.assertEqual(bot.model, "env-gpt-35") + self.assertTrue(bot.azure_openai) + + @patch('openai.OpenAI') + def test_init_with_gemini_config_args(self, MockOpenAIConstructor): + MockOpenAIConstructor.return_value = self.mock_openai_client_instance + + bot = ConcreteOpenAICompatibleBot( + api_key="gemini_key", + base_url="https://gemini.example.com", + model_name="gemini-pro", + is_gemini=True + ) + MockOpenAIConstructor.assert_called_once_with(api_key="gemini_key", base_url="https://gemini.example.com") + self.assertEqual(bot.model, "gemini-pro") + self.assertFalse(bot.azure_openai) # is_gemini doesn't mean azure_openai + + def test_configure_model_and_tokens(self): + bot = ConcreteOpenAICompatibleBot(model_name="initial_model") # init calls _configure + bot._configure_model_and_tokens("test-model", "500") + self.assertEqual(bot.model, "test-model") + self.assertEqual(bot.max_tokens, 500) + + bot._configure_model_and_tokens("test-model-2", None, default_max_tokens=150) + self.assertEqual(bot.max_tokens, 150) + + bot._configure_model_and_tokens("test-model-3", "invalid_token_val") + self.assertEqual(bot.max_tokens, 1000) # Default fallback + + def test_get_llm_description(self): + bot = ConcreteOpenAICompatibleBot(model_name="desc-model", max_tokens_str="256") + self.assertEqual(bot.get_llm_description(), "LLM: desc-model, Max Tokens: 256, Azure: False") + + bot_azure = ConcreteOpenAICompatibleBot(azure_deployment="azure-model", azure_endpoint="x", api_key="y", api_version="z") + self.assertEqual(bot_azure.get_llm_description(), "LLM: azure-model, Max Tokens: 1000, Azure: True") + + + def test_get_chat_response_success(self): + bot = ConcreteOpenAICompatibleBot(client=self.mock_openai_client_instance, model_name="test-gpt") + bot.max_tokens = 50 # Ensure this is set + mock_api_response = create_mock_openai_response(content="Hello from API") + self.mock_openai_client_instance.chat.completions.create.return_value = mock_api_response + + messages = [{"role": "user", "content": "Hi"}] + response = bot.get_chat_response(messages) + + self.mock_openai_client_instance.chat.completions.create.assert_called_once_with( + model="test-gpt", + messages=messages, + tools=ANY, # Assuming functions can be None or empty list + tool_choice=ANY, + max_tokens=50 + ) + self.assertEqual(response, mock_api_response) + + def test_get_chat_response_api_error(self): + bot = ConcreteOpenAICompatibleBot(client=self.mock_openai_client_instance, model_name="error-gpt") + self.mock_openai_client_instance.chat.completions.create.side_effect = Exception("API Down") + + with self.assertRaisesRegex(Exception, "API Down"): + bot.get_chat_response([{"role": "user", "content": "trigger"}]) + + async def test_handle_message_simple_response(self): + bot = ConcreteOpenAICompatibleBot(client=self.mock_openai_client_instance, model_name="chatty") + bot.system_prompt = "You are a test bot." # Set directly for simplicity + mock_api_response = create_mock_openai_response(content="Test reply") + self.mock_openai_client_instance.chat.completions.create.return_value = mock_api_response + + response_content = await bot.handle_message(user_id=1, user_message="Hello") + + self.assertEqual(response_content, "Test reply") + self.assertIn(1, bot.conversation_history) + self.assertEqual(len(bot.conversation_history[1]), 3) # System, User, Assistant + self.assertEqual(bot.conversation_history[1][0]["content"], "You are a test bot.") + self.assertEqual(bot.conversation_history[1][2]["content"], "Test reply") + + async def test_handle_message_with_tool_call_and_response(self): + bot = ConcreteOpenAICompatibleBot(client=self.mock_openai_client_instance, model_name="tool-user") + + # Mock functions/tools setup on the bot + mock_tool_def = {"function": {"name": "get_weather", "description": "Gets weather", "parameters": {}}} + bot.functions = [mock_tool_def] # Simulate tools are loaded + + # API response 1: Request to call tool + tool_call_request = [{"id": "call123", "function": {"name": "get_weather", "arguments": '''{"location": "moon"}'''}}] + api_response_1 = create_mock_openai_response(tool_calls=tool_call_request) + + # API response 2: Final answer after tool execution + api_response_2 = create_mock_openai_response(content="The weather on the moon is chilly.") + + self.mock_openai_client_instance.chat.completions.create.side_effect = [api_response_1, api_response_2] + + # Mock self.call_tool + bot.call_tool = MagicMock(return_value='''{"temperature": "-100 C"}''') + + final_response = await bot.handle_message(user_id=2, user_message="Weather on moon?") + + self.assertEqual(final_response, "The weather on the moon is chilly.") + bot.call_tool.assert_called_once_with("get_weather", '''{"location": "moon"}''') + + # Check conversation history includes tool messages + history = bot.conversation_history[2] + self.assertTrue(any(msg["role"] == "assistant" and msg.tool_calls is not None for msg in history)) + self.assertTrue(any(msg["role"] == "tool" and msg["name"] == "get_weather" for msg in history)) + self.assertEqual(self.mock_openai_client_instance.chat.completions.create.call_count, 2) + + async def test_handle_message_max_history_length(self): + bot = ConcreteOpenAICompatibleBot(client=self.mock_openai_client_instance, model_name="hist-test", max_history_length=3) + self.mock_openai_client_instance.chat.completions.create.return_value = create_mock_openai_response(content="Ok") + + await bot.handle_message(1, "Msg1") # Sys, User, Assist (3) + self.assertEqual(len(bot.conversation_history[1]), 3) + + await bot.handle_message(1, "Msg2") # User, Assist. Should be 3 (prev User, prev Assist, new User) -> then adds new Assist. + # Before new call: [Sys, U1, A1]. New U2. Call with [Sys,U1,A1,U2]. Resp A2. + # History: [Sys,U1,A1,U2,A2]. Limit 3. -> [A1,U2,A2] (if system is not preserved specially) + # The current code appends to history then truncates if over limit. + # So after Msg1: [S, U1, A1]. len=3. + # For Msg2: History is [S, U1, A1]. Append U2. Call with [S,U1,A1,U2]. Append A2. + # History now [S,U1,A1,U2,A2]. len=5. Truncate to 3. + # Expected: [A1, U2, A2] or [U1,A1,U2] or [U2,A2,S] depending on how system prompt is handled in truncation. + # The code is: self.conversation_history[user_id][-self.max_history_length:] + # And system prompt is only added IF user_id not in self.conversation_history. + # So, for Msg2, system prompt is not re-added. + # History before Msg2 call: [S, U1, A1] + # Messages for Msg2 call: [S, U1, A1, U2] + # History after Msg2 response A2: [S, U1, A1, U2, A2]. Len 5. + # Truncated to self.max_history_length=3: [A1, U2, A2] + + # Call 1 + self.mock_openai_client_instance.chat.completions.create.reset_mock() + self.mock_openai_client_instance.chat.completions.create.return_value = create_mock_openai_response(content="Reply1") + await bot.handle_message(user_id=7, user_message="First message") + self.assertEqual(len(bot.conversation_history[7]), 3) # System, User1, Assistant1 + + # Call 2 + self.mock_openai_client_instance.chat.completions.create.reset_mock() + self.mock_openai_client_instance.chat.completions.create.return_value = create_mock_openai_response(content="Reply2") + await bot.handle_message(user_id=7, user_message="Second message") + # History before call: [S, U1, A1]. Messages for call: [S, U1, A1, U2]. History after: [S, U1, A1, U2, A2]. + # Truncated to 3: [A1, U2, A2] + self.assertEqual(len(bot.conversation_history[7]), 3) + self.assertEqual(bot.conversation_history[7][0]["content"], "Reply1") # A1 + self.assertEqual(bot.conversation_history[7][1]["content"], "Second message") # U2 + self.assertEqual(bot.conversation_history[7][2]["content"], "Reply2") # A2 + + # Call 3 + self.mock_openai_client_instance.chat.completions.create.reset_mock() + self.mock_openai_client_instance.chat.completions.create.return_value = create_mock_openai_response(content="Reply3") + await bot.handle_message(user_id=7, user_message="Third message") + # History before call: [A1, U2, A2]. Messages for call: [A1, U2, A2, U3]. History after: [A1, U2, A2, U3, A3]. + # Truncated to 3: [A2, U3, A3] + self.assertEqual(len(bot.conversation_history[7]), 3) + self.assertEqual(bot.conversation_history[7][0]["content"], "Reply2") # A2 + self.assertEqual(bot.conversation_history[7][1]["content"], "Third message") # U3 + self.assertEqual(bot.conversation_history[7][2]["content"], "Reply3") # A3 + + + async def test_abort_processing(self): + bot = ConcreteOpenAICompatibleBot(model_name="test") + user_id = 123 + bot.processing_status[user_id] = {"processing": True, "message_id": 456} + bot.conversation_history[user_id] = [{"role": "user", "content": "stuff"}] + + with patch.object(bot, 'clear_conversation_history') as mock_clear_hist: # Patching the method from Base class + result = await bot.abort_processing(user_id) + + self.assertEqual(result, "Processing aborted and conversation cleared.") + self.assertFalse(bot.processing_status[user_id]["processing"]) + mock_clear_hist.assert_called_once_with(user_id) + + async def test_abort_processing_no_active_processing(self): + bot = ConcreteOpenAICompatibleBot(model_name="test") + user_id = 404 # Not in processing_status + with patch.object(bot, 'clear_conversation_history') as mock_clear_hist: + result = await bot.abort_processing(user_id) + self.assertEqual(result, "No active processing found to abort. Conversation cleared.") + mock_clear_hist.assert_called_once_with(user_id) + + # Test for the abstract switch_model (basic call, actual logic in concrete class for this test) + async def test_switch_model_concrete_implementation(self): + bot = ConcreteOpenAICompatibleBot(model_name="model1", small_model_name="model1", large_model_name="model2", max_tokens_str="100") + self.assertEqual(bot.model, "model1") + await bot.switch_model() # Calls the concrete implementation + self.assertEqual(bot.model, "model2") + await bot.switch_model() + self.assertEqual(bot.model, "model1") + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_telegram_helper.py b/tests/test_telegram_helper.py new file mode 100644 index 0000000..6b8d655 --- /dev/null +++ b/tests/test_telegram_helper.py @@ -0,0 +1,356 @@ +import unittest +from unittest.mock import MagicMock, patch, mock_open, AsyncMock +import asyncio +import os +import sys + +# Assuming telegram_helper.py is in the parent directory or PYTHONPATH is set +from telegram_helper import TelegramHelper, MessageHandlerLogicResult + +# Mock for the bot passed to TelegramHelper +class MockBot: + def __init__(self): + self.start = AsyncMock() + self.clear_conversation_history = MagicMock() + self.get_bot_status = AsyncMock(return_value="Bot Status OK") + self.switch_model = AsyncMock(return_value="Model Switched OK") + self.handle_message = AsyncMock() # Needs to return a string + self.abort_processing = AsyncMock(return_value="Abort OK") + self.set_processing_status = MagicMock() + self.clear_processing_status = MagicMock() + self.processing_status = {} # Add the attribute + +# Mock for telegram.Update and related objects +def create_mock_update(message_text=None, user_id=123, chat_id=456, message_id=789, callback_query_data=None): + update = MagicMock() + update.effective_user.id = user_id + update.effective_chat.id = chat_id + + if message_text: + update.message.text = message_text + update.message.reply_text = AsyncMock(return_value=MagicMock(message_id=message_id)) # reply_text returns a Message obj + + if callback_query_data: + update.callback_query.data = callback_query_data + update.callback_query.from_user.id = user_id + update.callback_query.answer = AsyncMock() + update.callback_query.edit_message_text = AsyncMock() + + return update + +# Mock for telegram.ext.ContextTypes.DEFAULT_TYPE +def create_mock_context(): + context = MagicMock() + context.bot.delete_message = AsyncMock() + context.bot.edit_message_text = AsyncMock() # For update_status_message + return context + +class TestTelegramHelper(unittest.IsolatedAsyncioTestCase): # Use IsolatedAsyncioTestCase for async methods + + def setUp(self): + self.mock_bot = MockBot() + # Default paths for reboot files, can be overridden in tests + self.reboot_claude_file = ".test_reboot_claude" + self.reboot_file = ".test_doreboot" + self.helper = TelegramHelper( + self.mock_bot, + reboot_claude_file_path=self.reboot_claude_file, + reboot_file_path=self.reboot_file, + chunk_message_sleep_duration=0.001 # Faster sleep for tests + ) + # Clean up any potential leftover reboot files from previous runs + if os.path.exists(self.reboot_claude_file): + os.remove(self.reboot_claude_file) + if os.path.exists(self.reboot_file): + os.remove(self.reboot_file) + + def tearDown(self): + # Clean up reboot files created during tests + if os.path.exists(self.reboot_claude_file): + os.remove(self.reboot_claude_file) + if os.path.exists(self.reboot_file): + os.remove(self.reboot_file) + + async def test_start_logic(self): + response = await self.helper._start_logic() + self.mock_bot.start.assert_called_once() + self.assertEqual(response, "Hello! I\'m your AI assistant. How can I help you today?") + + async def test_start_command(self): + mock_update = create_mock_update(message_text="/start") + mock_context = create_mock_context() + + with patch.object(self.helper, \'_start_logic\', new_callable=AsyncMock) as mock_logic: + mock_logic.return_value = "Start Logic Response" + await self.helper.start(mock_update, mock_context) + mock_logic.assert_called_once() + mock_update.message.reply_text.assert_called_once_with("Start Logic Response") + + async def test_clear_logic(self): + user_id = 123 + response = await self.helper._clear_logic(user_id) # _clear_logic is async after refactor + self.mock_bot.clear_conversation_history.assert_called_once_with(user_id) + self.assertEqual(response, "Conversation history cleared. Let\'s start fresh!") + + async def test_clear_command(self): + mock_update = create_mock_update(message_text="/clear", user_id=123) + mock_context = create_mock_context() + with patch.object(self.helper, \'_clear_logic\', new_callable=AsyncMock) as mock_logic: + mock_logic.return_value = "Clear Logic Response" + await self.helper.clear(mock_update, mock_context) + mock_logic.assert_called_once_with(123) + mock_update.message.reply_text.assert_called_once_with("Clear Logic Response") + + async def test_status_logic(self): + self.mock_bot.get_bot_status.return_value = "Test Status" + response = await self.helper._status_logic() + self.mock_bot.get_bot_status.assert_called_once() + self.assertEqual(response, "Test Status") + + async def test_switch_logic_supported(self): + self.mock_bot.switch_model.return_value = "Switched to Large Model" + response = await self.helper._switch_logic() + self.mock_bot.switch_model.assert_called_once() + self.assertEqual(response, "Switched to Large Model") + + async def test_switch_logic_not_supported(self): + del self.mock_bot.switch_model # Simulate bot not having the attribute + response = await self.helper._switch_logic() + self.assertEqual(response, "Model switching is not supported for this bot.") + + async def test_handle_message_logic_success(self): + user_id = 100 + user_message = "Hello bot" + bot_response = "Hello user Thinking hard Done." + expected_processed_response = f"Hello user {self.helper.HTML_QUOTE_BLOCK_START}Thinking hard{self.helper.HTML_QUOTE_BLOCK_END} Done." + self.mock_bot.handle_message.return_value = bot_response + + result = await self.helper._handle_message_logic(user_id, user_message) + + self.mock_bot.handle_message.assert_called_once_with(user_id, user_message) + self.assertTrue(result["success"]) + self.assertEqual(result["response_text"], expected_processed_response) + self.assertIsNone(result["error_message"]) + + async def test_handle_message_logic_bot_exception(self): + user_id = 101 + user_message = "Trigger error" + self.mock_bot.handle_message.side_effect = Exception("Bot Error") + + result = await self.helper._handle_message_logic(user_id, user_message) + + self.assertFalse(result["success"]) + self.assertIsNone(result["response_text"]) + self.assertEqual(result["error_message"], "Bot Error") + + @patch(\'logging.error\') + async def test_handle_message_command_success_short_message(self, mock_logging_error): + mock_update = create_mock_update(message_text="Hi", user_id=200, chat_id=201, message_id=202) + mock_context = create_mock_context() + + logic_result = MessageHandlerLogicResult(success=True, response_text="Short response", error_message=None) + + with patch.object(self.helper, \'_handle_message_logic\', new_callable=AsyncMock) as mock_message_logic: + mock_message_logic.return_value = logic_result + + await self.helper.handle_message(mock_update, mock_context) + + mock_update.message.reply_text.assert_any_call("Processing your request...", reply_markup=unittest.mock.ANY) + self.mock_bot.set_processing_status.assert_called_once_with(200, 202) # user_id, status_message_id + mock_message_logic.assert_called_once_with(200, "Hi") + mock_context.bot.delete_message.assert_called_once_with(chat_id=201, message_id=202) + self.mock_bot.clear_processing_status.assert_called_once_with(200) + mock_update.message.reply_text.assert_any_call("Short response") # Final response + self.assertEqual(mock_update.message.reply_text.call_count, 2) # Processing + final + + @patch(\'logging.error\') + async def test_handle_message_command_success_long_message_chunks(self, mock_logging_error): + mock_update = create_mock_update(message_text="Long text", user_id=200, chat_id=201, message_id=202) + mock_context = create_mock_context() + + long_response_text = "a" * 5000 # Longer than 4096 + chunk1 = long_response_text[:4096] + chunk2 = long_response_text[4096:] + + logic_result = MessageHandlerLogicResult(success=True, response_text=long_response_text, error_message=None) + + with patch.object(self.helper, \'_handle_message_logic\', new_callable=AsyncMock) as mock_message_logic, \ + patch(\'asyncio.sleep\', new_callable=AsyncMock) as mock_sleep: # Mock sleep + mock_message_logic.return_value = logic_result + + await self.helper.handle_message(mock_update, mock_context) + + mock_update.message.reply_text.assert_any_call(chunk1) + mock_update.message.reply_text.assert_any_call(chunk2) + mock_sleep.assert_called_once_with(self.helper.chunk_message_sleep_duration) + self.assertEqual(mock_update.message.reply_text.call_count, 3) # Processing + 2 chunks + + @patch(\'logging.error\') + async def test_handle_message_command_logic_fails(self, mock_logging_error): + mock_update = create_mock_update(message_text="Cause error in logic", user_id=200) + mock_context = create_mock_context() + logic_result = MessageHandlerLogicResult(success=False, response_text=None, error_message="Logic Failed") + + with patch.object(self.helper, \'_handle_message_logic\', new_callable=AsyncMock) as mock_message_logic: + mock_message_logic.return_value = logic_result + await self.helper.handle_message(mock_update, mock_context) + mock_update.message.reply_text.assert_any_call("Sorry, an error occurred while processing your request.") + self.assertEqual(mock_update.message.reply_text.call_count, 2) # Processing + error message + + @patch(\'logging.error\') + async def test_handle_message_command_telegram_exception_after_logic(self, mock_logging_error): + mock_update = create_mock_update(message_text="Test", user_id=200) + mock_context = create_mock_context() + logic_result = MessageHandlerLogicResult(success=True, response_text="OK", error_message=None) + + # Make sending the final reply fail + mock_update.message.reply_text.side_effect = [ + MagicMock(message_id=202), # For "Processing..." + Exception("Telegram API Error") # For the actual response + ] + + with patch.object(self.helper, \'_handle_message_logic\', new_callable=AsyncMock) as mock_message_logic: + mock_message_logic.return_value = logic_result + await self.helper.handle_message(mock_update, mock_context) + + # Check if the generic error message was attempted + # This is tricky because reply_text is already mocked with side_effect. + # We\'d expect logs. Let\'s check logs or if processing status was cleared. + self.mock_bot.clear_processing_status.assert_called_once_with(200) + mock_logging_error.assert_any_call(unittest.mock.string_containing("Outer error in handle_message")) + + + async def test_abort_processing_logic(self): + user_id = 300 + self.mock_bot.abort_processing.return_value = "Aborted by bot" + response = await self.helper._abort_processing_logic(user_id) + self.mock_bot.abort_processing.assert_called_once_with(user_id) + self.assertEqual(response, "Aborted by bot") + + async def test_abort_processing_command(self): + mock_update = create_mock_update(callback_query_data=\'abort\', user_id=301) + mock_context = create_mock_context() + with patch.object(self.helper, \'_abort_processing_logic\', new_callable=AsyncMock) as mock_logic: + mock_logic.return_value = "Abort Logic Done" + await self.helper.abort_processing(mock_update, mock_context) + + mock_update.callback_query.answer.assert_called_once() + mock_logic.assert_called_once_with(301) + mock_update.callback_query.edit_message_text.assert_called_once_with(text="Abort Logic Done") + + def test_reboot_logic_claude_and_main(self): + user_message_parts = ["/reboot", "claude"] + chat_id_to_write = "12345" + + with patch("builtins.open", mock_open()) as mock_file: + self.helper._reboot_logic(user_message_parts, chat_id_to_write) + + # Check claude reboot file + mock_file.assert_any_call(self.reboot_claude_file, \'w\') + # Check main doreboot file + mock_file.assert_any_call(self.reboot_file, \'w\') + handle_claude = mock_file.return_value + handle_main = mock_file.return_value # mock_open reuses the handle for multiple calls + + # Check if write was called for claude file (empty write) + # This part of assertion is tricky with single mock_file. Better to use different mocks if possible + # or check the sequence of calls if the mock supports it well. + # For now, assert_any_call ensures it was opened. + + # Check content for main reboot file + # Need to ensure the write for self.reboot_file had chat_id_to_write + # This requires more sophisticated mock_open or patching os.path.exists and multiple open calls + # Simpler check: was open(self.reboot_file, \'w\') called? Yes, via assert_any_call. + # And was open(self.reboot_claude_file, \'w\') called? Yes. + + # Verify files were created (mock_open doesn\'t actually create them) + # This test relies on mock_open\'s behavior. To test file content, need more setup. + # For now, assume open was called correctly. + + def test_reboot_logic_main_only(self): + user_message_parts = ["/reboot"] + chat_id_to_write = "67890" + with patch("builtins.open", mock_open()) as mock_file: + self.helper._reboot_logic(user_message_parts, chat_id_to_write) + # Ensure claude file was NOT opened for writing. + # This requires asserting that a specific call didn\'t happen, or checking call_args_list + claude_call = unittest.mock.call(self.reboot_claude_file, \'w\') + self.assertNotIn(claude_call, mock_file.call_args_list) + + mock_file.assert_any_call(self.reboot_file, \'w\') + + @patch(\'sys.exit\') # Mock sys.exit to prevent test runner from exiting + async def test_reboot_command(self, mock_sys_exit): + mock_update = create_mock_update(message_text="/reboot claude", chat_id="chat1") + mock_context = create_mock_context() + + with patch.object(self.helper, \'_reboot_logic\') as mock_reboot_file_logic: + await self.helper.reboot(mock_update, mock_context) + + mock_reboot_file_logic.assert_called_once_with(["/reboot", "claude"], "chat1") + mock_update.message.reply_text.assert_called_once_with("Rebooting the bot...") + mock_sys_exit.assert_called_once_with(0) + + @patch(\'os.path.exists\') + @patch(\'builtins.open\', new_callable=mock_open) + @patch(\'os.remove\') + async def test_check_doreboot_file_logic_file_exists(self, mock_os_remove, mock_file_open, mock_os_path_exists): + mock_os_path_exists.return_value = True + mock_file_open.return_value.read.return_value.strip.return_value = "chat123" + + chat_id = await self.helper._check_doreboot_file_logic() + + mock_os_path_exists.assert_called_once_with(self.reboot_file) + mock_file_open.assert_called_once_with(self.reboot_file, \'r\') + mock_os_remove.assert_called_once_with(self.reboot_file) + self.assertEqual(chat_id, "chat123") + + @patch(\'os.path.exists\', return_value=False) + async def test_check_doreboot_file_logic_file_not_exists(self, mock_os_path_exists): + chat_id = await self.helper._check_doreboot_file_logic() + mock_os_path_exists.assert_called_once_with(self.reboot_file) + self.assertIsNone(chat_id) + + @patch(\'logging.error\') + @patch(\'os.path.exists\', return_value=True) + @patch(\'builtins.open\', side_effect=IOError("Read error")) + @patch(\'os.remove\') # To check if remove is called even on read error + async def test_check_doreboot_file_logic_read_error(self, mock_os_remove, mock_file_open, mock_os_path_exists, mock_log_error): + chat_id = await self.helper._check_doreboot_file_logic() + + self.assertIsNone(chat_id) + mock_log_error.assert_any_call(unittest.mock.string_containing("Error reading reboot file")) + # Check if os.remove was attempted even after read error + mock_os_remove.assert_called_once_with(self.reboot_file) + + + async def test_check_doreboot_file_command_sends_message(self): + mock_application = MagicMock() + mock_application.bot.send_message = AsyncMock() + + with patch.object(self.helper, \'_check_doreboot_file_logic\', new_callable=AsyncMock) as mock_logic: + mock_logic.return_value = "chat789" # Simulate chat_id found + await self.helper.check_doreboot_file(mock_application) + + mock_logic.assert_called_once() + mock_application.bot.send_message.assert_called_once_with( + chat_id="chat789", text="The application has finished initializing." + ) + + async def test_check_doreboot_file_command_no_chat_id(self): + mock_application = MagicMock() + mock_application.bot.send_message = AsyncMock() + + with patch.object(self.helper, \'_check_doreboot_file_logic\', new_callable=AsyncMock) as mock_logic: + mock_logic.return_value = None # Simulate no chat_id found + await self.helper.check_doreboot_file(mock_application) + + mock_logic.assert_called_once() + mock_application.bot.send_message.assert_not_called() + + # Note: Testing the run() method itself is more of an integration test, + # as it involves setting up the full Application and polling loop. + # Unit tests here focus on the helper\'s own logic methods. + +if __name__ == \'__main__\': + unittest.main() diff --git a/tests/tools/test_github_tool.py b/tests/tools/test_github_tool.py new file mode 100644 index 0000000..6c10b91 --- /dev/null +++ b/tests/tools/test_github_tool.py @@ -0,0 +1,307 @@ +import unittest +from unittest.mock import MagicMock, patch +import os +import base64 +import logging +import requests # Required for spec in MagicMock + +# Ensure tools/github_tool.py is accessible +from tools.github_tool import GitHubTool + +# Helper to create a mock response for requests.Session +def create_mock_response(status_code, json_data=None, text_data=None, headers=None, links=None): + mock_resp = MagicMock() + mock_resp.status_code = status_code + if json_data is not None: + mock_resp.json = MagicMock(return_value=json_data) + mock_resp.text = text_data if text_data is not None else str(json_data) + mock_resp.headers = headers if headers else {} + mock_resp.links = links if links else {} # For pagination in _list_branches + return mock_resp + +class TestGitHubTool(unittest.TestCase): + + def setUp(self): + self.mock_session = MagicMock(spec=requests.Session) + self.mock_session.headers = {} # Simulate a new session's headers + + self.test_token = "test_github_token" + self.test_repo = "owner/repo" + self.test_base_url = "https://api.example.com" # Use a non-default base_url for some tests + + # Suppress logging output during tests unless explicitly testing for it + self.logger = logging.getLogger('tools.github_tool') + # Ensure only one NullHandler to prevent duplicate messages if tests run multiple times in a session + if not any(isinstance(h, logging.NullHandler) for h in self.logger.handlers): + self.logger.addHandler(logging.NullHandler()) + self.logger.propagate = False # Prevent propagation to root logger if it has handlers + + def test_init_with_args_and_session(self): + tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, base_url=self.test_base_url, logger=self.logger) + self.assertEqual(tool.session, self.mock_session) + self.assertEqual(tool._token, self.test_token) + self.assertEqual(tool._repo, self.test_repo) + self.assertEqual(tool.base_url, self.test_base_url) + self.assertEqual(tool.current_branch, "main") # Default initial branch + + @patch('requests.Session') + def test_init_creates_session_if_not_provided(self, MockSessionConstructor): + mock_created_session = MagicMock(spec=requests.Session) + mock_created_session.headers = {} + MockSessionConstructor.return_value = mock_created_session + + # Temporarily set env vars for this test + original_token = os.environ.get("GITHUB_TOKEN") + original_repo = os.environ.get("GITHUB_REPOSITORY") + os.environ["GITHUB_TOKEN"] = "env_token" + os.environ["GITHUB_REPOSITORY"] = "env/repo" + + tool = GitHubTool(logger=self.logger) # Use env vars + + MockSessionConstructor.assert_called_once() + self.assertEqual(tool.session, mock_created_session) + self.assertEqual(tool._token, "env_token") + self.assertEqual(tool._repo, "env/repo") + self.assertIn("Authorization", mock_created_session.headers) + self.assertEqual(mock_created_session.headers["Authorization"], "token env_token") + + # Restore original env vars + if original_token is None: del os.environ["GITHUB_TOKEN"] + else: os.environ["GITHUB_TOKEN"] = original_token + if original_repo is None: del os.environ["GITHUB_REPOSITORY"] + else: os.environ["GITHUB_REPOSITORY"] = original_repo + + def test_init_raises_value_error_if_no_token(self): + original_token = os.environ.pop("GITHUB_TOKEN", None) + with self.assertRaisesRegex(ValueError, "GitHub token must be provided"): + GitHubTool(repo=self.test_repo, logger=self.logger) + if original_token: os.environ["GITHUB_TOKEN"] = original_token + + def test_init_raises_value_error_if_no_repo(self): + original_repo = os.environ.pop("GITHUB_REPOSITORY", None) + with self.assertRaisesRegex(ValueError, "GitHub repository.*must be provided"): + GitHubTool(token=self.test_token, logger=self.logger) + if original_repo: os.environ["GITHUB_REPOSITORY"] = original_repo + + def test_clear_resets_branch(self): + tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, initial_branch="feature-branch", logger=self.logger) + # Mock _get_branch_sha for _set_current_branch called by clear + with patch.object(tool, '_get_branch_sha', return_value="sha_for_main"): + tool.clear() + self.assertEqual(tool.current_branch, "main") + + def test_get_functions_returns_list(self): + tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) + functions = tool.get_functions() + self.assertIsInstance(functions, list) + self.assertTrue(len(functions) > 0) + self.assertIn("name", functions[0]["function"]) + + + # --- Test individual private methods --- + + def test_read_file_success(self): + tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) + file_content = "Hello World!" + encoded_content = base64.b64encode(file_content.encode('utf-8')).decode('utf-8') + self.mock_session.get.return_value = create_mock_response(200, json_data={"content": encoded_content}) + + result = tool._read_file(path="test.txt") + self.assertEqual(result, file_content) + self.mock_session.get.assert_called_once_with( + f"{tool.base_url}/repos/{self.test_repo}/contents/test.txt", + params={"ref": "main"} + ) + + def test_read_file_error(self): + tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) + self.mock_session.get.return_value = create_mock_response(404, text_data="Not Found") + result = tool._read_file(path="nonexistent.txt") + self.assertIn("Error reading file", result) + + def test_create_branch_success(self): + tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) + # Mock getting base branch SHA + self.mock_session.get.return_value = create_mock_response(200, json_data={"object": {"sha": "base_sha123"}}) + # Mock creating new branch + self.mock_session.post.return_value = create_mock_response(201, json_data={"ref": "refs/heads/new-feature"}) + + result = tool._create_branch(branch_name="new-feature", base_branch="main") + self.assertIn("Branch 'new-feature' created successfully", result) + self.assertEqual(tool.current_branch, "new-feature") + self.mock_session.get.assert_called_once() # For base branch SHA + self.mock_session.post.assert_called_once() # For creating branch + + def test_create_branch_base_sha_error(self): + tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) + self.mock_session.get.return_value = create_mock_response(404, text_data="Base branch not found") + result = tool._create_branch(branch_name="new-feature", base_branch="nonexistent-base") + self.assertIn("Error getting base branch SHA", result) + + def test_create_branch_creation_error(self): + tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) + self.mock_session.get.return_value = create_mock_response(200, json_data={"object": {"sha": "base_sha456"}}) + self.mock_session.post.return_value = create_mock_response(422, text_data="Validation failed") + result = tool._create_branch(branch_name="bad-branch", base_branch="main") + self.assertIn("Error creating branch", result) + + def test_commit_file_success_new_file(self): + tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) + tool.current_branch = "dev-branch" # Cannot commit to main by default + + # Mock GET for checking file existence (404 means new file) + self.mock_session.get.return_value = create_mock_response(404) + # Mock PUT for committing file + self.mock_session.put.return_value = create_mock_response(201, json_data={"commit": {"sha": "commit_sha_abc"}}) + + result = tool._commit_file(file_path="new_file.py", content="print('Hello')", commit_message="Add new_file.py") + self.assertIn("committed successfully", result) + self.assertIn("commit_sha_abc", result) + self.mock_session.get.assert_called_once() # Check file existence + self.mock_session.put.assert_called_once() # Commit file + + def test_commit_file_success_update_file(self): + tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) + tool.current_branch = "dev-branch" + + # Mock GET for checking file existence (200 means existing file) + self.mock_session.get.return_value = create_mock_response(200, json_data={"sha": "existing_file_sha"}) + # Mock PUT for committing file + self.mock_session.put.return_value = create_mock_response(200, json_data={"commit": {"sha": "commit_sha_def"}}) + + result = tool._commit_file(file_path="existing_file.txt", content="Updated content", commit_message="Update existing_file.txt") + self.assertIn("committed successfully", result) + self.assertIn("commit_sha_def", result) + args, kwargs = self.mock_session.put.call_args + self.assertEqual(kwargs['json']['sha'], "existing_file_sha") + + + def test_commit_file_to_main_branch_error(self): + tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) + tool.current_branch = "main" + result = tool._commit_file(file_path="some.txt", content="content", commit_message="msg") + self.assertIn("Action directly to main branch is not allowed", result) + + def test_create_pull_request_success(self): + tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) + tool.current_branch = "feature-pr" + pr_url = "https://example.com/pull/1" + self.mock_session.post.return_value = create_mock_response(201, json_data={"html_url": pr_url, "number": 1}) + + result = tool._create_pull_request(title="New Feature PR", body="Please review.", base="main") + self.assertIn(f"Pull request created successfully: {pr_url}", result) + self.mock_session.post.assert_called_once() + call_data = self.mock_session.post.call_args[1]['json'] + self.assertEqual(call_data['head'], "feature-pr") + self.assertEqual(call_data['base'], "main") + + def test_create_pull_request_same_branch_error(self): + tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) + tool.current_branch = "main" + result = tool._create_pull_request(title="PR to self", body="This should fail", base="main") + self.assertIn("Cannot create a pull request from branch 'main' to itself", result) + + + def test_list_files_success(self): + tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) + mock_items = [ + {"name": "file1.txt", "type": "file", "path": "dir/file1.txt"}, + {"name": "subdir", "type": "dir", "path": "dir/subdir"} + ] + self.mock_session.get.return_value = create_mock_response(200, json_data=mock_items) + + result = tool._list_files(path="dir") + self.assertEqual(len(result), 2) + self.assertEqual(result[0]["name"], "file1.txt") + self.assertEqual(result[1]["type"], "dir") + + def test_search_code_success(self): + tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) + mock_search_results = { + "items": [{"path": "src/code.py", "html_url": "url1"}] + } + self.mock_session.get.return_value = create_mock_response(200, json_data=mock_search_results) + + results = tool._search_code(query="my_function") + self.assertEqual(len(results), 1) + self.assertEqual(results[0]["path"], "src/code.py") + + def test_get_commit_history_success(self): + tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) + mock_commits = [{ + "sha": "sha1", "commit": {"message": "Msg1", "author": {"name": "Authy", "date": "Date1"}} + }] + self.mock_session.get.return_value = create_mock_response(200, json_data=mock_commits) + + commits = tool._get_commit_history(file_path="file.txt", num_commits=1) + self.assertEqual(len(commits), 1) + self.assertEqual(commits[0]["sha"], "sha1") + + def test_set_current_branch_success(self): + tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) + # Mock _get_branch_sha to simulate branch exists + with patch.object(tool, '_get_branch_sha', return_value="some_sha_for_dev"): + result = tool._set_current_branch(branch_name="dev") + self.assertEqual(tool.current_branch, "dev") + self.assertIn("Current branch set to: dev", result) + + def test_set_current_branch_not_exists(self): + tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) + with patch.object(tool, '_get_branch_sha', return_value="Error getting SHA for branch"): + result = tool._set_current_branch(branch_name="nonexistent-branch") + self.assertNotEqual(tool.current_branch, "nonexistent-branch") # Should not change + self.assertIn("Cannot set current branch", result) + + + def test_list_branches_single_page(self): + tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) + mock_branches = [{"name": "main"}, {"name": "dev"}] + self.mock_session.get.return_value = create_mock_response(200, json_data=mock_branches, links={}) # No "next" link + + branches = tool._list_branches(all_pages=True) + self.assertEqual(branches, ["main", "dev"]) + self.mock_session.get.assert_called_once() + + def test_list_branches_multiple_pages(self): + tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) + + # Page 1 response + page1_branches = [{"name": "branch1"}, {"name": "branch2"}] + next_url = f"{tool.base_url}/repos/{self.test_repo}/branches?page=2" + response1 = create_mock_response(200, json_data=page1_branches, links={"next": {"url": next_url}}) + + # Page 2 response + page2_branches = [{"name": "branch3"}] + response2 = create_mock_response(200, json_data=page2_branches, links={}) # No "next" link + + self.mock_session.get.side_effect = [response1, response2] + + branches = tool._list_branches(all_pages=True) + self.assertEqual(branches, ["branch1", "branch2", "branch3"]) + self.assertEqual(self.mock_session.get.call_count, 2) + + # Check that the second call used the next_url + calls = self.mock_session.get.call_args_list + self.assertEqual(calls[1][0][0], next_url) # args[0] is the URL + + # --- Test execute dispatcher --- + def test_execute_read_file(self): + tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) + with patch.object(tool, '_read_file', return_value="file content") as mock_method: + result = tool.execute(function_name="read_file", path="test.md") + mock_method.assert_called_once_with(path="test.md") + self.assertEqual(result, "file content") + + def test_execute_unknown_function(self): + tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) + result = tool.execute(function_name="non_existent_function_name", arg1="val1") + self.assertIn("Unknown function: non_existent_function_name", result) + + def test_execute_method_exception(self): + tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger) + with patch.object(tool, '_read_file', side_effect=Exception("Kaboom")) as mock_method: + result = tool.execute(function_name="read_file", path="crash.txt") + self.assertIn("Error during read_file execution: Kaboom", result) + +if __name__ == '__main__': + unittest.main() diff --git a/tests/tools/test_log_tool.py b/tests/tools/test_log_tool.py new file mode 100644 index 0000000..a765b06 --- /dev/null +++ b/tests/tools/test_log_tool.py @@ -0,0 +1,146 @@ +import unittest +from unittest.mock import patch, mock_open, MagicMock +import os +import logging +from datetime import datetime, timedelta + +# Ensure tools/log_tool.py is accessible +from tools.log_tool import LogTool + +class TestLogTool(unittest.TestCase): + + def setUp(self): + self.test_log_file_path = "test_dummy_log.log" + # Suppress logging output during tests unless explicitly testing for it + self.logger = logging.getLogger('tools.log_tool') + # Ensure only one NullHandler to prevent duplicate messages if tests run multiple times in a session + if not any(isinstance(h, logging.NullHandler) for h in self.logger.handlers): + self.logger.addHandler(logging.NullHandler()) + self.logger.propagate = False # Prevent propagation to root logger if it has handlers + + + def test_init_default_log_path(self): + tool = LogTool(logger=self.logger) + self.assertEqual(tool.configured_log_file_path, 'logs/output.log') + + def test_init_custom_log_path(self): + tool = LogTool(log_file_path=self.test_log_file_path, logger=self.logger) + self.assertEqual(tool.configured_log_file_path, self.test_log_file_path) + + def test_get_functions(self): + tool = LogTool(logger=self.logger) + functions = tool.get_functions() + self.assertIsInstance(functions, list) + self.assertEqual(len(functions), 1) + self.assertEqual(functions[0]["function"]["name"], "get_log_contents") + + @patch("os.path.exists", return_value=False) + def test_get_log_contents_file_not_exists(self, mock_exists): + tool = LogTool(log_file_path=self.test_log_file_path, logger=self.logger) + result = tool._get_log_contents() + self.assertIn("Log file does not exist", result) + mock_exists.assert_called_once_with(self.test_log_file_path) + + @patch("os.path.exists", return_value=True) + @patch("builtins.open", new_callable=mock_open, read_data="line1\nline2\nline3\nline4\nline5") + def test_get_log_contents_with_line_count(self, mock_file_open, mock_exists): + tool = LogTool(log_file_path=self.test_log_file_path, logger=self.logger) + + result = tool._get_log_contents(line_count=3) + self.assertEqual(result, "line3\nline4\nline5") + mock_exists.assert_called_once_with(self.test_log_file_path) + mock_file_open.assert_called_once_with(self.test_log_file_path, 'r', encoding='utf-8') + + @patch("os.path.exists", return_value=True) + @patch("builtins.open", new_callable=mock_open, read_data="line1\nline2\n") + def test_get_log_contents_line_count_more_than_available(self, mock_file_open, mock_exists): + tool = LogTool(log_file_path=self.test_log_file_path, logger=self.logger) + result = tool._get_log_contents(line_count=5) + self.assertEqual(result, "line1\nline2\n") + + @patch("os.path.exists", return_value=True) + @patch("builtins.open", new_callable=mock_open, read_data="line1\nline2\n") + def test_get_log_contents_invalid_line_count_uses_default(self, mock_file_open, mock_exists): + tool = LogTool(log_file_path=self.test_log_file_path, logger=self.logger) + # Test with zero, negative, and non-integer line_count (though type hint is int) + # The code defaults to 150 if invalid. Here, we only have 2 lines. + with patch.object(tool.logger, 'warning') as mock_log_warning: + result_zero = tool._get_log_contents(line_count=0) + self.assertEqual(result_zero, "line1\nline2\n") + mock_log_warning.assert_any_call("Invalid line_count '0' provided, defaulting to fetch last 150 lines.") + + mock_file_open.reset_mock() # Reset for next call + result_neg = tool._get_log_contents(line_count=-5) + self.assertEqual(result_neg, "line1\nline2\n") + mock_log_warning.assert_any_call("Invalid line_count '-5' provided, defaulting to fetch last 150 lines.") + + + @patch("os.path.exists", return_value=True) + def test_get_log_contents_last_24_hours(self, mock_exists): + tool = LogTool(log_file_path=self.test_log_file_path, logger=self.logger) + + now = datetime.now() + one_hour_ago_dt = now - timedelta(hours=1) + two_days_ago_dt = now - timedelta(days=2) + + one_hour_ago_str = one_hour_ago_dt.strftime(LogTool.EXPECTED_LOG_TIMESTAMP_FORMAT) + two_days_ago_str = two_days_ago_dt.strftime(LogTool.EXPECTED_LOG_TIMESTAMP_FORMAT) + + log_data = ( + f"{two_days_ago_str} - OLD - This is an old log entry.\n" + f"No timestamp here - this line should be skipped by time filter.\n" + f"{one_hour_ago_str} - RECENT - This is a recent log entry.\n" + f"Malformed Date 2023-xx-01 - Another skipped line.\n" + f"{now.strftime(LogTool.EXPECTED_LOG_TIMESTAMP_FORMAT)} - CURRENT - This is a current log entry.\n" + ) + + expected_output = ( + f"{one_hour_ago_str} - RECENT - This is a recent log entry.\n" + f"{now.strftime(LogTool.EXPECTED_LOG_TIMESTAMP_FORMAT)} - CURRENT - This is a current log entry.\n" + ) + + with patch("builtins.open", mock_open(read_data=log_data)): + result = tool._get_log_contents(line_count=None) # Trigger 24-hour logic + self.assertEqual(result, expected_output) + + @patch("os.path.exists", return_value=True) + @patch("builtins.open", side_effect=IOError("File read error!")) + def test_get_log_contents_file_read_exception(self, mock_file_open, mock_exists): + tool = LogTool(log_file_path=self.test_log_file_path, logger=self.logger) + result = tool._get_log_contents(line_count=10) + self.assertIn("An error occurred while reading the log file: File read error!", result) + + def test_execute_get_log_contents(self): + tool = LogTool(logger=self.logger) + mock_return_value = "Mocked log content" + with patch.object(tool, '_get_log_contents', return_value=mock_return_value) as mock_method: + result = tool.execute(function_name="get_log_contents", line_count=50) + mock_method.assert_called_once_with(line_count=50) + self.assertEqual(result, mock_return_value) + + def test_execute_get_log_contents_no_line_count(self): + tool = LogTool(logger=self.logger) + mock_return_value = "Mocked log content for 24h" + with patch.object(tool, '_get_log_contents', return_value=mock_return_value) as mock_method: + result = tool.execute(function_name="get_log_contents") # No line_count + mock_method.assert_called_once_with(line_count=None) # Expects None to trigger 24h + self.assertEqual(result, mock_return_value) + + + def test_execute_unknown_function(self): + tool = LogTool(logger=self.logger) + result = tool.execute(function_name="non_existent_log_function") + self.assertIn("Unknown function: non_existent_log_function", result) + + def test_clear_method(self): + tool = LogTool(logger=self.logger) + # Set a specific level for the logger for this test if needed to capture debug + original_level = tool.logger.level + tool.logger.setLevel(logging.DEBUG) + with self.assertLogs(tool.logger, level='DEBUG') as cm: + tool.clear() + self.assertTrue(any("LogTool clear called" in message for message in cm.output)) + tool.logger.setLevel(original_level) # Reset level + +if __name__ == '__main__': + unittest.main() diff --git a/tests/tools/test_metrics.py b/tests/tools/test_metrics.py new file mode 100644 index 0000000..902abb7 --- /dev/null +++ b/tests/tools/test_metrics.py @@ -0,0 +1,217 @@ +import unittest +from unittest.mock import patch, MagicMock, ANY +import time +import logging + +# Ensure tools.metrics is accessible +from tools.metrics import Metrics # Import the class itself for direct testing +from tools.metrics import metrics as global_metrics_instance # Import the global instance + +# A simple function to decorate for testing +def sample_function_for_metrics(duration=0.01): + # Simulate some work + # Note: time.sleep is not always precisely profiled by cProfile in the same way as CPU-bound work. + # For testing, we will mock the cProfile/pstats interaction rather than relying on actual sleep duration. + if duration > 0: # Make it conditional so we can test zero-time case too + pass # The actual work is not important when mocking cProfile results + return "sample_output" + +def another_sample_function(x, y): + return x + y + +class TestMetrics(unittest.TestCase): + + def setUp(self): + # Create a fresh Metrics instance for most tests to avoid interference + self.logger = logging.getLogger('tools.metrics.test') + if not self.logger.handlers: # Avoid adding handler multiple times + self.logger.addHandler(logging.NullHandler()) + self.metrics_instance = Metrics(logger=self.logger) + + # Clear the global instance before each test that might use it + global_metrics_instance.clear_metrics() + + def test_measure_decorator_counts_calls(self): + decorated_func = self.metrics_instance.measure(sample_function_for_metrics) + + self.assertEqual(self.metrics_instance.call_count["sample_function_for_metrics"], 0) + decorated_func() + self.assertEqual(self.metrics_instance.call_count["sample_function_for_metrics"], 1) + decorated_func() + self.assertEqual(self.metrics_instance.call_count["sample_function_for_metrics"], 2) + + @patch('cProfile.Profile') + @patch('pstats.Stats') + def test_measure_decorator_records_time(self, MockPStats, MockCProfile): + # Mock cProfile and pstats to control the time value + mock_profiler_instance = MockCProfile.return_value + mock_pstats_instance = MockPStats.return_value + + # Simulate that pstats.Stats.stats dictionary contains the function's stats + # Key: (filename, lineno, funcname) + # Value: (cc, nc, tt, ct, callers) where ct is cumulative_time (index 3) + + # Get code object of the function *before* decoration for correct key + original_func_code = sample_function_for_metrics.__code__ + func_key = (original_func_code.co_filename, original_func_code.co_firstlineno, original_func_code.co_name) + + # Configure mock_pstats_instance.stats to return our desired time + mock_pstats_instance.stats = {func_key: (1, 1, 0.05, 0.123, {})} # cc, nc, tt, ct=0.123 + + decorated_func = self.metrics_instance.measure(sample_function_for_metrics) + + self.assertEqual(self.metrics_instance.total_time["sample_function_for_metrics"], 0) + + # Call the decorated function + decorated_func(duration=0) # Duration arg doesn't matter due to mocking + + # Assertions + mock_profiler_instance.enable.assert_called_once() + mock_profiler_instance.disable.assert_called_once() + MockPStats.assert_called_once_with(mock_profiler_instance) + + self.assertEqual(self.metrics_instance.total_time["sample_function_for_metrics"], 0.123) + + # Call again to see accumulation + # Reset mock stats for a new time value if needed, or assume same time per call + mock_pstats_instance.stats = {func_key: (1, 1, 0.05, 0.100, {})} # New ct=0.100 + decorated_func(duration=0) + self.assertAlmostEqual(self.metrics_instance.total_time["sample_function_for_metrics"], 0.123 + 0.100) + + + @patch('cProfile.Profile') + @patch('pstats.Stats') + def test_measure_decorator_fallback_time_recording_by_name(self, MockPStats, MockCProfile): + mock_profiler_instance = MockCProfile.return_value + mock_pstats_instance = MockPStats.return_value + + original_func_code = sample_function_for_metrics.__code__ # func to be decorated + # Simulate the primary key lookup fails by creating a slightly different key for what we expect + # This is what the code will try to look up first. + expected_primary_key = (original_func_code.co_filename, original_func_code.co_firstlineno, original_func_code.co_name) + + # This is the key that will *actually* be in pstats.stats, simulating a mismatch for primary lookup + # but a match for the by-name fallback. + actual_stats_key_in_pstats = (original_func_code.co_filename, + original_func_code.co_firstlineno + 5, # simulate a lineno difference for primary key mismatch + original_func_code.co_name) # Name is the same for fallback + + mock_pstats_instance.stats = { + # expected_primary_key is NOT present + actual_stats_key_in_pstats: (1, 1, 0.03, 0.077, {}) # ct = 0.077 + } + + decorated_func = self.metrics_instance.measure(sample_function_for_metrics) + + # Expecting a debug log for fallback, but assertLogs needs the logger to have a handler that captures. + # self.logger is already set up with NullHandler. For this test, let's use a specific logger. + metrics_internal_logger = logging.getLogger('tools.metrics') # Logger used inside Metrics class + original_level = metrics_internal_logger.level + metrics_internal_logger.setLevel(logging.DEBUG) + + with self.assertLogs(metrics_internal_logger, level='DEBUG') as log_capture: + decorated_func(duration=0) + + metrics_internal_logger.setLevel(original_level) # Reset logger level + + self.assertTrue(any("Found stats for sample_function_for_metrics by name" in msg for msg in log_capture.output)) + self.assertEqual(self.metrics_instance.total_time["sample_function_for_metrics"], 0.077) + + + @patch('cProfile.Profile') + @patch('pstats.Stats') + def test_measure_decorator_handles_func_stats_not_found(self, MockPStats, MockCProfile): + mock_profiler_instance = MockCProfile.return_value + mock_pstats_instance = MockPStats.return_value + mock_pstats_instance.stats = {} # Empty stats, function will not be found + + decorated_func = self.metrics_instance.measure(sample_function_for_metrics) + + metrics_internal_logger = logging.getLogger('tools.metrics') + original_level = metrics_internal_logger.level + metrics_internal_logger.setLevel(logging.WARNING) + with self.assertLogs(metrics_internal_logger, level='WARNING') as log_capture: + decorated_func(duration=0) + metrics_internal_logger.setLevel(original_level) + + self.assertTrue(any("Could not find exact cProfile stats" in msg for msg in log_capture.output)) + self.assertEqual(self.metrics_instance.total_time["sample_function_for_metrics"], 0) + + + def test_get_metrics_empty(self): + self.assertEqual(self.metrics_instance.get_metrics(), {}) + + @patch('cProfile.Profile') + @patch('pstats.Stats') + def test_get_metrics_with_data(self, MockPStats, MockCProfile): + mock_pstats_instance = MockPStats.return_value + + # Decorate two different functions + decorated_func1 = self.metrics_instance.measure(sample_function_for_metrics) + decorated_func2 = self.metrics_instance.measure(another_sample_function) + + # Data for func1 + func1_code = sample_function_for_metrics.__code__ + func1_key = (func1_code.co_filename, func1_code.co_firstlineno, func1_code.co_name) + mock_pstats_instance.stats = {func1_key: (1,1,0.1,0.1,{})} + decorated_func1() + + # Data for func2 + func2_code = another_sample_function.__code__ + func2_key = (func2_code.co_filename, func2_code.co_firstlineno, func2_code.co_name) + mock_pstats_instance.stats = {func2_key: (1,1,0.2,0.2,{})} # Cumulative time 0.2 + decorated_func2(1,2) + mock_pstats_instance.stats = {func2_key: (1,1,0.3,0.3,{})} # Cumulative time 0.3 for second call + decorated_func2(3,4) + + metrics_data = self.metrics_instance.get_metrics() + + self.assertIn("sample_function_for_metrics", metrics_data) + self.assertEqual(metrics_data["sample_function_for_metrics"]["call_count"], 1) + self.assertEqual(metrics_data["sample_function_for_metrics"]["total_time"], 0.1) + self.assertEqual(metrics_data["sample_function_for_metrics"]["average_time"], 0.1) + + self.assertIn("another_sample_function", metrics_data) + self.assertEqual(metrics_data["another_sample_function"]["call_count"], 2) + self.assertAlmostEqual(metrics_data["another_sample_function"]["total_time"], 0.5) + self.assertAlmostEqual(metrics_data["another_sample_function"]["average_time"], 0.25) + + + def test_clear_metrics(self): + # Add some data + self.metrics_instance.call_count["test_func"] = 5 + self.metrics_instance.total_time["test_func"] = 1.234 + + self.metrics_instance.clear_metrics() + + self.assertEqual(self.metrics_instance.call_count, {}) + self.assertEqual(self.metrics_instance.total_time, {}) + self.assertEqual(self.metrics_instance.get_metrics(), {}) + + # Test the global instance + @patch('cProfile.Profile') + @patch('pstats.Stats') + def test_global_metrics_instance_usage(self, MockPStats, MockCProfile): + mock_pstats_instance = MockPStats.return_value + + # Decorate a function with the global instance + @global_metrics_instance.measure + def globally_decorated_func(): + return "global_output" + + # Setup mock stats for the globally decorated function + # Access __wrapped__ to get the original function if other decorators might be present or for consistency. + original_g_func = globally_decorated_func.__wrapped__ + func_code = original_g_func.__code__ + func_key = (func_code.co_filename, func_code.co_firstlineno, func_code.co_name) + mock_pstats_instance.stats = {func_key: (1,1,0.05,0.05,{})} + + globally_decorated_func() + + metrics_data = global_metrics_instance.get_metrics() + self.assertIn("globally_decorated_func", metrics_data) + self.assertEqual(metrics_data["globally_decorated_func"]["call_count"], 1) + self.assertEqual(metrics_data["globally_decorated_func"]["total_time"], 0.05) + +if __name__ == '__main__': + unittest.main() diff --git a/tests/tools/test_metrics_tool.py b/tests/tools/test_metrics_tool.py new file mode 100644 index 0000000..17c1b9d --- /dev/null +++ b/tests/tools/test_metrics_tool.py @@ -0,0 +1,161 @@ +import unittest +from unittest.mock import MagicMock, patch +import logging + +# Ensure tools.metrics_tool and tools.metrics are accessible +from tools.metrics_tool import MetricsTool +from tools.metrics import Metrics # Used for typehinting and creating a mockable instance + +class TestMetricsTool(unittest.TestCase): + + def setUp(self): + self.mock_metrics_provider = MagicMock(spec=Metrics) + self.logger = logging.getLogger('tools.metrics_tool.test') + if not self.logger.handlers: + self.logger.addHandler(logging.NullHandler()) + self.logger.propagate = False + + + def test_init_with_provider(self): + tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger) + self.assertEqual(tool.metrics_provider, self.mock_metrics_provider) + + @patch('tools.metrics_tool.global_metrics_instance') # Patch the global instance path + def test_init_default_provider(self, mock_global_metrics): + tool = MetricsTool(logger=self.logger) + self.assertEqual(tool.metrics_provider, mock_global_metrics) + + def test_get_functions(self): + tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger) + functions = tool.get_functions() + self.assertIsInstance(functions, list) + self.assertTrue(len(functions) == 3) # Based on current definition + self.assertIn("get_function_metrics", [f["function"]["name"] for f in functions]) + self.assertIn("get_specific_function_metrics", [f["function"]["name"] for f in functions]) + self.assertIn("get_top_n_functions", [f["function"]["name"] for f in functions]) + + def test_execute_get_function_metrics(self): + tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger) + expected_metrics = {"func1": {"call_count": 1, "total_time": 0.1}} + self.mock_metrics_provider.get_metrics.return_value = expected_metrics + + result = tool.execute(function_name="get_function_metrics") + + self.mock_metrics_provider.get_metrics.assert_called_once() + self.assertEqual(result, expected_metrics) + + def test_execute_get_specific_function_metrics_found(self): + tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger) + func_metrics = {"call_count": 5, "total_time": 0.5, "average_time": 0.1} + all_metrics = {"specific_func": func_metrics, "other_func": {}} + self.mock_metrics_provider.get_metrics.return_value = all_metrics + + # The execute method expects kwargs that match the function parameters in get_functions. + # So, the argument name for the function to get is 'function_name' in the tool's spec. + result = tool.execute(function_name="get_specific_function_metrics", **{"function_name": "specific_func"}) + self.assertEqual(result, func_metrics) + + def test_execute_get_specific_function_metrics_not_found(self): + tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger) + self.mock_metrics_provider.get_metrics.return_value = {"other_func": {}} + + result = tool.execute(function_name="get_specific_function_metrics", **{"function_name": "non_existent_func"}) + self.assertEqual(result, "No metrics found for function: non_existent_func") + + def test_execute_get_specific_function_metrics_missing_arg(self): + tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger) + result = tool.execute(function_name="get_specific_function_metrics") # Missing function_name kwarg + self.assertIn("Error: Missing required argument 'function_name'", result) + + + def test_execute_get_top_n_functions(self): + tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger) + metrics_data = { + "func_a": {"call_count": 1, "total_time": 0.3}, + "func_b": {"call_count": 1, "total_time": 0.1}, + "func_c": {"call_count": 1, "total_time": 0.5}, + "func_d": {"call_count": 1, "total_time": 0.2}, + } + self.mock_metrics_provider.get_metrics.return_value = metrics_data + + # Test getting top 2 + result = tool.execute(function_name="get_top_n_functions", n=2) + expected_top_2 = {"func_c": metrics_data["func_c"], "func_a": metrics_data["func_a"]} + self.assertEqual(result, expected_top_2) + + # Test getting top 1 + result_top_1 = tool.execute(function_name="get_top_n_functions", n=1) + expected_top_1 = {"func_c": metrics_data["func_c"]} + self.assertEqual(result_top_1, expected_top_1) + + # Test N larger than available functions + result_top_all = tool.execute(function_name="get_top_n_functions", n=10) + # Order should be func_c, func_a, func_d, func_b + expected_top_all_keys = ["func_c", "func_a", "func_d", "func_b"] + self.assertEqual(list(result_top_all.keys()), expected_top_all_keys) + + def test_execute_get_top_n_functions_malformed_metrics(self): + tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger) + metrics_data = { + "func_a": {"call_count": 1, "total_time": 0.3}, + "func_b": "not a dict", # Malformed + "func_c": {"call_count": 1}, # Missing total_time + "func_d": {"call_count": 1, "total_time": 0.2}, + } + self.mock_metrics_provider.get_metrics.return_value = metrics_data + + metrics_tool_logger = logging.getLogger('tools.metrics_tool') + original_level = metrics_tool_logger.level + metrics_tool_logger.setLevel(logging.WARNING) + with self.assertLogs(metrics_tool_logger, level='WARNING') as log_capture: + result = tool.execute(function_name="get_top_n_functions", n=2) + metrics_tool_logger.setLevel(original_level) + + # Check that warnings were logged for malformed items + self.assertTrue(any("Metric item for 'func_b' is not in expected format" in msg for msg in log_capture.output)) + self.assertTrue(any("Metric item for 'func_c' is not in expected format" in msg for msg in log_capture.output)) + + # Expected: func_a, func_d (as they are valid and sortable) + expected_result = { + "func_a": metrics_data["func_a"], + "func_d": metrics_data["func_d"] + } + self.assertEqual(result, expected_result) + + + def test_execute_get_top_n_functions_invalid_n(self): + tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger) + self.mock_metrics_provider.get_metrics.return_value = {} # No metrics needed for this test + + result_zero = tool.execute(function_name="get_top_n_functions", n=0) + self.assertIn("Error: Argument 'n' must be a positive integer.", result_zero) + + result_negative = tool.execute(function_name="get_top_n_functions", n=-1) + self.assertIn("Error: Argument 'n' must be a positive integer.", result_negative) + + result_string = tool.execute(function_name="get_top_n_functions", n="abc") + self.assertIn("Error: Argument 'n' must be an integer.", result_string) + + def test_execute_get_top_n_functions_missing_arg_n(self): + tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger) + result = tool.execute(function_name="get_top_n_functions") # Missing n + self.assertIn("Error: Missing required argument 'n'.", result) + + + def test_execute_unknown_function(self): + tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger) + result = tool.execute(function_name="non_existent_metrics_function") + self.assertIn("Unknown function: non_existent_metrics_function", result) + + def test_clear_method(self): + tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger) + metrics_tool_logger = logging.getLogger('tools.metrics_tool') + original_level = metrics_tool_logger.level + metrics_tool_logger.setLevel(logging.DEBUG) + with self.assertLogs(metrics_tool_logger, level='DEBUG') as cm: + tool.clear() + metrics_tool_logger.setLevel(original_level) + self.assertTrue(any("MetricsTool clear method called" in message for message in cm.output)) + +if __name__ == '__main__': + unittest.main() diff --git a/tools/github_tool.py b/tools/github_tool.py index a14a063..7720fa8 100644 --- a/tools/github_tool.py +++ b/tools/github_tool.py @@ -7,42 +7,41 @@ import base64 import logging class GitHubTool(BaseTool): - def __init__(self): - self.base_url = "https://api.github.com" - self.token = os.environ.get("GITHUB_TOKEN") + def __init__(self, session=None, token=None, repo=None, base_url=None, initial_branch="main", logger=None): + self.base_url = base_url if base_url else "https://api.github.com" + self._token = token if token else os.environ.get("GITHUB_TOKEN") + self._repo = repo if repo else os.environ.get("GITHUB_REPOSITORY") - self.headers = { - "Authorization": f"token {self.token}", - "Accept": "application/vnd.github.v3+json" - } - self.repo = os.environ.get("GITHUB_REPOSITORY") - self.current_branch = "main" # Default to main branch + if not self._token: + # In a real scenario, might raise an error or operate in a degraded mode. + # For this tool, token is essential. + raise ValueError("GitHub token must be provided either as an argument or via GITHUB_TOKEN env var.") + if not self._repo: + raise ValueError("GitHub repository (e.g., 'owner/repo') must be provided either as an argument or via GITHUB_REPOSITORY env var.") - # Set up logging - self.logger = logging.getLogger(__name__) - self.logger.setLevel(logging.INFO) + if session: + self.session = session + else: + self.session = requests.Session() + self.session.headers.update({ + "Authorization": f"token {self._token}", + "Accept": "application/vnd.github.v3+json" + }) - # Create a file handler - file_handler = logging.FileHandler('github_tool.log') - file_handler.setLevel(logging.INFO) - - # Create a console handler - console_handler = logging.StreamHandler() - console_handler.setLevel(logging.INFO) - - # Create a formatting for the logs - formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') - file_handler.setFormatter(formatter) - console_handler.setFormatter(formatter) - - # Add the handlers to the logger - self.logger.addHandler(file_handler) - self.logger.addHandler(console_handler) + self.current_branch = initial_branch + + # Use provided logger or get a new one for the module + # The application using this tool should configure the logging handlers and formatting. + self.logger = logger if logger else logging.getLogger(__name__) + # If no handlers are configured by the application, add a NullHandler + # to prevent "No handler found" warnings if the tool logs something. + if not self.logger.handlers: + self.logger.addHandler(logging.NullHandler()) def clear(self): - if (self.current_branch != "main"): + if self.current_branch != "main": self._set_current_branch("main") - pass + self.logger.info(f"GitHubTool state cleared. Current branch is {self.current_branch}") def get_functions(self): return [ @@ -372,7 +371,7 @@ class GitHubTool(BaseTool): } } }, - { # New function definition for get_pull_request_general_comments + { "type": "function", "function": { "name": "get_pull_request_general_comments", @@ -485,7 +484,6 @@ class GitHubTool(BaseTool): } } }, - # New functions for PR review { "type": "function", "function": { @@ -581,740 +579,679 @@ class GitHubTool(BaseTool): } ] - @metrics.measure def execute(self, function_name, **kwargs): - self.logger.info(f"Executing: {function_name}") - - if function_name == "read_file": - return self._read_file(kwargs["path"]) - elif function_name == "create_branch": - return self._create_branch(kwargs["branch_name"], kwargs.get("base_branch", "main")) - elif function_name == "commit_file": - return self._commit_file(kwargs["file_path"], kwargs["content"], kwargs["commit_message"]) - elif function_name == "create_pull_request": - return self._create_pull_request(kwargs["title"], kwargs["body"], kwargs.get("base", "main")) - elif function_name == "list_files": - return self._list_files(kwargs["path"]) - elif function_name == "search_code": - return self._search_code(kwargs["query"]) - elif function_name == "get_commit_history": - return self._get_commit_history(kwargs["file_path"], kwargs.get("num_commits", 10)) - elif function_name == "view_commit_details_for_file": - return self._view_commit_details_for_file(kwargs["file_path"], kwargs.get("num_commits", 10)) - elif function_name == "get_current_branch": - return self._get_current_branch() - elif function_name == "set_current_branch": - return self._set_current_branch(kwargs["branch_name"]) - elif function_name == "get_file_at_commit": - return self._get_file_at_commit(kwargs["file_path"], kwargs["commit_sha"]) - elif function_name == "list_branches": - return self._list_branches(kwargs.get("per_page", 100), kwargs.get("all_pages", True)) - elif function_name == "get_branch_sha": - return self._get_branch_sha(kwargs["branch"]) - elif function_name == "approve_pull_request": - return self._approve_pull_request(kwargs["pull_number"]) - elif function_name == "close_pull_request": - return self._close_pull_request(kwargs["pull_number"]) - elif function_name == "merge_pull_request": - return self._merge_pull_request(kwargs["pull_number"], kwargs.get("commit_title", "Merge pull request"), - kwargs.get("commit_message", ""), kwargs.get("merge_method", "merge")) - elif function_name == "delete_branch": - return self._delete_branch(kwargs["branch_name"]) - elif function_name == "get_issue_details": - return self._get_issue_details(kwargs["issue_number"]) - elif function_name == "create_issue": - return self._create_issue(kwargs["title"], kwargs["body"], kwargs.get("labels", [])) - elif function_name == "list_issues": - return self._list_issues(kwargs.get("state", "open"), kwargs.get("per_page", 30), kwargs.get("page", 1)) - elif function_name == "add_issue_comment": - return self._add_issue_comment(kwargs["issue_number"], kwargs["comment"]) - elif function_name == "get_issue_comments": - return self._get_issue_comments(kwargs["issue_number"]) - elif function_name == "get_pull_request_general_comments": # New dispatch entry - return self._get_pull_request_general_comments(kwargs["pull_number"]) - elif function_name == "create_project_board": - return self._create_project_board(kwargs["name"], kwargs.get("body", "")) - elif function_name == "create_project_column": - return self._create_project_column(kwargs["project_id"], kwargs["column_name"]) - elif function_name == "create_project_card": - return self._create_project_card(kwargs["column_id"], kwargs["note"]) - elif function_name == "move_project_card": - return self._move_project_card(kwargs["card_id"], kwargs["position"], kwargs["column_id"]) - elif function_name == "link_issue_to_project_card": - return self._link_issue_to_project_card(kwargs["card_id"], kwargs["content_id"], kwargs["content_type"]) - elif function_name == "list_project_boards": - return self._list_project_boards() - elif function_name == "view_project_board_items": - return self._view_project_board_items(kwargs["project_id"]) - # New function dispatching - elif function_name == "get_pull_request_details": - return self._get_pull_request_details(kwargs["pull_number"]) - elif function_name == "get_pull_request_diff": - return self._get_pull_request_diff(kwargs["pull_number"]) - elif function_name == "get_pull_request_files": - return self._get_pull_request_files(kwargs["pull_number"]) - elif function_name == "create_pull_request_review_comment": - return self._create_pull_request_review_comment(kwargs["pull_number"], kwargs["body"], kwargs["commit_id"], - kwargs["path"], kwargs["position"], kwargs.get("side", "RIGHT"), - kwargs.get("start_line"), kwargs.get("start_side")) - elif function_name == "list_pull_request_review_comments": - return self._list_pull_request_review_comments(kwargs["pull_number"]) - elif function_name == "submit_pull_request_review": - return self._submit_pull_request_review(kwargs["pull_number"], kwargs["event"], kwargs.get("body")) + self.logger.info(f"Executing GitHub Tool function: {function_name} with args: {kwargs}") + # Dispatch to the appropriate private method + method_name = f"_{function_name}" + if hasattr(self, method_name): + method = getattr(self, method_name) + try: + return method(**kwargs) # Ensure only expected args are passed if method signature is strict + except Exception as e: + self.logger.error(f"Error executing {method_name}: {e}", exc_info=True) + return f"Error during {function_name} execution: {str(e)}" else: error_message = f"Unknown function: {function_name}" self.logger.error(error_message) return error_message + # Private methods for each function, using self.session for HTTP requests + @metrics.measure def _read_file(self, path): self.logger.info(f"Reading file: {path} from branch: {self.current_branch}") - url = f"{self.base_url}/repos/{self.repo}/contents/{path}" - response = requests.get(url, headers=self.headers, params={"ref": self.current_branch}) + url = f"{self.base_url}/repos/{self._repo}/contents/{path}" + response = self.session.get(url, params={"ref": self.current_branch}) if response.status_code == 200: content = response.json()["content"] decoded_content = base64.b64decode(content).decode('utf-8') self.logger.info(f"Successfully read file: {path}") return decoded_content else: - error_message = f"Error reading file: {response.status_code}" + error_message = f"Error reading file ({path}): {response.status_code} - {response.text}" self.logger.error(error_message) return error_message @metrics.measure - def _create_branch(self, branch_name, base_branch): + def _create_branch(self, branch_name, base_branch="main"): self.logger.info(f"Creating branch: {branch_name} from base: {base_branch}") - url = f"{self.base_url}/repos/{self.repo}/git/refs" - response = requests.get(f"{url}/heads/{base_branch}", headers=self.headers) - if response.status_code != 200: - error_message = f"Error getting base branch: {response.status_code}" + # Get SHA of base branch + ref_url = f"{self.base_url}/repos/{self._repo}/git/refs/heads/{base_branch}" + response_sha = self.session.get(ref_url) + if response_sha.status_code != 200: + error_message = f"Error getting base branch SHA ({base_branch}): {response_sha.status_code} - {response_sha.text}" self.logger.error(error_message) return error_message - - sha = response.json()["object"]["sha"] - data = { - "ref": f"refs/heads/{branch_name}", - "sha": sha - } - response = requests.post(url, headers=self.headers, json=data) - if response.status_code == 201: + sha = response_sha.json()["object"]["sha"] + + # Create new branch + create_ref_url = f"{self.base_url}/repos/{self._repo}/git/refs" + data = {"ref": f"refs/heads/{branch_name}", "sha": sha} + response_create = self.session.post(create_ref_url, json=data) + if response_create.status_code == 201: self.current_branch = branch_name - success_message = f"Branch '{branch_name}' created successfully and set as current branch" + success_message = f"Branch '{branch_name}' created successfully from '{base_branch}' and set as current branch." self.logger.info(success_message) return success_message else: - error_message = f"Error creating branch: {response.status_code}" + error_message = f"Error creating branch '{branch_name}': {response_create.status_code} - {response_create.text}" self.logger.error(error_message) return error_message @metrics.measure def _commit_file(self, file_path, content, commit_message): - self.logger.info(f"Committing file: {file_path} to branch: {self.current_branch}") + self.logger.info(f"Committing file: {file_path} to branch: {self.current_branch} with message: '{commit_message}'") if self.current_branch == "main": - error_message = "Cannot commit directly to main branch" - self.logger.error(error_message) + error_message = "Action directly to main branch is not allowed. Please create and switch to a new branch first." + self.logger.warning(error_message) return error_message - url = f"{self.base_url}/repos/{self.repo}/contents/{file_path}" - - self.logger.info("Checking if file already exists") - response = requests.get(url, headers=self.headers, params={"ref": self.current_branch}) - + url = f"{self.base_url}/repos/{self._repo}/contents/{file_path}" + encoded_content = base64.b64encode(content.encode('utf-8')).decode('utf-8') data = { "message": commit_message, - "content": base64.b64encode(content.encode()).decode(), + "content": encoded_content, "branch": self.current_branch } - if response.status_code == 200: - self.logger.info("File exists, updating") - file_sha = response.json()["sha"] - data["sha"] = file_sha + # Check if file exists to get its SHA for update + self.logger.info(f"Checking if file '{file_path}' exists on branch '{self.current_branch}'") + get_response = self.session.get(url, params={"ref": self.current_branch}) + if get_response.status_code == 200: + data["sha"] = get_response.json()["sha"] + self.logger.info(f"File '{file_path}' exists, will update.") + elif get_response.status_code == 404: + self.logger.info(f"File '{file_path}' does not exist, will create.") else: - self.logger.info("File does not exist, creating new file") + error_message = f"Error checking file existence for '{file_path}': {get_response.status_code} - {get_response.text}" + self.logger.error(error_message) + return error_message - response = requests.put(url, headers=self.headers, json=data) - - if response.status_code in [200, 201]: - success_message = f"File committed successfully to branch '{self.current_branch}'" + response = self.session.put(url, json=data) + if response.status_code in [200, 201]: # 200 for update, 201 for create + commit_sha = response.json().get("commit", {}).get("sha", "N/A") + success_message = f"File '{file_path}' committed successfully to branch '{self.current_branch}'. Commit SHA: {commit_sha}" self.logger.info(success_message) return success_message else: - error_message = f"Error committing file: {response.status_code}\nResponse: {response.text}" + error_message = f"Error committing file '{file_path}': {response.status_code} - {response.text}" self.logger.error(error_message) return error_message @metrics.measure - def _create_pull_request(self, title, body, base): - self.logger.info(f"Creating pull request: {title} from {self.current_branch} to {base}") - url = f"{self.base_url}/repos/{self.repo}/pulls" - data = { - "title": title, - "body": body, - "head": self.current_branch, - "base": base - } - response = requests.post(url, headers=self.headers, json=data) + def _create_pull_request(self, title, body, base="main"): + self.logger.info(f"Creating pull request: '{title}' from branch '{self.current_branch}' to '{base}'") + if self.current_branch == base: + error_message = f"Cannot create a pull request from branch '{self.current_branch}' to itself ('{base}')." + self.logger.warning(error_message) + return error_message + + url = f"{self.base_url}/repos/{self._repo}/pulls" + data = {"title": title, "body": body, "head": self.current_branch, "base": base} + response = self.session.post(url, json=data) if response.status_code == 201: - success_message = f"Pull request created successfully: {response.json()['html_url']}" + pr_html_url = response.json().get("html_url", "N/A") + pr_number = response.json().get("number", "N/A") + success_message = f"Pull request '{title}' created successfully: {pr_html_url} (Number: {pr_number})" self.logger.info(success_message) return success_message else: - error_message = f"Error creating pull request: {response.status_code}\nResponse: {response.text}" + error_message = f"Error creating pull request: {response.status_code} - {response.text}" self.logger.error(error_message) return error_message @metrics.measure def _get_branch_sha(self, branch): - url = f"{self.base_url}/repos/{self.repo}/git/refs/heads/{branch}" - response = requests.get(url, headers=self.headers) - if response.status_code == 200: - return response.json()["object"]["sha"] - else: - return f"Error getting branch SHA: {response.status_code}" + self.logger.info(f"Getting SHA for branch: {branch}") + url = f"{self.base_url}/repos/{self._repo}/git/refs/heads/{branch}" + response = self.session.get(url) + if response.status_code == 200: + sha = response.json()["object"]["sha"] + self.logger.info(f"SHA for branch '{branch}' is {sha}") + return sha + else: + error_message = f"Error getting SHA for branch '{branch}': {response.status_code} - {response.text}" + self.logger.error(error_message) + return error_message @metrics.measure def _list_files(self, path): - self.logger.info(f"Listing files in: {path} on branch: {self.current_branch}") - url = f"{self.base_url}/repos/{self.repo}/contents/{path}" - response = requests.get(url, headers=self.headers, params={"ref": self.current_branch}) + self.logger.info(f"Listing files in path: '{path}' on branch: '{self.current_branch}'") + url = f"{self.base_url}/repos/{self._repo}/contents/{path.strip('/')}" # Ensure no leading/trailing slashes for consistency + response = self.session.get(url, params={"ref": self.current_branch}) if response.status_code == 200: - files = [{"type": "file", "name": item["name"]} for item in response.json() if item["type"] == "file"] - directories = [{"type": "directory", "name": item["name"]} for item in response.json() if item["type"] == "dir"] - self.logger.info(f"Successfully listed files and directories in {path}") - files.extend(directories) - return files + items = response.json() + results = [] + if isinstance(items, list): # It's a directory listing + for item in items: + results.append({"name": item["name"], "type": item["type"], "path": item["path"]}) + elif isinstance(items, dict) and 'type' in items: # It's a single file response + results.append({"name": items["name"], "type": items["type"], "path": items["path"]}) + self.logger.info(f"Successfully listed {len(results)} items in '{path}'.") + return results + elif response.status_code == 404: + self.logger.warning(f"Path '{path}' not found on branch '{self.current_branch}'.") + return f"Error: Path '{path}' not found." else: - error_message = f"Error listing files: {response.status_code}" + error_message = f"Error listing files in '{path}': {response.status_code} - {response.text}" self.logger.error(error_message) return error_message @metrics.measure def _search_code(self, query): - self.logger.info(f"Searching code with query: {query}") + self.logger.info(f"Searching code with query: '{query}' in repo: '{self._repo}'") url = f"{self.base_url}/search/code" - params = { - "q": f"{query} repo:{self.repo}", - "per_page": 10 - } - response = requests.get(url, headers=self.headers, params=params) + params = {"q": f"{query} repo:{self._repo}"} + response = self.session.get(url, params=params) if response.status_code == 200: - results = [{"file": item["path"], "url": item["html_url"]} for item in response.json()["items"]] - self.logger.info(f"Successfully searched code. Found {len(results)} results.") + search_results = response.json().get("items", []) + results = [{"path": item["path"], "url": item["html_url"]} for item in search_results] + self.logger.info(f"Code search for '{query}' found {len(results)} items.") return results else: - error_message = f"Error searching code: {response.status_code}" + error_message = f"Error searching code for '{query}': {response.status_code} - {response.text}" self.logger.error(error_message) return error_message @metrics.measure - def _get_commit_history(self, file_path, num_commits): - self.logger.info(f"Getting commit history for file: {file_path}, number of commits: {num_commits}") - url = f"{self.base_url}/repos/{self.repo}/commits" - params = { - "path": file_path, - "per_page": num_commits - } - response = requests.get(url, headers=self.headers, params=params) + def _get_commit_history(self, file_path, num_commits=10): + self.logger.info(f"Getting last {num_commits} commit(s) for file: '{file_path}' on branch '{self.current_branch}'") + url = f"{self.base_url}/repos/{self._repo}/commits" + params = {"path": file_path, "sha": self.current_branch, "per_page": num_commits} + response = self.session.get(url, params=params) if response.status_code == 200: - commits = [{"sha": commit["sha"], "message": commit["commit"]["message"], "date": commit["commit"]["author"]["date"]} for commit in response.json()] - self.logger.info(f"Successfully retrieved commit history. Found {len(commits)} commits.") + commits_data = response.json() + commits = [{ + "sha": commit["sha"], + "message": commit["commit"]["message"], + "author": commit["commit"]["author"]["name"], + "date": commit["commit"]["author"]["date"] + } for commit in commits_data] + self.logger.info(f"Successfully retrieved {len(commits)} commit(s) for '{file_path}'.") return commits else: - error_message = f"Error getting commit history: {response.status_code}" + error_message = f"Error getting commit history for '{file_path}': {response.status_code} - {response.text}" self.logger.error(error_message) return error_message @metrics.measure - def _view_commit_details_for_file(self, file_path, num_commits): - self.logger.info(f"Viewing commit details for file: {file_path}, number of commits: {num_commits} (via _get_commit_history)") + def _view_commit_details_for_file(self, file_path, num_commits=10): + # This function is essentially the same as get_commit_history based on its description. + self.logger.info(f"Viewing commit details for file '{file_path}' (last {num_commits} commits) - using _get_commit_history.") return self._get_commit_history(file_path, num_commits) @metrics.measure def _get_current_branch(self): - self.logger.info(f"Getting current branch: {self.current_branch}") + self.logger.info(f"Current branch is: {self.current_branch}") return self.current_branch @metrics.measure def _set_current_branch(self, branch_name): - self.logger.info(f"Setting current branch from {self.current_branch} to {branch_name}") + self.logger.info(f"Attempting to set current branch to: {branch_name}") + # Check if branch exists by trying to get its SHA + sha_info = self._get_branch_sha(branch_name) + if isinstance(sha_info, str) and sha_info.startswith("Error getting SHA"): # Crude check for error string + error_message = f"Cannot set current branch: Branch '{branch_name}' not found or error accessing it. Details: {sha_info}" + self.logger.warning(error_message) + return error_message + self.current_branch = branch_name - return f"Current branch set to: {self.current_branch}" + success_message = f"Current branch set to: {self.current_branch}" + self.logger.info(success_message) + return success_message @metrics.measure def _get_file_at_commit(self, file_path, commit_sha): - self.logger.info(f"Getting file: {file_path} at commit: {commit_sha}") - url = f"{self.base_url}/repos/{self.repo}/contents/{file_path}" - response = requests.get(url, headers=self.headers, params={"ref": commit_sha}) + self.logger.info(f"Getting file '{file_path}' at commit SHA: {commit_sha}") + url = f"{self.base_url}/repos/{self._repo}/contents/{file_path}" + response = self.session.get(url, params={"ref": commit_sha}) if response.status_code == 200: content = response.json()["content"] decoded_content = base64.b64decode(content).decode('utf-8') - self.logger.info(f"Successfully retrieved file at commit") + self.logger.info(f"Successfully retrieved file '{file_path}' at commit {commit_sha}.") return decoded_content else: - error_message = f"Error reading file at commit: {response.status_code}" + error_message = f"Error reading file '{file_path}' at commit {commit_sha}: {response.status_code} - {response.text}" self.logger.error(error_message) return error_message @metrics.measure def _list_branches(self, per_page=100, all_pages=True): - self.logger.info(f"Listing branches. Per page: {per_page}, All pages: {all_pages}") - url = f"{self.base_url}/repos/{self.repo}/branches" - params = {"per_page": min(per_page, 100)} # GitHub API max is 100 per page - all_branches = [] - + self.logger.info(f"Listing branches for repo '{self._repo}'. Per_page={per_page}, All_pages={all_pages}") + url = f"{self.base_url}/repos/{self._repo}/branches" + params = {"per_page": min(per_page, 100)} # Respect GitHub API limit + branches_list = [] + page = 1 while url: - self.logger.info(f"Fetching branches from: {url}") - response = requests.get(url, headers=self.headers, params=params) + self.logger.debug(f"Fetching page {page} from {url} with params {params if page==1 else {}}") + response = self.session.get(url, params=params if page == 1 else None) # params only for first page if paginating via links if response.status_code != 200: - error_message = f"Error listing branches: {response.status_code}" + error_message = f"Error listing branches: {response.status_code} - {response.text}" self.logger.error(error_message) return error_message + + current_page_branches = [branch["name"] for branch in response.json()] + branches_list.extend(current_page_branches) + self.logger.debug(f"Fetched {len(current_page_branches)} branches on page {page}.") - branches = [branch["name"] for branch in response.json()] - all_branches.extend(branches) - self.logger.info(f"Fetched {len(branches)} branches") - - if not all_pages: + if not all_pages or not response.links.get("next"): break + url = response.links["next"]["url"] + page += 1 + params = {} # Clear params for subsequent calls using a link that includes them - # Check if there's a next page - url = response.links.get('next', {}).get('url') - if url: - params = {} # Remove per_page for subsequent requests - - self.logger.info(f"Successfully listed all branches. Total: {len(all_branches)}") - return all_branches + self.logger.info(f"Successfully listed {len(branches_list)} branches.") + return branches_list @metrics.measure def _approve_pull_request(self, pull_number): - self.logger.info(f"Approving pull request: {pull_number}") - url = f"{self.base_url}/repos/{self.repo}/pulls/{pull_number}/reviews" + self.logger.info(f"Approving pull request #{pull_number}") + url = f"{self.base_url}/repos/{self._repo}/pulls/{pull_number}/reviews" data = {"event": "APPROVE"} - response = requests.post(url, headers=self.headers, json=data) + response = self.session.post(url, json=data) if response.status_code == 200: - success_message = f"Pull request {pull_number} approved successfully" + success_message = f"Pull request #{pull_number} approved successfully." self.logger.info(success_message) return success_message else: - error_message = f"Error approving pull request: {response.status_code}\nResponse: {response.text}" + error_message = f"Error approving pull request #{pull_number}: {response.status_code} - {response.text}" self.logger.error(error_message) return error_message @metrics.measure def _close_pull_request(self, pull_number): - self.logger.info(f"Closing pull request: {pull_number}") - url = f"{self.base_url}/repos/{self.repo}/pulls/{pull_number}" + self.logger.info(f"Closing pull request #{pull_number}") + url = f"{self.base_url}/repos/{self._repo}/pulls/{pull_number}" data = {"state": "closed"} - response = requests.patch(url, headers=self.headers, json=data) + response = self.session.patch(url, json=data) # Use PATCH for update if response.status_code == 200: - success_message = f"Pull request {pull_number} closed successfully" + success_message = f"Pull request #{pull_number} closed successfully." self.logger.info(success_message) return success_message else: - error_message = f"Error closing pull request: {response.status_code}\nResponse: {response.text}" + error_message = f"Error closing pull request #{pull_number}: {response.status_code} - {response.text}" self.logger.error(error_message) return error_message @metrics.measure - def _merge_pull_request(self, pull_number, commit_title, commit_message, merge_method): - self.logger.info(f"Merging pull request: {pull_number}") - url = f"{self.base_url}/repos/{self.repo}/pulls/{pull_number}/merge" + def _merge_pull_request(self, pull_number, commit_title="Merge pull request", commit_message="", merge_method="merge"): + self.logger.info(f"Merging pull request #{pull_number} using method '{merge_method}'") + url = f"{self.base_url}/repos/{self._repo}/pulls/{pull_number}/merge" data = {"commit_title": commit_title, "commit_message": commit_message, "merge_method": merge_method} - response = requests.put(url, headers=self.headers, json=data) + response = self.session.put(url, json=data) if response.status_code == 200: - success_message = f"Pull request {pull_number} merged successfully" + success_message = f"Pull request #{pull_number} merged successfully." self.logger.info(success_message) return success_message + elif response.status_code == 405: # Method Not Allowed (e.g., PR not mergeable) + error_message = f"Error merging pull request #{pull_number}: Not mergeable. {response.json().get('message', response.text)}" + self.logger.warning(error_message) + return error_message + elif response.status_code == 409: # Conflict + error_message = f"Error merging pull request #{pull_number}: Merge conflict. {response.json().get('message', response.text)}" + self.logger.warning(error_message) + return error_message else: - error_message = f"Error merging pull request: {response.status_code}\nResponse: {response.text}" + error_message = f"Error merging pull request #{pull_number}: {response.status_code} - {response.text}" self.logger.error(error_message) return error_message @metrics.measure def _delete_branch(self, branch_name): self.logger.info(f"Deleting branch: {branch_name}") - url = f"{self.base_url}/repos/{self.repo}/git/refs/heads/{branch_name}" - response = requests.delete(url, headers=self.headers) + if branch_name == "main" or (hasattr(self, 'default_branch') and branch_name == self.default_branch) : + # Add a check for a configurable default branch if necessary + error_message = f"Cannot delete protected branch: {branch_name}" + self.logger.warning(error_message) + return error_message + + url = f"{self.base_url}/repos/{self._repo}/git/refs/heads/{branch_name}" + response = self.session.delete(url) if response.status_code == 204: - success_message = f"Branch {branch_name} deleted successfully" + success_message = f"Branch '{branch_name}' deleted successfully." self.logger.info(success_message) + if self.current_branch == branch_name: + self.current_branch = "main" # Or some other default + self.logger.info(f"Current branch was {branch_name}, reset to {self.current_branch}.") return success_message else: - error_message = f"Error deleting branch: {response.status_code}\nResponse: {response.text}" + error_message = f"Error deleting branch '{branch_name}': {response.status_code} - {response.text}" self.logger.error(error_message) return error_message + @metrics.measure def _get_issue_details(self, issue_number): - self.logger.info(f"Getting details for issue: {issue_number}") - url = f"{self.base_url}/repos/{self.repo}/issues/{issue_number}" - response = requests.get(url, headers=self.headers) + self.logger.info(f"Getting details for issue #{issue_number}") + url = f"{self.base_url}/repos/{self._repo}/issues/{issue_number}" + response = self.session.get(url) if response.status_code == 200: - issue_data = response.json() - issue_details = { - "number": issue_data["number"], - "title": issue_data["title"], - "state": issue_data["state"], - "body": issue_data["body"], - "created_at": issue_data["created_at"], - "updated_at": issue_data["updated_at"], - "labels": [label["name"] for label in issue_data["labels"]], - "assignees": [assignee["login"] for assignee in issue_data["assignees"]], - "comments": issue_data["comments"] - } - self.logger.info(f"Successfully retrieved details for issue {issue_number}") - return issue_details + self.logger.info(f"Successfully retrieved details for issue #{issue_number}.") + return response.json() # Return raw JSON data for now else: - error_message = f"Error getting issue details: {response.status_code}\nResponse: {response.text}" + error_message = f"Error getting details for issue #{issue_number}: {response.status_code} - {response.text}" self.logger.error(error_message) return error_message - + @metrics.measure def _create_issue(self, title, body, labels=None): - self.logger.info(f"Creating issue: {title}") - url = f"{self.base_url}/repos/{self.repo}/issues" - data = { - "title": title, - "body": body - } - if labels: - data["labels"] = labels - response = requests.post(url, headers=self.headers, json=data) + self.logger.info(f"Creating new issue with title: '{title}'") + url = f"{self.base_url}/repos/{self._repo}/issues" + data = {"title": title, "body": body} + if labels: # Ensure labels is a list of strings + data["labels"] = labels if isinstance(labels, list) else [labels] + + response = self.session.post(url, json=data) if response.status_code == 201: - issue = response.json() - success_message = f"Issue created successfully: {issue['html_url']}" + issue_html_url = response.json().get("html_url", "N/A") + issue_number = response.json().get("number", "N/A") + success_message = f"Issue '{title}' created successfully: {issue_html_url} (Number: {issue_number})" self.logger.info(success_message) return success_message else: - error_message = f"Error creating issue: {response.status_code}\nResponse: {response.text}" + error_message = f"Error creating issue '{title}': {response.status_code} - {response.text}" self.logger.error(error_message) return error_message @metrics.measure def _list_issues(self, state="open", per_page=30, page=1): - self.logger.info(f"Listing issues. State: {state}, Per page: {per_page}, Page: {page}") - url = f"{self.base_url}/repos/{self.repo}/issues" - params = { - "state": state, - "per_page": per_page, - "page": page - } - response = requests.get(url, headers=self.headers, params=params) + self.logger.info(f"Listing issues with state: {state}, per_page: {per_page}, page: {page}") + url = f"{self.base_url}/repos/{self._repo}/issues" + params = {"state": state, "per_page": per_page, "page": page} + response = self.session.get(url, params=params) if response.status_code == 200: - issues = [{ - "number": issue["number"], - "title": issue["title"], - "state": issue["state"], - "created_at": issue["created_at"], - "url": issue["html_url"] - } for issue in response.json()] - self.logger.info(f"Successfully listed issues. Found {len(issues)} issues.") - return issues + issues_data = response.json() + self.logger.info(f"Successfully listed {len(issues_data)} issues.") + # Return a summary or full data based on needs + return [{ "title": i["title"], "number": i["number"], "state": i["state"], "url": i["html_url"] } for i in issues_data] else: - error_message = f"Error listing issues: {response.status_code}\nResponse: {response.text}" + error_message = f"Error listing issues: {response.status_code} - {response.text}" self.logger.error(error_message) return error_message - + @metrics.measure def _add_issue_comment(self, issue_number, comment): - self.logger.info(f"Adding comment to issue: {issue_number}") - url = f"{self.base_url}/repos/{self.repo}/issues/{issue_number}/comments" + self.logger.info(f"Adding comment to issue #{issue_number}: '{comment[:50]}...'") + url = f"{self.base_url}/repos/{self._repo}/issues/{issue_number}/comments" data = {"body": comment} - response = requests.post(url, headers=self.headers, json=data) + response = self.session.post(url, json=data) if response.status_code == 201: - comment_data = response.json() - success_message = f"Comment added successfully to issue {issue_number}: {comment_data['html_url']}" + comment_html_url = response.json().get("html_url", "N/A") + success_message = f"Comment added to issue #{issue_number} successfully: {comment_html_url}" self.logger.info(success_message) return success_message else: - error_message = f"Error adding comment to issue: {response.status_code}\nResponse: {response.text}" + error_message = f"Error adding comment to issue #{issue_number}: {response.status_code} - {response.text}" self.logger.error(error_message) return error_message @metrics.measure def _get_issue_comments(self, issue_number): - self.logger.info(f"Getting comments for issue: {issue_number}") - url = f"{self.base_url}/repos/{self.repo}/issues/{issue_number}/comments" - response = requests.get(url, headers=self.headers) + self.logger.info(f"Getting comments for issue #{issue_number}") + url = f"{self.base_url}/repos/{self._repo}/issues/{issue_number}/comments" + response = self.session.get(url) if response.status_code == 200: - comments = [{ - "id": comment["id"], - "user": comment["user"]["login"], - "body": comment["body"], - "created_at": comment["created_at"], - "updated_at": comment["updated_at"] - } for comment in response.json()] - self.logger.info(f"Successfully retrieved comments for issue {issue_number}. Found {len(comments)} comments.") - return comments + comments_data = response.json() + self.logger.info(f"Successfully retrieved {len(comments_data)} comments for issue #{issue_number}.") + # Return summary or full data + return [{ "user": c["user"]["login"], "body": c["body"], "created_at": c["created_at"] } for c in comments_data] else: - error_message = f"Error getting issue comments: {response.status_code}\nResponse: {response.text}" + error_message = f"Error getting comments for issue #{issue_number}: {response.status_code} - {response.text}" self.logger.error(error_message) return error_message - # New method for PR general comments @metrics.measure def _get_pull_request_general_comments(self, pull_number): - self.logger.info(f"Getting general comments for pull request: {pull_number}") - # Pull request comments are treated as issue comments in the GitHub API - # Re-use the existing _get_issue_comments method + self.logger.info(f"Getting general comments for pull request #{pull_number}") + # In GitHub API, PR comments (general, not review comments on lines) are issue comments. + # The PR is also an issue, so use the issue comments endpoint. return self._get_issue_comments(issue_number=pull_number) @metrics.measure def _create_project_board(self, name, body=None): - url = f"{self.base_url}/repos/{self.repo}/projects" - data = {"name": name, "body": body} - response = requests.post(url, headers=self.headers, json=data) + self.logger.info(f"Creating project board: '{name}'") + url = f"{self.base_url}/repos/{self._repo}/projects" + headers = self.session.headers.copy() # Get existing session headers + headers["Accept"] = "application/vnd.github.inertia-preview+json" # Required for Projects API + data = {"name": name} + if body: data["body"] = body + response = self.session.post(url, headers=headers, json=data) if response.status_code == 201: - project = response.json() - success_message = f"Project board '{name}' created successfully." + project_data = response.json() + success_message = f"Project board '{name}' created successfully with ID: {project_data['id']}" self.logger.info(success_message) - return { - "status": "success", - "status_code": response.status_code, - "message": success_message, - "data": project - } + return project_data # Return full project data else: - error_message = f"Error creating project board: {response.status_code}" + error_message = f"Error creating project board '{name}': {response.status_code} - {response.text}" self.logger.error(error_message) - return { - "status": "error", - "status_code": response.status_code, - "message": error_message, - "response": response.text - } + return error_message @metrics.measure def _create_project_column(self, project_id, column_name): + self.logger.info(f"Creating column '{column_name}' for project ID: {project_id}") url = f"{self.base_url}/projects/{project_id}/columns" + headers = self.session.headers.copy() + headers["Accept"] = "application/vnd.github.inertia-preview+json" data = {"name": column_name} - response = requests.post(url, headers=self.headers, json=data) + response = self.session.post(url, headers=headers, json=data) if response.status_code == 201: - column = response.json() - success_message = f"Column '{column_name}' created successfully in project {project_id}." + column_data = response.json() + success_message = f"Column '{column_name}' created successfully for project {project_id} with ID: {column_data['id']}" self.logger.info(success_message) - return { - "status": "success", - "status_code": response.status_code, - "message": success_message, - "data": column - } + return column_data else: - error_message = f"Error creating project column: {response.status_code}" + error_message = f"Error creating column '{column_name}' for project {project_id}: {response.status_code} - {response.text}" self.logger.error(error_message) - return { - "status": "error", - "status_code": response.status_code, - "message": error_message, - "response": response.text - } + return error_message @metrics.measure - def _create_project_card(self, column_id, note): + def _create_project_card(self, column_id, note=None, content_id=None, content_type=None): + self.logger.info(f"Creating card in column ID: {column_id}") url = f"{self.base_url}/projects/columns/{column_id}/cards" - data = {"note": note} - response = requests.post(url, headers=self.headers, json=data) + headers = self.session.headers.copy() + headers["Accept"] = "application/vnd.github.inertia-preview+json" + data = {} + if note: + data["note"] = note + if content_id and content_type: + data["content_id"] = content_id + data["content_type"] = content_type + elif (content_id and not content_type) or (not content_id and content_type): + err = "Both content_id and content_type must be provided to link content to a project card." + self.logger.warning(err) + return err + + if not data: + return "Error: Card must have a note or content to link." + + response = self.session.post(url, headers=headers, json=data) if response.status_code == 201: - card = response.json() - success_message = f"Card created successfully in column {column_id}." + card_data = response.json() + success_message = f"Card created successfully in column {column_id} with ID: {card_data['id']}" self.logger.info(success_message) - return { - "status": "success", - "status_code": response.status_code, - "message": success_message, - "data": card - } + return card_data else: - error_message = f"Error creating project card: {response.status_code}" + error_message = f"Error creating card in column {column_id}: {response.status_code} - {response.text}" self.logger.error(error_message) - return { - "status": "error", - "status_code": response.status_code, - "message": error_message, - "response": response.text - } + return error_message @metrics.measure - def _move_project_card(self, card_id, position, column_id): + def _move_project_card(self, card_id, position, column_id=None): + self.logger.info(f"Moving card ID: {card_id} to position: {position}" + (f" in column ID: {column_id}" if column_id else "")) url = f"{self.base_url}/projects/columns/cards/{card_id}/moves" - data = {"position": position, "column_id": column_id} - response = requests.post(url, headers=self.headers, json=data) - if response.status_code == 201: - success_message = f"Card {card_id} moved successfully." + headers = self.session.headers.copy() + headers["Accept"] = "application/vnd.github.inertia-preview+json" + data = {"position": position} + if column_id: + data["column_id"] = column_id + + response = self.session.post(url, headers=headers, json=data) + if response.status_code == 201: # Successful move returns 201 with empty body + success_message = f"Card {card_id} moved successfully to position {position}" + (f" in column {column_id}" if column_id else ".") self.logger.info(success_message) - return { - "status": "success", - "status_code": response.status_code, - "message": success_message - } + return success_message # Return success message as body is empty else: - error_message = f"Error moving project card: {response.status_code}" + error_message = f"Error moving card {card_id}: {response.status_code} - {response.text}" self.logger.error(error_message) - return { - "status": "error", - "status_code": response.status_code, - "message": error_message, - "response": response.text - } + return error_message + # _link_issue_to_project_card is effectively handled by _create_project_card if content_id and content_type are passed. + # The API used to have a separate link endpoint, but now it is part of card creation/update. + # For updating an existing card to link an issue, one would PATCH the card's content_id/content_type. + # Let's assume the function intends to update an existing card if it's a separate function. + # However, the provided API spec for `link_issue_to_project_card` uses PATCH on card_id, so let's implement that. @metrics.measure def _link_issue_to_project_card(self, card_id, content_id, content_type): - url = f"{self.base_url}/projects/columns/cards/{card_id}" + self.logger.info(f"Linking content_id {content_id} (type: {content_type}) to card_id {card_id}") + url = f"{self.base_url}/projects/cards/{card_id}" # Note: API docs suggest /projects/columns/cards/{card_id} or /projects/cards/{card_id} + # Using /projects/cards/{card_id} as it seems more general for card update. + headers = self.session.headers.copy() + headers["Accept"] = "application/vnd.github.inertia-preview+json" data = {"content_id": content_id, "content_type": content_type} - response = requests.patch(url, headers=self.headers, json=data) + + response = self.session.patch(url, headers=headers, json=data) if response.status_code == 200: - success_message = f"Issue/PR linked to card {card_id} successfully." + updated_card = response.json() + success_message = f"{content_type} {content_id} linked to card {card_id} successfully." self.logger.info(success_message) - return { - "status": "success", - "status_code": response.status_code, - "message": success_message - } + return updated_card else: - error_message = f"Error linking issue/PR to project card: {response.status_code}" + error_message = f"Error linking {content_type} {content_id} to card {card_id}: {response.status_code} - {response.text}" self.logger.error(error_message) - return { - "status": "error", - "status_code": response.status_code, - "message": error_message, - "response": response.text - } - + return error_message + @metrics.measure def _list_project_boards(self): - self.logger.info("Fetching project boards...") - url = f"{self.base_url}/repos/{self.repo}/projects" - response = requests.get(url, headers=self.headers) + self.logger.info(f"Listing project boards for repo: {self._repo}") + url = f"{self.base_url}/repos/{self._repo}/projects" + headers = self.session.headers.copy() + headers["Accept"] = "application/vnd.github.inertia-preview+json" + response = self.session.get(url, headers=headers) if response.status_code == 200: - boards = response.json() - self.logger.info(f"Successfully fetched {len(boards)} project boards.") - return boards + projects_data = response.json() + self.logger.info(f"Successfully listed {len(projects_data)} project boards.") + return projects_data else: - error_message = f"Error fetching project boards: {response.status_code}" + error_message = f"Error listing project boards: {response.status_code} - {response.text}" self.logger.error(error_message) return error_message @metrics.measure def _view_project_board_items(self, project_id): - self.logger.info(f"Fetching items for project board ID: {project_id}...") + self.logger.info(f"Viewing items for project ID: {project_id}") columns_url = f"{self.base_url}/projects/{project_id}/columns" - columns_response = requests.get(columns_url, headers=self.headers) - if columns_response.status_code == 200: - columns = columns_response.json() - items = [] - for column in columns: - column_id = column["id"] - column_name = column["name"] - cards_url = f"{self.base_url}/projects/columns/{column_id}/cards" - cards_response = requests.get(cards_url, headers=self.headers) - if cards_response.status_code == 200: - cards = cards_response.json() - items.append({"column": column_name, "cards": cards}) - else: - self.logger.error(f"Error fetching cards for column {column_id}: {cards_response.status_code}") - items.append({"column": column_name, "cards": "Error fetching cards"}) - self.logger.info(f"Successfully fetched items for project board ID: {project_id}.") - return items - else: - error_message = f"Error fetching columns for project board: {columns_response.status_code}" + headers = self.session.headers.copy() + headers["Accept"] = "application/vnd.github.inertia-preview+json" + + columns_response = self.session.get(columns_url, headers=headers) + if columns_response.status_code != 200: + error_message = f"Error fetching columns for project {project_id}: {columns_response.status_code} - {columns_response.text}" self.logger.error(error_message) return error_message + + columns_data = columns_response.json() + project_items = [] + for column in columns_data: + column_info = {"id": column["id"], "name": column["name"], "cards": []} + cards_url = column["cards_url"] + cards_response = self.session.get(cards_url, headers=headers) + if cards_response.status_code == 200: + column_info["cards"] = cards_response.json() + else: + self.logger.error(f"Error fetching cards for column {column['id']}('{column['name']}'): {cards_response.status_code} - {cards_response.text}") + column_info["cards"] = "Error fetching cards" + project_items.append(column_info) + + self.logger.info(f"Successfully retrieved items for project ID: {project_id}.") + return project_items - # New functions for PR review capabilities @metrics.measure def _get_pull_request_details(self, pull_number): - self.logger.info(f"Getting details for pull request: {pull_number}") - url = f"{self.base_url}/repos/{self.repo}/pulls/{pull_number}" - response = requests.get(url, headers=self.headers) + self.logger.info(f"Getting details for PR #{pull_number}") + url = f"{self.base_url}/repos/{self._repo}/pulls/{pull_number}" + response = self.session.get(url) if response.status_code == 200: - self.logger.info(f"Successfully retrieved details for PR {pull_number}") + self.logger.info(f"Successfully retrieved details for PR #{pull_number}.") return response.json() else: - error_message = f"Error getting pull request details: {response.status_code}\nResponse: {response.text}" + error_message = f"Error getting details for PR #{pull_number}: {response.status_code} - {response.text}" self.logger.error(error_message) return error_message @metrics.measure def _get_pull_request_diff(self, pull_number): - self.logger.info(f"Getting diff for pull request: {pull_number}") - url = f"{self.base_url}/repos/{self.repo}/pulls/{pull_number}" - headers = self.headers.copy() - headers["Accept"] = "application/vnd.github.v3.diff" - response = requests.get(url, headers=headers) + self.logger.info(f"Getting diff for PR #{pull_number}") + url = f"{self.base_url}/repos/{self._repo}/pulls/{pull_number}" + diff_headers = self.session.headers.copy() + diff_headers["Accept"] = "application/vnd.github.diff" + response = self.session.get(url, headers=diff_headers) if response.status_code == 200: - self.logger.info(f"Successfully retrieved diff for PR {pull_number}") + self.logger.info(f"Successfully retrieved diff for PR #{pull_number}.") return response.text else: - error_message = f"Error getting pull request diff: {response.status_code}\nResponse: {response.text}" + error_message = f"Error getting diff for PR #{pull_number}: {response.status_code} - {response.text}" self.logger.error(error_message) return error_message @metrics.measure def _get_pull_request_files(self, pull_number): - self.logger.info(f"Getting files for pull request: {pull_number}") - url = f"{self.base_url}/repos/{self.repo}/pulls/{pull_number}/files" - response = requests.get(url, headers=self.headers) + self.logger.info(f"Getting files for PR #{pull_number}") + url = f"{self.base_url}/repos/{self._repo}/pulls/{pull_number}/files" + response = self.session.get(url) if response.status_code == 200: - self.logger.info(f"Successfully retrieved files for PR {pull_number}") + self.logger.info(f"Successfully retrieved files for PR #{pull_number}.") return response.json() else: - error_message = f"Error getting pull request files: {response.status_code}\nResponse: {response.text}" + error_message = f"Error getting files for PR #{pull_number}: {response.status_code} - {response.text}" self.logger.error(error_message) return error_message @metrics.measure def _create_pull_request_review_comment(self, pull_number, body, commit_id, path, position, side="RIGHT", start_line=None, start_side=None): - self.logger.info(f"Adding review comment to PR {pull_number} on file {path} at position {position}") - url = f"{self.base_url}/repos/{self.repo}/pulls/{pull_number}/comments" - data = { - "body": body, - "commit_id": commit_id, - "path": path, - "position": position, - "side": side - } - if start_line is not None: - data["start_line"] = start_line - if start_side is not None: - data["start_side"] = start_side - - response = requests.post(url, headers=self.headers, json=data) + self.logger.info(f"Creating review comment on PR #{pull_number}, file '{path}', position {position}") + url = f"{self.base_url}/repos/{self._repo}/pulls/{pull_number}/comments" + data = {"body": body, "commit_id": commit_id, "path": path, "position": position, "side": side} + if start_line is not None: data["start_line"] = start_line + if start_side is not None: data["start_side"] = start_side + + response = self.session.post(url, json=data) if response.status_code == 201: - success_message = f"Comment added to PR {pull_number} successfully: {response.json()['html_url']}" + comment_url = response.json().get("html_url", "N/A") + success_message = f"Review comment created successfully on PR #{pull_number}: {comment_url}" self.logger.info(success_message) return success_message else: - error_message = f"Error creating pull request review comment: {response.status_code}\nResponse: {response.text}" + error_message = f"Error creating review comment on PR #{pull_number}: {response.status_code} - {response.text}" self.logger.error(error_message) return error_message @metrics.measure def _list_pull_request_review_comments(self, pull_number): - self.logger.info(f"Listing review comments for pull request: {pull_number}") - url = f"{self.base_url}/repos/{self.repo}/pulls/{pull_number}/comments" - response = requests.get(url, headers=self.headers) + self.logger.info(f"Listing review comments for PR #{pull_number}") + url = f"{self.base_url}/repos/{self._repo}/pulls/{pull_number}/comments" + response = self.session.get(url) if response.status_code == 200: - self.logger.info(f"Successfully retrieved review comments for PR {pull_number}") + self.logger.info(f"Successfully retrieved review comments for PR #{pull_number}.") return response.json() else: - error_message = f"Error listing pull request review comments: {response.status_code}\nResponse: {response.text}" + error_message = f"Error listing review comments for PR #{pull_number}: {response.status_code} - {response.text}" self.logger.error(error_message) return error_message @metrics.measure def _submit_pull_request_review(self, pull_number, event, body=None): - self.logger.info(f"Submitting review for pull request {pull_number} with event: {event}") - url = f"{self.base_url}/repos/{self.repo}/pulls/{pull_number}/reviews" - data = {"event": event} - if body: - data["body"] = body - response = requests.post(url, headers=self.headers, json=data) + self.logger.info(f"Submitting '{event}' review for PR #{pull_number}") + url = f"{self.base_url}/repos/{self._repo}/pulls/{pull_number}/reviews" + data = {"event": event.upper()} # Ensure event is uppercase as per API + if body: data["body"] = body + + response = self.session.post(url, json=data) if response.status_code == 200: - success_message = f"Review submitted for PR {pull_number} successfully." + review_url = response.json().get("html_url", "N/A") + success_message = f"Review ({event}) submitted successfully for PR #{pull_number}: {review_url}" self.logger.info(success_message) return success_message else: - error_message = f"Error submitting pull request review: {response.status_code}\nResponse: {response.text}" + error_message = f"Error submitting review for PR #{pull_number}: {response.status_code} - {response.text}" self.logger.error(error_message) return error_message diff --git a/tools/log_tool.py b/tools/log_tool.py index 4a34557..bcadf04 100644 --- a/tools/log_tool.py +++ b/tools/log_tool.py @@ -1,5 +1,4 @@ # tools/log_tool.py - from .base_tool import BaseTool from .metrics import metrics import logging @@ -7,48 +6,39 @@ import os from datetime import datetime, timedelta class LogTool(BaseTool): - def __init__(self): - # Set up logging - self.logger = logging.getLogger(__name__) - self.logger.setLevel(logging.INFO) - - # Create a file handler - file_handler = logging.FileHandler('log_tool.log') - file_handler.setLevel(logging.INFO) - - # Create a console handler - console_handler = logging.StreamHandler() - console_handler.setLevel(logging.INFO) - - # Create a formatting for the logs - formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') - file_handler.setFormatter(formatter) - console_handler.setFormatter(formatter) - - # Add the handlers to the logger - self.logger.addHandler(file_handler) - self.logger.addHandler(console_handler) + # Default log format string that _get_log_contents expects for time-based filtering. + # Making it a class variable so it's visible and could be overridden by a subclass if needed, + # though the parser is still hardcoded in this version. + EXPECTED_LOG_TIMESTAMP_FORMAT = '%Y-%m-%d %H:%M:%S,%f' + + def __init__(self, log_file_path=None, logger=None): + self.configured_log_file_path = log_file_path if log_file_path else 'logs/output.log' + self.logger = logger if logger else logging.getLogger(__name__) + if not self.logger.handlers: + self.logger.addHandler(logging.NullHandler()) + self.logger.info(f"LogTool initialized. Log file path: {self.configured_log_file_path}") def clear(self): + # No specific state to clear for LogTool in this version. + self.logger.debug("LogTool clear called, no action taken.") pass def get_functions(self): - return [ { "type": "function", "function": { "name": "get_log_contents", - "description": "Get the contents of the log file.", + "description": "Get the contents of the log file. If line_count is not provided, it attempts to return logs from the last 24 hours based on timestamp.", "parameters": { "type": "object", "properties": { "line_count": { "type": "integer", - "description": "Number of lines from the end of the log file to retrieve" + "description": "Number of lines from the end of the log file to retrieve. If omitted, logs from last 24 hours are returned." } }, - "required": [] + "required": [] # line_count is optional } } } @@ -56,37 +46,60 @@ class LogTool(BaseTool): @metrics.measure def execute(self, function_name, **kwargs): - self.logger.info(f"Executing: {function_name}") - + self.logger.info(f"Executing LogTool function: {function_name} with args: {kwargs}") if function_name == "get_log_contents": - return self._get_log_contents(kwargs.get("line_count")) + # kwargs.get("line_count") will be None if not provided, which is handled by _get_log_contents + return self._get_log_contents(line_count=kwargs.get("line_count")) else: error_message = f"Unknown function: {function_name}" self.logger.error(error_message) return error_message @metrics.measure - def _get_log_contents(self, line_count=150): - log_file_path = 'logs/output.log' + def _get_log_contents(self, line_count=None): # Default line_count is None to trigger 24h logic if not specified + self.logger.info(f"Attempting to get log contents from: {self.configured_log_file_path}. Line count: {line_count if line_count is not None else 'Last 24 hours'}") - if not os.path.exists(log_file_path): - error_message = "Log file does not exist." + if not os.path.exists(self.configured_log_file_path): + error_message = f"Log file does not exist at path: {self.configured_log_file_path}" self.logger.error(error_message) return error_message try: - with open(log_file_path, 'r') as log_file: + with open(self.configured_log_file_path, 'r', encoding='utf-8') as log_file: log_lines = log_file.readlines() + self.logger.debug(f"Read {len(log_lines)} total lines from log file.") - if line_count is not None: - log_lines = log_lines[-line_count:] - else: - now = datetime.now() - twenty_four_hours_ago = now - timedelta(days=1) - log_lines = [line for line in log_lines if datetime.strptime(line.split(' - ')[0], '%Y-%m-%d %H:%M:%S,%f') > twenty_four_hours_ago] + if line_count is not None: + # Ensure line_count is positive if specified, otherwise could lead to unexpected slicing + if not isinstance(line_count, int) or line_count <= 0: + self.logger.warning(f"Invalid line_count '{line_count}' provided, defaulting to fetch last 150 lines.") + line_count = 150 # Default to a sensible number if invalid value provided + log_lines = log_lines[-line_count:] + self.logger.info(f"Returning last {len(log_lines)} lines based on line_count: {line_count}") + else: + # Default to last 24 hours if line_count is explicitly None or not provided + self.logger.info(f"Filtering logs for the last 24 hours. Expected timestamp format: {self.EXPECTED_LOG_TIMESTAMP_FORMAT}") + now = datetime.now() + twenty_four_hours_ago = now - timedelta(days=1) + + filtered_lines = [] + for line in log_lines: + try: + # Attempt to parse timestamp from the beginning of the line + timestamp_str = line.split(' - ', 1)[0] + log_time = datetime.strptime(timestamp_str, self.EXPECTED_LOG_TIMESTAMP_FORMAT) + if log_time > twenty_four_hours_ago: + filtered_lines.append(line) + except (ValueError, IndexError) as e: + # This means the line doesn't start with a parsable timestamp in the expected format. + # Depending on requirements, these lines might be skipped or included. + # For strict 24-hour filtering, we skip them. + self.logger.debug(f"Skipping line due to timestamp parse error ('{e}') or format mismatch: {line.strip()[:100]}...") + log_lines = filtered_lines + self.logger.info(f"Returning {len(log_lines)} lines from the last 24 hours.") - return "".join(log_lines) + return "".join(log_lines) except Exception as e: - error_message = f"An error occurred while reading the log file: {e}" - self.logger.error(error_message) - return error_message \ No newline at end of file + error_message = f"An error occurred while reading the log file '{self.configured_log_file_path}': {e}" + self.logger.error(error_message, exc_info=True) + return error_message diff --git a/tools/metrics.py b/tools/metrics.py index f4a695d..54cb228 100644 --- a/tools/metrics.py +++ b/tools/metrics.py @@ -1,13 +1,19 @@ +# tools/metrics.py import cProfile import pstats -import io +import io from functools import wraps from collections import defaultdict +import logging class Metrics: - def __init__(self): + def __init__(self, logger=None): self.call_count = defaultdict(int) self.total_time = defaultdict(float) + self.logger = logger if logger else logging.getLogger(__name__) + if not self.logger.handlers: + self.logger.addHandler(logging.NullHandler()) + self.logger.debug("Metrics instance initialized.") def measure(self, func): @wraps(func) @@ -16,30 +22,58 @@ class Metrics: pr = cProfile.Profile() pr.enable() - result = func(*args, **kwargs) - pr.disable() - s = io.StringIO() - ps = pstats.Stats(pr, stream=s).sort_stats('cumulative') - ps.print_stats() - # Extract the total time spent in the function - time_spent = float(s.getvalue().split('\n')[0].split()[-2]) - self.total_time[func.__name__] += time_spent + ps = pstats.Stats(pr) + + func_code = func.__code__ + func_key_tuple = (func_code.co_filename, func_code.co_firstlineno, func_code.co_name) + + time_spent_for_func = 0.0 + if func_key_tuple in ps.stats: + time_spent_for_func = ps.stats[func_key_tuple][3] # [3] is cumulative time (ct) + else: + # Fallback: try to find by function name if exact key fails (e.g. due to decorators changing code object details slightly) + # This is less precise and might pick up other functions if names are not unique across files. + found_by_name = False + for key, stat in ps.stats.items(): + if key[2] == func.__name__: # key[2] is function name + time_spent_for_func = stat[3] # cumulative time + self.logger.debug(f"Found stats for {func.__name__} by name {key} after primary key failed.") + found_by_name = True + break + if not found_by_name: + self.logger.warning( + f"Could not find exact cProfile stats for {func.__name__} with key {func_key_tuple} or by name. " + f"Time for this call will be recorded as 0. This might occur for non-Python functions or due to complex decorators." + ) + + self.total_time[func.__name__] += time_spent_for_func + self.logger.debug(f"Measured cumulative time for {func.__name__}: {time_spent_for_func:.6f}s") return result return wrapper def get_metrics(self): - metrics = {} + metrics_data = {} for func_name in self.call_count: - metrics[func_name] = { - 'call_count': self.call_count[func_name], - 'total_time': self.total_time[func_name], - 'average_time': self.total_time[func_name] / self.call_count[func_name] if self.call_count[func_name] > 0 else 0 + count = self.call_count[func_name] + total_t = self.total_time[func_name] + metrics_data[func_name] = { + 'call_count': count, + 'total_time': round(total_t, 6), + 'average_time': round(total_t / count, 6) if count > 0 else 0 } - return metrics + return metrics_data -# Create a global instance of Metrics -metrics = Metrics() \ No newline at end of file + def clear_metrics(self): + self.call_count.clear() + self.total_time.clear() + self.logger.info("Metrics cleared.") + +# Global instance for convenience +_metrics_instance_logger = logging.getLogger(__name__ + ".global_instance") +if not _metrics_instance_logger.handlers: + _metrics_instance_logger.addHandler(logging.NullHandler()) +metrics = Metrics(logger=_metrics_instance_logger) diff --git a/tools/metrics_tool.py b/tools/metrics_tool.py index 8ba6116..91d4664 100644 --- a/tools/metrics_tool.py +++ b/tools/metrics_tool.py @@ -1,13 +1,22 @@ # tools/metrics_tool.py - from .base_tool import BaseTool -from .metrics import metrics +from .metrics import metrics as global_metrics_instance # For default and measuring execute +from .metrics import Metrics # For type hinting and potentially creating a new one if needed +import logging class MetricsTool(BaseTool): - def __init__(self): - self.metrics = metrics + def __init__(self, metrics_provider: Metrics | None = None, logger: logging.Logger | None = None): + self.metrics_provider = metrics_provider if metrics_provider is not None else global_metrics_instance + self.logger = logger if logger else logging.getLogger(__name__) + if not self.logger.handlers: + self.logger.addHandler(logging.NullHandler()) + self.logger.debug(f"MetricsTool initialized. Using metrics provider: {self.metrics_provider}") def clear(self): + # This tool itself doesn't hold state that needs clearing beyond what its metrics_provider might do. + # If this tool were responsible for clearing the metrics it reports on, it would call: + # self.metrics_provider.clear_metrics() + self.logger.debug("MetricsTool clear method called. No local state to clear.") pass def get_functions(self): @@ -60,25 +69,60 @@ class MetricsTool(BaseTool): } ] - @metrics.measure + @global_metrics_instance.measure # The execute method can be measured by the global instance def execute(self, function_name, **kwargs): + self.logger.info(f"Executing MetricsTool function: {function_name} with args: {kwargs}") if function_name == "get_function_metrics": return self._get_function_metrics() elif function_name == "get_specific_function_metrics": - return self._get_specific_function_metrics(kwargs.get("function_name")) + func_name_arg = kwargs.get("function_name") + if func_name_arg is None: # Check if None, as empty string could be a valid (though unlikely) func name + self.logger.warning("'function_name' argument is missing for get_specific_function_metrics.") + return "Error: Missing required argument 'function_name'." + return self._get_specific_function_metrics(str(func_name_arg)) # Ensure string elif function_name == "get_top_n_functions": - return self._get_top_n_functions(kwargs.get("n")) + n_arg = kwargs.get("n") + if n_arg is None: + self.logger.warning("'n' argument is missing for get_top_n_functions.") + return "Error: Missing required argument 'n'." + try: + n_val = int(n_arg) + if n_val <= 0: + self.logger.warning(f"'n' argument must be a positive integer, got {n_val}.") + return "Error: Argument 'n' must be a positive integer." + return self._get_top_n_functions(n_val) + except ValueError: + self.logger.warning(f"'n' argument must be an integer, got '{n_arg}'.") + return "Error: Argument 'n' must be an integer." else: - return f"Unknown function: {function_name}" + error_message = f"Unknown function: {function_name}" + self.logger.error(error_message) + return error_message def _get_function_metrics(self): - return self.metrics.get_metrics() + self.logger.debug("Calling metrics_provider.get_metrics() for all functions.") + return self.metrics_provider.get_metrics() - def _get_specific_function_metrics(self, function_name): - all_metrics = self.metrics.get_metrics() - return all_metrics.get(function_name, f"No metrics found for function: {function_name}") + def _get_specific_function_metrics(self, function_to_get): + self.logger.debug(f"Getting metrics for specific function: {function_to_get}") + all_metrics = self.metrics_provider.get_metrics() + return all_metrics.get(function_to_get, f"No metrics found for function: {function_to_get}") def _get_top_n_functions(self, n): - all_metrics = self.metrics.get_metrics() - sorted_metrics = sorted(all_metrics.items(), key=lambda x: x[1]['total_time'], reverse=True) - return dict(sorted_metrics[:n]) \ No newline at end of file + self.logger.debug(f"Getting top {n} functions by total execution time.") + all_metrics = self.metrics_provider.get_metrics() + # Ensure that the items are actual metric dicts before trying to access 'total_time' + valid_metrics_items = [] + for name, metric_values in all_metrics.items(): + if isinstance(metric_values, dict) and 'total_time' in metric_values: + valid_metrics_items.append((name, metric_values)) + else: + self.logger.warning(f"Metric item for '{name}' is not in expected format: {metric_values}") + + # Sort items by total_time. items() gives list of (func_name, metrics_dict) + try: + sorted_metrics = sorted(valid_metrics_items, key=lambda item: item[1]['total_time'], reverse=True) + return dict(sorted_metrics[:n]) + except TypeError as e: + self.logger.error(f"Error sorting metrics, possibly due to unexpected data types: {e}", exc_info=True) + return "Error: Could not sort metrics due to unexpected data."