Merge pull request #203 from bucolucas/refactor/testability-base-files
Refactor: Improve Testability of Base Bots and Core Tools
This commit is contained in:
@@ -3,25 +3,48 @@ import json
|
||||
import logging
|
||||
from anthropic import Anthropic, APIError, RateLimitError
|
||||
from base_telegram_inference_bot import BaseTelegramInferenceBot
|
||||
from telegram_helper import TelegramHelper
|
||||
from telegram_helper import TelegramHelper # Used in main, not class
|
||||
|
||||
class AnthropicTelegramInferenceBot(BaseTelegramInferenceBot):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.anthropic_client = Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY"))
|
||||
DEFAULT_SMALL_MODEL_NAME = "claude-3-haiku-20240307"
|
||||
DEFAULT_SMALL_MODEL_MAX_TOKENS = "2048"
|
||||
DEFAULT_LARGE_MODEL_NAME = "claude-3-opus-20240229"
|
||||
DEFAULT_LARGE_MODEL_MAX_TOKENS = "4096"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
anthropic_client: Anthropic | None = None,
|
||||
api_key: str | None = None,
|
||||
small_model_name: str | None = None,
|
||||
small_model_max_tokens: str | None = None,
|
||||
large_model_name: str | None = None,
|
||||
large_model_max_tokens: str | None = None,
|
||||
system_prompt_content: str | None = None,
|
||||
system_prompt_path: str | None = None
|
||||
):
|
||||
super().__init__(system_prompt_content=system_prompt_content, system_prompt_path=system_prompt_path)
|
||||
|
||||
if anthropic_client:
|
||||
self.anthropic_client = anthropic_client
|
||||
else:
|
||||
_api_key = api_key or os.environ.get("ANTHROPIC_API_KEY")
|
||||
if not _api_key:
|
||||
raise ValueError("Anthropic API key must be provided either via argument or ANTHROPIC_API_KEY environment variable.")
|
||||
self.anthropic_client = Anthropic(api_key=_api_key)
|
||||
|
||||
self.small_model_name = small_model_name or os.environ.get("ANTHROPIC_SMALL_MODEL") or self.DEFAULT_SMALL_MODEL_NAME
|
||||
self.small_model_max_tokens_str = small_model_max_tokens or os.environ.get("ANTHROPIC_SMALL_MODEL_MAX_TOKENS") or self.DEFAULT_SMALL_MODEL_MAX_TOKENS
|
||||
self.large_model_name = large_model_name or os.environ.get("ANTHROPIC_LARGE_MODEL") or self.DEFAULT_LARGE_MODEL_NAME
|
||||
self.large_model_max_tokens_str = large_model_max_tokens or os.environ.get("ANTHROPIC_LARGE_MODEL_MAX_TOKENS") or self.DEFAULT_LARGE_MODEL_MAX_TOKENS
|
||||
|
||||
# Initialize with the small model by default
|
||||
self.small_model_name = os.environ.get("ANTHROPIC_SMALL_MODEL", "claude-3-haiku-20240307")
|
||||
self.small_model_max_tokens = os.environ.get("ANTHROPIC_SMALL_MODEL_MAX_TOKENS", "2048")
|
||||
self.large_model_name = os.environ.get("ANTHROPIC_LARGE_MODEL", "claude-3-opus-20240229")
|
||||
self.large_model_max_tokens = os.environ.get("ANTHROPIC_LARGE_MODEL_MAX_TOKENS", "4096")
|
||||
|
||||
self._configure_model_and_tokens(
|
||||
self.small_model_name,
|
||||
self.small_model_max_tokens
|
||||
self.small_model_max_tokens_str,
|
||||
default_max_tokens=int(self.DEFAULT_SMALL_MODEL_MAX_TOKENS) # pass int for default
|
||||
)
|
||||
|
||||
def _configure_model_and_tokens(self, model_name, max_tokens_str, default_max_tokens=2048): # Default max_tokens adjusted for typical "small"
|
||||
def _configure_model_and_tokens(self, model_name: str, max_tokens_str: str, default_max_tokens: int = 2048):
|
||||
self.model = model_name
|
||||
try:
|
||||
self.max_tokens = int(max_tokens_str) if max_tokens_str is not None else default_max_tokens
|
||||
@@ -65,17 +88,19 @@ class AnthropicTelegramInferenceBot(BaseTelegramInferenceBot):
|
||||
|
||||
def _format_tool_response_for_anthropic(self, tool_response_data):
|
||||
if isinstance(tool_response_data, str):
|
||||
# Wrap plain string in a list of text blocks if not already structured
|
||||
return [{"type": "text", "text": tool_response_data}]
|
||||
elif isinstance(tool_response_data, list) and all(isinstance(item, dict) and "type" in item for item in tool_response_data):
|
||||
# Already a list of content blocks
|
||||
return tool_response_data
|
||||
elif isinstance(tool_response_data, (dict, list)):
|
||||
# Attempt to JSON dump other dicts/lists if not already in content block format
|
||||
try:
|
||||
is_valid_block_list = isinstance(tool_response_data, list) and all(isinstance(item, dict) and "type" in item for item in tool_response_data)
|
||||
if is_valid_block_list:
|
||||
return tool_response_data
|
||||
else:
|
||||
return [{"type": "text", "text": json.dumps(tool_response_data)}]
|
||||
return [{"type": "text", "text": json.dumps(tool_response_data)}]
|
||||
except (TypeError, json.JSONDecodeError):
|
||||
return [{"type": "text", "text": str(tool_response_data)}]
|
||||
return [{"type": "text", "text": str(tool_response_data)}] # Fallback to string
|
||||
else:
|
||||
# Fallback for other types (int, float, etc.)
|
||||
return [{"type": "text", "text": str(tool_response_data)}]
|
||||
|
||||
async def handle_message(self, user_id, user_message):
|
||||
@@ -94,7 +119,7 @@ class AnthropicTelegramInferenceBot(BaseTelegramInferenceBot):
|
||||
|
||||
if not response or not response.content:
|
||||
logging.error("No valid response content from Anthropic LLM.")
|
||||
self.conversation_history[user_id] = current_turn_messages
|
||||
self.conversation_history[user_id] = current_turn_messages # Save current state
|
||||
return "Error: Could not get a valid response from the LLM."
|
||||
|
||||
assistant_current_turn_content_blocks = response.content
|
||||
@@ -123,7 +148,6 @@ class AnthropicTelegramInferenceBot(BaseTelegramInferenceBot):
|
||||
try:
|
||||
tool_response_data = self.call_tool(tool_name, tool_input)
|
||||
tool_result_content_block = self._format_tool_response_for_anthropic(tool_response_data)
|
||||
|
||||
tool_results_for_model.append({
|
||||
"type": "tool_result",
|
||||
"tool_use_id": tool_use_id,
|
||||
@@ -138,11 +162,15 @@ class AnthropicTelegramInferenceBot(BaseTelegramInferenceBot):
|
||||
"is_error": True
|
||||
})
|
||||
|
||||
current_turn_messages.append({"role": "user", "content": tool_results_for_model})
|
||||
current_turn_messages.append({"role": "user", "content": tool_results_for_model}) # Anthropic expects tool results as a user message
|
||||
|
||||
tool_use_count += 1
|
||||
if tool_use_count >= MAX_TOOL_ITERATIONS:
|
||||
logging.warning(f"Max tool iterations ({MAX_TOOL_ITERATIONS}) reached for Anthropic.")
|
||||
# Update assistant_response_content with any text from the last assistant turn before breaking
|
||||
if not assistant_response_content and text_parts_from_assistant:
|
||||
assistant_response_content = "".join(text_parts_from_assistant)
|
||||
assistant_response_content += "\n[Max tool iterations reached]"
|
||||
break
|
||||
|
||||
self.conversation_history[user_id] = current_turn_messages
|
||||
@@ -153,70 +181,89 @@ class AnthropicTelegramInferenceBot(BaseTelegramInferenceBot):
|
||||
if assistant_response_content:
|
||||
return assistant_response_content
|
||||
else:
|
||||
# Fallback if no text parts were found but there was an assistant message
|
||||
if current_turn_messages:
|
||||
last_message_in_turn = current_turn_messages[-1]
|
||||
# Check if the last message content has text blocks (Anthropic specific structure)
|
||||
if last_message_in_turn.get("role") == "assistant" and isinstance(last_message_in_turn.get("content"), list):
|
||||
for block in reversed(last_message_in_turn["content"]):
|
||||
if block.type == "text":
|
||||
return block.text
|
||||
return "No textual response from assistant."
|
||||
|
||||
if block.type == "text" and hasattr(block, 'text') and block.text:
|
||||
return block.text # Return the first non-empty text found from the end
|
||||
return "No textual response generated by the assistant after processing." # More informative default
|
||||
|
||||
async def start(self):
|
||||
logging.info("Anthropic Bot started")
|
||||
|
||||
async def clear_conversation_history(self, user_id):
|
||||
super().clear_conversation_history(user_id)
|
||||
logging.info(f"Cleared conversation history for Anthropic bot, user {user_id}")
|
||||
# clear_conversation_history is inherited from BaseTelegramInferenceBot and calls super().clear_conversation_history
|
||||
# No need to override if the base implementation is sufficient, unless specific logging is needed.
|
||||
# async def clear_conversation_history(self, user_id):
|
||||
# super().clear_conversation_history(user_id)
|
||||
# logging.info(f"Cleared conversation history for Anthropic bot, user {user_id}")
|
||||
|
||||
async def abort_processing(self, user_id):
|
||||
# This abort is a soft abort, as actual Anthropic API call is synchronous within handle_message
|
||||
# It primarily clears state and prevents further processing in the bot's loop if any.
|
||||
if user_id in self.processing_status:
|
||||
self.processing_status[user_id]["processing"] = False
|
||||
await self.clear_conversation_history(user_id)
|
||||
return "Processing aborted and conversation cleared."
|
||||
else:
|
||||
await self.clear_conversation_history(user_id)
|
||||
return "No active processing found to abort. Conversation cleared."
|
||||
self.processing_status[user_id]["processing"] = False # Mark as not processing
|
||||
# self.clear_processing_status(user_id) # Use base class method to remove entry
|
||||
# Clearing history might be too aggressive for a simple abort, depends on desired UX
|
||||
# For now, let's just stop processing and clear the flag.
|
||||
# Consider if conversation history should be cleared here or if that is a separate user action.
|
||||
# super().clear_conversation_history(user_id) # Moved to be less aggressive
|
||||
logging.info(f"Abort requested for user {user_id}. Processing flag cleared.")
|
||||
return "Processing aborted. You can send a new message or /clear the conversation."
|
||||
|
||||
async def switch_model(self):
|
||||
# Ensure ANTHROPIC_SMALL_MODEL and ANTHROPIC_LARGE_MODEL related env vars are loaded in __init__
|
||||
# or ensure they are freshly checked here if they can change during runtime (less common for model names).
|
||||
# For this implementation, we rely on the values stored during __init__.
|
||||
|
||||
if not self.small_model_name or not self.large_model_name:
|
||||
logging.warning("Small or Large model names for Anthropic are not defined. Cannot switch model.")
|
||||
return f"Model switching not fully configured. Currently using {self.model}."
|
||||
|
||||
if self.model == self.small_model_name:
|
||||
current_is_small = self.model == self.small_model_name
|
||||
current_is_large = self.model == self.large_model_name
|
||||
|
||||
if current_is_small:
|
||||
target_model = self.large_model_name
|
||||
target_max_tokens = self.large_model_max_tokens
|
||||
# Use default large max_tokens if specific one isn't set or invalid
|
||||
default_max_tokens_for_large = "4096"
|
||||
elif self.model == self.large_model_name:
|
||||
target_max_tokens_str = self.large_model_max_tokens_str
|
||||
default_target_max_tokens = int(self.DEFAULT_LARGE_MODEL_MAX_TOKENS)
|
||||
elif current_is_large:
|
||||
target_model = self.small_model_name
|
||||
target_max_tokens = self.small_model_max_tokens
|
||||
# Use default small max_tokens if specific one isn't set or invalid
|
||||
default_max_tokens_for_large = "2048"
|
||||
target_max_tokens_str = self.small_model_max_tokens_str
|
||||
default_target_max_tokens = int(self.DEFAULT_SMALL_MODEL_MAX_TOKENS)
|
||||
else:
|
||||
# Current model is neither the designated small nor large, switch to small as a reset
|
||||
logging.warning(f"Current model {self.model} is neither the configured small nor large model. Switching to small model.")
|
||||
logging.warning(f"Current model {self.model} is unrecognized. Switching to default small model.")
|
||||
target_model = self.small_model_name
|
||||
target_max_tokens = self.small_model_max_tokens
|
||||
default_max_tokens_for_large = "2048"
|
||||
target_max_tokens_str = self.small_model_max_tokens_str
|
||||
default_target_max_tokens = int(self.DEFAULT_SMALL_MODEL_MAX_TOKENS)
|
||||
|
||||
|
||||
self._configure_model_and_tokens(target_model, target_max_tokens, default_max_tokens=int(default_max_tokens_for_large)) # Pass appropriate default
|
||||
self._configure_model_and_tokens(target_model, target_max_tokens_str, default_max_tokens=default_target_max_tokens)
|
||||
logging.info(f"Switched Anthropic model to: {self.model}")
|
||||
return f"Switched to Anthropic model: {self.model} (Max Tokens: {self.max_tokens})"#Provide token info
|
||||
return f"Switched to Anthropic model: {self.model} (Max Tokens: {self.max_tokens})"
|
||||
|
||||
|
||||
# The main function is for standalone execution and basic testing, not part of the class itself.
|
||||
# It's good practice to update it to reflect changes if you use it for quick tests.
|
||||
# For unit tests, we'll instantiate the class with mocked dependencies.
|
||||
def main():
|
||||
if not os.environ.get("ANTHROPIC_API_KEY"):
|
||||
logging.error("FATAL: ANTHROPIC_API_KEY environment variable not set.")
|
||||
return
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
|
||||
bot = AnthropicTelegramInferenceBot()
|
||||
# Example of how to instantiate with new constructor (assuming API key is in ENV for this example)
|
||||
# For real tests, you'd mock Anthropic() or pass a mock client.
|
||||
try:
|
||||
# These would typically come from a config file or CLI args in a real app if not ENV
|
||||
# For this example, we rely on ENV or defaults being handled by constructor if not provided.
|
||||
bot = AnthropicTelegramInferenceBot(
|
||||
api_key=os.environ.get("ANTHROPIC_API_KEY") # Explicitly pass, or let constructor handle ENV
|
||||
)
|
||||
except ValueError as e:
|
||||
logging.error(f"Failed to initialize bot: {e}")
|
||||
return
|
||||
except Exception as e: # Catch any other init errors
|
||||
logging.error(f"An unexpected error occurred during bot initialization: {e}")
|
||||
return
|
||||
|
||||
# TelegramHelper also updated, ensure it's instantiated correctly for this main context.
|
||||
# For this basic main, we might not pass all configurable paths to TelegramHelper,
|
||||
# letting them use defaults.
|
||||
telegram_helper = TelegramHelper(bot)
|
||||
telegram_helper.run()
|
||||
|
||||
|
||||
@@ -7,26 +7,47 @@ from abc import ABC, abstractmethod
|
||||
from tools.base_tool import BaseTool
|
||||
|
||||
class BaseTelegramInferenceBot(ABC):
|
||||
def __init__(self):
|
||||
def __init__(self, system_prompt_content: str | None = None, system_prompt_path: str | None = None): # MODIFIED
|
||||
self.conversation_history = {}
|
||||
self.processing_status = {}
|
||||
self.system_prompt = self.load_system_prompt()
|
||||
# MODIFIED to pass arguments
|
||||
self.system_prompt = self.load_system_prompt(
|
||||
direct_content=system_prompt_content,
|
||||
file_path=system_prompt_path
|
||||
)
|
||||
self.tools, self.functions = self.load_functions()
|
||||
logging.info(f'System Prompt: {os.environ.get("SYSTEM_PROMPT_PATH")}')
|
||||
# Logging the actual source of the system prompt might be more complex now,
|
||||
# but we can log the final prompt or indicate if it's custom/default.
|
||||
# We'll also log the source of the prompt inside load_system_prompt.
|
||||
logging.info(f'System Prompt (effective): {"Custom" if self.system_prompt != "You are a helpful AI assistant." else "Default"}')
|
||||
logging.info(f'Github Repository: {os.environ.get("GITHUB_REPOSITORY")}')
|
||||
|
||||
def load_system_prompt(self):
|
||||
system_prompt_path = os.getenv("SYSTEM_PROMPT_PATH")
|
||||
if system_prompt_path and os.path.isfile(system_prompt_path):
|
||||
try:
|
||||
with open(system_prompt_path, "r", encoding="utf-8") as file:
|
||||
return file.read().strip()
|
||||
except IOError as e:
|
||||
logging.warning(f"Could not read system prompt file {system_prompt_path}: {e}")
|
||||
return "You are a helpful AI assistant."
|
||||
def load_system_prompt(self, direct_content: str | None = None, file_path: str | None = None) -> str: # MODIFIED
|
||||
default_prompt = "You are a helpful AI assistant."
|
||||
|
||||
if direct_content:
|
||||
logging.info("Using direct content for system prompt.")
|
||||
return direct_content.strip()
|
||||
|
||||
prompt_path_to_try = file_path or os.getenv("SYSTEM_PROMPT_PATH")
|
||||
|
||||
if prompt_path_to_try:
|
||||
if os.path.isfile(prompt_path_to_try):
|
||||
try:
|
||||
with open(prompt_path_to_try, "r", encoding="utf-8") as file:
|
||||
content = file.read().strip()
|
||||
logging.info(f"Successfully loaded system prompt from {prompt_path_to_try}.")
|
||||
return content
|
||||
except IOError as e:
|
||||
logging.warning(f"Could not read system prompt file {prompt_path_to_try}: {e}. Using default.")
|
||||
return default_prompt
|
||||
else:
|
||||
# This condition now also covers if 'file_path' argument was given but invalid
|
||||
logging.warning(f"System prompt file {prompt_path_to_try} not found. Using default system prompt.")
|
||||
return default_prompt
|
||||
else:
|
||||
logging.warning("SYSTEM_PROMPT_PATH is not set or file does not exist. Using default system prompt.")
|
||||
return "You are a helpful AI assistant."
|
||||
logging.info("No system prompt path provided (argument or ENV) or direct content. Using default system prompt.")
|
||||
return default_prompt
|
||||
|
||||
def load_functions(self):
|
||||
tools = []
|
||||
@@ -44,7 +65,7 @@ class BaseTelegramInferenceBot(ABC):
|
||||
for name, obj in inspect.getmembers(module):
|
||||
if inspect.isclass(obj) and issubclass(obj, BaseTool) and obj != BaseTool:
|
||||
try:
|
||||
tools.append(obj())
|
||||
tools.append(obj()) # This instantiation might be an issue for tools needing config
|
||||
except Exception as e:
|
||||
logging.error(f"Error instantiating tool {name} from {filename}: {e}")
|
||||
except Exception as e:
|
||||
@@ -87,9 +108,9 @@ class BaseTelegramInferenceBot(ABC):
|
||||
except json.JSONDecodeError as e:
|
||||
logging.error(f"Error decoding function call arguments (string) for {function_call_name}: {e}. Arguments: {function_call_arguments}")
|
||||
return f"Error: Malformed arguments for tool call: {e}"
|
||||
else: # Handle cases where arguments might be None or other unexpected types
|
||||
else:
|
||||
if function_call_arguments is None:
|
||||
function_args = {} # Default to empty dict if arguments are None
|
||||
function_args = {}
|
||||
else:
|
||||
logging.error(f"Unexpected type for function_call_arguments for {function_call_name}: {type(function_call_arguments)}. Arguments: {function_call_arguments}")
|
||||
return f"Error: Invalid argument type for tool call: {type(function_call_arguments)}"
|
||||
@@ -98,7 +119,6 @@ class BaseTelegramInferenceBot(ABC):
|
||||
for function in tool.get_functions():
|
||||
if function["function"]["name"] == function_name:
|
||||
try:
|
||||
# Ensure function_args is a dictionary before unpacking
|
||||
if not isinstance(function_args, dict):
|
||||
logging.error(f"Internal error: function_args not a dict for {function_name} before execution. Args: {function_args}")
|
||||
return f"Internal error preparing arguments for tool {function_name}."
|
||||
@@ -110,16 +130,23 @@ class BaseTelegramInferenceBot(ABC):
|
||||
return f"Error: Tool function {function_name} not found."
|
||||
|
||||
def get_system_prompt_description(self) -> str:
|
||||
"""Returns a description of the system prompt being used."""
|
||||
return f"System Prompt: {'Custom' if os.getenv('SYSTEM_PROMPT_PATH') else 'Default'}"
|
||||
# This method could be updated to be more specific about the prompt source if needed.
|
||||
# For now, it still reflects custom vs default based on the original ENV var logic's spirit.
|
||||
# A more accurate reflection would require storing how the prompt was loaded.
|
||||
# For simplicity, let's assume if it's not the default, it's "Custom".
|
||||
if self.system_prompt != "You are a helpful AI assistant.":
|
||||
return "System Prompt: Custom"
|
||||
# Check original ENV var for backward compatibility in description only
|
||||
elif os.getenv('SYSTEM_PROMPT_PATH'):
|
||||
return "System Prompt: Custom (via ENV)"
|
||||
return "System Prompt: Default"
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def get_llm_description(self) -> str:
|
||||
"""Returns a description of the LLM being used."""
|
||||
pass
|
||||
|
||||
async def get_bot_status(self) -> str:
|
||||
"""Provides a status message including prompt and LLM information."""
|
||||
prompt_desc = self.get_system_prompt_description()
|
||||
llm_desc = self.get_llm_description()
|
||||
return f"{prompt_desc}\n{llm_desc}"
|
||||
@@ -134,5 +161,4 @@ class BaseTelegramInferenceBot(ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def switch_model(self):
|
||||
"""Switches the underlying model if supported by the bot."""
|
||||
pass
|
||||
|
||||
@@ -1,42 +1,104 @@
|
||||
import os
|
||||
import logging
|
||||
from openai import OpenAI
|
||||
from openai import OpenAI # Keep for type hinting and default client creation
|
||||
from openai_compatible_inference_bot import OpenAICompatibleInferenceBot
|
||||
from telegram_helper import TelegramHelper
|
||||
from telegram_helper import TelegramHelper # Used in main
|
||||
|
||||
class ChatGPTTelegramInferenceBot(OpenAICompatibleInferenceBot):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
|
||||
DEFAULT_SMALL_MODEL_NAME = "gpt-3.5-turbo"
|
||||
DEFAULT_LARGE_MODEL_NAME = "gpt-4"
|
||||
# Default max tokens can be None, relying on parent or API defaults
|
||||
DEFAULT_SMALL_MODEL_MAX_TOKENS = None
|
||||
DEFAULT_LARGE_MODEL_MAX_TOKENS = None
|
||||
|
||||
self._configure_model_and_tokens(
|
||||
os.environ.get("OPENAI_SMALL_MODEL", "gpt-3.5-turbo"),
|
||||
os.environ.get("OPENAI_SMALL_MODEL_MAX_TOKENS")
|
||||
def __init__(
|
||||
self,
|
||||
client: OpenAI | None = None, # Accepts an OpenAI client
|
||||
api_key: str | None = None,
|
||||
small_model_name: str | None = None,
|
||||
small_model_max_tokens: str | None = None, # Kept as str for consistency with env vars
|
||||
large_model_name: str | None = None,
|
||||
large_model_max_tokens: str | None = None,
|
||||
system_prompt_content: str | None = None,
|
||||
system_prompt_path: str | None = None,
|
||||
base_url: str | None = None, # For OpenAI compatible, though direct OpenAI client doesn't use it here
|
||||
):
|
||||
# Initialize model names and tokens before calling super, as super might use them via _configure_model_and_tokens
|
||||
self.small_model_name = small_model_name or os.environ.get("OPENAI_SMALL_MODEL") or self.DEFAULT_SMALL_MODEL_NAME
|
||||
self.small_model_max_tokens_str = small_model_max_tokens or os.environ.get("OPENAI_SMALL_MODEL_MAX_TOKENS") or self.DEFAULT_SMALL_MODEL_MAX_TOKENS
|
||||
|
||||
self.large_model_name = large_model_name or os.environ.get("OPENAI_LARGE_MODEL") or self.DEFAULT_LARGE_MODEL_NAME
|
||||
self.large_model_max_tokens_str = large_model_max_tokens or os.environ.get("OPENAI_LARGE_MODEL_MAX_TOKENS") or self.DEFAULT_LARGE_MODEL_MAX_TOKENS
|
||||
|
||||
# The actual client and active model configuration will be handled by OpenAICompatibleInferenceBot's __init__
|
||||
# We pass the specific OpenAI client or parameters to create one.
|
||||
# If a client is passed, api_key and base_url might be ignored by super if super prioritizes existing client.
|
||||
super().__init__(
|
||||
client=client,
|
||||
api_key=api_key,
|
||||
model_name=self.small_model_name, # Initial model
|
||||
max_tokens_str=self.small_model_max_tokens_str,
|
||||
system_prompt_content=system_prompt_content,
|
||||
system_prompt_path=system_prompt_path,
|
||||
base_url=base_url # Pass base_url, though for standard OpenAI it's fixed
|
||||
)
|
||||
# Ensure client is of type OpenAI for this specific class, if not already set by super with a compatible one.
|
||||
# This check is more of an assertion, as OpenAICompatibleInferenceBot should handle client creation.
|
||||
if not isinstance(self.client, OpenAI):
|
||||
# If super() didn't create a vanilla OpenAI client (e.g. if base_url was for Azure)
|
||||
# we might need to recreate it here if this class *must* use a non-Azure OpenAI client.
|
||||
# However, the current structure of OpenAICompatibleInferenceBot handles this.
|
||||
# This is more about ensuring type correctness if code specific to OpenAI (non-compatible) methods were added here.
|
||||
_api_key = api_key or os.environ.get("OPENAI_API_KEY")
|
||||
if not self.client or (base_url and not isinstance(self.client, OpenAI)):
|
||||
# If superclass initialized with a generic client due to base_url, re-init for OpenAI specifically if needed.
|
||||
# For now, assume superclass correctly initializes based on absence of Azure env vars for this path.
|
||||
# This logic might be simplified once OpenAICompatibleInferenceBot is fully refactored.
|
||||
if not _api_key: # Ensure API key is available if we need to create a client
|
||||
raise ValueError("OpenAI API key must be provided for ChatGPTTelegramInferenceBot if no client is passed.")
|
||||
self.client = OpenAI(api_key=_api_key)
|
||||
logging.info("Client re-initialized to standard OpenAI client for ChatGPTTelegramInferenceBot.")
|
||||
|
||||
async def switch_model(self):
|
||||
current_small_model = os.environ.get("OPENAI_SMALL_MODEL", "gpt-3.5-turbo")
|
||||
current_large_model = os.environ.get("OPENAI_LARGE_MODEL", "gpt-4")
|
||||
# Uses instance variables for model names set in __init__
|
||||
if not self.small_model_name or not self.large_model_name:
|
||||
logging.warning("Small or Large model names for OpenAI are not defined. Cannot switch model.")
|
||||
return f"Model switching not fully configured. Currently using {self.model}."
|
||||
|
||||
if self.model == current_large_model or self.model != current_small_model:
|
||||
target_model = current_small_model
|
||||
target_max_tokens = os.environ.get("OPENAI_SMALL_MODEL_MAX_TOKENS")
|
||||
current_is_small = self.model == self.small_model_name
|
||||
current_is_large = self.model == self.large_model_name
|
||||
|
||||
if current_is_large:
|
||||
target_model = self.small_model_name
|
||||
target_max_tokens_str = self.small_model_max_tokens_str
|
||||
elif current_is_small:
|
||||
target_model = self.large_model_name
|
||||
target_max_tokens_str = self.large_model_max_tokens_str
|
||||
else:
|
||||
target_model = current_large_model
|
||||
target_max_tokens = os.environ.get("OPENAI_LARGE_MODEL_MAX_TOKENS")
|
||||
# Current model is neither the designated small nor large for this bot,
|
||||
# switch to this bot's default small model as a reset.
|
||||
logging.warning(f"Current model {self.model} is unrecognized for ChatGPT bot. Switching to default small model: {self.small_model_name}.")
|
||||
target_model = self.small_model_name
|
||||
target_max_tokens_str = self.small_model_max_tokens_str
|
||||
|
||||
self._configure_model_and_tokens(target_model, target_max_tokens)
|
||||
logging.info(f"Switched to model: {self.model}")
|
||||
return f"Switched to model: {self.model}"
|
||||
self._configure_model_and_tokens(target_model, target_max_tokens_str)
|
||||
# self.model and self.max_tokens are updated by _configure_model_and_tokens
|
||||
logging.info(f"Switched to OpenAI model: {self.model}")
|
||||
return f"Switched to OpenAI model: {self.model} (Max Tokens: {self.max_tokens if self.max_tokens is not None else 'API default'})"
|
||||
|
||||
def main():
|
||||
if not os.environ.get("OPENAI_API_KEY"):
|
||||
logging.error("FATAL: OPENAI_API_KEY environment variable not set.")
|
||||
return
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
|
||||
bot = ChatGPTTelegramInferenceBot()
|
||||
try:
|
||||
# Example: api_key from env, other params default or from env via constructor logic
|
||||
bot = ChatGPTTelegramInferenceBot(api_key=os.environ.get("OPENAI_API_KEY"))
|
||||
except ValueError as e:
|
||||
logging.error(f"FATAL: {e}")
|
||||
return
|
||||
except Exception as e:
|
||||
logging.error(f"An unexpected error occurred during bot initialization: {e}")
|
||||
return
|
||||
|
||||
telegram_helper = TelegramHelper(bot)
|
||||
telegram_helper.run()
|
||||
|
||||
|
||||
@@ -1,42 +1,102 @@
|
||||
import os
|
||||
import logging
|
||||
from openai import OpenAI
|
||||
from openai import OpenAI # For type hinting and default client creation if needed
|
||||
from openai_compatible_inference_bot import OpenAICompatibleInferenceBot
|
||||
from telegram_helper import TelegramHelper
|
||||
from telegram_helper import TelegramHelper # Used in main
|
||||
|
||||
class GeminiTelegramInferenceBot(OpenAICompatibleInferenceBot):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.client = OpenAI(api_key=os.environ.get("GEMINI_API_KEY"), base_url=os.environ.get("GEMINI_API_BASE_URL"))
|
||||
DEFAULT_GEMINI_API_BASE_URL = "https://generativelanguage.googleapis.com/v1beta"
|
||||
DEFAULT_SMALL_MODEL_NAME = "gemini-pro" # Actual model name for Gemini, not via OpenAI client directly
|
||||
DEFAULT_LARGE_MODEL_NAME = "gemini-1.5-pro-latest"
|
||||
DEFAULT_SMALL_MODEL_MAX_TOKENS = "2048" # Gemini uses outputTokenLimit, not exactly max_tokens in OpenAI sense
|
||||
DEFAULT_LARGE_MODEL_MAX_TOKENS = "8192"
|
||||
|
||||
self._configure_model_and_tokens(
|
||||
os.environ.get("GEMINI_SMALL_MODEL", "gemini-pro"),
|
||||
os.environ.get("GEMINI_SMALL_MODEL_MAX_TOKENS")
|
||||
def __init__(
|
||||
self,
|
||||
client: OpenAI | None = None, # OpenAI client for compatible mode
|
||||
api_key: str | None = None, # Gemini API Key
|
||||
base_url: str | None = None, # Gemini API Base URL for OpenAI client
|
||||
small_model_name: str | None = None,
|
||||
small_model_max_tokens: str | None = None,
|
||||
large_model_name: str | None = None,
|
||||
large_model_max_tokens: str | None = None,
|
||||
system_prompt_content: str | None = None,
|
||||
system_prompt_path: str | None = None
|
||||
):
|
||||
_api_key = api_key or os.environ.get("GEMINI_API_KEY")
|
||||
_base_url = base_url or os.environ.get("GEMINI_API_BASE_URL") or self.DEFAULT_GEMINI_API_BASE_URL
|
||||
|
||||
if not _api_key:
|
||||
# This check might seem redundant if super() also checks, but it's good for clarity
|
||||
# for this specific bot type if it were to be instantiated directly with missing critical env vars.
|
||||
raise ValueError("Gemini API key must be provided either via api_key argument or GEMINI_API_KEY environment variable.")
|
||||
|
||||
self.small_model_name = small_model_name or os.environ.get("GEMINI_SMALL_MODEL") or self.DEFAULT_SMALL_MODEL_NAME
|
||||
self.small_model_max_tokens_str = small_model_max_tokens or os.environ.get("GEMINI_SMALL_MODEL_MAX_TOKENS") or self.DEFAULT_SMALL_MODEL_MAX_TOKENS
|
||||
|
||||
self.large_model_name = large_model_name or os.environ.get("GEMINI_LARGE_MODEL") or self.DEFAULT_LARGE_MODEL_NAME
|
||||
self.large_model_max_tokens_str = large_model_max_tokens or os.environ.get("GEMINI_LARGE_MODEL_MAX_TOKENS") or self.DEFAULT_LARGE_MODEL_MAX_TOKENS
|
||||
|
||||
# Pass parameters to the OpenAICompatibleInferenceBot constructor
|
||||
# It will create an OpenAI client configured for the Gemini endpoint
|
||||
super().__init__(
|
||||
client=client,
|
||||
api_key=_api_key, # This key will be used by OpenAI client for the custom base_url
|
||||
model_name=self.small_model_name, # Initial model
|
||||
max_tokens_str=self.small_model_max_tokens_str,
|
||||
system_prompt_content=system_prompt_content,
|
||||
system_prompt_path=system_prompt_path,
|
||||
base_url=_base_url, # Crucial for Gemini via OpenAI client
|
||||
is_gemini=True # Flag for specific Gemini handling in compatible layer if needed
|
||||
)
|
||||
# self.client will be set by OpenAICompatibleInferenceBot with base_url and api_key.
|
||||
# Logging to confirm Gemini specific setup
|
||||
logging.info(f"GeminiTelegramInferenceBot initialized to use model {self.model} via {_base_url}")
|
||||
|
||||
async def switch_model(self):
|
||||
current_small_model = os.environ.get("GEMINI_SMALL_MODEL", "gemini-pro")
|
||||
current_large_model = os.environ.get("GEMINI_LARGE_MODEL", "gemini-1.5-pro-latest")
|
||||
if not self.small_model_name or not self.large_model_name:
|
||||
logging.warning("Small or Large model names for Gemini are not defined. Cannot switch model.")
|
||||
return f"Model switching not fully configured. Currently using {self.model}."
|
||||
|
||||
if self.model == current_large_model or self.model != current_small_model :
|
||||
target_model = current_small_model
|
||||
target_max_tokens = os.environ.get("GEMINI_SMALL_MODEL_MAX_TOKENS")
|
||||
current_is_small = self.model == self.small_model_name
|
||||
current_is_large = self.model == self.large_model_name
|
||||
|
||||
if current_is_large:
|
||||
target_model = self.small_model_name
|
||||
target_max_tokens_str = self.small_model_max_tokens_str
|
||||
elif current_is_small:
|
||||
target_model = self.large_model_name
|
||||
target_max_tokens_str = self.large_model_max_tokens_str
|
||||
else:
|
||||
target_model = current_large_model
|
||||
target_max_tokens = os.environ.get("GEMINI_LARGE_MODEL_MAX_TOKENS")
|
||||
logging.warning(f"Current model {self.model} is unrecognized for Gemini bot. Switching to default small model: {self.small_model_name}.")
|
||||
target_model = self.small_model_name
|
||||
target_max_tokens_str = self.small_model_max_tokens_str
|
||||
|
||||
self._configure_model_and_tokens(target_model, target_max_tokens)
|
||||
logging.info(f"Switched to model: {self.model}")
|
||||
return f"Switched to model: {self.model}"
|
||||
self._configure_model_and_tokens(target_model, target_max_tokens_str)
|
||||
logging.info(f"Switched to Gemini model: {self.model}")
|
||||
# For Gemini, max_tokens might translate to outputTokenLimit, so be clear it's a configuration parameter
|
||||
return f"Switched to Gemini model: {self.model} (Configured Max Tokens: {self.max_tokens if self.max_tokens is not None else 'API default'})"
|
||||
|
||||
def main():
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
|
||||
# GEMINI_API_KEY is crucial for this bot
|
||||
if not os.environ.get("GEMINI_API_KEY"):
|
||||
logging.error("FATAL: GEMINI_API_KEY environment variable not set.")
|
||||
return
|
||||
# GEMINI_API_BASE_URL is also important, but constructor has a default
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
try:
|
||||
bot = GeminiTelegramInferenceBot(
|
||||
# api_key and base_url will be picked from ENV by constructor if not passed
|
||||
)
|
||||
except ValueError as e:
|
||||
logging.error(f"FATAL: {e}")
|
||||
return
|
||||
except Exception as e: # Catch any other init errors
|
||||
logging.error(f"An unexpected error occurred during bot initialization: {e}")
|
||||
return
|
||||
|
||||
bot = GeminiTelegramInferenceBot()
|
||||
telegram_helper = TelegramHelper(bot)
|
||||
telegram_helper.run()
|
||||
|
||||
|
||||
@@ -3,32 +3,114 @@ import os
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from base_telegram_inference_bot import BaseTelegramInferenceBot
|
||||
from openai import OpenAI
|
||||
from openai import OpenAI, AzureOpenAI # Import both
|
||||
|
||||
class OpenAICompatibleInferenceBot(BaseTelegramInferenceBot):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Client and model configuration will be handled by subclasses
|
||||
self.client = None
|
||||
self.model = None
|
||||
self.max_tokens = None
|
||||
DEFAULT_MAX_HISTORY_LENGTH = 20
|
||||
DEFAULT_MAX_TOKENS = 1000 # Default for _configure_model_and_tokens
|
||||
|
||||
def _configure_model_and_tokens(self, model_name, max_tokens_str, default_max_tokens=1000):
|
||||
self.model = model_name if model_name else "default-model"
|
||||
def __init__(
|
||||
self,
|
||||
client: OpenAI | AzureOpenAI | None = None,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
api_version: str | None = None, # For Azure
|
||||
azure_deployment: str | None = None, # Model for Azure, distinct from general model_name if needed
|
||||
model_name: str | None = None, # General model name for the API call
|
||||
max_tokens_str: str | None = None,
|
||||
system_prompt_content: str | None = None,
|
||||
system_prompt_path: str | None = None,
|
||||
is_gemini: bool = False, # Hint for specific API key if others are not set
|
||||
max_history_length: int | None = None
|
||||
):
|
||||
super().__init__(system_prompt_content=system_prompt_content, system_prompt_path=system_prompt_path)
|
||||
|
||||
self.max_history_length = max_history_length if max_history_length is not None else self.DEFAULT_MAX_HISTORY_LENGTH
|
||||
self.client = client
|
||||
|
||||
if not self.client:
|
||||
_api_key = api_key
|
||||
_base_url = base_url
|
||||
_api_version = api_version
|
||||
_azure_deployment_name = azure_deployment # This will be used as the model for Azure
|
||||
|
||||
# Determine if configuring for Azure OpenAI
|
||||
is_azure = False
|
||||
if _azure_deployment_name or (_base_url and "azure.com" in _base_url) or os.environ.get("AZURE_OPENAI_ENDPOINT"):
|
||||
is_azure = True
|
||||
|
||||
if is_azure:
|
||||
_base_url = _base_url or os.environ.get("AZURE_OPENAI_ENDPOINT")
|
||||
_api_key = _api_key or os.environ.get("AZURE_OPENAI_KEY")
|
||||
_api_version = _api_version or os.environ.get("AZURE_OPENAI_API_VERSION")
|
||||
# For Azure, the model parameter in API calls is the deployment name
|
||||
_effective_model_name = _azure_deployment_name or model_name # Use deployment if available, else model_name
|
||||
if not _base_url or not _api_key or not _api_version or not _effective_model_name:
|
||||
raise ValueError("For Azure OpenAI, endpoint, API key, API version, and deployment/model name must be configured.")
|
||||
self.client = AzureOpenAI(
|
||||
api_key=_api_key,
|
||||
azure_endpoint=_base_url,
|
||||
api_version=_api_version
|
||||
)
|
||||
# The model to be used in API calls for Azure is the deployment name.
|
||||
# _configure_model_and_tokens will set self.model to this.
|
||||
model_name_for_config = _effective_model_name
|
||||
logging.info(f"Initialized AzureOpenAI client for deployment: {model_name_for_config} at {_base_url}")
|
||||
else:
|
||||
# Standard OpenAI or other OpenAI-compatible (like Gemini via base_url)
|
||||
_base_url = _base_url or os.environ.get("OPENAI_API_BASE_URL") # For other compatible APIs
|
||||
if not _api_key: # Try different ENV sources for API key
|
||||
if is_gemini and os.environ.get("GEMINI_API_KEY"):
|
||||
_api_key = os.environ.get("GEMINI_API_KEY")
|
||||
else:
|
||||
_api_key = os.environ.get("OPENAI_API_KEY")
|
||||
|
||||
if not _api_key and not _base_url : # For completely local models with no key needed via base_url
|
||||
pass # Allow client to be created with no API key if base_url is set and points to local model
|
||||
elif not _api_key:
|
||||
raise ValueError("API key must be provided for OpenAI compatible client if not Azure or local anonymous.")
|
||||
|
||||
self.client = OpenAI(api_key=_api_key, base_url=_base_url)
|
||||
model_name_for_config = model_name # Use the general model_name for non-Azure
|
||||
log_msg = f"Initialized OpenAI compatible client. Target URL: {_base_url if _base_url else 'OpenAI default'}."
|
||||
logging.info(log_msg)
|
||||
else:
|
||||
# Client was provided directly
|
||||
model_name_for_config = model_name # Use provided model_name
|
||||
logging.info(f"Using provided client: {type(self.client)}")
|
||||
|
||||
# Configure the actual model name and max_tokens for API calls
|
||||
self._configure_model_and_tokens(
|
||||
model_name_for_config,
|
||||
max_tokens_str,
|
||||
default_max_tokens=self.DEFAULT_MAX_TOKENS
|
||||
)
|
||||
|
||||
def _configure_model_and_tokens(self, model_name: str | None, max_tokens_str: str | None, default_max_tokens: int = 1000):
|
||||
self.model = model_name if model_name else "default-model" # Fallback model name
|
||||
try:
|
||||
self.max_tokens = int(max_tokens_str) if max_tokens_str is not None else default_max_tokens
|
||||
# If max_tokens_str is explicitly "None" or empty, treat as None for API default
|
||||
if max_tokens_str and max_tokens_str.lower() not in ["none", "", "null"]:
|
||||
self.max_tokens = int(max_tokens_str)
|
||||
else:
|
||||
self.max_tokens = None # Use API default by not sending the parameter or sending null
|
||||
except ValueError:
|
||||
logging.error(f"Invalid value for max_tokens: {max_tokens_str}. Using default {default_max_tokens}.")
|
||||
self.max_tokens = default_max_tokens
|
||||
logging.info(f"Configured to use model: {self.model} with max_tokens: {self.max_tokens}")
|
||||
logging.warning(f"Invalid value for max_tokens: {max_tokens_str}. Using API default (None). stalwart default was {default_max_tokens}")
|
||||
self.max_tokens = None # Use API default
|
||||
|
||||
logging.info(f"Configured to use model: {self.model} with max_tokens: {self.max_tokens if self.max_tokens is not None else 'API default'}")
|
||||
|
||||
def get_llm_description(self) -> str:
|
||||
return f"LLM: {self.model}, Max Tokens: {self.max_tokens}"
|
||||
client_type = type(self.client).__name__
|
||||
return f"Client: {client_type}, LLM: {self.model}, Max Tokens: {self.max_tokens if self.max_tokens is not None else 'API default'}"
|
||||
|
||||
def get_chat_response(self, messages):
|
||||
if not self.client:
|
||||
raise ValueError("OpenAI client not initialized. Subclasses must initialize it.")
|
||||
# This should ideally not be hit if __init__ is successful
|
||||
logging.error("OpenAI client not initialized before get_chat_response.")
|
||||
raise ValueError("OpenAI client not initialized.")
|
||||
try:
|
||||
# Pass self.max_tokens directly. If None, OpenAI library omits it or handles it.
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
@@ -38,32 +120,33 @@ class OpenAICompatibleInferenceBot(BaseTelegramInferenceBot):
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
logging.error(f"API call failed: {e}")
|
||||
logging.error(f"API call to model {self.model} failed: {e}")
|
||||
raise
|
||||
|
||||
async def handle_message(self, user_id, user_message):
|
||||
if user_id not in self.conversation_history:
|
||||
if user_id not in self.conversation_history or not self.conversation_history[user_id]:
|
||||
self.conversation_history[user_id] = []
|
||||
if hasattr(self, 'system_prompt') and self.system_prompt:
|
||||
if self.system_prompt: # Use the loaded system_prompt
|
||||
self.conversation_history[user_id].append({"role": "system", "content": self.system_prompt})
|
||||
|
||||
self.conversation_history[user_id].append({"role": "user", "content": user_message})
|
||||
messages = self.conversation_history[user_id]
|
||||
messages = list(self.conversation_history[user_id]) # Work with a copy for this turn
|
||||
|
||||
response = self.get_chat_response(messages)
|
||||
|
||||
if not (response.choices and response.choices[0].message):
|
||||
logging.error("No valid response choice message from LLM.")
|
||||
# Persist the user message in history even if LLM fails this turn
|
||||
self.conversation_history[user_id] = messages
|
||||
return "Error: Could not get a valid response from the LLM."
|
||||
|
||||
messages.append(response.choices[0].message) # Append the assistant's response message
|
||||
assistant_message = response.choices[0].message
|
||||
messages.append(assistant_message)
|
||||
|
||||
tool_calls_from_response = []
|
||||
if response.choices[0].message.tool_calls:
|
||||
tool_calls_from_response.extend(response.choices[0].message.tool_calls)
|
||||
tool_calls_from_response = list(assistant_message.tool_calls) if assistant_message.tool_calls else []
|
||||
|
||||
tool_use_count = 0
|
||||
MAX_TOOL_ITERATIONS = 200
|
||||
MAX_TOOL_ITERATIONS = 5 # OpenAI compatible typically uses fewer iterations than Anthropic
|
||||
|
||||
while tool_calls_from_response and tool_use_count < MAX_TOOL_ITERATIONS:
|
||||
tool_results_for_model = []
|
||||
@@ -71,20 +154,24 @@ class OpenAICompatibleInferenceBot(BaseTelegramInferenceBot):
|
||||
for tool_call in tool_calls_from_response:
|
||||
tool_call_id = tool_call.id
|
||||
function_to_call = tool_call.function
|
||||
function_name = function_to_call.name
|
||||
function_args_str = function_to_call.arguments
|
||||
|
||||
logging.info(f"Attempting to call tool: {function_to_call.name} with args: {function_to_call.arguments}")
|
||||
logging.info(f"Attempting to call tool: {function_name} with args: {function_args_str}")
|
||||
try:
|
||||
tool_response_content = self.call_tool(function_to_call.name, function_to_call.arguments)
|
||||
# Arguments are already a string from the API, self.call_tool expects dict or string
|
||||
tool_response_content = self.call_tool(function_name, function_args_str)
|
||||
# Ensure content is string for OpenAI tool role
|
||||
if not isinstance(tool_response_content, str):
|
||||
tool_response_content = json.dumps(tool_response_content)
|
||||
except Exception as e:
|
||||
logging.error(f"Error calling tool {function_to_call.name}: {e}")
|
||||
tool_response_content = f"Error executing tool {function_to_call.name}: {str(e)}"
|
||||
logging.error(f"Error calling tool {function_name}: {e}")
|
||||
tool_response_content = f"Error executing tool {function_name}: {str(e)}"
|
||||
|
||||
tool_results_for_model.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call_id,
|
||||
"name": function_to_call.name,
|
||||
"name": function_name,
|
||||
"content": tool_response_content
|
||||
})
|
||||
|
||||
@@ -93,40 +180,50 @@ class OpenAICompatibleInferenceBot(BaseTelegramInferenceBot):
|
||||
response = self.get_chat_response(messages)
|
||||
if not (response.choices and response.choices[0].message):
|
||||
logging.error("No valid response choice message from LLM after tool call.")
|
||||
self.conversation_history[user_id] = messages # Persist state before error
|
||||
return "Error: Could not get a valid response from the LLM after tool call."
|
||||
|
||||
messages.append(response.choices[0].message)
|
||||
assistant_message = response.choices[0].message
|
||||
messages.append(assistant_message)
|
||||
|
||||
tool_calls_from_response = []
|
||||
if response.choices[0].message.tool_calls:
|
||||
tool_calls_from_response.extend(response.choices[0].message.tool_calls)
|
||||
tool_calls_from_response = list(assistant_message.tool_calls) if assistant_message.tool_calls else []
|
||||
|
||||
tool_use_count += 1
|
||||
if tool_use_count >= MAX_TOOL_ITERATIONS and tool_calls_from_response:
|
||||
logging.warning(f"Max tool iterations ({MAX_TOOL_ITERATIONS}) reached. Returning last assistant message.")
|
||||
# Ensure final content is returned even if max iterations hit with pending tool calls
|
||||
break
|
||||
|
||||
# Conversation history management
|
||||
# This limit should be reviewed and potentially made configurable
|
||||
if len(self.conversation_history[user_id]) > 20: # Example limit, adjust as needed
|
||||
self.conversation_history[user_id] = self.conversation_history[user_id][-20:]
|
||||
self.conversation_history[user_id] = messages # Persist the full exchange for this turn
|
||||
# Apply history length limit
|
||||
if len(self.conversation_history[user_id]) > self.max_history_length:
|
||||
# Keep system prompt if present as the first message, then trim the rest
|
||||
if self.conversation_history[user_id][0]["role"] == "system":
|
||||
system_msg = [self.conversation_history[user_id][0]]
|
||||
trimmed_history = self.conversation_history[user_id][-(self.max_history_length-1):]
|
||||
self.conversation_history[user_id] = system_msg + trimmed_history
|
||||
else:
|
||||
self.conversation_history[user_id] = self.conversation_history[user_id][-self.max_history_length:]
|
||||
|
||||
final_assistant_message = messages[-1]
|
||||
return final_assistant_message.content if final_assistant_message.role == "assistant" and final_assistant_message.content else "No content in final message."
|
||||
return final_assistant_message.content if final_assistant_message.role == "assistant" and final_assistant_message.content is not None else "Assistant did not provide a textual response."
|
||||
|
||||
async def start(self):
|
||||
logging.info(f"{self.__class__.__name__} started.")
|
||||
logging.info(f"{self.__class__.__name__} (Model: {self.model}) started.")
|
||||
|
||||
def clear(self, user_id):
|
||||
super().clear_conversation_history(user_id)
|
||||
# clear_conversation_history is inherited from BaseTelegramInferenceBot
|
||||
|
||||
async def abort_processing(self, user_id):
|
||||
# This is a soft abort for OpenAI compatible bots as API calls are synchronous within handle_message
|
||||
if user_id in self.processing_status:
|
||||
self.processing_status[user_id]["processing"] = False
|
||||
self.clear(user_id)
|
||||
return "Processing aborted and conversation cleared."
|
||||
self.clear_processing_status(user_id) # Use base class method
|
||||
logging.info(f"Processing aborted for user {user_id}.")
|
||||
# Optionally clear conversation history or let user do it explicitly
|
||||
# super().clear_conversation_history(user_id)
|
||||
return "Processing aborted. You can send a new message or /clear the conversation."
|
||||
else:
|
||||
self.clear(user_id)
|
||||
return "No active processing found to abort. Conversation cleared."
|
||||
# super().clear_conversation_history(user_id)
|
||||
return "No active processing found to abort. If you wish, /clear the conversation history."
|
||||
|
||||
@abstractmethod
|
||||
async def switch_model(self):
|
||||
|
||||
@@ -30,3 +30,19 @@ Pull Requests and Issues: The Collaborative Symphony
|
||||
Pull Request Mastery: Treat pull requests as complete change proposals. They evolve with each commit to their branch.
|
||||
Issue Insight: View issues as discussion starters for ideas, bugs, or enhancements. They may inspire multiple pull requests.
|
||||
Ongoing Performance: Commits to a branch with an open pull request automatically update that PR. No need for new PRs per commit.
|
||||
|
||||
**Focus on Testability and Robust Design (Lessons Learned):**
|
||||
|
||||
When implementing or refactoring, *aggressively prioritize testability*. This includes:
|
||||
* **Dependency Injection:** Consistently apply Dependency Injection for all external services (e.g., API clients, database connections), configurations (e.g., API keys, file paths, model names, feature flags), and system resources (e.g., file system access via `open`, network requests via `requests.Session`, time/clock functions if timing is critical and needs mocking).
|
||||
* **Configuration Management:** Externalize configurations. Allow them to be passed via constructor arguments, with environment variables or sensible defaults as fallbacks. Avoid hardcoding paths, keys, or URLs directly within functions or methods.
|
||||
* **Separation of Concerns:** Clearly separate core business logic from framework-specific code, I/O operations, or direct external service interactions. This often involves creating internal `_logic` methods that can be tested independently of, for example, Telegram API update/context objects.
|
||||
* **Logging for Libraries/Tools:** Components like tools or libraries should use `logging.getLogger(__name__)` for their logging. They should *not* configure handlers (e.g., `FileHandler`, `StreamHandler`) directly. Logging setup (handlers, formatters, levels) is the responsibility of the main application. Tools can accept an optional `logger` instance via their constructor for more explicit control by the application or for testing.
|
||||
* **State Management for Testability:** For stateful components, tools, or classes, ensure there's a mechanism to reset or clear their state (e.g., a `clear()` or `reset()` method). This is crucial for test isolation and predictable behavior during testing.
|
||||
* **Robust Metrics & Profiling:** When implementing metrics collection (e.g., using `cProfile` via decorators), ensure that data extraction (like execution time) is robust. Rely on stable APIs or attributes of the profiling tools (e.g., `pstats.Stats.stats` dictionary) rather than fragile string parsing of their output. Provide methods to clear/reset collected metrics to facilitate testing of the metrics system itself.
|
||||
* **Comprehensive Unit Testing Strategy:** When generating unit tests:
|
||||
* For abstract base classes, create simple concrete subclasses within the test file to enable instantiation and testing of shared, non-abstract logic.
|
||||
* Employ `unittest.mock` (`MagicMock`, `patch`, `AsyncMock`, `mock_open`) extensively to isolate the unit under test from its dependencies.
|
||||
* Cover various scenarios: initialization with different configurations, success paths for public methods, error conditions (e.g., API errors, file not found, invalid arguments), and relevant edge cases.
|
||||
* Thoroughly mock external dependencies like file system operations, network calls, and any injected client objects.
|
||||
* **Iterative Development Cycle:** For significant changes or new features, propose refactoring for testability *first*, then proceed to write comprehensive unit tests against the refactored code. This leads to more robust, maintainable, and reliable components.
|
||||
|
||||
+170
-58
@@ -3,44 +3,70 @@ import logging
|
||||
import sys
|
||||
import asyncio
|
||||
import time
|
||||
from typing import TypedDict, Union, TypeAlias, List # Added List for type hint
|
||||
from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup
|
||||
from telegram.ext import Application, CommandHandler, MessageHandler, filters, ContextTypes, CallbackQueryHandler
|
||||
from browse_command import browse_command, button_callback
|
||||
|
||||
class MessageHandlerLogicResult(TypedDict):
|
||||
success: bool
|
||||
response_text: Union[str, None]
|
||||
error_message: Union[str, None]
|
||||
|
||||
LogicResult: TypeAlias = MessageHandlerLogicResult
|
||||
|
||||
class TelegramHelper:
|
||||
# --- Constants for configurable paths and magic strings ---
|
||||
REBOOT_CLAUDE_FILE = '.reboot_claude'
|
||||
REBOOT_FILE = '.doreboot'
|
||||
CLAUDE_REBOOT_TARGET = 'claude'
|
||||
HTML_QUOTE_BLOCK_START = '<blockquote expandable><b>Thinking...</b>'
|
||||
HTML_QUOTE_BLOCK_END = '</blockquote>'
|
||||
DEFAULT_REBOOT_CLAUDE_FILE = '.reboot_claude'
|
||||
DEFAULT_REBOOT_FILE = '.doreboot'
|
||||
CHUNK_MESSAGE_SLEEP_DURATION = 0.1
|
||||
|
||||
def __init__(self, bot):
|
||||
def __init__(self, bot,
|
||||
reboot_claude_file_path: str | None = None,
|
||||
reboot_file_path: str | None = None,
|
||||
chunk_message_sleep_duration: float | None = None):
|
||||
self.bot = bot
|
||||
self.telegram_bot_token = os.getenv('TELEGRAM_BOT_TOKEN')
|
||||
self.start_time = time.time()
|
||||
self.reboot_claude_file = reboot_claude_file_path or self.DEFAULT_REBOOT_CLAUDE_FILE
|
||||
self.reboot_file = reboot_file_path or self.DEFAULT_REBOOT_FILE
|
||||
self.chunk_message_sleep_duration = chunk_message_sleep_duration if chunk_message_sleep_duration is not None else self.CHUNK_MESSAGE_SLEEP_DURATION
|
||||
|
||||
async def _start_logic(self) -> str:
|
||||
await self.bot.start()
|
||||
return "Hello! I'm your AI assistant. How can I help you today?"
|
||||
|
||||
async def start(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
await self.bot.start()
|
||||
await update.message.reply_text(
|
||||
"Hello! I'm your AI assistant. How can I help you today?"
|
||||
)
|
||||
response_message = await self._start_logic()
|
||||
await update.message.reply_text(response_message)
|
||||
|
||||
async def _clear_logic(self, user_id: int) -> str:
|
||||
self.bot.clear_conversation_history(user_id)
|
||||
return "Conversation history cleared. Let's start fresh!"
|
||||
|
||||
async def clear(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
user_id = update.effective_user.id
|
||||
self.bot.clear_conversation_history(user_id)
|
||||
await update.message.reply_text("Conversation history cleared. Let's start fresh!")
|
||||
response_message = await self._clear_logic(user_id)
|
||||
await update.message.reply_text(response_message)
|
||||
|
||||
async def _status_logic(self) -> str:
|
||||
return await self.bot.get_bot_status()
|
||||
|
||||
async def status(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
status_message = await self.bot.get_bot_status()
|
||||
await update.message.reply_text(status_message)
|
||||
response_message = await self._status_logic()
|
||||
await update.message.reply_text(response_message)
|
||||
|
||||
async def _switch_logic(self) -> str:
|
||||
if hasattr(self.bot, 'switch_model'):
|
||||
return await self.bot.switch_model()
|
||||
else:
|
||||
return "Model switching is not supported for this bot."
|
||||
|
||||
async def switch(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
if hasattr(self.bot, 'switch_model'):
|
||||
status_message = await self.bot.switch_model()
|
||||
await update.message.reply_text(status_message)
|
||||
else:
|
||||
await update.message.reply_text("Model switching is not supported for this bot.")
|
||||
response_message = await self._switch_logic()
|
||||
await update.message.reply_text(response_message)
|
||||
|
||||
async def update_status_message(self, context: ContextTypes.DEFAULT_TYPE, chat_id: int, message_id: int, status: str):
|
||||
keyboard = [
|
||||
@@ -54,65 +80,147 @@ class TelegramHelper:
|
||||
reply_markup=reply_markup
|
||||
)
|
||||
|
||||
async def handle_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
async def _handle_message_logic(self, user_id: int, user_message: str) -> LogicResult:
|
||||
try:
|
||||
user_id = update.effective_user.id
|
||||
user_message = update.message.text
|
||||
|
||||
logging.info(f"Message from user {user_id}: {user_message}")
|
||||
|
||||
status_message = await update.message.reply_text("Processing your request...", reply_markup=InlineKeyboardMarkup([[InlineKeyboardButton("Abort", callback_data='abort')]]))
|
||||
self.bot.set_processing_status(user_id, status_message.message_id)
|
||||
|
||||
response = await self.bot.handle_message(user_id, user_message)
|
||||
processed_response = response.replace("<think>", self.HTML_QUOTE_BLOCK_START).replace("</think>", self.HTML_QUOTE_BLOCK_END)
|
||||
return LogicResult(success=True, response_text=processed_response, error_message=None)
|
||||
except Exception as e:
|
||||
logging.error(f"Error in _handle_message_logic for user {user_id}: {str(e)}")
|
||||
return LogicResult(success=False, response_text=None, error_message=str(e))
|
||||
|
||||
await context.bot.delete_message(chat_id=update.effective_chat.id, message_id=status_message.message_id)
|
||||
async def handle_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
user_id = update.effective_user.id
|
||||
user_message = update.message.text
|
||||
chat_id = update.effective_chat.id
|
||||
status_message_obj = None
|
||||
|
||||
try:
|
||||
status_message_obj = await update.message.reply_text(
|
||||
"Processing your request...",
|
||||
reply_markup=InlineKeyboardMarkup([[InlineKeyboardButton("Abort", callback_data='abort')]])
|
||||
)
|
||||
self.bot.set_processing_status(user_id, status_message_obj.message_id)
|
||||
|
||||
logic_result = await self._handle_message_logic(user_id, user_message)
|
||||
|
||||
if status_message_obj:
|
||||
try:
|
||||
await context.bot.delete_message(chat_id=chat_id, message_id=status_message_obj.message_id)
|
||||
except Exception as e_del:
|
||||
logging.warning(f"Failed to delete status message: {e_del}")
|
||||
self.bot.clear_processing_status(user_id)
|
||||
|
||||
response = response.replace("<think>", self.HTML_QUOTE_BLOCK_START).replace("</think>", self.HTML_QUOTE_BLOCK_END)
|
||||
|
||||
if len(response) > 4096:
|
||||
chunks = [response[i:i + 4096] for i in range(0, len(response), 4096)]
|
||||
for chunk in chunks:
|
||||
await update.message.reply_text(chunk)
|
||||
await asyncio.sleep(0.1)
|
||||
if logic_result["success"]:
|
||||
response_text = logic_result["response_text"]
|
||||
if response_text:
|
||||
if len(response_text) > 4096:
|
||||
chunks = [response_text[i:i + 4096] for i in range(0, len(response_text), 4096)]
|
||||
for chunk in chunks:
|
||||
await update.message.reply_text(chunk)
|
||||
await asyncio.sleep(self.chunk_message_sleep_duration)
|
||||
else:
|
||||
await update.message.reply_text(response_text)
|
||||
else:
|
||||
logging.warning("Successful logic result but no response text.")
|
||||
await update.message.reply_text("Something went unexpectedly well, but I have nothing to say.")
|
||||
else:
|
||||
await update.message.reply_text(response)
|
||||
await update.message.reply_text("Sorry, an error occurred while processing your request.")
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"An error occurred: {str(e)}")
|
||||
await update.message.reply_text("Sorry, an error occurred while processing your request.")
|
||||
logging.error(f"Outer error in handle_message for user {user_id}: {str(e)}")
|
||||
if status_message_obj and self.bot.processing_status.get(user_id):
|
||||
self.bot.clear_processing_status(user_id)
|
||||
try:
|
||||
await update.message.reply_text("Sorry, an unexpected error occurred with the bot.")
|
||||
except Exception as e_reply:
|
||||
logging.error(f"Failed to send error reply: {e_reply}")
|
||||
|
||||
async def _abort_processing_logic(self, user_id: int) -> str:
|
||||
return await self.bot.abort_processing(user_id)
|
||||
|
||||
async def abort_processing(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
query = update.callback_query
|
||||
await query.answer()
|
||||
|
||||
user_id = query.from_user.id
|
||||
result = await self.bot.abort_processing(user_id)
|
||||
await query.edit_message_text(text=result)
|
||||
response_text = await self._abort_processing_logic(user_id)
|
||||
await query.edit_message_text(text=response_text)
|
||||
|
||||
# --- Reboot Command ---
|
||||
def _reboot_logic(self, user_message_parts: List[str], chat_id_to_write: str) -> None:
|
||||
"""Handles the logic for creating reboot files."""
|
||||
if len(user_message_parts) > 1 and user_message_parts[1].lower() == self.CLAUDE_REBOOT_TARGET:
|
||||
try:
|
||||
with open(self.reboot_claude_file, 'w') as f:
|
||||
f.write("") # Create/truncate the file
|
||||
logging.info(f"Created/truncated Claude reboot file: {self.reboot_claude_file}")
|
||||
except IOError as e:
|
||||
logging.error(f"Failed to create/truncate Claude reboot file {self.reboot_claude_file}: {e}")
|
||||
|
||||
# Create the main reboot file if it doesn't exist
|
||||
if not os.path.exists(self.reboot_file):
|
||||
try:
|
||||
with open(self.reboot_file, 'w') as f:
|
||||
f.write(chat_id_to_write)
|
||||
logging.info(f"Created main reboot file: {self.reboot_file} with chat_id.")
|
||||
except IOError as e:
|
||||
logging.error(f"Failed to create main reboot file {self.reboot_file}: {e}")
|
||||
else:
|
||||
logging.info(f"Main reboot file {self.reboot_file} already exists. Not overwriting chat_id.")
|
||||
|
||||
async def reboot(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
user_message = update.message.text.split()
|
||||
if len(user_message) > 1 and user_message[1].lower() == self.CLAUDE_REBOOT_TARGET:
|
||||
open(self.REBOOT_CLAUDE_FILE, 'w').close()
|
||||
"""Handles the /reboot command, triggers file creation and exits."""
|
||||
user_message_parts = update.message.text.split()
|
||||
chat_id_str = str(update.effective_chat.id) if update and update.effective_chat else ""
|
||||
|
||||
self._reboot_logic(user_message_parts, chat_id_str)
|
||||
|
||||
if update:
|
||||
await update.message.reply_text("Rebooting the bot...")
|
||||
logging.info("Received reboot command. Exiting process...")
|
||||
reboot_file_path = self.REBOOT_FILE
|
||||
if not os.path.exists(reboot_file_path):
|
||||
with open(reboot_file_path, 'w') as f:
|
||||
f.write(str(update.effective_chat.id) if update else "")
|
||||
sys.exit(0)
|
||||
try:
|
||||
await update.message.reply_text("Rebooting the bot...")
|
||||
except Exception as e_reply:
|
||||
logging.error(f"Failed to send reboot reply: {e_reply}")
|
||||
|
||||
async def check_doreboot_file(self, application: Application):
|
||||
reboot_file_path = self.REBOOT_FILE
|
||||
if os.path.exists(reboot_file_path):
|
||||
with open(reboot_file_path, 'r') as f:
|
||||
chat_id = f.read().strip()
|
||||
if chat_id:
|
||||
logging.info("Initiating shutdown for reboot...")
|
||||
sys.exit(0) # This part is not directly testable for completion in unit tests
|
||||
|
||||
# --- Check Doreboot File ---
|
||||
async def _check_doreboot_file_logic(self) -> Union[str, None]:
|
||||
"""Checks for the reboot file, reads chat_id, removes file, and returns chat_id."""
|
||||
if os.path.exists(self.reboot_file):
|
||||
chat_id = None
|
||||
try:
|
||||
with open(self.reboot_file, 'r') as f:
|
||||
chat_id = f.read().strip()
|
||||
# Attempt to remove the file after reading
|
||||
try:
|
||||
os.remove(self.reboot_file)
|
||||
logging.info(f"Successfully read and removed reboot file: {self.reboot_file}")
|
||||
except OSError as e_remove:
|
||||
logging.error(f"Failed to remove reboot file {self.reboot_file} after reading: {e_remove}")
|
||||
# Still return chat_id if read was successful, to attempt notification
|
||||
return chat_id
|
||||
except IOError as e_read:
|
||||
logging.error(f"Error reading reboot file {self.reboot_file}: {e_read}")
|
||||
# If reading failed, attempt to remove anyway if it exists, to prevent stale files
|
||||
if os.path.exists(self.reboot_file):
|
||||
try:
|
||||
os.remove(self.reboot_file)
|
||||
logging.warning(f"Removed reboot file {self.reboot_file} after a read error.")
|
||||
except OSError as e_remove_after_fail:
|
||||
logging.error(f"Failed to remove reboot file {self.reboot_file} even after a read error: {e_remove_after_fail}")
|
||||
return None # Reading failed
|
||||
return None # File does not exist
|
||||
|
||||
async def check_doreboot_file(self, application: Application) -> None:
|
||||
"""Checks for reboot file using logic method and sends notification if applicable."""
|
||||
chat_id = await self._check_doreboot_file_logic()
|
||||
if chat_id:
|
||||
try:
|
||||
await application.bot.send_message(chat_id=chat_id, text="The application has finished initializing.")
|
||||
os.remove(reboot_file_path)
|
||||
logging.info(f"Sent reboot initialization notification to chat_id: {chat_id}")
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to send reboot initialization notification to chat_id {chat_id}: {e}")
|
||||
|
||||
async def browse(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
await browse_command(update, context, self.bot)
|
||||
@@ -132,6 +240,10 @@ class TelegramHelper:
|
||||
|
||||
logging.info("Bot is running...")
|
||||
|
||||
asyncio.get_event_loop().create_task(self.check_doreboot_file(application))
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running(): # pragma: no cover
|
||||
loop.create_task(self.check_doreboot_file(application))
|
||||
else: # pragma: no cover
|
||||
asyncio.run(self.check_doreboot_file(application))
|
||||
|
||||
application.run_polling()
|
||||
|
||||
@@ -0,0 +1,280 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch, AsyncMock, ANY
|
||||
import os
|
||||
|
||||
# Assuming anthropic_telegram_inference_bot.py is in the parent directory or PYTHONPATH is set
|
||||
from anthropic_telegram_inference_bot import AnthropicTelegramInferenceBot
|
||||
|
||||
# Mock response from Anthropic client's messages.create
|
||||
def create_mock_anthropic_response(content_text=None, stop_reason="end_turn", tool_use_parts=None):
|
||||
mock_response = MagicMock()
|
||||
mock_response.stop_reason = stop_reason
|
||||
|
||||
content_blocks = []
|
||||
if content_text:
|
||||
text_block = MagicMock()
|
||||
text_block.type = "text"
|
||||
text_block.text = content_text
|
||||
content_blocks.append(text_block)
|
||||
|
||||
if tool_use_parts:
|
||||
for tu_part in tool_use_parts: # tu_part = {"id": "toolu_123", "name": "get_weather", "input": {}}
|
||||
tool_block = MagicMock()
|
||||
tool_block.type = "tool_use"
|
||||
tool_block.id = tu_part["id"]
|
||||
tool_block.name = tu_part["name"]
|
||||
tool_block.input = tu_part["input"]
|
||||
content_blocks.append(tool_block)
|
||||
|
||||
mock_response.content = content_blocks
|
||||
return mock_response
|
||||
|
||||
class TestAnthropicTelegramInferenceBot(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.original_anthropic_api_key = os.environ.get("ANTHROPIC_API_KEY")
|
||||
self.original_small_model = os.environ.get("ANTHROPIC_SMALL_MODEL")
|
||||
self.original_large_model = os.environ.get("ANTHROPIC_LARGE_MODEL")
|
||||
self.original_system_prompt_path = os.environ.get("SYSTEM_PROMPT_PATH")
|
||||
|
||||
for key in ["ANTHROPIC_API_KEY", "ANTHROPIC_SMALL_MODEL", "ANTHROPIC_LARGE_MODEL", "SYSTEM_PROMPT_PATH"]:
|
||||
if os.environ.get(key):
|
||||
del os.environ[key]
|
||||
|
||||
self.mock_anthropic_client_instance = MagicMock()
|
||||
self.mock_anthropic_client_instance.messages.create = MagicMock()
|
||||
|
||||
def tearDown(self):
|
||||
if self.original_anthropic_api_key: os.environ["ANTHROPIC_API_KEY"] = self.original_anthropic_api_key
|
||||
if self.original_small_model: os.environ["ANTHROPIC_SMALL_MODEL"] = self.original_small_model
|
||||
if self.original_large_model: os.environ["ANTHROPIC_LARGE_MODEL"] = self.original_large_model
|
||||
if self.original_system_prompt_path: os.environ["SYSTEM_PROMPT_PATH"] = self.original_system_prompt_path
|
||||
|
||||
@patch('anthropic.Anthropic')
|
||||
def test_init_with_anthropic_defaults_env_key(self, MockAnthropicConstructor):
|
||||
MockAnthropicConstructor.return_value = self.mock_anthropic_client_instance
|
||||
os.environ["ANTHROPIC_API_KEY"] = "test_anthropic_key"
|
||||
|
||||
bot = AnthropicTelegramInferenceBot()
|
||||
|
||||
MockAnthropicConstructor.assert_called_once_with(api_key="test_anthropic_key")
|
||||
self.assertEqual(bot.anthropic_client, self.mock_anthropic_client_instance)
|
||||
self.assertEqual(bot.model, os.environ.get("ANTHROPIC_SMALL_MODEL", "claude-3-haiku-20240307"))
|
||||
self.assertEqual(bot.max_tokens, int(os.environ.get("ANTHROPIC_SMALL_MODEL_MAX_TOKENS", 2000)))
|
||||
|
||||
@patch('anthropic.Anthropic')
|
||||
def test_init_with_provided_client_and_models(self, MockAnthropicConstructor):
|
||||
preconfigured_client = MagicMock()
|
||||
bot = AnthropicTelegramInferenceBot(
|
||||
anthropic_client=preconfigured_client,
|
||||
small_model_name="custom-small",
|
||||
small_model_max_tokens=100,
|
||||
large_model_name="custom-large",
|
||||
large_model_max_tokens=200
|
||||
)
|
||||
|
||||
MockAnthropicConstructor.assert_not_called()
|
||||
self.assertEqual(bot.anthropic_client, preconfigured_client)
|
||||
self.assertEqual(bot.model, "custom-small")
|
||||
self.assertEqual(bot.max_tokens, 100)
|
||||
self.assertEqual(bot.small_model_name, "custom-small")
|
||||
self.assertEqual(bot.large_model_name, "custom-large")
|
||||
|
||||
|
||||
def test_get_llm_description(self):
|
||||
bot = AnthropicTelegramInferenceBot(small_model_name="claude-test", small_model_max_tokens=500)
|
||||
self.assertEqual(bot.get_llm_description(), "LLM: claude-test, Max Tokens: 500")
|
||||
|
||||
async def test_switch_model(self):
|
||||
bot = AnthropicTelegramInferenceBot(
|
||||
small_model_name="claude-small", small_model_max_tokens=10,
|
||||
large_model_name="claude-large", large_model_max_tokens=20
|
||||
)
|
||||
self.assertEqual(bot.model, "claude-small")
|
||||
self.assertEqual(bot.max_tokens, 10)
|
||||
|
||||
status = await bot.switch_model()
|
||||
self.assertEqual(bot.model, "claude-large")
|
||||
self.assertEqual(bot.max_tokens, 20)
|
||||
self.assertEqual(status, "Switched to model: claude-large")
|
||||
|
||||
status = await bot.switch_model()
|
||||
self.assertEqual(bot.model, "claude-small")
|
||||
self.assertEqual(bot.max_tokens, 10)
|
||||
self.assertEqual(status, "Switched to model: claude-small")
|
||||
|
||||
def test_get_chat_response_success_text_only(self):
|
||||
bot = AnthropicTelegramInferenceBot(anthropic_client=self.mock_anthropic_client_instance)
|
||||
bot.model = "test-claude"
|
||||
bot.max_tokens = 150
|
||||
|
||||
mock_api_response = create_mock_anthropic_response(content_text="Hello from Anthropic API")
|
||||
self.mock_anthropic_client_instance.messages.create.return_value = mock_api_response
|
||||
|
||||
messages = [{"role": "user", "content": "Hi"}] # Anthropic format
|
||||
response = bot.get_chat_response(messages, []) # tools = empty list
|
||||
|
||||
self.mock_anthropic_client_instance.messages.create.assert_called_once_with(
|
||||
model="test-claude",
|
||||
max_tokens=150,
|
||||
messages=messages,
|
||||
system=bot.system_prompt, # Ensure system prompt is passed
|
||||
tools=None, # No tools passed to API if empty list or None
|
||||
tool_choice=None
|
||||
)
|
||||
self.assertEqual(response, mock_api_response)
|
||||
|
||||
def test_get_chat_response_with_tools(self):
|
||||
bot = AnthropicTelegramInferenceBot(anthropic_client=self.mock_anthropic_client_instance)
|
||||
bot.model = "claude-toolmaster"
|
||||
bot.max_tokens = 300
|
||||
|
||||
mock_tools_spec = [{"name": "get_weather", "description": "Gets weather", "input_schema": {"type": "object", "properties": {}}}]
|
||||
|
||||
mock_api_response = create_mock_anthropic_response(content_text="Thinking...", tool_use_parts=[
|
||||
{"id": "tool1", "name": "get_weather", "input": {"location": "here"}}
|
||||
])
|
||||
self.mock_anthropic_client_instance.messages.create.return_value = mock_api_response
|
||||
|
||||
messages = [{"role": "user", "content": "Weather?"}]
|
||||
response = bot.get_chat_response(messages, mock_tools_spec)
|
||||
|
||||
self.mock_anthropic_client_instance.messages.create.assert_called_once_with(
|
||||
model="claude-toolmaster",
|
||||
max_tokens=300,
|
||||
messages=messages,
|
||||
system=bot.system_prompt,
|
||||
tools=mock_tools_spec,
|
||||
tool_choice={"type": "auto"}
|
||||
)
|
||||
self.assertEqual(response.content[0].type, "text") # First part can be text
|
||||
self.assertEqual(response.content[1].type, "tool_use")
|
||||
|
||||
|
||||
def test_get_chat_response_api_error(self):
|
||||
bot = AnthropicTelegramInferenceBot(anthropic_client=self.mock_anthropic_client_instance)
|
||||
self.mock_anthropic_client_instance.messages.create.side_effect = Exception("Anthropic API Down")
|
||||
|
||||
with self.assertRaisesRegex(Exception, "Anthropic API Down"):
|
||||
bot.get_chat_response([{"role": "user", "content": "trigger"}], [])
|
||||
|
||||
|
||||
async def test_handle_message_simple_response_no_tools(self):
|
||||
# This test is more involved as it touches BaseTelegramInferenceBot's handle_message structure
|
||||
# which then calls the overridden get_chat_response.
|
||||
bot = AnthropicTelegramInferenceBot(anthropic_client=self.mock_anthropic_client_instance)
|
||||
bot.system_prompt = "System prompt for Anthropic"
|
||||
|
||||
# Mock get_chat_response directly to isolate its behavior from full handle_message logic of base
|
||||
# However, the point of this bot is its get_chat_response and subsequent processing.
|
||||
# So, let's mock the API call within get_chat_response.
|
||||
|
||||
api_response = create_mock_anthropic_response(content_text="Anthropic says hello.")
|
||||
self.mock_anthropic_client_instance.messages.create.return_value = api_response
|
||||
|
||||
# Ensure functions are empty for this test, so no tool logic is triggered
|
||||
bot.functions = []
|
||||
bot.tools = []
|
||||
|
||||
response_content = await bot.handle_message(user_id=101, user_message="Hello Anthropic")
|
||||
|
||||
self.assertEqual(response_content, "Anthropic says hello.")
|
||||
self.assertIn(101, bot.conversation_history)
|
||||
# Anthropic's handle_message structure:
|
||||
# 1. User message added to history.
|
||||
# 2. get_chat_response is called.
|
||||
# 3. Response content (text) is extracted.
|
||||
# 4. Assistant text response is added to history.
|
||||
# Expected history: [User, Assistant_Text_Response] (system prompt handled by get_chat_response)
|
||||
# The base class handle_message adds system prompt if not present.
|
||||
# Anthropic handle_message modifies history format before calling get_chat_response.
|
||||
|
||||
# Let's trace Base.handle_message -> Anthropic.handle_message -> Anthropic.get_chat_response
|
||||
# Base.handle_message:
|
||||
# - Adds system prompt to history if first turn: `self.conversation_history[user_id] = [{"role": "system", "content": self.system_prompt}]` (OpenAI style)
|
||||
# - Appends user message: `{"role": "user", "content": user_message}`
|
||||
# - Calls self.get_chat_response(messages, self.functions) -> This is Anthropic's get_chat_response
|
||||
# Anthropic.get_chat_response:
|
||||
# - Takes OpenAI style `messages` and `self.functions` (tool specs).
|
||||
# - Calls `anthropic_client.messages.create` with Anthropic style messages and system prompt.
|
||||
# Anthropic.handle_message (overridden):
|
||||
# - Prepares Anthropic-style messages from conversation_history (which is OpenAI style from Base)
|
||||
# - Calls get_chat_response with these Anthropic messages and self.functions (tool_specs)
|
||||
# - Processes response, extracts text, handles tool calls.
|
||||
# - Appends *user* message (original) and *assistant* text response to self.conversation_history (OpenAI style).
|
||||
|
||||
# For this test, we are calling AnthropicBot.handle_message directly.
|
||||
# 1. `user_id` not in `self.conversation_history`: `system_prompt` not added yet by Base logic.
|
||||
# Anthropic's `handle_message` will create `anthropic_messages` from this.
|
||||
# If `conversation_history` is empty, `anthropic_messages` = `[{"role": "user", "content": user_message}]`
|
||||
# 2. `get_chat_response` called with `anthropic_messages` and `bot.system_prompt` passed to API.
|
||||
# 3. Response "Anthropic says hello."
|
||||
# 4. Original `user_message` and "Anthropic says hello." (as assistant) added to `self.conversation_history`.
|
||||
|
||||
history = bot.conversation_history[101]
|
||||
self.assertEqual(len(history), 2) # User, Assistant
|
||||
self.assertEqual(history[0]["role"], "user")
|
||||
self.assertEqual(history[0]["content"], "Hello Anthropic")
|
||||
self.assertEqual(history[1]["role"], "assistant")
|
||||
self.assertEqual(history[1]["content"], "Anthropic says hello.")
|
||||
|
||||
# Check API call (made by the mocked get_chat_response indirectly)
|
||||
self.mock_anthropic_client_instance.messages.create.assert_called_once()
|
||||
call_args = self.mock_anthropic_client_instance.messages.create.call_args
|
||||
self.assertEqual(call_args.kwargs["system"], "System prompt for Anthropic")
|
||||
# Initial messages for API should just be the user message for first turn
|
||||
self.assertEqual(call_args.kwargs["messages"], [{"role": "user", "content": "Hello Anthropic"}])
|
||||
|
||||
|
||||
async def test_handle_message_with_tool_calls(self):
|
||||
bot = AnthropicTelegramInferenceBot(anthropic_client=self.mock_anthropic_client_instance)
|
||||
bot.system_prompt = "You are a helpful, tool-using assistant."
|
||||
|
||||
# Define a tool for the bot (OpenAI format, will be converted by Anthropic bot for API)
|
||||
mock_tool_oai_format = {"type": "function", "function": {"name": "get_weather", "description": "Get weather", "parameters": {}}}
|
||||
bot.functions = [mock_tool_oai_format] # This is used to generate anthropic_tools for API
|
||||
|
||||
# API Response 1: Request for tool call
|
||||
tool_use_part = {"id": "toolu_xyz", "name": "get_weather", "input": {"location": "paris"}}
|
||||
api_response_1 = create_mock_anthropic_response(tool_use_parts=[tool_use_part])
|
||||
|
||||
# API Response 2: Final text response after tool execution
|
||||
api_response_2 = create_mock_anthropic_response(content_text="The weather in Paris is nice.")
|
||||
|
||||
self.mock_anthropic_client_instance.messages.create.side_effect = [api_response_1, api_response_2]
|
||||
|
||||
# Mock the bot's call_tool method (from BaseTelegramInferenceBot)
|
||||
bot.call_tool = MagicMock(return_value='''{"weather": "sunny"}''') # Tool execution result
|
||||
|
||||
user_id = 102
|
||||
user_message = "What's the weather in Paris?"
|
||||
final_text_response = await bot.handle_message(user_id, user_message)
|
||||
|
||||
self.assertEqual(final_text_response, "The weather in Paris is nice.")
|
||||
self.assertEqual(self.mock_anthropic_client_instance.messages.create.call_count, 2)
|
||||
|
||||
bot.call_tool.assert_called_once_with("get_weather", {"location": "paris"}) # Anthropic passes input as dict
|
||||
|
||||
# Check conversation history (OpenAI style)
|
||||
history = bot.conversation_history[user_id]
|
||||
self.assertEqual(history[0]["role"], "user")
|
||||
self.assertEqual(history[0]["content"], user_message)
|
||||
|
||||
# Assistant message that requested tool call (Anthropic-specific format stored by its handle_message)
|
||||
# Anthropic's handle_message appends the raw tool_use block and then the tool_result
|
||||
self.assertEqual(history[1]["role"], "assistant")
|
||||
self.assertTrue(isinstance(history[1]["content"], list)) # Anthropic content is a list
|
||||
self.assertEqual(history[1]["content"][0]["type"], "tool_use")
|
||||
self.assertEqual(history[1]["content"][0]["id"], "toolu_xyz")
|
||||
|
||||
self.assertEqual(history[2]["role"], "tool")
|
||||
self.assertEqual(history[2]["tool_call_id"], "toolu_xyz")
|
||||
self.assertEqual(history[2]["name"], "get_weather")
|
||||
self.assertEqual(history[2]["content"], '''{"weather": "sunny"}''') # call_tool result
|
||||
|
||||
self.assertEqual(history[3]["role"], "assistant") # Final text response
|
||||
self.assertTrue(isinstance(history[3]["content"], str)) # simple text
|
||||
self.assertEqual(history[3]["content"], "The weather in Paris is nice.")
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -0,0 +1,310 @@
|
||||
import unittest
|
||||
from unittest.mock import patch, mock_open, MagicMock
|
||||
import os
|
||||
import json
|
||||
|
||||
# Ensure the path includes the directory where base_telegram_inference_bot is located
|
||||
# This might require adjustment based on actual project structure if tests are run from root
|
||||
# For now, assuming it can be imported directly or via PYTHONPATH
|
||||
from base_telegram_inference_bot import BaseTelegramInferenceBot
|
||||
from tools.base_tool import BaseTool # For mocking tool structure
|
||||
|
||||
# Create a concrete subclass for testing, as BaseTelegramInferenceBot is abstract
|
||||
class ConcreteTestBot(BaseTelegramInferenceBot):
|
||||
def __init__(self, system_prompt_content=None, system_prompt_path=None, mock_tools=None, mock_functions=None):
|
||||
# Mock load_functions during super().__init__ if needed, or control tools/functions directly
|
||||
self._mock_tools = mock_tools if mock_tools is not None else []
|
||||
self._mock_functions = mock_functions if mock_functions is not None else []
|
||||
super().__init__(system_prompt_content=system_prompt_content, system_prompt_path=system_prompt_path)
|
||||
|
||||
# Override load_functions to use mocks
|
||||
def load_functions(self):
|
||||
return self._mock_tools, self._mock_functions
|
||||
|
||||
def get_chat_response(self, messages):
|
||||
pass # Abstract method, not tested here directly
|
||||
|
||||
async def handle_message(self, user_id, user_message):
|
||||
pass # Abstract method
|
||||
|
||||
def get_llm_description(self) -> str:
|
||||
return "Mock LLM Description" # Concrete implementation for testing get_bot_status
|
||||
|
||||
async def start(self):
|
||||
pass # Abstract method
|
||||
|
||||
async def abort_processing(self, user_id):
|
||||
pass # Abstract method
|
||||
|
||||
async def switch_model(self):
|
||||
pass # Abstract method
|
||||
|
||||
class TestBaseTelegramInferenceBot(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Reset relevant environment variables before each test
|
||||
self.original_system_prompt_path = os.environ.get("SYSTEM_PROMPT_PATH")
|
||||
if "SYSTEM_PROMPT_PATH" in os.environ:
|
||||
del os.environ["SYSTEM_PROMPT_PATH"]
|
||||
|
||||
def tearDown(self):
|
||||
# Restore environment variables
|
||||
if self.original_system_prompt_path:
|
||||
os.environ["SYSTEM_PROMPT_PATH"] = self.original_system_prompt_path
|
||||
elif "SYSTEM_PROMPT_PATH" in os.environ: # Ensure it's removed if test set it and it wasn't there before
|
||||
del os.environ["SYSTEM_PROMPT_PATH"]
|
||||
|
||||
def test_init_with_direct_system_prompt(self):
|
||||
bot = ConcreteTestBot(system_prompt_content="Direct prompt content")
|
||||
self.assertEqual(bot.system_prompt, "Direct prompt content")
|
||||
|
||||
@patch("os.path.isfile")
|
||||
@patch("builtins.open", new_callable=mock_open, read_data="File prompt content")
|
||||
def test_init_with_system_prompt_path_argument(self, mock_file_open, mock_isfile):
|
||||
mock_isfile.return_value = True
|
||||
bot = ConcreteTestBot(system_prompt_path="dummy/path.txt")
|
||||
self.assertEqual(bot.system_prompt, "File prompt content")
|
||||
mock_isfile.assert_called_once_with("dummy/path.txt")
|
||||
mock_file_open.assert_called_once_with("dummy/path.txt", "r", encoding="utf-8")
|
||||
|
||||
@patch("os.path.isfile")
|
||||
@patch("builtins.open", new_callable=mock_open, read_data="Env prompt content")
|
||||
def test_init_with_env_system_prompt_path(self, mock_file_open, mock_isfile):
|
||||
mock_isfile.return_value = True
|
||||
os.environ["SYSTEM_PROMPT_PATH"] = "env/path.txt"
|
||||
bot = ConcreteTestBot()
|
||||
self.assertEqual(bot.system_prompt, "Env prompt content")
|
||||
mock_isfile.assert_called_once_with("env/path.txt")
|
||||
mock_file_open.assert_called_once_with("env/path.txt", "r", encoding="utf-8")
|
||||
|
||||
def test_init_with_default_system_prompt(self):
|
||||
# Ensure ENV var is not set for this test
|
||||
if "SYSTEM_PROMPT_PATH" in os.environ:
|
||||
del os.environ["SYSTEM_PROMPT_PATH"]
|
||||
bot = ConcreteTestBot()
|
||||
self.assertEqual(bot.system_prompt, "You are a helpful AI assistant.")
|
||||
|
||||
@patch("os.path.isfile", return_value=False)
|
||||
def test_init_with_invalid_system_prompt_path(self, mock_isfile):
|
||||
bot = ConcreteTestBot(system_prompt_path="invalid/path.txt")
|
||||
self.assertEqual(bot.system_prompt, "You are a helpful AI assistant.")
|
||||
mock_isfile.assert_called_once_with("invalid/path.txt")
|
||||
|
||||
@patch("os.path.isfile")
|
||||
@patch("builtins.open", side_effect=IOError("File read error"))
|
||||
def test_init_with_system_prompt_file_read_error(self, mock_file_open, mock_isfile):
|
||||
mock_isfile.return_value = True
|
||||
bot = ConcreteTestBot(system_prompt_path="dummy/path.txt")
|
||||
self.assertEqual(bot.system_prompt, "You are a helpful AI assistant.")
|
||||
|
||||
def test_clear_conversation_history(self):
|
||||
mock_tool_instance = MagicMock(spec=BaseTool)
|
||||
bot = ConcreteTestBot(mock_tools=[mock_tool_instance])
|
||||
bot.conversation_history[123] = [{"role": "user", "content": "Hello"}]
|
||||
|
||||
bot.clear_conversation_history(123)
|
||||
self.assertNotIn(123, bot.conversation_history)
|
||||
mock_tool_instance.clear.assert_called_once()
|
||||
|
||||
def test_clear_conversation_history_user_not_found(self):
|
||||
mock_tool_instance = MagicMock(spec=BaseTool)
|
||||
bot = ConcreteTestBot(mock_tools=[mock_tool_instance])
|
||||
bot.clear_conversation_history(404)
|
||||
self.assertNotIn(404, bot.conversation_history)
|
||||
mock_tool_instance.clear.assert_called_once()
|
||||
|
||||
def test_processing_status(self):
|
||||
bot = ConcreteTestBot()
|
||||
self.assertEqual(bot.processing_status, {})
|
||||
bot.set_processing_status(123, 789)
|
||||
self.assertEqual(bot.processing_status[123], {"processing": True, "message_id": 789})
|
||||
bot.clear_processing_status(123)
|
||||
self.assertNotIn(123, bot.processing_status)
|
||||
|
||||
def test_clear_processing_status_user_not_found(self):
|
||||
bot = ConcreteTestBot()
|
||||
bot.clear_processing_status(404)
|
||||
self.assertNotIn(404, bot.processing_status)
|
||||
|
||||
def test_call_tool_success_dict_args(self):
|
||||
mock_tool = MagicMock(spec=BaseTool)
|
||||
mock_tool.get_functions.return_value = [
|
||||
{"function": {"name": "test_tool", "description": "A test tool", "parameters": {}}}
|
||||
]
|
||||
mock_tool.execute.return_value = "Tool executed successfully"
|
||||
|
||||
bot = ConcreteTestBot(mock_tools=[mock_tool], mock_functions=mock_tool.get_functions())
|
||||
|
||||
result = bot.call_tool("test_tool", {"arg1": "value1"})
|
||||
self.assertEqual(result, "Tool executed successfully")
|
||||
mock_tool.execute.assert_called_once_with("test_tool", arg1="value1")
|
||||
|
||||
def test_call_tool_success_json_string_args(self):
|
||||
mock_tool = MagicMock(spec=BaseTool)
|
||||
mock_tool.get_functions.return_value = [
|
||||
{"function": {"name": "test_tool_json", "parameters": {}}}
|
||||
]
|
||||
mock_tool.execute.return_value = "Tool JSON OK"
|
||||
bot = ConcreteTestBot(mock_tools=[mock_tool], mock_functions=mock_tool.get_functions())
|
||||
|
||||
args_json_str = '''{"param": "value"}'''
|
||||
result = bot.call_tool("test_tool_json", args_json_str)
|
||||
self.assertEqual(result, "Tool JSON OK")
|
||||
mock_tool.execute.assert_called_once_with("test_tool_json", param="value")
|
||||
|
||||
def test_call_tool_malformed_json_string_args(self):
|
||||
bot = ConcreteTestBot(mock_tools=[])
|
||||
args_malformed_json_str = '''{"param": "value"'''
|
||||
result = bot.call_tool("some_tool", args_malformed_json_str)
|
||||
self.assertTrue("Error: Malformed arguments for tool call" in result)
|
||||
|
||||
def test_call_tool_unexpected_arg_type(self):
|
||||
bot = ConcreteTestBot(mock_tools=[])
|
||||
result = bot.call_tool("some_tool", 12345) # Integer instead of dict/str
|
||||
self.assertTrue("Error: Invalid argument type for tool call" in result)
|
||||
|
||||
def test_call_tool_none_args(self):
|
||||
mock_tool = MagicMock(spec=BaseTool)
|
||||
mock_tool.get_functions.return_value = [
|
||||
{"function": {"name": "test_tool_none", "parameters": {}}}
|
||||
]
|
||||
mock_tool.execute.return_value = "Tool None OK"
|
||||
bot = ConcreteTestBot(mock_tools=[mock_tool], mock_functions=mock_tool.get_functions())
|
||||
|
||||
result = bot.call_tool("test_tool_none", None)
|
||||
self.assertEqual(result, "Tool None OK")
|
||||
mock_tool.execute.assert_called_once_with("test_tool_none") # No kwargs if None
|
||||
|
||||
def test_call_tool_not_found(self):
|
||||
bot = ConcreteTestBot(mock_tools=[])
|
||||
result = bot.call_tool("non_existent_tool", {})
|
||||
self.assertEqual(result, "Error: Tool function non_existent_tool not found.")
|
||||
|
||||
def test_call_tool_execute_exception(self):
|
||||
mock_tool = MagicMock(spec=BaseTool)
|
||||
mock_tool.get_functions.return_value = [{"function": {"name": "error_tool", "parameters": {}}}]
|
||||
mock_tool.execute.side_effect = Exception("Execution failed")
|
||||
bot = ConcreteTestBot(mock_tools=[mock_tool], mock_functions=mock_tool.get_functions())
|
||||
|
||||
result = bot.call_tool("error_tool", {})
|
||||
self.assertEqual(result, "Error executing tool error_tool: Execution failed")
|
||||
|
||||
def test_get_system_prompt_description(self):
|
||||
if "SYSTEM_PROMPT_PATH" in os.environ: # Ensure clean state
|
||||
del os.environ["SYSTEM_PROMPT_PATH"]
|
||||
|
||||
bot_default = ConcreteTestBot()
|
||||
self.assertEqual(bot_default.get_system_prompt_description(), "System Prompt: Default")
|
||||
|
||||
bot_custom_content = ConcreteTestBot(system_prompt_content="Custom content here")
|
||||
self.assertEqual(bot_custom_content.get_system_prompt_description(), "System Prompt: Custom")
|
||||
|
||||
os.environ["SYSTEM_PROMPT_PATH"] = "some/path.txt"
|
||||
bot_env_default_prompt = ConcreteTestBot() # system_prompt itself is default
|
||||
self.assertEqual(bot_env_default_prompt.get_system_prompt_description(), "System Prompt: Custom (via ENV)")
|
||||
|
||||
with patch("os.path.isfile", return_value=True), \
|
||||
patch("builtins.open", mock_open(read_data="File prompt from ENV")):
|
||||
bot_env_file_prompt = ConcreteTestBot() # system_prompt gets loaded from ENV path
|
||||
self.assertEqual(bot_env_file_prompt.get_system_prompt_description(), "System Prompt: Custom")
|
||||
del os.environ["SYSTEM_PROMPT_PATH"]
|
||||
|
||||
with patch("os.path.isfile", return_value=True), \
|
||||
patch("builtins.open", mock_open(read_data="File prompt from arg")):
|
||||
bot_custom_file_arg = ConcreteTestBot(system_prompt_path="custom/file.txt")
|
||||
self.assertEqual(bot_custom_file_arg.get_system_prompt_description(), "System Prompt: Custom")
|
||||
|
||||
@patch.object(ConcreteTestBot, 'get_llm_description', return_value="Test LLM Description")
|
||||
@patch.object(ConcreteTestBot, 'get_system_prompt_description', return_value="Test Prompt Description")
|
||||
async def test_get_bot_status(self, mock_prompt_desc, mock_llm_desc):
|
||||
bot = ConcreteTestBot()
|
||||
status = await bot.get_bot_status()
|
||||
self.assertEqual(status, "Test Prompt Description\nTest LLM Description")
|
||||
mock_prompt_desc.assert_called_once()
|
||||
mock_llm_desc.assert_called_once()
|
||||
|
||||
@patch('os.path.dirname', return_value='/mock/path')
|
||||
@patch('os.path.join')
|
||||
@patch('os.path.exists')
|
||||
@patch('os.listdir')
|
||||
@patch('importlib.import_module')
|
||||
def test_load_functions_no_tools_dir(self, mock_import_module, mock_listdir, mock_exists, mock_join, mock_dirname):
|
||||
mock_join.return_value = '/mock/path/tools'
|
||||
mock_exists.return_value = False
|
||||
|
||||
class BotForLoadTest(BaseTelegramInferenceBot):
|
||||
load_system_prompt = MagicMock(return_value="Default")
|
||||
get_chat_response = MagicMock(); handle_message = MagicMock(); get_llm_description = MagicMock(return_value="mock")
|
||||
start = MagicMock(); abort_processing = MagicMock(); switch_model = MagicMock()
|
||||
|
||||
bot = BotForLoadTest()
|
||||
self.assertEqual(bot.tools, [])
|
||||
self.assertEqual(bot.functions, [])
|
||||
mock_listdir.assert_not_called()
|
||||
|
||||
@patch('os.path.dirname', return_value='/mock/base_bot_dir')
|
||||
@patch('os.path.join', side_effect=lambda *args: os.path.normpath(os.path.join(*args)))
|
||||
@patch('os.path.exists', return_value=True)
|
||||
@patch('os.listdir', return_value=['my_tool.py', '__init__.py', 'base_tool.py'])
|
||||
@patch('importlib.import_module')
|
||||
def test_load_functions_with_one_tool(self, mock_import_module, mock_listdir, mock_exists, mock_join, mock_dirname):
|
||||
|
||||
mock_tool_class = MagicMock(spec=BaseTool) # This is the class itself
|
||||
mock_tool_instance = MagicMock(spec=BaseTool) # This is the instance
|
||||
mock_tool_class.return_value = mock_tool_instance # mock_tool_class() creates mock_tool_instance
|
||||
mock_tool_instance.get_functions.return_value = [{"function": {"name": "sample_function"}}]
|
||||
|
||||
mock_my_tool_module = MagicMock()
|
||||
# Simulate inspect.getmembers behavior: returns list of (name, member) tuples
|
||||
# Only include members that are classes, derive from BaseTool, and are not BaseTool itself.
|
||||
mock_my_tool_module.ValidToolClass = mock_tool_class
|
||||
mock_my_tool_module.NotATool = object()
|
||||
mock_my_tool_module.BaseTool = BaseTool # This should be skipped by the loader
|
||||
|
||||
def import_side_effect(module_name):
|
||||
if module_name == 'tools.my_tool':
|
||||
return mock_my_tool_module
|
||||
raise ImportError(f"Unexpected import: {module_name}")
|
||||
mock_import_module.side_effect = import_side_effect
|
||||
|
||||
class BotForLoadTest(BaseTelegramInferenceBot):
|
||||
load_system_prompt = MagicMock(return_value="Default")
|
||||
get_chat_response = MagicMock(); handle_message = MagicMock(); get_llm_description = MagicMock(return_value="mock")
|
||||
start = MagicMock(); abort_processing = MagicMock(); switch_model = MagicMock()
|
||||
|
||||
bot = BotForLoadTest()
|
||||
self.assertEqual(len(bot.tools), 1)
|
||||
self.assertIs(bot.tools[0], mock_tool_instance)
|
||||
self.assertEqual(len(bot.functions), 1)
|
||||
self.assertEqual(bot.functions[0]['function']['name'], "sample_function")
|
||||
mock_import_module.assert_called_once_with('tools.my_tool')
|
||||
mock_tool_class.assert_called_once_with() # Tool class was instantiated
|
||||
mock_tool_instance.get_functions.assert_called_once_with()
|
||||
|
||||
@patch('os.path.dirname', return_value='/mock/base_bot_dir')
|
||||
@patch('os.path.join', side_effect=lambda *args: os.path.normpath(os.path.join(*args)))
|
||||
@patch('os.path.exists', return_value=True)
|
||||
@patch('os.listdir', return_value=['tool_with_init_error.py'])
|
||||
@patch('importlib.import_module')
|
||||
@patch('logging.error') # Mock logging to check for error messages
|
||||
def test_load_functions_tool_instantiation_error(self, mock_logging_error, mock_import_module, mock_listdir, mock_exists, mock_join, mock_dirname):
|
||||
mock_tool_class_init_error = MagicMock(spec=BaseTool)
|
||||
mock_tool_class_init_error.side_effect = Exception("Failed to init tool") # Error on instantiation
|
||||
|
||||
mock_error_tool_module = MagicMock()
|
||||
mock_error_tool_module.ToolWithInitError = mock_tool_class_init_error
|
||||
|
||||
mock_import_module.return_value = mock_error_tool_module
|
||||
|
||||
class BotForLoadTest(BaseTelegramInferenceBot):
|
||||
load_system_prompt = MagicMock(return_value="Default")
|
||||
get_chat_response = MagicMock(); handle_message = MagicMock(); get_llm_description = MagicMock(return_value="mock")
|
||||
start = MagicMock(); abort_processing = MagicMock(); switch_model = MagicMock()
|
||||
|
||||
bot = BotForLoadTest()
|
||||
self.assertEqual(len(bot.tools), 0)
|
||||
self.assertEqual(len(bot.functions), 0)
|
||||
mock_logging_error.assert_any_call("Error instantiating tool ToolWithInitError from tool_with_init_error.py: Failed to init tool")
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(闂傚лен䦗婢у〃埊鍓解劓姣)
|
||||
@@ -0,0 +1,158 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch, ANY
|
||||
import os
|
||||
|
||||
# Assuming chatgpt_telegram_inference_bot.py and its parent are accessible
|
||||
from chatgpt_telegram_inference_bot import ChatGPTTelegramInferenceBot
|
||||
from openai_compatible_inference_bot import OpenAICompatibleInferenceBot # For patching super
|
||||
|
||||
class TestChatGPTTelegramInferenceBot(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Store and clear relevant environment variables
|
||||
self.original_openai_key = os.environ.get("OPENAI_API_KEY")
|
||||
self.original_small_model = os.environ.get("OPENAI_SMALL_MODEL")
|
||||
self.original_large_model = os.environ.get("OPENAI_LARGE_MODEL")
|
||||
self.original_small_tokens = os.environ.get("OPENAI_SMALL_MODEL_MAX_TOKENS")
|
||||
self.original_large_tokens = os.environ.get("OPENAI_LARGE_MODEL_MAX_TOKENS")
|
||||
self.original_system_prompt_path = os.environ.get("SYSTEM_PROMPT_PATH")
|
||||
|
||||
for key in ["OPENAI_API_KEY", "OPENAI_SMALL_MODEL", "OPENAI_LARGE_MODEL",
|
||||
"OPENAI_SMALL_MODEL_MAX_TOKENS", "OPENAI_LARGE_MODEL_MAX_TOKENS", "SYSTEM_PROMPT_PATH"]:
|
||||
if os.environ.get(key):
|
||||
del os.environ[key]
|
||||
|
||||
# Mock the OpenAI client that OpenAICompatibleInferenceBot's __init__ might create
|
||||
self.mock_openai_client = MagicMock()
|
||||
|
||||
def tearDown(self):
|
||||
# Restore environment variables
|
||||
if self.original_openai_key: os.environ["OPENAI_API_KEY"] = self.original_openai_key
|
||||
if self.original_small_model: os.environ["OPENAI_SMALL_MODEL"] = self.original_small_model
|
||||
if self.original_large_model: os.environ["OPENAI_LARGE_MODEL"] = self.original_large_model
|
||||
if self.original_small_tokens: os.environ["OPENAI_SMALL_MODEL_MAX_TOKENS"] = self.original_small_tokens
|
||||
if self.original_large_tokens: os.environ["OPENAI_LARGE_MODEL_MAX_TOKENS"] = self.original_large_tokens
|
||||
if self.original_system_prompt_path: os.environ["SYSTEM_PROMPT_PATH"] = self.original_system_prompt_path
|
||||
|
||||
|
||||
@patch.object(OpenAICompatibleInferenceBot, '__init__') # Mock the superclass's __init__
|
||||
def test_init_defaults_and_super_call(self, mock_super_init):
|
||||
os.environ["OPENAI_API_KEY"] = "test_key_chatgpt"
|
||||
os.environ["OPENAI_SMALL_MODEL"] = "gpt-3.5-turbo-env"
|
||||
os.environ["OPENAI_SMALL_MODEL_MAX_TOKENS"] = "350"
|
||||
|
||||
bot = ChatGPTTelegramInferenceBot()
|
||||
|
||||
mock_super_init.assert_called_once_with(
|
||||
client=None, # ChatGPT bot will let superclass create it
|
||||
api_key="test_key_chatgpt", # Passed to super
|
||||
base_url=None,
|
||||
api_version=None,
|
||||
azure_deployment=None,
|
||||
model_name="gpt-3.5-turbo-env", # Default small model from env
|
||||
max_tokens_str="350", # Default small model tokens from env
|
||||
small_model_name="gpt-3.5-turbo-env",
|
||||
small_model_max_tokens_str="350",
|
||||
large_model_name=os.environ.get("OPENAI_LARGE_MODEL", "gpt-4-turbo-preview"), # Default large
|
||||
large_model_max_tokens_str=os.environ.get("OPENAI_LARGE_MODEL_MAX_TOKENS"),
|
||||
system_prompt_content=None,
|
||||
system_prompt_path=None,
|
||||
is_gemini=False,
|
||||
max_history_length=20 # Default from OpenAICompatibleInferenceBot
|
||||
)
|
||||
|
||||
@patch.object(OpenAICompatibleInferenceBot, '__init__')
|
||||
def test_init_with_arguments(self, mock_super_init):
|
||||
mock_client_arg = MagicMock()
|
||||
bot = ChatGPTTelegramInferenceBot(
|
||||
openai_client=mock_client_arg,
|
||||
api_key="arg_key",
|
||||
small_model_name="arg_small_model",
|
||||
small_model_max_tokens="123",
|
||||
large_model_name="arg_large_model",
|
||||
large_model_max_tokens="456",
|
||||
system_prompt_content="Arg prompt"
|
||||
)
|
||||
mock_super_init.assert_called_once_with(
|
||||
client=mock_client_arg,
|
||||
api_key="arg_key",
|
||||
base_url=None,
|
||||
api_version=None,
|
||||
azure_deployment=None,
|
||||
model_name="arg_small_model", # Initially configured with small model
|
||||
max_tokens_str="123",
|
||||
small_model_name="arg_small_model",
|
||||
small_model_max_tokens_str="123",
|
||||
large_model_name="arg_large_model",
|
||||
large_model_max_tokens_str="456",
|
||||
system_prompt_content="Arg prompt",
|
||||
system_prompt_path=None,
|
||||
is_gemini=False,
|
||||
max_history_length=20
|
||||
)
|
||||
|
||||
# Test switch_model - this method is part of ChatGPTTelegramInferenceBot
|
||||
# It calls _configure_model_and_tokens which is in the superclass.
|
||||
# We need a bot instance where _configure_model_and_tokens can be called.
|
||||
@patch('openai.OpenAI') # To allow instantiation of the bot by mocking client creation
|
||||
async def test_switch_model_logic(self, mock_openai_constructor):
|
||||
mock_openai_constructor.return_value = self.mock_openai_client # Mock client creation in super
|
||||
|
||||
# Set env vars for model names that switch_model will use as fallback
|
||||
os.environ["OPENAI_SMALL_MODEL"] = "env-small-gpt"
|
||||
os.environ["OPENAI_SMALL_MODEL_MAX_TOKENS"] = "100"
|
||||
os.environ["OPENAI_LARGE_MODEL"] = "env-large-gpt"
|
||||
os.environ["OPENAI_LARGE_MODEL_MAX_TOKENS"] = "200"
|
||||
|
||||
# Instantiate with initial model (small)
|
||||
bot = ChatGPTTelegramInferenceBot()
|
||||
self.assertEqual(bot.model, "env-small-gpt")
|
||||
self.assertEqual(bot.max_tokens, 100)
|
||||
|
||||
# Switch to large
|
||||
status = await bot.switch_model()
|
||||
self.assertEqual(bot.model, "env-large-gpt")
|
||||
self.assertEqual(bot.max_tokens, 200)
|
||||
self.assertEqual(status, "Switched to model: env-large-gpt")
|
||||
|
||||
# Switch back to small
|
||||
status = await bot.switch_model()
|
||||
self.assertEqual(bot.model, "env-small-gpt")
|
||||
self.assertEqual(bot.max_tokens, 100)
|
||||
self.assertEqual(status, "Switched to model: env-small-gpt")
|
||||
|
||||
@patch('openai.OpenAI')
|
||||
async def test_switch_model_uses_instance_configs_if_provided(self, mock_openai_constructor):
|
||||
mock_openai_constructor.return_value = self.mock_openai_client
|
||||
|
||||
# Instantiate with specific model names, overriding potential env vars
|
||||
bot = ChatGPTTelegramInferenceBot(
|
||||
small_model_name="init-small", small_model_max_tokens="50",
|
||||
large_model_name="init-large", large_model_max_tokens="150"
|
||||
)
|
||||
self.assertEqual(bot.model, "init-small") # Starts with small
|
||||
self.assertEqual(bot.max_tokens, 50)
|
||||
|
||||
# Switch to large
|
||||
status = await bot.switch_model()
|
||||
self.assertEqual(bot.model, "init-large")
|
||||
self.assertEqual(bot.max_tokens, 150)
|
||||
self.assertEqual(status, "Switched to model: init-large")
|
||||
|
||||
# Switch back to small
|
||||
status = await bot.switch_model()
|
||||
self.assertEqual(bot.model, "init-small")
|
||||
self.assertEqual(bot.max_tokens, 50)
|
||||
self.assertEqual(status, "Switched to model: init-small")
|
||||
|
||||
# get_llm_description is inherited from OpenAICompatibleInferenceBot.
|
||||
# Test just to ensure it works in the context of a ChatGPTBot instance
|
||||
@patch('openai.OpenAI')
|
||||
def test_get_llm_description_for_chatgpt_bot(self, mock_openai_constructor):
|
||||
mock_openai_constructor.return_value = self.mock_openai_client
|
||||
bot = ChatGPTTelegramInferenceBot(small_model_name="gpt-3.5-desc", small_model_max_tokens="777")
|
||||
# Initially configured with small model
|
||||
self.assertEqual(bot.get_llm_description(), "LLM: gpt-3.5-desc, Max Tokens: 777, Azure: False")
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -0,0 +1,154 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch, ANY
|
||||
import os
|
||||
|
||||
# Assuming gemini_telegram_inference_bot.py and its parent are accessible
|
||||
from gemini_telegram_inference_bot import GeminiTelegramInferenceBot
|
||||
from openai_compatible_inference_bot import OpenAICompatibleInferenceBot # For patching super
|
||||
|
||||
class TestGeminiTelegramInferenceBot(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Store and clear relevant environment variables
|
||||
self.original_gemini_key = os.environ.get("GEMINI_API_KEY")
|
||||
self.original_gemini_base_url = os.environ.get("GEMINI_API_BASE_URL")
|
||||
self.original_small_model = os.environ.get("GEMINI_SMALL_MODEL")
|
||||
self.original_large_model = os.environ.get("GEMINI_LARGE_MODEL")
|
||||
self.original_small_tokens = os.environ.get("GEMINI_SMALL_MODEL_MAX_TOKENS")
|
||||
self.original_large_tokens = os.environ.get("GEMINI_LARGE_MODEL_MAX_TOKENS")
|
||||
self.original_system_prompt_path = os.environ.get("SYSTEM_PROMPT_PATH")
|
||||
|
||||
for key in ["GEMINI_API_KEY", "GEMINI_API_BASE_URL", "GEMINI_SMALL_MODEL", "GEMINI_LARGE_MODEL",
|
||||
"GEMINI_SMALL_MODEL_MAX_TOKENS", "GEMINI_LARGE_MODEL_MAX_TOKENS", "SYSTEM_PROMPT_PATH"]:
|
||||
if os.environ.get(key):
|
||||
del os.environ[key]
|
||||
|
||||
self.mock_openai_client = MagicMock() # Used if superclass creates an OpenAI client
|
||||
|
||||
def tearDown(self):
|
||||
# Restore environment variables
|
||||
if self.original_gemini_key: os.environ["GEMINI_API_KEY"] = self.original_gemini_key
|
||||
if self.original_gemini_base_url: os.environ["GEMINI_API_BASE_URL"] = self.original_gemini_base_url
|
||||
if self.original_small_model: os.environ["GEMINI_SMALL_MODEL"] = self.original_small_model
|
||||
if self.original_large_model: os.environ["GEMINI_LARGE_MODEL"] = self.original_large_model
|
||||
if self.original_small_tokens: os.environ["GEMINI_SMALL_MODEL_MAX_TOKENS"] = self.original_small_tokens
|
||||
if self.original_large_tokens: os.environ["GEMINI_LARGE_MODEL_MAX_TOKENS"] = self.original_large_tokens
|
||||
if self.original_system_prompt_path: os.environ["SYSTEM_PROMPT_PATH"] = self.original_system_prompt_path
|
||||
|
||||
@patch.object(OpenAICompatibleInferenceBot, '__init__') # Mock the superclass's __init__
|
||||
def test_init_defaults_and_super_call(self, mock_super_init):
|
||||
os.environ["GEMINI_API_KEY"] = "test_key_gemini"
|
||||
os.environ["GEMINI_API_BASE_URL"] = "https://gemini.env.com"
|
||||
os.environ["GEMINI_SMALL_MODEL"] = "gemini-pro-env"
|
||||
os.environ["GEMINI_SMALL_MODEL_MAX_TOKENS"] = "360"
|
||||
|
||||
bot = GeminiTelegramInferenceBot()
|
||||
|
||||
mock_super_init.assert_called_once_with(
|
||||
client=None,
|
||||
api_key="test_key_gemini",
|
||||
base_url="https://gemini.env.com", # Passed to super
|
||||
api_version=None,
|
||||
azure_deployment=None,
|
||||
model_name="gemini-pro-env",
|
||||
max_tokens_str="360",
|
||||
small_model_name="gemini-pro-env",
|
||||
small_model_max_tokens_str="360",
|
||||
large_model_name=os.environ.get("GEMINI_LARGE_MODEL", "gemini-1.5-pro-latest"), # Default large
|
||||
large_model_max_tokens_str=os.environ.get("GEMINI_LARGE_MODEL_MAX_TOKENS"),
|
||||
system_prompt_content=None,
|
||||
system_prompt_path=None,
|
||||
is_gemini=True, # Important for Gemini bot
|
||||
max_history_length=20
|
||||
)
|
||||
|
||||
@patch.object(OpenAICompatibleInferenceBot, '__init__')
|
||||
def test_init_with_arguments(self, mock_super_init):
|
||||
mock_client_arg = MagicMock()
|
||||
bot = GeminiTelegramInferenceBot(
|
||||
openai_client=mock_client_arg, # Name in Gemini bot is openai_client for consistency
|
||||
api_key="arg_gem_key",
|
||||
base_url="https://arg.gemini.com",
|
||||
small_model_name="arg_gem_small",
|
||||
small_model_max_tokens="124",
|
||||
large_model_name="arg_gem_large",
|
||||
large_model_max_tokens="457",
|
||||
system_prompt_content="Gemini prompt"
|
||||
)
|
||||
mock_super_init.assert_called_once_with(
|
||||
client=mock_client_arg,
|
||||
api_key="arg_gem_key",
|
||||
base_url="https://arg.gemini.com",
|
||||
api_version=None,
|
||||
azure_deployment=None,
|
||||
model_name="arg_gem_small",
|
||||
max_tokens_str="124",
|
||||
small_model_name="arg_gem_small",
|
||||
small_model_max_tokens_str="124",
|
||||
large_model_name="arg_gem_large",
|
||||
large_model_max_tokens_str="457",
|
||||
system_prompt_content="Gemini prompt",
|
||||
system_prompt_path=None,
|
||||
is_gemini=True,
|
||||
max_history_length=20
|
||||
)
|
||||
|
||||
@patch('openai.OpenAI') # Gemini bot uses OpenAI client configured for Gemini endpoint
|
||||
async def test_switch_model_logic(self, mock_openai_constructor):
|
||||
mock_openai_constructor.return_value = self.mock_openai_client
|
||||
|
||||
os.environ["GEMINI_SMALL_MODEL"] = "env-gemini-small"
|
||||
os.environ["GEMINI_SMALL_MODEL_MAX_TOKENS"] = "110"
|
||||
os.environ["GEMINI_LARGE_MODEL"] = "env-gemini-large"
|
||||
os.environ["GEMINI_LARGE_MODEL_MAX_TOKENS"] = "220"
|
||||
|
||||
bot = GeminiTelegramInferenceBot() # Uses env vars by default
|
||||
self.assertEqual(bot.model, "env-gemini-small")
|
||||
self.assertEqual(bot.max_tokens, 110)
|
||||
|
||||
status = await bot.switch_model()
|
||||
self.assertEqual(bot.model, "env-gemini-large")
|
||||
self.assertEqual(bot.max_tokens, 220)
|
||||
self.assertEqual(status, "Switched to model: env-gemini-large")
|
||||
|
||||
status = await bot.switch_model()
|
||||
self.assertEqual(bot.model, "env-gemini-small")
|
||||
self.assertEqual(bot.max_tokens, 110)
|
||||
self.assertEqual(status, "Switched to model: env-gemini-small")
|
||||
|
||||
@patch('openai.OpenAI')
|
||||
async def test_switch_model_uses_instance_configs_if_provided(self, mock_openai_constructor):
|
||||
mock_openai_constructor.return_value = self.mock_openai_client
|
||||
|
||||
bot = GeminiTelegramInferenceBot(
|
||||
small_model_name="init-gem-small", small_model_max_tokens="55",
|
||||
large_model_name="init-gem-large", large_model_max_tokens="155"
|
||||
)
|
||||
self.assertEqual(bot.model, "init-gem-small")
|
||||
self.assertEqual(bot.max_tokens, 55)
|
||||
|
||||
status = await bot.switch_model()
|
||||
self.assertEqual(bot.model, "init-gem-large")
|
||||
self.assertEqual(bot.max_tokens, 155)
|
||||
self.assertEqual(status, "Switched to model: init-gem-large")
|
||||
|
||||
status = await bot.switch_model()
|
||||
self.assertEqual(bot.model, "init-gem-small")
|
||||
self.assertEqual(bot.max_tokens, 55)
|
||||
self.assertEqual(status, "Switched to model: init-gem-small")
|
||||
|
||||
@patch('openai.OpenAI')
|
||||
def test_get_llm_description_for_gemini_bot(self, mock_openai_constructor):
|
||||
mock_openai_constructor.return_value = self.mock_openai_client
|
||||
bot = GeminiTelegramInferenceBot(
|
||||
small_model_name="gemini-pro-desc",
|
||||
small_model_max_tokens="888",
|
||||
# is_gemini is True by default in constructor call to super
|
||||
)
|
||||
# LLM description should indicate not Azure, even though it uses OpenAICompatible... base
|
||||
# The is_gemini flag primarily affects client instantiation logic in the superclass.
|
||||
# The azure_openai flag in superclass is based on azure_endpoint presence.
|
||||
self.assertEqual(bot.get_llm_description(), "LLM: gemini-pro-desc, Max Tokens: 888, Azure: False")
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -0,0 +1,332 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch, AsyncMock, ANY
|
||||
import os
|
||||
import json
|
||||
|
||||
# Assuming openai_compatible_inference_bot.py is in the parent directory or PYTHONPATH is set
|
||||
from openai_compatible_inference_bot import OpenAICompatibleInferenceBot
|
||||
|
||||
# Mock response from OpenAI client's chat.completions.create
|
||||
def create_mock_openai_response(content=None, tool_calls=None):
|
||||
mock_message = MagicMock()
|
||||
mock_message.role = "assistant"
|
||||
mock_message.content = content
|
||||
if tool_calls:
|
||||
# tool_calls should be a list of objects with id and function (name, arguments)
|
||||
mock_tool_calls = []
|
||||
for tc in tool_calls:
|
||||
mock_tc = MagicMock()
|
||||
mock_tc.id = tc["id"]
|
||||
mock_tc.function.name = tc["function"]["name"]
|
||||
mock_tc.function.arguments = tc["function"]["arguments"]
|
||||
mock_tool_calls.append(mock_tc)
|
||||
mock_message.tool_calls = mock_tool_calls
|
||||
else:
|
||||
mock_message.tool_calls = None
|
||||
|
||||
mock_choice = MagicMock()
|
||||
mock_choice.message = mock_message
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [mock_choice]
|
||||
return mock_response
|
||||
|
||||
# Concrete class for testing
|
||||
class ConcreteOpenAICompatibleBot(OpenAICompatibleInferenceBot):
|
||||
# Implement abstract methods for instantiation
|
||||
async def switch_model(self):
|
||||
# Simple switch for testing if needed, or just pass
|
||||
if self.model == self.small_model_name:
|
||||
self._configure_model_and_tokens(self.large_model_name, self.large_model_max_tokens_str)
|
||||
else:
|
||||
self._configure_model_and_tokens(self.small_model_name, self.small_model_max_tokens_str)
|
||||
return f"Switched to {self.model}"
|
||||
|
||||
# Override load_functions if it's called by parent and needs mocking for these tests
|
||||
# (OpenAICompatibleInferenceBot's __init__ calls BaseTelegramInferenceBot's __init__, which calls load_functions)
|
||||
def load_functions(self):
|
||||
# For these tests, assume no tools unless specifically added
|
||||
self.tools = []
|
||||
self.functions = []
|
||||
return self.tools, self.functions
|
||||
|
||||
|
||||
class TestOpenAICompatibleInferenceBot(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.original_openai_api_key = os.environ.get("OPENAI_API_KEY")
|
||||
self.original_azure_openai_key = os.environ.get("AZURE_OPENAI_KEY")
|
||||
self.original_azure_endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT")
|
||||
self.original_api_version = os.environ.get("AZURE_OPENAI_API_VERSION")
|
||||
self.original_azure_deployment = os.environ.get("AZURE_DEPLOYMENT_NAME")
|
||||
|
||||
# Clear relevant env vars before each test
|
||||
for key in ["OPENAI_API_KEY", "AZURE_OPENAI_KEY", "AZURE_OPENAI_ENDPOINT",
|
||||
"AZURE_OPENAI_API_VERSION", "AZURE_DEPLOYMENT_NAME", "SYSTEM_PROMPT_PATH"]:
|
||||
if os.environ.get(key):
|
||||
del os.environ[key]
|
||||
|
||||
self.mock_openai_client_instance = MagicMock()
|
||||
self.mock_openai_client_instance.chat.completions.create = MagicMock()
|
||||
|
||||
def tearDown(self):
|
||||
# Restore environment variables
|
||||
if self.original_openai_api_key: os.environ["OPENAI_API_KEY"] = self.original_openai_api_key
|
||||
if self.original_azure_openai_key: os.environ["AZURE_OPENAI_KEY"] = self.original_azure_openai_key
|
||||
if self.original_azure_endpoint: os.environ["AZURE_OPENAI_ENDPOINT"] = self.original_azure_endpoint
|
||||
if self.original_api_version: os.environ["AZURE_OPENAI_API_VERSION"] = self.original_api_version
|
||||
if self.original_azure_deployment: os.environ["AZURE_DEPLOYMENT_NAME"] = self.original_azure_deployment
|
||||
|
||||
|
||||
@patch('openai.OpenAI')
|
||||
def test_init_with_openai_defaults(self, MockOpenAIConstructor):
|
||||
MockOpenAIConstructor.return_value = self.mock_openai_client_instance
|
||||
os.environ["OPENAI_API_KEY"] = "test_openai_key"
|
||||
|
||||
bot = ConcreteOpenAICompatibleBot(model_name="gpt-4")
|
||||
|
||||
MockOpenAIConstructor.assert_called_once_with(api_key="test_openai_key", base_url=None)
|
||||
self.assertEqual(bot.client, self.mock_openai_client_instance)
|
||||
self.assertEqual(bot.model, "gpt-4")
|
||||
self.assertEqual(bot.max_tokens, 1000) # Default from _configure_model_and_tokens
|
||||
self.assertEqual(bot.azure_openai, False)
|
||||
|
||||
@patch('openai.OpenAI')
|
||||
def test_init_with_provided_client(self, MockOpenAIConstructor):
|
||||
preconfigured_client = MagicMock()
|
||||
bot = ConcreteOpenAICompatibleBot(client=preconfigured_client, model_name="gpt-3.5")
|
||||
|
||||
MockOpenAIConstructor.assert_not_called()
|
||||
self.assertEqual(bot.client, preconfigured_client)
|
||||
self.assertEqual(bot.model, "gpt-3.5")
|
||||
|
||||
@patch('openai.AzureOpenAI')
|
||||
def test_init_with_azure_config_args(self, MockAzureOpenAIConstructor):
|
||||
MockAzureOpenAIConstructor.return_value = self.mock_openai_client_instance
|
||||
|
||||
bot = ConcreteOpenAICompatibleBot(
|
||||
api_key="azure_key",
|
||||
azure_endpoint="https://myenv.openai.azure.com",
|
||||
api_version="2023-05-15",
|
||||
azure_deployment="my-gpt-4", # This should be used as model_name for API call
|
||||
model_name="should_be_overridden_by_azure_deployment_for_api"
|
||||
# model_name is passed to _configure_model_and_tokens, which sets self.model for display/logging
|
||||
# but for Azure, the client needs the deployment name.
|
||||
)
|
||||
|
||||
MockAzureOpenAIConstructor.assert_called_once_with(
|
||||
api_key="azure_key",
|
||||
azure_endpoint="https://myenv.openai.azure.com",
|
||||
api_version="2023-05-15"
|
||||
)
|
||||
self.assertEqual(bot.client, self.mock_openai_client_instance)
|
||||
self.assertEqual(bot.model, "my-gpt-4") # Azure deployment name becomes the model for API calls
|
||||
self.assertEqual(bot.azure_openai, True)
|
||||
|
||||
|
||||
@patch('openai.AzureOpenAI')
|
||||
def test_init_with_azure_env_vars(self, MockAzureOpenAIConstructor):
|
||||
MockAzureOpenAIConstructor.return_value = self.mock_openai_client_instance
|
||||
os.environ["AZURE_OPENAI_KEY"] = "env_azure_key"
|
||||
os.environ["AZURE_OPENAI_ENDPOINT"] = "https://env.openai.azure.com"
|
||||
os.environ["AZURE_OPENAI_API_VERSION"] = "2023-06-01"
|
||||
os.environ["AZURE_DEPLOYMENT_NAME"] = "env-gpt-35" # Used as model_name
|
||||
|
||||
bot = ConcreteOpenAICompatibleBot(model_name="ignored_if_azure_deployment_env_is_set")
|
||||
|
||||
MockAzureOpenAIConstructor.assert_called_once_with(
|
||||
api_key="env_azure_key",
|
||||
azure_endpoint="https://env.openai.azure.com",
|
||||
api_version="2023-06-01"
|
||||
)
|
||||
self.assertEqual(bot.model, "env-gpt-35")
|
||||
self.assertTrue(bot.azure_openai)
|
||||
|
||||
@patch('openai.OpenAI')
|
||||
def test_init_with_gemini_config_args(self, MockOpenAIConstructor):
|
||||
MockOpenAIConstructor.return_value = self.mock_openai_client_instance
|
||||
|
||||
bot = ConcreteOpenAICompatibleBot(
|
||||
api_key="gemini_key",
|
||||
base_url="https://gemini.example.com",
|
||||
model_name="gemini-pro",
|
||||
is_gemini=True
|
||||
)
|
||||
MockOpenAIConstructor.assert_called_once_with(api_key="gemini_key", base_url="https://gemini.example.com")
|
||||
self.assertEqual(bot.model, "gemini-pro")
|
||||
self.assertFalse(bot.azure_openai) # is_gemini doesn't mean azure_openai
|
||||
|
||||
def test_configure_model_and_tokens(self):
|
||||
bot = ConcreteOpenAICompatibleBot(model_name="initial_model") # init calls _configure
|
||||
bot._configure_model_and_tokens("test-model", "500")
|
||||
self.assertEqual(bot.model, "test-model")
|
||||
self.assertEqual(bot.max_tokens, 500)
|
||||
|
||||
bot._configure_model_and_tokens("test-model-2", None, default_max_tokens=150)
|
||||
self.assertEqual(bot.max_tokens, 150)
|
||||
|
||||
bot._configure_model_and_tokens("test-model-3", "invalid_token_val")
|
||||
self.assertEqual(bot.max_tokens, 1000) # Default fallback
|
||||
|
||||
def test_get_llm_description(self):
|
||||
bot = ConcreteOpenAICompatibleBot(model_name="desc-model", max_tokens_str="256")
|
||||
self.assertEqual(bot.get_llm_description(), "LLM: desc-model, Max Tokens: 256, Azure: False")
|
||||
|
||||
bot_azure = ConcreteOpenAICompatibleBot(azure_deployment="azure-model", azure_endpoint="x", api_key="y", api_version="z")
|
||||
self.assertEqual(bot_azure.get_llm_description(), "LLM: azure-model, Max Tokens: 1000, Azure: True")
|
||||
|
||||
|
||||
def test_get_chat_response_success(self):
|
||||
bot = ConcreteOpenAICompatibleBot(client=self.mock_openai_client_instance, model_name="test-gpt")
|
||||
bot.max_tokens = 50 # Ensure this is set
|
||||
mock_api_response = create_mock_openai_response(content="Hello from API")
|
||||
self.mock_openai_client_instance.chat.completions.create.return_value = mock_api_response
|
||||
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
response = bot.get_chat_response(messages)
|
||||
|
||||
self.mock_openai_client_instance.chat.completions.create.assert_called_once_with(
|
||||
model="test-gpt",
|
||||
messages=messages,
|
||||
tools=ANY, # Assuming functions can be None or empty list
|
||||
tool_choice=ANY,
|
||||
max_tokens=50
|
||||
)
|
||||
self.assertEqual(response, mock_api_response)
|
||||
|
||||
def test_get_chat_response_api_error(self):
|
||||
bot = ConcreteOpenAICompatibleBot(client=self.mock_openai_client_instance, model_name="error-gpt")
|
||||
self.mock_openai_client_instance.chat.completions.create.side_effect = Exception("API Down")
|
||||
|
||||
with self.assertRaisesRegex(Exception, "API Down"):
|
||||
bot.get_chat_response([{"role": "user", "content": "trigger"}])
|
||||
|
||||
async def test_handle_message_simple_response(self):
|
||||
bot = ConcreteOpenAICompatibleBot(client=self.mock_openai_client_instance, model_name="chatty")
|
||||
bot.system_prompt = "You are a test bot." # Set directly for simplicity
|
||||
mock_api_response = create_mock_openai_response(content="Test reply")
|
||||
self.mock_openai_client_instance.chat.completions.create.return_value = mock_api_response
|
||||
|
||||
response_content = await bot.handle_message(user_id=1, user_message="Hello")
|
||||
|
||||
self.assertEqual(response_content, "Test reply")
|
||||
self.assertIn(1, bot.conversation_history)
|
||||
self.assertEqual(len(bot.conversation_history[1]), 3) # System, User, Assistant
|
||||
self.assertEqual(bot.conversation_history[1][0]["content"], "You are a test bot.")
|
||||
self.assertEqual(bot.conversation_history[1][2]["content"], "Test reply")
|
||||
|
||||
async def test_handle_message_with_tool_call_and_response(self):
|
||||
bot = ConcreteOpenAICompatibleBot(client=self.mock_openai_client_instance, model_name="tool-user")
|
||||
|
||||
# Mock functions/tools setup on the bot
|
||||
mock_tool_def = {"function": {"name": "get_weather", "description": "Gets weather", "parameters": {}}}
|
||||
bot.functions = [mock_tool_def] # Simulate tools are loaded
|
||||
|
||||
# API response 1: Request to call tool
|
||||
tool_call_request = [{"id": "call123", "function": {"name": "get_weather", "arguments": '''{"location": "moon"}'''}}]
|
||||
api_response_1 = create_mock_openai_response(tool_calls=tool_call_request)
|
||||
|
||||
# API response 2: Final answer after tool execution
|
||||
api_response_2 = create_mock_openai_response(content="The weather on the moon is chilly.")
|
||||
|
||||
self.mock_openai_client_instance.chat.completions.create.side_effect = [api_response_1, api_response_2]
|
||||
|
||||
# Mock self.call_tool
|
||||
bot.call_tool = MagicMock(return_value='''{"temperature": "-100 C"}''')
|
||||
|
||||
final_response = await bot.handle_message(user_id=2, user_message="Weather on moon?")
|
||||
|
||||
self.assertEqual(final_response, "The weather on the moon is chilly.")
|
||||
bot.call_tool.assert_called_once_with("get_weather", '''{"location": "moon"}''')
|
||||
|
||||
# Check conversation history includes tool messages
|
||||
history = bot.conversation_history[2]
|
||||
self.assertTrue(any(msg["role"] == "assistant" and msg.tool_calls is not None for msg in history))
|
||||
self.assertTrue(any(msg["role"] == "tool" and msg["name"] == "get_weather" for msg in history))
|
||||
self.assertEqual(self.mock_openai_client_instance.chat.completions.create.call_count, 2)
|
||||
|
||||
async def test_handle_message_max_history_length(self):
|
||||
bot = ConcreteOpenAICompatibleBot(client=self.mock_openai_client_instance, model_name="hist-test", max_history_length=3)
|
||||
self.mock_openai_client_instance.chat.completions.create.return_value = create_mock_openai_response(content="Ok")
|
||||
|
||||
await bot.handle_message(1, "Msg1") # Sys, User, Assist (3)
|
||||
self.assertEqual(len(bot.conversation_history[1]), 3)
|
||||
|
||||
await bot.handle_message(1, "Msg2") # User, Assist. Should be 3 (prev User, prev Assist, new User) -> then adds new Assist.
|
||||
# Before new call: [Sys, U1, A1]. New U2. Call with [Sys,U1,A1,U2]. Resp A2.
|
||||
# History: [Sys,U1,A1,U2,A2]. Limit 3. -> [A1,U2,A2] (if system is not preserved specially)
|
||||
# The current code appends to history then truncates if over limit.
|
||||
# So after Msg1: [S, U1, A1]. len=3.
|
||||
# For Msg2: History is [S, U1, A1]. Append U2. Call with [S,U1,A1,U2]. Append A2.
|
||||
# History now [S,U1,A1,U2,A2]. len=5. Truncate to 3.
|
||||
# Expected: [A1, U2, A2] or [U1,A1,U2] or [U2,A2,S] depending on how system prompt is handled in truncation.
|
||||
# The code is: self.conversation_history[user_id][-self.max_history_length:]
|
||||
# And system prompt is only added IF user_id not in self.conversation_history.
|
||||
# So, for Msg2, system prompt is not re-added.
|
||||
# History before Msg2 call: [S, U1, A1]
|
||||
# Messages for Msg2 call: [S, U1, A1, U2]
|
||||
# History after Msg2 response A2: [S, U1, A1, U2, A2]. Len 5.
|
||||
# Truncated to self.max_history_length=3: [A1, U2, A2]
|
||||
|
||||
# Call 1
|
||||
self.mock_openai_client_instance.chat.completions.create.reset_mock()
|
||||
self.mock_openai_client_instance.chat.completions.create.return_value = create_mock_openai_response(content="Reply1")
|
||||
await bot.handle_message(user_id=7, user_message="First message")
|
||||
self.assertEqual(len(bot.conversation_history[7]), 3) # System, User1, Assistant1
|
||||
|
||||
# Call 2
|
||||
self.mock_openai_client_instance.chat.completions.create.reset_mock()
|
||||
self.mock_openai_client_instance.chat.completions.create.return_value = create_mock_openai_response(content="Reply2")
|
||||
await bot.handle_message(user_id=7, user_message="Second message")
|
||||
# History before call: [S, U1, A1]. Messages for call: [S, U1, A1, U2]. History after: [S, U1, A1, U2, A2].
|
||||
# Truncated to 3: [A1, U2, A2]
|
||||
self.assertEqual(len(bot.conversation_history[7]), 3)
|
||||
self.assertEqual(bot.conversation_history[7][0]["content"], "Reply1") # A1
|
||||
self.assertEqual(bot.conversation_history[7][1]["content"], "Second message") # U2
|
||||
self.assertEqual(bot.conversation_history[7][2]["content"], "Reply2") # A2
|
||||
|
||||
# Call 3
|
||||
self.mock_openai_client_instance.chat.completions.create.reset_mock()
|
||||
self.mock_openai_client_instance.chat.completions.create.return_value = create_mock_openai_response(content="Reply3")
|
||||
await bot.handle_message(user_id=7, user_message="Third message")
|
||||
# History before call: [A1, U2, A2]. Messages for call: [A1, U2, A2, U3]. History after: [A1, U2, A2, U3, A3].
|
||||
# Truncated to 3: [A2, U3, A3]
|
||||
self.assertEqual(len(bot.conversation_history[7]), 3)
|
||||
self.assertEqual(bot.conversation_history[7][0]["content"], "Reply2") # A2
|
||||
self.assertEqual(bot.conversation_history[7][1]["content"], "Third message") # U3
|
||||
self.assertEqual(bot.conversation_history[7][2]["content"], "Reply3") # A3
|
||||
|
||||
|
||||
async def test_abort_processing(self):
|
||||
bot = ConcreteOpenAICompatibleBot(model_name="test")
|
||||
user_id = 123
|
||||
bot.processing_status[user_id] = {"processing": True, "message_id": 456}
|
||||
bot.conversation_history[user_id] = [{"role": "user", "content": "stuff"}]
|
||||
|
||||
with patch.object(bot, 'clear_conversation_history') as mock_clear_hist: # Patching the method from Base class
|
||||
result = await bot.abort_processing(user_id)
|
||||
|
||||
self.assertEqual(result, "Processing aborted and conversation cleared.")
|
||||
self.assertFalse(bot.processing_status[user_id]["processing"])
|
||||
mock_clear_hist.assert_called_once_with(user_id)
|
||||
|
||||
async def test_abort_processing_no_active_processing(self):
|
||||
bot = ConcreteOpenAICompatibleBot(model_name="test")
|
||||
user_id = 404 # Not in processing_status
|
||||
with patch.object(bot, 'clear_conversation_history') as mock_clear_hist:
|
||||
result = await bot.abort_processing(user_id)
|
||||
self.assertEqual(result, "No active processing found to abort. Conversation cleared.")
|
||||
mock_clear_hist.assert_called_once_with(user_id)
|
||||
|
||||
# Test for the abstract switch_model (basic call, actual logic in concrete class for this test)
|
||||
async def test_switch_model_concrete_implementation(self):
|
||||
bot = ConcreteOpenAICompatibleBot(model_name="model1", small_model_name="model1", large_model_name="model2", max_tokens_str="100")
|
||||
self.assertEqual(bot.model, "model1")
|
||||
await bot.switch_model() # Calls the concrete implementation
|
||||
self.assertEqual(bot.model, "model2")
|
||||
await bot.switch_model()
|
||||
self.assertEqual(bot.model, "model1")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -0,0 +1,356 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch, mock_open, AsyncMock
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
|
||||
# Assuming telegram_helper.py is in the parent directory or PYTHONPATH is set
|
||||
from telegram_helper import TelegramHelper, MessageHandlerLogicResult
|
||||
|
||||
# Mock for the bot passed to TelegramHelper
|
||||
class MockBot:
|
||||
def __init__(self):
|
||||
self.start = AsyncMock()
|
||||
self.clear_conversation_history = MagicMock()
|
||||
self.get_bot_status = AsyncMock(return_value="Bot Status OK")
|
||||
self.switch_model = AsyncMock(return_value="Model Switched OK")
|
||||
self.handle_message = AsyncMock() # Needs to return a string
|
||||
self.abort_processing = AsyncMock(return_value="Abort OK")
|
||||
self.set_processing_status = MagicMock()
|
||||
self.clear_processing_status = MagicMock()
|
||||
self.processing_status = {} # Add the attribute
|
||||
|
||||
# Mock for telegram.Update and related objects
|
||||
def create_mock_update(message_text=None, user_id=123, chat_id=456, message_id=789, callback_query_data=None):
|
||||
update = MagicMock()
|
||||
update.effective_user.id = user_id
|
||||
update.effective_chat.id = chat_id
|
||||
|
||||
if message_text:
|
||||
update.message.text = message_text
|
||||
update.message.reply_text = AsyncMock(return_value=MagicMock(message_id=message_id)) # reply_text returns a Message obj
|
||||
|
||||
if callback_query_data:
|
||||
update.callback_query.data = callback_query_data
|
||||
update.callback_query.from_user.id = user_id
|
||||
update.callback_query.answer = AsyncMock()
|
||||
update.callback_query.edit_message_text = AsyncMock()
|
||||
|
||||
return update
|
||||
|
||||
# Mock for telegram.ext.ContextTypes.DEFAULT_TYPE
|
||||
def create_mock_context():
|
||||
context = MagicMock()
|
||||
context.bot.delete_message = AsyncMock()
|
||||
context.bot.edit_message_text = AsyncMock() # For update_status_message
|
||||
return context
|
||||
|
||||
class TestTelegramHelper(unittest.IsolatedAsyncioTestCase): # Use IsolatedAsyncioTestCase for async methods
|
||||
|
||||
def setUp(self):
|
||||
self.mock_bot = MockBot()
|
||||
# Default paths for reboot files, can be overridden in tests
|
||||
self.reboot_claude_file = ".test_reboot_claude"
|
||||
self.reboot_file = ".test_doreboot"
|
||||
self.helper = TelegramHelper(
|
||||
self.mock_bot,
|
||||
reboot_claude_file_path=self.reboot_claude_file,
|
||||
reboot_file_path=self.reboot_file,
|
||||
chunk_message_sleep_duration=0.001 # Faster sleep for tests
|
||||
)
|
||||
# Clean up any potential leftover reboot files from previous runs
|
||||
if os.path.exists(self.reboot_claude_file):
|
||||
os.remove(self.reboot_claude_file)
|
||||
if os.path.exists(self.reboot_file):
|
||||
os.remove(self.reboot_file)
|
||||
|
||||
def tearDown(self):
|
||||
# Clean up reboot files created during tests
|
||||
if os.path.exists(self.reboot_claude_file):
|
||||
os.remove(self.reboot_claude_file)
|
||||
if os.path.exists(self.reboot_file):
|
||||
os.remove(self.reboot_file)
|
||||
|
||||
async def test_start_logic(self):
|
||||
response = await self.helper._start_logic()
|
||||
self.mock_bot.start.assert_called_once()
|
||||
self.assertEqual(response, "Hello! I\'m your AI assistant. How can I help you today?")
|
||||
|
||||
async def test_start_command(self):
|
||||
mock_update = create_mock_update(message_text="/start")
|
||||
mock_context = create_mock_context()
|
||||
|
||||
with patch.object(self.helper, \'_start_logic\', new_callable=AsyncMock) as mock_logic:
|
||||
mock_logic.return_value = "Start Logic Response"
|
||||
await self.helper.start(mock_update, mock_context)
|
||||
mock_logic.assert_called_once()
|
||||
mock_update.message.reply_text.assert_called_once_with("Start Logic Response")
|
||||
|
||||
async def test_clear_logic(self):
|
||||
user_id = 123
|
||||
response = await self.helper._clear_logic(user_id) # _clear_logic is async after refactor
|
||||
self.mock_bot.clear_conversation_history.assert_called_once_with(user_id)
|
||||
self.assertEqual(response, "Conversation history cleared. Let\'s start fresh!")
|
||||
|
||||
async def test_clear_command(self):
|
||||
mock_update = create_mock_update(message_text="/clear", user_id=123)
|
||||
mock_context = create_mock_context()
|
||||
with patch.object(self.helper, \'_clear_logic\', new_callable=AsyncMock) as mock_logic:
|
||||
mock_logic.return_value = "Clear Logic Response"
|
||||
await self.helper.clear(mock_update, mock_context)
|
||||
mock_logic.assert_called_once_with(123)
|
||||
mock_update.message.reply_text.assert_called_once_with("Clear Logic Response")
|
||||
|
||||
async def test_status_logic(self):
|
||||
self.mock_bot.get_bot_status.return_value = "Test Status"
|
||||
response = await self.helper._status_logic()
|
||||
self.mock_bot.get_bot_status.assert_called_once()
|
||||
self.assertEqual(response, "Test Status")
|
||||
|
||||
async def test_switch_logic_supported(self):
|
||||
self.mock_bot.switch_model.return_value = "Switched to Large Model"
|
||||
response = await self.helper._switch_logic()
|
||||
self.mock_bot.switch_model.assert_called_once()
|
||||
self.assertEqual(response, "Switched to Large Model")
|
||||
|
||||
async def test_switch_logic_not_supported(self):
|
||||
del self.mock_bot.switch_model # Simulate bot not having the attribute
|
||||
response = await self.helper._switch_logic()
|
||||
self.assertEqual(response, "Model switching is not supported for this bot.")
|
||||
|
||||
async def test_handle_message_logic_success(self):
|
||||
user_id = 100
|
||||
user_message = "Hello bot"
|
||||
bot_response = "Hello user <think>Thinking hard</think> Done."
|
||||
expected_processed_response = f"Hello user {self.helper.HTML_QUOTE_BLOCK_START}Thinking hard{self.helper.HTML_QUOTE_BLOCK_END} Done."
|
||||
self.mock_bot.handle_message.return_value = bot_response
|
||||
|
||||
result = await self.helper._handle_message_logic(user_id, user_message)
|
||||
|
||||
self.mock_bot.handle_message.assert_called_once_with(user_id, user_message)
|
||||
self.assertTrue(result["success"])
|
||||
self.assertEqual(result["response_text"], expected_processed_response)
|
||||
self.assertIsNone(result["error_message"])
|
||||
|
||||
async def test_handle_message_logic_bot_exception(self):
|
||||
user_id = 101
|
||||
user_message = "Trigger error"
|
||||
self.mock_bot.handle_message.side_effect = Exception("Bot Error")
|
||||
|
||||
result = await self.helper._handle_message_logic(user_id, user_message)
|
||||
|
||||
self.assertFalse(result["success"])
|
||||
self.assertIsNone(result["response_text"])
|
||||
self.assertEqual(result["error_message"], "Bot Error")
|
||||
|
||||
@patch(\'logging.error\')
|
||||
async def test_handle_message_command_success_short_message(self, mock_logging_error):
|
||||
mock_update = create_mock_update(message_text="Hi", user_id=200, chat_id=201, message_id=202)
|
||||
mock_context = create_mock_context()
|
||||
|
||||
logic_result = MessageHandlerLogicResult(success=True, response_text="Short response", error_message=None)
|
||||
|
||||
with patch.object(self.helper, \'_handle_message_logic\', new_callable=AsyncMock) as mock_message_logic:
|
||||
mock_message_logic.return_value = logic_result
|
||||
|
||||
await self.helper.handle_message(mock_update, mock_context)
|
||||
|
||||
mock_update.message.reply_text.assert_any_call("Processing your request...", reply_markup=unittest.mock.ANY)
|
||||
self.mock_bot.set_processing_status.assert_called_once_with(200, 202) # user_id, status_message_id
|
||||
mock_message_logic.assert_called_once_with(200, "Hi")
|
||||
mock_context.bot.delete_message.assert_called_once_with(chat_id=201, message_id=202)
|
||||
self.mock_bot.clear_processing_status.assert_called_once_with(200)
|
||||
mock_update.message.reply_text.assert_any_call("Short response") # Final response
|
||||
self.assertEqual(mock_update.message.reply_text.call_count, 2) # Processing + final
|
||||
|
||||
@patch(\'logging.error\')
|
||||
async def test_handle_message_command_success_long_message_chunks(self, mock_logging_error):
|
||||
mock_update = create_mock_update(message_text="Long text", user_id=200, chat_id=201, message_id=202)
|
||||
mock_context = create_mock_context()
|
||||
|
||||
long_response_text = "a" * 5000 # Longer than 4096
|
||||
chunk1 = long_response_text[:4096]
|
||||
chunk2 = long_response_text[4096:]
|
||||
|
||||
logic_result = MessageHandlerLogicResult(success=True, response_text=long_response_text, error_message=None)
|
||||
|
||||
with patch.object(self.helper, \'_handle_message_logic\', new_callable=AsyncMock) as mock_message_logic, \
|
||||
patch(\'asyncio.sleep\', new_callable=AsyncMock) as mock_sleep: # Mock sleep
|
||||
mock_message_logic.return_value = logic_result
|
||||
|
||||
await self.helper.handle_message(mock_update, mock_context)
|
||||
|
||||
mock_update.message.reply_text.assert_any_call(chunk1)
|
||||
mock_update.message.reply_text.assert_any_call(chunk2)
|
||||
mock_sleep.assert_called_once_with(self.helper.chunk_message_sleep_duration)
|
||||
self.assertEqual(mock_update.message.reply_text.call_count, 3) # Processing + 2 chunks
|
||||
|
||||
@patch(\'logging.error\')
|
||||
async def test_handle_message_command_logic_fails(self, mock_logging_error):
|
||||
mock_update = create_mock_update(message_text="Cause error in logic", user_id=200)
|
||||
mock_context = create_mock_context()
|
||||
logic_result = MessageHandlerLogicResult(success=False, response_text=None, error_message="Logic Failed")
|
||||
|
||||
with patch.object(self.helper, \'_handle_message_logic\', new_callable=AsyncMock) as mock_message_logic:
|
||||
mock_message_logic.return_value = logic_result
|
||||
await self.helper.handle_message(mock_update, mock_context)
|
||||
mock_update.message.reply_text.assert_any_call("Sorry, an error occurred while processing your request.")
|
||||
self.assertEqual(mock_update.message.reply_text.call_count, 2) # Processing + error message
|
||||
|
||||
@patch(\'logging.error\')
|
||||
async def test_handle_message_command_telegram_exception_after_logic(self, mock_logging_error):
|
||||
mock_update = create_mock_update(message_text="Test", user_id=200)
|
||||
mock_context = create_mock_context()
|
||||
logic_result = MessageHandlerLogicResult(success=True, response_text="OK", error_message=None)
|
||||
|
||||
# Make sending the final reply fail
|
||||
mock_update.message.reply_text.side_effect = [
|
||||
MagicMock(message_id=202), # For "Processing..."
|
||||
Exception("Telegram API Error") # For the actual response
|
||||
]
|
||||
|
||||
with patch.object(self.helper, \'_handle_message_logic\', new_callable=AsyncMock) as mock_message_logic:
|
||||
mock_message_logic.return_value = logic_result
|
||||
await self.helper.handle_message(mock_update, mock_context)
|
||||
|
||||
# Check if the generic error message was attempted
|
||||
# This is tricky because reply_text is already mocked with side_effect.
|
||||
# We\'d expect logs. Let\'s check logs or if processing status was cleared.
|
||||
self.mock_bot.clear_processing_status.assert_called_once_with(200)
|
||||
mock_logging_error.assert_any_call(unittest.mock.string_containing("Outer error in handle_message"))
|
||||
|
||||
|
||||
async def test_abort_processing_logic(self):
|
||||
user_id = 300
|
||||
self.mock_bot.abort_processing.return_value = "Aborted by bot"
|
||||
response = await self.helper._abort_processing_logic(user_id)
|
||||
self.mock_bot.abort_processing.assert_called_once_with(user_id)
|
||||
self.assertEqual(response, "Aborted by bot")
|
||||
|
||||
async def test_abort_processing_command(self):
|
||||
mock_update = create_mock_update(callback_query_data=\'abort\', user_id=301)
|
||||
mock_context = create_mock_context()
|
||||
with patch.object(self.helper, \'_abort_processing_logic\', new_callable=AsyncMock) as mock_logic:
|
||||
mock_logic.return_value = "Abort Logic Done"
|
||||
await self.helper.abort_processing(mock_update, mock_context)
|
||||
|
||||
mock_update.callback_query.answer.assert_called_once()
|
||||
mock_logic.assert_called_once_with(301)
|
||||
mock_update.callback_query.edit_message_text.assert_called_once_with(text="Abort Logic Done")
|
||||
|
||||
def test_reboot_logic_claude_and_main(self):
|
||||
user_message_parts = ["/reboot", "claude"]
|
||||
chat_id_to_write = "12345"
|
||||
|
||||
with patch("builtins.open", mock_open()) as mock_file:
|
||||
self.helper._reboot_logic(user_message_parts, chat_id_to_write)
|
||||
|
||||
# Check claude reboot file
|
||||
mock_file.assert_any_call(self.reboot_claude_file, \'w\')
|
||||
# Check main doreboot file
|
||||
mock_file.assert_any_call(self.reboot_file, \'w\')
|
||||
handle_claude = mock_file.return_value
|
||||
handle_main = mock_file.return_value # mock_open reuses the handle for multiple calls
|
||||
|
||||
# Check if write was called for claude file (empty write)
|
||||
# This part of assertion is tricky with single mock_file. Better to use different mocks if possible
|
||||
# or check the sequence of calls if the mock supports it well.
|
||||
# For now, assert_any_call ensures it was opened.
|
||||
|
||||
# Check content for main reboot file
|
||||
# Need to ensure the write for self.reboot_file had chat_id_to_write
|
||||
# This requires more sophisticated mock_open or patching os.path.exists and multiple open calls
|
||||
# Simpler check: was open(self.reboot_file, \'w\') called? Yes, via assert_any_call.
|
||||
# And was open(self.reboot_claude_file, \'w\') called? Yes.
|
||||
|
||||
# Verify files were created (mock_open doesn\'t actually create them)
|
||||
# This test relies on mock_open\'s behavior. To test file content, need more setup.
|
||||
# For now, assume open was called correctly.
|
||||
|
||||
def test_reboot_logic_main_only(self):
|
||||
user_message_parts = ["/reboot"]
|
||||
chat_id_to_write = "67890"
|
||||
with patch("builtins.open", mock_open()) as mock_file:
|
||||
self.helper._reboot_logic(user_message_parts, chat_id_to_write)
|
||||
# Ensure claude file was NOT opened for writing.
|
||||
# This requires asserting that a specific call didn\'t happen, or checking call_args_list
|
||||
claude_call = unittest.mock.call(self.reboot_claude_file, \'w\')
|
||||
self.assertNotIn(claude_call, mock_file.call_args_list)
|
||||
|
||||
mock_file.assert_any_call(self.reboot_file, \'w\')
|
||||
|
||||
@patch(\'sys.exit\') # Mock sys.exit to prevent test runner from exiting
|
||||
async def test_reboot_command(self, mock_sys_exit):
|
||||
mock_update = create_mock_update(message_text="/reboot claude", chat_id="chat1")
|
||||
mock_context = create_mock_context()
|
||||
|
||||
with patch.object(self.helper, \'_reboot_logic\') as mock_reboot_file_logic:
|
||||
await self.helper.reboot(mock_update, mock_context)
|
||||
|
||||
mock_reboot_file_logic.assert_called_once_with(["/reboot", "claude"], "chat1")
|
||||
mock_update.message.reply_text.assert_called_once_with("Rebooting the bot...")
|
||||
mock_sys_exit.assert_called_once_with(0)
|
||||
|
||||
@patch(\'os.path.exists\')
|
||||
@patch(\'builtins.open\', new_callable=mock_open)
|
||||
@patch(\'os.remove\')
|
||||
async def test_check_doreboot_file_logic_file_exists(self, mock_os_remove, mock_file_open, mock_os_path_exists):
|
||||
mock_os_path_exists.return_value = True
|
||||
mock_file_open.return_value.read.return_value.strip.return_value = "chat123"
|
||||
|
||||
chat_id = await self.helper._check_doreboot_file_logic()
|
||||
|
||||
mock_os_path_exists.assert_called_once_with(self.reboot_file)
|
||||
mock_file_open.assert_called_once_with(self.reboot_file, \'r\')
|
||||
mock_os_remove.assert_called_once_with(self.reboot_file)
|
||||
self.assertEqual(chat_id, "chat123")
|
||||
|
||||
@patch(\'os.path.exists\', return_value=False)
|
||||
async def test_check_doreboot_file_logic_file_not_exists(self, mock_os_path_exists):
|
||||
chat_id = await self.helper._check_doreboot_file_logic()
|
||||
mock_os_path_exists.assert_called_once_with(self.reboot_file)
|
||||
self.assertIsNone(chat_id)
|
||||
|
||||
@patch(\'logging.error\')
|
||||
@patch(\'os.path.exists\', return_value=True)
|
||||
@patch(\'builtins.open\', side_effect=IOError("Read error"))
|
||||
@patch(\'os.remove\') # To check if remove is called even on read error
|
||||
async def test_check_doreboot_file_logic_read_error(self, mock_os_remove, mock_file_open, mock_os_path_exists, mock_log_error):
|
||||
chat_id = await self.helper._check_doreboot_file_logic()
|
||||
|
||||
self.assertIsNone(chat_id)
|
||||
mock_log_error.assert_any_call(unittest.mock.string_containing("Error reading reboot file"))
|
||||
# Check if os.remove was attempted even after read error
|
||||
mock_os_remove.assert_called_once_with(self.reboot_file)
|
||||
|
||||
|
||||
async def test_check_doreboot_file_command_sends_message(self):
|
||||
mock_application = MagicMock()
|
||||
mock_application.bot.send_message = AsyncMock()
|
||||
|
||||
with patch.object(self.helper, \'_check_doreboot_file_logic\', new_callable=AsyncMock) as mock_logic:
|
||||
mock_logic.return_value = "chat789" # Simulate chat_id found
|
||||
await self.helper.check_doreboot_file(mock_application)
|
||||
|
||||
mock_logic.assert_called_once()
|
||||
mock_application.bot.send_message.assert_called_once_with(
|
||||
chat_id="chat789", text="The application has finished initializing."
|
||||
)
|
||||
|
||||
async def test_check_doreboot_file_command_no_chat_id(self):
|
||||
mock_application = MagicMock()
|
||||
mock_application.bot.send_message = AsyncMock()
|
||||
|
||||
with patch.object(self.helper, \'_check_doreboot_file_logic\', new_callable=AsyncMock) as mock_logic:
|
||||
mock_logic.return_value = None # Simulate no chat_id found
|
||||
await self.helper.check_doreboot_file(mock_application)
|
||||
|
||||
mock_logic.assert_called_once()
|
||||
mock_application.bot.send_message.assert_not_called()
|
||||
|
||||
# Note: Testing the run() method itself is more of an integration test,
|
||||
# as it involves setting up the full Application and polling loop.
|
||||
# Unit tests here focus on the helper\'s own logic methods.
|
||||
|
||||
if __name__ == \'__main__\':
|
||||
unittest.main()
|
||||
@@ -0,0 +1,307 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
import os
|
||||
import base64
|
||||
import logging
|
||||
import requests # Required for spec in MagicMock
|
||||
|
||||
# Ensure tools/github_tool.py is accessible
|
||||
from tools.github_tool import GitHubTool
|
||||
|
||||
# Helper to create a mock response for requests.Session
|
||||
def create_mock_response(status_code, json_data=None, text_data=None, headers=None, links=None):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = status_code
|
||||
if json_data is not None:
|
||||
mock_resp.json = MagicMock(return_value=json_data)
|
||||
mock_resp.text = text_data if text_data is not None else str(json_data)
|
||||
mock_resp.headers = headers if headers else {}
|
||||
mock_resp.links = links if links else {} # For pagination in _list_branches
|
||||
return mock_resp
|
||||
|
||||
class TestGitHubTool(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.mock_session = MagicMock(spec=requests.Session)
|
||||
self.mock_session.headers = {} # Simulate a new session's headers
|
||||
|
||||
self.test_token = "test_github_token"
|
||||
self.test_repo = "owner/repo"
|
||||
self.test_base_url = "https://api.example.com" # Use a non-default base_url for some tests
|
||||
|
||||
# Suppress logging output during tests unless explicitly testing for it
|
||||
self.logger = logging.getLogger('tools.github_tool')
|
||||
# Ensure only one NullHandler to prevent duplicate messages if tests run multiple times in a session
|
||||
if not any(isinstance(h, logging.NullHandler) for h in self.logger.handlers):
|
||||
self.logger.addHandler(logging.NullHandler())
|
||||
self.logger.propagate = False # Prevent propagation to root logger if it has handlers
|
||||
|
||||
def test_init_with_args_and_session(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, base_url=self.test_base_url, logger=self.logger)
|
||||
self.assertEqual(tool.session, self.mock_session)
|
||||
self.assertEqual(tool._token, self.test_token)
|
||||
self.assertEqual(tool._repo, self.test_repo)
|
||||
self.assertEqual(tool.base_url, self.test_base_url)
|
||||
self.assertEqual(tool.current_branch, "main") # Default initial branch
|
||||
|
||||
@patch('requests.Session')
|
||||
def test_init_creates_session_if_not_provided(self, MockSessionConstructor):
|
||||
mock_created_session = MagicMock(spec=requests.Session)
|
||||
mock_created_session.headers = {}
|
||||
MockSessionConstructor.return_value = mock_created_session
|
||||
|
||||
# Temporarily set env vars for this test
|
||||
original_token = os.environ.get("GITHUB_TOKEN")
|
||||
original_repo = os.environ.get("GITHUB_REPOSITORY")
|
||||
os.environ["GITHUB_TOKEN"] = "env_token"
|
||||
os.environ["GITHUB_REPOSITORY"] = "env/repo"
|
||||
|
||||
tool = GitHubTool(logger=self.logger) # Use env vars
|
||||
|
||||
MockSessionConstructor.assert_called_once()
|
||||
self.assertEqual(tool.session, mock_created_session)
|
||||
self.assertEqual(tool._token, "env_token")
|
||||
self.assertEqual(tool._repo, "env/repo")
|
||||
self.assertIn("Authorization", mock_created_session.headers)
|
||||
self.assertEqual(mock_created_session.headers["Authorization"], "token env_token")
|
||||
|
||||
# Restore original env vars
|
||||
if original_token is None: del os.environ["GITHUB_TOKEN"]
|
||||
else: os.environ["GITHUB_TOKEN"] = original_token
|
||||
if original_repo is None: del os.environ["GITHUB_REPOSITORY"]
|
||||
else: os.environ["GITHUB_REPOSITORY"] = original_repo
|
||||
|
||||
def test_init_raises_value_error_if_no_token(self):
|
||||
original_token = os.environ.pop("GITHUB_TOKEN", None)
|
||||
with self.assertRaisesRegex(ValueError, "GitHub token must be provided"):
|
||||
GitHubTool(repo=self.test_repo, logger=self.logger)
|
||||
if original_token: os.environ["GITHUB_TOKEN"] = original_token
|
||||
|
||||
def test_init_raises_value_error_if_no_repo(self):
|
||||
original_repo = os.environ.pop("GITHUB_REPOSITORY", None)
|
||||
with self.assertRaisesRegex(ValueError, "GitHub repository.*must be provided"):
|
||||
GitHubTool(token=self.test_token, logger=self.logger)
|
||||
if original_repo: os.environ["GITHUB_REPOSITORY"] = original_repo
|
||||
|
||||
def test_clear_resets_branch(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, initial_branch="feature-branch", logger=self.logger)
|
||||
# Mock _get_branch_sha for _set_current_branch called by clear
|
||||
with patch.object(tool, '_get_branch_sha', return_value="sha_for_main"):
|
||||
tool.clear()
|
||||
self.assertEqual(tool.current_branch, "main")
|
||||
|
||||
def test_get_functions_returns_list(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
functions = tool.get_functions()
|
||||
self.assertIsInstance(functions, list)
|
||||
self.assertTrue(len(functions) > 0)
|
||||
self.assertIn("name", functions[0]["function"])
|
||||
|
||||
|
||||
# --- Test individual private methods ---
|
||||
|
||||
def test_read_file_success(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
file_content = "Hello World!"
|
||||
encoded_content = base64.b64encode(file_content.encode('utf-8')).decode('utf-8')
|
||||
self.mock_session.get.return_value = create_mock_response(200, json_data={"content": encoded_content})
|
||||
|
||||
result = tool._read_file(path="test.txt")
|
||||
self.assertEqual(result, file_content)
|
||||
self.mock_session.get.assert_called_once_with(
|
||||
f"{tool.base_url}/repos/{self.test_repo}/contents/test.txt",
|
||||
params={"ref": "main"}
|
||||
)
|
||||
|
||||
def test_read_file_error(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
self.mock_session.get.return_value = create_mock_response(404, text_data="Not Found")
|
||||
result = tool._read_file(path="nonexistent.txt")
|
||||
self.assertIn("Error reading file", result)
|
||||
|
||||
def test_create_branch_success(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
# Mock getting base branch SHA
|
||||
self.mock_session.get.return_value = create_mock_response(200, json_data={"object": {"sha": "base_sha123"}})
|
||||
# Mock creating new branch
|
||||
self.mock_session.post.return_value = create_mock_response(201, json_data={"ref": "refs/heads/new-feature"})
|
||||
|
||||
result = tool._create_branch(branch_name="new-feature", base_branch="main")
|
||||
self.assertIn("Branch 'new-feature' created successfully", result)
|
||||
self.assertEqual(tool.current_branch, "new-feature")
|
||||
self.mock_session.get.assert_called_once() # For base branch SHA
|
||||
self.mock_session.post.assert_called_once() # For creating branch
|
||||
|
||||
def test_create_branch_base_sha_error(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
self.mock_session.get.return_value = create_mock_response(404, text_data="Base branch not found")
|
||||
result = tool._create_branch(branch_name="new-feature", base_branch="nonexistent-base")
|
||||
self.assertIn("Error getting base branch SHA", result)
|
||||
|
||||
def test_create_branch_creation_error(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
self.mock_session.get.return_value = create_mock_response(200, json_data={"object": {"sha": "base_sha456"}})
|
||||
self.mock_session.post.return_value = create_mock_response(422, text_data="Validation failed")
|
||||
result = tool._create_branch(branch_name="bad-branch", base_branch="main")
|
||||
self.assertIn("Error creating branch", result)
|
||||
|
||||
def test_commit_file_success_new_file(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
tool.current_branch = "dev-branch" # Cannot commit to main by default
|
||||
|
||||
# Mock GET for checking file existence (404 means new file)
|
||||
self.mock_session.get.return_value = create_mock_response(404)
|
||||
# Mock PUT for committing file
|
||||
self.mock_session.put.return_value = create_mock_response(201, json_data={"commit": {"sha": "commit_sha_abc"}})
|
||||
|
||||
result = tool._commit_file(file_path="new_file.py", content="print('Hello')", commit_message="Add new_file.py")
|
||||
self.assertIn("committed successfully", result)
|
||||
self.assertIn("commit_sha_abc", result)
|
||||
self.mock_session.get.assert_called_once() # Check file existence
|
||||
self.mock_session.put.assert_called_once() # Commit file
|
||||
|
||||
def test_commit_file_success_update_file(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
tool.current_branch = "dev-branch"
|
||||
|
||||
# Mock GET for checking file existence (200 means existing file)
|
||||
self.mock_session.get.return_value = create_mock_response(200, json_data={"sha": "existing_file_sha"})
|
||||
# Mock PUT for committing file
|
||||
self.mock_session.put.return_value = create_mock_response(200, json_data={"commit": {"sha": "commit_sha_def"}})
|
||||
|
||||
result = tool._commit_file(file_path="existing_file.txt", content="Updated content", commit_message="Update existing_file.txt")
|
||||
self.assertIn("committed successfully", result)
|
||||
self.assertIn("commit_sha_def", result)
|
||||
args, kwargs = self.mock_session.put.call_args
|
||||
self.assertEqual(kwargs['json']['sha'], "existing_file_sha")
|
||||
|
||||
|
||||
def test_commit_file_to_main_branch_error(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
tool.current_branch = "main"
|
||||
result = tool._commit_file(file_path="some.txt", content="content", commit_message="msg")
|
||||
self.assertIn("Action directly to main branch is not allowed", result)
|
||||
|
||||
def test_create_pull_request_success(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
tool.current_branch = "feature-pr"
|
||||
pr_url = "https://example.com/pull/1"
|
||||
self.mock_session.post.return_value = create_mock_response(201, json_data={"html_url": pr_url, "number": 1})
|
||||
|
||||
result = tool._create_pull_request(title="New Feature PR", body="Please review.", base="main")
|
||||
self.assertIn(f"Pull request created successfully: {pr_url}", result)
|
||||
self.mock_session.post.assert_called_once()
|
||||
call_data = self.mock_session.post.call_args[1]['json']
|
||||
self.assertEqual(call_data['head'], "feature-pr")
|
||||
self.assertEqual(call_data['base'], "main")
|
||||
|
||||
def test_create_pull_request_same_branch_error(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
tool.current_branch = "main"
|
||||
result = tool._create_pull_request(title="PR to self", body="This should fail", base="main")
|
||||
self.assertIn("Cannot create a pull request from branch 'main' to itself", result)
|
||||
|
||||
|
||||
def test_list_files_success(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
mock_items = [
|
||||
{"name": "file1.txt", "type": "file", "path": "dir/file1.txt"},
|
||||
{"name": "subdir", "type": "dir", "path": "dir/subdir"}
|
||||
]
|
||||
self.mock_session.get.return_value = create_mock_response(200, json_data=mock_items)
|
||||
|
||||
result = tool._list_files(path="dir")
|
||||
self.assertEqual(len(result), 2)
|
||||
self.assertEqual(result[0]["name"], "file1.txt")
|
||||
self.assertEqual(result[1]["type"], "dir")
|
||||
|
||||
def test_search_code_success(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
mock_search_results = {
|
||||
"items": [{"path": "src/code.py", "html_url": "url1"}]
|
||||
}
|
||||
self.mock_session.get.return_value = create_mock_response(200, json_data=mock_search_results)
|
||||
|
||||
results = tool._search_code(query="my_function")
|
||||
self.assertEqual(len(results), 1)
|
||||
self.assertEqual(results[0]["path"], "src/code.py")
|
||||
|
||||
def test_get_commit_history_success(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
mock_commits = [{
|
||||
"sha": "sha1", "commit": {"message": "Msg1", "author": {"name": "Authy", "date": "Date1"}}
|
||||
}]
|
||||
self.mock_session.get.return_value = create_mock_response(200, json_data=mock_commits)
|
||||
|
||||
commits = tool._get_commit_history(file_path="file.txt", num_commits=1)
|
||||
self.assertEqual(len(commits), 1)
|
||||
self.assertEqual(commits[0]["sha"], "sha1")
|
||||
|
||||
def test_set_current_branch_success(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
# Mock _get_branch_sha to simulate branch exists
|
||||
with patch.object(tool, '_get_branch_sha', return_value="some_sha_for_dev"):
|
||||
result = tool._set_current_branch(branch_name="dev")
|
||||
self.assertEqual(tool.current_branch, "dev")
|
||||
self.assertIn("Current branch set to: dev", result)
|
||||
|
||||
def test_set_current_branch_not_exists(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
with patch.object(tool, '_get_branch_sha', return_value="Error getting SHA for branch"):
|
||||
result = tool._set_current_branch(branch_name="nonexistent-branch")
|
||||
self.assertNotEqual(tool.current_branch, "nonexistent-branch") # Should not change
|
||||
self.assertIn("Cannot set current branch", result)
|
||||
|
||||
|
||||
def test_list_branches_single_page(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
mock_branches = [{"name": "main"}, {"name": "dev"}]
|
||||
self.mock_session.get.return_value = create_mock_response(200, json_data=mock_branches, links={}) # No "next" link
|
||||
|
||||
branches = tool._list_branches(all_pages=True)
|
||||
self.assertEqual(branches, ["main", "dev"])
|
||||
self.mock_session.get.assert_called_once()
|
||||
|
||||
def test_list_branches_multiple_pages(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
|
||||
# Page 1 response
|
||||
page1_branches = [{"name": "branch1"}, {"name": "branch2"}]
|
||||
next_url = f"{tool.base_url}/repos/{self.test_repo}/branches?page=2"
|
||||
response1 = create_mock_response(200, json_data=page1_branches, links={"next": {"url": next_url}})
|
||||
|
||||
# Page 2 response
|
||||
page2_branches = [{"name": "branch3"}]
|
||||
response2 = create_mock_response(200, json_data=page2_branches, links={}) # No "next" link
|
||||
|
||||
self.mock_session.get.side_effect = [response1, response2]
|
||||
|
||||
branches = tool._list_branches(all_pages=True)
|
||||
self.assertEqual(branches, ["branch1", "branch2", "branch3"])
|
||||
self.assertEqual(self.mock_session.get.call_count, 2)
|
||||
|
||||
# Check that the second call used the next_url
|
||||
calls = self.mock_session.get.call_args_list
|
||||
self.assertEqual(calls[1][0][0], next_url) # args[0] is the URL
|
||||
|
||||
# --- Test execute dispatcher ---
|
||||
def test_execute_read_file(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
with patch.object(tool, '_read_file', return_value="file content") as mock_method:
|
||||
result = tool.execute(function_name="read_file", path="test.md")
|
||||
mock_method.assert_called_once_with(path="test.md")
|
||||
self.assertEqual(result, "file content")
|
||||
|
||||
def test_execute_unknown_function(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
result = tool.execute(function_name="non_existent_function_name", arg1="val1")
|
||||
self.assertIn("Unknown function: non_existent_function_name", result)
|
||||
|
||||
def test_execute_method_exception(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
with patch.object(tool, '_read_file', side_effect=Exception("Kaboom")) as mock_method:
|
||||
result = tool.execute(function_name="read_file", path="crash.txt")
|
||||
self.assertIn("Error during read_file execution: Kaboom", result)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -0,0 +1,146 @@
|
||||
import unittest
|
||||
from unittest.mock import patch, mock_open, MagicMock
|
||||
import os
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
# Ensure tools/log_tool.py is accessible
|
||||
from tools.log_tool import LogTool
|
||||
|
||||
class TestLogTool(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.test_log_file_path = "test_dummy_log.log"
|
||||
# Suppress logging output during tests unless explicitly testing for it
|
||||
self.logger = logging.getLogger('tools.log_tool')
|
||||
# Ensure only one NullHandler to prevent duplicate messages if tests run multiple times in a session
|
||||
if not any(isinstance(h, logging.NullHandler) for h in self.logger.handlers):
|
||||
self.logger.addHandler(logging.NullHandler())
|
||||
self.logger.propagate = False # Prevent propagation to root logger if it has handlers
|
||||
|
||||
|
||||
def test_init_default_log_path(self):
|
||||
tool = LogTool(logger=self.logger)
|
||||
self.assertEqual(tool.configured_log_file_path, 'logs/output.log')
|
||||
|
||||
def test_init_custom_log_path(self):
|
||||
tool = LogTool(log_file_path=self.test_log_file_path, logger=self.logger)
|
||||
self.assertEqual(tool.configured_log_file_path, self.test_log_file_path)
|
||||
|
||||
def test_get_functions(self):
|
||||
tool = LogTool(logger=self.logger)
|
||||
functions = tool.get_functions()
|
||||
self.assertIsInstance(functions, list)
|
||||
self.assertEqual(len(functions), 1)
|
||||
self.assertEqual(functions[0]["function"]["name"], "get_log_contents")
|
||||
|
||||
@patch("os.path.exists", return_value=False)
|
||||
def test_get_log_contents_file_not_exists(self, mock_exists):
|
||||
tool = LogTool(log_file_path=self.test_log_file_path, logger=self.logger)
|
||||
result = tool._get_log_contents()
|
||||
self.assertIn("Log file does not exist", result)
|
||||
mock_exists.assert_called_once_with(self.test_log_file_path)
|
||||
|
||||
@patch("os.path.exists", return_value=True)
|
||||
@patch("builtins.open", new_callable=mock_open, read_data="line1\nline2\nline3\nline4\nline5")
|
||||
def test_get_log_contents_with_line_count(self, mock_file_open, mock_exists):
|
||||
tool = LogTool(log_file_path=self.test_log_file_path, logger=self.logger)
|
||||
|
||||
result = tool._get_log_contents(line_count=3)
|
||||
self.assertEqual(result, "line3\nline4\nline5")
|
||||
mock_exists.assert_called_once_with(self.test_log_file_path)
|
||||
mock_file_open.assert_called_once_with(self.test_log_file_path, 'r', encoding='utf-8')
|
||||
|
||||
@patch("os.path.exists", return_value=True)
|
||||
@patch("builtins.open", new_callable=mock_open, read_data="line1\nline2\n")
|
||||
def test_get_log_contents_line_count_more_than_available(self, mock_file_open, mock_exists):
|
||||
tool = LogTool(log_file_path=self.test_log_file_path, logger=self.logger)
|
||||
result = tool._get_log_contents(line_count=5)
|
||||
self.assertEqual(result, "line1\nline2\n")
|
||||
|
||||
@patch("os.path.exists", return_value=True)
|
||||
@patch("builtins.open", new_callable=mock_open, read_data="line1\nline2\n")
|
||||
def test_get_log_contents_invalid_line_count_uses_default(self, mock_file_open, mock_exists):
|
||||
tool = LogTool(log_file_path=self.test_log_file_path, logger=self.logger)
|
||||
# Test with zero, negative, and non-integer line_count (though type hint is int)
|
||||
# The code defaults to 150 if invalid. Here, we only have 2 lines.
|
||||
with patch.object(tool.logger, 'warning') as mock_log_warning:
|
||||
result_zero = tool._get_log_contents(line_count=0)
|
||||
self.assertEqual(result_zero, "line1\nline2\n")
|
||||
mock_log_warning.assert_any_call("Invalid line_count '0' provided, defaulting to fetch last 150 lines.")
|
||||
|
||||
mock_file_open.reset_mock() # Reset for next call
|
||||
result_neg = tool._get_log_contents(line_count=-5)
|
||||
self.assertEqual(result_neg, "line1\nline2\n")
|
||||
mock_log_warning.assert_any_call("Invalid line_count '-5' provided, defaulting to fetch last 150 lines.")
|
||||
|
||||
|
||||
@patch("os.path.exists", return_value=True)
|
||||
def test_get_log_contents_last_24_hours(self, mock_exists):
|
||||
tool = LogTool(log_file_path=self.test_log_file_path, logger=self.logger)
|
||||
|
||||
now = datetime.now()
|
||||
one_hour_ago_dt = now - timedelta(hours=1)
|
||||
two_days_ago_dt = now - timedelta(days=2)
|
||||
|
||||
one_hour_ago_str = one_hour_ago_dt.strftime(LogTool.EXPECTED_LOG_TIMESTAMP_FORMAT)
|
||||
two_days_ago_str = two_days_ago_dt.strftime(LogTool.EXPECTED_LOG_TIMESTAMP_FORMAT)
|
||||
|
||||
log_data = (
|
||||
f"{two_days_ago_str} - OLD - This is an old log entry.\n"
|
||||
f"No timestamp here - this line should be skipped by time filter.\n"
|
||||
f"{one_hour_ago_str} - RECENT - This is a recent log entry.\n"
|
||||
f"Malformed Date 2023-xx-01 - Another skipped line.\n"
|
||||
f"{now.strftime(LogTool.EXPECTED_LOG_TIMESTAMP_FORMAT)} - CURRENT - This is a current log entry.\n"
|
||||
)
|
||||
|
||||
expected_output = (
|
||||
f"{one_hour_ago_str} - RECENT - This is a recent log entry.\n"
|
||||
f"{now.strftime(LogTool.EXPECTED_LOG_TIMESTAMP_FORMAT)} - CURRENT - This is a current log entry.\n"
|
||||
)
|
||||
|
||||
with patch("builtins.open", mock_open(read_data=log_data)):
|
||||
result = tool._get_log_contents(line_count=None) # Trigger 24-hour logic
|
||||
self.assertEqual(result, expected_output)
|
||||
|
||||
@patch("os.path.exists", return_value=True)
|
||||
@patch("builtins.open", side_effect=IOError("File read error!"))
|
||||
def test_get_log_contents_file_read_exception(self, mock_file_open, mock_exists):
|
||||
tool = LogTool(log_file_path=self.test_log_file_path, logger=self.logger)
|
||||
result = tool._get_log_contents(line_count=10)
|
||||
self.assertIn("An error occurred while reading the log file: File read error!", result)
|
||||
|
||||
def test_execute_get_log_contents(self):
|
||||
tool = LogTool(logger=self.logger)
|
||||
mock_return_value = "Mocked log content"
|
||||
with patch.object(tool, '_get_log_contents', return_value=mock_return_value) as mock_method:
|
||||
result = tool.execute(function_name="get_log_contents", line_count=50)
|
||||
mock_method.assert_called_once_with(line_count=50)
|
||||
self.assertEqual(result, mock_return_value)
|
||||
|
||||
def test_execute_get_log_contents_no_line_count(self):
|
||||
tool = LogTool(logger=self.logger)
|
||||
mock_return_value = "Mocked log content for 24h"
|
||||
with patch.object(tool, '_get_log_contents', return_value=mock_return_value) as mock_method:
|
||||
result = tool.execute(function_name="get_log_contents") # No line_count
|
||||
mock_method.assert_called_once_with(line_count=None) # Expects None to trigger 24h
|
||||
self.assertEqual(result, mock_return_value)
|
||||
|
||||
|
||||
def test_execute_unknown_function(self):
|
||||
tool = LogTool(logger=self.logger)
|
||||
result = tool.execute(function_name="non_existent_log_function")
|
||||
self.assertIn("Unknown function: non_existent_log_function", result)
|
||||
|
||||
def test_clear_method(self):
|
||||
tool = LogTool(logger=self.logger)
|
||||
# Set a specific level for the logger for this test if needed to capture debug
|
||||
original_level = tool.logger.level
|
||||
tool.logger.setLevel(logging.DEBUG)
|
||||
with self.assertLogs(tool.logger, level='DEBUG') as cm:
|
||||
tool.clear()
|
||||
self.assertTrue(any("LogTool clear called" in message for message in cm.output))
|
||||
tool.logger.setLevel(original_level) # Reset level
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -0,0 +1,217 @@
|
||||
import unittest
|
||||
from unittest.mock import patch, MagicMock, ANY
|
||||
import time
|
||||
import logging
|
||||
|
||||
# Ensure tools.metrics is accessible
|
||||
from tools.metrics import Metrics # Import the class itself for direct testing
|
||||
from tools.metrics import metrics as global_metrics_instance # Import the global instance
|
||||
|
||||
# A simple function to decorate for testing
|
||||
def sample_function_for_metrics(duration=0.01):
|
||||
# Simulate some work
|
||||
# Note: time.sleep is not always precisely profiled by cProfile in the same way as CPU-bound work.
|
||||
# For testing, we will mock the cProfile/pstats interaction rather than relying on actual sleep duration.
|
||||
if duration > 0: # Make it conditional so we can test zero-time case too
|
||||
pass # The actual work is not important when mocking cProfile results
|
||||
return "sample_output"
|
||||
|
||||
def another_sample_function(x, y):
|
||||
return x + y
|
||||
|
||||
class TestMetrics(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Create a fresh Metrics instance for most tests to avoid interference
|
||||
self.logger = logging.getLogger('tools.metrics.test')
|
||||
if not self.logger.handlers: # Avoid adding handler multiple times
|
||||
self.logger.addHandler(logging.NullHandler())
|
||||
self.metrics_instance = Metrics(logger=self.logger)
|
||||
|
||||
# Clear the global instance before each test that might use it
|
||||
global_metrics_instance.clear_metrics()
|
||||
|
||||
def test_measure_decorator_counts_calls(self):
|
||||
decorated_func = self.metrics_instance.measure(sample_function_for_metrics)
|
||||
|
||||
self.assertEqual(self.metrics_instance.call_count["sample_function_for_metrics"], 0)
|
||||
decorated_func()
|
||||
self.assertEqual(self.metrics_instance.call_count["sample_function_for_metrics"], 1)
|
||||
decorated_func()
|
||||
self.assertEqual(self.metrics_instance.call_count["sample_function_for_metrics"], 2)
|
||||
|
||||
@patch('cProfile.Profile')
|
||||
@patch('pstats.Stats')
|
||||
def test_measure_decorator_records_time(self, MockPStats, MockCProfile):
|
||||
# Mock cProfile and pstats to control the time value
|
||||
mock_profiler_instance = MockCProfile.return_value
|
||||
mock_pstats_instance = MockPStats.return_value
|
||||
|
||||
# Simulate that pstats.Stats.stats dictionary contains the function's stats
|
||||
# Key: (filename, lineno, funcname)
|
||||
# Value: (cc, nc, tt, ct, callers) where ct is cumulative_time (index 3)
|
||||
|
||||
# Get code object of the function *before* decoration for correct key
|
||||
original_func_code = sample_function_for_metrics.__code__
|
||||
func_key = (original_func_code.co_filename, original_func_code.co_firstlineno, original_func_code.co_name)
|
||||
|
||||
# Configure mock_pstats_instance.stats to return our desired time
|
||||
mock_pstats_instance.stats = {func_key: (1, 1, 0.05, 0.123, {})} # cc, nc, tt, ct=0.123
|
||||
|
||||
decorated_func = self.metrics_instance.measure(sample_function_for_metrics)
|
||||
|
||||
self.assertEqual(self.metrics_instance.total_time["sample_function_for_metrics"], 0)
|
||||
|
||||
# Call the decorated function
|
||||
decorated_func(duration=0) # Duration arg doesn't matter due to mocking
|
||||
|
||||
# Assertions
|
||||
mock_profiler_instance.enable.assert_called_once()
|
||||
mock_profiler_instance.disable.assert_called_once()
|
||||
MockPStats.assert_called_once_with(mock_profiler_instance)
|
||||
|
||||
self.assertEqual(self.metrics_instance.total_time["sample_function_for_metrics"], 0.123)
|
||||
|
||||
# Call again to see accumulation
|
||||
# Reset mock stats for a new time value if needed, or assume same time per call
|
||||
mock_pstats_instance.stats = {func_key: (1, 1, 0.05, 0.100, {})} # New ct=0.100
|
||||
decorated_func(duration=0)
|
||||
self.assertAlmostEqual(self.metrics_instance.total_time["sample_function_for_metrics"], 0.123 + 0.100)
|
||||
|
||||
|
||||
@patch('cProfile.Profile')
|
||||
@patch('pstats.Stats')
|
||||
def test_measure_decorator_fallback_time_recording_by_name(self, MockPStats, MockCProfile):
|
||||
mock_profiler_instance = MockCProfile.return_value
|
||||
mock_pstats_instance = MockPStats.return_value
|
||||
|
||||
original_func_code = sample_function_for_metrics.__code__ # func to be decorated
|
||||
# Simulate the primary key lookup fails by creating a slightly different key for what we expect
|
||||
# This is what the code will try to look up first.
|
||||
expected_primary_key = (original_func_code.co_filename, original_func_code.co_firstlineno, original_func_code.co_name)
|
||||
|
||||
# This is the key that will *actually* be in pstats.stats, simulating a mismatch for primary lookup
|
||||
# but a match for the by-name fallback.
|
||||
actual_stats_key_in_pstats = (original_func_code.co_filename,
|
||||
original_func_code.co_firstlineno + 5, # simulate a lineno difference for primary key mismatch
|
||||
original_func_code.co_name) # Name is the same for fallback
|
||||
|
||||
mock_pstats_instance.stats = {
|
||||
# expected_primary_key is NOT present
|
||||
actual_stats_key_in_pstats: (1, 1, 0.03, 0.077, {}) # ct = 0.077
|
||||
}
|
||||
|
||||
decorated_func = self.metrics_instance.measure(sample_function_for_metrics)
|
||||
|
||||
# Expecting a debug log for fallback, but assertLogs needs the logger to have a handler that captures.
|
||||
# self.logger is already set up with NullHandler. For this test, let's use a specific logger.
|
||||
metrics_internal_logger = logging.getLogger('tools.metrics') # Logger used inside Metrics class
|
||||
original_level = metrics_internal_logger.level
|
||||
metrics_internal_logger.setLevel(logging.DEBUG)
|
||||
|
||||
with self.assertLogs(metrics_internal_logger, level='DEBUG') as log_capture:
|
||||
decorated_func(duration=0)
|
||||
|
||||
metrics_internal_logger.setLevel(original_level) # Reset logger level
|
||||
|
||||
self.assertTrue(any("Found stats for sample_function_for_metrics by name" in msg for msg in log_capture.output))
|
||||
self.assertEqual(self.metrics_instance.total_time["sample_function_for_metrics"], 0.077)
|
||||
|
||||
|
||||
@patch('cProfile.Profile')
|
||||
@patch('pstats.Stats')
|
||||
def test_measure_decorator_handles_func_stats_not_found(self, MockPStats, MockCProfile):
|
||||
mock_profiler_instance = MockCProfile.return_value
|
||||
mock_pstats_instance = MockPStats.return_value
|
||||
mock_pstats_instance.stats = {} # Empty stats, function will not be found
|
||||
|
||||
decorated_func = self.metrics_instance.measure(sample_function_for_metrics)
|
||||
|
||||
metrics_internal_logger = logging.getLogger('tools.metrics')
|
||||
original_level = metrics_internal_logger.level
|
||||
metrics_internal_logger.setLevel(logging.WARNING)
|
||||
with self.assertLogs(metrics_internal_logger, level='WARNING') as log_capture:
|
||||
decorated_func(duration=0)
|
||||
metrics_internal_logger.setLevel(original_level)
|
||||
|
||||
self.assertTrue(any("Could not find exact cProfile stats" in msg for msg in log_capture.output))
|
||||
self.assertEqual(self.metrics_instance.total_time["sample_function_for_metrics"], 0)
|
||||
|
||||
|
||||
def test_get_metrics_empty(self):
|
||||
self.assertEqual(self.metrics_instance.get_metrics(), {})
|
||||
|
||||
@patch('cProfile.Profile')
|
||||
@patch('pstats.Stats')
|
||||
def test_get_metrics_with_data(self, MockPStats, MockCProfile):
|
||||
mock_pstats_instance = MockPStats.return_value
|
||||
|
||||
# Decorate two different functions
|
||||
decorated_func1 = self.metrics_instance.measure(sample_function_for_metrics)
|
||||
decorated_func2 = self.metrics_instance.measure(another_sample_function)
|
||||
|
||||
# Data for func1
|
||||
func1_code = sample_function_for_metrics.__code__
|
||||
func1_key = (func1_code.co_filename, func1_code.co_firstlineno, func1_code.co_name)
|
||||
mock_pstats_instance.stats = {func1_key: (1,1,0.1,0.1,{})}
|
||||
decorated_func1()
|
||||
|
||||
# Data for func2
|
||||
func2_code = another_sample_function.__code__
|
||||
func2_key = (func2_code.co_filename, func2_code.co_firstlineno, func2_code.co_name)
|
||||
mock_pstats_instance.stats = {func2_key: (1,1,0.2,0.2,{})} # Cumulative time 0.2
|
||||
decorated_func2(1,2)
|
||||
mock_pstats_instance.stats = {func2_key: (1,1,0.3,0.3,{})} # Cumulative time 0.3 for second call
|
||||
decorated_func2(3,4)
|
||||
|
||||
metrics_data = self.metrics_instance.get_metrics()
|
||||
|
||||
self.assertIn("sample_function_for_metrics", metrics_data)
|
||||
self.assertEqual(metrics_data["sample_function_for_metrics"]["call_count"], 1)
|
||||
self.assertEqual(metrics_data["sample_function_for_metrics"]["total_time"], 0.1)
|
||||
self.assertEqual(metrics_data["sample_function_for_metrics"]["average_time"], 0.1)
|
||||
|
||||
self.assertIn("another_sample_function", metrics_data)
|
||||
self.assertEqual(metrics_data["another_sample_function"]["call_count"], 2)
|
||||
self.assertAlmostEqual(metrics_data["another_sample_function"]["total_time"], 0.5)
|
||||
self.assertAlmostEqual(metrics_data["another_sample_function"]["average_time"], 0.25)
|
||||
|
||||
|
||||
def test_clear_metrics(self):
|
||||
# Add some data
|
||||
self.metrics_instance.call_count["test_func"] = 5
|
||||
self.metrics_instance.total_time["test_func"] = 1.234
|
||||
|
||||
self.metrics_instance.clear_metrics()
|
||||
|
||||
self.assertEqual(self.metrics_instance.call_count, {})
|
||||
self.assertEqual(self.metrics_instance.total_time, {})
|
||||
self.assertEqual(self.metrics_instance.get_metrics(), {})
|
||||
|
||||
# Test the global instance
|
||||
@patch('cProfile.Profile')
|
||||
@patch('pstats.Stats')
|
||||
def test_global_metrics_instance_usage(self, MockPStats, MockCProfile):
|
||||
mock_pstats_instance = MockPStats.return_value
|
||||
|
||||
# Decorate a function with the global instance
|
||||
@global_metrics_instance.measure
|
||||
def globally_decorated_func():
|
||||
return "global_output"
|
||||
|
||||
# Setup mock stats for the globally decorated function
|
||||
# Access __wrapped__ to get the original function if other decorators might be present or for consistency.
|
||||
original_g_func = globally_decorated_func.__wrapped__
|
||||
func_code = original_g_func.__code__
|
||||
func_key = (func_code.co_filename, func_code.co_firstlineno, func_code.co_name)
|
||||
mock_pstats_instance.stats = {func_key: (1,1,0.05,0.05,{})}
|
||||
|
||||
globally_decorated_func()
|
||||
|
||||
metrics_data = global_metrics_instance.get_metrics()
|
||||
self.assertIn("globally_decorated_func", metrics_data)
|
||||
self.assertEqual(metrics_data["globally_decorated_func"]["call_count"], 1)
|
||||
self.assertEqual(metrics_data["globally_decorated_func"]["total_time"], 0.05)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -0,0 +1,161 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
import logging
|
||||
|
||||
# Ensure tools.metrics_tool and tools.metrics are accessible
|
||||
from tools.metrics_tool import MetricsTool
|
||||
from tools.metrics import Metrics # Used for typehinting and creating a mockable instance
|
||||
|
||||
class TestMetricsTool(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.mock_metrics_provider = MagicMock(spec=Metrics)
|
||||
self.logger = logging.getLogger('tools.metrics_tool.test')
|
||||
if not self.logger.handlers:
|
||||
self.logger.addHandler(logging.NullHandler())
|
||||
self.logger.propagate = False
|
||||
|
||||
|
||||
def test_init_with_provider(self):
|
||||
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
|
||||
self.assertEqual(tool.metrics_provider, self.mock_metrics_provider)
|
||||
|
||||
@patch('tools.metrics_tool.global_metrics_instance') # Patch the global instance path
|
||||
def test_init_default_provider(self, mock_global_metrics):
|
||||
tool = MetricsTool(logger=self.logger)
|
||||
self.assertEqual(tool.metrics_provider, mock_global_metrics)
|
||||
|
||||
def test_get_functions(self):
|
||||
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
|
||||
functions = tool.get_functions()
|
||||
self.assertIsInstance(functions, list)
|
||||
self.assertTrue(len(functions) == 3) # Based on current definition
|
||||
self.assertIn("get_function_metrics", [f["function"]["name"] for f in functions])
|
||||
self.assertIn("get_specific_function_metrics", [f["function"]["name"] for f in functions])
|
||||
self.assertIn("get_top_n_functions", [f["function"]["name"] for f in functions])
|
||||
|
||||
def test_execute_get_function_metrics(self):
|
||||
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
|
||||
expected_metrics = {"func1": {"call_count": 1, "total_time": 0.1}}
|
||||
self.mock_metrics_provider.get_metrics.return_value = expected_metrics
|
||||
|
||||
result = tool.execute(function_name="get_function_metrics")
|
||||
|
||||
self.mock_metrics_provider.get_metrics.assert_called_once()
|
||||
self.assertEqual(result, expected_metrics)
|
||||
|
||||
def test_execute_get_specific_function_metrics_found(self):
|
||||
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
|
||||
func_metrics = {"call_count": 5, "total_time": 0.5, "average_time": 0.1}
|
||||
all_metrics = {"specific_func": func_metrics, "other_func": {}}
|
||||
self.mock_metrics_provider.get_metrics.return_value = all_metrics
|
||||
|
||||
# The execute method expects kwargs that match the function parameters in get_functions.
|
||||
# So, the argument name for the function to get is 'function_name' in the tool's spec.
|
||||
result = tool.execute(function_name="get_specific_function_metrics", **{"function_name": "specific_func"})
|
||||
self.assertEqual(result, func_metrics)
|
||||
|
||||
def test_execute_get_specific_function_metrics_not_found(self):
|
||||
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
|
||||
self.mock_metrics_provider.get_metrics.return_value = {"other_func": {}}
|
||||
|
||||
result = tool.execute(function_name="get_specific_function_metrics", **{"function_name": "non_existent_func"})
|
||||
self.assertEqual(result, "No metrics found for function: non_existent_func")
|
||||
|
||||
def test_execute_get_specific_function_metrics_missing_arg(self):
|
||||
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
|
||||
result = tool.execute(function_name="get_specific_function_metrics") # Missing function_name kwarg
|
||||
self.assertIn("Error: Missing required argument 'function_name'", result)
|
||||
|
||||
|
||||
def test_execute_get_top_n_functions(self):
|
||||
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
|
||||
metrics_data = {
|
||||
"func_a": {"call_count": 1, "total_time": 0.3},
|
||||
"func_b": {"call_count": 1, "total_time": 0.1},
|
||||
"func_c": {"call_count": 1, "total_time": 0.5},
|
||||
"func_d": {"call_count": 1, "total_time": 0.2},
|
||||
}
|
||||
self.mock_metrics_provider.get_metrics.return_value = metrics_data
|
||||
|
||||
# Test getting top 2
|
||||
result = tool.execute(function_name="get_top_n_functions", n=2)
|
||||
expected_top_2 = {"func_c": metrics_data["func_c"], "func_a": metrics_data["func_a"]}
|
||||
self.assertEqual(result, expected_top_2)
|
||||
|
||||
# Test getting top 1
|
||||
result_top_1 = tool.execute(function_name="get_top_n_functions", n=1)
|
||||
expected_top_1 = {"func_c": metrics_data["func_c"]}
|
||||
self.assertEqual(result_top_1, expected_top_1)
|
||||
|
||||
# Test N larger than available functions
|
||||
result_top_all = tool.execute(function_name="get_top_n_functions", n=10)
|
||||
# Order should be func_c, func_a, func_d, func_b
|
||||
expected_top_all_keys = ["func_c", "func_a", "func_d", "func_b"]
|
||||
self.assertEqual(list(result_top_all.keys()), expected_top_all_keys)
|
||||
|
||||
def test_execute_get_top_n_functions_malformed_metrics(self):
|
||||
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
|
||||
metrics_data = {
|
||||
"func_a": {"call_count": 1, "total_time": 0.3},
|
||||
"func_b": "not a dict", # Malformed
|
||||
"func_c": {"call_count": 1}, # Missing total_time
|
||||
"func_d": {"call_count": 1, "total_time": 0.2},
|
||||
}
|
||||
self.mock_metrics_provider.get_metrics.return_value = metrics_data
|
||||
|
||||
metrics_tool_logger = logging.getLogger('tools.metrics_tool')
|
||||
original_level = metrics_tool_logger.level
|
||||
metrics_tool_logger.setLevel(logging.WARNING)
|
||||
with self.assertLogs(metrics_tool_logger, level='WARNING') as log_capture:
|
||||
result = tool.execute(function_name="get_top_n_functions", n=2)
|
||||
metrics_tool_logger.setLevel(original_level)
|
||||
|
||||
# Check that warnings were logged for malformed items
|
||||
self.assertTrue(any("Metric item for 'func_b' is not in expected format" in msg for msg in log_capture.output))
|
||||
self.assertTrue(any("Metric item for 'func_c' is not in expected format" in msg for msg in log_capture.output))
|
||||
|
||||
# Expected: func_a, func_d (as they are valid and sortable)
|
||||
expected_result = {
|
||||
"func_a": metrics_data["func_a"],
|
||||
"func_d": metrics_data["func_d"]
|
||||
}
|
||||
self.assertEqual(result, expected_result)
|
||||
|
||||
|
||||
def test_execute_get_top_n_functions_invalid_n(self):
|
||||
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
|
||||
self.mock_metrics_provider.get_metrics.return_value = {} # No metrics needed for this test
|
||||
|
||||
result_zero = tool.execute(function_name="get_top_n_functions", n=0)
|
||||
self.assertIn("Error: Argument 'n' must be a positive integer.", result_zero)
|
||||
|
||||
result_negative = tool.execute(function_name="get_top_n_functions", n=-1)
|
||||
self.assertIn("Error: Argument 'n' must be a positive integer.", result_negative)
|
||||
|
||||
result_string = tool.execute(function_name="get_top_n_functions", n="abc")
|
||||
self.assertIn("Error: Argument 'n' must be an integer.", result_string)
|
||||
|
||||
def test_execute_get_top_n_functions_missing_arg_n(self):
|
||||
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
|
||||
result = tool.execute(function_name="get_top_n_functions") # Missing n
|
||||
self.assertIn("Error: Missing required argument 'n'.", result)
|
||||
|
||||
|
||||
def test_execute_unknown_function(self):
|
||||
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
|
||||
result = tool.execute(function_name="non_existent_metrics_function")
|
||||
self.assertIn("Unknown function: non_existent_metrics_function", result)
|
||||
|
||||
def test_clear_method(self):
|
||||
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
|
||||
metrics_tool_logger = logging.getLogger('tools.metrics_tool')
|
||||
original_level = metrics_tool_logger.level
|
||||
metrics_tool_logger.setLevel(logging.DEBUG)
|
||||
with self.assertLogs(metrics_tool_logger, level='DEBUG') as cm:
|
||||
tool.clear()
|
||||
metrics_tool_logger.setLevel(original_level)
|
||||
self.assertTrue(any("MetricsTool clear method called" in message for message in cm.output))
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
+411
-474
File diff suppressed because it is too large
Load Diff
+55
-42
@@ -1,5 +1,4 @@
|
||||
# tools/log_tool.py
|
||||
|
||||
from .base_tool import BaseTool
|
||||
from .metrics import metrics
|
||||
import logging
|
||||
@@ -7,48 +6,39 @@ import os
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
class LogTool(BaseTool):
|
||||
def __init__(self):
|
||||
# Set up logging
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.logger.setLevel(logging.INFO)
|
||||
# Default log format string that _get_log_contents expects for time-based filtering.
|
||||
# Making it a class variable so it's visible and could be overridden by a subclass if needed,
|
||||
# though the parser is still hardcoded in this version.
|
||||
EXPECTED_LOG_TIMESTAMP_FORMAT = '%Y-%m-%d %H:%M:%S,%f'
|
||||
|
||||
# Create a file handler
|
||||
file_handler = logging.FileHandler('log_tool.log')
|
||||
file_handler.setLevel(logging.INFO)
|
||||
|
||||
# Create a console handler
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setLevel(logging.INFO)
|
||||
|
||||
# Create a formatting for the logs
|
||||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
file_handler.setFormatter(formatter)
|
||||
console_handler.setFormatter(formatter)
|
||||
|
||||
# Add the handlers to the logger
|
||||
self.logger.addHandler(file_handler)
|
||||
self.logger.addHandler(console_handler)
|
||||
def __init__(self, log_file_path=None, logger=None):
|
||||
self.configured_log_file_path = log_file_path if log_file_path else 'logs/output.log'
|
||||
self.logger = logger if logger else logging.getLogger(__name__)
|
||||
if not self.logger.handlers:
|
||||
self.logger.addHandler(logging.NullHandler())
|
||||
self.logger.info(f"LogTool initialized. Log file path: {self.configured_log_file_path}")
|
||||
|
||||
def clear(self):
|
||||
# No specific state to clear for LogTool in this version.
|
||||
self.logger.debug("LogTool clear called, no action taken.")
|
||||
pass
|
||||
|
||||
def get_functions(self):
|
||||
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_log_contents",
|
||||
"description": "Get the contents of the log file.",
|
||||
"description": "Get the contents of the log file. If line_count is not provided, it attempts to return logs from the last 24 hours based on timestamp.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"line_count": {
|
||||
"type": "integer",
|
||||
"description": "Number of lines from the end of the log file to retrieve"
|
||||
"description": "Number of lines from the end of the log file to retrieve. If omitted, logs from last 24 hours are returned."
|
||||
}
|
||||
},
|
||||
"required": []
|
||||
"required": [] # line_count is optional
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -56,37 +46,60 @@ class LogTool(BaseTool):
|
||||
|
||||
@metrics.measure
|
||||
def execute(self, function_name, **kwargs):
|
||||
self.logger.info(f"Executing: {function_name}")
|
||||
|
||||
self.logger.info(f"Executing LogTool function: {function_name} with args: {kwargs}")
|
||||
if function_name == "get_log_contents":
|
||||
return self._get_log_contents(kwargs.get("line_count"))
|
||||
# kwargs.get("line_count") will be None if not provided, which is handled by _get_log_contents
|
||||
return self._get_log_contents(line_count=kwargs.get("line_count"))
|
||||
else:
|
||||
error_message = f"Unknown function: {function_name}"
|
||||
self.logger.error(error_message)
|
||||
return error_message
|
||||
|
||||
@metrics.measure
|
||||
def _get_log_contents(self, line_count=150):
|
||||
log_file_path = 'logs/output.log'
|
||||
def _get_log_contents(self, line_count=None): # Default line_count is None to trigger 24h logic if not specified
|
||||
self.logger.info(f"Attempting to get log contents from: {self.configured_log_file_path}. Line count: {line_count if line_count is not None else 'Last 24 hours'}")
|
||||
|
||||
if not os.path.exists(log_file_path):
|
||||
error_message = "Log file does not exist."
|
||||
if not os.path.exists(self.configured_log_file_path):
|
||||
error_message = f"Log file does not exist at path: {self.configured_log_file_path}"
|
||||
self.logger.error(error_message)
|
||||
return error_message
|
||||
|
||||
try:
|
||||
with open(log_file_path, 'r') as log_file:
|
||||
with open(self.configured_log_file_path, 'r', encoding='utf-8') as log_file:
|
||||
log_lines = log_file.readlines()
|
||||
self.logger.debug(f"Read {len(log_lines)} total lines from log file.")
|
||||
|
||||
if line_count is not None:
|
||||
log_lines = log_lines[-line_count:]
|
||||
else:
|
||||
now = datetime.now()
|
||||
twenty_four_hours_ago = now - timedelta(days=1)
|
||||
log_lines = [line for line in log_lines if datetime.strptime(line.split(' - ')[0], '%Y-%m-%d %H:%M:%S,%f') > twenty_four_hours_ago]
|
||||
if line_count is not None:
|
||||
# Ensure line_count is positive if specified, otherwise could lead to unexpected slicing
|
||||
if not isinstance(line_count, int) or line_count <= 0:
|
||||
self.logger.warning(f"Invalid line_count '{line_count}' provided, defaulting to fetch last 150 lines.")
|
||||
line_count = 150 # Default to a sensible number if invalid value provided
|
||||
log_lines = log_lines[-line_count:]
|
||||
self.logger.info(f"Returning last {len(log_lines)} lines based on line_count: {line_count}")
|
||||
else:
|
||||
# Default to last 24 hours if line_count is explicitly None or not provided
|
||||
self.logger.info(f"Filtering logs for the last 24 hours. Expected timestamp format: {self.EXPECTED_LOG_TIMESTAMP_FORMAT}")
|
||||
now = datetime.now()
|
||||
twenty_four_hours_ago = now - timedelta(days=1)
|
||||
|
||||
return "".join(log_lines)
|
||||
filtered_lines = []
|
||||
for line in log_lines:
|
||||
try:
|
||||
# Attempt to parse timestamp from the beginning of the line
|
||||
timestamp_str = line.split(' - ', 1)[0]
|
||||
log_time = datetime.strptime(timestamp_str, self.EXPECTED_LOG_TIMESTAMP_FORMAT)
|
||||
if log_time > twenty_four_hours_ago:
|
||||
filtered_lines.append(line)
|
||||
except (ValueError, IndexError) as e:
|
||||
# This means the line doesn't start with a parsable timestamp in the expected format.
|
||||
# Depending on requirements, these lines might be skipped or included.
|
||||
# For strict 24-hour filtering, we skip them.
|
||||
self.logger.debug(f"Skipping line due to timestamp parse error ('{e}') or format mismatch: {line.strip()[:100]}...")
|
||||
log_lines = filtered_lines
|
||||
self.logger.info(f"Returning {len(log_lines)} lines from the last 24 hours.")
|
||||
|
||||
return "".join(log_lines)
|
||||
except Exception as e:
|
||||
error_message = f"An error occurred while reading the log file: {e}"
|
||||
self.logger.error(error_message)
|
||||
error_message = f"An error occurred while reading the log file '{self.configured_log_file_path}': {e}"
|
||||
self.logger.error(error_message, exc_info=True)
|
||||
return error_message
|
||||
+51
-17
@@ -1,13 +1,19 @@
|
||||
# tools/metrics.py
|
||||
import cProfile
|
||||
import pstats
|
||||
import io
|
||||
from functools import wraps
|
||||
from collections import defaultdict
|
||||
import logging
|
||||
|
||||
class Metrics:
|
||||
def __init__(self):
|
||||
def __init__(self, logger=None):
|
||||
self.call_count = defaultdict(int)
|
||||
self.total_time = defaultdict(float)
|
||||
self.logger = logger if logger else logging.getLogger(__name__)
|
||||
if not self.logger.handlers:
|
||||
self.logger.addHandler(logging.NullHandler())
|
||||
self.logger.debug("Metrics instance initialized.")
|
||||
|
||||
def measure(self, func):
|
||||
@wraps(func)
|
||||
@@ -16,30 +22,58 @@ class Metrics:
|
||||
|
||||
pr = cProfile.Profile()
|
||||
pr.enable()
|
||||
|
||||
result = func(*args, **kwargs)
|
||||
|
||||
pr.disable()
|
||||
s = io.StringIO()
|
||||
ps = pstats.Stats(pr, stream=s).sort_stats('cumulative')
|
||||
ps.print_stats()
|
||||
|
||||
# Extract the total time spent in the function
|
||||
time_spent = float(s.getvalue().split('\n')[0].split()[-2])
|
||||
self.total_time[func.__name__] += time_spent
|
||||
ps = pstats.Stats(pr)
|
||||
|
||||
func_code = func.__code__
|
||||
func_key_tuple = (func_code.co_filename, func_code.co_firstlineno, func_code.co_name)
|
||||
|
||||
time_spent_for_func = 0.0
|
||||
if func_key_tuple in ps.stats:
|
||||
time_spent_for_func = ps.stats[func_key_tuple][3] # [3] is cumulative time (ct)
|
||||
else:
|
||||
# Fallback: try to find by function name if exact key fails (e.g. due to decorators changing code object details slightly)
|
||||
# This is less precise and might pick up other functions if names are not unique across files.
|
||||
found_by_name = False
|
||||
for key, stat in ps.stats.items():
|
||||
if key[2] == func.__name__: # key[2] is function name
|
||||
time_spent_for_func = stat[3] # cumulative time
|
||||
self.logger.debug(f"Found stats for {func.__name__} by name {key} after primary key failed.")
|
||||
found_by_name = True
|
||||
break
|
||||
if not found_by_name:
|
||||
self.logger.warning(
|
||||
f"Could not find exact cProfile stats for {func.__name__} with key {func_key_tuple} or by name. "
|
||||
f"Time for this call will be recorded as 0. This might occur for non-Python functions or due to complex decorators."
|
||||
)
|
||||
|
||||
self.total_time[func.__name__] += time_spent_for_func
|
||||
self.logger.debug(f"Measured cumulative time for {func.__name__}: {time_spent_for_func:.6f}s")
|
||||
|
||||
return result
|
||||
return wrapper
|
||||
|
||||
def get_metrics(self):
|
||||
metrics = {}
|
||||
metrics_data = {}
|
||||
for func_name in self.call_count:
|
||||
metrics[func_name] = {
|
||||
'call_count': self.call_count[func_name],
|
||||
'total_time': self.total_time[func_name],
|
||||
'average_time': self.total_time[func_name] / self.call_count[func_name] if self.call_count[func_name] > 0 else 0
|
||||
count = self.call_count[func_name]
|
||||
total_t = self.total_time[func_name]
|
||||
metrics_data[func_name] = {
|
||||
'call_count': count,
|
||||
'total_time': round(total_t, 6),
|
||||
'average_time': round(total_t / count, 6) if count > 0 else 0
|
||||
}
|
||||
return metrics
|
||||
return metrics_data
|
||||
|
||||
# Create a global instance of Metrics
|
||||
metrics = Metrics()
|
||||
def clear_metrics(self):
|
||||
self.call_count.clear()
|
||||
self.total_time.clear()
|
||||
self.logger.info("Metrics cleared.")
|
||||
|
||||
# Global instance for convenience
|
||||
_metrics_instance_logger = logging.getLogger(__name__ + ".global_instance")
|
||||
if not _metrics_instance_logger.handlers:
|
||||
_metrics_instance_logger.addHandler(logging.NullHandler())
|
||||
metrics = Metrics(logger=_metrics_instance_logger)
|
||||
|
||||
+59
-15
@@ -1,13 +1,22 @@
|
||||
# tools/metrics_tool.py
|
||||
|
||||
from .base_tool import BaseTool
|
||||
from .metrics import metrics
|
||||
from .metrics import metrics as global_metrics_instance # For default and measuring execute
|
||||
from .metrics import Metrics # For type hinting and potentially creating a new one if needed
|
||||
import logging
|
||||
|
||||
class MetricsTool(BaseTool):
|
||||
def __init__(self):
|
||||
self.metrics = metrics
|
||||
def __init__(self, metrics_provider: Metrics | None = None, logger: logging.Logger | None = None):
|
||||
self.metrics_provider = metrics_provider if metrics_provider is not None else global_metrics_instance
|
||||
self.logger = logger if logger else logging.getLogger(__name__)
|
||||
if not self.logger.handlers:
|
||||
self.logger.addHandler(logging.NullHandler())
|
||||
self.logger.debug(f"MetricsTool initialized. Using metrics provider: {self.metrics_provider}")
|
||||
|
||||
def clear(self):
|
||||
# This tool itself doesn't hold state that needs clearing beyond what its metrics_provider might do.
|
||||
# If this tool were responsible for clearing the metrics it reports on, it would call:
|
||||
# self.metrics_provider.clear_metrics()
|
||||
self.logger.debug("MetricsTool clear method called. No local state to clear.")
|
||||
pass
|
||||
|
||||
def get_functions(self):
|
||||
@@ -60,25 +69,60 @@ class MetricsTool(BaseTool):
|
||||
}
|
||||
]
|
||||
|
||||
@metrics.measure
|
||||
@global_metrics_instance.measure # The execute method can be measured by the global instance
|
||||
def execute(self, function_name, **kwargs):
|
||||
self.logger.info(f"Executing MetricsTool function: {function_name} with args: {kwargs}")
|
||||
if function_name == "get_function_metrics":
|
||||
return self._get_function_metrics()
|
||||
elif function_name == "get_specific_function_metrics":
|
||||
return self._get_specific_function_metrics(kwargs.get("function_name"))
|
||||
func_name_arg = kwargs.get("function_name")
|
||||
if func_name_arg is None: # Check if None, as empty string could be a valid (though unlikely) func name
|
||||
self.logger.warning("'function_name' argument is missing for get_specific_function_metrics.")
|
||||
return "Error: Missing required argument 'function_name'."
|
||||
return self._get_specific_function_metrics(str(func_name_arg)) # Ensure string
|
||||
elif function_name == "get_top_n_functions":
|
||||
return self._get_top_n_functions(kwargs.get("n"))
|
||||
n_arg = kwargs.get("n")
|
||||
if n_arg is None:
|
||||
self.logger.warning("'n' argument is missing for get_top_n_functions.")
|
||||
return "Error: Missing required argument 'n'."
|
||||
try:
|
||||
n_val = int(n_arg)
|
||||
if n_val <= 0:
|
||||
self.logger.warning(f"'n' argument must be a positive integer, got {n_val}.")
|
||||
return "Error: Argument 'n' must be a positive integer."
|
||||
return self._get_top_n_functions(n_val)
|
||||
except ValueError:
|
||||
self.logger.warning(f"'n' argument must be an integer, got '{n_arg}'.")
|
||||
return "Error: Argument 'n' must be an integer."
|
||||
else:
|
||||
return f"Unknown function: {function_name}"
|
||||
error_message = f"Unknown function: {function_name}"
|
||||
self.logger.error(error_message)
|
||||
return error_message
|
||||
|
||||
def _get_function_metrics(self):
|
||||
return self.metrics.get_metrics()
|
||||
self.logger.debug("Calling metrics_provider.get_metrics() for all functions.")
|
||||
return self.metrics_provider.get_metrics()
|
||||
|
||||
def _get_specific_function_metrics(self, function_name):
|
||||
all_metrics = self.metrics.get_metrics()
|
||||
return all_metrics.get(function_name, f"No metrics found for function: {function_name}")
|
||||
def _get_specific_function_metrics(self, function_to_get):
|
||||
self.logger.debug(f"Getting metrics for specific function: {function_to_get}")
|
||||
all_metrics = self.metrics_provider.get_metrics()
|
||||
return all_metrics.get(function_to_get, f"No metrics found for function: {function_to_get}")
|
||||
|
||||
def _get_top_n_functions(self, n):
|
||||
all_metrics = self.metrics.get_metrics()
|
||||
sorted_metrics = sorted(all_metrics.items(), key=lambda x: x[1]['total_time'], reverse=True)
|
||||
return dict(sorted_metrics[:n])
|
||||
self.logger.debug(f"Getting top {n} functions by total execution time.")
|
||||
all_metrics = self.metrics_provider.get_metrics()
|
||||
# Ensure that the items are actual metric dicts before trying to access 'total_time'
|
||||
valid_metrics_items = []
|
||||
for name, metric_values in all_metrics.items():
|
||||
if isinstance(metric_values, dict) and 'total_time' in metric_values:
|
||||
valid_metrics_items.append((name, metric_values))
|
||||
else:
|
||||
self.logger.warning(f"Metric item for '{name}' is not in expected format: {metric_values}")
|
||||
|
||||
# Sort items by total_time. items() gives list of (func_name, metrics_dict)
|
||||
try:
|
||||
sorted_metrics = sorted(valid_metrics_items, key=lambda item: item[1]['total_time'], reverse=True)
|
||||
return dict(sorted_metrics[:n])
|
||||
except TypeError as e:
|
||||
self.logger.error(f"Error sorting metrics, possibly due to unexpected data types: {e}", exc_info=True)
|
||||
return "Error: Could not sort metrics due to unexpected data."
|
||||
|
||||
Reference in New Issue
Block a user