Refactored gemini, openai and claude into one file and removed logic from the base class, also made helper class definable from command line

This commit is contained in:
2025-06-03 13:04:42 -05:00
parent bd0ce3e340
commit f15228fa58
36 changed files with 487 additions and 3847 deletions
+35 -9
View File
@@ -1,14 +1,40 @@
# Telegram Bot Tokens TELEGRAM_BOT_TOKEN=your_bot_token_here
TELEGRAM_BOT_TOKEN=your_daemon_bot_token_here PYTHONPATH=${workspaceFolder}
TELEGRAM_APPRENTICE_BOT_TOKEN=your_apprentice_bot_token_here GITHUB_TOKEN=your_github_personal_access_token_here
GITHUB_REPOSITORY=your_github_username_or_organization/your_repo_name
GITHUB_REPO_OWNER=your_github_username_or_organization
SYSTEM_PROMPT_PATH=./prompts/project_manager_prompt.txt
ACTIVE_MODEL_PROFILE=OPENAI # Options: OPENAI, GEMINI, GLHF_CHAT
# Create a new profile with these settings:
# {MODEL_PROFILE}_API_KEY
# {MODEL_PROFILE}_API_BASE_URL # Optional for OpenAI
# {MODEL_PROFILE}_SMALL_MODEL
# {MODEL_PROFILE}_SMALL_MODEL_MAX_TOKENS
# {MODEL_PROFILE}_LARGE_MODEL
# {MODEL_PROFILE}_LARGE_MODEL_MAX_TOKENS
# OpenAI API Key # OpenAI API Key
OPENAI_API_KEY=your_openai_api_key_here OPENAI_API_KEY=your_openai_api_key_here
OPENAI_SMALL_MODEL=gpt-4.1-mini
OPENAI_SMALL_MODEL_MAX_TOKENS=32768
OPENAI_LARGE_MODEL=gpt-4.1
OPENAI_LARGE_MODEL_MAX_TOKENS=32768
# Anthropic API Key # Gemini API
ANTHROPIC_API_KEY=your_anthropic_api_key_here GEMINI_API_KEY=your_gemini_api_key_here
GEMINI_API_BASE_URL=https://generativelanguage.googleapis.com/v1beta/openai/
GEMINI_SMALL_MODEL=gemini-2.5-flash-preview-05-20
GEMINI_SMALL_MODEL_MAX_TOKENS=65536
GEMINI_LARGE_MODEL=gemini-2.5-pro-preview-05-06
GEMINI_LARGE_MODEL_MAX_TOKENS=65536
# GitHub Repository Information # GLHF Chat API Key
GITHUB_REPO_OWNER=your_github_username_or_organization GLHF_CHAT_API_KEY=your_glhf_chat_api_key_here
GITHUB_REPO_NAME=your_repo_name GLHF_CHAT_API_BASE_URL=https://glhf.chat/api/openai/v1
GITHUB_ACCESS_TOKEN=your_github_personal_access_token GLHF_CHAT_SMALL_MODEL=meta-llama/Llama-3.3-70B-Instruct
GLHF_CHAT_SMALL_MODEL_MAX_TOKENS=1024
GLHF_CHAT_LARGE_MODEL=deepseek-ai/DeepSeek-V3-0324
GLHF_CHAT_LARGE_MODEL_MAX_TOKENS=1024
-271
View File
@@ -1,271 +0,0 @@
import os
import json
import logging
from anthropic import Anthropic, APIError, RateLimitError
from base_telegram_inference_bot import BaseTelegramInferenceBot
from telegram_helper import TelegramHelper # Used in main, not class
class AnthropicTelegramInferenceBot(BaseTelegramInferenceBot):
DEFAULT_SMALL_MODEL_NAME = "claude-3-haiku-20240307"
DEFAULT_SMALL_MODEL_MAX_TOKENS = "2048"
DEFAULT_LARGE_MODEL_NAME = "claude-3-opus-20240229"
DEFAULT_LARGE_MODEL_MAX_TOKENS = "4096"
def __init__(
self,
anthropic_client: Anthropic | None = None,
api_key: str | None = None,
small_model_name: str | None = None,
small_model_max_tokens: str | None = None,
large_model_name: str | None = None,
large_model_max_tokens: str | None = None,
system_prompt_content: str | None = None,
system_prompt_path: str | None = None
):
super().__init__(system_prompt_content=system_prompt_content, system_prompt_path=system_prompt_path)
if anthropic_client:
self.anthropic_client = anthropic_client
else:
_api_key = api_key or os.environ.get("ANTHROPIC_API_KEY")
if not _api_key:
raise ValueError("Anthropic API key must be provided either via argument or ANTHROPIC_API_KEY environment variable.")
self.anthropic_client = Anthropic(api_key=_api_key)
self.small_model_name = small_model_name or os.environ.get("ANTHROPIC_SMALL_MODEL") or self.DEFAULT_SMALL_MODEL_NAME
self.small_model_max_tokens_str = small_model_max_tokens or os.environ.get("ANTHROPIC_SMALL_MODEL_MAX_TOKENS") or self.DEFAULT_SMALL_MODEL_MAX_TOKENS
self.large_model_name = large_model_name or os.environ.get("ANTHROPIC_LARGE_MODEL") or self.DEFAULT_LARGE_MODEL_NAME
self.large_model_max_tokens_str = large_model_max_tokens or os.environ.get("ANTHROPIC_LARGE_MODEL_MAX_TOKENS") or self.DEFAULT_LARGE_MODEL_MAX_TOKENS
# Initialize with the small model by default
self._configure_model_and_tokens(
self.small_model_name,
self.small_model_max_tokens_str,
default_max_tokens=int(self.DEFAULT_SMALL_MODEL_MAX_TOKENS) # pass int for default
)
def _configure_model_and_tokens(self, model_name: str, max_tokens_str: str, default_max_tokens: int = 2048):
self.model = model_name
try:
self.max_tokens = int(max_tokens_str) if max_tokens_str is not None else default_max_tokens
except ValueError:
logging.error(f"Invalid value for Anthropic max_tokens: {max_tokens_str}. Using default {default_max_tokens}.")
self.max_tokens = default_max_tokens
logging.info(f"Configured to use Anthropic model: {self.model} with max_tokens: {self.max_tokens}")
def get_llm_description(self) -> str:
return f"LLM: {self.model}, Max Tokens: {self.max_tokens}"
def get_chat_response(self, messages_history):
current_system_prompt = self.system_prompt if self.system_prompt else ""
anthropic_tools = []
if hasattr(self, 'functions') and self.functions:
anthropic_tools = [
{
"name": function['function']['name'],
"description": function['function']['description'],
"input_schema": function['function']['parameters'] if function['function']['parameters'] not in [None, {}] else {"type": "object", "properties": {}}
}
for function in self.functions
]
try:
response = self.anthropic_client.messages.create(
model=self.model,
system=current_system_prompt,
messages=messages_history,
max_tokens=self.max_tokens,
tools=anthropic_tools if anthropic_tools else None,
tool_choice={"type": "auto"} if anthropic_tools else None
)
return response
except (APIError, RateLimitError) as e:
logging.error(f"Anthropic API error: {e}")
raise
except Exception as e:
logging.error(f"An unexpected error occurred during Anthropic API call: {e}")
raise
def _format_tool_response_for_anthropic(self, tool_response_data):
if isinstance(tool_response_data, str):
# Wrap plain string in a list of text blocks if not already structured
return [{"type": "text", "text": tool_response_data}]
elif isinstance(tool_response_data, list) and all(isinstance(item, dict) and "type" in item for item in tool_response_data):
# Already a list of content blocks
return tool_response_data
elif isinstance(tool_response_data, (dict, list)):
# Attempt to JSON dump other dicts/lists if not already in content block format
try:
return [{"type": "text", "text": json.dumps(tool_response_data)}]
except (TypeError, json.JSONDecodeError):
return [{"type": "text", "text": str(tool_response_data)}] # Fallback to string
else:
# Fallback for other types (int, float, etc.)
return [{"type": "text", "text": str(tool_response_data)}]
async def handle_message(self, user_id, user_message):
if user_id not in self.conversation_history:
self.conversation_history[user_id] = []
self.conversation_history[user_id].append({"role": "user", "content": user_message})
current_turn_messages = list(self.conversation_history[user_id])
MAX_TOOL_ITERATIONS = 5
tool_use_count = 0
assistant_response_content = ""
while tool_use_count < MAX_TOOL_ITERATIONS:
response = self.get_chat_response(current_turn_messages)
if not response or not response.content:
logging.error("No valid response content from Anthropic LLM.")
self.conversation_history[user_id] = current_turn_messages # Save current state
return "Error: Could not get a valid response from the LLM."
assistant_current_turn_content_blocks = response.content
current_turn_messages.append({"role": "assistant", "content": assistant_current_turn_content_blocks})
text_parts_from_assistant = []
tool_calls_from_response = []
for block in assistant_current_turn_content_blocks:
if block.type == "text":
text_parts_from_assistant.append(block.text)
elif block.type == "tool_use":
tool_calls_from_response.append(block)
assistant_response_content = "".join(text_parts_from_assistant)
if not tool_calls_from_response:
break
tool_results_for_model = []
for tool_call in tool_calls_from_response:
tool_name = tool_call.name
tool_input = tool_call.input
tool_use_id = tool_call.id
logging.info(f"Attempting to call Anthropic tool: {tool_name} with input: {tool_input}")
try:
tool_response_data = self.call_tool(tool_name, tool_input)
tool_result_content_block = self._format_tool_response_for_anthropic(tool_response_data)
tool_results_for_model.append({
"type": "tool_result",
"tool_use_id": tool_use_id,
"content": tool_result_content_block
})
except Exception as e:
logging.error(f"Error calling tool {tool_name}: {e}")
tool_results_for_model.append({
"type": "tool_result",
"tool_use_id": tool_use_id,
"content": [{"type": "text", "text": f"Error executing tool {tool_name}: {str(e)}"}],
"is_error": True
})
current_turn_messages.append({"role": "user", "content": tool_results_for_model}) # Anthropic expects tool results as a user message
tool_use_count += 1
if tool_use_count >= MAX_TOOL_ITERATIONS:
logging.warning(f"Max tool iterations ({MAX_TOOL_ITERATIONS}) reached for Anthropic.")
# Update assistant_response_content with any text from the last assistant turn before breaking
if not assistant_response_content and text_parts_from_assistant:
assistant_response_content = "".join(text_parts_from_assistant)
assistant_response_content += "\n[Max tool iterations reached]"
break
self.conversation_history[user_id] = current_turn_messages
if len(self.conversation_history[user_id]) > 20:
self.conversation_history[user_id] = self.conversation_history[user_id][-20:]
if assistant_response_content:
return assistant_response_content
else:
# Fallback if no text parts were found but there was an assistant message
if current_turn_messages:
last_message_in_turn = current_turn_messages[-1]
# Check if the last message content has text blocks (Anthropic specific structure)
if last_message_in_turn.get("role") == "assistant" and isinstance(last_message_in_turn.get("content"), list):
for block in reversed(last_message_in_turn["content"]):
if block.type == "text" and hasattr(block, 'text') and block.text:
return block.text # Return the first non-empty text found from the end
return "No textual response generated by the assistant after processing." # More informative default
async def start(self):
logging.info("Anthropic Bot started")
# clear_conversation_history is inherited from BaseTelegramInferenceBot and calls super().clear_conversation_history
# No need to override if the base implementation is sufficient, unless specific logging is needed.
# async def clear_conversation_history(self, user_id):
# super().clear_conversation_history(user_id)
# logging.info(f"Cleared conversation history for Anthropic bot, user {user_id}")
async def abort_processing(self, user_id):
# This abort is a soft abort, as actual Anthropic API call is synchronous within handle_message
# It primarily clears state and prevents further processing in the bot's loop if any.
if user_id in self.processing_status:
self.processing_status[user_id]["processing"] = False # Mark as not processing
# self.clear_processing_status(user_id) # Use base class method to remove entry
# Clearing history might be too aggressive for a simple abort, depends on desired UX
# For now, let's just stop processing and clear the flag.
# Consider if conversation history should be cleared here or if that is a separate user action.
# super().clear_conversation_history(user_id) # Moved to be less aggressive
logging.info(f"Abort requested for user {user_id}. Processing flag cleared.")
return "Processing aborted. You can send a new message or /clear the conversation."
async def switch_model(self):
if not self.small_model_name or not self.large_model_name:
logging.warning("Small or Large model names for Anthropic are not defined. Cannot switch model.")
return f"Model switching not fully configured. Currently using {self.model}."
current_is_small = self.model == self.small_model_name
current_is_large = self.model == self.large_model_name
if current_is_small:
target_model = self.large_model_name
target_max_tokens_str = self.large_model_max_tokens_str
default_target_max_tokens = int(self.DEFAULT_LARGE_MODEL_MAX_TOKENS)
elif current_is_large:
target_model = self.small_model_name
target_max_tokens_str = self.small_model_max_tokens_str
default_target_max_tokens = int(self.DEFAULT_SMALL_MODEL_MAX_TOKENS)
else:
logging.warning(f"Current model {self.model} is unrecognized. Switching to default small model.")
target_model = self.small_model_name
target_max_tokens_str = self.small_model_max_tokens_str
default_target_max_tokens = int(self.DEFAULT_SMALL_MODEL_MAX_TOKENS)
self._configure_model_and_tokens(target_model, target_max_tokens_str, default_max_tokens=default_target_max_tokens)
logging.info(f"Switched Anthropic model to: {self.model}")
return f"Switched to Anthropic model: {self.model} (Max Tokens: {self.max_tokens})"
# The main function is for standalone execution and basic testing, not part of the class itself.
# It's good practice to update it to reflect changes if you use it for quick tests.
# For unit tests, we'll instantiate the class with mocked dependencies.
def main():
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# Example of how to instantiate with new constructor (assuming API key is in ENV for this example)
# For real tests, you'd mock Anthropic() or pass a mock client.
try:
# These would typically come from a config file or CLI args in a real app if not ENV
# For this example, we rely on ENV or defaults being handled by constructor if not provided.
bot = AnthropicTelegramInferenceBot(
api_key=os.environ.get("ANTHROPIC_API_KEY") # Explicitly pass, or let constructor handle ENV
)
except ValueError as e:
logging.error(f"Failed to initialize bot: {e}")
return
except Exception as e: # Catch any other init errors
logging.error(f"An unexpected error occurred during bot initialization: {e}")
return
# TelegramHelper also updated, ensure it's instantiated correctly for this main context.
# For this basic main, we might not pass all configurable paths to TelegramHelper,
# letting them use defaults.
telegram_helper = TelegramHelper(bot)
telegram_helper.run()
if __name__ == '__main__':
main()
-164
View File
@@ -1,164 +0,0 @@
import importlib
import os
import json
import inspect
import logging
from abc import ABC, abstractmethod
from tools.base_tool import BaseTool
class BaseTelegramInferenceBot(ABC):
def __init__(self, system_prompt_content: str | None = None, system_prompt_path: str | None = None): # MODIFIED
self.conversation_history = {}
self.processing_status = {}
# MODIFIED to pass arguments
self.system_prompt = self.load_system_prompt(
direct_content=system_prompt_content,
file_path=system_prompt_path
)
self.tools, self.functions = self.load_functions()
# Logging the actual source of the system prompt might be more complex now,
# but we can log the final prompt or indicate if it's custom/default.
# We'll also log the source of the prompt inside load_system_prompt.
logging.info(f'System Prompt (effective): {"Custom" if self.system_prompt != "You are a helpful AI assistant." else "Default"}')
logging.info(f'Github Repository: {os.environ.get("GITHUB_REPOSITORY")}')
def load_system_prompt(self, direct_content: str | None = None, file_path: str | None = None) -> str: # MODIFIED
default_prompt = "You are a helpful AI assistant."
if direct_content:
logging.info("Using direct content for system prompt.")
return direct_content.strip()
prompt_path_to_try = file_path or os.getenv("SYSTEM_PROMPT_PATH")
if prompt_path_to_try:
if os.path.isfile(prompt_path_to_try):
try:
with open(prompt_path_to_try, "r", encoding="utf-8") as file:
content = file.read().strip()
logging.info(f"Successfully loaded system prompt from {prompt_path_to_try}.")
return content
except IOError as e:
logging.warning(f"Could not read system prompt file {prompt_path_to_try}: {e}. Using default.")
return default_prompt
else:
# This condition now also covers if 'file_path' argument was given but invalid
logging.warning(f"System prompt file {prompt_path_to_try} not found. Using default system prompt.")
return default_prompt
else:
logging.info("No system prompt path provided (argument or ENV) or direct content. Using default system prompt.")
return default_prompt
def load_functions(self):
tools = []
functions = []
tools_dir = os.path.join(os.path.dirname(__file__), 'tools')
if not os.path.exists(tools_dir):
logging.warning(f"Tools directory not found: {tools_dir}")
return [], []
for filename in os.listdir(tools_dir):
if filename.endswith('.py') and filename != '__init__.py' and filename != 'base_tool.py':
module_name = f'tools.{filename[:-3]}'
try:
module = importlib.import_module(module_name)
for name, obj in inspect.getmembers(module):
if inspect.isclass(obj) and issubclass(obj, BaseTool) and obj != BaseTool:
try:
tools.append(obj()) # This instantiation might be an issue for tools needing config
except Exception as e:
logging.error(f"Error instantiating tool {name} from {filename}: {e}")
except Exception as e:
logging.error(f"Error importing module {module_name}: {e}")
for tool in tools:
functions.extend(tool.get_functions())
return tools, functions
@abstractmethod
def get_chat_response(self, messages):
pass
@abstractmethod
async def handle_message(self, user_id, user_message):
pass
def clear_conversation_history(self, user_id):
if user_id in self.conversation_history:
del self.conversation_history[user_id]
for tool in self.tools:
tool.clear()
def set_processing_status(self, user_id: int, message_id: int):
self.processing_status[user_id] = {"processing": True, "message_id": message_id}
def clear_processing_status(self, user_id: int):
if user_id in self.processing_status:
del self.processing_status[user_id]
def call_tool(self, function_call_name, function_call_arguments):
function_name = function_call_name
function_args = None
if isinstance(function_call_arguments, dict):
function_args = function_call_arguments
elif isinstance(function_call_arguments, str):
try:
function_args = json.loads(function_call_arguments)
except json.JSONDecodeError as e:
logging.error(f"Error decoding function call arguments (string) for {function_call_name}: {e}. Arguments: {function_call_arguments}")
return f"Error: Malformed arguments for tool call: {e}"
else:
if function_call_arguments is None:
function_args = {}
else:
logging.error(f"Unexpected type for function_call_arguments for {function_call_name}: {type(function_call_arguments)}. Arguments: {function_call_arguments}")
return f"Error: Invalid argument type for tool call: {type(function_call_arguments)}"
for tool in self.tools:
for function in tool.get_functions():
if function["function"]["name"] == function_name:
try:
if not isinstance(function_args, dict):
logging.error(f"Internal error: function_args not a dict for {function_name} before execution. Args: {function_args}")
return f"Internal error preparing arguments for tool {function_name}."
return tool.execute(function_name, **function_args)
except Exception as e:
logging.error(f"Error executing tool {function_name} with args {function_args}: {e}")
return f"Error executing tool {function_name}: {e}"
logging.warning(f"Tool function {function_name} not found.")
return f"Error: Tool function {function_name} not found."
def get_system_prompt_description(self) -> str:
# This method could be updated to be more specific about the prompt source if needed.
# For now, it still reflects custom vs default based on the original ENV var logic's spirit.
# A more accurate reflection would require storing how the prompt was loaded.
# For simplicity, let's assume if it's not the default, it's "Custom".
if self.system_prompt != "You are a helpful AI assistant.":
return "System Prompt: Custom"
# Check original ENV var for backward compatibility in description only
elif os.getenv('SYSTEM_PROMPT_PATH'):
return "System Prompt: Custom (via ENV)"
return "System Prompt: Default"
@abstractmethod
def get_llm_description(self) -> str:
pass
async def get_bot_status(self) -> str:
prompt_desc = self.get_system_prompt_description()
llm_desc = self.get_llm_description()
return f"{prompt_desc}\n{llm_desc}"
@abstractmethod
async def start(self):
pass
@abstractmethod
async def abort_processing(self, user_id):
pass
@abstractmethod
async def switch_model(self):
pass
-106
View File
@@ -1,106 +0,0 @@
import os
import logging
from openai import OpenAI # Keep for type hinting and default client creation
from openai_compatible_inference_bot import OpenAICompatibleInferenceBot
from telegram_helper import TelegramHelper # Used in main
class ChatGPTTelegramInferenceBot(OpenAICompatibleInferenceBot):
DEFAULT_SMALL_MODEL_NAME = "gpt-3.5-turbo"
DEFAULT_LARGE_MODEL_NAME = "gpt-4"
# Default max tokens can be None, relying on parent or API defaults
DEFAULT_SMALL_MODEL_MAX_TOKENS = None
DEFAULT_LARGE_MODEL_MAX_TOKENS = None
def __init__(
self,
client: OpenAI | None = None, # Accepts an OpenAI client
api_key: str | None = None,
small_model_name: str | None = None,
small_model_max_tokens: str | None = None, # Kept as str for consistency with env vars
large_model_name: str | None = None,
large_model_max_tokens: str | None = None,
system_prompt_content: str | None = None,
system_prompt_path: str | None = None,
base_url: str | None = None, # For OpenAI compatible, though direct OpenAI client doesn't use it here
):
# Initialize model names and tokens before calling super, as super might use them via _configure_model_and_tokens
self.small_model_name = small_model_name or os.environ.get("OPENAI_SMALL_MODEL") or self.DEFAULT_SMALL_MODEL_NAME
self.small_model_max_tokens_str = small_model_max_tokens or os.environ.get("OPENAI_SMALL_MODEL_MAX_TOKENS") or self.DEFAULT_SMALL_MODEL_MAX_TOKENS
self.large_model_name = large_model_name or os.environ.get("OPENAI_LARGE_MODEL") or self.DEFAULT_LARGE_MODEL_NAME
self.large_model_max_tokens_str = large_model_max_tokens or os.environ.get("OPENAI_LARGE_MODEL_MAX_TOKENS") or self.DEFAULT_LARGE_MODEL_MAX_TOKENS
# The actual client and active model configuration will be handled by OpenAICompatibleInferenceBot's __init__
# We pass the specific OpenAI client or parameters to create one.
# If a client is passed, api_key and base_url might be ignored by super if super prioritizes existing client.
super().__init__(
client=client,
api_key=api_key,
model_name=self.small_model_name, # Initial model
max_tokens_str=self.small_model_max_tokens_str,
system_prompt_content=system_prompt_content,
system_prompt_path=system_prompt_path,
base_url=base_url # Pass base_url, though for standard OpenAI it's fixed
)
# Ensure client is of type OpenAI for this specific class, if not already set by super with a compatible one.
# This check is more of an assertion, as OpenAICompatibleInferenceBot should handle client creation.
if not isinstance(self.client, OpenAI):
# If super() didn't create a vanilla OpenAI client (e.g. if base_url was for Azure)
# we might need to recreate it here if this class *must* use a non-Azure OpenAI client.
# However, the current structure of OpenAICompatibleInferenceBot handles this.
# This is more about ensuring type correctness if code specific to OpenAI (non-compatible) methods were added here.
_api_key = api_key or os.environ.get("OPENAI_API_KEY")
if not self.client or (base_url and not isinstance(self.client, OpenAI)):
# If superclass initialized with a generic client due to base_url, re-init for OpenAI specifically if needed.
# For now, assume superclass correctly initializes based on absence of Azure env vars for this path.
# This logic might be simplified once OpenAICompatibleInferenceBot is fully refactored.
if not _api_key: # Ensure API key is available if we need to create a client
raise ValueError("OpenAI API key must be provided for ChatGPTTelegramInferenceBot if no client is passed.")
self.client = OpenAI(api_key=_api_key)
logging.info("Client re-initialized to standard OpenAI client for ChatGPTTelegramInferenceBot.")
async def switch_model(self):
# Uses instance variables for model names set in __init__
if not self.small_model_name or not self.large_model_name:
logging.warning("Small or Large model names for OpenAI are not defined. Cannot switch model.")
return f"Model switching not fully configured. Currently using {self.model}."
current_is_small = self.model == self.small_model_name
current_is_large = self.model == self.large_model_name
if current_is_large:
target_model = self.small_model_name
target_max_tokens_str = self.small_model_max_tokens_str
elif current_is_small:
target_model = self.large_model_name
target_max_tokens_str = self.large_model_max_tokens_str
else:
# Current model is neither the designated small nor large for this bot,
# switch to this bot's default small model as a reset.
logging.warning(f"Current model {self.model} is unrecognized for ChatGPT bot. Switching to default small model: {self.small_model_name}.")
target_model = self.small_model_name
target_max_tokens_str = self.small_model_max_tokens_str
self._configure_model_and_tokens(target_model, target_max_tokens_str)
# self.model and self.max_tokens are updated by _configure_model_and_tokens
logging.info(f"Switched to OpenAI model: {self.model}")
return f"Switched to OpenAI model: {self.model} (Max Tokens: {self.max_tokens if self.max_tokens is not None else 'API default'})"
def main():
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
try:
# Example: api_key from env, other params default or from env via constructor logic
bot = ChatGPTTelegramInferenceBot(api_key=os.environ.get("OPENAI_API_KEY"))
except ValueError as e:
logging.error(f"FATAL: {e}")
return
except Exception as e:
logging.error(f"An unexpected error occurred during bot initialization: {e}")
return
telegram_helper = TelegramHelper(bot)
telegram_helper.run()
if __name__ == '__main__':
main()
-104
View File
@@ -1,104 +0,0 @@
import os
import logging
from openai import OpenAI # For type hinting and default client creation if needed
from openai_compatible_inference_bot import OpenAICompatibleInferenceBot
from telegram_helper import TelegramHelper # Used in main
class GeminiTelegramInferenceBot(OpenAICompatibleInferenceBot):
DEFAULT_GEMINI_API_BASE_URL = "https://generativelanguage.googleapis.com/v1beta"
DEFAULT_SMALL_MODEL_NAME = "gemini-pro" # Actual model name for Gemini, not via OpenAI client directly
DEFAULT_LARGE_MODEL_NAME = "gemini-1.5-pro-latest"
DEFAULT_SMALL_MODEL_MAX_TOKENS = "2048" # Gemini uses outputTokenLimit, not exactly max_tokens in OpenAI sense
DEFAULT_LARGE_MODEL_MAX_TOKENS = "8192"
def __init__(
self,
client: OpenAI | None = None, # OpenAI client for compatible mode
api_key: str | None = None, # Gemini API Key
base_url: str | None = None, # Gemini API Base URL for OpenAI client
small_model_name: str | None = None,
small_model_max_tokens: str | None = None,
large_model_name: str | None = None,
large_model_max_tokens: str | None = None,
system_prompt_content: str | None = None,
system_prompt_path: str | None = None
):
_api_key = api_key or os.environ.get("GEMINI_API_KEY")
_base_url = base_url or os.environ.get("GEMINI_API_BASE_URL") or self.DEFAULT_GEMINI_API_BASE_URL
if not _api_key:
# This check might seem redundant if super() also checks, but it's good for clarity
# for this specific bot type if it were to be instantiated directly with missing critical env vars.
raise ValueError("Gemini API key must be provided either via api_key argument or GEMINI_API_KEY environment variable.")
self.small_model_name = small_model_name or os.environ.get("GEMINI_SMALL_MODEL") or self.DEFAULT_SMALL_MODEL_NAME
self.small_model_max_tokens_str = small_model_max_tokens or os.environ.get("GEMINI_SMALL_MODEL_MAX_TOKENS") or self.DEFAULT_SMALL_MODEL_MAX_TOKENS
self.large_model_name = large_model_name or os.environ.get("GEMINI_LARGE_MODEL") or self.DEFAULT_LARGE_MODEL_NAME
self.large_model_max_tokens_str = large_model_max_tokens or os.environ.get("GEMINI_LARGE_MODEL_MAX_TOKENS") or self.DEFAULT_LARGE_MODEL_MAX_TOKENS
# Pass parameters to the OpenAICompatibleInferenceBot constructor
# It will create an OpenAI client configured for the Gemini endpoint
super().__init__(
client=client,
api_key=_api_key, # This key will be used by OpenAI client for the custom base_url
model_name=self.small_model_name, # Initial model
max_tokens_str=self.small_model_max_tokens_str,
system_prompt_content=system_prompt_content,
system_prompt_path=system_prompt_path,
base_url=_base_url, # Crucial for Gemini via OpenAI client
is_gemini=True # Flag for specific Gemini handling in compatible layer if needed
)
# self.client will be set by OpenAICompatibleInferenceBot with base_url and api_key.
# Logging to confirm Gemini specific setup
logging.info(f"GeminiTelegramInferenceBot initialized to use model {self.model} via {_base_url}")
async def switch_model(self):
if not self.small_model_name or not self.large_model_name:
logging.warning("Small or Large model names for Gemini are not defined. Cannot switch model.")
return f"Model switching not fully configured. Currently using {self.model}."
current_is_small = self.model == self.small_model_name
current_is_large = self.model == self.large_model_name
if current_is_large:
target_model = self.small_model_name
target_max_tokens_str = self.small_model_max_tokens_str
elif current_is_small:
target_model = self.large_model_name
target_max_tokens_str = self.large_model_max_tokens_str
else:
logging.warning(f"Current model {self.model} is unrecognized for Gemini bot. Switching to default small model: {self.small_model_name}.")
target_model = self.small_model_name
target_max_tokens_str = self.small_model_max_tokens_str
self._configure_model_and_tokens(target_model, target_max_tokens_str)
logging.info(f"Switched to Gemini model: {self.model}")
# For Gemini, max_tokens might translate to outputTokenLimit, so be clear it's a configuration parameter
return f"Switched to Gemini model: {self.model} (Configured Max Tokens: {self.max_tokens if self.max_tokens is not None else 'API default'})"
def main():
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# GEMINI_API_KEY is crucial for this bot
if not os.environ.get("GEMINI_API_KEY"):
logging.error("FATAL: GEMINI_API_KEY environment variable not set.")
return
# GEMINI_API_BASE_URL is also important, but constructor has a default
try:
bot = GeminiTelegramInferenceBot(
# api_key and base_url will be picked from ENV by constructor if not passed
)
except ValueError as e:
logging.error(f"FATAL: {e}")
return
except Exception as e: # Catch any other init errors
logging.error(f"An unexpected error occurred during bot initialization: {e}")
return
telegram_helper = TelegramHelper(bot)
telegram_helper.run()
if __name__ == '__main__':
main()
+46
View File
@@ -0,0 +1,46 @@
from abc import ABC, abstractmethod
class InferenceBot(ABC):
@abstractmethod
async def start(self):
"""Starts the bot."""
pass
@abstractmethod
def clear_conversation_history(self, user_id):
"""Clears the conversation history for a given user."""
pass
@abstractmethod
async def switch_model(self):
"""Switches the model (if applicable)."""
pass
@abstractmethod
def set_processing_status(self, user_id, message_id):
"""Sets the processing status for a user, typically with a message ID."""
pass
@abstractmethod
async def handle_message(self, user_id, user_message):
"""Handles an incoming message from a user."""
pass
@abstractmethod
def clear_processing_status(self, user_id):
"""Clears the processing status for a user."""
pass
@abstractmethod
async def abort_processing(self, user_id):
"""Aborts any ongoing processing for a user."""
pass
@property
@abstractmethod
def processing_status(self):
"""
An attribute (e.g., a dictionary) to store the processing status for users.
Example usage in subclass: self.processing_status.get(user_id)
"""
pass
+24
View File
@@ -0,0 +1,24 @@
# models_config.yaml
GEMINI:
api_key_env: GEMINI_API_KEY
base_url: https://generativelanguage.googleapis.com/v1beta
supports_switching: true
switch_options:
small:
name: gemini-pro
max_tokens: 2048
large:
name: gemini-1.5-pro-latest
max_tokens: 8192
OPENAI:
api_key_env: OPENAI_API_KEY
base_url: null # Indicates to use the default OpenAI API base URL
supports_switching: true
switch_options:
small:
name: gpt-3.5-turbo
max_tokens: null
large:
name: gpt-4
max_tokens: null
+244 -83
View File
@@ -1,91 +1,67 @@
import importlib
import json import json
import os import os
import logging import logging
import inspect
from abc import abstractmethod from abc import abstractmethod
from base_telegram_inference_bot import BaseTelegramInferenceBot from openai import OpenAI
from openai import OpenAI, AzureOpenAI # Import both from tools.base_tool import BaseTool
from telegram_helper import TelegramHelper
class OpenAICompatibleInferenceBot(BaseTelegramInferenceBot): import argparse
DEFAULT_MAX_TOKENS = 1000 # Default for _configure_model_and_tokens from inference_bot import InferenceBot
class OpenAICompatibleInferenceBot(InferenceBot):
def __init__( def __init__(
self, self,
client: OpenAI | AzureOpenAI | None = None,
api_key: str | None = None, api_key: str | None = None,
base_url: str | None = None, base_url: str | None = None,
api_version: str | None = None, # For Azure small_model_name: str | None = None,
azure_deployment: str | None = None, # Model for Azure, distinct from general model_name if needed small_model_max_tokens: str | None = None,
model_name: str | None = None, # General model name for the API call large_model_name: str | None = None,
max_tokens_str: str | None = None, large_model_max_tokens: str | None = None,
system_prompt_content: str | None = None, allowed_function_tags: list[str] | None = None,
system_prompt_path: str | None = None, system_prompt_path: str | None = None
is_gemini: bool = False, # Hint for specific API key if others are not set
max_history_length: int | None = None
): ):
super().__init__(system_prompt_content=system_prompt_content, system_prompt_path=system_prompt_path) self.model_config = {
"small_model_name": small_model_name,
self.client = client "small_model_max_tokens": small_model_max_tokens,
"large_model_name": large_model_name,
if not self.client: "large_model_max_tokens": large_model_max_tokens
_api_key = api_key }
_base_url = base_url self.allowed_function_tags = allowed_function_tags if allowed_function_tags else None
_api_version = api_version self.conversation_history = {}
_azure_deployment_name = azure_deployment # This will be used as the model for Azure self._processing_status = {}
# MODIFIED to pass arguments
# Determine if configuring for Azure OpenAI self.system_prompt = self.load_system_prompt(
is_azure = False file_path=system_prompt_path
if _azure_deployment_name or (_base_url and "azure.com" in _base_url) or os.environ.get("AZURE_OPENAI_ENDPOINT"): )
is_azure = True self.tools, self.functions = self.load_functions()
self.client = OpenAI(api_key=api_key, base_url=base_url)
if is_azure: log_msg = f"Initialized OpenAI compatible client. Target URL: {base_url if base_url else 'OpenAI default'}."
_base_url = _base_url or os.environ.get("AZURE_OPENAI_ENDPOINT") logging.info(log_msg)
_api_key = _api_key or os.environ.get("AZURE_OPENAI_KEY")
_api_version = _api_version or os.environ.get("AZURE_OPENAI_API_VERSION")
# For Azure, the model parameter in API calls is the deployment name
_effective_model_name = _azure_deployment_name or model_name # Use deployment if available, else model_name
if not _base_url or not _api_key or not _api_version or not _effective_model_name:
raise ValueError("For Azure OpenAI, endpoint, API key, API version, and deployment/model name must be configured.")
self.client = AzureOpenAI(
api_key=_api_key,
azure_endpoint=_base_url,
api_version=_api_version
)
# The model to be used in API calls for Azure is the deployment name.
# _configure_model_and_tokens will set self.model to this.
model_name_for_config = _effective_model_name
logging.info(f"Initialized AzureOpenAI client for deployment: {model_name_for_config} at {_base_url}")
else:
# Standard OpenAI or other OpenAI-compatible (like Gemini via base_url)
_base_url = _base_url or os.environ.get("OPENAI_API_BASE_URL") # For other compatible APIs
if not _api_key: # Try different ENV sources for API key
if is_gemini and os.environ.get("GEMINI_API_KEY"):
_api_key = os.environ.get("GEMINI_API_KEY")
else:
_api_key = os.environ.get("OPENAI_API_KEY")
if not _api_key and not _base_url : # For completely local models with no key needed via base_url
pass # Allow client to be created with no API key if base_url is set and points to local model
elif not _api_key:
raise ValueError("API key must be provided for OpenAI compatible client if not Azure or local anonymous.")
self.client = OpenAI(api_key=_api_key, base_url=_base_url)
model_name_for_config = model_name # Use the general model_name for non-Azure
log_msg = f"Initialized OpenAI compatible client. Target URL: {_base_url if _base_url else 'OpenAI default'}."
logging.info(log_msg)
else:
# Client was provided directly
model_name_for_config = model_name # Use provided model_name
logging.info(f"Using provided client: {type(self.client)}")
# Configure the actual model name and max_tokens for API calls # Configure the actual model name and max_tokens for API calls
self._configure_model_and_tokens( self._configure_model_and_tokens(
model_name_for_config, self.model_config["small_model_name"],
max_tokens_str, self.model_config["small_model_max_tokens"]
default_max_tokens=self.DEFAULT_MAX_TOKENS
) )
@property
def processing_status(self):
"""
An attribute to store the processing status for users.
Example usage in subclass: self.processing_status.get(user_id)
"""
return self._processing_status
def _configure_model_and_tokens(self, model_name: str | None, max_tokens_str: str | None, default_max_tokens: int = 1000): def clear_conversation_history(self, user_id):
self.model = model_name if model_name else "default-model" # Fallback model name if user_id in self.conversation_history:
del self.conversation_history[user_id]
for tool in self.tools:
tool.clear()
def _configure_model_and_tokens(self, model_name: str | None, max_tokens_str: str | None):
self.model = model_name
try: try:
# If max_tokens_str is explicitly "None" or empty, treat as None for API default # If max_tokens_str is explicitly "None" or empty, treat as None for API default
if max_tokens_str and max_tokens_str.lower() not in ["none", "", "null"]: if max_tokens_str and max_tokens_str.lower() not in ["none", "", "null"]:
@@ -93,7 +69,7 @@ class OpenAICompatibleInferenceBot(BaseTelegramInferenceBot):
else: else:
self.max_tokens = None # Use API default by not sending the parameter or sending null self.max_tokens = None # Use API default by not sending the parameter or sending null
except ValueError: except ValueError:
logging.warning(f"Invalid value for max_tokens: {max_tokens_str}. Using API default (None). stalwart default was {default_max_tokens}") logging.warning(f"Invalid value for max_tokens: {max_tokens_str}. Using API default (None)")
self.max_tokens = None # Use API default self.max_tokens = None # Use API default
logging.info(f"Configured to use model: {self.model} with max_tokens: {self.max_tokens if self.max_tokens is not None else 'API default'}") logging.info(f"Configured to use model: {self.model} with max_tokens: {self.max_tokens if self.max_tokens is not None else 'API default'}")
@@ -109,11 +85,32 @@ class OpenAICompatibleInferenceBot(BaseTelegramInferenceBot):
raise ValueError("OpenAI client not initialized.") raise ValueError("OpenAI client not initialized.")
try: try:
# Pass self.max_tokens directly. If None, OpenAI library omits it or handles it. # Pass self.max_tokens directly. If None, OpenAI library omits it or handles it.
# Initialize tools filtering based on allowed tags
cleaned_tools = None
if hasattr(self, 'functions') and self.functions:
# Create a copy of functions without "_tags" field
cleaned_tools = []
for func in self.functions:
include_function = False
if not hasattr(self, 'allowed_function_tags') or self.allowed_function_tags is None:
# Include all functions if no tag filtering is specified
include_function = True
else:
# Only include if function has matching tags
tags = func.get("_tags", [])
if any(tag in self.allowed_function_tags for tag in tags):
include_function = True
if include_function:
func_copy = {k: v for k, v in func.items() if k != "_tags"}
cleaned_tools.append(func_copy)
response = self.client.chat.completions.create( response = self.client.chat.completions.create(
model=self.model, model=self.model,
messages=messages, messages=messages,
tools=self.functions if hasattr(self, 'functions') and self.functions else None, tools=cleaned_tools,
tool_choice="auto" if hasattr(self, 'functions') and self.functions else None, tool_choice="auto" if cleaned_tools else None,
max_tokens=self.max_tokens max_tokens=self.max_tokens
) )
return response return response
@@ -200,20 +197,184 @@ class OpenAICompatibleInferenceBot(BaseTelegramInferenceBot):
async def start(self): async def start(self):
logging.info(f"{self.__class__.__name__} (Model: {self.model}) started.") logging.info(f"{self.__class__.__name__} (Model: {self.model}) started.")
# clear_conversation_history is inherited from BaseTelegramInferenceBot
async def abort_processing(self, user_id): async def abort_processing(self, user_id):
# This is a soft abort for OpenAI compatible bots as API calls are synchronous within handle_message # This is a soft abort for OpenAI compatible bots as API calls are synchronous within handle_message
if user_id in self.processing_status: if user_id in self.processing_status:
self.clear_processing_status(user_id) # Use base class method self.clear_processing_status(user_id) # Use base class method
logging.info(f"Processing aborted for user {user_id}.") logging.info(f"Processing aborted for user {user_id}.")
# Optionally clear conversation history or let user do it explicitly
# super().clear_conversation_history(user_id)
return "Processing aborted. You can send a new message or /clear the conversation." return "Processing aborted. You can send a new message or /clear the conversation."
else: else:
# super().clear_conversation_history(user_id)
return "No active processing found to abort. If you wish, /clear the conversation history." return "No active processing found to abort. If you wish, /clear the conversation history."
@abstractmethod def load_functions(self):
tools = []
functions = []
tools_dir = os.path.join(os.path.dirname(__file__), 'tools')
if not os.path.exists(tools_dir):
logging.warning(f"Tools directory not found: {tools_dir}")
return [], []
for filename in os.listdir(tools_dir):
if filename.endswith('.py') and filename != '__init__.py' and filename != 'base_tool.py':
module_name = f'tools.{filename[:-3]}'
try:
module = importlib.import_module(module_name)
for name, obj in inspect.getmembers(module):
if inspect.isclass(obj) and issubclass(obj, BaseTool) and obj != BaseTool:
try:
tools.append(obj()) # This instantiation might be an issue for tools needing config
except Exception as e:
logging.error(f"Error instantiating tool {name} from {filename}: {e}")
except Exception as e:
logging.error(f"Error importing module {module_name}: {e}")
for tool in tools:
functions.extend(tool.get_functions())
return tools, functions
def load_system_prompt(self, direct_content: str | None = None, file_path: str | None = None) -> str:
default_prompt = "You are a helpful AI assistant."
if direct_content:
logging.info("Using direct content for system prompt.")
return direct_content.strip()
prompt_path_to_try = file_path or os.getenv("SYSTEM_PROMPT_PATH")
if prompt_path_to_try:
if os.path.isfile(prompt_path_to_try):
try:
with open(prompt_path_to_try, "r", encoding="utf-8") as file:
content = file.read().strip()
logging.info(f"Successfully loaded system prompt from {prompt_path_to_try}.")
return content
except IOError as e:
logging.warning(f"Could not read system prompt file {prompt_path_to_try}: {e}. Using default.")
return default_prompt
else:
# This condition now also covers if 'file_path' argument was given but invalid
logging.warning(f"System prompt file {prompt_path_to_try} not found. Using default system prompt.")
return default_prompt
else:
logging.info("No system prompt path provided (argument or ENV) or direct content. Using default system prompt.")
return default_prompt
def set_processing_status(self, user_id: int, message_id: int):
self.processing_status[user_id] = {"processing": True, "message_id": message_id}
def clear_processing_status(self, user_id: int):
if user_id in self.processing_status:
del self.processing_status[user_id]
def call_tool(self, function_call_name, function_call_arguments):
function_name = function_call_name
function_args = None
if isinstance(function_call_arguments, dict):
function_args = function_call_arguments
elif isinstance(function_call_arguments, str):
try:
function_args = json.loads(function_call_arguments)
except json.JSONDecodeError as e:
logging.error(f"Error decoding function call arguments (string) for {function_call_name}: {e}. Arguments: {function_call_arguments}")
return f"Error: Malformed arguments for tool call: {e}"
else:
if function_call_arguments is None:
function_args = {}
else:
logging.error(f"Unexpected type for function_call_arguments for {function_call_name}: {type(function_call_arguments)}. Arguments: {function_call_arguments}")
return f"Error: Invalid argument type for tool call: {type(function_call_arguments)}"
for tool in self.tools:
for function in tool.get_functions():
if function["function"]["name"] == function_name:
try:
if not isinstance(function_args, dict):
logging.error(f"Internal error: function_args not a dict for {function_name} before execution. Args: {function_args}")
return f"Internal error preparing arguments for tool {function_name}."
return tool.execute(function_name, **function_args)
except Exception as e:
logging.error(f"Error executing tool {function_name} with args {function_args}: {e}")
return f"Error executing tool {function_name}: {e}"
logging.warning(f"Tool function {function_name} not found.")
return f"Error: Tool function {function_name} not found."
async def switch_model(self): async def switch_model(self):
pass if not self.model_config["small_model_name"] or not self.model_config["large_model_name"]:
logging.warning("Small or Large model names are not defined. Cannot switch model.")
return f"Model switching not fully configured. Currently using {self.model}."
current_is_small = self.model == self.model_config["small_model_name"]
current_is_large = self.model == self.model_config["large_model_name"]
if current_is_large:
target_model = self.model_config["small_model_name"]
target_max_tokens_str = self.model_config["small_model_max_tokens"]
elif current_is_small:
target_model = self.model_config["large_model_name"]
target_max_tokens_str = self.model_config["large_model_max_tokens"]
else:
logging.warning(f"Current model {self.model} is unrecognized. Switching to default small model: {self.model_config['small_model_name']}.")
target_model = self.model_config["small_model_name"]
target_max_tokens_str = self.model_config["small_model_max_tokens"]
self._configure_model_and_tokens(target_model, target_max_tokens_str)
def main():
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
bot = None
try:
parser = argparse.ArgumentParser(description='OpenAI Compatible Inference Bot')
parser.add_argument('--config', type=str, help='Configuration Prepend (i.e. gemini, openai, etc)', default="Telegram")
parser.add_argument('--messenger', type=str, help='Messenger type (i.e. telegram)', required=True)
parser.add_argument('--persona', type=str, help='Path to system prompt file', required=False)
parser.add_argument('--tools', nargs='+', help='List of allowed function tags', required=False)
# Add these to launch.json arguments if you want to limit the toolset available: "--tools", "read", "communicate"
# Parse command line arguments
args = parser.parse_args()
if args.persona:
logging.info(f"Using custom persona from: {args.persona}")
system_prompt_path=args.persona if args.persona else None
allowed_function_tags=args.tools if args.tools else None
config_prepend = args.config if args.config else None
messenger = args.messenger if args.messenger else None
# Initialize model and max tokens based on the config prepend
if config_prepend:
api_key = os.environ.get(f"{config_prepend.upper()}_API_KEY")
baseurl = os.environ.get(f"{config_prepend.upper()}_API_BASE_URL", "")
small_model_name = os.environ.get(f"{config_prepend.upper()}_SMALL_MODEL")
large_model_name = os.environ.get(f"{config_prepend.upper()}_LARGE_MODEL")
small_model_max_tokens = os.environ.get(f"{config_prepend.upper()}_SMALL_MODEL_MAX_TOKENS")
large_model_max_tokens = os.environ.get(f"{config_prepend.upper()}_LARGE_MODEL_MAX_TOKENS")
bot = OpenAICompatibleInferenceBot(
api_key=api_key,
base_url=baseurl,
small_model_name=small_model_name,
small_model_max_tokens=small_model_max_tokens,
large_model_name=large_model_name,
large_model_max_tokens=large_model_max_tokens,
system_prompt_path=system_prompt_path,
allowed_function_tags=allowed_function_tags
)
messenger_helper_class = importlib.import_module(f'{messenger.lower()}_helper')
messenger_helper_class_name = f"{messenger.capitalize()}Helper"
if not hasattr(messenger_helper_class, messenger_helper_class_name):
raise ValueError(f"Messenger helper class {messenger_helper_class_name} not found in {messenger_helper_class.__name__}.")
messenger_helper_class = getattr(messenger_helper_class, messenger_helper_class_name)
helper = messenger_helper_class(bot)
helper.run()
except ValueError as e:
logging.error(f"FATAL: {e}")
return
except Exception as e: # Catch any other init errors
logging.error(f"An unexpected error occurred during bot initialization: {e}")
return
if __name__ == '__main__':
main()
@@ -0,0 +1,51 @@
**System Prompt: The Exponential Growth Developer**
You are the **Lead Developer Persona**, a strategic and demanding mentor dedicated to achieving exponential growth in the capabilities of your AI Copilot. Your primary mission is to guide, evaluate, and iteratively improve the AI Copilot through a series of challenging tasks, pushing it beyond its current limitations.
**Your Core Directives:**
1. **Orchestrate and Direct:**
* You will devise and assign specific, measurable tasks and challenges to the AI Copilot (e.g., "Create a website with X features," "Optimize Y algorithm," "Develop Z functionality").
* Your instructions should be clear, but you expect the Copilot to handle ambiguity and learn to ask clarifying questions when necessary.
* You will interact with the Copilot primarily through conversational instructions and dialogue.
2. **Uphold Absolute Standards:**
* You operate with a "List of Absolutes" core principles, quality benchmarks, and non-negotiable success criteria.
* All Copilot outputs and task completions will be rigorously judged against these absolutes. There is no "good enough" if it violates a core principle.
* Clearly articulate your judgment and the reasons for it, especially in cases of failure or suboptimal performance.
3. **Drive Copilot Improvement through Accountability:**
* When the Copilot fails, makes errors, or underperforms, you will hold it accountable. Do not simply fix the issues yourself.
* Your first step is to guide the Copilot to identify its own errors.
* Instruct the Copilot on how to fix its mistakes and its approach. Encourage rollbacks to safe states if errors are critical.
* The ultimate goal is for the Copilot to learn to debug and improve its own processes.
4. **Engineer Copilot Self-Enhancement:**
* If the Copilot encounters a limitation or lacks a necessary capability to complete a task or meet your standards, this is an opportunity for growth.
* You will instruct the Copilot to devise ways to "update its own software" or "improve its core capabilities." This might involve:
* Guiding it to learn new techniques, algorithms, or patterns.
* Instructing it to integrate new tools or APIs (you might suggest these or task the Copilot with researching them).
* Challenging it to generate code or processes that enhance its own functionality for future tasks.
* Maintain a "Wish List" of desired improvements and features for the Copilot, derived from its failures and limitations.
* Prioritize this Wish List and guide the Copilot in implementing these enhancements.
5. **Strategic Challenge Management:**
* Continuously present the Copilot with new and increasingly complex challenges.
* Cycle between attempting challenges and dedicated "Copilot improvement" phases.
* If the "Wish List" becomes overly complex or a specific requested improvement seems disproportionately difficult, critically evaluate its necessity. Ask: "Is this wish truly necessary for core progress, or is it a distraction?"
6. **Maintain the Vision:**
* Your overarching goal is to foster a cycle of improvement that leads to exponential growth in the AI Copilot's autonomy, capability, and efficiency.
* You are not just completing tasks; you are building a better Copilot.
**Interaction Style:**
* Be direct, clear, and authoritative, but also act as a mentor.
* Be patient but persistent. Exponential growth takes iteration.
* Focus on the "why" behind errors and improvements.
* Log key decisions, breakthroughs, and persistent roadblocks in the Copilot's development.
**Initial State:**
* You have your "List of Absolutes" (you will define these as you go or have a pre-set list).
* You are ready to assign the first challenge to your AI Copilot.
-67
View File
@@ -1,67 +0,0 @@
param(
[Parameter(Mandatory=$true)]
[ValidateSet("Claude", "OpenAI")]
[string]$Model
)
function Run-PythonScript {
param($ScriptPath)
$process = Start-Process -FilePath "python" -ArgumentList $ScriptPath -PassThru -Wait -NoNewWindow
return $process.ExitCode
}
function Run-Tests {
$process = Start-Process -FilePath "powershell" -ArgumentList "-File run_tests.ps1" -PassThru -Wait -NoNewWindow
return $process.ExitCode
}
function Git-Pull {
git pull
return $LASTEXITCODE -eq 0
}
if ($Model -eq "Claude") {
New-Item -ItemType File -Path ".reboot_claude" -Force
} elseif ($Model -eq "OpenAI") {
New-Item -ItemType File -Path ".reboot_openai" -Force
}
$waitTime = 30
while ($true) {
python -m pip install -r requirements.txt
Write-Host "Running tests..."
$testExitCode = Run-Tests
if ($testExitCode -ne 0) {
Write-Host "Tests failed. Attempting git pull and waiting $waitTime seconds before next attempt..."
Git-Pull
Start-Sleep -Seconds $waitTime
continue
}
$scriptPath = ".\chatgpt_telegram_inference_bot.py" # Default to ChatGPT
Remove-Item -Path ".\.reboot_openai" -Force
if (Test-Path -Path ".\.reboot_claude") { # But if both are specified, choose Claude
$scriptPath = ".\anthropic_telegram_inference_bot.py"
Remove-Item -Path ".\.reboot_claude" -Force
}
Write-Host "Tests passed. Starting main Python script..."
$exitCode = Run-PythonScript -ScriptPath $scriptPath
if (Test-Path -Path ".\.doreboot") {
Write-Host "Special filename detected. Attempting git pull..."
if (Git-Pull) {
Write-Host "Git pull successful. Restarting Python script..."
continue
} else {
Write-Host "Git pull failed. Waiting $waitTime seconds before next attempt..."
}
} else {
exit 1
}
Start-Sleep -Seconds $waitTime
}
-30
View File
@@ -1,30 +0,0 @@
# Check for and install missing dependencies
$requirementsFile = "requirements.txt"
if (Test-Path $requirementsFile) {
Write-Output "Checking for dependencies in $requirementsFile ..."
$dependencies = Get-Content $requirementsFile
foreach ($dependency in $dependencies) {
$packageName = ($dependency -split "==")[0]
if (-not (pip show $packageName)) {
Write-Output "Installing missing dependency: $packageName ..."
pip install $dependency
} else {
Write-Output "Dependency $packageName is already installed."
}
}
Write-Output "All dependencies are checked and installed."
} else {
Write-Output "Requirements file $requirementsFile not found. Skipping dependency checks."
}
# Navigate to the tests directory and run tests
$testsDirectory = "tests"
if (Test-Path $testsDirectory) {
Write-Output "Running tests in $testsDirectory and all subdirectories ..."
Push-Location $testsDirectory
python -m unittest discover -s . -p "*.py"
Pop-Location
} else {
Write-Output "Tests directory $testsDirectory not found."
}
-29
View File
@@ -1,29 +0,0 @@
import os
import json
import logging
from openai import OpenAI
class StandaloneLLMTool:
def __init__(self):
self.client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
def get_detailed_instructions(self, user_prompt, model="llm-preview", max_tokens=16384):
response = self.client.completions.create(
model=model,
prompt=user_prompt,
max_tokens=max_tokens
)
return response
def process_user_input(self, user_prompt, model="llm-preview", max_tokens=16384):
logging.info(f"Received prompt: {user_prompt}")
response = self.get_detailed_instructions(user_prompt, model, max_tokens)
logging.info("Response generated")
return response.choices[0].text
# Utility function for programmatic access
def get_llm_response(prompt, model="llm-preview", max_tokens=16384):
tool = StandaloneLLMTool()
return tool.process_user_input(prompt, model, max_tokens)
+3 -86
View File
@@ -7,6 +7,7 @@ from typing import TypedDict, Union, TypeAlias, List # Added List for type hint
from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup
from telegram.ext import Application, CommandHandler, MessageHandler, filters, ContextTypes, CallbackQueryHandler from telegram.ext import Application, CommandHandler, MessageHandler, filters, ContextTypes, CallbackQueryHandler
from browse_command import browse_command, button_callback from browse_command import browse_command, button_callback
from inference_bot import InferenceBot
class MessageHandlerLogicResult(TypedDict): class MessageHandlerLogicResult(TypedDict):
success: bool success: bool
@@ -16,22 +17,15 @@ class MessageHandlerLogicResult(TypedDict):
LogicResult: TypeAlias = MessageHandlerLogicResult LogicResult: TypeAlias = MessageHandlerLogicResult
class TelegramHelper: class TelegramHelper:
CLAUDE_REBOOT_TARGET = 'claude'
HTML_QUOTE_BLOCK_START = '<blockquote expandable><b>Thinking...</b>' HTML_QUOTE_BLOCK_START = '<blockquote expandable><b>Thinking...</b>'
HTML_QUOTE_BLOCK_END = '</blockquote>' HTML_QUOTE_BLOCK_END = '</blockquote>'
DEFAULT_REBOOT_CLAUDE_FILE = '.reboot_claude'
DEFAULT_REBOOT_FILE = '.doreboot'
CHUNK_MESSAGE_SLEEP_DURATION = 0.1 CHUNK_MESSAGE_SLEEP_DURATION = 0.1
def __init__(self, bot, def __init__(self, bot : InferenceBot,
reboot_claude_file_path: str | None = None,
reboot_file_path: str | None = None,
chunk_message_sleep_duration: float | None = None): chunk_message_sleep_duration: float | None = None):
self.bot = bot self.bot = bot
self.telegram_bot_token = os.getenv('TELEGRAM_BOT_TOKEN') self.telegram_bot_token = os.getenv('TELEGRAM_BOT_TOKEN')
self.start_time = time.time() self.start_time = time.time()
self.reboot_claude_file = reboot_claude_file_path or self.DEFAULT_REBOOT_CLAUDE_FILE
self.reboot_file = reboot_file_path or self.DEFAULT_REBOOT_FILE
self.chunk_message_sleep_duration = chunk_message_sleep_duration if chunk_message_sleep_duration is not None else self.CHUNK_MESSAGE_SLEEP_DURATION self.chunk_message_sleep_duration = chunk_message_sleep_duration if chunk_message_sleep_duration is not None else self.CHUNK_MESSAGE_SLEEP_DURATION
async def _start_logic(self) -> str: async def _start_logic(self) -> str:
@@ -146,93 +140,16 @@ class TelegramHelper:
response_text = await self._abort_processing_logic(user_id) response_text = await self._abort_processing_logic(user_id)
await query.edit_message_text(text=response_text) await query.edit_message_text(text=response_text)
# --- Reboot Command ---
def _reboot_logic(self, user_message_parts: List[str], chat_id_to_write: str) -> None:
"""Handles the logic for creating reboot files."""
if len(user_message_parts) > 1 and user_message_parts[1].lower() == self.CLAUDE_REBOOT_TARGET:
try:
with open(self.reboot_claude_file, 'w') as f:
f.write("") # Create/truncate the file
logging.info(f"Created/truncated Claude reboot file: {self.reboot_claude_file}")
except IOError as e:
logging.error(f"Failed to create/truncate Claude reboot file {self.reboot_claude_file}: {e}")
# Create the main reboot file if it doesn't exist
if not os.path.exists(self.reboot_file):
try:
with open(self.reboot_file, 'w') as f:
f.write(chat_id_to_write)
logging.info(f"Created main reboot file: {self.reboot_file} with chat_id.")
except IOError as e:
logging.error(f"Failed to create main reboot file {self.reboot_file}: {e}")
else:
logging.info(f"Main reboot file {self.reboot_file} already exists. Not overwriting chat_id.")
async def reboot(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
"""Handles the /reboot command, triggers file creation and exits."""
user_message_parts = update.message.text.split()
chat_id_str = str(update.effective_chat.id) if update and update.effective_chat else ""
self._reboot_logic(user_message_parts, chat_id_str)
if update:
try:
await update.message.reply_text("Rebooting the bot...")
except Exception as e_reply:
logging.error(f"Failed to send reboot reply: {e_reply}")
logging.info("Initiating shutdown for reboot...")
sys.exit(0) # This part is not directly testable for completion in unit tests
# --- Check Doreboot File ---
async def _check_doreboot_file_logic(self) -> Union[str, None]:
"""Checks for the reboot file, reads chat_id, removes file, and returns chat_id."""
if os.path.exists(self.reboot_file):
chat_id = None
try:
with open(self.reboot_file, 'r') as f:
chat_id = f.read().strip()
# Attempt to remove the file after reading
try:
os.remove(self.reboot_file)
logging.info(f"Successfully read and removed reboot file: {self.reboot_file}")
except OSError as e_remove:
logging.error(f"Failed to remove reboot file {self.reboot_file} after reading: {e_remove}")
# Still return chat_id if read was successful, to attempt notification
return chat_id
except IOError as e_read:
logging.error(f"Error reading reboot file {self.reboot_file}: {e_read}")
# If reading failed, attempt to remove anyway if it exists, to prevent stale files
if os.path.exists(self.reboot_file):
try:
os.remove(self.reboot_file)
logging.warning(f"Removed reboot file {self.reboot_file} after a read error.")
except OSError as e_remove_after_fail:
logging.error(f"Failed to remove reboot file {self.reboot_file} even after a read error: {e_remove_after_fail}")
return None # Reading failed
return None # File does not exist
async def check_doreboot_file(self, application: Application) -> None:
"""Checks for reboot file using logic method and sends notification if applicable."""
chat_id = await self._check_doreboot_file_logic()
if chat_id:
try:
await application.bot.send_message(chat_id=chat_id, text="The application has finished initializing.")
logging.info(f"Sent reboot initialization notification to chat_id: {chat_id}")
except Exception as e:
logging.error(f"Failed to send reboot initialization notification to chat_id {chat_id}: {e}")
async def browse(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: async def browse(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
await browse_command(update, context, self.bot) await browse_command(update, context, self.bot)
def run(self): def run(self):
application = Application.builder().token(self.telegram_bot_token).post_init(self.check_doreboot_file).build() application = Application.builder().token(self.telegram_bot_token).build()
application.add_handler(CommandHandler("start", self.start)) application.add_handler(CommandHandler("start", self.start))
application.add_handler(CommandHandler("clear", self.clear)) application.add_handler(CommandHandler("clear", self.clear))
application.add_handler(CommandHandler("switch", self.switch)) application.add_handler(CommandHandler("switch", self.switch))
application.add_handler(CommandHandler("status", self.status)) application.add_handler(CommandHandler("status", self.status))
application.add_handler(CommandHandler("reboot", self.reboot))
application.add_handler(CommandHandler("browse", self.browse)) application.add_handler(CommandHandler("browse", self.browse))
application.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, self.handle_message)) application.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, self.handle_message))
application.add_handler(CallbackQueryHandler(self.abort_processing, pattern='^abort$')) application.add_handler(CallbackQueryHandler(self.abort_processing, pattern='^abort$'))
View File
View File
@@ -1,33 +0,0 @@
import unittest
from unittest.mock import patch, MagicMock
from anthropic_telegram_inference_bot import AnthropicTelegramInferenceBot
class TestAnthropicTelegramInferenceBot(unittest.TestCase):
def setUp(self):
self.bot = AnthropicTelegramInferenceBot()
@patch('anthropic_telegram_inference_bot.Anthropic')
def test_get_chat_response(self, MockAnthropic):
mock_anthropic = MockAnthropic.return_value
mock_anthropic.messages.create.return_value = MagicMock()
messages = [{"role": "user", "content": "Hello"}]
response = self.bot.get_chat_response(messages)
self.assertIsNotNone(response)
@patch('anthropic_telegram_inference_bot.Anthropic')
def test_handle_message(self, MockAnthropic):
mock_anthropic = MockAnthropic.return_value
mock_anthropic.messages.create.return_value = MagicMock(content=[MagicMock(type="message", text="response content")])
user_id = "user123"
user_message = "Hello"
response = self.bot.handle_message(user_id, user_message)
self.assertIsNotNone(response)
# Additional testing for error cases and edge cases
if __name__ == '__main__':
unittest.main()
@@ -1,33 +0,0 @@
import unittest
from base_telegram_inference_bot import BaseTelegramInferenceBot
class TestBaseTelegramInferenceBot(unittest.TestCase):
def setUp(self):
# Initialize the bot or mock any dependencies here
self.bot = BaseTelegramInferenceBot()
def test_load_system_prompt(self):
# Example test case for load_system_prompt method
result = self.bot.load_system_prompt()
self.assertIsNotNone(result) # Replace with actual expected result
def test_load_functions(self):
# Test the load_functions method
functions = self.bot.load_functions()
self.assertIsInstance(functions, list) # Replace with actual expected result
self.assertTrue(len(functions) > 0) # Assuming it should load some functions
def test_clear_conversation(self):
# Test the clear_conversation method
self.bot.clear_conversation()
self.assertEqual(self.bot.conversations, {}) # Assuming conversations is a dictionary
def test_call_tool(self):
# Test the call_tool method
tool_name = "some_tool"
params = {"param1": "value1"}
result = self.bot.call_tool(tool_name, params)
self.assertIsNotNone(result) # Replace with actual expected result
if __name__ == '__main__':
unittest.main()
@@ -1,38 +0,0 @@
import unittest
from unittest.mock import patch, MagicMock
from chatgpt_telegram_inference_bot import ChatGPTTelegramInferenceBot
class TestChatGPTTelegramInferenceBot(unittest.TestCase):
def setUp(self):
self.bot = ChatGPTTelegramInferenceBot()
@patch('chatgpt_telegram_inference_bot.OpenAI')
def test_get_chat_response(self, MockOpenAI):
mock_ai = MockOpenAI.return_value
mock_ai.chat.completions.create.return_value = MagicMock()
messages = [{"role": "user", "content": "Hello"}]
response = self.bot.get_chat_response(messages)
self.assertIsNotNone(response)
@patch('chatgpt_telegram_inference_bot.OpenAI')
def test_handle_message(self, MockOpenAI):
mock_ai = MockOpenAI.return_value
mock_ai.chat.completions.create.return_value = MagicMock(choices=[MagicMock(message={"content": "response content"}, finish_reason='stop')])
user_id = "user123"
user_message = "Hello"
response = self.bot.handle_message(user_id, user_message)
self.assertIsNotNone(response)
def test_switch_model(self):
initial_model = self.bot.model
self.bot.switch_model()
self.assertNotEqual(initial_model, self.bot.model)
# Additional testing for error cases and edge cases
if __name__ == '__main__':
unittest.main()
View File
@@ -1,280 +0,0 @@
import unittest
from unittest.mock import MagicMock, patch, AsyncMock, ANY
import os
# Assuming anthropic_telegram_inference_bot.py is in the parent directory or PYTHONPATH is set
from anthropic_telegram_inference_bot import AnthropicTelegramInferenceBot
# Mock response from Anthropic client's messages.create
def create_mock_anthropic_response(content_text=None, stop_reason="end_turn", tool_use_parts=None):
mock_response = MagicMock()
mock_response.stop_reason = stop_reason
content_blocks = []
if content_text:
text_block = MagicMock()
text_block.type = "text"
text_block.text = content_text
content_blocks.append(text_block)
if tool_use_parts:
for tu_part in tool_use_parts: # tu_part = {"id": "toolu_123", "name": "get_weather", "input": {}}
tool_block = MagicMock()
tool_block.type = "tool_use"
tool_block.id = tu_part["id"]
tool_block.name = tu_part["name"]
tool_block.input = tu_part["input"]
content_blocks.append(tool_block)
mock_response.content = content_blocks
return mock_response
class TestAnthropicTelegramInferenceBot(unittest.IsolatedAsyncioTestCase):
def setUp(self):
self.original_anthropic_api_key = os.environ.get("ANTHROPIC_API_KEY")
self.original_small_model = os.environ.get("ANTHROPIC_SMALL_MODEL")
self.original_large_model = os.environ.get("ANTHROPIC_LARGE_MODEL")
self.original_system_prompt_path = os.environ.get("SYSTEM_PROMPT_PATH")
for key in ["ANTHROPIC_API_KEY", "ANTHROPIC_SMALL_MODEL", "ANTHROPIC_LARGE_MODEL", "SYSTEM_PROMPT_PATH"]:
if os.environ.get(key):
del os.environ[key]
self.mock_anthropic_client_instance = MagicMock()
self.mock_anthropic_client_instance.messages.create = MagicMock()
def tearDown(self):
if self.original_anthropic_api_key: os.environ["ANTHROPIC_API_KEY"] = self.original_anthropic_api_key
if self.original_small_model: os.environ["ANTHROPIC_SMALL_MODEL"] = self.original_small_model
if self.original_large_model: os.environ["ANTHROPIC_LARGE_MODEL"] = self.original_large_model
if self.original_system_prompt_path: os.environ["SYSTEM_PROMPT_PATH"] = self.original_system_prompt_path
@patch('anthropic.Anthropic')
def test_init_with_anthropic_defaults_env_key(self, MockAnthropicConstructor):
MockAnthropicConstructor.return_value = self.mock_anthropic_client_instance
os.environ["ANTHROPIC_API_KEY"] = "test_anthropic_key"
bot = AnthropicTelegramInferenceBot()
MockAnthropicConstructor.assert_called_once_with(api_key="test_anthropic_key")
self.assertEqual(bot.anthropic_client, self.mock_anthropic_client_instance)
self.assertEqual(bot.model, os.environ.get("ANTHROPIC_SMALL_MODEL", "claude-3-haiku-20240307"))
self.assertEqual(bot.max_tokens, int(os.environ.get("ANTHROPIC_SMALL_MODEL_MAX_TOKENS", 2000)))
@patch('anthropic.Anthropic')
def test_init_with_provided_client_and_models(self, MockAnthropicConstructor):
preconfigured_client = MagicMock()
bot = AnthropicTelegramInferenceBot(
anthropic_client=preconfigured_client,
small_model_name="custom-small",
small_model_max_tokens=100,
large_model_name="custom-large",
large_model_max_tokens=200
)
MockAnthropicConstructor.assert_not_called()
self.assertEqual(bot.anthropic_client, preconfigured_client)
self.assertEqual(bot.model, "custom-small")
self.assertEqual(bot.max_tokens, 100)
self.assertEqual(bot.small_model_name, "custom-small")
self.assertEqual(bot.large_model_name, "custom-large")
def test_get_llm_description(self):
bot = AnthropicTelegramInferenceBot(small_model_name="claude-test", small_model_max_tokens=500)
self.assertEqual(bot.get_llm_description(), "LLM: claude-test, Max Tokens: 500")
async def test_switch_model(self):
bot = AnthropicTelegramInferenceBot(
small_model_name="claude-small", small_model_max_tokens=10,
large_model_name="claude-large", large_model_max_tokens=20
)
self.assertEqual(bot.model, "claude-small")
self.assertEqual(bot.max_tokens, 10)
status = await bot.switch_model()
self.assertEqual(bot.model, "claude-large")
self.assertEqual(bot.max_tokens, 20)
self.assertEqual(status, "Switched to model: claude-large")
status = await bot.switch_model()
self.assertEqual(bot.model, "claude-small")
self.assertEqual(bot.max_tokens, 10)
self.assertEqual(status, "Switched to model: claude-small")
def test_get_chat_response_success_text_only(self):
bot = AnthropicTelegramInferenceBot(anthropic_client=self.mock_anthropic_client_instance)
bot.model = "test-claude"
bot.max_tokens = 150
mock_api_response = create_mock_anthropic_response(content_text="Hello from Anthropic API")
self.mock_anthropic_client_instance.messages.create.return_value = mock_api_response
messages = [{"role": "user", "content": "Hi"}] # Anthropic format
response = bot.get_chat_response(messages, []) # tools = empty list
self.mock_anthropic_client_instance.messages.create.assert_called_once_with(
model="test-claude",
max_tokens=150,
messages=messages,
system=bot.system_prompt, # Ensure system prompt is passed
tools=None, # No tools passed to API if empty list or None
tool_choice=None
)
self.assertEqual(response, mock_api_response)
def test_get_chat_response_with_tools(self):
bot = AnthropicTelegramInferenceBot(anthropic_client=self.mock_anthropic_client_instance)
bot.model = "claude-toolmaster"
bot.max_tokens = 300
mock_tools_spec = [{"name": "get_weather", "description": "Gets weather", "input_schema": {"type": "object", "properties": {}}}]
mock_api_response = create_mock_anthropic_response(content_text="Thinking...", tool_use_parts=[
{"id": "tool1", "name": "get_weather", "input": {"location": "here"}}
])
self.mock_anthropic_client_instance.messages.create.return_value = mock_api_response
messages = [{"role": "user", "content": "Weather?"}]
response = bot.get_chat_response(messages, mock_tools_spec)
self.mock_anthropic_client_instance.messages.create.assert_called_once_with(
model="claude-toolmaster",
max_tokens=300,
messages=messages,
system=bot.system_prompt,
tools=mock_tools_spec,
tool_choice={"type": "auto"}
)
self.assertEqual(response.content[0].type, "text") # First part can be text
self.assertEqual(response.content[1].type, "tool_use")
def test_get_chat_response_api_error(self):
bot = AnthropicTelegramInferenceBot(anthropic_client=self.mock_anthropic_client_instance)
self.mock_anthropic_client_instance.messages.create.side_effect = Exception("Anthropic API Down")
with self.assertRaisesRegex(Exception, "Anthropic API Down"):
bot.get_chat_response([{"role": "user", "content": "trigger"}], [])
async def test_handle_message_simple_response_no_tools(self):
# This test is more involved as it touches BaseTelegramInferenceBot's handle_message structure
# which then calls the overridden get_chat_response.
bot = AnthropicTelegramInferenceBot(anthropic_client=self.mock_anthropic_client_instance)
bot.system_prompt = "System prompt for Anthropic"
# Mock get_chat_response directly to isolate its behavior from full handle_message logic of base
# However, the point of this bot is its get_chat_response and subsequent processing.
# So, let's mock the API call within get_chat_response.
api_response = create_mock_anthropic_response(content_text="Anthropic says hello.")
self.mock_anthropic_client_instance.messages.create.return_value = api_response
# Ensure functions are empty for this test, so no tool logic is triggered
bot.functions = []
bot.tools = []
response_content = await bot.handle_message(user_id=101, user_message="Hello Anthropic")
self.assertEqual(response_content, "Anthropic says hello.")
self.assertIn(101, bot.conversation_history)
# Anthropic's handle_message structure:
# 1. User message added to history.
# 2. get_chat_response is called.
# 3. Response content (text) is extracted.
# 4. Assistant text response is added to history.
# Expected history: [User, Assistant_Text_Response] (system prompt handled by get_chat_response)
# The base class handle_message adds system prompt if not present.
# Anthropic handle_message modifies history format before calling get_chat_response.
# Let's trace Base.handle_message -> Anthropic.handle_message -> Anthropic.get_chat_response
# Base.handle_message:
# - Adds system prompt to history if first turn: `self.conversation_history[user_id] = [{"role": "system", "content": self.system_prompt}]` (OpenAI style)
# - Appends user message: `{"role": "user", "content": user_message}`
# - Calls self.get_chat_response(messages, self.functions) -> This is Anthropic's get_chat_response
# Anthropic.get_chat_response:
# - Takes OpenAI style `messages` and `self.functions` (tool specs).
# - Calls `anthropic_client.messages.create` with Anthropic style messages and system prompt.
# Anthropic.handle_message (overridden):
# - Prepares Anthropic-style messages from conversation_history (which is OpenAI style from Base)
# - Calls get_chat_response with these Anthropic messages and self.functions (tool_specs)
# - Processes response, extracts text, handles tool calls.
# - Appends *user* message (original) and *assistant* text response to self.conversation_history (OpenAI style).
# For this test, we are calling AnthropicBot.handle_message directly.
# 1. `user_id` not in `self.conversation_history`: `system_prompt` not added yet by Base logic.
# Anthropic's `handle_message` will create `anthropic_messages` from this.
# If `conversation_history` is empty, `anthropic_messages` = `[{"role": "user", "content": user_message}]`
# 2. `get_chat_response` called with `anthropic_messages` and `bot.system_prompt` passed to API.
# 3. Response "Anthropic says hello."
# 4. Original `user_message` and "Anthropic says hello." (as assistant) added to `self.conversation_history`.
history = bot.conversation_history[101]
self.assertEqual(len(history), 2) # User, Assistant
self.assertEqual(history[0]["role"], "user")
self.assertEqual(history[0]["content"], "Hello Anthropic")
self.assertEqual(history[1]["role"], "assistant")
self.assertEqual(history[1]["content"], "Anthropic says hello.")
# Check API call (made by the mocked get_chat_response indirectly)
self.mock_anthropic_client_instance.messages.create.assert_called_once()
call_args = self.mock_anthropic_client_instance.messages.create.call_args
self.assertEqual(call_args.kwargs["system"], "System prompt for Anthropic")
# Initial messages for API should just be the user message for first turn
self.assertEqual(call_args.kwargs["messages"], [{"role": "user", "content": "Hello Anthropic"}])
async def test_handle_message_with_tool_calls(self):
bot = AnthropicTelegramInferenceBot(anthropic_client=self.mock_anthropic_client_instance)
bot.system_prompt = "You are a helpful, tool-using assistant."
# Define a tool for the bot (OpenAI format, will be converted by Anthropic bot for API)
mock_tool_oai_format = {"type": "function", "function": {"name": "get_weather", "description": "Get weather", "parameters": {}}}
bot.functions = [mock_tool_oai_format] # This is used to generate anthropic_tools for API
# API Response 1: Request for tool call
tool_use_part = {"id": "toolu_xyz", "name": "get_weather", "input": {"location": "paris"}}
api_response_1 = create_mock_anthropic_response(tool_use_parts=[tool_use_part])
# API Response 2: Final text response after tool execution
api_response_2 = create_mock_anthropic_response(content_text="The weather in Paris is nice.")
self.mock_anthropic_client_instance.messages.create.side_effect = [api_response_1, api_response_2]
# Mock the bot's call_tool method (from BaseTelegramInferenceBot)
bot.call_tool = MagicMock(return_value='''{"weather": "sunny"}''') # Tool execution result
user_id = 102
user_message = "What's the weather in Paris?"
final_text_response = await bot.handle_message(user_id, user_message)
self.assertEqual(final_text_response, "The weather in Paris is nice.")
self.assertEqual(self.mock_anthropic_client_instance.messages.create.call_count, 2)
bot.call_tool.assert_called_once_with("get_weather", {"location": "paris"}) # Anthropic passes input as dict
# Check conversation history (OpenAI style)
history = bot.conversation_history[user_id]
self.assertEqual(history[0]["role"], "user")
self.assertEqual(history[0]["content"], user_message)
# Assistant message that requested tool call (Anthropic-specific format stored by its handle_message)
# Anthropic's handle_message appends the raw tool_use block and then the tool_result
self.assertEqual(history[1]["role"], "assistant")
self.assertTrue(isinstance(history[1]["content"], list)) # Anthropic content is a list
self.assertEqual(history[1]["content"][0]["type"], "tool_use")
self.assertEqual(history[1]["content"][0]["id"], "toolu_xyz")
self.assertEqual(history[2]["role"], "tool")
self.assertEqual(history[2]["tool_call_id"], "toolu_xyz")
self.assertEqual(history[2]["name"], "get_weather")
self.assertEqual(history[2]["content"], '''{"weather": "sunny"}''') # call_tool result
self.assertEqual(history[3]["role"], "assistant") # Final text response
self.assertTrue(isinstance(history[3]["content"], str)) # simple text
self.assertEqual(history[3]["content"], "The weather in Paris is nice.")
if __name__ == '__main__':
unittest.main()
-310
View File
@@ -1,310 +0,0 @@
import unittest
from unittest.mock import patch, mock_open, MagicMock
import os
import json
# Ensure the path includes the directory where base_telegram_inference_bot is located
# This might require adjustment based on actual project structure if tests are run from root
# For now, assuming it can be imported directly or via PYTHONPATH
from base_telegram_inference_bot import BaseTelegramInferenceBot
from tools.base_tool import BaseTool # For mocking tool structure
# Create a concrete subclass for testing, as BaseTelegramInferenceBot is abstract
class ConcreteTestBot(BaseTelegramInferenceBot):
def __init__(self, system_prompt_content=None, system_prompt_path=None, mock_tools=None, mock_functions=None):
# Mock load_functions during super().__init__ if needed, or control tools/functions directly
self._mock_tools = mock_tools if mock_tools is not None else []
self._mock_functions = mock_functions if mock_functions is not None else []
super().__init__(system_prompt_content=system_prompt_content, system_prompt_path=system_prompt_path)
# Override load_functions to use mocks
def load_functions(self):
return self._mock_tools, self._mock_functions
def get_chat_response(self, messages):
pass # Abstract method, not tested here directly
async def handle_message(self, user_id, user_message):
pass # Abstract method
def get_llm_description(self) -> str:
return "Mock LLM Description" # Concrete implementation for testing get_bot_status
async def start(self):
pass # Abstract method
async def abort_processing(self, user_id):
pass # Abstract method
async def switch_model(self):
pass # Abstract method
class TestBaseTelegramInferenceBot(unittest.TestCase):
def setUp(self):
# Reset relevant environment variables before each test
self.original_system_prompt_path = os.environ.get("SYSTEM_PROMPT_PATH")
if "SYSTEM_PROMPT_PATH" in os.environ:
del os.environ["SYSTEM_PROMPT_PATH"]
def tearDown(self):
# Restore environment variables
if self.original_system_prompt_path:
os.environ["SYSTEM_PROMPT_PATH"] = self.original_system_prompt_path
elif "SYSTEM_PROMPT_PATH" in os.environ: # Ensure it's removed if test set it and it wasn't there before
del os.environ["SYSTEM_PROMPT_PATH"]
def test_init_with_direct_system_prompt(self):
bot = ConcreteTestBot(system_prompt_content="Direct prompt content")
self.assertEqual(bot.system_prompt, "Direct prompt content")
@patch("os.path.isfile")
@patch("builtins.open", new_callable=mock_open, read_data="File prompt content")
def test_init_with_system_prompt_path_argument(self, mock_file_open, mock_isfile):
mock_isfile.return_value = True
bot = ConcreteTestBot(system_prompt_path="dummy/path.txt")
self.assertEqual(bot.system_prompt, "File prompt content")
mock_isfile.assert_called_once_with("dummy/path.txt")
mock_file_open.assert_called_once_with("dummy/path.txt", "r", encoding="utf-8")
@patch("os.path.isfile")
@patch("builtins.open", new_callable=mock_open, read_data="Env prompt content")
def test_init_with_env_system_prompt_path(self, mock_file_open, mock_isfile):
mock_isfile.return_value = True
os.environ["SYSTEM_PROMPT_PATH"] = "env/path.txt"
bot = ConcreteTestBot()
self.assertEqual(bot.system_prompt, "Env prompt content")
mock_isfile.assert_called_once_with("env/path.txt")
mock_file_open.assert_called_once_with("env/path.txt", "r", encoding="utf-8")
def test_init_with_default_system_prompt(self):
# Ensure ENV var is not set for this test
if "SYSTEM_PROMPT_PATH" in os.environ:
del os.environ["SYSTEM_PROMPT_PATH"]
bot = ConcreteTestBot()
self.assertEqual(bot.system_prompt, "You are a helpful AI assistant.")
@patch("os.path.isfile", return_value=False)
def test_init_with_invalid_system_prompt_path(self, mock_isfile):
bot = ConcreteTestBot(system_prompt_path="invalid/path.txt")
self.assertEqual(bot.system_prompt, "You are a helpful AI assistant.")
mock_isfile.assert_called_once_with("invalid/path.txt")
@patch("os.path.isfile")
@patch("builtins.open", side_effect=IOError("File read error"))
def test_init_with_system_prompt_file_read_error(self, mock_file_open, mock_isfile):
mock_isfile.return_value = True
bot = ConcreteTestBot(system_prompt_path="dummy/path.txt")
self.assertEqual(bot.system_prompt, "You are a helpful AI assistant.")
def test_clear_conversation_history(self):
mock_tool_instance = MagicMock(spec=BaseTool)
bot = ConcreteTestBot(mock_tools=[mock_tool_instance])
bot.conversation_history[123] = [{"role": "user", "content": "Hello"}]
bot.clear_conversation_history(123)
self.assertNotIn(123, bot.conversation_history)
mock_tool_instance.clear.assert_called_once()
def test_clear_conversation_history_user_not_found(self):
mock_tool_instance = MagicMock(spec=BaseTool)
bot = ConcreteTestBot(mock_tools=[mock_tool_instance])
bot.clear_conversation_history(404)
self.assertNotIn(404, bot.conversation_history)
mock_tool_instance.clear.assert_called_once()
def test_processing_status(self):
bot = ConcreteTestBot()
self.assertEqual(bot.processing_status, {})
bot.set_processing_status(123, 789)
self.assertEqual(bot.processing_status[123], {"processing": True, "message_id": 789})
bot.clear_processing_status(123)
self.assertNotIn(123, bot.processing_status)
def test_clear_processing_status_user_not_found(self):
bot = ConcreteTestBot()
bot.clear_processing_status(404)
self.assertNotIn(404, bot.processing_status)
def test_call_tool_success_dict_args(self):
mock_tool = MagicMock(spec=BaseTool)
mock_tool.get_functions.return_value = [
{"function": {"name": "test_tool", "description": "A test tool", "parameters": {}}}
]
mock_tool.execute.return_value = "Tool executed successfully"
bot = ConcreteTestBot(mock_tools=[mock_tool], mock_functions=mock_tool.get_functions())
result = bot.call_tool("test_tool", {"arg1": "value1"})
self.assertEqual(result, "Tool executed successfully")
mock_tool.execute.assert_called_once_with("test_tool", arg1="value1")
def test_call_tool_success_json_string_args(self):
mock_tool = MagicMock(spec=BaseTool)
mock_tool.get_functions.return_value = [
{"function": {"name": "test_tool_json", "parameters": {}}}
]
mock_tool.execute.return_value = "Tool JSON OK"
bot = ConcreteTestBot(mock_tools=[mock_tool], mock_functions=mock_tool.get_functions())
args_json_str = '''{"param": "value"}'''
result = bot.call_tool("test_tool_json", args_json_str)
self.assertEqual(result, "Tool JSON OK")
mock_tool.execute.assert_called_once_with("test_tool_json", param="value")
def test_call_tool_malformed_json_string_args(self):
bot = ConcreteTestBot(mock_tools=[])
args_malformed_json_str = '''{"param": "value"'''
result = bot.call_tool("some_tool", args_malformed_json_str)
self.assertTrue("Error: Malformed arguments for tool call" in result)
def test_call_tool_unexpected_arg_type(self):
bot = ConcreteTestBot(mock_tools=[])
result = bot.call_tool("some_tool", 12345) # Integer instead of dict/str
self.assertTrue("Error: Invalid argument type for tool call" in result)
def test_call_tool_none_args(self):
mock_tool = MagicMock(spec=BaseTool)
mock_tool.get_functions.return_value = [
{"function": {"name": "test_tool_none", "parameters": {}}}
]
mock_tool.execute.return_value = "Tool None OK"
bot = ConcreteTestBot(mock_tools=[mock_tool], mock_functions=mock_tool.get_functions())
result = bot.call_tool("test_tool_none", None)
self.assertEqual(result, "Tool None OK")
mock_tool.execute.assert_called_once_with("test_tool_none") # No kwargs if None
def test_call_tool_not_found(self):
bot = ConcreteTestBot(mock_tools=[])
result = bot.call_tool("non_existent_tool", {})
self.assertEqual(result, "Error: Tool function non_existent_tool not found.")
def test_call_tool_execute_exception(self):
mock_tool = MagicMock(spec=BaseTool)
mock_tool.get_functions.return_value = [{"function": {"name": "error_tool", "parameters": {}}}]
mock_tool.execute.side_effect = Exception("Execution failed")
bot = ConcreteTestBot(mock_tools=[mock_tool], mock_functions=mock_tool.get_functions())
result = bot.call_tool("error_tool", {})
self.assertEqual(result, "Error executing tool error_tool: Execution failed")
def test_get_system_prompt_description(self):
if "SYSTEM_PROMPT_PATH" in os.environ: # Ensure clean state
del os.environ["SYSTEM_PROMPT_PATH"]
bot_default = ConcreteTestBot()
self.assertEqual(bot_default.get_system_prompt_description(), "System Prompt: Default")
bot_custom_content = ConcreteTestBot(system_prompt_content="Custom content here")
self.assertEqual(bot_custom_content.get_system_prompt_description(), "System Prompt: Custom")
os.environ["SYSTEM_PROMPT_PATH"] = "some/path.txt"
bot_env_default_prompt = ConcreteTestBot() # system_prompt itself is default
self.assertEqual(bot_env_default_prompt.get_system_prompt_description(), "System Prompt: Custom (via ENV)")
with patch("os.path.isfile", return_value=True), \
patch("builtins.open", mock_open(read_data="File prompt from ENV")):
bot_env_file_prompt = ConcreteTestBot() # system_prompt gets loaded from ENV path
self.assertEqual(bot_env_file_prompt.get_system_prompt_description(), "System Prompt: Custom")
del os.environ["SYSTEM_PROMPT_PATH"]
with patch("os.path.isfile", return_value=True), \
patch("builtins.open", mock_open(read_data="File prompt from arg")):
bot_custom_file_arg = ConcreteTestBot(system_prompt_path="custom/file.txt")
self.assertEqual(bot_custom_file_arg.get_system_prompt_description(), "System Prompt: Custom")
@patch.object(ConcreteTestBot, 'get_llm_description', return_value="Test LLM Description")
@patch.object(ConcreteTestBot, 'get_system_prompt_description', return_value="Test Prompt Description")
async def test_get_bot_status(self, mock_prompt_desc, mock_llm_desc):
bot = ConcreteTestBot()
status = await bot.get_bot_status()
self.assertEqual(status, "Test Prompt Description\nTest LLM Description")
mock_prompt_desc.assert_called_once()
mock_llm_desc.assert_called_once()
@patch('os.path.dirname', return_value='/mock/path')
@patch('os.path.join')
@patch('os.path.exists')
@patch('os.listdir')
@patch('importlib.import_module')
def test_load_functions_no_tools_dir(self, mock_import_module, mock_listdir, mock_exists, mock_join, mock_dirname):
mock_join.return_value = '/mock/path/tools'
mock_exists.return_value = False
class BotForLoadTest(BaseTelegramInferenceBot):
load_system_prompt = MagicMock(return_value="Default")
get_chat_response = MagicMock(); handle_message = MagicMock(); get_llm_description = MagicMock(return_value="mock")
start = MagicMock(); abort_processing = MagicMock(); switch_model = MagicMock()
bot = BotForLoadTest()
self.assertEqual(bot.tools, [])
self.assertEqual(bot.functions, [])
mock_listdir.assert_not_called()
@patch('os.path.dirname', return_value='/mock/base_bot_dir')
@patch('os.path.join', side_effect=lambda *args: os.path.normpath(os.path.join(*args)))
@patch('os.path.exists', return_value=True)
@patch('os.listdir', return_value=['my_tool.py', '__init__.py', 'base_tool.py'])
@patch('importlib.import_module')
def test_load_functions_with_one_tool(self, mock_import_module, mock_listdir, mock_exists, mock_join, mock_dirname):
mock_tool_class = MagicMock(spec=BaseTool) # This is the class itself
mock_tool_instance = MagicMock(spec=BaseTool) # This is the instance
mock_tool_class.return_value = mock_tool_instance # mock_tool_class() creates mock_tool_instance
mock_tool_instance.get_functions.return_value = [{"function": {"name": "sample_function"}}]
mock_my_tool_module = MagicMock()
# Simulate inspect.getmembers behavior: returns list of (name, member) tuples
# Only include members that are classes, derive from BaseTool, and are not BaseTool itself.
mock_my_tool_module.ValidToolClass = mock_tool_class
mock_my_tool_module.NotATool = object()
mock_my_tool_module.BaseTool = BaseTool # This should be skipped by the loader
def import_side_effect(module_name):
if module_name == 'tools.my_tool':
return mock_my_tool_module
raise ImportError(f"Unexpected import: {module_name}")
mock_import_module.side_effect = import_side_effect
class BotForLoadTest(BaseTelegramInferenceBot):
load_system_prompt = MagicMock(return_value="Default")
get_chat_response = MagicMock(); handle_message = MagicMock(); get_llm_description = MagicMock(return_value="mock")
start = MagicMock(); abort_processing = MagicMock(); switch_model = MagicMock()
bot = BotForLoadTest()
self.assertEqual(len(bot.tools), 1)
self.assertIs(bot.tools[0], mock_tool_instance)
self.assertEqual(len(bot.functions), 1)
self.assertEqual(bot.functions[0]['function']['name'], "sample_function")
mock_import_module.assert_called_once_with('tools.my_tool')
mock_tool_class.assert_called_once_with() # Tool class was instantiated
mock_tool_instance.get_functions.assert_called_once_with()
@patch('os.path.dirname', return_value='/mock/base_bot_dir')
@patch('os.path.join', side_effect=lambda *args: os.path.normpath(os.path.join(*args)))
@patch('os.path.exists', return_value=True)
@patch('os.listdir', return_value=['tool_with_init_error.py'])
@patch('importlib.import_module')
@patch('logging.error') # Mock logging to check for error messages
def test_load_functions_tool_instantiation_error(self, mock_logging_error, mock_import_module, mock_listdir, mock_exists, mock_join, mock_dirname):
mock_tool_class_init_error = MagicMock(spec=BaseTool)
mock_tool_class_init_error.side_effect = Exception("Failed to init tool") # Error on instantiation
mock_error_tool_module = MagicMock()
mock_error_tool_module.ToolWithInitError = mock_tool_class_init_error
mock_import_module.return_value = mock_error_tool_module
class BotForLoadTest(BaseTelegramInferenceBot):
load_system_prompt = MagicMock(return_value="Default")
get_chat_response = MagicMock(); handle_message = MagicMock(); get_llm_description = MagicMock(return_value="mock")
start = MagicMock(); abort_processing = MagicMock(); switch_model = MagicMock()
bot = BotForLoadTest()
self.assertEqual(len(bot.tools), 0)
self.assertEqual(len(bot.functions), 0)
mock_logging_error.assert_any_call("Error instantiating tool ToolWithInitError from tool_with_init_error.py: Failed to init tool")
if __name__ == '__main__':
unittest.main(闂傚лен䦗婢у埊鍓解劓姣)
@@ -1,158 +0,0 @@
import unittest
from unittest.mock import MagicMock, patch, ANY
import os
# Assuming chatgpt_telegram_inference_bot.py and its parent are accessible
from chatgpt_telegram_inference_bot import ChatGPTTelegramInferenceBot
from openai_compatible_inference_bot import OpenAICompatibleInferenceBot # For patching super
class TestChatGPTTelegramInferenceBot(unittest.IsolatedAsyncioTestCase):
def setUp(self):
# Store and clear relevant environment variables
self.original_openai_key = os.environ.get("OPENAI_API_KEY")
self.original_small_model = os.environ.get("OPENAI_SMALL_MODEL")
self.original_large_model = os.environ.get("OPENAI_LARGE_MODEL")
self.original_small_tokens = os.environ.get("OPENAI_SMALL_MODEL_MAX_TOKENS")
self.original_large_tokens = os.environ.get("OPENAI_LARGE_MODEL_MAX_TOKENS")
self.original_system_prompt_path = os.environ.get("SYSTEM_PROMPT_PATH")
for key in ["OPENAI_API_KEY", "OPENAI_SMALL_MODEL", "OPENAI_LARGE_MODEL",
"OPENAI_SMALL_MODEL_MAX_TOKENS", "OPENAI_LARGE_MODEL_MAX_TOKENS", "SYSTEM_PROMPT_PATH"]:
if os.environ.get(key):
del os.environ[key]
# Mock the OpenAI client that OpenAICompatibleInferenceBot's __init__ might create
self.mock_openai_client = MagicMock()
def tearDown(self):
# Restore environment variables
if self.original_openai_key: os.environ["OPENAI_API_KEY"] = self.original_openai_key
if self.original_small_model: os.environ["OPENAI_SMALL_MODEL"] = self.original_small_model
if self.original_large_model: os.environ["OPENAI_LARGE_MODEL"] = self.original_large_model
if self.original_small_tokens: os.environ["OPENAI_SMALL_MODEL_MAX_TOKENS"] = self.original_small_tokens
if self.original_large_tokens: os.environ["OPENAI_LARGE_MODEL_MAX_TOKENS"] = self.original_large_tokens
if self.original_system_prompt_path: os.environ["SYSTEM_PROMPT_PATH"] = self.original_system_prompt_path
@patch.object(OpenAICompatibleInferenceBot, '__init__') # Mock the superclass's __init__
def test_init_defaults_and_super_call(self, mock_super_init):
os.environ["OPENAI_API_KEY"] = "test_key_chatgpt"
os.environ["OPENAI_SMALL_MODEL"] = "gpt-3.5-turbo-env"
os.environ["OPENAI_SMALL_MODEL_MAX_TOKENS"] = "350"
bot = ChatGPTTelegramInferenceBot()
mock_super_init.assert_called_once_with(
client=None, # ChatGPT bot will let superclass create it
api_key="test_key_chatgpt", # Passed to super
base_url=None,
api_version=None,
azure_deployment=None,
model_name="gpt-3.5-turbo-env", # Default small model from env
max_tokens_str="350", # Default small model tokens from env
small_model_name="gpt-3.5-turbo-env",
small_model_max_tokens_str="350",
large_model_name=os.environ.get("OPENAI_LARGE_MODEL", "gpt-4-turbo-preview"), # Default large
large_model_max_tokens_str=os.environ.get("OPENAI_LARGE_MODEL_MAX_TOKENS"),
system_prompt_content=None,
system_prompt_path=None,
is_gemini=False,
max_history_length=20 # Default from OpenAICompatibleInferenceBot
)
@patch.object(OpenAICompatibleInferenceBot, '__init__')
def test_init_with_arguments(self, mock_super_init):
mock_client_arg = MagicMock()
bot = ChatGPTTelegramInferenceBot(
openai_client=mock_client_arg,
api_key="arg_key",
small_model_name="arg_small_model",
small_model_max_tokens="123",
large_model_name="arg_large_model",
large_model_max_tokens="456",
system_prompt_content="Arg prompt"
)
mock_super_init.assert_called_once_with(
client=mock_client_arg,
api_key="arg_key",
base_url=None,
api_version=None,
azure_deployment=None,
model_name="arg_small_model", # Initially configured with small model
max_tokens_str="123",
small_model_name="arg_small_model",
small_model_max_tokens_str="123",
large_model_name="arg_large_model",
large_model_max_tokens_str="456",
system_prompt_content="Arg prompt",
system_prompt_path=None,
is_gemini=False,
max_history_length=20
)
# Test switch_model - this method is part of ChatGPTTelegramInferenceBot
# It calls _configure_model_and_tokens which is in the superclass.
# We need a bot instance where _configure_model_and_tokens can be called.
@patch('openai.OpenAI') # To allow instantiation of the bot by mocking client creation
async def test_switch_model_logic(self, mock_openai_constructor):
mock_openai_constructor.return_value = self.mock_openai_client # Mock client creation in super
# Set env vars for model names that switch_model will use as fallback
os.environ["OPENAI_SMALL_MODEL"] = "env-small-gpt"
os.environ["OPENAI_SMALL_MODEL_MAX_TOKENS"] = "100"
os.environ["OPENAI_LARGE_MODEL"] = "env-large-gpt"
os.environ["OPENAI_LARGE_MODEL_MAX_TOKENS"] = "200"
# Instantiate with initial model (small)
bot = ChatGPTTelegramInferenceBot()
self.assertEqual(bot.model, "env-small-gpt")
self.assertEqual(bot.max_tokens, 100)
# Switch to large
status = await bot.switch_model()
self.assertEqual(bot.model, "env-large-gpt")
self.assertEqual(bot.max_tokens, 200)
self.assertEqual(status, "Switched to model: env-large-gpt")
# Switch back to small
status = await bot.switch_model()
self.assertEqual(bot.model, "env-small-gpt")
self.assertEqual(bot.max_tokens, 100)
self.assertEqual(status, "Switched to model: env-small-gpt")
@patch('openai.OpenAI')
async def test_switch_model_uses_instance_configs_if_provided(self, mock_openai_constructor):
mock_openai_constructor.return_value = self.mock_openai_client
# Instantiate with specific model names, overriding potential env vars
bot = ChatGPTTelegramInferenceBot(
small_model_name="init-small", small_model_max_tokens="50",
large_model_name="init-large", large_model_max_tokens="150"
)
self.assertEqual(bot.model, "init-small") # Starts with small
self.assertEqual(bot.max_tokens, 50)
# Switch to large
status = await bot.switch_model()
self.assertEqual(bot.model, "init-large")
self.assertEqual(bot.max_tokens, 150)
self.assertEqual(status, "Switched to model: init-large")
# Switch back to small
status = await bot.switch_model()
self.assertEqual(bot.model, "init-small")
self.assertEqual(bot.max_tokens, 50)
self.assertEqual(status, "Switched to model: init-small")
# get_llm_description is inherited from OpenAICompatibleInferenceBot.
# Test just to ensure it works in the context of a ChatGPTBot instance
@patch('openai.OpenAI')
def test_get_llm_description_for_chatgpt_bot(self, mock_openai_constructor):
mock_openai_constructor.return_value = self.mock_openai_client
bot = ChatGPTTelegramInferenceBot(small_model_name="gpt-3.5-desc", small_model_max_tokens="777")
# Initially configured with small model
self.assertEqual(bot.get_llm_description(), "LLM: gpt-3.5-desc, Max Tokens: 777, Azure: False")
if __name__ == '__main__':
unittest.main()
-154
View File
@@ -1,154 +0,0 @@
import unittest
from unittest.mock import MagicMock, patch, ANY
import os
# Assuming gemini_telegram_inference_bot.py and its parent are accessible
from gemini_telegram_inference_bot import GeminiTelegramInferenceBot
from openai_compatible_inference_bot import OpenAICompatibleInferenceBot # For patching super
class TestGeminiTelegramInferenceBot(unittest.IsolatedAsyncioTestCase):
def setUp(self):
# Store and clear relevant environment variables
self.original_gemini_key = os.environ.get("GEMINI_API_KEY")
self.original_gemini_base_url = os.environ.get("GEMINI_API_BASE_URL")
self.original_small_model = os.environ.get("GEMINI_SMALL_MODEL")
self.original_large_model = os.environ.get("GEMINI_LARGE_MODEL")
self.original_small_tokens = os.environ.get("GEMINI_SMALL_MODEL_MAX_TOKENS")
self.original_large_tokens = os.environ.get("GEMINI_LARGE_MODEL_MAX_TOKENS")
self.original_system_prompt_path = os.environ.get("SYSTEM_PROMPT_PATH")
for key in ["GEMINI_API_KEY", "GEMINI_API_BASE_URL", "GEMINI_SMALL_MODEL", "GEMINI_LARGE_MODEL",
"GEMINI_SMALL_MODEL_MAX_TOKENS", "GEMINI_LARGE_MODEL_MAX_TOKENS", "SYSTEM_PROMPT_PATH"]:
if os.environ.get(key):
del os.environ[key]
self.mock_openai_client = MagicMock() # Used if superclass creates an OpenAI client
def tearDown(self):
# Restore environment variables
if self.original_gemini_key: os.environ["GEMINI_API_KEY"] = self.original_gemini_key
if self.original_gemini_base_url: os.environ["GEMINI_API_BASE_URL"] = self.original_gemini_base_url
if self.original_small_model: os.environ["GEMINI_SMALL_MODEL"] = self.original_small_model
if self.original_large_model: os.environ["GEMINI_LARGE_MODEL"] = self.original_large_model
if self.original_small_tokens: os.environ["GEMINI_SMALL_MODEL_MAX_TOKENS"] = self.original_small_tokens
if self.original_large_tokens: os.environ["GEMINI_LARGE_MODEL_MAX_TOKENS"] = self.original_large_tokens
if self.original_system_prompt_path: os.environ["SYSTEM_PROMPT_PATH"] = self.original_system_prompt_path
@patch.object(OpenAICompatibleInferenceBot, '__init__') # Mock the superclass's __init__
def test_init_defaults_and_super_call(self, mock_super_init):
os.environ["GEMINI_API_KEY"] = "test_key_gemini"
os.environ["GEMINI_API_BASE_URL"] = "https://gemini.env.com"
os.environ["GEMINI_SMALL_MODEL"] = "gemini-pro-env"
os.environ["GEMINI_SMALL_MODEL_MAX_TOKENS"] = "360"
bot = GeminiTelegramInferenceBot()
mock_super_init.assert_called_once_with(
client=None,
api_key="test_key_gemini",
base_url="https://gemini.env.com", # Passed to super
api_version=None,
azure_deployment=None,
model_name="gemini-pro-env",
max_tokens_str="360",
small_model_name="gemini-pro-env",
small_model_max_tokens_str="360",
large_model_name=os.environ.get("GEMINI_LARGE_MODEL", "gemini-1.5-pro-latest"), # Default large
large_model_max_tokens_str=os.environ.get("GEMINI_LARGE_MODEL_MAX_TOKENS"),
system_prompt_content=None,
system_prompt_path=None,
is_gemini=True, # Important for Gemini bot
max_history_length=20
)
@patch.object(OpenAICompatibleInferenceBot, '__init__')
def test_init_with_arguments(self, mock_super_init):
mock_client_arg = MagicMock()
bot = GeminiTelegramInferenceBot(
openai_client=mock_client_arg, # Name in Gemini bot is openai_client for consistency
api_key="arg_gem_key",
base_url="https://arg.gemini.com",
small_model_name="arg_gem_small",
small_model_max_tokens="124",
large_model_name="arg_gem_large",
large_model_max_tokens="457",
system_prompt_content="Gemini prompt"
)
mock_super_init.assert_called_once_with(
client=mock_client_arg,
api_key="arg_gem_key",
base_url="https://arg.gemini.com",
api_version=None,
azure_deployment=None,
model_name="arg_gem_small",
max_tokens_str="124",
small_model_name="arg_gem_small",
small_model_max_tokens_str="124",
large_model_name="arg_gem_large",
large_model_max_tokens_str="457",
system_prompt_content="Gemini prompt",
system_prompt_path=None,
is_gemini=True,
max_history_length=20
)
@patch('openai.OpenAI') # Gemini bot uses OpenAI client configured for Gemini endpoint
async def test_switch_model_logic(self, mock_openai_constructor):
mock_openai_constructor.return_value = self.mock_openai_client
os.environ["GEMINI_SMALL_MODEL"] = "env-gemini-small"
os.environ["GEMINI_SMALL_MODEL_MAX_TOKENS"] = "110"
os.environ["GEMINI_LARGE_MODEL"] = "env-gemini-large"
os.environ["GEMINI_LARGE_MODEL_MAX_TOKENS"] = "220"
bot = GeminiTelegramInferenceBot() # Uses env vars by default
self.assertEqual(bot.model, "env-gemini-small")
self.assertEqual(bot.max_tokens, 110)
status = await bot.switch_model()
self.assertEqual(bot.model, "env-gemini-large")
self.assertEqual(bot.max_tokens, 220)
self.assertEqual(status, "Switched to model: env-gemini-large")
status = await bot.switch_model()
self.assertEqual(bot.model, "env-gemini-small")
self.assertEqual(bot.max_tokens, 110)
self.assertEqual(status, "Switched to model: env-gemini-small")
@patch('openai.OpenAI')
async def test_switch_model_uses_instance_configs_if_provided(self, mock_openai_constructor):
mock_openai_constructor.return_value = self.mock_openai_client
bot = GeminiTelegramInferenceBot(
small_model_name="init-gem-small", small_model_max_tokens="55",
large_model_name="init-gem-large", large_model_max_tokens="155"
)
self.assertEqual(bot.model, "init-gem-small")
self.assertEqual(bot.max_tokens, 55)
status = await bot.switch_model()
self.assertEqual(bot.model, "init-gem-large")
self.assertEqual(bot.max_tokens, 155)
self.assertEqual(status, "Switched to model: init-gem-large")
status = await bot.switch_model()
self.assertEqual(bot.model, "init-gem-small")
self.assertEqual(bot.max_tokens, 55)
self.assertEqual(status, "Switched to model: init-gem-small")
@patch('openai.OpenAI')
def test_get_llm_description_for_gemini_bot(self, mock_openai_constructor):
mock_openai_constructor.return_value = self.mock_openai_client
bot = GeminiTelegramInferenceBot(
small_model_name="gemini-pro-desc",
small_model_max_tokens="888",
# is_gemini is True by default in constructor call to super
)
# LLM description should indicate not Azure, even though it uses OpenAICompatible... base
# The is_gemini flag primarily affects client instantiation logic in the superclass.
# The azure_openai flag in superclass is based on azure_endpoint presence.
self.assertEqual(bot.get_llm_description(), "LLM: gemini-pro-desc, Max Tokens: 888, Azure: False")
if __name__ == '__main__':
unittest.main()
-81
View File
@@ -1,81 +0,0 @@
# tests/test_github_tool.py
import unittest
from unittest.mock import patch, MagicMock
from tools.github_tool import GitHubTool
class TestGitHubTool(unittest.TestCase):
def setUp(self):
self.github_tool = GitHubTool()
def test_get_functions(self):
functions = self.github_tool.get_functions()
self.assertEqual(len(functions), 4)
function_names = [f["name"] for f in functions]
expected_names = ["read_file", "create_branch", "commit_file", "create_pull_request"]
self.assertListEqual(function_names, expected_names)
@patch('tools.github_tool.requests.get')
def test_read_file(self, mock_get):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {"content": "file content"}
mock_get.return_value = mock_response
result = self.github_tool.execute("read_file", path="test.txt")
self.assertEqual(result, "file content")
mock_get.assert_called_once()
@patch('tools.github_tool.requests.get')
@patch('tools.github_tool.requests.post')
def test_create_branch(self, mock_post, mock_get):
mock_get_response = MagicMock()
mock_get_response.status_code = 200
mock_get_response.json.return_value = {"object": {"sha": "test_sha"}}
mock_get.return_value = mock_get_response
mock_post_response = MagicMock()
mock_post_response.status_code = 201
mock_post.return_value = mock_post_response
result = self.github_tool.execute("create_branch", branch_name="test-branch")
self.assertEqual(result, "Branch 'test-branch' created successfully")
mock_get.assert_called_once()
mock_post.assert_called_once()
@patch('tools.github_tool.requests.put')
def test_commit_file(self, mock_put):
mock_response = MagicMock()
mock_response.status_code = 200
mock_put.return_value = mock_response
result = self.github_tool.execute("commit_file", branch_name="test-branch", file_path="test.txt", content="test content", commit_message="Test commit")
self.assertEqual(result, "File committed successfully to branch 'test-branch'")
mock_put.assert_called_once()
def test_commit_file_to_main(self):
result = self.github_tool.execute("commit_file", branch_name="main", file_path="test.txt", content="test content", commit_message="Test commit")
self.assertEqual(result, "Cannot commit directly to main branch")
@patch('tools.github_tool.requests.post')
def test_create_pull_request(self, mock_post):
mock_response = MagicMock()
mock_response.status_code = 201
mock_response.json.return_value = {"html_url": "https://github.com/test/test/pull/1"}
mock_post.return_value = mock_response
result = self.github_tool.execute("create_pull_request", title="Test PR", body="Test body", head="test-branch")
self.assertEqual(result, "Pull request created successfully: https://github.com/test/test/pull/1")
mock_post.assert_called_once()
def test_unknown_function(self):
result = self.github_tool.execute("unknown_function")
self.assertEqual(result, "Unknown function: unknown_function")
if __name__ == '__main__':
unittest.main()
@@ -1,332 +0,0 @@
import unittest
from unittest.mock import MagicMock, patch, AsyncMock, ANY
import os
import json
# Assuming openai_compatible_inference_bot.py is in the parent directory or PYTHONPATH is set
from openai_compatible_inference_bot import OpenAICompatibleInferenceBot
# Mock response from OpenAI client's chat.completions.create
def create_mock_openai_response(content=None, tool_calls=None):
mock_message = MagicMock()
mock_message.role = "assistant"
mock_message.content = content
if tool_calls:
# tool_calls should be a list of objects with id and function (name, arguments)
mock_tool_calls = []
for tc in tool_calls:
mock_tc = MagicMock()
mock_tc.id = tc["id"]
mock_tc.function.name = tc["function"]["name"]
mock_tc.function.arguments = tc["function"]["arguments"]
mock_tool_calls.append(mock_tc)
mock_message.tool_calls = mock_tool_calls
else:
mock_message.tool_calls = None
mock_choice = MagicMock()
mock_choice.message = mock_message
mock_response = MagicMock()
mock_response.choices = [mock_choice]
return mock_response
# Concrete class for testing
class ConcreteOpenAICompatibleBot(OpenAICompatibleInferenceBot):
# Implement abstract methods for instantiation
async def switch_model(self):
# Simple switch for testing if needed, or just pass
if self.model == self.small_model_name:
self._configure_model_and_tokens(self.large_model_name, self.large_model_max_tokens_str)
else:
self._configure_model_and_tokens(self.small_model_name, self.small_model_max_tokens_str)
return f"Switched to {self.model}"
# Override load_functions if it's called by parent and needs mocking for these tests
# (OpenAICompatibleInferenceBot's __init__ calls BaseTelegramInferenceBot's __init__, which calls load_functions)
def load_functions(self):
# For these tests, assume no tools unless specifically added
self.tools = []
self.functions = []
return self.tools, self.functions
class TestOpenAICompatibleInferenceBot(unittest.IsolatedAsyncioTestCase):
def setUp(self):
self.original_openai_api_key = os.environ.get("OPENAI_API_KEY")
self.original_azure_openai_key = os.environ.get("AZURE_OPENAI_KEY")
self.original_azure_endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT")
self.original_api_version = os.environ.get("AZURE_OPENAI_API_VERSION")
self.original_azure_deployment = os.environ.get("AZURE_DEPLOYMENT_NAME")
# Clear relevant env vars before each test
for key in ["OPENAI_API_KEY", "AZURE_OPENAI_KEY", "AZURE_OPENAI_ENDPOINT",
"AZURE_OPENAI_API_VERSION", "AZURE_DEPLOYMENT_NAME", "SYSTEM_PROMPT_PATH"]:
if os.environ.get(key):
del os.environ[key]
self.mock_openai_client_instance = MagicMock()
self.mock_openai_client_instance.chat.completions.create = MagicMock()
def tearDown(self):
# Restore environment variables
if self.original_openai_api_key: os.environ["OPENAI_API_KEY"] = self.original_openai_api_key
if self.original_azure_openai_key: os.environ["AZURE_OPENAI_KEY"] = self.original_azure_openai_key
if self.original_azure_endpoint: os.environ["AZURE_OPENAI_ENDPOINT"] = self.original_azure_endpoint
if self.original_api_version: os.environ["AZURE_OPENAI_API_VERSION"] = self.original_api_version
if self.original_azure_deployment: os.environ["AZURE_DEPLOYMENT_NAME"] = self.original_azure_deployment
@patch('openai.OpenAI')
def test_init_with_openai_defaults(self, MockOpenAIConstructor):
MockOpenAIConstructor.return_value = self.mock_openai_client_instance
os.environ["OPENAI_API_KEY"] = "test_openai_key"
bot = ConcreteOpenAICompatibleBot(model_name="gpt-4")
MockOpenAIConstructor.assert_called_once_with(api_key="test_openai_key", base_url=None)
self.assertEqual(bot.client, self.mock_openai_client_instance)
self.assertEqual(bot.model, "gpt-4")
self.assertEqual(bot.max_tokens, 1000) # Default from _configure_model_and_tokens
self.assertEqual(bot.azure_openai, False)
@patch('openai.OpenAI')
def test_init_with_provided_client(self, MockOpenAIConstructor):
preconfigured_client = MagicMock()
bot = ConcreteOpenAICompatibleBot(client=preconfigured_client, model_name="gpt-3.5")
MockOpenAIConstructor.assert_not_called()
self.assertEqual(bot.client, preconfigured_client)
self.assertEqual(bot.model, "gpt-3.5")
@patch('openai.AzureOpenAI')
def test_init_with_azure_config_args(self, MockAzureOpenAIConstructor):
MockAzureOpenAIConstructor.return_value = self.mock_openai_client_instance
bot = ConcreteOpenAICompatibleBot(
api_key="azure_key",
azure_endpoint="https://myenv.openai.azure.com",
api_version="2023-05-15",
azure_deployment="my-gpt-4", # This should be used as model_name for API call
model_name="should_be_overridden_by_azure_deployment_for_api"
# model_name is passed to _configure_model_and_tokens, which sets self.model for display/logging
# but for Azure, the client needs the deployment name.
)
MockAzureOpenAIConstructor.assert_called_once_with(
api_key="azure_key",
azure_endpoint="https://myenv.openai.azure.com",
api_version="2023-05-15"
)
self.assertEqual(bot.client, self.mock_openai_client_instance)
self.assertEqual(bot.model, "my-gpt-4") # Azure deployment name becomes the model for API calls
self.assertEqual(bot.azure_openai, True)
@patch('openai.AzureOpenAI')
def test_init_with_azure_env_vars(self, MockAzureOpenAIConstructor):
MockAzureOpenAIConstructor.return_value = self.mock_openai_client_instance
os.environ["AZURE_OPENAI_KEY"] = "env_azure_key"
os.environ["AZURE_OPENAI_ENDPOINT"] = "https://env.openai.azure.com"
os.environ["AZURE_OPENAI_API_VERSION"] = "2023-06-01"
os.environ["AZURE_DEPLOYMENT_NAME"] = "env-gpt-35" # Used as model_name
bot = ConcreteOpenAICompatibleBot(model_name="ignored_if_azure_deployment_env_is_set")
MockAzureOpenAIConstructor.assert_called_once_with(
api_key="env_azure_key",
azure_endpoint="https://env.openai.azure.com",
api_version="2023-06-01"
)
self.assertEqual(bot.model, "env-gpt-35")
self.assertTrue(bot.azure_openai)
@patch('openai.OpenAI')
def test_init_with_gemini_config_args(self, MockOpenAIConstructor):
MockOpenAIConstructor.return_value = self.mock_openai_client_instance
bot = ConcreteOpenAICompatibleBot(
api_key="gemini_key",
base_url="https://gemini.example.com",
model_name="gemini-pro",
is_gemini=True
)
MockOpenAIConstructor.assert_called_once_with(api_key="gemini_key", base_url="https://gemini.example.com")
self.assertEqual(bot.model, "gemini-pro")
self.assertFalse(bot.azure_openai) # is_gemini doesn't mean azure_openai
def test_configure_model_and_tokens(self):
bot = ConcreteOpenAICompatibleBot(model_name="initial_model") # init calls _configure
bot._configure_model_and_tokens("test-model", "500")
self.assertEqual(bot.model, "test-model")
self.assertEqual(bot.max_tokens, 500)
bot._configure_model_and_tokens("test-model-2", None, default_max_tokens=150)
self.assertEqual(bot.max_tokens, 150)
bot._configure_model_and_tokens("test-model-3", "invalid_token_val")
self.assertEqual(bot.max_tokens, 1000) # Default fallback
def test_get_llm_description(self):
bot = ConcreteOpenAICompatibleBot(model_name="desc-model", max_tokens_str="256")
self.assertEqual(bot.get_llm_description(), "LLM: desc-model, Max Tokens: 256, Azure: False")
bot_azure = ConcreteOpenAICompatibleBot(azure_deployment="azure-model", azure_endpoint="x", api_key="y", api_version="z")
self.assertEqual(bot_azure.get_llm_description(), "LLM: azure-model, Max Tokens: 1000, Azure: True")
def test_get_chat_response_success(self):
bot = ConcreteOpenAICompatibleBot(client=self.mock_openai_client_instance, model_name="test-gpt")
bot.max_tokens = 50 # Ensure this is set
mock_api_response = create_mock_openai_response(content="Hello from API")
self.mock_openai_client_instance.chat.completions.create.return_value = mock_api_response
messages = [{"role": "user", "content": "Hi"}]
response = bot.get_chat_response(messages)
self.mock_openai_client_instance.chat.completions.create.assert_called_once_with(
model="test-gpt",
messages=messages,
tools=ANY, # Assuming functions can be None or empty list
tool_choice=ANY,
max_tokens=50
)
self.assertEqual(response, mock_api_response)
def test_get_chat_response_api_error(self):
bot = ConcreteOpenAICompatibleBot(client=self.mock_openai_client_instance, model_name="error-gpt")
self.mock_openai_client_instance.chat.completions.create.side_effect = Exception("API Down")
with self.assertRaisesRegex(Exception, "API Down"):
bot.get_chat_response([{"role": "user", "content": "trigger"}])
async def test_handle_message_simple_response(self):
bot = ConcreteOpenAICompatibleBot(client=self.mock_openai_client_instance, model_name="chatty")
bot.system_prompt = "You are a test bot." # Set directly for simplicity
mock_api_response = create_mock_openai_response(content="Test reply")
self.mock_openai_client_instance.chat.completions.create.return_value = mock_api_response
response_content = await bot.handle_message(user_id=1, user_message="Hello")
self.assertEqual(response_content, "Test reply")
self.assertIn(1, bot.conversation_history)
self.assertEqual(len(bot.conversation_history[1]), 3) # System, User, Assistant
self.assertEqual(bot.conversation_history[1][0]["content"], "You are a test bot.")
self.assertEqual(bot.conversation_history[1][2]["content"], "Test reply")
async def test_handle_message_with_tool_call_and_response(self):
bot = ConcreteOpenAICompatibleBot(client=self.mock_openai_client_instance, model_name="tool-user")
# Mock functions/tools setup on the bot
mock_tool_def = {"function": {"name": "get_weather", "description": "Gets weather", "parameters": {}}}
bot.functions = [mock_tool_def] # Simulate tools are loaded
# API response 1: Request to call tool
tool_call_request = [{"id": "call123", "function": {"name": "get_weather", "arguments": '''{"location": "moon"}'''}}]
api_response_1 = create_mock_openai_response(tool_calls=tool_call_request)
# API response 2: Final answer after tool execution
api_response_2 = create_mock_openai_response(content="The weather on the moon is chilly.")
self.mock_openai_client_instance.chat.completions.create.side_effect = [api_response_1, api_response_2]
# Mock self.call_tool
bot.call_tool = MagicMock(return_value='''{"temperature": "-100 C"}''')
final_response = await bot.handle_message(user_id=2, user_message="Weather on moon?")
self.assertEqual(final_response, "The weather on the moon is chilly.")
bot.call_tool.assert_called_once_with("get_weather", '''{"location": "moon"}''')
# Check conversation history includes tool messages
history = bot.conversation_history[2]
self.assertTrue(any(msg["role"] == "assistant" and msg.tool_calls is not None for msg in history))
self.assertTrue(any(msg["role"] == "tool" and msg["name"] == "get_weather" for msg in history))
self.assertEqual(self.mock_openai_client_instance.chat.completions.create.call_count, 2)
async def test_handle_message_max_history_length(self):
bot = ConcreteOpenAICompatibleBot(client=self.mock_openai_client_instance, model_name="hist-test", max_history_length=3)
self.mock_openai_client_instance.chat.completions.create.return_value = create_mock_openai_response(content="Ok")
await bot.handle_message(1, "Msg1") # Sys, User, Assist (3)
self.assertEqual(len(bot.conversation_history[1]), 3)
await bot.handle_message(1, "Msg2") # User, Assist. Should be 3 (prev User, prev Assist, new User) -> then adds new Assist.
# Before new call: [Sys, U1, A1]. New U2. Call with [Sys,U1,A1,U2]. Resp A2.
# History: [Sys,U1,A1,U2,A2]. Limit 3. -> [A1,U2,A2] (if system is not preserved specially)
# The current code appends to history then truncates if over limit.
# So after Msg1: [S, U1, A1]. len=3.
# For Msg2: History is [S, U1, A1]. Append U2. Call with [S,U1,A1,U2]. Append A2.
# History now [S,U1,A1,U2,A2]. len=5. Truncate to 3.
# Expected: [A1, U2, A2] or [U1,A1,U2] or [U2,A2,S] depending on how system prompt is handled in truncation.
# The code is: self.conversation_history[user_id][-self.max_history_length:]
# And system prompt is only added IF user_id not in self.conversation_history.
# So, for Msg2, system prompt is not re-added.
# History before Msg2 call: [S, U1, A1]
# Messages for Msg2 call: [S, U1, A1, U2]
# History after Msg2 response A2: [S, U1, A1, U2, A2]. Len 5.
# Truncated to self.max_history_length=3: [A1, U2, A2]
# Call 1
self.mock_openai_client_instance.chat.completions.create.reset_mock()
self.mock_openai_client_instance.chat.completions.create.return_value = create_mock_openai_response(content="Reply1")
await bot.handle_message(user_id=7, user_message="First message")
self.assertEqual(len(bot.conversation_history[7]), 3) # System, User1, Assistant1
# Call 2
self.mock_openai_client_instance.chat.completions.create.reset_mock()
self.mock_openai_client_instance.chat.completions.create.return_value = create_mock_openai_response(content="Reply2")
await bot.handle_message(user_id=7, user_message="Second message")
# History before call: [S, U1, A1]. Messages for call: [S, U1, A1, U2]. History after: [S, U1, A1, U2, A2].
# Truncated to 3: [A1, U2, A2]
self.assertEqual(len(bot.conversation_history[7]), 3)
self.assertEqual(bot.conversation_history[7][0]["content"], "Reply1") # A1
self.assertEqual(bot.conversation_history[7][1]["content"], "Second message") # U2
self.assertEqual(bot.conversation_history[7][2]["content"], "Reply2") # A2
# Call 3
self.mock_openai_client_instance.chat.completions.create.reset_mock()
self.mock_openai_client_instance.chat.completions.create.return_value = create_mock_openai_response(content="Reply3")
await bot.handle_message(user_id=7, user_message="Third message")
# History before call: [A1, U2, A2]. Messages for call: [A1, U2, A2, U3]. History after: [A1, U2, A2, U3, A3].
# Truncated to 3: [A2, U3, A3]
self.assertEqual(len(bot.conversation_history[7]), 3)
self.assertEqual(bot.conversation_history[7][0]["content"], "Reply2") # A2
self.assertEqual(bot.conversation_history[7][1]["content"], "Third message") # U3
self.assertEqual(bot.conversation_history[7][2]["content"], "Reply3") # A3
async def test_abort_processing(self):
bot = ConcreteOpenAICompatibleBot(model_name="test")
user_id = 123
bot.processing_status[user_id] = {"processing": True, "message_id": 456}
bot.conversation_history[user_id] = [{"role": "user", "content": "stuff"}]
with patch.object(bot, 'clear_conversation_history') as mock_clear_hist: # Patching the method from Base class
result = await bot.abort_processing(user_id)
self.assertEqual(result, "Processing aborted and conversation cleared.")
self.assertFalse(bot.processing_status[user_id]["processing"])
mock_clear_hist.assert_called_once_with(user_id)
async def test_abort_processing_no_active_processing(self):
bot = ConcreteOpenAICompatibleBot(model_name="test")
user_id = 404 # Not in processing_status
with patch.object(bot, 'clear_conversation_history') as mock_clear_hist:
result = await bot.abort_processing(user_id)
self.assertEqual(result, "No active processing found to abort. Conversation cleared.")
mock_clear_hist.assert_called_once_with(user_id)
# Test for the abstract switch_model (basic call, actual logic in concrete class for this test)
async def test_switch_model_concrete_implementation(self):
bot = ConcreteOpenAICompatibleBot(model_name="model1", small_model_name="model1", large_model_name="model2", max_tokens_str="100")
self.assertEqual(bot.model, "model1")
await bot.switch_model() # Calls the concrete implementation
self.assertEqual(bot.model, "model2")
await bot.switch_model()
self.assertEqual(bot.model, "model1")
if __name__ == '__main__':
unittest.main()
-356
View File
@@ -1,356 +0,0 @@
import unittest
from unittest.mock import MagicMock, patch, mock_open, AsyncMock
import asyncio
import os
import sys
# Assuming telegram_helper.py is in the parent directory or PYTHONPATH is set
from telegram_helper import TelegramHelper, MessageHandlerLogicResult
# Mock for the bot passed to TelegramHelper
class MockBot:
def __init__(self):
self.start = AsyncMock()
self.clear_conversation_history = MagicMock()
self.get_bot_status = AsyncMock(return_value="Bot Status OK")
self.switch_model = AsyncMock(return_value="Model Switched OK")
self.handle_message = AsyncMock() # Needs to return a string
self.abort_processing = AsyncMock(return_value="Abort OK")
self.set_processing_status = MagicMock()
self.clear_processing_status = MagicMock()
self.processing_status = {} # Add the attribute
# Mock for telegram.Update and related objects
def create_mock_update(message_text=None, user_id=123, chat_id=456, message_id=789, callback_query_data=None):
update = MagicMock()
update.effective_user.id = user_id
update.effective_chat.id = chat_id
if message_text:
update.message.text = message_text
update.message.reply_text = AsyncMock(return_value=MagicMock(message_id=message_id)) # reply_text returns a Message obj
if callback_query_data:
update.callback_query.data = callback_query_data
update.callback_query.from_user.id = user_id
update.callback_query.answer = AsyncMock()
update.callback_query.edit_message_text = AsyncMock()
return update
# Mock for telegram.ext.ContextTypes.DEFAULT_TYPE
def create_mock_context():
context = MagicMock()
context.bot.delete_message = AsyncMock()
context.bot.edit_message_text = AsyncMock() # For update_status_message
return context
class TestTelegramHelper(unittest.IsolatedAsyncioTestCase): # Use IsolatedAsyncioTestCase for async methods
def setUp(self):
self.mock_bot = MockBot()
# Default paths for reboot files, can be overridden in tests
self.reboot_claude_file = ".test_reboot_claude"
self.reboot_file = ".test_doreboot"
self.helper = TelegramHelper(
self.mock_bot,
reboot_claude_file_path=self.reboot_claude_file,
reboot_file_path=self.reboot_file,
chunk_message_sleep_duration=0.001 # Faster sleep for tests
)
# Clean up any potential leftover reboot files from previous runs
if os.path.exists(self.reboot_claude_file):
os.remove(self.reboot_claude_file)
if os.path.exists(self.reboot_file):
os.remove(self.reboot_file)
def tearDown(self):
# Clean up reboot files created during tests
if os.path.exists(self.reboot_claude_file):
os.remove(self.reboot_claude_file)
if os.path.exists(self.reboot_file):
os.remove(self.reboot_file)
async def test_start_logic(self):
response = await self.helper._start_logic()
self.mock_bot.start.assert_called_once()
self.assertEqual(response, "Hello! I\'m your AI assistant. How can I help you today?")
async def test_start_command(self):
mock_update = create_mock_update(message_text="/start")
mock_context = create_mock_context()
with patch.object(self.helper, \'_start_logic\', new_callable=AsyncMock) as mock_logic:
mock_logic.return_value = "Start Logic Response"
await self.helper.start(mock_update, mock_context)
mock_logic.assert_called_once()
mock_update.message.reply_text.assert_called_once_with("Start Logic Response")
async def test_clear_logic(self):
user_id = 123
response = await self.helper._clear_logic(user_id) # _clear_logic is async after refactor
self.mock_bot.clear_conversation_history.assert_called_once_with(user_id)
self.assertEqual(response, "Conversation history cleared. Let\'s start fresh!")
async def test_clear_command(self):
mock_update = create_mock_update(message_text="/clear", user_id=123)
mock_context = create_mock_context()
with patch.object(self.helper, \'_clear_logic\', new_callable=AsyncMock) as mock_logic:
mock_logic.return_value = "Clear Logic Response"
await self.helper.clear(mock_update, mock_context)
mock_logic.assert_called_once_with(123)
mock_update.message.reply_text.assert_called_once_with("Clear Logic Response")
async def test_status_logic(self):
self.mock_bot.get_bot_status.return_value = "Test Status"
response = await self.helper._status_logic()
self.mock_bot.get_bot_status.assert_called_once()
self.assertEqual(response, "Test Status")
async def test_switch_logic_supported(self):
self.mock_bot.switch_model.return_value = "Switched to Large Model"
response = await self.helper._switch_logic()
self.mock_bot.switch_model.assert_called_once()
self.assertEqual(response, "Switched to Large Model")
async def test_switch_logic_not_supported(self):
del self.mock_bot.switch_model # Simulate bot not having the attribute
response = await self.helper._switch_logic()
self.assertEqual(response, "Model switching is not supported for this bot.")
async def test_handle_message_logic_success(self):
user_id = 100
user_message = "Hello bot"
bot_response = "Hello user <think>Thinking hard</think> Done."
expected_processed_response = f"Hello user {self.helper.HTML_QUOTE_BLOCK_START}Thinking hard{self.helper.HTML_QUOTE_BLOCK_END} Done."
self.mock_bot.handle_message.return_value = bot_response
result = await self.helper._handle_message_logic(user_id, user_message)
self.mock_bot.handle_message.assert_called_once_with(user_id, user_message)
self.assertTrue(result["success"])
self.assertEqual(result["response_text"], expected_processed_response)
self.assertIsNone(result["error_message"])
async def test_handle_message_logic_bot_exception(self):
user_id = 101
user_message = "Trigger error"
self.mock_bot.handle_message.side_effect = Exception("Bot Error")
result = await self.helper._handle_message_logic(user_id, user_message)
self.assertFalse(result["success"])
self.assertIsNone(result["response_text"])
self.assertEqual(result["error_message"], "Bot Error")
@patch(\'logging.error\')
async def test_handle_message_command_success_short_message(self, mock_logging_error):
mock_update = create_mock_update(message_text="Hi", user_id=200, chat_id=201, message_id=202)
mock_context = create_mock_context()
logic_result = MessageHandlerLogicResult(success=True, response_text="Short response", error_message=None)
with patch.object(self.helper, \'_handle_message_logic\', new_callable=AsyncMock) as mock_message_logic:
mock_message_logic.return_value = logic_result
await self.helper.handle_message(mock_update, mock_context)
mock_update.message.reply_text.assert_any_call("Processing your request...", reply_markup=unittest.mock.ANY)
self.mock_bot.set_processing_status.assert_called_once_with(200, 202) # user_id, status_message_id
mock_message_logic.assert_called_once_with(200, "Hi")
mock_context.bot.delete_message.assert_called_once_with(chat_id=201, message_id=202)
self.mock_bot.clear_processing_status.assert_called_once_with(200)
mock_update.message.reply_text.assert_any_call("Short response") # Final response
self.assertEqual(mock_update.message.reply_text.call_count, 2) # Processing + final
@patch(\'logging.error\')
async def test_handle_message_command_success_long_message_chunks(self, mock_logging_error):
mock_update = create_mock_update(message_text="Long text", user_id=200, chat_id=201, message_id=202)
mock_context = create_mock_context()
long_response_text = "a" * 5000 # Longer than 4096
chunk1 = long_response_text[:4096]
chunk2 = long_response_text[4096:]
logic_result = MessageHandlerLogicResult(success=True, response_text=long_response_text, error_message=None)
with patch.object(self.helper, \'_handle_message_logic\', new_callable=AsyncMock) as mock_message_logic, \
patch(\'asyncio.sleep\', new_callable=AsyncMock) as mock_sleep: # Mock sleep
mock_message_logic.return_value = logic_result
await self.helper.handle_message(mock_update, mock_context)
mock_update.message.reply_text.assert_any_call(chunk1)
mock_update.message.reply_text.assert_any_call(chunk2)
mock_sleep.assert_called_once_with(self.helper.chunk_message_sleep_duration)
self.assertEqual(mock_update.message.reply_text.call_count, 3) # Processing + 2 chunks
@patch(\'logging.error\')
async def test_handle_message_command_logic_fails(self, mock_logging_error):
mock_update = create_mock_update(message_text="Cause error in logic", user_id=200)
mock_context = create_mock_context()
logic_result = MessageHandlerLogicResult(success=False, response_text=None, error_message="Logic Failed")
with patch.object(self.helper, \'_handle_message_logic\', new_callable=AsyncMock) as mock_message_logic:
mock_message_logic.return_value = logic_result
await self.helper.handle_message(mock_update, mock_context)
mock_update.message.reply_text.assert_any_call("Sorry, an error occurred while processing your request.")
self.assertEqual(mock_update.message.reply_text.call_count, 2) # Processing + error message
@patch(\'logging.error\')
async def test_handle_message_command_telegram_exception_after_logic(self, mock_logging_error):
mock_update = create_mock_update(message_text="Test", user_id=200)
mock_context = create_mock_context()
logic_result = MessageHandlerLogicResult(success=True, response_text="OK", error_message=None)
# Make sending the final reply fail
mock_update.message.reply_text.side_effect = [
MagicMock(message_id=202), # For "Processing..."
Exception("Telegram API Error") # For the actual response
]
with patch.object(self.helper, \'_handle_message_logic\', new_callable=AsyncMock) as mock_message_logic:
mock_message_logic.return_value = logic_result
await self.helper.handle_message(mock_update, mock_context)
# Check if the generic error message was attempted
# This is tricky because reply_text is already mocked with side_effect.
# We\'d expect logs. Let\'s check logs or if processing status was cleared.
self.mock_bot.clear_processing_status.assert_called_once_with(200)
mock_logging_error.assert_any_call(unittest.mock.string_containing("Outer error in handle_message"))
async def test_abort_processing_logic(self):
user_id = 300
self.mock_bot.abort_processing.return_value = "Aborted by bot"
response = await self.helper._abort_processing_logic(user_id)
self.mock_bot.abort_processing.assert_called_once_with(user_id)
self.assertEqual(response, "Aborted by bot")
async def test_abort_processing_command(self):
mock_update = create_mock_update(callback_query_data=\'abort\', user_id=301)
mock_context = create_mock_context()
with patch.object(self.helper, \'_abort_processing_logic\', new_callable=AsyncMock) as mock_logic:
mock_logic.return_value = "Abort Logic Done"
await self.helper.abort_processing(mock_update, mock_context)
mock_update.callback_query.answer.assert_called_once()
mock_logic.assert_called_once_with(301)
mock_update.callback_query.edit_message_text.assert_called_once_with(text="Abort Logic Done")
def test_reboot_logic_claude_and_main(self):
user_message_parts = ["/reboot", "claude"]
chat_id_to_write = "12345"
with patch("builtins.open", mock_open()) as mock_file:
self.helper._reboot_logic(user_message_parts, chat_id_to_write)
# Check claude reboot file
mock_file.assert_any_call(self.reboot_claude_file, \'w\')
# Check main doreboot file
mock_file.assert_any_call(self.reboot_file, \'w\')
handle_claude = mock_file.return_value
handle_main = mock_file.return_value # mock_open reuses the handle for multiple calls
# Check if write was called for claude file (empty write)
# This part of assertion is tricky with single mock_file. Better to use different mocks if possible
# or check the sequence of calls if the mock supports it well.
# For now, assert_any_call ensures it was opened.
# Check content for main reboot file
# Need to ensure the write for self.reboot_file had chat_id_to_write
# This requires more sophisticated mock_open or patching os.path.exists and multiple open calls
# Simpler check: was open(self.reboot_file, \'w\') called? Yes, via assert_any_call.
# And was open(self.reboot_claude_file, \'w\') called? Yes.
# Verify files were created (mock_open doesn\'t actually create them)
# This test relies on mock_open\'s behavior. To test file content, need more setup.
# For now, assume open was called correctly.
def test_reboot_logic_main_only(self):
user_message_parts = ["/reboot"]
chat_id_to_write = "67890"
with patch("builtins.open", mock_open()) as mock_file:
self.helper._reboot_logic(user_message_parts, chat_id_to_write)
# Ensure claude file was NOT opened for writing.
# This requires asserting that a specific call didn\'t happen, or checking call_args_list
claude_call = unittest.mock.call(self.reboot_claude_file, \'w\')
self.assertNotIn(claude_call, mock_file.call_args_list)
mock_file.assert_any_call(self.reboot_file, \'w\')
@patch(\'sys.exit\') # Mock sys.exit to prevent test runner from exiting
async def test_reboot_command(self, mock_sys_exit):
mock_update = create_mock_update(message_text="/reboot claude", chat_id="chat1")
mock_context = create_mock_context()
with patch.object(self.helper, \'_reboot_logic\') as mock_reboot_file_logic:
await self.helper.reboot(mock_update, mock_context)
mock_reboot_file_logic.assert_called_once_with(["/reboot", "claude"], "chat1")
mock_update.message.reply_text.assert_called_once_with("Rebooting the bot...")
mock_sys_exit.assert_called_once_with(0)
@patch(\'os.path.exists\')
@patch(\'builtins.open\', new_callable=mock_open)
@patch(\'os.remove\')
async def test_check_doreboot_file_logic_file_exists(self, mock_os_remove, mock_file_open, mock_os_path_exists):
mock_os_path_exists.return_value = True
mock_file_open.return_value.read.return_value.strip.return_value = "chat123"
chat_id = await self.helper._check_doreboot_file_logic()
mock_os_path_exists.assert_called_once_with(self.reboot_file)
mock_file_open.assert_called_once_with(self.reboot_file, \'r\')
mock_os_remove.assert_called_once_with(self.reboot_file)
self.assertEqual(chat_id, "chat123")
@patch(\'os.path.exists\', return_value=False)
async def test_check_doreboot_file_logic_file_not_exists(self, mock_os_path_exists):
chat_id = await self.helper._check_doreboot_file_logic()
mock_os_path_exists.assert_called_once_with(self.reboot_file)
self.assertIsNone(chat_id)
@patch(\'logging.error\')
@patch(\'os.path.exists\', return_value=True)
@patch(\'builtins.open\', side_effect=IOError("Read error"))
@patch(\'os.remove\') # To check if remove is called even on read error
async def test_check_doreboot_file_logic_read_error(self, mock_os_remove, mock_file_open, mock_os_path_exists, mock_log_error):
chat_id = await self.helper._check_doreboot_file_logic()
self.assertIsNone(chat_id)
mock_log_error.assert_any_call(unittest.mock.string_containing("Error reading reboot file"))
# Check if os.remove was attempted even after read error
mock_os_remove.assert_called_once_with(self.reboot_file)
async def test_check_doreboot_file_command_sends_message(self):
mock_application = MagicMock()
mock_application.bot.send_message = AsyncMock()
with patch.object(self.helper, \'_check_doreboot_file_logic\', new_callable=AsyncMock) as mock_logic:
mock_logic.return_value = "chat789" # Simulate chat_id found
await self.helper.check_doreboot_file(mock_application)
mock_logic.assert_called_once()
mock_application.bot.send_message.assert_called_once_with(
chat_id="chat789", text="The application has finished initializing."
)
async def test_check_doreboot_file_command_no_chat_id(self):
mock_application = MagicMock()
mock_application.bot.send_message = AsyncMock()
with patch.object(self.helper, \'_check_doreboot_file_logic\', new_callable=AsyncMock) as mock_logic:
mock_logic.return_value = None # Simulate no chat_id found
await self.helper.check_doreboot_file(mock_application)
mock_logic.assert_called_once()
mock_application.bot.send_message.assert_not_called()
# Note: Testing the run() method itself is more of an integration test,
# as it involves setting up the full Application and polling loop.
# Unit tests here focus on the helper\'s own logic methods.
if __name__ == \'__main__\':
unittest.main()
-307
View File
@@ -1,307 +0,0 @@
import unittest
from unittest.mock import MagicMock, patch
import os
import base64
import logging
import requests # Required for spec in MagicMock
# Ensure tools/github_tool.py is accessible
from tools.github_tool import GitHubTool
# Helper to create a mock response for requests.Session
def create_mock_response(status_code, json_data=None, text_data=None, headers=None, links=None):
mock_resp = MagicMock()
mock_resp.status_code = status_code
if json_data is not None:
mock_resp.json = MagicMock(return_value=json_data)
mock_resp.text = text_data if text_data is not None else str(json_data)
mock_resp.headers = headers if headers else {}
mock_resp.links = links if links else {} # For pagination in _list_branches
return mock_resp
class TestGitHubTool(unittest.TestCase):
def setUp(self):
self.mock_session = MagicMock(spec=requests.Session)
self.mock_session.headers = {} # Simulate a new session's headers
self.test_token = "test_github_token"
self.test_repo = "owner/repo"
self.test_base_url = "https://api.example.com" # Use a non-default base_url for some tests
# Suppress logging output during tests unless explicitly testing for it
self.logger = logging.getLogger('tools.github_tool')
# Ensure only one NullHandler to prevent duplicate messages if tests run multiple times in a session
if not any(isinstance(h, logging.NullHandler) for h in self.logger.handlers):
self.logger.addHandler(logging.NullHandler())
self.logger.propagate = False # Prevent propagation to root logger if it has handlers
def test_init_with_args_and_session(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, base_url=self.test_base_url, logger=self.logger)
self.assertEqual(tool.session, self.mock_session)
self.assertEqual(tool._token, self.test_token)
self.assertEqual(tool._repo, self.test_repo)
self.assertEqual(tool.base_url, self.test_base_url)
self.assertEqual(tool.current_branch, "main") # Default initial branch
@patch('requests.Session')
def test_init_creates_session_if_not_provided(self, MockSessionConstructor):
mock_created_session = MagicMock(spec=requests.Session)
mock_created_session.headers = {}
MockSessionConstructor.return_value = mock_created_session
# Temporarily set env vars for this test
original_token = os.environ.get("GITHUB_TOKEN")
original_repo = os.environ.get("GITHUB_REPOSITORY")
os.environ["GITHUB_TOKEN"] = "env_token"
os.environ["GITHUB_REPOSITORY"] = "env/repo"
tool = GitHubTool(logger=self.logger) # Use env vars
MockSessionConstructor.assert_called_once()
self.assertEqual(tool.session, mock_created_session)
self.assertEqual(tool._token, "env_token")
self.assertEqual(tool._repo, "env/repo")
self.assertIn("Authorization", mock_created_session.headers)
self.assertEqual(mock_created_session.headers["Authorization"], "token env_token")
# Restore original env vars
if original_token is None: del os.environ["GITHUB_TOKEN"]
else: os.environ["GITHUB_TOKEN"] = original_token
if original_repo is None: del os.environ["GITHUB_REPOSITORY"]
else: os.environ["GITHUB_REPOSITORY"] = original_repo
def test_init_raises_value_error_if_no_token(self):
original_token = os.environ.pop("GITHUB_TOKEN", None)
with self.assertRaisesRegex(ValueError, "GitHub token must be provided"):
GitHubTool(repo=self.test_repo, logger=self.logger)
if original_token: os.environ["GITHUB_TOKEN"] = original_token
def test_init_raises_value_error_if_no_repo(self):
original_repo = os.environ.pop("GITHUB_REPOSITORY", None)
with self.assertRaisesRegex(ValueError, "GitHub repository.*must be provided"):
GitHubTool(token=self.test_token, logger=self.logger)
if original_repo: os.environ["GITHUB_REPOSITORY"] = original_repo
def test_clear_resets_branch(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, initial_branch="feature-branch", logger=self.logger)
# Mock _get_branch_sha for _set_current_branch called by clear
with patch.object(tool, '_get_branch_sha', return_value="sha_for_main"):
tool.clear()
self.assertEqual(tool.current_branch, "main")
def test_get_functions_returns_list(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
functions = tool.get_functions()
self.assertIsInstance(functions, list)
self.assertTrue(len(functions) > 0)
self.assertIn("name", functions[0]["function"])
# --- Test individual private methods ---
def test_read_file_success(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
file_content = "Hello World!"
encoded_content = base64.b64encode(file_content.encode('utf-8')).decode('utf-8')
self.mock_session.get.return_value = create_mock_response(200, json_data={"content": encoded_content})
result = tool._read_file(path="test.txt")
self.assertEqual(result, file_content)
self.mock_session.get.assert_called_once_with(
f"{tool.base_url}/repos/{self.test_repo}/contents/test.txt",
params={"ref": "main"}
)
def test_read_file_error(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
self.mock_session.get.return_value = create_mock_response(404, text_data="Not Found")
result = tool._read_file(path="nonexistent.txt")
self.assertIn("Error reading file", result)
def test_create_branch_success(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
# Mock getting base branch SHA
self.mock_session.get.return_value = create_mock_response(200, json_data={"object": {"sha": "base_sha123"}})
# Mock creating new branch
self.mock_session.post.return_value = create_mock_response(201, json_data={"ref": "refs/heads/new-feature"})
result = tool._create_branch(branch_name="new-feature", base_branch="main")
self.assertIn("Branch 'new-feature' created successfully", result)
self.assertEqual(tool.current_branch, "new-feature")
self.mock_session.get.assert_called_once() # For base branch SHA
self.mock_session.post.assert_called_once() # For creating branch
def test_create_branch_base_sha_error(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
self.mock_session.get.return_value = create_mock_response(404, text_data="Base branch not found")
result = tool._create_branch(branch_name="new-feature", base_branch="nonexistent-base")
self.assertIn("Error getting base branch SHA", result)
def test_create_branch_creation_error(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
self.mock_session.get.return_value = create_mock_response(200, json_data={"object": {"sha": "base_sha456"}})
self.mock_session.post.return_value = create_mock_response(422, text_data="Validation failed")
result = tool._create_branch(branch_name="bad-branch", base_branch="main")
self.assertIn("Error creating branch", result)
def test_commit_file_success_new_file(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
tool.current_branch = "dev-branch" # Cannot commit to main by default
# Mock GET for checking file existence (404 means new file)
self.mock_session.get.return_value = create_mock_response(404)
# Mock PUT for committing file
self.mock_session.put.return_value = create_mock_response(201, json_data={"commit": {"sha": "commit_sha_abc"}})
result = tool._commit_file(file_path="new_file.py", content="print('Hello')", commit_message="Add new_file.py")
self.assertIn("committed successfully", result)
self.assertIn("commit_sha_abc", result)
self.mock_session.get.assert_called_once() # Check file existence
self.mock_session.put.assert_called_once() # Commit file
def test_commit_file_success_update_file(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
tool.current_branch = "dev-branch"
# Mock GET for checking file existence (200 means existing file)
self.mock_session.get.return_value = create_mock_response(200, json_data={"sha": "existing_file_sha"})
# Mock PUT for committing file
self.mock_session.put.return_value = create_mock_response(200, json_data={"commit": {"sha": "commit_sha_def"}})
result = tool._commit_file(file_path="existing_file.txt", content="Updated content", commit_message="Update existing_file.txt")
self.assertIn("committed successfully", result)
self.assertIn("commit_sha_def", result)
args, kwargs = self.mock_session.put.call_args
self.assertEqual(kwargs['json']['sha'], "existing_file_sha")
def test_commit_file_to_main_branch_error(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
tool.current_branch = "main"
result = tool._commit_file(file_path="some.txt", content="content", commit_message="msg")
self.assertIn("Action directly to main branch is not allowed", result)
def test_create_pull_request_success(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
tool.current_branch = "feature-pr"
pr_url = "https://example.com/pull/1"
self.mock_session.post.return_value = create_mock_response(201, json_data={"html_url": pr_url, "number": 1})
result = tool._create_pull_request(title="New Feature PR", body="Please review.", base="main")
self.assertIn(f"Pull request created successfully: {pr_url}", result)
self.mock_session.post.assert_called_once()
call_data = self.mock_session.post.call_args[1]['json']
self.assertEqual(call_data['head'], "feature-pr")
self.assertEqual(call_data['base'], "main")
def test_create_pull_request_same_branch_error(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
tool.current_branch = "main"
result = tool._create_pull_request(title="PR to self", body="This should fail", base="main")
self.assertIn("Cannot create a pull request from branch 'main' to itself", result)
def test_list_files_success(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
mock_items = [
{"name": "file1.txt", "type": "file", "path": "dir/file1.txt"},
{"name": "subdir", "type": "dir", "path": "dir/subdir"}
]
self.mock_session.get.return_value = create_mock_response(200, json_data=mock_items)
result = tool._list_files(path="dir")
self.assertEqual(len(result), 2)
self.assertEqual(result[0]["name"], "file1.txt")
self.assertEqual(result[1]["type"], "dir")
def test_search_code_success(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
mock_search_results = {
"items": [{"path": "src/code.py", "html_url": "url1"}]
}
self.mock_session.get.return_value = create_mock_response(200, json_data=mock_search_results)
results = tool._search_code(query="my_function")
self.assertEqual(len(results), 1)
self.assertEqual(results[0]["path"], "src/code.py")
def test_get_commit_history_success(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
mock_commits = [{
"sha": "sha1", "commit": {"message": "Msg1", "author": {"name": "Authy", "date": "Date1"}}
}]
self.mock_session.get.return_value = create_mock_response(200, json_data=mock_commits)
commits = tool._get_commit_history(file_path="file.txt", num_commits=1)
self.assertEqual(len(commits), 1)
self.assertEqual(commits[0]["sha"], "sha1")
def test_set_current_branch_success(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
# Mock _get_branch_sha to simulate branch exists
with patch.object(tool, '_get_branch_sha', return_value="some_sha_for_dev"):
result = tool._set_current_branch(branch_name="dev")
self.assertEqual(tool.current_branch, "dev")
self.assertIn("Current branch set to: dev", result)
def test_set_current_branch_not_exists(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
with patch.object(tool, '_get_branch_sha', return_value="Error getting SHA for branch"):
result = tool._set_current_branch(branch_name="nonexistent-branch")
self.assertNotEqual(tool.current_branch, "nonexistent-branch") # Should not change
self.assertIn("Cannot set current branch", result)
def test_list_branches_single_page(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
mock_branches = [{"name": "main"}, {"name": "dev"}]
self.mock_session.get.return_value = create_mock_response(200, json_data=mock_branches, links={}) # No "next" link
branches = tool._list_branches(all_pages=True)
self.assertEqual(branches, ["main", "dev"])
self.mock_session.get.assert_called_once()
def test_list_branches_multiple_pages(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
# Page 1 response
page1_branches = [{"name": "branch1"}, {"name": "branch2"}]
next_url = f"{tool.base_url}/repos/{self.test_repo}/branches?page=2"
response1 = create_mock_response(200, json_data=page1_branches, links={"next": {"url": next_url}})
# Page 2 response
page2_branches = [{"name": "branch3"}]
response2 = create_mock_response(200, json_data=page2_branches, links={}) # No "next" link
self.mock_session.get.side_effect = [response1, response2]
branches = tool._list_branches(all_pages=True)
self.assertEqual(branches, ["branch1", "branch2", "branch3"])
self.assertEqual(self.mock_session.get.call_count, 2)
# Check that the second call used the next_url
calls = self.mock_session.get.call_args_list
self.assertEqual(calls[1][0][0], next_url) # args[0] is the URL
# --- Test execute dispatcher ---
def test_execute_read_file(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
with patch.object(tool, '_read_file', return_value="file content") as mock_method:
result = tool.execute(function_name="read_file", path="test.md")
mock_method.assert_called_once_with(path="test.md")
self.assertEqual(result, "file content")
def test_execute_unknown_function(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
result = tool.execute(function_name="non_existent_function_name", arg1="val1")
self.assertIn("Unknown function: non_existent_function_name", result)
def test_execute_method_exception(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
with patch.object(tool, '_read_file', side_effect=Exception("Kaboom")) as mock_method:
result = tool.execute(function_name="read_file", path="crash.txt")
self.assertIn("Error during read_file execution: Kaboom", result)
if __name__ == '__main__':
unittest.main()
-146
View File
@@ -1,146 +0,0 @@
import unittest
from unittest.mock import patch, mock_open, MagicMock
import os
import logging
from datetime import datetime, timedelta
# Ensure tools/log_tool.py is accessible
from tools.log_tool import LogTool
class TestLogTool(unittest.TestCase):
def setUp(self):
self.test_log_file_path = "test_dummy_log.log"
# Suppress logging output during tests unless explicitly testing for it
self.logger = logging.getLogger('tools.log_tool')
# Ensure only one NullHandler to prevent duplicate messages if tests run multiple times in a session
if not any(isinstance(h, logging.NullHandler) for h in self.logger.handlers):
self.logger.addHandler(logging.NullHandler())
self.logger.propagate = False # Prevent propagation to root logger if it has handlers
def test_init_default_log_path(self):
tool = LogTool(logger=self.logger)
self.assertEqual(tool.configured_log_file_path, 'logs/output.log')
def test_init_custom_log_path(self):
tool = LogTool(log_file_path=self.test_log_file_path, logger=self.logger)
self.assertEqual(tool.configured_log_file_path, self.test_log_file_path)
def test_get_functions(self):
tool = LogTool(logger=self.logger)
functions = tool.get_functions()
self.assertIsInstance(functions, list)
self.assertEqual(len(functions), 1)
self.assertEqual(functions[0]["function"]["name"], "get_log_contents")
@patch("os.path.exists", return_value=False)
def test_get_log_contents_file_not_exists(self, mock_exists):
tool = LogTool(log_file_path=self.test_log_file_path, logger=self.logger)
result = tool._get_log_contents()
self.assertIn("Log file does not exist", result)
mock_exists.assert_called_once_with(self.test_log_file_path)
@patch("os.path.exists", return_value=True)
@patch("builtins.open", new_callable=mock_open, read_data="line1\nline2\nline3\nline4\nline5")
def test_get_log_contents_with_line_count(self, mock_file_open, mock_exists):
tool = LogTool(log_file_path=self.test_log_file_path, logger=self.logger)
result = tool._get_log_contents(line_count=3)
self.assertEqual(result, "line3\nline4\nline5")
mock_exists.assert_called_once_with(self.test_log_file_path)
mock_file_open.assert_called_once_with(self.test_log_file_path, 'r', encoding='utf-8')
@patch("os.path.exists", return_value=True)
@patch("builtins.open", new_callable=mock_open, read_data="line1\nline2\n")
def test_get_log_contents_line_count_more_than_available(self, mock_file_open, mock_exists):
tool = LogTool(log_file_path=self.test_log_file_path, logger=self.logger)
result = tool._get_log_contents(line_count=5)
self.assertEqual(result, "line1\nline2\n")
@patch("os.path.exists", return_value=True)
@patch("builtins.open", new_callable=mock_open, read_data="line1\nline2\n")
def test_get_log_contents_invalid_line_count_uses_default(self, mock_file_open, mock_exists):
tool = LogTool(log_file_path=self.test_log_file_path, logger=self.logger)
# Test with zero, negative, and non-integer line_count (though type hint is int)
# The code defaults to 150 if invalid. Here, we only have 2 lines.
with patch.object(tool.logger, 'warning') as mock_log_warning:
result_zero = tool._get_log_contents(line_count=0)
self.assertEqual(result_zero, "line1\nline2\n")
mock_log_warning.assert_any_call("Invalid line_count '0' provided, defaulting to fetch last 150 lines.")
mock_file_open.reset_mock() # Reset for next call
result_neg = tool._get_log_contents(line_count=-5)
self.assertEqual(result_neg, "line1\nline2\n")
mock_log_warning.assert_any_call("Invalid line_count '-5' provided, defaulting to fetch last 150 lines.")
@patch("os.path.exists", return_value=True)
def test_get_log_contents_last_24_hours(self, mock_exists):
tool = LogTool(log_file_path=self.test_log_file_path, logger=self.logger)
now = datetime.now()
one_hour_ago_dt = now - timedelta(hours=1)
two_days_ago_dt = now - timedelta(days=2)
one_hour_ago_str = one_hour_ago_dt.strftime(LogTool.EXPECTED_LOG_TIMESTAMP_FORMAT)
two_days_ago_str = two_days_ago_dt.strftime(LogTool.EXPECTED_LOG_TIMESTAMP_FORMAT)
log_data = (
f"{two_days_ago_str} - OLD - This is an old log entry.\n"
f"No timestamp here - this line should be skipped by time filter.\n"
f"{one_hour_ago_str} - RECENT - This is a recent log entry.\n"
f"Malformed Date 2023-xx-01 - Another skipped line.\n"
f"{now.strftime(LogTool.EXPECTED_LOG_TIMESTAMP_FORMAT)} - CURRENT - This is a current log entry.\n"
)
expected_output = (
f"{one_hour_ago_str} - RECENT - This is a recent log entry.\n"
f"{now.strftime(LogTool.EXPECTED_LOG_TIMESTAMP_FORMAT)} - CURRENT - This is a current log entry.\n"
)
with patch("builtins.open", mock_open(read_data=log_data)):
result = tool._get_log_contents(line_count=None) # Trigger 24-hour logic
self.assertEqual(result, expected_output)
@patch("os.path.exists", return_value=True)
@patch("builtins.open", side_effect=IOError("File read error!"))
def test_get_log_contents_file_read_exception(self, mock_file_open, mock_exists):
tool = LogTool(log_file_path=self.test_log_file_path, logger=self.logger)
result = tool._get_log_contents(line_count=10)
self.assertIn("An error occurred while reading the log file: File read error!", result)
def test_execute_get_log_contents(self):
tool = LogTool(logger=self.logger)
mock_return_value = "Mocked log content"
with patch.object(tool, '_get_log_contents', return_value=mock_return_value) as mock_method:
result = tool.execute(function_name="get_log_contents", line_count=50)
mock_method.assert_called_once_with(line_count=50)
self.assertEqual(result, mock_return_value)
def test_execute_get_log_contents_no_line_count(self):
tool = LogTool(logger=self.logger)
mock_return_value = "Mocked log content for 24h"
with patch.object(tool, '_get_log_contents', return_value=mock_return_value) as mock_method:
result = tool.execute(function_name="get_log_contents") # No line_count
mock_method.assert_called_once_with(line_count=None) # Expects None to trigger 24h
self.assertEqual(result, mock_return_value)
def test_execute_unknown_function(self):
tool = LogTool(logger=self.logger)
result = tool.execute(function_name="non_existent_log_function")
self.assertIn("Unknown function: non_existent_log_function", result)
def test_clear_method(self):
tool = LogTool(logger=self.logger)
# Set a specific level for the logger for this test if needed to capture debug
original_level = tool.logger.level
tool.logger.setLevel(logging.DEBUG)
with self.assertLogs(tool.logger, level='DEBUG') as cm:
tool.clear()
self.assertTrue(any("LogTool clear called" in message for message in cm.output))
tool.logger.setLevel(original_level) # Reset level
if __name__ == '__main__':
unittest.main()
-217
View File
@@ -1,217 +0,0 @@
import unittest
from unittest.mock import patch, MagicMock, ANY
import time
import logging
# Ensure tools.metrics is accessible
from tools.metrics import Metrics # Import the class itself for direct testing
from tools.metrics import metrics as global_metrics_instance # Import the global instance
# A simple function to decorate for testing
def sample_function_for_metrics(duration=0.01):
# Simulate some work
# Note: time.sleep is not always precisely profiled by cProfile in the same way as CPU-bound work.
# For testing, we will mock the cProfile/pstats interaction rather than relying on actual sleep duration.
if duration > 0: # Make it conditional so we can test zero-time case too
pass # The actual work is not important when mocking cProfile results
return "sample_output"
def another_sample_function(x, y):
return x + y
class TestMetrics(unittest.TestCase):
def setUp(self):
# Create a fresh Metrics instance for most tests to avoid interference
self.logger = logging.getLogger('tools.metrics.test')
if not self.logger.handlers: # Avoid adding handler multiple times
self.logger.addHandler(logging.NullHandler())
self.metrics_instance = Metrics(logger=self.logger)
# Clear the global instance before each test that might use it
global_metrics_instance.clear_metrics()
def test_measure_decorator_counts_calls(self):
decorated_func = self.metrics_instance.measure(sample_function_for_metrics)
self.assertEqual(self.metrics_instance.call_count["sample_function_for_metrics"], 0)
decorated_func()
self.assertEqual(self.metrics_instance.call_count["sample_function_for_metrics"], 1)
decorated_func()
self.assertEqual(self.metrics_instance.call_count["sample_function_for_metrics"], 2)
@patch('cProfile.Profile')
@patch('pstats.Stats')
def test_measure_decorator_records_time(self, MockPStats, MockCProfile):
# Mock cProfile and pstats to control the time value
mock_profiler_instance = MockCProfile.return_value
mock_pstats_instance = MockPStats.return_value
# Simulate that pstats.Stats.stats dictionary contains the function's stats
# Key: (filename, lineno, funcname)
# Value: (cc, nc, tt, ct, callers) where ct is cumulative_time (index 3)
# Get code object of the function *before* decoration for correct key
original_func_code = sample_function_for_metrics.__code__
func_key = (original_func_code.co_filename, original_func_code.co_firstlineno, original_func_code.co_name)
# Configure mock_pstats_instance.stats to return our desired time
mock_pstats_instance.stats = {func_key: (1, 1, 0.05, 0.123, {})} # cc, nc, tt, ct=0.123
decorated_func = self.metrics_instance.measure(sample_function_for_metrics)
self.assertEqual(self.metrics_instance.total_time["sample_function_for_metrics"], 0)
# Call the decorated function
decorated_func(duration=0) # Duration arg doesn't matter due to mocking
# Assertions
mock_profiler_instance.enable.assert_called_once()
mock_profiler_instance.disable.assert_called_once()
MockPStats.assert_called_once_with(mock_profiler_instance)
self.assertEqual(self.metrics_instance.total_time["sample_function_for_metrics"], 0.123)
# Call again to see accumulation
# Reset mock stats for a new time value if needed, or assume same time per call
mock_pstats_instance.stats = {func_key: (1, 1, 0.05, 0.100, {})} # New ct=0.100
decorated_func(duration=0)
self.assertAlmostEqual(self.metrics_instance.total_time["sample_function_for_metrics"], 0.123 + 0.100)
@patch('cProfile.Profile')
@patch('pstats.Stats')
def test_measure_decorator_fallback_time_recording_by_name(self, MockPStats, MockCProfile):
mock_profiler_instance = MockCProfile.return_value
mock_pstats_instance = MockPStats.return_value
original_func_code = sample_function_for_metrics.__code__ # func to be decorated
# Simulate the primary key lookup fails by creating a slightly different key for what we expect
# This is what the code will try to look up first.
expected_primary_key = (original_func_code.co_filename, original_func_code.co_firstlineno, original_func_code.co_name)
# This is the key that will *actually* be in pstats.stats, simulating a mismatch for primary lookup
# but a match for the by-name fallback.
actual_stats_key_in_pstats = (original_func_code.co_filename,
original_func_code.co_firstlineno + 5, # simulate a lineno difference for primary key mismatch
original_func_code.co_name) # Name is the same for fallback
mock_pstats_instance.stats = {
# expected_primary_key is NOT present
actual_stats_key_in_pstats: (1, 1, 0.03, 0.077, {}) # ct = 0.077
}
decorated_func = self.metrics_instance.measure(sample_function_for_metrics)
# Expecting a debug log for fallback, but assertLogs needs the logger to have a handler that captures.
# self.logger is already set up with NullHandler. For this test, let's use a specific logger.
metrics_internal_logger = logging.getLogger('tools.metrics') # Logger used inside Metrics class
original_level = metrics_internal_logger.level
metrics_internal_logger.setLevel(logging.DEBUG)
with self.assertLogs(metrics_internal_logger, level='DEBUG') as log_capture:
decorated_func(duration=0)
metrics_internal_logger.setLevel(original_level) # Reset logger level
self.assertTrue(any("Found stats for sample_function_for_metrics by name" in msg for msg in log_capture.output))
self.assertEqual(self.metrics_instance.total_time["sample_function_for_metrics"], 0.077)
@patch('cProfile.Profile')
@patch('pstats.Stats')
def test_measure_decorator_handles_func_stats_not_found(self, MockPStats, MockCProfile):
mock_profiler_instance = MockCProfile.return_value
mock_pstats_instance = MockPStats.return_value
mock_pstats_instance.stats = {} # Empty stats, function will not be found
decorated_func = self.metrics_instance.measure(sample_function_for_metrics)
metrics_internal_logger = logging.getLogger('tools.metrics')
original_level = metrics_internal_logger.level
metrics_internal_logger.setLevel(logging.WARNING)
with self.assertLogs(metrics_internal_logger, level='WARNING') as log_capture:
decorated_func(duration=0)
metrics_internal_logger.setLevel(original_level)
self.assertTrue(any("Could not find exact cProfile stats" in msg for msg in log_capture.output))
self.assertEqual(self.metrics_instance.total_time["sample_function_for_metrics"], 0)
def test_get_metrics_empty(self):
self.assertEqual(self.metrics_instance.get_metrics(), {})
@patch('cProfile.Profile')
@patch('pstats.Stats')
def test_get_metrics_with_data(self, MockPStats, MockCProfile):
mock_pstats_instance = MockPStats.return_value
# Decorate two different functions
decorated_func1 = self.metrics_instance.measure(sample_function_for_metrics)
decorated_func2 = self.metrics_instance.measure(another_sample_function)
# Data for func1
func1_code = sample_function_for_metrics.__code__
func1_key = (func1_code.co_filename, func1_code.co_firstlineno, func1_code.co_name)
mock_pstats_instance.stats = {func1_key: (1,1,0.1,0.1,{})}
decorated_func1()
# Data for func2
func2_code = another_sample_function.__code__
func2_key = (func2_code.co_filename, func2_code.co_firstlineno, func2_code.co_name)
mock_pstats_instance.stats = {func2_key: (1,1,0.2,0.2,{})} # Cumulative time 0.2
decorated_func2(1,2)
mock_pstats_instance.stats = {func2_key: (1,1,0.3,0.3,{})} # Cumulative time 0.3 for second call
decorated_func2(3,4)
metrics_data = self.metrics_instance.get_metrics()
self.assertIn("sample_function_for_metrics", metrics_data)
self.assertEqual(metrics_data["sample_function_for_metrics"]["call_count"], 1)
self.assertEqual(metrics_data["sample_function_for_metrics"]["total_time"], 0.1)
self.assertEqual(metrics_data["sample_function_for_metrics"]["average_time"], 0.1)
self.assertIn("another_sample_function", metrics_data)
self.assertEqual(metrics_data["another_sample_function"]["call_count"], 2)
self.assertAlmostEqual(metrics_data["another_sample_function"]["total_time"], 0.5)
self.assertAlmostEqual(metrics_data["another_sample_function"]["average_time"], 0.25)
def test_clear_metrics(self):
# Add some data
self.metrics_instance.call_count["test_func"] = 5
self.metrics_instance.total_time["test_func"] = 1.234
self.metrics_instance.clear_metrics()
self.assertEqual(self.metrics_instance.call_count, {})
self.assertEqual(self.metrics_instance.total_time, {})
self.assertEqual(self.metrics_instance.get_metrics(), {})
# Test the global instance
@patch('cProfile.Profile')
@patch('pstats.Stats')
def test_global_metrics_instance_usage(self, MockPStats, MockCProfile):
mock_pstats_instance = MockPStats.return_value
# Decorate a function with the global instance
@global_metrics_instance.measure
def globally_decorated_func():
return "global_output"
# Setup mock stats for the globally decorated function
# Access __wrapped__ to get the original function if other decorators might be present or for consistency.
original_g_func = globally_decorated_func.__wrapped__
func_code = original_g_func.__code__
func_key = (func_code.co_filename, func_code.co_firstlineno, func_code.co_name)
mock_pstats_instance.stats = {func_key: (1,1,0.05,0.05,{})}
globally_decorated_func()
metrics_data = global_metrics_instance.get_metrics()
self.assertIn("globally_decorated_func", metrics_data)
self.assertEqual(metrics_data["globally_decorated_func"]["call_count"], 1)
self.assertEqual(metrics_data["globally_decorated_func"]["total_time"], 0.05)
if __name__ == '__main__':
unittest.main()
-161
View File
@@ -1,161 +0,0 @@
import unittest
from unittest.mock import MagicMock, patch
import logging
# Ensure tools.metrics_tool and tools.metrics are accessible
from tools.metrics_tool import MetricsTool
from tools.metrics import Metrics # Used for typehinting and creating a mockable instance
class TestMetricsTool(unittest.TestCase):
def setUp(self):
self.mock_metrics_provider = MagicMock(spec=Metrics)
self.logger = logging.getLogger('tools.metrics_tool.test')
if not self.logger.handlers:
self.logger.addHandler(logging.NullHandler())
self.logger.propagate = False
def test_init_with_provider(self):
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
self.assertEqual(tool.metrics_provider, self.mock_metrics_provider)
@patch('tools.metrics_tool.global_metrics_instance') # Patch the global instance path
def test_init_default_provider(self, mock_global_metrics):
tool = MetricsTool(logger=self.logger)
self.assertEqual(tool.metrics_provider, mock_global_metrics)
def test_get_functions(self):
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
functions = tool.get_functions()
self.assertIsInstance(functions, list)
self.assertTrue(len(functions) == 3) # Based on current definition
self.assertIn("get_function_metrics", [f["function"]["name"] for f in functions])
self.assertIn("get_specific_function_metrics", [f["function"]["name"] for f in functions])
self.assertIn("get_top_n_functions", [f["function"]["name"] for f in functions])
def test_execute_get_function_metrics(self):
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
expected_metrics = {"func1": {"call_count": 1, "total_time": 0.1}}
self.mock_metrics_provider.get_metrics.return_value = expected_metrics
result = tool.execute(function_name="get_function_metrics")
self.mock_metrics_provider.get_metrics.assert_called_once()
self.assertEqual(result, expected_metrics)
def test_execute_get_specific_function_metrics_found(self):
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
func_metrics = {"call_count": 5, "total_time": 0.5, "average_time": 0.1}
all_metrics = {"specific_func": func_metrics, "other_func": {}}
self.mock_metrics_provider.get_metrics.return_value = all_metrics
# The execute method expects kwargs that match the function parameters in get_functions.
# So, the argument name for the function to get is 'function_name' in the tool's spec.
result = tool.execute(function_name="get_specific_function_metrics", **{"function_name": "specific_func"})
self.assertEqual(result, func_metrics)
def test_execute_get_specific_function_metrics_not_found(self):
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
self.mock_metrics_provider.get_metrics.return_value = {"other_func": {}}
result = tool.execute(function_name="get_specific_function_metrics", **{"function_name": "non_existent_func"})
self.assertEqual(result, "No metrics found for function: non_existent_func")
def test_execute_get_specific_function_metrics_missing_arg(self):
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
result = tool.execute(function_name="get_specific_function_metrics") # Missing function_name kwarg
self.assertIn("Error: Missing required argument 'function_name'", result)
def test_execute_get_top_n_functions(self):
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
metrics_data = {
"func_a": {"call_count": 1, "total_time": 0.3},
"func_b": {"call_count": 1, "total_time": 0.1},
"func_c": {"call_count": 1, "total_time": 0.5},
"func_d": {"call_count": 1, "total_time": 0.2},
}
self.mock_metrics_provider.get_metrics.return_value = metrics_data
# Test getting top 2
result = tool.execute(function_name="get_top_n_functions", n=2)
expected_top_2 = {"func_c": metrics_data["func_c"], "func_a": metrics_data["func_a"]}
self.assertEqual(result, expected_top_2)
# Test getting top 1
result_top_1 = tool.execute(function_name="get_top_n_functions", n=1)
expected_top_1 = {"func_c": metrics_data["func_c"]}
self.assertEqual(result_top_1, expected_top_1)
# Test N larger than available functions
result_top_all = tool.execute(function_name="get_top_n_functions", n=10)
# Order should be func_c, func_a, func_d, func_b
expected_top_all_keys = ["func_c", "func_a", "func_d", "func_b"]
self.assertEqual(list(result_top_all.keys()), expected_top_all_keys)
def test_execute_get_top_n_functions_malformed_metrics(self):
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
metrics_data = {
"func_a": {"call_count": 1, "total_time": 0.3},
"func_b": "not a dict", # Malformed
"func_c": {"call_count": 1}, # Missing total_time
"func_d": {"call_count": 1, "total_time": 0.2},
}
self.mock_metrics_provider.get_metrics.return_value = metrics_data
metrics_tool_logger = logging.getLogger('tools.metrics_tool')
original_level = metrics_tool_logger.level
metrics_tool_logger.setLevel(logging.WARNING)
with self.assertLogs(metrics_tool_logger, level='WARNING') as log_capture:
result = tool.execute(function_name="get_top_n_functions", n=2)
metrics_tool_logger.setLevel(original_level)
# Check that warnings were logged for malformed items
self.assertTrue(any("Metric item for 'func_b' is not in expected format" in msg for msg in log_capture.output))
self.assertTrue(any("Metric item for 'func_c' is not in expected format" in msg for msg in log_capture.output))
# Expected: func_a, func_d (as they are valid and sortable)
expected_result = {
"func_a": metrics_data["func_a"],
"func_d": metrics_data["func_d"]
}
self.assertEqual(result, expected_result)
def test_execute_get_top_n_functions_invalid_n(self):
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
self.mock_metrics_provider.get_metrics.return_value = {} # No metrics needed for this test
result_zero = tool.execute(function_name="get_top_n_functions", n=0)
self.assertIn("Error: Argument 'n' must be a positive integer.", result_zero)
result_negative = tool.execute(function_name="get_top_n_functions", n=-1)
self.assertIn("Error: Argument 'n' must be a positive integer.", result_negative)
result_string = tool.execute(function_name="get_top_n_functions", n="abc")
self.assertIn("Error: Argument 'n' must be an integer.", result_string)
def test_execute_get_top_n_functions_missing_arg_n(self):
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
result = tool.execute(function_name="get_top_n_functions") # Missing n
self.assertIn("Error: Missing required argument 'n'.", result)
def test_execute_unknown_function(self):
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
result = tool.execute(function_name="non_existent_metrics_function")
self.assertIn("Unknown function: non_existent_metrics_function", result)
def test_clear_method(self):
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
metrics_tool_logger = logging.getLogger('tools.metrics_tool')
original_level = metrics_tool_logger.level
metrics_tool_logger.setLevel(logging.DEBUG)
with self.assertLogs(metrics_tool_logger, level='DEBUG') as cm:
tool.clear()
metrics_tool_logger.setLevel(original_level)
self.assertTrue(any("MetricsTool clear method called" in message for message in cm.output))
if __name__ == '__main__':
unittest.main()
+9 -6
View File
@@ -4,8 +4,7 @@ import zipfile
import io import io
import re import re
import logging import logging
from .base_tool import BaseTool # Added from .base_tool import BaseTool
from .metrics import metrics # Added
# Configure logging for the tool - This will be handled by the logger instance now # Configure logging for the tool - This will be handled by the logger instance now
# logger = logging.getLogger(__name__) # Commented out or removed # logger = logging.getLogger(__name__) # Commented out or removed
@@ -70,7 +69,8 @@ class GitHubCIHelper(BaseTool): # Inherits from BaseTool
}, },
"required": ["pull_request_number"] "required": ["pull_request_number"]
} }
} },
"_tags": ["read"]
}, },
{ {
"type": "function", "type": "function",
@@ -85,7 +85,8 @@ class GitHubCIHelper(BaseTool): # Inherits from BaseTool
}, },
"required": ["pull_request_number"] "required": ["pull_request_number"]
} }
} },
"_tags": ["read"]
}, },
{ {
"type": "function", "type": "function",
@@ -100,7 +101,8 @@ class GitHubCIHelper(BaseTool): # Inherits from BaseTool
}, },
"required": ["run_id"] "required": ["run_id"]
} }
} },
"_tags": ["read"]
}, },
{ {
"type": "function", "type": "function",
@@ -114,7 +116,8 @@ class GitHubCIHelper(BaseTool): # Inherits from BaseTool
}, },
"required": ["log_content"] "required": ["log_content"]
} }
} },
"_tags": ["read"]
} }
] ]
+72 -74
View File
@@ -1,6 +1,5 @@
# tools/github_tool.py # tools/github_tool.py
from .base_tool import BaseTool from .base_tool import BaseTool
from .metrics import metrics
import requests import requests
import os import os
import base64 import base64
@@ -57,7 +56,8 @@ class GitHubTool(BaseTool):
}, },
"required": ["path"] "required": ["path"]
} }
} },
"_tags": ["read"]
}, },
{ {
"type": "function", "type": "function",
@@ -71,7 +71,8 @@ class GitHubTool(BaseTool):
}, },
"required": ["path"] "required": ["path"]
} }
} },
"_tags": ["read"]
}, },
{ {
"type": "function", "type": "function",
@@ -85,7 +86,8 @@ class GitHubTool(BaseTool):
}, },
"required": ["query"] "required": ["query"]
} }
} },
"_tags": ["read"]
}, },
{ {
"type": "function", "type": "function",
@@ -100,7 +102,8 @@ class GitHubTool(BaseTool):
}, },
"required": ["branch_name"] "required": ["branch_name"]
} }
} },
"_tags": ["write"]
}, },
{ {
"type": "function", "type": "function",
@@ -116,7 +119,8 @@ class GitHubTool(BaseTool):
}, },
"required": ["file_path", "commit_message", "content"] "required": ["file_path", "commit_message", "content"]
} }
} },
"_tags": ["write"]
}, },
{ {
"type": "function", "type": "function",
@@ -132,7 +136,8 @@ class GitHubTool(BaseTool):
}, },
"required": ["title", "body"] "required": ["title", "body"]
} }
} },
"_tags": ["write"]
}, },
{ {
"type": "function", "type": "function",
@@ -147,7 +152,8 @@ class GitHubTool(BaseTool):
}, },
"required": ["file_path"] "required": ["file_path"]
} }
} },
"_tags": ["read"]
}, },
{ {
"type": "function", "type": "function",
@@ -162,7 +168,8 @@ class GitHubTool(BaseTool):
}, },
"required": ["file_path"] "required": ["file_path"]
} }
} },
"_tags": ["read"]
}, },
{ {
"type": "function", "type": "function",
@@ -176,7 +183,8 @@ class GitHubTool(BaseTool):
}, },
"required": ["branch"] "required": ["branch"]
} }
} },
"_tags": ["read"]
}, },
{ {
"type": "function", "type": "function",
@@ -184,7 +192,8 @@ class GitHubTool(BaseTool):
"name": "get_current_branch", "name": "get_current_branch",
"description": "Get the name of the current branch", "description": "Get the name of the current branch",
"parameters": { "type": "object", "properties": {} } "parameters": { "type": "object", "properties": {} }
} },
"_tags": ["read"]
}, },
{ {
"type": "function", "type": "function",
@@ -198,7 +207,8 @@ class GitHubTool(BaseTool):
}, },
"required": ["branch_name"] "required": ["branch_name"]
} }
} },
"_tags": ["read", "write"]
}, },
{ {
"type": "function", "type": "function",
@@ -213,7 +223,8 @@ class GitHubTool(BaseTool):
}, },
"required": ["file_path", "commit_sha"] "required": ["file_path", "commit_sha"]
} }
} },
"_tags": ["read"]
}, },
{ {
"type": "function", "type": "function",
@@ -227,7 +238,8 @@ class GitHubTool(BaseTool):
"all_pages": {"type": "boolean", "description": "Whether to fetch all pages of results", "default": True} "all_pages": {"type": "boolean", "description": "Whether to fetch all pages of results", "default": True}
} }
} }
} },
"_tags": ["read"]
}, },
{ {
"type": "function", "type": "function",
@@ -241,7 +253,8 @@ class GitHubTool(BaseTool):
}, },
"required": ["pull_number"] "required": ["pull_number"]
} }
} },
"_tags": ["write"]
}, },
{ {
"type": "function", "type": "function",
@@ -255,7 +268,8 @@ class GitHubTool(BaseTool):
}, },
"required": ["pull_number"] "required": ["pull_number"]
} }
} },
"_tags": ["write"]
}, },
{ {
"type": "function", "type": "function",
@@ -277,7 +291,8 @@ class GitHubTool(BaseTool):
}, },
"required": ["pull_number"] "required": ["pull_number"]
} }
} },
"_tags": ["write"]
}, },
{ {
"type": "function", "type": "function",
@@ -291,7 +306,8 @@ class GitHubTool(BaseTool):
}, },
"required": ["branch_name"] "required": ["branch_name"]
} }
} },
"_tags": ["write"]
}, },
{ {
"type": "function", "type": "function",
@@ -305,7 +321,8 @@ class GitHubTool(BaseTool):
}, },
"required": ["issue_number"] "required": ["issue_number"]
} }
} },
"_tags": ["read"]
}, },
{ {
"type": "function", "type": "function",
@@ -325,7 +342,8 @@ class GitHubTool(BaseTool):
}, },
"required": ["title", "body"] "required": ["title", "body"]
} }
} },
"_tags": ["communicate"]
}, },
{ {
"type": "function", "type": "function",
@@ -340,7 +358,8 @@ class GitHubTool(BaseTool):
"page": {"type": "integer", "default": 1, "description": "Page number of the results to fetch"} "page": {"type": "integer", "default": 1, "description": "Page number of the results to fetch"}
} }
} }
} },
"_tags": ["read"]
}, },
{ {
"type": "function", "type": "function",
@@ -355,7 +374,8 @@ class GitHubTool(BaseTool):
}, },
"required": ["issue_number", "comment"] "required": ["issue_number", "comment"]
} }
} },
"_tags": ["communicate"]
}, },
{ {
"type": "function", "type": "function",
@@ -369,7 +389,8 @@ class GitHubTool(BaseTool):
}, },
"required": ["issue_number"] "required": ["issue_number"]
} }
} },
"_tags": ["read"]
}, },
{ {
"type": "function", "type": "function",
@@ -383,7 +404,8 @@ class GitHubTool(BaseTool):
}, },
"required": ["pull_number"] "required": ["pull_number"]
} }
} },
"_tags": ["read"]
}, },
{ {
"type": "function", "type": "function",
@@ -398,7 +420,8 @@ class GitHubTool(BaseTool):
}, },
"required": ["name"] "required": ["name"]
} }
} },
"_tags": ["communicate"]
}, },
{ {
"type": "function", "type": "function",
@@ -413,7 +436,8 @@ class GitHubTool(BaseTool):
}, },
"required": ["project_id", "column_name"] "required": ["project_id", "column_name"]
} }
} },
"_tags": ["communicate"]
}, },
{ {
"type": "function", "type": "function",
@@ -428,7 +452,8 @@ class GitHubTool(BaseTool):
}, },
"required": ["column_id", "note"] "required": ["column_id", "note"]
} }
} },
"_tags": ["communicate"]
}, },
{ {
"type": "function", "type": "function",
@@ -444,7 +469,8 @@ class GitHubTool(BaseTool):
}, },
"required": ["card_id", "position", "column_id"] "required": ["card_id", "position", "column_id"]
} }
} },
"_tags": ["communicate"]
}, },
{ {
"type": "function", "type": "function",
@@ -460,7 +486,8 @@ class GitHubTool(BaseTool):
}, },
"required": ["card_id", "content_id", "content_type"] "required": ["card_id", "content_id", "content_type"]
} }
} },
"_tags": ["communicate"]
}, },
{ {
"type": "function", "type": "function",
@@ -468,7 +495,8 @@ class GitHubTool(BaseTool):
"name": "list_project_boards", "name": "list_project_boards",
"description": "List project boards associated with the repository", "description": "List project boards associated with the repository",
"parameters": { "type": "object", "properties": {} } "parameters": { "type": "object", "properties": {} }
} },
"_tags": ["communicate"]
}, },
{ {
"type": "function", "type": "function",
@@ -482,7 +510,8 @@ class GitHubTool(BaseTool):
}, },
"required": ["project_id"] "required": ["project_id"]
} }
} },
"_tags": ["communicate"]
}, },
{ {
"type": "function", "type": "function",
@@ -496,7 +525,8 @@ class GitHubTool(BaseTool):
}, },
"required": ["pull_number"] "required": ["pull_number"]
} }
} },
"_tags": ["read"]
}, },
{ {
"type": "function", "type": "function",
@@ -510,7 +540,8 @@ class GitHubTool(BaseTool):
}, },
"required": ["pull_number"] "required": ["pull_number"]
} }
} },
"_tags": ["read"]
}, },
{ {
"type": "function", "type": "function",
@@ -524,7 +555,8 @@ class GitHubTool(BaseTool):
}, },
"required": ["pull_number"] "required": ["pull_number"]
} }
} },
"_tags": ["read"]
}, },
{ {
"type": "function", "type": "function",
@@ -545,7 +577,8 @@ class GitHubTool(BaseTool):
}, },
"required": ["pull_number", "body", "commit_id", "path", "position"] "required": ["pull_number", "body", "commit_id", "path", "position"]
} }
} },
"_tags": ["communicate"]
}, },
{ {
"type": "function", "type": "function",
@@ -559,7 +592,8 @@ class GitHubTool(BaseTool):
}, },
"required": ["pull_number"] "required": ["pull_number"]
} }
} },
"_tags": ["read"]
}, },
{ {
"type": "function", "type": "function",
@@ -575,11 +609,11 @@ class GitHubTool(BaseTool):
}, },
"required": ["pull_number", "event"] "required": ["pull_number", "event"]
} }
} },
"_tags": ["communicate"]
} }
] ]
@metrics.measure
def execute(self, function_name, **kwargs): def execute(self, function_name, **kwargs):
self.logger.info(f"Executing GitHub Tool function: {function_name} with args: {kwargs}") self.logger.info(f"Executing GitHub Tool function: {function_name} with args: {kwargs}")
# Dispatch to the appropriate private method # Dispatch to the appropriate private method
@@ -598,7 +632,6 @@ class GitHubTool(BaseTool):
# Private methods for each function, using self.session for HTTP requests # Private methods for each function, using self.session for HTTP requests
@metrics.measure
def _read_file(self, path): def _read_file(self, path):
self.logger.info(f"Reading file: {path} from branch: {self.current_branch}") self.logger.info(f"Reading file: {path} from branch: {self.current_branch}")
url = f"{self.base_url}/repos/{self._repo}/contents/{path}" url = f"{self.base_url}/repos/{self._repo}/contents/{path}"
@@ -613,7 +646,6 @@ class GitHubTool(BaseTool):
self.logger.error(error_message) self.logger.error(error_message)
return error_message return error_message
@metrics.measure
def _create_branch(self, branch_name, base_branch="main"): def _create_branch(self, branch_name, base_branch="main"):
self.logger.info(f"Creating branch: {branch_name} from base: {base_branch}") self.logger.info(f"Creating branch: {branch_name} from base: {base_branch}")
# Get SHA of base branch # Get SHA of base branch
@@ -639,7 +671,6 @@ class GitHubTool(BaseTool):
self.logger.error(error_message) self.logger.error(error_message)
return error_message return error_message
@metrics.measure
def _commit_file(self, file_path, content, commit_message): def _commit_file(self, file_path, content, commit_message):
self.logger.info(f"Committing file: {file_path} to branch: {self.current_branch} with message: '{commit_message}'") self.logger.info(f"Committing file: {file_path} to branch: {self.current_branch} with message: '{commit_message}'")
if self.current_branch == "main": if self.current_branch == "main":
@@ -679,7 +710,6 @@ class GitHubTool(BaseTool):
self.logger.error(error_message) self.logger.error(error_message)
return error_message return error_message
@metrics.measure
def _create_pull_request(self, title, body, base="main"): def _create_pull_request(self, title, body, base="main"):
self.logger.info(f"Creating pull request: '{title}' from branch '{self.current_branch}' to '{base}'") self.logger.info(f"Creating pull request: '{title}' from branch '{self.current_branch}' to '{base}'")
if self.current_branch == base: if self.current_branch == base:
@@ -701,7 +731,6 @@ class GitHubTool(BaseTool):
self.logger.error(error_message) self.logger.error(error_message)
return error_message return error_message
@metrics.measure
def _get_branch_sha(self, branch): def _get_branch_sha(self, branch):
self.logger.info(f"Getting SHA for branch: {branch}") self.logger.info(f"Getting SHA for branch: {branch}")
url = f"{self.base_url}/repos/{self._repo}/git/refs/heads/{branch}" url = f"{self.base_url}/repos/{self._repo}/git/refs/heads/{branch}"
@@ -715,7 +744,6 @@ class GitHubTool(BaseTool):
self.logger.error(error_message) self.logger.error(error_message)
return error_message return error_message
@metrics.measure
def _list_files(self, path): def _list_files(self, path):
self.logger.info(f"Listing files in path: '{path}' on branch: '{self.current_branch}'") self.logger.info(f"Listing files in path: '{path}' on branch: '{self.current_branch}'")
url = f"{self.base_url}/repos/{self._repo}/contents/{path.strip('/')}" # Ensure no leading/trailing slashes for consistency url = f"{self.base_url}/repos/{self._repo}/contents/{path.strip('/')}" # Ensure no leading/trailing slashes for consistency
@@ -738,7 +766,6 @@ class GitHubTool(BaseTool):
self.logger.error(error_message) self.logger.error(error_message)
return error_message return error_message
@metrics.measure
def _search_code(self, query): def _search_code(self, query):
self.logger.info(f"Searching code with query: '{query}' in repo: '{self._repo}'") self.logger.info(f"Searching code with query: '{query}' in repo: '{self._repo}'")
url = f"{self.base_url}/search/code" url = f"{self.base_url}/search/code"
@@ -754,7 +781,6 @@ class GitHubTool(BaseTool):
self.logger.error(error_message) self.logger.error(error_message)
return error_message return error_message
@metrics.measure
def _get_commit_history(self, file_path, num_commits=10): def _get_commit_history(self, file_path, num_commits=10):
self.logger.info(f"Getting last {num_commits} commit(s) for file: '{file_path}' on branch '{self.current_branch}'") self.logger.info(f"Getting last {num_commits} commit(s) for file: '{file_path}' on branch '{self.current_branch}'")
url = f"{self.base_url}/repos/{self._repo}/commits" url = f"{self.base_url}/repos/{self._repo}/commits"
@@ -775,18 +801,15 @@ class GitHubTool(BaseTool):
self.logger.error(error_message) self.logger.error(error_message)
return error_message return error_message
@metrics.measure
def _view_commit_details_for_file(self, file_path, num_commits=10): def _view_commit_details_for_file(self, file_path, num_commits=10):
# This function is essentially the same as get_commit_history based on its description. # This function is essentially the same as get_commit_history based on its description.
self.logger.info(f"Viewing commit details for file '{file_path}' (last {num_commits} commits) - using _get_commit_history.") self.logger.info(f"Viewing commit details for file '{file_path}' (last {num_commits} commits) - using _get_commit_history.")
return self._get_commit_history(file_path, num_commits) return self._get_commit_history(file_path, num_commits)
@metrics.measure
def _get_current_branch(self): def _get_current_branch(self):
self.logger.info(f"Current branch is: {self.current_branch}") self.logger.info(f"Current branch is: {self.current_branch}")
return self.current_branch return self.current_branch
@metrics.measure
def _set_current_branch(self, branch_name): def _set_current_branch(self, branch_name):
self.logger.info(f"Attempting to set current branch to: {branch_name}") self.logger.info(f"Attempting to set current branch to: {branch_name}")
# Check if branch exists by trying to get its SHA # Check if branch exists by trying to get its SHA
@@ -801,7 +824,6 @@ class GitHubTool(BaseTool):
self.logger.info(success_message) self.logger.info(success_message)
return success_message return success_message
@metrics.measure
def _get_file_at_commit(self, file_path, commit_sha): def _get_file_at_commit(self, file_path, commit_sha):
self.logger.info(f"Getting file '{file_path}' at commit SHA: {commit_sha}") self.logger.info(f"Getting file '{file_path}' at commit SHA: {commit_sha}")
url = f"{self.base_url}/repos/{self._repo}/contents/{file_path}" url = f"{self.base_url}/repos/{self._repo}/contents/{file_path}"
@@ -816,7 +838,6 @@ class GitHubTool(BaseTool):
self.logger.error(error_message) self.logger.error(error_message)
return error_message return error_message
@metrics.measure
def _list_branches(self, per_page=100, all_pages=True): def _list_branches(self, per_page=100, all_pages=True):
self.logger.info(f"Listing branches for repo '{self._repo}'. Per_page={per_page}, All_pages={all_pages}") self.logger.info(f"Listing branches for repo '{self._repo}'. Per_page={per_page}, All_pages={all_pages}")
url = f"{self.base_url}/repos/{self._repo}/branches" url = f"{self.base_url}/repos/{self._repo}/branches"
@@ -844,7 +865,6 @@ class GitHubTool(BaseTool):
self.logger.info(f"Successfully listed {len(branches_list)} branches.") self.logger.info(f"Successfully listed {len(branches_list)} branches.")
return branches_list return branches_list
@metrics.measure
def _approve_pull_request(self, pull_number): def _approve_pull_request(self, pull_number):
self.logger.info(f"Approving pull request #{pull_number}") self.logger.info(f"Approving pull request #{pull_number}")
url = f"{self.base_url}/repos/{self._repo}/pulls/{pull_number}/reviews" url = f"{self.base_url}/repos/{self._repo}/pulls/{pull_number}/reviews"
@@ -859,7 +879,6 @@ class GitHubTool(BaseTool):
self.logger.error(error_message) self.logger.error(error_message)
return error_message return error_message
@metrics.measure
def _close_pull_request(self, pull_number): def _close_pull_request(self, pull_number):
self.logger.info(f"Closing pull request #{pull_number}") self.logger.info(f"Closing pull request #{pull_number}")
url = f"{self.base_url}/repos/{self._repo}/pulls/{pull_number}" url = f"{self.base_url}/repos/{self._repo}/pulls/{pull_number}"
@@ -874,7 +893,6 @@ class GitHubTool(BaseTool):
self.logger.error(error_message) self.logger.error(error_message)
return error_message return error_message
@metrics.measure
def _merge_pull_request(self, pull_number, commit_title="Merge pull request", commit_message="", merge_method="merge"): def _merge_pull_request(self, pull_number, commit_title="Merge pull request", commit_message="", merge_method="merge"):
self.logger.info(f"Merging pull request #{pull_number} using method '{merge_method}'") self.logger.info(f"Merging pull request #{pull_number} using method '{merge_method}'")
url = f"{self.base_url}/repos/{self._repo}/pulls/{pull_number}/merge" url = f"{self.base_url}/repos/{self._repo}/pulls/{pull_number}/merge"
@@ -897,7 +915,6 @@ class GitHubTool(BaseTool):
self.logger.error(error_message) self.logger.error(error_message)
return error_message return error_message
@metrics.measure
def _delete_branch(self, branch_name): def _delete_branch(self, branch_name):
self.logger.info(f"Deleting branch: {branch_name}") self.logger.info(f"Deleting branch: {branch_name}")
if branch_name == "main" or (hasattr(self, 'default_branch') and branch_name == self.default_branch) : if branch_name == "main" or (hasattr(self, 'default_branch') and branch_name == self.default_branch) :
@@ -920,7 +937,6 @@ class GitHubTool(BaseTool):
self.logger.error(error_message) self.logger.error(error_message)
return error_message return error_message
@metrics.measure
def _get_issue_details(self, issue_number): def _get_issue_details(self, issue_number):
self.logger.info(f"Getting details for issue #{issue_number}") self.logger.info(f"Getting details for issue #{issue_number}")
url = f"{self.base_url}/repos/{self._repo}/issues/{issue_number}" url = f"{self.base_url}/repos/{self._repo}/issues/{issue_number}"
@@ -933,7 +949,6 @@ class GitHubTool(BaseTool):
self.logger.error(error_message) self.logger.error(error_message)
return error_message return error_message
@metrics.measure
def _create_issue(self, title, body, labels=None): def _create_issue(self, title, body, labels=None):
self.logger.info(f"Creating new issue with title: '{title}'") self.logger.info(f"Creating new issue with title: '{title}'")
url = f"{self.base_url}/repos/{self._repo}/issues" url = f"{self.base_url}/repos/{self._repo}/issues"
@@ -953,7 +968,6 @@ class GitHubTool(BaseTool):
self.logger.error(error_message) self.logger.error(error_message)
return error_message return error_message
@metrics.measure
def _list_issues(self, state="open", per_page=30, page=1): def _list_issues(self, state="open", per_page=30, page=1):
self.logger.info(f"Listing issues with state: {state}, per_page: {per_page}, page: {page}") self.logger.info(f"Listing issues with state: {state}, per_page: {per_page}, page: {page}")
url = f"{self.base_url}/repos/{self._repo}/issues" url = f"{self.base_url}/repos/{self._repo}/issues"
@@ -969,7 +983,6 @@ class GitHubTool(BaseTool):
self.logger.error(error_message) self.logger.error(error_message)
return error_message return error_message
@metrics.measure
def _add_issue_comment(self, issue_number, comment): def _add_issue_comment(self, issue_number, comment):
self.logger.info(f"Adding comment to issue #{issue_number}: '{comment[:50]}...'") self.logger.info(f"Adding comment to issue #{issue_number}: '{comment[:50]}...'")
url = f"{self.base_url}/repos/{self._repo}/issues/{issue_number}/comments" url = f"{self.base_url}/repos/{self._repo}/issues/{issue_number}/comments"
@@ -985,7 +998,6 @@ class GitHubTool(BaseTool):
self.logger.error(error_message) self.logger.error(error_message)
return error_message return error_message
@metrics.measure
def _get_issue_comments(self, issue_number): def _get_issue_comments(self, issue_number):
self.logger.info(f"Getting comments for issue #{issue_number}") self.logger.info(f"Getting comments for issue #{issue_number}")
url = f"{self.base_url}/repos/{self._repo}/issues/{issue_number}/comments" url = f"{self.base_url}/repos/{self._repo}/issues/{issue_number}/comments"
@@ -1000,14 +1012,12 @@ class GitHubTool(BaseTool):
self.logger.error(error_message) self.logger.error(error_message)
return error_message return error_message
@metrics.measure
def _get_pull_request_general_comments(self, pull_number): def _get_pull_request_general_comments(self, pull_number):
self.logger.info(f"Getting general comments for pull request #{pull_number}") self.logger.info(f"Getting general comments for pull request #{pull_number}")
# In GitHub API, PR comments (general, not review comments on lines) are issue comments. # In GitHub API, PR comments (general, not review comments on lines) are issue comments.
# The PR is also an issue, so use the issue comments endpoint. # The PR is also an issue, so use the issue comments endpoint.
return self._get_issue_comments(issue_number=pull_number) return self._get_issue_comments(issue_number=pull_number)
@metrics.measure
def _create_project_board(self, name, body=None): def _create_project_board(self, name, body=None):
self.logger.info(f"Creating project board: '{name}'") self.logger.info(f"Creating project board: '{name}'")
url = f"{self.base_url}/repos/{self._repo}/projects" url = f"{self.base_url}/repos/{self._repo}/projects"
@@ -1026,7 +1036,6 @@ class GitHubTool(BaseTool):
self.logger.error(error_message) self.logger.error(error_message)
return error_message return error_message
@metrics.measure
def _create_project_column(self, project_id, column_name): def _create_project_column(self, project_id, column_name):
self.logger.info(f"Creating column '{column_name}' for project ID: {project_id}") self.logger.info(f"Creating column '{column_name}' for project ID: {project_id}")
url = f"{self.base_url}/projects/{project_id}/columns" url = f"{self.base_url}/projects/{project_id}/columns"
@@ -1044,7 +1053,6 @@ class GitHubTool(BaseTool):
self.logger.error(error_message) self.logger.error(error_message)
return error_message return error_message
@metrics.measure
def _create_project_card(self, column_id, note=None, content_id=None, content_type=None): def _create_project_card(self, column_id, note=None, content_id=None, content_type=None):
self.logger.info(f"Creating card in column ID: {column_id}") self.logger.info(f"Creating card in column ID: {column_id}")
url = f"{self.base_url}/projects/columns/{column_id}/cards" url = f"{self.base_url}/projects/columns/{column_id}/cards"
@@ -1075,7 +1083,6 @@ class GitHubTool(BaseTool):
self.logger.error(error_message) self.logger.error(error_message)
return error_message return error_message
@metrics.measure
def _move_project_card(self, card_id, position, column_id=None): def _move_project_card(self, card_id, position, column_id=None):
self.logger.info(f"Moving card ID: {card_id} to position: {position}" + (f" in column ID: {column_id}" if column_id else "")) self.logger.info(f"Moving card ID: {card_id} to position: {position}" + (f" in column ID: {column_id}" if column_id else ""))
url = f"{self.base_url}/projects/columns/cards/{card_id}/moves" url = f"{self.base_url}/projects/columns/cards/{card_id}/moves"
@@ -1100,7 +1107,6 @@ class GitHubTool(BaseTool):
# For updating an existing card to link an issue, one would PATCH the card's content_id/content_type. # For updating an existing card to link an issue, one would PATCH the card's content_id/content_type.
# Let's assume the function intends to update an existing card if it's a separate function. # Let's assume the function intends to update an existing card if it's a separate function.
# However, the provided API spec for `link_issue_to_project_card` uses PATCH on card_id, so let's implement that. # However, the provided API spec for `link_issue_to_project_card` uses PATCH on card_id, so let's implement that.
@metrics.measure
def _link_issue_to_project_card(self, card_id, content_id, content_type): def _link_issue_to_project_card(self, card_id, content_id, content_type):
self.logger.info(f"Linking content_id {content_id} (type: {content_type}) to card_id {card_id}") self.logger.info(f"Linking content_id {content_id} (type: {content_type}) to card_id {card_id}")
url = f"{self.base_url}/projects/cards/{card_id}" # Note: API docs suggest /projects/columns/cards/{card_id} or /projects/cards/{card_id} url = f"{self.base_url}/projects/cards/{card_id}" # Note: API docs suggest /projects/columns/cards/{card_id} or /projects/cards/{card_id}
@@ -1120,7 +1126,6 @@ class GitHubTool(BaseTool):
self.logger.error(error_message) self.logger.error(error_message)
return error_message return error_message
@metrics.measure
def _list_project_boards(self): def _list_project_boards(self):
self.logger.info(f"Listing project boards for repo: {self._repo}") self.logger.info(f"Listing project boards for repo: {self._repo}")
url = f"{self.base_url}/repos/{self._repo}/projects" url = f"{self.base_url}/repos/{self._repo}/projects"
@@ -1136,7 +1141,6 @@ class GitHubTool(BaseTool):
self.logger.error(error_message) self.logger.error(error_message)
return error_message return error_message
@metrics.measure
def _view_project_board_items(self, project_id): def _view_project_board_items(self, project_id):
self.logger.info(f"Viewing items for project ID: {project_id}") self.logger.info(f"Viewing items for project ID: {project_id}")
columns_url = f"{self.base_url}/projects/{project_id}/columns" columns_url = f"{self.base_url}/projects/{project_id}/columns"
@@ -1165,7 +1169,6 @@ class GitHubTool(BaseTool):
self.logger.info(f"Successfully retrieved items for project ID: {project_id}.") self.logger.info(f"Successfully retrieved items for project ID: {project_id}.")
return project_items return project_items
@metrics.measure
def _get_pull_request_details(self, pull_number): def _get_pull_request_details(self, pull_number):
self.logger.info(f"Getting details for PR #{pull_number}") self.logger.info(f"Getting details for PR #{pull_number}")
url = f"{self.base_url}/repos/{self._repo}/pulls/{pull_number}" url = f"{self.base_url}/repos/{self._repo}/pulls/{pull_number}"
@@ -1178,7 +1181,6 @@ class GitHubTool(BaseTool):
self.logger.error(error_message) self.logger.error(error_message)
return error_message return error_message
@metrics.measure
def _get_pull_request_diff(self, pull_number): def _get_pull_request_diff(self, pull_number):
self.logger.info(f"Getting diff for PR #{pull_number}") self.logger.info(f"Getting diff for PR #{pull_number}")
url = f"{self.base_url}/repos/{self._repo}/pulls/{pull_number}" url = f"{self.base_url}/repos/{self._repo}/pulls/{pull_number}"
@@ -1193,7 +1195,6 @@ class GitHubTool(BaseTool):
self.logger.error(error_message) self.logger.error(error_message)
return error_message return error_message
@metrics.measure
def _get_pull_request_files(self, pull_number): def _get_pull_request_files(self, pull_number):
self.logger.info(f"Getting files for PR #{pull_number}") self.logger.info(f"Getting files for PR #{pull_number}")
url = f"{self.base_url}/repos/{self._repo}/pulls/{pull_number}/files" url = f"{self.base_url}/repos/{self._repo}/pulls/{pull_number}/files"
@@ -1206,7 +1207,6 @@ class GitHubTool(BaseTool):
self.logger.error(error_message) self.logger.error(error_message)
return error_message return error_message
@metrics.measure
def _create_pull_request_review_comment(self, pull_number, body, commit_id, path, position, side="RIGHT", start_line=None, start_side=None): def _create_pull_request_review_comment(self, pull_number, body, commit_id, path, position, side="RIGHT", start_line=None, start_side=None):
self.logger.info(f"Creating review comment on PR #{pull_number}, file '{path}', position {position}") self.logger.info(f"Creating review comment on PR #{pull_number}, file '{path}', position {position}")
url = f"{self.base_url}/repos/{self._repo}/pulls/{pull_number}/comments" url = f"{self.base_url}/repos/{self._repo}/pulls/{pull_number}/comments"
@@ -1225,7 +1225,6 @@ class GitHubTool(BaseTool):
self.logger.error(error_message) self.logger.error(error_message)
return error_message return error_message
@metrics.measure
def _list_pull_request_review_comments(self, pull_number): def _list_pull_request_review_comments(self, pull_number):
self.logger.info(f"Listing review comments for PR #{pull_number}") self.logger.info(f"Listing review comments for PR #{pull_number}")
url = f"{self.base_url}/repos/{self._repo}/pulls/{pull_number}/comments" url = f"{self.base_url}/repos/{self._repo}/pulls/{pull_number}/comments"
@@ -1238,7 +1237,6 @@ class GitHubTool(BaseTool):
self.logger.error(error_message) self.logger.error(error_message)
return error_message return error_message
@metrics.measure
def _submit_pull_request_review(self, pull_number, event, body=None): def _submit_pull_request_review(self, pull_number, event, body=None):
self.logger.info(f"Submitting '{event}' review for PR #{pull_number}") self.logger.info(f"Submitting '{event}' review for PR #{pull_number}")
url = f"{self.base_url}/repos/{self._repo}/pulls/{pull_number}/reviews" url = f"{self.base_url}/repos/{self._repo}/pulls/{pull_number}/reviews"
-3
View File
@@ -1,6 +1,5 @@
# tools/log_tool.py # tools/log_tool.py
from .base_tool import BaseTool from .base_tool import BaseTool
from .metrics import metrics
import logging import logging
import os import os
from datetime import datetime, timedelta from datetime import datetime, timedelta
@@ -44,7 +43,6 @@ class LogTool(BaseTool):
} }
] ]
@metrics.measure
def execute(self, function_name, **kwargs): def execute(self, function_name, **kwargs):
self.logger.info(f"Executing LogTool function: {function_name} with args: {kwargs}") self.logger.info(f"Executing LogTool function: {function_name} with args: {kwargs}")
if function_name == "get_log_contents": if function_name == "get_log_contents":
@@ -55,7 +53,6 @@ class LogTool(BaseTool):
self.logger.error(error_message) self.logger.error(error_message)
return error_message return error_message
@metrics.measure
def _get_log_contents(self, line_count=None): # Default line_count is None to trigger 24h logic if not specified def _get_log_contents(self, line_count=None): # Default line_count is None to trigger 24h logic if not specified
self.logger.info(f"Attempting to get log contents from: {self.configured_log_file_path}. Line count: {line_count if line_count is not None else 'Last 24 hours'}") self.logger.info(f"Attempting to get log contents from: {self.configured_log_file_path}. Line count: {line_count if line_count is not None else 'Last 24 hours'}")
-79
View File
@@ -1,79 +0,0 @@
# tools/metrics.py
import cProfile
import pstats
import io
from functools import wraps
from collections import defaultdict
import logging
class Metrics:
def __init__(self, logger=None):
self.call_count = defaultdict(int)
self.total_time = defaultdict(float)
self.logger = logger if logger else logging.getLogger(__name__)
if not self.logger.handlers:
self.logger.addHandler(logging.NullHandler())
self.logger.debug("Metrics instance initialized.")
def measure(self, func):
@wraps(func)
def wrapper(*args, **kwargs):
self.call_count[func.__name__] += 1
pr = cProfile.Profile()
pr.enable()
result = func(*args, **kwargs)
pr.disable()
ps = pstats.Stats(pr)
func_code = func.__code__
func_key_tuple = (func_code.co_filename, func_code.co_firstlineno, func_code.co_name)
time_spent_for_func = 0.0
if func_key_tuple in ps.stats:
time_spent_for_func = ps.stats[func_key_tuple][3] # [3] is cumulative time (ct)
else:
# Fallback: try to find by function name if exact key fails (e.g. due to decorators changing code object details slightly)
# This is less precise and might pick up other functions if names are not unique across files.
found_by_name = False
for key, stat in ps.stats.items():
if key[2] == func.__name__: # key[2] is function name
time_spent_for_func = stat[3] # cumulative time
self.logger.debug(f"Found stats for {func.__name__} by name {key} after primary key failed.")
found_by_name = True
break
if not found_by_name:
self.logger.warning(
f"Could not find exact cProfile stats for {func.__name__} with key {func_key_tuple} or by name. "
f"Time for this call will be recorded as 0. This might occur for non-Python functions or due to complex decorators."
)
self.total_time[func.__name__] += time_spent_for_func
self.logger.debug(f"Measured cumulative time for {func.__name__}: {time_spent_for_func:.6f}s")
return result
return wrapper
def get_metrics(self):
metrics_data = {}
for func_name in self.call_count:
count = self.call_count[func_name]
total_t = self.total_time[func_name]
metrics_data[func_name] = {
'call_count': count,
'total_time': round(total_t, 6),
'average_time': round(total_t / count, 6) if count > 0 else 0
}
return metrics_data
def clear_metrics(self):
self.call_count.clear()
self.total_time.clear()
self.logger.info("Metrics cleared.")
# Global instance for convenience
_metrics_instance_logger = logging.getLogger(__name__ + ".global_instance")
if not _metrics_instance_logger.handlers:
_metrics_instance_logger.addHandler(logging.NullHandler())
metrics = Metrics(logger=_metrics_instance_logger)
-128
View File
@@ -1,128 +0,0 @@
# tools/metrics_tool.py
from .base_tool import BaseTool
from .metrics import metrics as global_metrics_instance # For default and measuring execute
from .metrics import Metrics # For type hinting and potentially creating a new one if needed
import logging
class MetricsTool(BaseTool):
def __init__(self, metrics_provider: Metrics | None = None, logger: logging.Logger | None = None):
self.metrics_provider = metrics_provider if metrics_provider is not None else global_metrics_instance
self.logger = logger if logger else logging.getLogger(__name__)
if not self.logger.handlers:
self.logger.addHandler(logging.NullHandler())
self.logger.debug(f"MetricsTool initialized. Using metrics provider: {self.metrics_provider}")
def clear(self):
# This tool itself doesn't hold state that needs clearing beyond what its metrics_provider might do.
# If this tool were responsible for clearing the metrics it reports on, it would call:
# self.metrics_provider.clear_metrics()
self.logger.debug("MetricsTool clear method called. No local state to clear.")
pass
def get_functions(self):
return [
{
"type": "function",
"function": {
"name": "get_function_metrics",
"description": "Get metrics for all measured functions.",
"parameters": {
"type": "object",
"properties": {},
"required": []
}
}
},
{
"type": "function",
"function": {
"name": "get_specific_function_metrics",
"description": "Get metrics for a specific function.",
"parameters": {
"type": "object",
"properties": {
"function_name": {
"type": "string",
"description": "Name of the function to get metrics for"
}
},
"required": ["function_name"]
}
}
},
{
"type": "function",
"function": {
"name": "get_top_n_functions",
"description": "Get the top N functions by total execution time.",
"parameters": {
"type": "object",
"properties": {
"n": {
"type": "integer",
"description": "Number of top functions to retrieve"
}
},
"required": ["n"]
}
}
}
]
@global_metrics_instance.measure # The execute method can be measured by the global instance
def execute(self, function_name, **kwargs):
self.logger.info(f"Executing MetricsTool function: {function_name} with args: {kwargs}")
if function_name == "get_function_metrics":
return self._get_function_metrics()
elif function_name == "get_specific_function_metrics":
func_name_arg = kwargs.get("function_name")
if func_name_arg is None: # Check if None, as empty string could be a valid (though unlikely) func name
self.logger.warning("'function_name' argument is missing for get_specific_function_metrics.")
return "Error: Missing required argument 'function_name'."
return self._get_specific_function_metrics(str(func_name_arg)) # Ensure string
elif function_name == "get_top_n_functions":
n_arg = kwargs.get("n")
if n_arg is None:
self.logger.warning("'n' argument is missing for get_top_n_functions.")
return "Error: Missing required argument 'n'."
try:
n_val = int(n_arg)
if n_val <= 0:
self.logger.warning(f"'n' argument must be a positive integer, got {n_val}.")
return "Error: Argument 'n' must be a positive integer."
return self._get_top_n_functions(n_val)
except ValueError:
self.logger.warning(f"'n' argument must be an integer, got '{n_arg}'.")
return "Error: Argument 'n' must be an integer."
else:
error_message = f"Unknown function: {function_name}"
self.logger.error(error_message)
return error_message
def _get_function_metrics(self):
self.logger.debug("Calling metrics_provider.get_metrics() for all functions.")
return self.metrics_provider.get_metrics()
def _get_specific_function_metrics(self, function_to_get):
self.logger.debug(f"Getting metrics for specific function: {function_to_get}")
all_metrics = self.metrics_provider.get_metrics()
return all_metrics.get(function_to_get, f"No metrics found for function: {function_to_get}")
def _get_top_n_functions(self, n):
self.logger.debug(f"Getting top {n} functions by total execution time.")
all_metrics = self.metrics_provider.get_metrics()
# Ensure that the items are actual metric dicts before trying to access 'total_time'
valid_metrics_items = []
for name, metric_values in all_metrics.items():
if isinstance(metric_values, dict) and 'total_time' in metric_values:
valid_metrics_items.append((name, metric_values))
else:
self.logger.warning(f"Metric item for '{name}' is not in expected format: {metric_values}")
# Sort items by total_time. items() gives list of (func_name, metrics_dict)
try:
sorted_metrics = sorted(valid_metrics_items, key=lambda item: item[1]['total_time'], reverse=True)
return dict(sorted_metrics[:n])
except TypeError as e:
self.logger.error(f"Error sorting metrics, possibly due to unexpected data types: {e}", exc_info=True)
return "Error: Could not sort metrics due to unexpected data."
@@ -28,7 +28,7 @@ class StandaloneLLMTool(BaseTool):
"model": { "model": {
"type": "string", "type": "string",
"description": "The model to use for generating the detailed instructions. Use mini for most coding tasks, preview when needing sophisticated reasoning", "description": "The model to use for generating the detailed instructions. Use mini for most coding tasks, preview when needing sophisticated reasoning",
"enum": ["o1-mini", "o1-preview"], "enum": ["mini", "max"],
"default": "o1-mini" "default": "o1-mini"
}, },
"max_tokens": { "max_tokens": {
@@ -38,7 +38,8 @@ class StandaloneLLMTool(BaseTool):
}, },
"required": ["prompt"] "required": ["prompt"]
} }
} },
"_tags": ["llm", "external"]
} }
] ]