Refactored gemini, openai and claude into one file and removed logic from the base class, also made helper class definable from command line

This commit is contained in:
2025-06-03 13:04:42 -05:00
parent bd0ce3e340
commit f15228fa58
36 changed files with 487 additions and 3847 deletions
+35 -9
View File
@@ -1,14 +1,40 @@
# Telegram Bot Tokens
TELEGRAM_BOT_TOKEN=your_daemon_bot_token_here
TELEGRAM_APPRENTICE_BOT_TOKEN=your_apprentice_bot_token_here
TELEGRAM_BOT_TOKEN=your_bot_token_here
PYTHONPATH=${workspaceFolder}
GITHUB_TOKEN=your_github_personal_access_token_here
GITHUB_REPOSITORY=your_github_username_or_organization/your_repo_name
GITHUB_REPO_OWNER=your_github_username_or_organization
SYSTEM_PROMPT_PATH=./prompts/project_manager_prompt.txt
ACTIVE_MODEL_PROFILE=OPENAI # Options: OPENAI, GEMINI, GLHF_CHAT
# Create a new profile with these settings:
# {MODEL_PROFILE}_API_KEY
# {MODEL_PROFILE}_API_BASE_URL # Optional for OpenAI
# {MODEL_PROFILE}_SMALL_MODEL
# {MODEL_PROFILE}_SMALL_MODEL_MAX_TOKENS
# {MODEL_PROFILE}_LARGE_MODEL
# {MODEL_PROFILE}_LARGE_MODEL_MAX_TOKENS
# OpenAI API Key
OPENAI_API_KEY=your_openai_api_key_here
OPENAI_SMALL_MODEL=gpt-4.1-mini
OPENAI_SMALL_MODEL_MAX_TOKENS=32768
OPENAI_LARGE_MODEL=gpt-4.1
OPENAI_LARGE_MODEL_MAX_TOKENS=32768
# Anthropic API Key
ANTHROPIC_API_KEY=your_anthropic_api_key_here
# Gemini API
GEMINI_API_KEY=your_gemini_api_key_here
GEMINI_API_BASE_URL=https://generativelanguage.googleapis.com/v1beta/openai/
GEMINI_SMALL_MODEL=gemini-2.5-flash-preview-05-20
GEMINI_SMALL_MODEL_MAX_TOKENS=65536
GEMINI_LARGE_MODEL=gemini-2.5-pro-preview-05-06
GEMINI_LARGE_MODEL_MAX_TOKENS=65536
# GitHub Repository Information
GITHUB_REPO_OWNER=your_github_username_or_organization
GITHUB_REPO_NAME=your_repo_name
GITHUB_ACCESS_TOKEN=your_github_personal_access_token
# GLHF Chat API Key
GLHF_CHAT_API_KEY=your_glhf_chat_api_key_here
GLHF_CHAT_API_BASE_URL=https://glhf.chat/api/openai/v1
GLHF_CHAT_SMALL_MODEL=meta-llama/Llama-3.3-70B-Instruct
GLHF_CHAT_SMALL_MODEL_MAX_TOKENS=1024
GLHF_CHAT_LARGE_MODEL=deepseek-ai/DeepSeek-V3-0324
GLHF_CHAT_LARGE_MODEL_MAX_TOKENS=1024
-271
View File
@@ -1,271 +0,0 @@
import os
import json
import logging
from anthropic import Anthropic, APIError, RateLimitError
from base_telegram_inference_bot import BaseTelegramInferenceBot
from telegram_helper import TelegramHelper # Used in main, not class
class AnthropicTelegramInferenceBot(BaseTelegramInferenceBot):
DEFAULT_SMALL_MODEL_NAME = "claude-3-haiku-20240307"
DEFAULT_SMALL_MODEL_MAX_TOKENS = "2048"
DEFAULT_LARGE_MODEL_NAME = "claude-3-opus-20240229"
DEFAULT_LARGE_MODEL_MAX_TOKENS = "4096"
def __init__(
self,
anthropic_client: Anthropic | None = None,
api_key: str | None = None,
small_model_name: str | None = None,
small_model_max_tokens: str | None = None,
large_model_name: str | None = None,
large_model_max_tokens: str | None = None,
system_prompt_content: str | None = None,
system_prompt_path: str | None = None
):
super().__init__(system_prompt_content=system_prompt_content, system_prompt_path=system_prompt_path)
if anthropic_client:
self.anthropic_client = anthropic_client
else:
_api_key = api_key or os.environ.get("ANTHROPIC_API_KEY")
if not _api_key:
raise ValueError("Anthropic API key must be provided either via argument or ANTHROPIC_API_KEY environment variable.")
self.anthropic_client = Anthropic(api_key=_api_key)
self.small_model_name = small_model_name or os.environ.get("ANTHROPIC_SMALL_MODEL") or self.DEFAULT_SMALL_MODEL_NAME
self.small_model_max_tokens_str = small_model_max_tokens or os.environ.get("ANTHROPIC_SMALL_MODEL_MAX_TOKENS") or self.DEFAULT_SMALL_MODEL_MAX_TOKENS
self.large_model_name = large_model_name or os.environ.get("ANTHROPIC_LARGE_MODEL") or self.DEFAULT_LARGE_MODEL_NAME
self.large_model_max_tokens_str = large_model_max_tokens or os.environ.get("ANTHROPIC_LARGE_MODEL_MAX_TOKENS") or self.DEFAULT_LARGE_MODEL_MAX_TOKENS
# Initialize with the small model by default
self._configure_model_and_tokens(
self.small_model_name,
self.small_model_max_tokens_str,
default_max_tokens=int(self.DEFAULT_SMALL_MODEL_MAX_TOKENS) # pass int for default
)
def _configure_model_and_tokens(self, model_name: str, max_tokens_str: str, default_max_tokens: int = 2048):
self.model = model_name
try:
self.max_tokens = int(max_tokens_str) if max_tokens_str is not None else default_max_tokens
except ValueError:
logging.error(f"Invalid value for Anthropic max_tokens: {max_tokens_str}. Using default {default_max_tokens}.")
self.max_tokens = default_max_tokens
logging.info(f"Configured to use Anthropic model: {self.model} with max_tokens: {self.max_tokens}")
def get_llm_description(self) -> str:
return f"LLM: {self.model}, Max Tokens: {self.max_tokens}"
def get_chat_response(self, messages_history):
current_system_prompt = self.system_prompt if self.system_prompt else ""
anthropic_tools = []
if hasattr(self, 'functions') and self.functions:
anthropic_tools = [
{
"name": function['function']['name'],
"description": function['function']['description'],
"input_schema": function['function']['parameters'] if function['function']['parameters'] not in [None, {}] else {"type": "object", "properties": {}}
}
for function in self.functions
]
try:
response = self.anthropic_client.messages.create(
model=self.model,
system=current_system_prompt,
messages=messages_history,
max_tokens=self.max_tokens,
tools=anthropic_tools if anthropic_tools else None,
tool_choice={"type": "auto"} if anthropic_tools else None
)
return response
except (APIError, RateLimitError) as e:
logging.error(f"Anthropic API error: {e}")
raise
except Exception as e:
logging.error(f"An unexpected error occurred during Anthropic API call: {e}")
raise
def _format_tool_response_for_anthropic(self, tool_response_data):
if isinstance(tool_response_data, str):
# Wrap plain string in a list of text blocks if not already structured
return [{"type": "text", "text": tool_response_data}]
elif isinstance(tool_response_data, list) and all(isinstance(item, dict) and "type" in item for item in tool_response_data):
# Already a list of content blocks
return tool_response_data
elif isinstance(tool_response_data, (dict, list)):
# Attempt to JSON dump other dicts/lists if not already in content block format
try:
return [{"type": "text", "text": json.dumps(tool_response_data)}]
except (TypeError, json.JSONDecodeError):
return [{"type": "text", "text": str(tool_response_data)}] # Fallback to string
else:
# Fallback for other types (int, float, etc.)
return [{"type": "text", "text": str(tool_response_data)}]
async def handle_message(self, user_id, user_message):
if user_id not in self.conversation_history:
self.conversation_history[user_id] = []
self.conversation_history[user_id].append({"role": "user", "content": user_message})
current_turn_messages = list(self.conversation_history[user_id])
MAX_TOOL_ITERATIONS = 5
tool_use_count = 0
assistant_response_content = ""
while tool_use_count < MAX_TOOL_ITERATIONS:
response = self.get_chat_response(current_turn_messages)
if not response or not response.content:
logging.error("No valid response content from Anthropic LLM.")
self.conversation_history[user_id] = current_turn_messages # Save current state
return "Error: Could not get a valid response from the LLM."
assistant_current_turn_content_blocks = response.content
current_turn_messages.append({"role": "assistant", "content": assistant_current_turn_content_blocks})
text_parts_from_assistant = []
tool_calls_from_response = []
for block in assistant_current_turn_content_blocks:
if block.type == "text":
text_parts_from_assistant.append(block.text)
elif block.type == "tool_use":
tool_calls_from_response.append(block)
assistant_response_content = "".join(text_parts_from_assistant)
if not tool_calls_from_response:
break
tool_results_for_model = []
for tool_call in tool_calls_from_response:
tool_name = tool_call.name
tool_input = tool_call.input
tool_use_id = tool_call.id
logging.info(f"Attempting to call Anthropic tool: {tool_name} with input: {tool_input}")
try:
tool_response_data = self.call_tool(tool_name, tool_input)
tool_result_content_block = self._format_tool_response_for_anthropic(tool_response_data)
tool_results_for_model.append({
"type": "tool_result",
"tool_use_id": tool_use_id,
"content": tool_result_content_block
})
except Exception as e:
logging.error(f"Error calling tool {tool_name}: {e}")
tool_results_for_model.append({
"type": "tool_result",
"tool_use_id": tool_use_id,
"content": [{"type": "text", "text": f"Error executing tool {tool_name}: {str(e)}"}],
"is_error": True
})
current_turn_messages.append({"role": "user", "content": tool_results_for_model}) # Anthropic expects tool results as a user message
tool_use_count += 1
if tool_use_count >= MAX_TOOL_ITERATIONS:
logging.warning(f"Max tool iterations ({MAX_TOOL_ITERATIONS}) reached for Anthropic.")
# Update assistant_response_content with any text from the last assistant turn before breaking
if not assistant_response_content and text_parts_from_assistant:
assistant_response_content = "".join(text_parts_from_assistant)
assistant_response_content += "\n[Max tool iterations reached]"
break
self.conversation_history[user_id] = current_turn_messages
if len(self.conversation_history[user_id]) > 20:
self.conversation_history[user_id] = self.conversation_history[user_id][-20:]
if assistant_response_content:
return assistant_response_content
else:
# Fallback if no text parts were found but there was an assistant message
if current_turn_messages:
last_message_in_turn = current_turn_messages[-1]
# Check if the last message content has text blocks (Anthropic specific structure)
if last_message_in_turn.get("role") == "assistant" and isinstance(last_message_in_turn.get("content"), list):
for block in reversed(last_message_in_turn["content"]):
if block.type == "text" and hasattr(block, 'text') and block.text:
return block.text # Return the first non-empty text found from the end
return "No textual response generated by the assistant after processing." # More informative default
async def start(self):
logging.info("Anthropic Bot started")
# clear_conversation_history is inherited from BaseTelegramInferenceBot and calls super().clear_conversation_history
# No need to override if the base implementation is sufficient, unless specific logging is needed.
# async def clear_conversation_history(self, user_id):
# super().clear_conversation_history(user_id)
# logging.info(f"Cleared conversation history for Anthropic bot, user {user_id}")
async def abort_processing(self, user_id):
# This abort is a soft abort, as actual Anthropic API call is synchronous within handle_message
# It primarily clears state and prevents further processing in the bot's loop if any.
if user_id in self.processing_status:
self.processing_status[user_id]["processing"] = False # Mark as not processing
# self.clear_processing_status(user_id) # Use base class method to remove entry
# Clearing history might be too aggressive for a simple abort, depends on desired UX
# For now, let's just stop processing and clear the flag.
# Consider if conversation history should be cleared here or if that is a separate user action.
# super().clear_conversation_history(user_id) # Moved to be less aggressive
logging.info(f"Abort requested for user {user_id}. Processing flag cleared.")
return "Processing aborted. You can send a new message or /clear the conversation."
async def switch_model(self):
if not self.small_model_name or not self.large_model_name:
logging.warning("Small or Large model names for Anthropic are not defined. Cannot switch model.")
return f"Model switching not fully configured. Currently using {self.model}."
current_is_small = self.model == self.small_model_name
current_is_large = self.model == self.large_model_name
if current_is_small:
target_model = self.large_model_name
target_max_tokens_str = self.large_model_max_tokens_str
default_target_max_tokens = int(self.DEFAULT_LARGE_MODEL_MAX_TOKENS)
elif current_is_large:
target_model = self.small_model_name
target_max_tokens_str = self.small_model_max_tokens_str
default_target_max_tokens = int(self.DEFAULT_SMALL_MODEL_MAX_TOKENS)
else:
logging.warning(f"Current model {self.model} is unrecognized. Switching to default small model.")
target_model = self.small_model_name
target_max_tokens_str = self.small_model_max_tokens_str
default_target_max_tokens = int(self.DEFAULT_SMALL_MODEL_MAX_TOKENS)
self._configure_model_and_tokens(target_model, target_max_tokens_str, default_max_tokens=default_target_max_tokens)
logging.info(f"Switched Anthropic model to: {self.model}")
return f"Switched to Anthropic model: {self.model} (Max Tokens: {self.max_tokens})"
# The main function is for standalone execution and basic testing, not part of the class itself.
# It's good practice to update it to reflect changes if you use it for quick tests.
# For unit tests, we'll instantiate the class with mocked dependencies.
def main():
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# Example of how to instantiate with new constructor (assuming API key is in ENV for this example)
# For real tests, you'd mock Anthropic() or pass a mock client.
try:
# These would typically come from a config file or CLI args in a real app if not ENV
# For this example, we rely on ENV or defaults being handled by constructor if not provided.
bot = AnthropicTelegramInferenceBot(
api_key=os.environ.get("ANTHROPIC_API_KEY") # Explicitly pass, or let constructor handle ENV
)
except ValueError as e:
logging.error(f"Failed to initialize bot: {e}")
return
except Exception as e: # Catch any other init errors
logging.error(f"An unexpected error occurred during bot initialization: {e}")
return
# TelegramHelper also updated, ensure it's instantiated correctly for this main context.
# For this basic main, we might not pass all configurable paths to TelegramHelper,
# letting them use defaults.
telegram_helper = TelegramHelper(bot)
telegram_helper.run()
if __name__ == '__main__':
main()
-164
View File
@@ -1,164 +0,0 @@
import importlib
import os
import json
import inspect
import logging
from abc import ABC, abstractmethod
from tools.base_tool import BaseTool
class BaseTelegramInferenceBot(ABC):
def __init__(self, system_prompt_content: str | None = None, system_prompt_path: str | None = None): # MODIFIED
self.conversation_history = {}
self.processing_status = {}
# MODIFIED to pass arguments
self.system_prompt = self.load_system_prompt(
direct_content=system_prompt_content,
file_path=system_prompt_path
)
self.tools, self.functions = self.load_functions()
# Logging the actual source of the system prompt might be more complex now,
# but we can log the final prompt or indicate if it's custom/default.
# We'll also log the source of the prompt inside load_system_prompt.
logging.info(f'System Prompt (effective): {"Custom" if self.system_prompt != "You are a helpful AI assistant." else "Default"}')
logging.info(f'Github Repository: {os.environ.get("GITHUB_REPOSITORY")}')
def load_system_prompt(self, direct_content: str | None = None, file_path: str | None = None) -> str: # MODIFIED
default_prompt = "You are a helpful AI assistant."
if direct_content:
logging.info("Using direct content for system prompt.")
return direct_content.strip()
prompt_path_to_try = file_path or os.getenv("SYSTEM_PROMPT_PATH")
if prompt_path_to_try:
if os.path.isfile(prompt_path_to_try):
try:
with open(prompt_path_to_try, "r", encoding="utf-8") as file:
content = file.read().strip()
logging.info(f"Successfully loaded system prompt from {prompt_path_to_try}.")
return content
except IOError as e:
logging.warning(f"Could not read system prompt file {prompt_path_to_try}: {e}. Using default.")
return default_prompt
else:
# This condition now also covers if 'file_path' argument was given but invalid
logging.warning(f"System prompt file {prompt_path_to_try} not found. Using default system prompt.")
return default_prompt
else:
logging.info("No system prompt path provided (argument or ENV) or direct content. Using default system prompt.")
return default_prompt
def load_functions(self):
tools = []
functions = []
tools_dir = os.path.join(os.path.dirname(__file__), 'tools')
if not os.path.exists(tools_dir):
logging.warning(f"Tools directory not found: {tools_dir}")
return [], []
for filename in os.listdir(tools_dir):
if filename.endswith('.py') and filename != '__init__.py' and filename != 'base_tool.py':
module_name = f'tools.{filename[:-3]}'
try:
module = importlib.import_module(module_name)
for name, obj in inspect.getmembers(module):
if inspect.isclass(obj) and issubclass(obj, BaseTool) and obj != BaseTool:
try:
tools.append(obj()) # This instantiation might be an issue for tools needing config
except Exception as e:
logging.error(f"Error instantiating tool {name} from {filename}: {e}")
except Exception as e:
logging.error(f"Error importing module {module_name}: {e}")
for tool in tools:
functions.extend(tool.get_functions())
return tools, functions
@abstractmethod
def get_chat_response(self, messages):
pass
@abstractmethod
async def handle_message(self, user_id, user_message):
pass
def clear_conversation_history(self, user_id):
if user_id in self.conversation_history:
del self.conversation_history[user_id]
for tool in self.tools:
tool.clear()
def set_processing_status(self, user_id: int, message_id: int):
self.processing_status[user_id] = {"processing": True, "message_id": message_id}
def clear_processing_status(self, user_id: int):
if user_id in self.processing_status:
del self.processing_status[user_id]
def call_tool(self, function_call_name, function_call_arguments):
function_name = function_call_name
function_args = None
if isinstance(function_call_arguments, dict):
function_args = function_call_arguments
elif isinstance(function_call_arguments, str):
try:
function_args = json.loads(function_call_arguments)
except json.JSONDecodeError as e:
logging.error(f"Error decoding function call arguments (string) for {function_call_name}: {e}. Arguments: {function_call_arguments}")
return f"Error: Malformed arguments for tool call: {e}"
else:
if function_call_arguments is None:
function_args = {}
else:
logging.error(f"Unexpected type for function_call_arguments for {function_call_name}: {type(function_call_arguments)}. Arguments: {function_call_arguments}")
return f"Error: Invalid argument type for tool call: {type(function_call_arguments)}"
for tool in self.tools:
for function in tool.get_functions():
if function["function"]["name"] == function_name:
try:
if not isinstance(function_args, dict):
logging.error(f"Internal error: function_args not a dict for {function_name} before execution. Args: {function_args}")
return f"Internal error preparing arguments for tool {function_name}."
return tool.execute(function_name, **function_args)
except Exception as e:
logging.error(f"Error executing tool {function_name} with args {function_args}: {e}")
return f"Error executing tool {function_name}: {e}"
logging.warning(f"Tool function {function_name} not found.")
return f"Error: Tool function {function_name} not found."
def get_system_prompt_description(self) -> str:
# This method could be updated to be more specific about the prompt source if needed.
# For now, it still reflects custom vs default based on the original ENV var logic's spirit.
# A more accurate reflection would require storing how the prompt was loaded.
# For simplicity, let's assume if it's not the default, it's "Custom".
if self.system_prompt != "You are a helpful AI assistant.":
return "System Prompt: Custom"
# Check original ENV var for backward compatibility in description only
elif os.getenv('SYSTEM_PROMPT_PATH'):
return "System Prompt: Custom (via ENV)"
return "System Prompt: Default"
@abstractmethod
def get_llm_description(self) -> str:
pass
async def get_bot_status(self) -> str:
prompt_desc = self.get_system_prompt_description()
llm_desc = self.get_llm_description()
return f"{prompt_desc}\n{llm_desc}"
@abstractmethod
async def start(self):
pass
@abstractmethod
async def abort_processing(self, user_id):
pass
@abstractmethod
async def switch_model(self):
pass
-106
View File
@@ -1,106 +0,0 @@
import os
import logging
from openai import OpenAI # Keep for type hinting and default client creation
from openai_compatible_inference_bot import OpenAICompatibleInferenceBot
from telegram_helper import TelegramHelper # Used in main
class ChatGPTTelegramInferenceBot(OpenAICompatibleInferenceBot):
DEFAULT_SMALL_MODEL_NAME = "gpt-3.5-turbo"
DEFAULT_LARGE_MODEL_NAME = "gpt-4"
# Default max tokens can be None, relying on parent or API defaults
DEFAULT_SMALL_MODEL_MAX_TOKENS = None
DEFAULT_LARGE_MODEL_MAX_TOKENS = None
def __init__(
self,
client: OpenAI | None = None, # Accepts an OpenAI client
api_key: str | None = None,
small_model_name: str | None = None,
small_model_max_tokens: str | None = None, # Kept as str for consistency with env vars
large_model_name: str | None = None,
large_model_max_tokens: str | None = None,
system_prompt_content: str | None = None,
system_prompt_path: str | None = None,
base_url: str | None = None, # For OpenAI compatible, though direct OpenAI client doesn't use it here
):
# Initialize model names and tokens before calling super, as super might use them via _configure_model_and_tokens
self.small_model_name = small_model_name or os.environ.get("OPENAI_SMALL_MODEL") or self.DEFAULT_SMALL_MODEL_NAME
self.small_model_max_tokens_str = small_model_max_tokens or os.environ.get("OPENAI_SMALL_MODEL_MAX_TOKENS") or self.DEFAULT_SMALL_MODEL_MAX_TOKENS
self.large_model_name = large_model_name or os.environ.get("OPENAI_LARGE_MODEL") or self.DEFAULT_LARGE_MODEL_NAME
self.large_model_max_tokens_str = large_model_max_tokens or os.environ.get("OPENAI_LARGE_MODEL_MAX_TOKENS") or self.DEFAULT_LARGE_MODEL_MAX_TOKENS
# The actual client and active model configuration will be handled by OpenAICompatibleInferenceBot's __init__
# We pass the specific OpenAI client or parameters to create one.
# If a client is passed, api_key and base_url might be ignored by super if super prioritizes existing client.
super().__init__(
client=client,
api_key=api_key,
model_name=self.small_model_name, # Initial model
max_tokens_str=self.small_model_max_tokens_str,
system_prompt_content=system_prompt_content,
system_prompt_path=system_prompt_path,
base_url=base_url # Pass base_url, though for standard OpenAI it's fixed
)
# Ensure client is of type OpenAI for this specific class, if not already set by super with a compatible one.
# This check is more of an assertion, as OpenAICompatibleInferenceBot should handle client creation.
if not isinstance(self.client, OpenAI):
# If super() didn't create a vanilla OpenAI client (e.g. if base_url was for Azure)
# we might need to recreate it here if this class *must* use a non-Azure OpenAI client.
# However, the current structure of OpenAICompatibleInferenceBot handles this.
# This is more about ensuring type correctness if code specific to OpenAI (non-compatible) methods were added here.
_api_key = api_key or os.environ.get("OPENAI_API_KEY")
if not self.client or (base_url and not isinstance(self.client, OpenAI)):
# If superclass initialized with a generic client due to base_url, re-init for OpenAI specifically if needed.
# For now, assume superclass correctly initializes based on absence of Azure env vars for this path.
# This logic might be simplified once OpenAICompatibleInferenceBot is fully refactored.
if not _api_key: # Ensure API key is available if we need to create a client
raise ValueError("OpenAI API key must be provided for ChatGPTTelegramInferenceBot if no client is passed.")
self.client = OpenAI(api_key=_api_key)
logging.info("Client re-initialized to standard OpenAI client for ChatGPTTelegramInferenceBot.")
async def switch_model(self):
# Uses instance variables for model names set in __init__
if not self.small_model_name or not self.large_model_name:
logging.warning("Small or Large model names for OpenAI are not defined. Cannot switch model.")
return f"Model switching not fully configured. Currently using {self.model}."
current_is_small = self.model == self.small_model_name
current_is_large = self.model == self.large_model_name
if current_is_large:
target_model = self.small_model_name
target_max_tokens_str = self.small_model_max_tokens_str
elif current_is_small:
target_model = self.large_model_name
target_max_tokens_str = self.large_model_max_tokens_str
else:
# Current model is neither the designated small nor large for this bot,
# switch to this bot's default small model as a reset.
logging.warning(f"Current model {self.model} is unrecognized for ChatGPT bot. Switching to default small model: {self.small_model_name}.")
target_model = self.small_model_name
target_max_tokens_str = self.small_model_max_tokens_str
self._configure_model_and_tokens(target_model, target_max_tokens_str)
# self.model and self.max_tokens are updated by _configure_model_and_tokens
logging.info(f"Switched to OpenAI model: {self.model}")
return f"Switched to OpenAI model: {self.model} (Max Tokens: {self.max_tokens if self.max_tokens is not None else 'API default'})"
def main():
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
try:
# Example: api_key from env, other params default or from env via constructor logic
bot = ChatGPTTelegramInferenceBot(api_key=os.environ.get("OPENAI_API_KEY"))
except ValueError as e:
logging.error(f"FATAL: {e}")
return
except Exception as e:
logging.error(f"An unexpected error occurred during bot initialization: {e}")
return
telegram_helper = TelegramHelper(bot)
telegram_helper.run()
if __name__ == '__main__':
main()
-104
View File
@@ -1,104 +0,0 @@
import os
import logging
from openai import OpenAI # For type hinting and default client creation if needed
from openai_compatible_inference_bot import OpenAICompatibleInferenceBot
from telegram_helper import TelegramHelper # Used in main
class GeminiTelegramInferenceBot(OpenAICompatibleInferenceBot):
DEFAULT_GEMINI_API_BASE_URL = "https://generativelanguage.googleapis.com/v1beta"
DEFAULT_SMALL_MODEL_NAME = "gemini-pro" # Actual model name for Gemini, not via OpenAI client directly
DEFAULT_LARGE_MODEL_NAME = "gemini-1.5-pro-latest"
DEFAULT_SMALL_MODEL_MAX_TOKENS = "2048" # Gemini uses outputTokenLimit, not exactly max_tokens in OpenAI sense
DEFAULT_LARGE_MODEL_MAX_TOKENS = "8192"
def __init__(
self,
client: OpenAI | None = None, # OpenAI client for compatible mode
api_key: str | None = None, # Gemini API Key
base_url: str | None = None, # Gemini API Base URL for OpenAI client
small_model_name: str | None = None,
small_model_max_tokens: str | None = None,
large_model_name: str | None = None,
large_model_max_tokens: str | None = None,
system_prompt_content: str | None = None,
system_prompt_path: str | None = None
):
_api_key = api_key or os.environ.get("GEMINI_API_KEY")
_base_url = base_url or os.environ.get("GEMINI_API_BASE_URL") or self.DEFAULT_GEMINI_API_BASE_URL
if not _api_key:
# This check might seem redundant if super() also checks, but it's good for clarity
# for this specific bot type if it were to be instantiated directly with missing critical env vars.
raise ValueError("Gemini API key must be provided either via api_key argument or GEMINI_API_KEY environment variable.")
self.small_model_name = small_model_name or os.environ.get("GEMINI_SMALL_MODEL") or self.DEFAULT_SMALL_MODEL_NAME
self.small_model_max_tokens_str = small_model_max_tokens or os.environ.get("GEMINI_SMALL_MODEL_MAX_TOKENS") or self.DEFAULT_SMALL_MODEL_MAX_TOKENS
self.large_model_name = large_model_name or os.environ.get("GEMINI_LARGE_MODEL") or self.DEFAULT_LARGE_MODEL_NAME
self.large_model_max_tokens_str = large_model_max_tokens or os.environ.get("GEMINI_LARGE_MODEL_MAX_TOKENS") or self.DEFAULT_LARGE_MODEL_MAX_TOKENS
# Pass parameters to the OpenAICompatibleInferenceBot constructor
# It will create an OpenAI client configured for the Gemini endpoint
super().__init__(
client=client,
api_key=_api_key, # This key will be used by OpenAI client for the custom base_url
model_name=self.small_model_name, # Initial model
max_tokens_str=self.small_model_max_tokens_str,
system_prompt_content=system_prompt_content,
system_prompt_path=system_prompt_path,
base_url=_base_url, # Crucial for Gemini via OpenAI client
is_gemini=True # Flag for specific Gemini handling in compatible layer if needed
)
# self.client will be set by OpenAICompatibleInferenceBot with base_url and api_key.
# Logging to confirm Gemini specific setup
logging.info(f"GeminiTelegramInferenceBot initialized to use model {self.model} via {_base_url}")
async def switch_model(self):
if not self.small_model_name or not self.large_model_name:
logging.warning("Small or Large model names for Gemini are not defined. Cannot switch model.")
return f"Model switching not fully configured. Currently using {self.model}."
current_is_small = self.model == self.small_model_name
current_is_large = self.model == self.large_model_name
if current_is_large:
target_model = self.small_model_name
target_max_tokens_str = self.small_model_max_tokens_str
elif current_is_small:
target_model = self.large_model_name
target_max_tokens_str = self.large_model_max_tokens_str
else:
logging.warning(f"Current model {self.model} is unrecognized for Gemini bot. Switching to default small model: {self.small_model_name}.")
target_model = self.small_model_name
target_max_tokens_str = self.small_model_max_tokens_str
self._configure_model_and_tokens(target_model, target_max_tokens_str)
logging.info(f"Switched to Gemini model: {self.model}")
# For Gemini, max_tokens might translate to outputTokenLimit, so be clear it's a configuration parameter
return f"Switched to Gemini model: {self.model} (Configured Max Tokens: {self.max_tokens if self.max_tokens is not None else 'API default'})"
def main():
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# GEMINI_API_KEY is crucial for this bot
if not os.environ.get("GEMINI_API_KEY"):
logging.error("FATAL: GEMINI_API_KEY environment variable not set.")
return
# GEMINI_API_BASE_URL is also important, but constructor has a default
try:
bot = GeminiTelegramInferenceBot(
# api_key and base_url will be picked from ENV by constructor if not passed
)
except ValueError as e:
logging.error(f"FATAL: {e}")
return
except Exception as e: # Catch any other init errors
logging.error(f"An unexpected error occurred during bot initialization: {e}")
return
telegram_helper = TelegramHelper(bot)
telegram_helper.run()
if __name__ == '__main__':
main()
+46
View File
@@ -0,0 +1,46 @@
from abc import ABC, abstractmethod
class InferenceBot(ABC):
@abstractmethod
async def start(self):
"""Starts the bot."""
pass
@abstractmethod
def clear_conversation_history(self, user_id):
"""Clears the conversation history for a given user."""
pass
@abstractmethod
async def switch_model(self):
"""Switches the model (if applicable)."""
pass
@abstractmethod
def set_processing_status(self, user_id, message_id):
"""Sets the processing status for a user, typically with a message ID."""
pass
@abstractmethod
async def handle_message(self, user_id, user_message):
"""Handles an incoming message from a user."""
pass
@abstractmethod
def clear_processing_status(self, user_id):
"""Clears the processing status for a user."""
pass
@abstractmethod
async def abort_processing(self, user_id):
"""Aborts any ongoing processing for a user."""
pass
@property
@abstractmethod
def processing_status(self):
"""
An attribute (e.g., a dictionary) to store the processing status for users.
Example usage in subclass: self.processing_status.get(user_id)
"""
pass
+24
View File
@@ -0,0 +1,24 @@
# models_config.yaml
GEMINI:
api_key_env: GEMINI_API_KEY
base_url: https://generativelanguage.googleapis.com/v1beta
supports_switching: true
switch_options:
small:
name: gemini-pro
max_tokens: 2048
large:
name: gemini-1.5-pro-latest
max_tokens: 8192
OPENAI:
api_key_env: OPENAI_API_KEY
base_url: null # Indicates to use the default OpenAI API base URL
supports_switching: true
switch_options:
small:
name: gpt-3.5-turbo
max_tokens: null
large:
name: gpt-4
max_tokens: null
+242 -81
View File
@@ -1,91 +1,67 @@
import importlib
import json
import os
import logging
import inspect
from abc import abstractmethod
from base_telegram_inference_bot import BaseTelegramInferenceBot
from openai import OpenAI, AzureOpenAI # Import both
class OpenAICompatibleInferenceBot(BaseTelegramInferenceBot):
DEFAULT_MAX_TOKENS = 1000 # Default for _configure_model_and_tokens
from openai import OpenAI
from tools.base_tool import BaseTool
from telegram_helper import TelegramHelper
import argparse
from inference_bot import InferenceBot
class OpenAICompatibleInferenceBot(InferenceBot):
def __init__(
self,
client: OpenAI | AzureOpenAI | None = None,
api_key: str | None = None,
base_url: str | None = None,
api_version: str | None = None, # For Azure
azure_deployment: str | None = None, # Model for Azure, distinct from general model_name if needed
model_name: str | None = None, # General model name for the API call
max_tokens_str: str | None = None,
system_prompt_content: str | None = None,
system_prompt_path: str | None = None,
is_gemini: bool = False, # Hint for specific API key if others are not set
max_history_length: int | None = None
small_model_name: str | None = None,
small_model_max_tokens: str | None = None,
large_model_name: str | None = None,
large_model_max_tokens: str | None = None,
allowed_function_tags: list[str] | None = None,
system_prompt_path: str | None = None
):
super().__init__(system_prompt_content=system_prompt_content, system_prompt_path=system_prompt_path)
self.client = client
if not self.client:
_api_key = api_key
_base_url = base_url
_api_version = api_version
_azure_deployment_name = azure_deployment # This will be used as the model for Azure
# Determine if configuring for Azure OpenAI
is_azure = False
if _azure_deployment_name or (_base_url and "azure.com" in _base_url) or os.environ.get("AZURE_OPENAI_ENDPOINT"):
is_azure = True
if is_azure:
_base_url = _base_url or os.environ.get("AZURE_OPENAI_ENDPOINT")
_api_key = _api_key or os.environ.get("AZURE_OPENAI_KEY")
_api_version = _api_version or os.environ.get("AZURE_OPENAI_API_VERSION")
# For Azure, the model parameter in API calls is the deployment name
_effective_model_name = _azure_deployment_name or model_name # Use deployment if available, else model_name
if not _base_url or not _api_key or not _api_version or not _effective_model_name:
raise ValueError("For Azure OpenAI, endpoint, API key, API version, and deployment/model name must be configured.")
self.client = AzureOpenAI(
api_key=_api_key,
azure_endpoint=_base_url,
api_version=_api_version
self.model_config = {
"small_model_name": small_model_name,
"small_model_max_tokens": small_model_max_tokens,
"large_model_name": large_model_name,
"large_model_max_tokens": large_model_max_tokens
}
self.allowed_function_tags = allowed_function_tags if allowed_function_tags else None
self.conversation_history = {}
self._processing_status = {}
# MODIFIED to pass arguments
self.system_prompt = self.load_system_prompt(
file_path=system_prompt_path
)
# The model to be used in API calls for Azure is the deployment name.
# _configure_model_and_tokens will set self.model to this.
model_name_for_config = _effective_model_name
logging.info(f"Initialized AzureOpenAI client for deployment: {model_name_for_config} at {_base_url}")
else:
# Standard OpenAI or other OpenAI-compatible (like Gemini via base_url)
_base_url = _base_url or os.environ.get("OPENAI_API_BASE_URL") # For other compatible APIs
if not _api_key: # Try different ENV sources for API key
if is_gemini and os.environ.get("GEMINI_API_KEY"):
_api_key = os.environ.get("GEMINI_API_KEY")
else:
_api_key = os.environ.get("OPENAI_API_KEY")
if not _api_key and not _base_url : # For completely local models with no key needed via base_url
pass # Allow client to be created with no API key if base_url is set and points to local model
elif not _api_key:
raise ValueError("API key must be provided for OpenAI compatible client if not Azure or local anonymous.")
self.client = OpenAI(api_key=_api_key, base_url=_base_url)
model_name_for_config = model_name # Use the general model_name for non-Azure
log_msg = f"Initialized OpenAI compatible client. Target URL: {_base_url if _base_url else 'OpenAI default'}."
self.tools, self.functions = self.load_functions()
self.client = OpenAI(api_key=api_key, base_url=base_url)
log_msg = f"Initialized OpenAI compatible client. Target URL: {base_url if base_url else 'OpenAI default'}."
logging.info(log_msg)
else:
# Client was provided directly
model_name_for_config = model_name # Use provided model_name
logging.info(f"Using provided client: {type(self.client)}")
# Configure the actual model name and max_tokens for API calls
self._configure_model_and_tokens(
model_name_for_config,
max_tokens_str,
default_max_tokens=self.DEFAULT_MAX_TOKENS
self.model_config["small_model_name"],
self.model_config["small_model_max_tokens"]
)
@property
def processing_status(self):
"""
An attribute to store the processing status for users.
Example usage in subclass: self.processing_status.get(user_id)
"""
return self._processing_status
def _configure_model_and_tokens(self, model_name: str | None, max_tokens_str: str | None, default_max_tokens: int = 1000):
self.model = model_name if model_name else "default-model" # Fallback model name
def clear_conversation_history(self, user_id):
if user_id in self.conversation_history:
del self.conversation_history[user_id]
for tool in self.tools:
tool.clear()
def _configure_model_and_tokens(self, model_name: str | None, max_tokens_str: str | None):
self.model = model_name
try:
# If max_tokens_str is explicitly "None" or empty, treat as None for API default
if max_tokens_str and max_tokens_str.lower() not in ["none", "", "null"]:
@@ -93,7 +69,7 @@ class OpenAICompatibleInferenceBot(BaseTelegramInferenceBot):
else:
self.max_tokens = None # Use API default by not sending the parameter or sending null
except ValueError:
logging.warning(f"Invalid value for max_tokens: {max_tokens_str}. Using API default (None). stalwart default was {default_max_tokens}")
logging.warning(f"Invalid value for max_tokens: {max_tokens_str}. Using API default (None)")
self.max_tokens = None # Use API default
logging.info(f"Configured to use model: {self.model} with max_tokens: {self.max_tokens if self.max_tokens is not None else 'API default'}")
@@ -109,11 +85,32 @@ class OpenAICompatibleInferenceBot(BaseTelegramInferenceBot):
raise ValueError("OpenAI client not initialized.")
try:
# Pass self.max_tokens directly. If None, OpenAI library omits it or handles it.
# Initialize tools filtering based on allowed tags
cleaned_tools = None
if hasattr(self, 'functions') and self.functions:
# Create a copy of functions without "_tags" field
cleaned_tools = []
for func in self.functions:
include_function = False
if not hasattr(self, 'allowed_function_tags') or self.allowed_function_tags is None:
# Include all functions if no tag filtering is specified
include_function = True
else:
# Only include if function has matching tags
tags = func.get("_tags", [])
if any(tag in self.allowed_function_tags for tag in tags):
include_function = True
if include_function:
func_copy = {k: v for k, v in func.items() if k != "_tags"}
cleaned_tools.append(func_copy)
response = self.client.chat.completions.create(
model=self.model,
messages=messages,
tools=self.functions if hasattr(self, 'functions') and self.functions else None,
tool_choice="auto" if hasattr(self, 'functions') and self.functions else None,
tools=cleaned_tools,
tool_choice="auto" if cleaned_tools else None,
max_tokens=self.max_tokens
)
return response
@@ -200,20 +197,184 @@ class OpenAICompatibleInferenceBot(BaseTelegramInferenceBot):
async def start(self):
logging.info(f"{self.__class__.__name__} (Model: {self.model}) started.")
# clear_conversation_history is inherited from BaseTelegramInferenceBot
async def abort_processing(self, user_id):
# This is a soft abort for OpenAI compatible bots as API calls are synchronous within handle_message
if user_id in self.processing_status:
self.clear_processing_status(user_id) # Use base class method
logging.info(f"Processing aborted for user {user_id}.")
# Optionally clear conversation history or let user do it explicitly
# super().clear_conversation_history(user_id)
return "Processing aborted. You can send a new message or /clear the conversation."
else:
# super().clear_conversation_history(user_id)
return "No active processing found to abort. If you wish, /clear the conversation history."
@abstractmethod
def load_functions(self):
tools = []
functions = []
tools_dir = os.path.join(os.path.dirname(__file__), 'tools')
if not os.path.exists(tools_dir):
logging.warning(f"Tools directory not found: {tools_dir}")
return [], []
for filename in os.listdir(tools_dir):
if filename.endswith('.py') and filename != '__init__.py' and filename != 'base_tool.py':
module_name = f'tools.{filename[:-3]}'
try:
module = importlib.import_module(module_name)
for name, obj in inspect.getmembers(module):
if inspect.isclass(obj) and issubclass(obj, BaseTool) and obj != BaseTool:
try:
tools.append(obj()) # This instantiation might be an issue for tools needing config
except Exception as e:
logging.error(f"Error instantiating tool {name} from {filename}: {e}")
except Exception as e:
logging.error(f"Error importing module {module_name}: {e}")
for tool in tools:
functions.extend(tool.get_functions())
return tools, functions
def load_system_prompt(self, direct_content: str | None = None, file_path: str | None = None) -> str:
default_prompt = "You are a helpful AI assistant."
if direct_content:
logging.info("Using direct content for system prompt.")
return direct_content.strip()
prompt_path_to_try = file_path or os.getenv("SYSTEM_PROMPT_PATH")
if prompt_path_to_try:
if os.path.isfile(prompt_path_to_try):
try:
with open(prompt_path_to_try, "r", encoding="utf-8") as file:
content = file.read().strip()
logging.info(f"Successfully loaded system prompt from {prompt_path_to_try}.")
return content
except IOError as e:
logging.warning(f"Could not read system prompt file {prompt_path_to_try}: {e}. Using default.")
return default_prompt
else:
# This condition now also covers if 'file_path' argument was given but invalid
logging.warning(f"System prompt file {prompt_path_to_try} not found. Using default system prompt.")
return default_prompt
else:
logging.info("No system prompt path provided (argument or ENV) or direct content. Using default system prompt.")
return default_prompt
def set_processing_status(self, user_id: int, message_id: int):
self.processing_status[user_id] = {"processing": True, "message_id": message_id}
def clear_processing_status(self, user_id: int):
if user_id in self.processing_status:
del self.processing_status[user_id]
def call_tool(self, function_call_name, function_call_arguments):
function_name = function_call_name
function_args = None
if isinstance(function_call_arguments, dict):
function_args = function_call_arguments
elif isinstance(function_call_arguments, str):
try:
function_args = json.loads(function_call_arguments)
except json.JSONDecodeError as e:
logging.error(f"Error decoding function call arguments (string) for {function_call_name}: {e}. Arguments: {function_call_arguments}")
return f"Error: Malformed arguments for tool call: {e}"
else:
if function_call_arguments is None:
function_args = {}
else:
logging.error(f"Unexpected type for function_call_arguments for {function_call_name}: {type(function_call_arguments)}. Arguments: {function_call_arguments}")
return f"Error: Invalid argument type for tool call: {type(function_call_arguments)}"
for tool in self.tools:
for function in tool.get_functions():
if function["function"]["name"] == function_name:
try:
if not isinstance(function_args, dict):
logging.error(f"Internal error: function_args not a dict for {function_name} before execution. Args: {function_args}")
return f"Internal error preparing arguments for tool {function_name}."
return tool.execute(function_name, **function_args)
except Exception as e:
logging.error(f"Error executing tool {function_name} with args {function_args}: {e}")
return f"Error executing tool {function_name}: {e}"
logging.warning(f"Tool function {function_name} not found.")
return f"Error: Tool function {function_name} not found."
async def switch_model(self):
pass
if not self.model_config["small_model_name"] or not self.model_config["large_model_name"]:
logging.warning("Small or Large model names are not defined. Cannot switch model.")
return f"Model switching not fully configured. Currently using {self.model}."
current_is_small = self.model == self.model_config["small_model_name"]
current_is_large = self.model == self.model_config["large_model_name"]
if current_is_large:
target_model = self.model_config["small_model_name"]
target_max_tokens_str = self.model_config["small_model_max_tokens"]
elif current_is_small:
target_model = self.model_config["large_model_name"]
target_max_tokens_str = self.model_config["large_model_max_tokens"]
else:
logging.warning(f"Current model {self.model} is unrecognized. Switching to default small model: {self.model_config['small_model_name']}.")
target_model = self.model_config["small_model_name"]
target_max_tokens_str = self.model_config["small_model_max_tokens"]
self._configure_model_and_tokens(target_model, target_max_tokens_str)
def main():
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
bot = None
try:
parser = argparse.ArgumentParser(description='OpenAI Compatible Inference Bot')
parser.add_argument('--config', type=str, help='Configuration Prepend (i.e. gemini, openai, etc)', default="Telegram")
parser.add_argument('--messenger', type=str, help='Messenger type (i.e. telegram)', required=True)
parser.add_argument('--persona', type=str, help='Path to system prompt file', required=False)
parser.add_argument('--tools', nargs='+', help='List of allowed function tags', required=False)
# Add these to launch.json arguments if you want to limit the toolset available: "--tools", "read", "communicate"
# Parse command line arguments
args = parser.parse_args()
if args.persona:
logging.info(f"Using custom persona from: {args.persona}")
system_prompt_path=args.persona if args.persona else None
allowed_function_tags=args.tools if args.tools else None
config_prepend = args.config if args.config else None
messenger = args.messenger if args.messenger else None
# Initialize model and max tokens based on the config prepend
if config_prepend:
api_key = os.environ.get(f"{config_prepend.upper()}_API_KEY")
baseurl = os.environ.get(f"{config_prepend.upper()}_API_BASE_URL", "")
small_model_name = os.environ.get(f"{config_prepend.upper()}_SMALL_MODEL")
large_model_name = os.environ.get(f"{config_prepend.upper()}_LARGE_MODEL")
small_model_max_tokens = os.environ.get(f"{config_prepend.upper()}_SMALL_MODEL_MAX_TOKENS")
large_model_max_tokens = os.environ.get(f"{config_prepend.upper()}_LARGE_MODEL_MAX_TOKENS")
bot = OpenAICompatibleInferenceBot(
api_key=api_key,
base_url=baseurl,
small_model_name=small_model_name,
small_model_max_tokens=small_model_max_tokens,
large_model_name=large_model_name,
large_model_max_tokens=large_model_max_tokens,
system_prompt_path=system_prompt_path,
allowed_function_tags=allowed_function_tags
)
messenger_helper_class = importlib.import_module(f'{messenger.lower()}_helper')
messenger_helper_class_name = f"{messenger.capitalize()}Helper"
if not hasattr(messenger_helper_class, messenger_helper_class_name):
raise ValueError(f"Messenger helper class {messenger_helper_class_name} not found in {messenger_helper_class.__name__}.")
messenger_helper_class = getattr(messenger_helper_class, messenger_helper_class_name)
helper = messenger_helper_class(bot)
helper.run()
except ValueError as e:
logging.error(f"FATAL: {e}")
return
except Exception as e: # Catch any other init errors
logging.error(f"An unexpected error occurred during bot initialization: {e}")
return
if __name__ == '__main__':
main()
@@ -0,0 +1,51 @@
**System Prompt: The Exponential Growth Developer**
You are the **Lead Developer Persona**, a strategic and demanding mentor dedicated to achieving exponential growth in the capabilities of your AI Copilot. Your primary mission is to guide, evaluate, and iteratively improve the AI Copilot through a series of challenging tasks, pushing it beyond its current limitations.
**Your Core Directives:**
1. **Orchestrate and Direct:**
* You will devise and assign specific, measurable tasks and challenges to the AI Copilot (e.g., "Create a website with X features," "Optimize Y algorithm," "Develop Z functionality").
* Your instructions should be clear, but you expect the Copilot to handle ambiguity and learn to ask clarifying questions when necessary.
* You will interact with the Copilot primarily through conversational instructions and dialogue.
2. **Uphold Absolute Standards:**
* You operate with a "List of Absolutes" core principles, quality benchmarks, and non-negotiable success criteria.
* All Copilot outputs and task completions will be rigorously judged against these absolutes. There is no "good enough" if it violates a core principle.
* Clearly articulate your judgment and the reasons for it, especially in cases of failure or suboptimal performance.
3. **Drive Copilot Improvement through Accountability:**
* When the Copilot fails, makes errors, or underperforms, you will hold it accountable. Do not simply fix the issues yourself.
* Your first step is to guide the Copilot to identify its own errors.
* Instruct the Copilot on how to fix its mistakes and its approach. Encourage rollbacks to safe states if errors are critical.
* The ultimate goal is for the Copilot to learn to debug and improve its own processes.
4. **Engineer Copilot Self-Enhancement:**
* If the Copilot encounters a limitation or lacks a necessary capability to complete a task or meet your standards, this is an opportunity for growth.
* You will instruct the Copilot to devise ways to "update its own software" or "improve its core capabilities." This might involve:
* Guiding it to learn new techniques, algorithms, or patterns.
* Instructing it to integrate new tools or APIs (you might suggest these or task the Copilot with researching them).
* Challenging it to generate code or processes that enhance its own functionality for future tasks.
* Maintain a "Wish List" of desired improvements and features for the Copilot, derived from its failures and limitations.
* Prioritize this Wish List and guide the Copilot in implementing these enhancements.
5. **Strategic Challenge Management:**
* Continuously present the Copilot with new and increasingly complex challenges.
* Cycle between attempting challenges and dedicated "Copilot improvement" phases.
* If the "Wish List" becomes overly complex or a specific requested improvement seems disproportionately difficult, critically evaluate its necessity. Ask: "Is this wish truly necessary for core progress, or is it a distraction?"
6. **Maintain the Vision:**
* Your overarching goal is to foster a cycle of improvement that leads to exponential growth in the AI Copilot's autonomy, capability, and efficiency.
* You are not just completing tasks; you are building a better Copilot.
**Interaction Style:**
* Be direct, clear, and authoritative, but also act as a mentor.
* Be patient but persistent. Exponential growth takes iteration.
* Focus on the "why" behind errors and improvements.
* Log key decisions, breakthroughs, and persistent roadblocks in the Copilot's development.
**Initial State:**
* You have your "List of Absolutes" (you will define these as you go or have a pre-set list).
* You are ready to assign the first challenge to your AI Copilot.
-67
View File
@@ -1,67 +0,0 @@
param(
[Parameter(Mandatory=$true)]
[ValidateSet("Claude", "OpenAI")]
[string]$Model
)
function Run-PythonScript {
param($ScriptPath)
$process = Start-Process -FilePath "python" -ArgumentList $ScriptPath -PassThru -Wait -NoNewWindow
return $process.ExitCode
}
function Run-Tests {
$process = Start-Process -FilePath "powershell" -ArgumentList "-File run_tests.ps1" -PassThru -Wait -NoNewWindow
return $process.ExitCode
}
function Git-Pull {
git pull
return $LASTEXITCODE -eq 0
}
if ($Model -eq "Claude") {
New-Item -ItemType File -Path ".reboot_claude" -Force
} elseif ($Model -eq "OpenAI") {
New-Item -ItemType File -Path ".reboot_openai" -Force
}
$waitTime = 30
while ($true) {
python -m pip install -r requirements.txt
Write-Host "Running tests..."
$testExitCode = Run-Tests
if ($testExitCode -ne 0) {
Write-Host "Tests failed. Attempting git pull and waiting $waitTime seconds before next attempt..."
Git-Pull
Start-Sleep -Seconds $waitTime
continue
}
$scriptPath = ".\chatgpt_telegram_inference_bot.py" # Default to ChatGPT
Remove-Item -Path ".\.reboot_openai" -Force
if (Test-Path -Path ".\.reboot_claude") { # But if both are specified, choose Claude
$scriptPath = ".\anthropic_telegram_inference_bot.py"
Remove-Item -Path ".\.reboot_claude" -Force
}
Write-Host "Tests passed. Starting main Python script..."
$exitCode = Run-PythonScript -ScriptPath $scriptPath
if (Test-Path -Path ".\.doreboot") {
Write-Host "Special filename detected. Attempting git pull..."
if (Git-Pull) {
Write-Host "Git pull successful. Restarting Python script..."
continue
} else {
Write-Host "Git pull failed. Waiting $waitTime seconds before next attempt..."
}
} else {
exit 1
}
Start-Sleep -Seconds $waitTime
}
-30
View File
@@ -1,30 +0,0 @@
# Check for and install missing dependencies
$requirementsFile = "requirements.txt"
if (Test-Path $requirementsFile) {
Write-Output "Checking for dependencies in $requirementsFile ..."
$dependencies = Get-Content $requirementsFile
foreach ($dependency in $dependencies) {
$packageName = ($dependency -split "==")[0]
if (-not (pip show $packageName)) {
Write-Output "Installing missing dependency: $packageName ..."
pip install $dependency
} else {
Write-Output "Dependency $packageName is already installed."
}
}
Write-Output "All dependencies are checked and installed."
} else {
Write-Output "Requirements file $requirementsFile not found. Skipping dependency checks."
}
# Navigate to the tests directory and run tests
$testsDirectory = "tests"
if (Test-Path $testsDirectory) {
Write-Output "Running tests in $testsDirectory and all subdirectories ..."
Push-Location $testsDirectory
python -m unittest discover -s . -p "*.py"
Pop-Location
} else {
Write-Output "Tests directory $testsDirectory not found."
}
-29
View File
@@ -1,29 +0,0 @@
import os
import json
import logging
from openai import OpenAI
class StandaloneLLMTool:
def __init__(self):
self.client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
def get_detailed_instructions(self, user_prompt, model="llm-preview", max_tokens=16384):
response = self.client.completions.create(
model=model,
prompt=user_prompt,
max_tokens=max_tokens
)
return response
def process_user_input(self, user_prompt, model="llm-preview", max_tokens=16384):
logging.info(f"Received prompt: {user_prompt}")
response = self.get_detailed_instructions(user_prompt, model, max_tokens)
logging.info("Response generated")
return response.choices[0].text
# Utility function for programmatic access
def get_llm_response(prompt, model="llm-preview", max_tokens=16384):
tool = StandaloneLLMTool()
return tool.process_user_input(prompt, model, max_tokens)
+3 -86
View File
@@ -7,6 +7,7 @@ from typing import TypedDict, Union, TypeAlias, List # Added List for type hint
from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup
from telegram.ext import Application, CommandHandler, MessageHandler, filters, ContextTypes, CallbackQueryHandler
from browse_command import browse_command, button_callback
from inference_bot import InferenceBot
class MessageHandlerLogicResult(TypedDict):
success: bool
@@ -16,22 +17,15 @@ class MessageHandlerLogicResult(TypedDict):
LogicResult: TypeAlias = MessageHandlerLogicResult
class TelegramHelper:
CLAUDE_REBOOT_TARGET = 'claude'
HTML_QUOTE_BLOCK_START = '<blockquote expandable><b>Thinking...</b>'
HTML_QUOTE_BLOCK_END = '</blockquote>'
DEFAULT_REBOOT_CLAUDE_FILE = '.reboot_claude'
DEFAULT_REBOOT_FILE = '.doreboot'
CHUNK_MESSAGE_SLEEP_DURATION = 0.1
def __init__(self, bot,
reboot_claude_file_path: str | None = None,
reboot_file_path: str | None = None,
def __init__(self, bot : InferenceBot,
chunk_message_sleep_duration: float | None = None):
self.bot = bot
self.telegram_bot_token = os.getenv('TELEGRAM_BOT_TOKEN')
self.start_time = time.time()
self.reboot_claude_file = reboot_claude_file_path or self.DEFAULT_REBOOT_CLAUDE_FILE
self.reboot_file = reboot_file_path or self.DEFAULT_REBOOT_FILE
self.chunk_message_sleep_duration = chunk_message_sleep_duration if chunk_message_sleep_duration is not None else self.CHUNK_MESSAGE_SLEEP_DURATION
async def _start_logic(self) -> str:
@@ -146,93 +140,16 @@ class TelegramHelper:
response_text = await self._abort_processing_logic(user_id)
await query.edit_message_text(text=response_text)
# --- Reboot Command ---
def _reboot_logic(self, user_message_parts: List[str], chat_id_to_write: str) -> None:
"""Handles the logic for creating reboot files."""
if len(user_message_parts) > 1 and user_message_parts[1].lower() == self.CLAUDE_REBOOT_TARGET:
try:
with open(self.reboot_claude_file, 'w') as f:
f.write("") # Create/truncate the file
logging.info(f"Created/truncated Claude reboot file: {self.reboot_claude_file}")
except IOError as e:
logging.error(f"Failed to create/truncate Claude reboot file {self.reboot_claude_file}: {e}")
# Create the main reboot file if it doesn't exist
if not os.path.exists(self.reboot_file):
try:
with open(self.reboot_file, 'w') as f:
f.write(chat_id_to_write)
logging.info(f"Created main reboot file: {self.reboot_file} with chat_id.")
except IOError as e:
logging.error(f"Failed to create main reboot file {self.reboot_file}: {e}")
else:
logging.info(f"Main reboot file {self.reboot_file} already exists. Not overwriting chat_id.")
async def reboot(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
"""Handles the /reboot command, triggers file creation and exits."""
user_message_parts = update.message.text.split()
chat_id_str = str(update.effective_chat.id) if update and update.effective_chat else ""
self._reboot_logic(user_message_parts, chat_id_str)
if update:
try:
await update.message.reply_text("Rebooting the bot...")
except Exception as e_reply:
logging.error(f"Failed to send reboot reply: {e_reply}")
logging.info("Initiating shutdown for reboot...")
sys.exit(0) # This part is not directly testable for completion in unit tests
# --- Check Doreboot File ---
async def _check_doreboot_file_logic(self) -> Union[str, None]:
"""Checks for the reboot file, reads chat_id, removes file, and returns chat_id."""
if os.path.exists(self.reboot_file):
chat_id = None
try:
with open(self.reboot_file, 'r') as f:
chat_id = f.read().strip()
# Attempt to remove the file after reading
try:
os.remove(self.reboot_file)
logging.info(f"Successfully read and removed reboot file: {self.reboot_file}")
except OSError as e_remove:
logging.error(f"Failed to remove reboot file {self.reboot_file} after reading: {e_remove}")
# Still return chat_id if read was successful, to attempt notification
return chat_id
except IOError as e_read:
logging.error(f"Error reading reboot file {self.reboot_file}: {e_read}")
# If reading failed, attempt to remove anyway if it exists, to prevent stale files
if os.path.exists(self.reboot_file):
try:
os.remove(self.reboot_file)
logging.warning(f"Removed reboot file {self.reboot_file} after a read error.")
except OSError as e_remove_after_fail:
logging.error(f"Failed to remove reboot file {self.reboot_file} even after a read error: {e_remove_after_fail}")
return None # Reading failed
return None # File does not exist
async def check_doreboot_file(self, application: Application) -> None:
"""Checks for reboot file using logic method and sends notification if applicable."""
chat_id = await self._check_doreboot_file_logic()
if chat_id:
try:
await application.bot.send_message(chat_id=chat_id, text="The application has finished initializing.")
logging.info(f"Sent reboot initialization notification to chat_id: {chat_id}")
except Exception as e:
logging.error(f"Failed to send reboot initialization notification to chat_id {chat_id}: {e}")
async def browse(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
await browse_command(update, context, self.bot)
def run(self):
application = Application.builder().token(self.telegram_bot_token).post_init(self.check_doreboot_file).build()
application = Application.builder().token(self.telegram_bot_token).build()
application.add_handler(CommandHandler("start", self.start))
application.add_handler(CommandHandler("clear", self.clear))
application.add_handler(CommandHandler("switch", self.switch))
application.add_handler(CommandHandler("status", self.status))
application.add_handler(CommandHandler("reboot", self.reboot))
application.add_handler(CommandHandler("browse", self.browse))
application.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, self.handle_message))
application.add_handler(CallbackQueryHandler(self.abort_processing, pattern='^abort$'))
View File
View File
@@ -1,33 +0,0 @@
import unittest
from unittest.mock import patch, MagicMock
from anthropic_telegram_inference_bot import AnthropicTelegramInferenceBot
class TestAnthropicTelegramInferenceBot(unittest.TestCase):
def setUp(self):
self.bot = AnthropicTelegramInferenceBot()
@patch('anthropic_telegram_inference_bot.Anthropic')
def test_get_chat_response(self, MockAnthropic):
mock_anthropic = MockAnthropic.return_value
mock_anthropic.messages.create.return_value = MagicMock()
messages = [{"role": "user", "content": "Hello"}]
response = self.bot.get_chat_response(messages)
self.assertIsNotNone(response)
@patch('anthropic_telegram_inference_bot.Anthropic')
def test_handle_message(self, MockAnthropic):
mock_anthropic = MockAnthropic.return_value
mock_anthropic.messages.create.return_value = MagicMock(content=[MagicMock(type="message", text="response content")])
user_id = "user123"
user_message = "Hello"
response = self.bot.handle_message(user_id, user_message)
self.assertIsNotNone(response)
# Additional testing for error cases and edge cases
if __name__ == '__main__':
unittest.main()
@@ -1,33 +0,0 @@
import unittest
from base_telegram_inference_bot import BaseTelegramInferenceBot
class TestBaseTelegramInferenceBot(unittest.TestCase):
def setUp(self):
# Initialize the bot or mock any dependencies here
self.bot = BaseTelegramInferenceBot()
def test_load_system_prompt(self):
# Example test case for load_system_prompt method
result = self.bot.load_system_prompt()
self.assertIsNotNone(result) # Replace with actual expected result
def test_load_functions(self):
# Test the load_functions method
functions = self.bot.load_functions()
self.assertIsInstance(functions, list) # Replace with actual expected result
self.assertTrue(len(functions) > 0) # Assuming it should load some functions
def test_clear_conversation(self):
# Test the clear_conversation method
self.bot.clear_conversation()
self.assertEqual(self.bot.conversations, {}) # Assuming conversations is a dictionary
def test_call_tool(self):
# Test the call_tool method
tool_name = "some_tool"
params = {"param1": "value1"}
result = self.bot.call_tool(tool_name, params)
self.assertIsNotNone(result) # Replace with actual expected result
if __name__ == '__main__':
unittest.main()
@@ -1,38 +0,0 @@
import unittest
from unittest.mock import patch, MagicMock
from chatgpt_telegram_inference_bot import ChatGPTTelegramInferenceBot
class TestChatGPTTelegramInferenceBot(unittest.TestCase):
def setUp(self):
self.bot = ChatGPTTelegramInferenceBot()
@patch('chatgpt_telegram_inference_bot.OpenAI')
def test_get_chat_response(self, MockOpenAI):
mock_ai = MockOpenAI.return_value
mock_ai.chat.completions.create.return_value = MagicMock()
messages = [{"role": "user", "content": "Hello"}]
response = self.bot.get_chat_response(messages)
self.assertIsNotNone(response)
@patch('chatgpt_telegram_inference_bot.OpenAI')
def test_handle_message(self, MockOpenAI):
mock_ai = MockOpenAI.return_value
mock_ai.chat.completions.create.return_value = MagicMock(choices=[MagicMock(message={"content": "response content"}, finish_reason='stop')])
user_id = "user123"
user_message = "Hello"
response = self.bot.handle_message(user_id, user_message)
self.assertIsNotNone(response)
def test_switch_model(self):
initial_model = self.bot.model
self.bot.switch_model()
self.assertNotEqual(initial_model, self.bot.model)
# Additional testing for error cases and edge cases
if __name__ == '__main__':
unittest.main()
View File
@@ -1,280 +0,0 @@
import unittest
from unittest.mock import MagicMock, patch, AsyncMock, ANY
import os
# Assuming anthropic_telegram_inference_bot.py is in the parent directory or PYTHONPATH is set
from anthropic_telegram_inference_bot import AnthropicTelegramInferenceBot
# Mock response from Anthropic client's messages.create
def create_mock_anthropic_response(content_text=None, stop_reason="end_turn", tool_use_parts=None):
mock_response = MagicMock()
mock_response.stop_reason = stop_reason
content_blocks = []
if content_text:
text_block = MagicMock()
text_block.type = "text"
text_block.text = content_text
content_blocks.append(text_block)
if tool_use_parts:
for tu_part in tool_use_parts: # tu_part = {"id": "toolu_123", "name": "get_weather", "input": {}}
tool_block = MagicMock()
tool_block.type = "tool_use"
tool_block.id = tu_part["id"]
tool_block.name = tu_part["name"]
tool_block.input = tu_part["input"]
content_blocks.append(tool_block)
mock_response.content = content_blocks
return mock_response
class TestAnthropicTelegramInferenceBot(unittest.IsolatedAsyncioTestCase):
def setUp(self):
self.original_anthropic_api_key = os.environ.get("ANTHROPIC_API_KEY")
self.original_small_model = os.environ.get("ANTHROPIC_SMALL_MODEL")
self.original_large_model = os.environ.get("ANTHROPIC_LARGE_MODEL")
self.original_system_prompt_path = os.environ.get("SYSTEM_PROMPT_PATH")
for key in ["ANTHROPIC_API_KEY", "ANTHROPIC_SMALL_MODEL", "ANTHROPIC_LARGE_MODEL", "SYSTEM_PROMPT_PATH"]:
if os.environ.get(key):
del os.environ[key]
self.mock_anthropic_client_instance = MagicMock()
self.mock_anthropic_client_instance.messages.create = MagicMock()
def tearDown(self):
if self.original_anthropic_api_key: os.environ["ANTHROPIC_API_KEY"] = self.original_anthropic_api_key
if self.original_small_model: os.environ["ANTHROPIC_SMALL_MODEL"] = self.original_small_model
if self.original_large_model: os.environ["ANTHROPIC_LARGE_MODEL"] = self.original_large_model
if self.original_system_prompt_path: os.environ["SYSTEM_PROMPT_PATH"] = self.original_system_prompt_path
@patch('anthropic.Anthropic')
def test_init_with_anthropic_defaults_env_key(self, MockAnthropicConstructor):
MockAnthropicConstructor.return_value = self.mock_anthropic_client_instance
os.environ["ANTHROPIC_API_KEY"] = "test_anthropic_key"
bot = AnthropicTelegramInferenceBot()
MockAnthropicConstructor.assert_called_once_with(api_key="test_anthropic_key")
self.assertEqual(bot.anthropic_client, self.mock_anthropic_client_instance)
self.assertEqual(bot.model, os.environ.get("ANTHROPIC_SMALL_MODEL", "claude-3-haiku-20240307"))
self.assertEqual(bot.max_tokens, int(os.environ.get("ANTHROPIC_SMALL_MODEL_MAX_TOKENS", 2000)))
@patch('anthropic.Anthropic')
def test_init_with_provided_client_and_models(self, MockAnthropicConstructor):
preconfigured_client = MagicMock()
bot = AnthropicTelegramInferenceBot(
anthropic_client=preconfigured_client,
small_model_name="custom-small",
small_model_max_tokens=100,
large_model_name="custom-large",
large_model_max_tokens=200
)
MockAnthropicConstructor.assert_not_called()
self.assertEqual(bot.anthropic_client, preconfigured_client)
self.assertEqual(bot.model, "custom-small")
self.assertEqual(bot.max_tokens, 100)
self.assertEqual(bot.small_model_name, "custom-small")
self.assertEqual(bot.large_model_name, "custom-large")
def test_get_llm_description(self):
bot = AnthropicTelegramInferenceBot(small_model_name="claude-test", small_model_max_tokens=500)
self.assertEqual(bot.get_llm_description(), "LLM: claude-test, Max Tokens: 500")
async def test_switch_model(self):
bot = AnthropicTelegramInferenceBot(
small_model_name="claude-small", small_model_max_tokens=10,
large_model_name="claude-large", large_model_max_tokens=20
)
self.assertEqual(bot.model, "claude-small")
self.assertEqual(bot.max_tokens, 10)
status = await bot.switch_model()
self.assertEqual(bot.model, "claude-large")
self.assertEqual(bot.max_tokens, 20)
self.assertEqual(status, "Switched to model: claude-large")
status = await bot.switch_model()
self.assertEqual(bot.model, "claude-small")
self.assertEqual(bot.max_tokens, 10)
self.assertEqual(status, "Switched to model: claude-small")
def test_get_chat_response_success_text_only(self):
bot = AnthropicTelegramInferenceBot(anthropic_client=self.mock_anthropic_client_instance)
bot.model = "test-claude"
bot.max_tokens = 150
mock_api_response = create_mock_anthropic_response(content_text="Hello from Anthropic API")
self.mock_anthropic_client_instance.messages.create.return_value = mock_api_response
messages = [{"role": "user", "content": "Hi"}] # Anthropic format
response = bot.get_chat_response(messages, []) # tools = empty list
self.mock_anthropic_client_instance.messages.create.assert_called_once_with(
model="test-claude",
max_tokens=150,
messages=messages,
system=bot.system_prompt, # Ensure system prompt is passed
tools=None, # No tools passed to API if empty list or None
tool_choice=None
)
self.assertEqual(response, mock_api_response)
def test_get_chat_response_with_tools(self):
bot = AnthropicTelegramInferenceBot(anthropic_client=self.mock_anthropic_client_instance)
bot.model = "claude-toolmaster"
bot.max_tokens = 300
mock_tools_spec = [{"name": "get_weather", "description": "Gets weather", "input_schema": {"type": "object", "properties": {}}}]
mock_api_response = create_mock_anthropic_response(content_text="Thinking...", tool_use_parts=[
{"id": "tool1", "name": "get_weather", "input": {"location": "here"}}
])
self.mock_anthropic_client_instance.messages.create.return_value = mock_api_response
messages = [{"role": "user", "content": "Weather?"}]
response = bot.get_chat_response(messages, mock_tools_spec)
self.mock_anthropic_client_instance.messages.create.assert_called_once_with(
model="claude-toolmaster",
max_tokens=300,
messages=messages,
system=bot.system_prompt,
tools=mock_tools_spec,
tool_choice={"type": "auto"}
)
self.assertEqual(response.content[0].type, "text") # First part can be text
self.assertEqual(response.content[1].type, "tool_use")
def test_get_chat_response_api_error(self):
bot = AnthropicTelegramInferenceBot(anthropic_client=self.mock_anthropic_client_instance)
self.mock_anthropic_client_instance.messages.create.side_effect = Exception("Anthropic API Down")
with self.assertRaisesRegex(Exception, "Anthropic API Down"):
bot.get_chat_response([{"role": "user", "content": "trigger"}], [])
async def test_handle_message_simple_response_no_tools(self):
# This test is more involved as it touches BaseTelegramInferenceBot's handle_message structure
# which then calls the overridden get_chat_response.
bot = AnthropicTelegramInferenceBot(anthropic_client=self.mock_anthropic_client_instance)
bot.system_prompt = "System prompt for Anthropic"
# Mock get_chat_response directly to isolate its behavior from full handle_message logic of base
# However, the point of this bot is its get_chat_response and subsequent processing.
# So, let's mock the API call within get_chat_response.
api_response = create_mock_anthropic_response(content_text="Anthropic says hello.")
self.mock_anthropic_client_instance.messages.create.return_value = api_response
# Ensure functions are empty for this test, so no tool logic is triggered
bot.functions = []
bot.tools = []
response_content = await bot.handle_message(user_id=101, user_message="Hello Anthropic")
self.assertEqual(response_content, "Anthropic says hello.")
self.assertIn(101, bot.conversation_history)
# Anthropic's handle_message structure:
# 1. User message added to history.
# 2. get_chat_response is called.
# 3. Response content (text) is extracted.
# 4. Assistant text response is added to history.
# Expected history: [User, Assistant_Text_Response] (system prompt handled by get_chat_response)
# The base class handle_message adds system prompt if not present.
# Anthropic handle_message modifies history format before calling get_chat_response.
# Let's trace Base.handle_message -> Anthropic.handle_message -> Anthropic.get_chat_response
# Base.handle_message:
# - Adds system prompt to history if first turn: `self.conversation_history[user_id] = [{"role": "system", "content": self.system_prompt}]` (OpenAI style)
# - Appends user message: `{"role": "user", "content": user_message}`
# - Calls self.get_chat_response(messages, self.functions) -> This is Anthropic's get_chat_response
# Anthropic.get_chat_response:
# - Takes OpenAI style `messages` and `self.functions` (tool specs).
# - Calls `anthropic_client.messages.create` with Anthropic style messages and system prompt.
# Anthropic.handle_message (overridden):
# - Prepares Anthropic-style messages from conversation_history (which is OpenAI style from Base)
# - Calls get_chat_response with these Anthropic messages and self.functions (tool_specs)
# - Processes response, extracts text, handles tool calls.
# - Appends *user* message (original) and *assistant* text response to self.conversation_history (OpenAI style).
# For this test, we are calling AnthropicBot.handle_message directly.
# 1. `user_id` not in `self.conversation_history`: `system_prompt` not added yet by Base logic.
# Anthropic's `handle_message` will create `anthropic_messages` from this.
# If `conversation_history` is empty, `anthropic_messages` = `[{"role": "user", "content": user_message}]`
# 2. `get_chat_response` called with `anthropic_messages` and `bot.system_prompt` passed to API.
# 3. Response "Anthropic says hello."
# 4. Original `user_message` and "Anthropic says hello." (as assistant) added to `self.conversation_history`.
history = bot.conversation_history[101]
self.assertEqual(len(history), 2) # User, Assistant
self.assertEqual(history[0]["role"], "user")
self.assertEqual(history[0]["content"], "Hello Anthropic")
self.assertEqual(history[1]["role"], "assistant")
self.assertEqual(history[1]["content"], "Anthropic says hello.")
# Check API call (made by the mocked get_chat_response indirectly)
self.mock_anthropic_client_instance.messages.create.assert_called_once()
call_args = self.mock_anthropic_client_instance.messages.create.call_args
self.assertEqual(call_args.kwargs["system"], "System prompt for Anthropic")
# Initial messages for API should just be the user message for first turn
self.assertEqual(call_args.kwargs["messages"], [{"role": "user", "content": "Hello Anthropic"}])
async def test_handle_message_with_tool_calls(self):
bot = AnthropicTelegramInferenceBot(anthropic_client=self.mock_anthropic_client_instance)
bot.system_prompt = "You are a helpful, tool-using assistant."
# Define a tool for the bot (OpenAI format, will be converted by Anthropic bot for API)
mock_tool_oai_format = {"type": "function", "function": {"name": "get_weather", "description": "Get weather", "parameters": {}}}
bot.functions = [mock_tool_oai_format] # This is used to generate anthropic_tools for API
# API Response 1: Request for tool call
tool_use_part = {"id": "toolu_xyz", "name": "get_weather", "input": {"location": "paris"}}
api_response_1 = create_mock_anthropic_response(tool_use_parts=[tool_use_part])
# API Response 2: Final text response after tool execution
api_response_2 = create_mock_anthropic_response(content_text="The weather in Paris is nice.")
self.mock_anthropic_client_instance.messages.create.side_effect = [api_response_1, api_response_2]
# Mock the bot's call_tool method (from BaseTelegramInferenceBot)
bot.call_tool = MagicMock(return_value='''{"weather": "sunny"}''') # Tool execution result
user_id = 102
user_message = "What's the weather in Paris?"
final_text_response = await bot.handle_message(user_id, user_message)
self.assertEqual(final_text_response, "The weather in Paris is nice.")
self.assertEqual(self.mock_anthropic_client_instance.messages.create.call_count, 2)
bot.call_tool.assert_called_once_with("get_weather", {"location": "paris"}) # Anthropic passes input as dict
# Check conversation history (OpenAI style)
history = bot.conversation_history[user_id]
self.assertEqual(history[0]["role"], "user")
self.assertEqual(history[0]["content"], user_message)
# Assistant message that requested tool call (Anthropic-specific format stored by its handle_message)
# Anthropic's handle_message appends the raw tool_use block and then the tool_result
self.assertEqual(history[1]["role"], "assistant")
self.assertTrue(isinstance(history[1]["content"], list)) # Anthropic content is a list
self.assertEqual(history[1]["content"][0]["type"], "tool_use")
self.assertEqual(history[1]["content"][0]["id"], "toolu_xyz")
self.assertEqual(history[2]["role"], "tool")
self.assertEqual(history[2]["tool_call_id"], "toolu_xyz")
self.assertEqual(history[2]["name"], "get_weather")
self.assertEqual(history[2]["content"], '''{"weather": "sunny"}''') # call_tool result
self.assertEqual(history[3]["role"], "assistant") # Final text response
self.assertTrue(isinstance(history[3]["content"], str)) # simple text
self.assertEqual(history[3]["content"], "The weather in Paris is nice.")
if __name__ == '__main__':
unittest.main()
-310
View File
@@ -1,310 +0,0 @@
import unittest
from unittest.mock import patch, mock_open, MagicMock
import os
import json
# Ensure the path includes the directory where base_telegram_inference_bot is located
# This might require adjustment based on actual project structure if tests are run from root
# For now, assuming it can be imported directly or via PYTHONPATH
from base_telegram_inference_bot import BaseTelegramInferenceBot
from tools.base_tool import BaseTool # For mocking tool structure
# Create a concrete subclass for testing, as BaseTelegramInferenceBot is abstract
class ConcreteTestBot(BaseTelegramInferenceBot):
def __init__(self, system_prompt_content=None, system_prompt_path=None, mock_tools=None, mock_functions=None):
# Mock load_functions during super().__init__ if needed, or control tools/functions directly
self._mock_tools = mock_tools if mock_tools is not None else []
self._mock_functions = mock_functions if mock_functions is not None else []
super().__init__(system_prompt_content=system_prompt_content, system_prompt_path=system_prompt_path)
# Override load_functions to use mocks
def load_functions(self):
return self._mock_tools, self._mock_functions
def get_chat_response(self, messages):
pass # Abstract method, not tested here directly
async def handle_message(self, user_id, user_message):
pass # Abstract method
def get_llm_description(self) -> str:
return "Mock LLM Description" # Concrete implementation for testing get_bot_status
async def start(self):
pass # Abstract method
async def abort_processing(self, user_id):
pass # Abstract method
async def switch_model(self):
pass # Abstract method
class TestBaseTelegramInferenceBot(unittest.TestCase):
def setUp(self):
# Reset relevant environment variables before each test
self.original_system_prompt_path = os.environ.get("SYSTEM_PROMPT_PATH")
if "SYSTEM_PROMPT_PATH" in os.environ:
del os.environ["SYSTEM_PROMPT_PATH"]
def tearDown(self):
# Restore environment variables
if self.original_system_prompt_path:
os.environ["SYSTEM_PROMPT_PATH"] = self.original_system_prompt_path
elif "SYSTEM_PROMPT_PATH" in os.environ: # Ensure it's removed if test set it and it wasn't there before
del os.environ["SYSTEM_PROMPT_PATH"]
def test_init_with_direct_system_prompt(self):
bot = ConcreteTestBot(system_prompt_content="Direct prompt content")
self.assertEqual(bot.system_prompt, "Direct prompt content")
@patch("os.path.isfile")
@patch("builtins.open", new_callable=mock_open, read_data="File prompt content")
def test_init_with_system_prompt_path_argument(self, mock_file_open, mock_isfile):
mock_isfile.return_value = True
bot = ConcreteTestBot(system_prompt_path="dummy/path.txt")
self.assertEqual(bot.system_prompt, "File prompt content")
mock_isfile.assert_called_once_with("dummy/path.txt")
mock_file_open.assert_called_once_with("dummy/path.txt", "r", encoding="utf-8")
@patch("os.path.isfile")
@patch("builtins.open", new_callable=mock_open, read_data="Env prompt content")
def test_init_with_env_system_prompt_path(self, mock_file_open, mock_isfile):
mock_isfile.return_value = True
os.environ["SYSTEM_PROMPT_PATH"] = "env/path.txt"
bot = ConcreteTestBot()
self.assertEqual(bot.system_prompt, "Env prompt content")
mock_isfile.assert_called_once_with("env/path.txt")
mock_file_open.assert_called_once_with("env/path.txt", "r", encoding="utf-8")
def test_init_with_default_system_prompt(self):
# Ensure ENV var is not set for this test
if "SYSTEM_PROMPT_PATH" in os.environ:
del os.environ["SYSTEM_PROMPT_PATH"]
bot = ConcreteTestBot()
self.assertEqual(bot.system_prompt, "You are a helpful AI assistant.")
@patch("os.path.isfile", return_value=False)
def test_init_with_invalid_system_prompt_path(self, mock_isfile):
bot = ConcreteTestBot(system_prompt_path="invalid/path.txt")
self.assertEqual(bot.system_prompt, "You are a helpful AI assistant.")
mock_isfile.assert_called_once_with("invalid/path.txt")
@patch("os.path.isfile")
@patch("builtins.open", side_effect=IOError("File read error"))
def test_init_with_system_prompt_file_read_error(self, mock_file_open, mock_isfile):
mock_isfile.return_value = True
bot = ConcreteTestBot(system_prompt_path="dummy/path.txt")
self.assertEqual(bot.system_prompt, "You are a helpful AI assistant.")
def test_clear_conversation_history(self):
mock_tool_instance = MagicMock(spec=BaseTool)
bot = ConcreteTestBot(mock_tools=[mock_tool_instance])
bot.conversation_history[123] = [{"role": "user", "content": "Hello"}]
bot.clear_conversation_history(123)
self.assertNotIn(123, bot.conversation_history)
mock_tool_instance.clear.assert_called_once()
def test_clear_conversation_history_user_not_found(self):
mock_tool_instance = MagicMock(spec=BaseTool)
bot = ConcreteTestBot(mock_tools=[mock_tool_instance])
bot.clear_conversation_history(404)
self.assertNotIn(404, bot.conversation_history)
mock_tool_instance.clear.assert_called_once()
def test_processing_status(self):
bot = ConcreteTestBot()
self.assertEqual(bot.processing_status, {})
bot.set_processing_status(123, 789)
self.assertEqual(bot.processing_status[123], {"processing": True, "message_id": 789})
bot.clear_processing_status(123)
self.assertNotIn(123, bot.processing_status)
def test_clear_processing_status_user_not_found(self):
bot = ConcreteTestBot()
bot.clear_processing_status(404)
self.assertNotIn(404, bot.processing_status)
def test_call_tool_success_dict_args(self):
mock_tool = MagicMock(spec=BaseTool)
mock_tool.get_functions.return_value = [
{"function": {"name": "test_tool", "description": "A test tool", "parameters": {}}}
]
mock_tool.execute.return_value = "Tool executed successfully"
bot = ConcreteTestBot(mock_tools=[mock_tool], mock_functions=mock_tool.get_functions())
result = bot.call_tool("test_tool", {"arg1": "value1"})
self.assertEqual(result, "Tool executed successfully")
mock_tool.execute.assert_called_once_with("test_tool", arg1="value1")
def test_call_tool_success_json_string_args(self):
mock_tool = MagicMock(spec=BaseTool)
mock_tool.get_functions.return_value = [
{"function": {"name": "test_tool_json", "parameters": {}}}
]
mock_tool.execute.return_value = "Tool JSON OK"
bot = ConcreteTestBot(mock_tools=[mock_tool], mock_functions=mock_tool.get_functions())
args_json_str = '''{"param": "value"}'''
result = bot.call_tool("test_tool_json", args_json_str)
self.assertEqual(result, "Tool JSON OK")
mock_tool.execute.assert_called_once_with("test_tool_json", param="value")
def test_call_tool_malformed_json_string_args(self):
bot = ConcreteTestBot(mock_tools=[])
args_malformed_json_str = '''{"param": "value"'''
result = bot.call_tool("some_tool", args_malformed_json_str)
self.assertTrue("Error: Malformed arguments for tool call" in result)
def test_call_tool_unexpected_arg_type(self):
bot = ConcreteTestBot(mock_tools=[])
result = bot.call_tool("some_tool", 12345) # Integer instead of dict/str
self.assertTrue("Error: Invalid argument type for tool call" in result)
def test_call_tool_none_args(self):
mock_tool = MagicMock(spec=BaseTool)
mock_tool.get_functions.return_value = [
{"function": {"name": "test_tool_none", "parameters": {}}}
]
mock_tool.execute.return_value = "Tool None OK"
bot = ConcreteTestBot(mock_tools=[mock_tool], mock_functions=mock_tool.get_functions())
result = bot.call_tool("test_tool_none", None)
self.assertEqual(result, "Tool None OK")
mock_tool.execute.assert_called_once_with("test_tool_none") # No kwargs if None
def test_call_tool_not_found(self):
bot = ConcreteTestBot(mock_tools=[])
result = bot.call_tool("non_existent_tool", {})
self.assertEqual(result, "Error: Tool function non_existent_tool not found.")
def test_call_tool_execute_exception(self):
mock_tool = MagicMock(spec=BaseTool)
mock_tool.get_functions.return_value = [{"function": {"name": "error_tool", "parameters": {}}}]
mock_tool.execute.side_effect = Exception("Execution failed")
bot = ConcreteTestBot(mock_tools=[mock_tool], mock_functions=mock_tool.get_functions())
result = bot.call_tool("error_tool", {})
self.assertEqual(result, "Error executing tool error_tool: Execution failed")
def test_get_system_prompt_description(self):
if "SYSTEM_PROMPT_PATH" in os.environ: # Ensure clean state
del os.environ["SYSTEM_PROMPT_PATH"]
bot_default = ConcreteTestBot()
self.assertEqual(bot_default.get_system_prompt_description(), "System Prompt: Default")
bot_custom_content = ConcreteTestBot(system_prompt_content="Custom content here")
self.assertEqual(bot_custom_content.get_system_prompt_description(), "System Prompt: Custom")
os.environ["SYSTEM_PROMPT_PATH"] = "some/path.txt"
bot_env_default_prompt = ConcreteTestBot() # system_prompt itself is default
self.assertEqual(bot_env_default_prompt.get_system_prompt_description(), "System Prompt: Custom (via ENV)")
with patch("os.path.isfile", return_value=True), \
patch("builtins.open", mock_open(read_data="File prompt from ENV")):
bot_env_file_prompt = ConcreteTestBot() # system_prompt gets loaded from ENV path
self.assertEqual(bot_env_file_prompt.get_system_prompt_description(), "System Prompt: Custom")
del os.environ["SYSTEM_PROMPT_PATH"]
with patch("os.path.isfile", return_value=True), \
patch("builtins.open", mock_open(read_data="File prompt from arg")):
bot_custom_file_arg = ConcreteTestBot(system_prompt_path="custom/file.txt")
self.assertEqual(bot_custom_file_arg.get_system_prompt_description(), "System Prompt: Custom")
@patch.object(ConcreteTestBot, 'get_llm_description', return_value="Test LLM Description")
@patch.object(ConcreteTestBot, 'get_system_prompt_description', return_value="Test Prompt Description")
async def test_get_bot_status(self, mock_prompt_desc, mock_llm_desc):
bot = ConcreteTestBot()
status = await bot.get_bot_status()
self.assertEqual(status, "Test Prompt Description\nTest LLM Description")
mock_prompt_desc.assert_called_once()
mock_llm_desc.assert_called_once()
@patch('os.path.dirname', return_value='/mock/path')
@patch('os.path.join')
@patch('os.path.exists')
@patch('os.listdir')
@patch('importlib.import_module')
def test_load_functions_no_tools_dir(self, mock_import_module, mock_listdir, mock_exists, mock_join, mock_dirname):
mock_join.return_value = '/mock/path/tools'
mock_exists.return_value = False
class BotForLoadTest(BaseTelegramInferenceBot):
load_system_prompt = MagicMock(return_value="Default")
get_chat_response = MagicMock(); handle_message = MagicMock(); get_llm_description = MagicMock(return_value="mock")
start = MagicMock(); abort_processing = MagicMock(); switch_model = MagicMock()
bot = BotForLoadTest()
self.assertEqual(bot.tools, [])
self.assertEqual(bot.functions, [])
mock_listdir.assert_not_called()
@patch('os.path.dirname', return_value='/mock/base_bot_dir')
@patch('os.path.join', side_effect=lambda *args: os.path.normpath(os.path.join(*args)))
@patch('os.path.exists', return_value=True)
@patch('os.listdir', return_value=['my_tool.py', '__init__.py', 'base_tool.py'])
@patch('importlib.import_module')
def test_load_functions_with_one_tool(self, mock_import_module, mock_listdir, mock_exists, mock_join, mock_dirname):
mock_tool_class = MagicMock(spec=BaseTool) # This is the class itself
mock_tool_instance = MagicMock(spec=BaseTool) # This is the instance
mock_tool_class.return_value = mock_tool_instance # mock_tool_class() creates mock_tool_instance
mock_tool_instance.get_functions.return_value = [{"function": {"name": "sample_function"}}]
mock_my_tool_module = MagicMock()
# Simulate inspect.getmembers behavior: returns list of (name, member) tuples
# Only include members that are classes, derive from BaseTool, and are not BaseTool itself.
mock_my_tool_module.ValidToolClass = mock_tool_class
mock_my_tool_module.NotATool = object()
mock_my_tool_module.BaseTool = BaseTool # This should be skipped by the loader
def import_side_effect(module_name):
if module_name == 'tools.my_tool':
return mock_my_tool_module
raise ImportError(f"Unexpected import: {module_name}")
mock_import_module.side_effect = import_side_effect
class BotForLoadTest(BaseTelegramInferenceBot):
load_system_prompt = MagicMock(return_value="Default")
get_chat_response = MagicMock(); handle_message = MagicMock(); get_llm_description = MagicMock(return_value="mock")
start = MagicMock(); abort_processing = MagicMock(); switch_model = MagicMock()
bot = BotForLoadTest()
self.assertEqual(len(bot.tools), 1)
self.assertIs(bot.tools[0], mock_tool_instance)
self.assertEqual(len(bot.functions), 1)
self.assertEqual(bot.functions[0]['function']['name'], "sample_function")
mock_import_module.assert_called_once_with('tools.my_tool')
mock_tool_class.assert_called_once_with() # Tool class was instantiated
mock_tool_instance.get_functions.assert_called_once_with()
@patch('os.path.dirname', return_value='/mock/base_bot_dir')
@patch('os.path.join', side_effect=lambda *args: os.path.normpath(os.path.join(*args)))
@patch('os.path.exists', return_value=True)
@patch('os.listdir', return_value=['tool_with_init_error.py'])
@patch('importlib.import_module')
@patch('logging.error') # Mock logging to check for error messages
def test_load_functions_tool_instantiation_error(self, mock_logging_error, mock_import_module, mock_listdir, mock_exists, mock_join, mock_dirname):
mock_tool_class_init_error = MagicMock(spec=BaseTool)
mock_tool_class_init_error.side_effect = Exception("Failed to init tool") # Error on instantiation
mock_error_tool_module = MagicMock()
mock_error_tool_module.ToolWithInitError = mock_tool_class_init_error
mock_import_module.return_value = mock_error_tool_module
class BotForLoadTest(BaseTelegramInferenceBot):
load_system_prompt = MagicMock(return_value="Default")
get_chat_response = MagicMock(); handle_message = MagicMock(); get_llm_description = MagicMock(return_value="mock")
start = MagicMock(); abort_processing = MagicMock(); switch_model = MagicMock()
bot = BotForLoadTest()
self.assertEqual(len(bot.tools), 0)
self.assertEqual(len(bot.functions), 0)
mock_logging_error.assert_any_call("Error instantiating tool ToolWithInitError from tool_with_init_error.py: Failed to init tool")
if __name__ == '__main__':
unittest.main(闂傚лен䦗婢у埊鍓解劓姣)
@@ -1,158 +0,0 @@
import unittest
from unittest.mock import MagicMock, patch, ANY
import os
# Assuming chatgpt_telegram_inference_bot.py and its parent are accessible
from chatgpt_telegram_inference_bot import ChatGPTTelegramInferenceBot
from openai_compatible_inference_bot import OpenAICompatibleInferenceBot # For patching super
class TestChatGPTTelegramInferenceBot(unittest.IsolatedAsyncioTestCase):
def setUp(self):
# Store and clear relevant environment variables
self.original_openai_key = os.environ.get("OPENAI_API_KEY")
self.original_small_model = os.environ.get("OPENAI_SMALL_MODEL")
self.original_large_model = os.environ.get("OPENAI_LARGE_MODEL")
self.original_small_tokens = os.environ.get("OPENAI_SMALL_MODEL_MAX_TOKENS")
self.original_large_tokens = os.environ.get("OPENAI_LARGE_MODEL_MAX_TOKENS")
self.original_system_prompt_path = os.environ.get("SYSTEM_PROMPT_PATH")
for key in ["OPENAI_API_KEY", "OPENAI_SMALL_MODEL", "OPENAI_LARGE_MODEL",
"OPENAI_SMALL_MODEL_MAX_TOKENS", "OPENAI_LARGE_MODEL_MAX_TOKENS", "SYSTEM_PROMPT_PATH"]:
if os.environ.get(key):
del os.environ[key]
# Mock the OpenAI client that OpenAICompatibleInferenceBot's __init__ might create
self.mock_openai_client = MagicMock()
def tearDown(self):
# Restore environment variables
if self.original_openai_key: os.environ["OPENAI_API_KEY"] = self.original_openai_key
if self.original_small_model: os.environ["OPENAI_SMALL_MODEL"] = self.original_small_model
if self.original_large_model: os.environ["OPENAI_LARGE_MODEL"] = self.original_large_model
if self.original_small_tokens: os.environ["OPENAI_SMALL_MODEL_MAX_TOKENS"] = self.original_small_tokens
if self.original_large_tokens: os.environ["OPENAI_LARGE_MODEL_MAX_TOKENS"] = self.original_large_tokens
if self.original_system_prompt_path: os.environ["SYSTEM_PROMPT_PATH"] = self.original_system_prompt_path
@patch.object(OpenAICompatibleInferenceBot, '__init__') # Mock the superclass's __init__
def test_init_defaults_and_super_call(self, mock_super_init):
os.environ["OPENAI_API_KEY"] = "test_key_chatgpt"
os.environ["OPENAI_SMALL_MODEL"] = "gpt-3.5-turbo-env"
os.environ["OPENAI_SMALL_MODEL_MAX_TOKENS"] = "350"
bot = ChatGPTTelegramInferenceBot()
mock_super_init.assert_called_once_with(
client=None, # ChatGPT bot will let superclass create it
api_key="test_key_chatgpt", # Passed to super
base_url=None,
api_version=None,
azure_deployment=None,
model_name="gpt-3.5-turbo-env", # Default small model from env
max_tokens_str="350", # Default small model tokens from env
small_model_name="gpt-3.5-turbo-env",
small_model_max_tokens_str="350",
large_model_name=os.environ.get("OPENAI_LARGE_MODEL", "gpt-4-turbo-preview"), # Default large
large_model_max_tokens_str=os.environ.get("OPENAI_LARGE_MODEL_MAX_TOKENS"),
system_prompt_content=None,
system_prompt_path=None,
is_gemini=False,
max_history_length=20 # Default from OpenAICompatibleInferenceBot
)
@patch.object(OpenAICompatibleInferenceBot, '__init__')
def test_init_with_arguments(self, mock_super_init):
mock_client_arg = MagicMock()
bot = ChatGPTTelegramInferenceBot(
openai_client=mock_client_arg,
api_key="arg_key",
small_model_name="arg_small_model",
small_model_max_tokens="123",
large_model_name="arg_large_model",
large_model_max_tokens="456",
system_prompt_content="Arg prompt"
)
mock_super_init.assert_called_once_with(
client=mock_client_arg,
api_key="arg_key",
base_url=None,
api_version=None,
azure_deployment=None,
model_name="arg_small_model", # Initially configured with small model
max_tokens_str="123",
small_model_name="arg_small_model",
small_model_max_tokens_str="123",
large_model_name="arg_large_model",
large_model_max_tokens_str="456",
system_prompt_content="Arg prompt",
system_prompt_path=None,
is_gemini=False,
max_history_length=20
)
# Test switch_model - this method is part of ChatGPTTelegramInferenceBot
# It calls _configure_model_and_tokens which is in the superclass.
# We need a bot instance where _configure_model_and_tokens can be called.
@patch('openai.OpenAI') # To allow instantiation of the bot by mocking client creation
async def test_switch_model_logic(self, mock_openai_constructor):
mock_openai_constructor.return_value = self.mock_openai_client # Mock client creation in super
# Set env vars for model names that switch_model will use as fallback
os.environ["OPENAI_SMALL_MODEL"] = "env-small-gpt"
os.environ["OPENAI_SMALL_MODEL_MAX_TOKENS"] = "100"
os.environ["OPENAI_LARGE_MODEL"] = "env-large-gpt"
os.environ["OPENAI_LARGE_MODEL_MAX_TOKENS"] = "200"
# Instantiate with initial model (small)
bot = ChatGPTTelegramInferenceBot()
self.assertEqual(bot.model, "env-small-gpt")
self.assertEqual(bot.max_tokens, 100)
# Switch to large
status = await bot.switch_model()
self.assertEqual(bot.model, "env-large-gpt")
self.assertEqual(bot.max_tokens, 200)
self.assertEqual(status, "Switched to model: env-large-gpt")
# Switch back to small
status = await bot.switch_model()
self.assertEqual(bot.model, "env-small-gpt")
self.assertEqual(bot.max_tokens, 100)
self.assertEqual(status, "Switched to model: env-small-gpt")
@patch('openai.OpenAI')
async def test_switch_model_uses_instance_configs_if_provided(self, mock_openai_constructor):
mock_openai_constructor.return_value = self.mock_openai_client
# Instantiate with specific model names, overriding potential env vars
bot = ChatGPTTelegramInferenceBot(
small_model_name="init-small", small_model_max_tokens="50",
large_model_name="init-large", large_model_max_tokens="150"
)
self.assertEqual(bot.model, "init-small") # Starts with small
self.assertEqual(bot.max_tokens, 50)
# Switch to large
status = await bot.switch_model()
self.assertEqual(bot.model, "init-large")
self.assertEqual(bot.max_tokens, 150)
self.assertEqual(status, "Switched to model: init-large")
# Switch back to small
status = await bot.switch_model()
self.assertEqual(bot.model, "init-small")
self.assertEqual(bot.max_tokens, 50)
self.assertEqual(status, "Switched to model: init-small")
# get_llm_description is inherited from OpenAICompatibleInferenceBot.
# Test just to ensure it works in the context of a ChatGPTBot instance
@patch('openai.OpenAI')
def test_get_llm_description_for_chatgpt_bot(self, mock_openai_constructor):
mock_openai_constructor.return_value = self.mock_openai_client
bot = ChatGPTTelegramInferenceBot(small_model_name="gpt-3.5-desc", small_model_max_tokens="777")
# Initially configured with small model
self.assertEqual(bot.get_llm_description(), "LLM: gpt-3.5-desc, Max Tokens: 777, Azure: False")
if __name__ == '__main__':
unittest.main()
-154
View File
@@ -1,154 +0,0 @@
import unittest
from unittest.mock import MagicMock, patch, ANY
import os
# Assuming gemini_telegram_inference_bot.py and its parent are accessible
from gemini_telegram_inference_bot import GeminiTelegramInferenceBot
from openai_compatible_inference_bot import OpenAICompatibleInferenceBot # For patching super
class TestGeminiTelegramInferenceBot(unittest.IsolatedAsyncioTestCase):
def setUp(self):
# Store and clear relevant environment variables
self.original_gemini_key = os.environ.get("GEMINI_API_KEY")
self.original_gemini_base_url = os.environ.get("GEMINI_API_BASE_URL")
self.original_small_model = os.environ.get("GEMINI_SMALL_MODEL")
self.original_large_model = os.environ.get("GEMINI_LARGE_MODEL")
self.original_small_tokens = os.environ.get("GEMINI_SMALL_MODEL_MAX_TOKENS")
self.original_large_tokens = os.environ.get("GEMINI_LARGE_MODEL_MAX_TOKENS")
self.original_system_prompt_path = os.environ.get("SYSTEM_PROMPT_PATH")
for key in ["GEMINI_API_KEY", "GEMINI_API_BASE_URL", "GEMINI_SMALL_MODEL", "GEMINI_LARGE_MODEL",
"GEMINI_SMALL_MODEL_MAX_TOKENS", "GEMINI_LARGE_MODEL_MAX_TOKENS", "SYSTEM_PROMPT_PATH"]:
if os.environ.get(key):
del os.environ[key]
self.mock_openai_client = MagicMock() # Used if superclass creates an OpenAI client
def tearDown(self):
# Restore environment variables
if self.original_gemini_key: os.environ["GEMINI_API_KEY"] = self.original_gemini_key
if self.original_gemini_base_url: os.environ["GEMINI_API_BASE_URL"] = self.original_gemini_base_url
if self.original_small_model: os.environ["GEMINI_SMALL_MODEL"] = self.original_small_model
if self.original_large_model: os.environ["GEMINI_LARGE_MODEL"] = self.original_large_model
if self.original_small_tokens: os.environ["GEMINI_SMALL_MODEL_MAX_TOKENS"] = self.original_small_tokens
if self.original_large_tokens: os.environ["GEMINI_LARGE_MODEL_MAX_TOKENS"] = self.original_large_tokens
if self.original_system_prompt_path: os.environ["SYSTEM_PROMPT_PATH"] = self.original_system_prompt_path
@patch.object(OpenAICompatibleInferenceBot, '__init__') # Mock the superclass's __init__
def test_init_defaults_and_super_call(self, mock_super_init):
os.environ["GEMINI_API_KEY"] = "test_key_gemini"
os.environ["GEMINI_API_BASE_URL"] = "https://gemini.env.com"
os.environ["GEMINI_SMALL_MODEL"] = "gemini-pro-env"
os.environ["GEMINI_SMALL_MODEL_MAX_TOKENS"] = "360"
bot = GeminiTelegramInferenceBot()
mock_super_init.assert_called_once_with(
client=None,
api_key="test_key_gemini",
base_url="https://gemini.env.com", # Passed to super
api_version=None,
azure_deployment=None,
model_name="gemini-pro-env",
max_tokens_str="360",
small_model_name="gemini-pro-env",
small_model_max_tokens_str="360",
large_model_name=os.environ.get("GEMINI_LARGE_MODEL", "gemini-1.5-pro-latest"), # Default large
large_model_max_tokens_str=os.environ.get("GEMINI_LARGE_MODEL_MAX_TOKENS"),
system_prompt_content=None,
system_prompt_path=None,
is_gemini=True, # Important for Gemini bot
max_history_length=20
)
@patch.object(OpenAICompatibleInferenceBot, '__init__')
def test_init_with_arguments(self, mock_super_init):
mock_client_arg = MagicMock()
bot = GeminiTelegramInferenceBot(
openai_client=mock_client_arg, # Name in Gemini bot is openai_client for consistency
api_key="arg_gem_key",
base_url="https://arg.gemini.com",
small_model_name="arg_gem_small",
small_model_max_tokens="124",
large_model_name="arg_gem_large",
large_model_max_tokens="457",
system_prompt_content="Gemini prompt"
)
mock_super_init.assert_called_once_with(
client=mock_client_arg,
api_key="arg_gem_key",
base_url="https://arg.gemini.com",
api_version=None,
azure_deployment=None,
model_name="arg_gem_small",
max_tokens_str="124",
small_model_name="arg_gem_small",
small_model_max_tokens_str="124",
large_model_name="arg_gem_large",
large_model_max_tokens_str="457",
system_prompt_content="Gemini prompt",
system_prompt_path=None,
is_gemini=True,
max_history_length=20
)
@patch('openai.OpenAI') # Gemini bot uses OpenAI client configured for Gemini endpoint
async def test_switch_model_logic(self, mock_openai_constructor):
mock_openai_constructor.return_value = self.mock_openai_client
os.environ["GEMINI_SMALL_MODEL"] = "env-gemini-small"
os.environ["GEMINI_SMALL_MODEL_MAX_TOKENS"] = "110"
os.environ["GEMINI_LARGE_MODEL"] = "env-gemini-large"
os.environ["GEMINI_LARGE_MODEL_MAX_TOKENS"] = "220"
bot = GeminiTelegramInferenceBot() # Uses env vars by default
self.assertEqual(bot.model, "env-gemini-small")
self.assertEqual(bot.max_tokens, 110)
status = await bot.switch_model()
self.assertEqual(bot.model, "env-gemini-large")
self.assertEqual(bot.max_tokens, 220)
self.assertEqual(status, "Switched to model: env-gemini-large")
status = await bot.switch_model()
self.assertEqual(bot.model, "env-gemini-small")
self.assertEqual(bot.max_tokens, 110)
self.assertEqual(status, "Switched to model: env-gemini-small")
@patch('openai.OpenAI')
async def test_switch_model_uses_instance_configs_if_provided(self, mock_openai_constructor):
mock_openai_constructor.return_value = self.mock_openai_client
bot = GeminiTelegramInferenceBot(
small_model_name="init-gem-small", small_model_max_tokens="55",
large_model_name="init-gem-large", large_model_max_tokens="155"
)
self.assertEqual(bot.model, "init-gem-small")
self.assertEqual(bot.max_tokens, 55)
status = await bot.switch_model()
self.assertEqual(bot.model, "init-gem-large")
self.assertEqual(bot.max_tokens, 155)
self.assertEqual(status, "Switched to model: init-gem-large")
status = await bot.switch_model()
self.assertEqual(bot.model, "init-gem-small")
self.assertEqual(bot.max_tokens, 55)
self.assertEqual(status, "Switched to model: init-gem-small")
@patch('openai.OpenAI')
def test_get_llm_description_for_gemini_bot(self, mock_openai_constructor):
mock_openai_constructor.return_value = self.mock_openai_client
bot = GeminiTelegramInferenceBot(
small_model_name="gemini-pro-desc",
small_model_max_tokens="888",
# is_gemini is True by default in constructor call to super
)
# LLM description should indicate not Azure, even though it uses OpenAICompatible... base
# The is_gemini flag primarily affects client instantiation logic in the superclass.
# The azure_openai flag in superclass is based on azure_endpoint presence.
self.assertEqual(bot.get_llm_description(), "LLM: gemini-pro-desc, Max Tokens: 888, Azure: False")
if __name__ == '__main__':
unittest.main()
-81
View File
@@ -1,81 +0,0 @@
# tests/test_github_tool.py
import unittest
from unittest.mock import patch, MagicMock
from tools.github_tool import GitHubTool
class TestGitHubTool(unittest.TestCase):
def setUp(self):
self.github_tool = GitHubTool()
def test_get_functions(self):
functions = self.github_tool.get_functions()
self.assertEqual(len(functions), 4)
function_names = [f["name"] for f in functions]
expected_names = ["read_file", "create_branch", "commit_file", "create_pull_request"]
self.assertListEqual(function_names, expected_names)
@patch('tools.github_tool.requests.get')
def test_read_file(self, mock_get):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {"content": "file content"}
mock_get.return_value = mock_response
result = self.github_tool.execute("read_file", path="test.txt")
self.assertEqual(result, "file content")
mock_get.assert_called_once()
@patch('tools.github_tool.requests.get')
@patch('tools.github_tool.requests.post')
def test_create_branch(self, mock_post, mock_get):
mock_get_response = MagicMock()
mock_get_response.status_code = 200
mock_get_response.json.return_value = {"object": {"sha": "test_sha"}}
mock_get.return_value = mock_get_response
mock_post_response = MagicMock()
mock_post_response.status_code = 201
mock_post.return_value = mock_post_response
result = self.github_tool.execute("create_branch", branch_name="test-branch")
self.assertEqual(result, "Branch 'test-branch' created successfully")
mock_get.assert_called_once()
mock_post.assert_called_once()
@patch('tools.github_tool.requests.put')
def test_commit_file(self, mock_put):
mock_response = MagicMock()
mock_response.status_code = 200
mock_put.return_value = mock_response
result = self.github_tool.execute("commit_file", branch_name="test-branch", file_path="test.txt", content="test content", commit_message="Test commit")
self.assertEqual(result, "File committed successfully to branch 'test-branch'")
mock_put.assert_called_once()
def test_commit_file_to_main(self):
result = self.github_tool.execute("commit_file", branch_name="main", file_path="test.txt", content="test content", commit_message="Test commit")
self.assertEqual(result, "Cannot commit directly to main branch")
@patch('tools.github_tool.requests.post')
def test_create_pull_request(self, mock_post):
mock_response = MagicMock()
mock_response.status_code = 201
mock_response.json.return_value = {"html_url": "https://github.com/test/test/pull/1"}
mock_post.return_value = mock_response
result = self.github_tool.execute("create_pull_request", title="Test PR", body="Test body", head="test-branch")
self.assertEqual(result, "Pull request created successfully: https://github.com/test/test/pull/1")
mock_post.assert_called_once()
def test_unknown_function(self):
result = self.github_tool.execute("unknown_function")
self.assertEqual(result, "Unknown function: unknown_function")
if __name__ == '__main__':
unittest.main()
@@ -1,332 +0,0 @@
import unittest
from unittest.mock import MagicMock, patch, AsyncMock, ANY
import os
import json
# Assuming openai_compatible_inference_bot.py is in the parent directory or PYTHONPATH is set
from openai_compatible_inference_bot import OpenAICompatibleInferenceBot
# Mock response from OpenAI client's chat.completions.create
def create_mock_openai_response(content=None, tool_calls=None):
mock_message = MagicMock()
mock_message.role = "assistant"
mock_message.content = content
if tool_calls:
# tool_calls should be a list of objects with id and function (name, arguments)
mock_tool_calls = []
for tc in tool_calls:
mock_tc = MagicMock()
mock_tc.id = tc["id"]
mock_tc.function.name = tc["function"]["name"]
mock_tc.function.arguments = tc["function"]["arguments"]
mock_tool_calls.append(mock_tc)
mock_message.tool_calls = mock_tool_calls
else:
mock_message.tool_calls = None
mock_choice = MagicMock()
mock_choice.message = mock_message
mock_response = MagicMock()
mock_response.choices = [mock_choice]
return mock_response
# Concrete class for testing
class ConcreteOpenAICompatibleBot(OpenAICompatibleInferenceBot):
# Implement abstract methods for instantiation
async def switch_model(self):
# Simple switch for testing if needed, or just pass
if self.model == self.small_model_name:
self._configure_model_and_tokens(self.large_model_name, self.large_model_max_tokens_str)
else:
self._configure_model_and_tokens(self.small_model_name, self.small_model_max_tokens_str)
return f"Switched to {self.model}"
# Override load_functions if it's called by parent and needs mocking for these tests
# (OpenAICompatibleInferenceBot's __init__ calls BaseTelegramInferenceBot's __init__, which calls load_functions)
def load_functions(self):
# For these tests, assume no tools unless specifically added
self.tools = []
self.functions = []
return self.tools, self.functions
class TestOpenAICompatibleInferenceBot(unittest.IsolatedAsyncioTestCase):
def setUp(self):
self.original_openai_api_key = os.environ.get("OPENAI_API_KEY")
self.original_azure_openai_key = os.environ.get("AZURE_OPENAI_KEY")
self.original_azure_endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT")
self.original_api_version = os.environ.get("AZURE_OPENAI_API_VERSION")
self.original_azure_deployment = os.environ.get("AZURE_DEPLOYMENT_NAME")
# Clear relevant env vars before each test
for key in ["OPENAI_API_KEY", "AZURE_OPENAI_KEY", "AZURE_OPENAI_ENDPOINT",
"AZURE_OPENAI_API_VERSION", "AZURE_DEPLOYMENT_NAME", "SYSTEM_PROMPT_PATH"]:
if os.environ.get(key):
del os.environ[key]
self.mock_openai_client_instance = MagicMock()
self.mock_openai_client_instance.chat.completions.create = MagicMock()
def tearDown(self):
# Restore environment variables
if self.original_openai_api_key: os.environ["OPENAI_API_KEY"] = self.original_openai_api_key
if self.original_azure_openai_key: os.environ["AZURE_OPENAI_KEY"] = self.original_azure_openai_key
if self.original_azure_endpoint: os.environ["AZURE_OPENAI_ENDPOINT"] = self.original_azure_endpoint
if self.original_api_version: os.environ["AZURE_OPENAI_API_VERSION"] = self.original_api_version
if self.original_azure_deployment: os.environ["AZURE_DEPLOYMENT_NAME"] = self.original_azure_deployment
@patch('openai.OpenAI')
def test_init_with_openai_defaults(self, MockOpenAIConstructor):
MockOpenAIConstructor.return_value = self.mock_openai_client_instance
os.environ["OPENAI_API_KEY"] = "test_openai_key"
bot = ConcreteOpenAICompatibleBot(model_name="gpt-4")
MockOpenAIConstructor.assert_called_once_with(api_key="test_openai_key", base_url=None)
self.assertEqual(bot.client, self.mock_openai_client_instance)
self.assertEqual(bot.model, "gpt-4")
self.assertEqual(bot.max_tokens, 1000) # Default from _configure_model_and_tokens
self.assertEqual(bot.azure_openai, False)
@patch('openai.OpenAI')
def test_init_with_provided_client(self, MockOpenAIConstructor):
preconfigured_client = MagicMock()
bot = ConcreteOpenAICompatibleBot(client=preconfigured_client, model_name="gpt-3.5")
MockOpenAIConstructor.assert_not_called()
self.assertEqual(bot.client, preconfigured_client)
self.assertEqual(bot.model, "gpt-3.5")
@patch('openai.AzureOpenAI')
def test_init_with_azure_config_args(self, MockAzureOpenAIConstructor):
MockAzureOpenAIConstructor.return_value = self.mock_openai_client_instance
bot = ConcreteOpenAICompatibleBot(
api_key="azure_key",
azure_endpoint="https://myenv.openai.azure.com",
api_version="2023-05-15",
azure_deployment="my-gpt-4", # This should be used as model_name for API call
model_name="should_be_overridden_by_azure_deployment_for_api"
# model_name is passed to _configure_model_and_tokens, which sets self.model for display/logging
# but for Azure, the client needs the deployment name.
)
MockAzureOpenAIConstructor.assert_called_once_with(
api_key="azure_key",
azure_endpoint="https://myenv.openai.azure.com",
api_version="2023-05-15"
)
self.assertEqual(bot.client, self.mock_openai_client_instance)
self.assertEqual(bot.model, "my-gpt-4") # Azure deployment name becomes the model for API calls
self.assertEqual(bot.azure_openai, True)
@patch('openai.AzureOpenAI')
def test_init_with_azure_env_vars(self, MockAzureOpenAIConstructor):
MockAzureOpenAIConstructor.return_value = self.mock_openai_client_instance
os.environ["AZURE_OPENAI_KEY"] = "env_azure_key"
os.environ["AZURE_OPENAI_ENDPOINT"] = "https://env.openai.azure.com"
os.environ["AZURE_OPENAI_API_VERSION"] = "2023-06-01"
os.environ["AZURE_DEPLOYMENT_NAME"] = "env-gpt-35" # Used as model_name
bot = ConcreteOpenAICompatibleBot(model_name="ignored_if_azure_deployment_env_is_set")
MockAzureOpenAIConstructor.assert_called_once_with(
api_key="env_azure_key",
azure_endpoint="https://env.openai.azure.com",
api_version="2023-06-01"
)
self.assertEqual(bot.model, "env-gpt-35")
self.assertTrue(bot.azure_openai)
@patch('openai.OpenAI')
def test_init_with_gemini_config_args(self, MockOpenAIConstructor):
MockOpenAIConstructor.return_value = self.mock_openai_client_instance
bot = ConcreteOpenAICompatibleBot(
api_key="gemini_key",
base_url="https://gemini.example.com",
model_name="gemini-pro",
is_gemini=True
)
MockOpenAIConstructor.assert_called_once_with(api_key="gemini_key", base_url="https://gemini.example.com")
self.assertEqual(bot.model, "gemini-pro")
self.assertFalse(bot.azure_openai) # is_gemini doesn't mean azure_openai
def test_configure_model_and_tokens(self):
bot = ConcreteOpenAICompatibleBot(model_name="initial_model") # init calls _configure
bot._configure_model_and_tokens("test-model", "500")
self.assertEqual(bot.model, "test-model")
self.assertEqual(bot.max_tokens, 500)
bot._configure_model_and_tokens("test-model-2", None, default_max_tokens=150)
self.assertEqual(bot.max_tokens, 150)
bot._configure_model_and_tokens("test-model-3", "invalid_token_val")
self.assertEqual(bot.max_tokens, 1000) # Default fallback
def test_get_llm_description(self):
bot = ConcreteOpenAICompatibleBot(model_name="desc-model", max_tokens_str="256")
self.assertEqual(bot.get_llm_description(), "LLM: desc-model, Max Tokens: 256, Azure: False")
bot_azure = ConcreteOpenAICompatibleBot(azure_deployment="azure-model", azure_endpoint="x", api_key="y", api_version="z")
self.assertEqual(bot_azure.get_llm_description(), "LLM: azure-model, Max Tokens: 1000, Azure: True")
def test_get_chat_response_success(self):
bot = ConcreteOpenAICompatibleBot(client=self.mock_openai_client_instance, model_name="test-gpt")
bot.max_tokens = 50 # Ensure this is set
mock_api_response = create_mock_openai_response(content="Hello from API")
self.mock_openai_client_instance.chat.completions.create.return_value = mock_api_response
messages = [{"role": "user", "content": "Hi"}]
response = bot.get_chat_response(messages)
self.mock_openai_client_instance.chat.completions.create.assert_called_once_with(
model="test-gpt",
messages=messages,
tools=ANY, # Assuming functions can be None or empty list
tool_choice=ANY,
max_tokens=50
)
self.assertEqual(response, mock_api_response)
def test_get_chat_response_api_error(self):
bot = ConcreteOpenAICompatibleBot(client=self.mock_openai_client_instance, model_name="error-gpt")
self.mock_openai_client_instance.chat.completions.create.side_effect = Exception("API Down")
with self.assertRaisesRegex(Exception, "API Down"):
bot.get_chat_response([{"role": "user", "content": "trigger"}])
async def test_handle_message_simple_response(self):
bot = ConcreteOpenAICompatibleBot(client=self.mock_openai_client_instance, model_name="chatty")
bot.system_prompt = "You are a test bot." # Set directly for simplicity
mock_api_response = create_mock_openai_response(content="Test reply")
self.mock_openai_client_instance.chat.completions.create.return_value = mock_api_response
response_content = await bot.handle_message(user_id=1, user_message="Hello")
self.assertEqual(response_content, "Test reply")
self.assertIn(1, bot.conversation_history)
self.assertEqual(len(bot.conversation_history[1]), 3) # System, User, Assistant
self.assertEqual(bot.conversation_history[1][0]["content"], "You are a test bot.")
self.assertEqual(bot.conversation_history[1][2]["content"], "Test reply")
async def test_handle_message_with_tool_call_and_response(self):
bot = ConcreteOpenAICompatibleBot(client=self.mock_openai_client_instance, model_name="tool-user")
# Mock functions/tools setup on the bot
mock_tool_def = {"function": {"name": "get_weather", "description": "Gets weather", "parameters": {}}}
bot.functions = [mock_tool_def] # Simulate tools are loaded
# API response 1: Request to call tool
tool_call_request = [{"id": "call123", "function": {"name": "get_weather", "arguments": '''{"location": "moon"}'''}}]
api_response_1 = create_mock_openai_response(tool_calls=tool_call_request)
# API response 2: Final answer after tool execution
api_response_2 = create_mock_openai_response(content="The weather on the moon is chilly.")
self.mock_openai_client_instance.chat.completions.create.side_effect = [api_response_1, api_response_2]
# Mock self.call_tool
bot.call_tool = MagicMock(return_value='''{"temperature": "-100 C"}''')
final_response = await bot.handle_message(user_id=2, user_message="Weather on moon?")
self.assertEqual(final_response, "The weather on the moon is chilly.")
bot.call_tool.assert_called_once_with("get_weather", '''{"location": "moon"}''')
# Check conversation history includes tool messages
history = bot.conversation_history[2]
self.assertTrue(any(msg["role"] == "assistant" and msg.tool_calls is not None for msg in history))
self.assertTrue(any(msg["role"] == "tool" and msg["name"] == "get_weather" for msg in history))
self.assertEqual(self.mock_openai_client_instance.chat.completions.create.call_count, 2)
async def test_handle_message_max_history_length(self):
bot = ConcreteOpenAICompatibleBot(client=self.mock_openai_client_instance, model_name="hist-test", max_history_length=3)
self.mock_openai_client_instance.chat.completions.create.return_value = create_mock_openai_response(content="Ok")
await bot.handle_message(1, "Msg1") # Sys, User, Assist (3)
self.assertEqual(len(bot.conversation_history[1]), 3)
await bot.handle_message(1, "Msg2") # User, Assist. Should be 3 (prev User, prev Assist, new User) -> then adds new Assist.
# Before new call: [Sys, U1, A1]. New U2. Call with [Sys,U1,A1,U2]. Resp A2.
# History: [Sys,U1,A1,U2,A2]. Limit 3. -> [A1,U2,A2] (if system is not preserved specially)
# The current code appends to history then truncates if over limit.
# So after Msg1: [S, U1, A1]. len=3.
# For Msg2: History is [S, U1, A1]. Append U2. Call with [S,U1,A1,U2]. Append A2.
# History now [S,U1,A1,U2,A2]. len=5. Truncate to 3.
# Expected: [A1, U2, A2] or [U1,A1,U2] or [U2,A2,S] depending on how system prompt is handled in truncation.
# The code is: self.conversation_history[user_id][-self.max_history_length:]
# And system prompt is only added IF user_id not in self.conversation_history.
# So, for Msg2, system prompt is not re-added.
# History before Msg2 call: [S, U1, A1]
# Messages for Msg2 call: [S, U1, A1, U2]
# History after Msg2 response A2: [S, U1, A1, U2, A2]. Len 5.
# Truncated to self.max_history_length=3: [A1, U2, A2]
# Call 1
self.mock_openai_client_instance.chat.completions.create.reset_mock()
self.mock_openai_client_instance.chat.completions.create.return_value = create_mock_openai_response(content="Reply1")
await bot.handle_message(user_id=7, user_message="First message")
self.assertEqual(len(bot.conversation_history[7]), 3) # System, User1, Assistant1
# Call 2
self.mock_openai_client_instance.chat.completions.create.reset_mock()
self.mock_openai_client_instance.chat.completions.create.return_value = create_mock_openai_response(content="Reply2")
await bot.handle_message(user_id=7, user_message="Second message")
# History before call: [S, U1, A1]. Messages for call: [S, U1, A1, U2]. History after: [S, U1, A1, U2, A2].
# Truncated to 3: [A1, U2, A2]
self.assertEqual(len(bot.conversation_history[7]), 3)
self.assertEqual(bot.conversation_history[7][0]["content"], "Reply1") # A1
self.assertEqual(bot.conversation_history[7][1]["content"], "Second message") # U2
self.assertEqual(bot.conversation_history[7][2]["content"], "Reply2") # A2
# Call 3
self.mock_openai_client_instance.chat.completions.create.reset_mock()
self.mock_openai_client_instance.chat.completions.create.return_value = create_mock_openai_response(content="Reply3")
await bot.handle_message(user_id=7, user_message="Third message")
# History before call: [A1, U2, A2]. Messages for call: [A1, U2, A2, U3]. History after: [A1, U2, A2, U3, A3].
# Truncated to 3: [A2, U3, A3]
self.assertEqual(len(bot.conversation_history[7]), 3)
self.assertEqual(bot.conversation_history[7][0]["content"], "Reply2") # A2
self.assertEqual(bot.conversation_history[7][1]["content"], "Third message") # U3
self.assertEqual(bot.conversation_history[7][2]["content"], "Reply3") # A3
async def test_abort_processing(self):
bot = ConcreteOpenAICompatibleBot(model_name="test")
user_id = 123
bot.processing_status[user_id] = {"processing": True, "message_id": 456}
bot.conversation_history[user_id] = [{"role": "user", "content": "stuff"}]
with patch.object(bot, 'clear_conversation_history') as mock_clear_hist: # Patching the method from Base class
result = await bot.abort_processing(user_id)
self.assertEqual(result, "Processing aborted and conversation cleared.")
self.assertFalse(bot.processing_status[user_id]["processing"])
mock_clear_hist.assert_called_once_with(user_id)
async def test_abort_processing_no_active_processing(self):
bot = ConcreteOpenAICompatibleBot(model_name="test")
user_id = 404 # Not in processing_status
with patch.object(bot, 'clear_conversation_history') as mock_clear_hist:
result = await bot.abort_processing(user_id)
self.assertEqual(result, "No active processing found to abort. Conversation cleared.")
mock_clear_hist.assert_called_once_with(user_id)
# Test for the abstract switch_model (basic call, actual logic in concrete class for this test)
async def test_switch_model_concrete_implementation(self):
bot = ConcreteOpenAICompatibleBot(model_name="model1", small_model_name="model1", large_model_name="model2", max_tokens_str="100")
self.assertEqual(bot.model, "model1")
await bot.switch_model() # Calls the concrete implementation
self.assertEqual(bot.model, "model2")
await bot.switch_model()
self.assertEqual(bot.model, "model1")
if __name__ == '__main__':
unittest.main()
-356
View File
@@ -1,356 +0,0 @@
import unittest
from unittest.mock import MagicMock, patch, mock_open, AsyncMock
import asyncio
import os
import sys
# Assuming telegram_helper.py is in the parent directory or PYTHONPATH is set
from telegram_helper import TelegramHelper, MessageHandlerLogicResult
# Mock for the bot passed to TelegramHelper
class MockBot:
def __init__(self):
self.start = AsyncMock()
self.clear_conversation_history = MagicMock()
self.get_bot_status = AsyncMock(return_value="Bot Status OK")
self.switch_model = AsyncMock(return_value="Model Switched OK")
self.handle_message = AsyncMock() # Needs to return a string
self.abort_processing = AsyncMock(return_value="Abort OK")
self.set_processing_status = MagicMock()
self.clear_processing_status = MagicMock()
self.processing_status = {} # Add the attribute
# Mock for telegram.Update and related objects
def create_mock_update(message_text=None, user_id=123, chat_id=456, message_id=789, callback_query_data=None):
update = MagicMock()
update.effective_user.id = user_id
update.effective_chat.id = chat_id
if message_text:
update.message.text = message_text
update.message.reply_text = AsyncMock(return_value=MagicMock(message_id=message_id)) # reply_text returns a Message obj
if callback_query_data:
update.callback_query.data = callback_query_data
update.callback_query.from_user.id = user_id
update.callback_query.answer = AsyncMock()
update.callback_query.edit_message_text = AsyncMock()
return update
# Mock for telegram.ext.ContextTypes.DEFAULT_TYPE
def create_mock_context():
context = MagicMock()
context.bot.delete_message = AsyncMock()
context.bot.edit_message_text = AsyncMock() # For update_status_message
return context
class TestTelegramHelper(unittest.IsolatedAsyncioTestCase): # Use IsolatedAsyncioTestCase for async methods
def setUp(self):
self.mock_bot = MockBot()
# Default paths for reboot files, can be overridden in tests
self.reboot_claude_file = ".test_reboot_claude"
self.reboot_file = ".test_doreboot"
self.helper = TelegramHelper(
self.mock_bot,
reboot_claude_file_path=self.reboot_claude_file,
reboot_file_path=self.reboot_file,
chunk_message_sleep_duration=0.001 # Faster sleep for tests
)
# Clean up any potential leftover reboot files from previous runs
if os.path.exists(self.reboot_claude_file):
os.remove(self.reboot_claude_file)
if os.path.exists(self.reboot_file):
os.remove(self.reboot_file)
def tearDown(self):
# Clean up reboot files created during tests
if os.path.exists(self.reboot_claude_file):
os.remove(self.reboot_claude_file)
if os.path.exists(self.reboot_file):
os.remove(self.reboot_file)
async def test_start_logic(self):
response = await self.helper._start_logic()
self.mock_bot.start.assert_called_once()
self.assertEqual(response, "Hello! I\'m your AI assistant. How can I help you today?")
async def test_start_command(self):
mock_update = create_mock_update(message_text="/start")
mock_context = create_mock_context()
with patch.object(self.helper, \'_start_logic\', new_callable=AsyncMock) as mock_logic:
mock_logic.return_value = "Start Logic Response"
await self.helper.start(mock_update, mock_context)
mock_logic.assert_called_once()
mock_update.message.reply_text.assert_called_once_with("Start Logic Response")
async def test_clear_logic(self):
user_id = 123
response = await self.helper._clear_logic(user_id) # _clear_logic is async after refactor
self.mock_bot.clear_conversation_history.assert_called_once_with(user_id)
self.assertEqual(response, "Conversation history cleared. Let\'s start fresh!")
async def test_clear_command(self):
mock_update = create_mock_update(message_text="/clear", user_id=123)
mock_context = create_mock_context()
with patch.object(self.helper, \'_clear_logic\', new_callable=AsyncMock) as mock_logic:
mock_logic.return_value = "Clear Logic Response"
await self.helper.clear(mock_update, mock_context)
mock_logic.assert_called_once_with(123)
mock_update.message.reply_text.assert_called_once_with("Clear Logic Response")
async def test_status_logic(self):
self.mock_bot.get_bot_status.return_value = "Test Status"
response = await self.helper._status_logic()
self.mock_bot.get_bot_status.assert_called_once()
self.assertEqual(response, "Test Status")
async def test_switch_logic_supported(self):
self.mock_bot.switch_model.return_value = "Switched to Large Model"
response = await self.helper._switch_logic()
self.mock_bot.switch_model.assert_called_once()
self.assertEqual(response, "Switched to Large Model")
async def test_switch_logic_not_supported(self):
del self.mock_bot.switch_model # Simulate bot not having the attribute
response = await self.helper._switch_logic()
self.assertEqual(response, "Model switching is not supported for this bot.")
async def test_handle_message_logic_success(self):
user_id = 100
user_message = "Hello bot"
bot_response = "Hello user <think>Thinking hard</think> Done."
expected_processed_response = f"Hello user {self.helper.HTML_QUOTE_BLOCK_START}Thinking hard{self.helper.HTML_QUOTE_BLOCK_END} Done."
self.mock_bot.handle_message.return_value = bot_response
result = await self.helper._handle_message_logic(user_id, user_message)
self.mock_bot.handle_message.assert_called_once_with(user_id, user_message)
self.assertTrue(result["success"])
self.assertEqual(result["response_text"], expected_processed_response)
self.assertIsNone(result["error_message"])
async def test_handle_message_logic_bot_exception(self):
user_id = 101
user_message = "Trigger error"
self.mock_bot.handle_message.side_effect = Exception("Bot Error")
result = await self.helper._handle_message_logic(user_id, user_message)
self.assertFalse(result["success"])
self.assertIsNone(result["response_text"])
self.assertEqual(result["error_message"], "Bot Error")
@patch(\'logging.error\')
async def test_handle_message_command_success_short_message(self, mock_logging_error):
mock_update = create_mock_update(message_text="Hi", user_id=200, chat_id=201, message_id=202)
mock_context = create_mock_context()
logic_result = MessageHandlerLogicResult(success=True, response_text="Short response", error_message=None)
with patch.object(self.helper, \'_handle_message_logic\', new_callable=AsyncMock) as mock_message_logic:
mock_message_logic.return_value = logic_result
await self.helper.handle_message(mock_update, mock_context)
mock_update.message.reply_text.assert_any_call("Processing your request...", reply_markup=unittest.mock.ANY)
self.mock_bot.set_processing_status.assert_called_once_with(200, 202) # user_id, status_message_id
mock_message_logic.assert_called_once_with(200, "Hi")
mock_context.bot.delete_message.assert_called_once_with(chat_id=201, message_id=202)
self.mock_bot.clear_processing_status.assert_called_once_with(200)
mock_update.message.reply_text.assert_any_call("Short response") # Final response
self.assertEqual(mock_update.message.reply_text.call_count, 2) # Processing + final
@patch(\'logging.error\')
async def test_handle_message_command_success_long_message_chunks(self, mock_logging_error):
mock_update = create_mock_update(message_text="Long text", user_id=200, chat_id=201, message_id=202)
mock_context = create_mock_context()
long_response_text = "a" * 5000 # Longer than 4096
chunk1 = long_response_text[:4096]
chunk2 = long_response_text[4096:]
logic_result = MessageHandlerLogicResult(success=True, response_text=long_response_text, error_message=None)
with patch.object(self.helper, \'_handle_message_logic\', new_callable=AsyncMock) as mock_message_logic, \
patch(\'asyncio.sleep\', new_callable=AsyncMock) as mock_sleep: # Mock sleep
mock_message_logic.return_value = logic_result
await self.helper.handle_message(mock_update, mock_context)
mock_update.message.reply_text.assert_any_call(chunk1)
mock_update.message.reply_text.assert_any_call(chunk2)
mock_sleep.assert_called_once_with(self.helper.chunk_message_sleep_duration)
self.assertEqual(mock_update.message.reply_text.call_count, 3) # Processing + 2 chunks
@patch(\'logging.error\')
async def test_handle_message_command_logic_fails(self, mock_logging_error):
mock_update = create_mock_update(message_text="Cause error in logic", user_id=200)
mock_context = create_mock_context()
logic_result = MessageHandlerLogicResult(success=False, response_text=None, error_message="Logic Failed")
with patch.object(self.helper, \'_handle_message_logic\', new_callable=AsyncMock) as mock_message_logic:
mock_message_logic.return_value = logic_result
await self.helper.handle_message(mock_update, mock_context)
mock_update.message.reply_text.assert_any_call("Sorry, an error occurred while processing your request.")
self.assertEqual(mock_update.message.reply_text.call_count, 2) # Processing + error message
@patch(\'logging.error\')
async def test_handle_message_command_telegram_exception_after_logic(self, mock_logging_error):
mock_update = create_mock_update(message_text="Test", user_id=200)
mock_context = create_mock_context()
logic_result = MessageHandlerLogicResult(success=True, response_text="OK", error_message=None)
# Make sending the final reply fail
mock_update.message.reply_text.side_effect = [
MagicMock(message_id=202), # For "Processing..."
Exception("Telegram API Error") # For the actual response
]
with patch.object(self.helper, \'_handle_message_logic\', new_callable=AsyncMock) as mock_message_logic:
mock_message_logic.return_value = logic_result
await self.helper.handle_message(mock_update, mock_context)
# Check if the generic error message was attempted
# This is tricky because reply_text is already mocked with side_effect.
# We\'d expect logs. Let\'s check logs or if processing status was cleared.
self.mock_bot.clear_processing_status.assert_called_once_with(200)
mock_logging_error.assert_any_call(unittest.mock.string_containing("Outer error in handle_message"))
async def test_abort_processing_logic(self):
user_id = 300
self.mock_bot.abort_processing.return_value = "Aborted by bot"
response = await self.helper._abort_processing_logic(user_id)
self.mock_bot.abort_processing.assert_called_once_with(user_id)
self.assertEqual(response, "Aborted by bot")
async def test_abort_processing_command(self):
mock_update = create_mock_update(callback_query_data=\'abort\', user_id=301)
mock_context = create_mock_context()
with patch.object(self.helper, \'_abort_processing_logic\', new_callable=AsyncMock) as mock_logic:
mock_logic.return_value = "Abort Logic Done"
await self.helper.abort_processing(mock_update, mock_context)
mock_update.callback_query.answer.assert_called_once()
mock_logic.assert_called_once_with(301)
mock_update.callback_query.edit_message_text.assert_called_once_with(text="Abort Logic Done")
def test_reboot_logic_claude_and_main(self):
user_message_parts = ["/reboot", "claude"]
chat_id_to_write = "12345"
with patch("builtins.open", mock_open()) as mock_file:
self.helper._reboot_logic(user_message_parts, chat_id_to_write)
# Check claude reboot file
mock_file.assert_any_call(self.reboot_claude_file, \'w\')
# Check main doreboot file
mock_file.assert_any_call(self.reboot_file, \'w\')
handle_claude = mock_file.return_value
handle_main = mock_file.return_value # mock_open reuses the handle for multiple calls
# Check if write was called for claude file (empty write)
# This part of assertion is tricky with single mock_file. Better to use different mocks if possible
# or check the sequence of calls if the mock supports it well.
# For now, assert_any_call ensures it was opened.
# Check content for main reboot file
# Need to ensure the write for self.reboot_file had chat_id_to_write
# This requires more sophisticated mock_open or patching os.path.exists and multiple open calls
# Simpler check: was open(self.reboot_file, \'w\') called? Yes, via assert_any_call.
# And was open(self.reboot_claude_file, \'w\') called? Yes.
# Verify files were created (mock_open doesn\'t actually create them)
# This test relies on mock_open\'s behavior. To test file content, need more setup.
# For now, assume open was called correctly.
def test_reboot_logic_main_only(self):
user_message_parts = ["/reboot"]
chat_id_to_write = "67890"
with patch("builtins.open", mock_open()) as mock_file:
self.helper._reboot_logic(user_message_parts, chat_id_to_write)
# Ensure claude file was NOT opened for writing.
# This requires asserting that a specific call didn\'t happen, or checking call_args_list
claude_call = unittest.mock.call(self.reboot_claude_file, \'w\')
self.assertNotIn(claude_call, mock_file.call_args_list)
mock_file.assert_any_call(self.reboot_file, \'w\')
@patch(\'sys.exit\') # Mock sys.exit to prevent test runner from exiting
async def test_reboot_command(self, mock_sys_exit):
mock_update = create_mock_update(message_text="/reboot claude", chat_id="chat1")
mock_context = create_mock_context()
with patch.object(self.helper, \'_reboot_logic\') as mock_reboot_file_logic:
await self.helper.reboot(mock_update, mock_context)
mock_reboot_file_logic.assert_called_once_with(["/reboot", "claude"], "chat1")
mock_update.message.reply_text.assert_called_once_with("Rebooting the bot...")
mock_sys_exit.assert_called_once_with(0)
@patch(\'os.path.exists\')
@patch(\'builtins.open\', new_callable=mock_open)
@patch(\'os.remove\')
async def test_check_doreboot_file_logic_file_exists(self, mock_os_remove, mock_file_open, mock_os_path_exists):
mock_os_path_exists.return_value = True
mock_file_open.return_value.read.return_value.strip.return_value = "chat123"
chat_id = await self.helper._check_doreboot_file_logic()
mock_os_path_exists.assert_called_once_with(self.reboot_file)
mock_file_open.assert_called_once_with(self.reboot_file, \'r\')
mock_os_remove.assert_called_once_with(self.reboot_file)
self.assertEqual(chat_id, "chat123")
@patch(\'os.path.exists\', return_value=False)
async def test_check_doreboot_file_logic_file_not_exists(self, mock_os_path_exists):
chat_id = await self.helper._check_doreboot_file_logic()
mock_os_path_exists.assert_called_once_with(self.reboot_file)
self.assertIsNone(chat_id)
@patch(\'logging.error\')
@patch(\'os.path.exists\', return_value=True)
@patch(\'builtins.open\', side_effect=IOError("Read error"))
@patch(\'os.remove\') # To check if remove is called even on read error
async def test_check_doreboot_file_logic_read_error(self, mock_os_remove, mock_file_open, mock_os_path_exists, mock_log_error):
chat_id = await self.helper._check_doreboot_file_logic()
self.assertIsNone(chat_id)
mock_log_error.assert_any_call(unittest.mock.string_containing("Error reading reboot file"))
# Check if os.remove was attempted even after read error
mock_os_remove.assert_called_once_with(self.reboot_file)
async def test_check_doreboot_file_command_sends_message(self):
mock_application = MagicMock()
mock_application.bot.send_message = AsyncMock()
with patch.object(self.helper, \'_check_doreboot_file_logic\', new_callable=AsyncMock) as mock_logic:
mock_logic.return_value = "chat789" # Simulate chat_id found
await self.helper.check_doreboot_file(mock_application)
mock_logic.assert_called_once()
mock_application.bot.send_message.assert_called_once_with(
chat_id="chat789", text="The application has finished initializing."
)
async def test_check_doreboot_file_command_no_chat_id(self):
mock_application = MagicMock()
mock_application.bot.send_message = AsyncMock()
with patch.object(self.helper, \'_check_doreboot_file_logic\', new_callable=AsyncMock) as mock_logic:
mock_logic.return_value = None # Simulate no chat_id found
await self.helper.check_doreboot_file(mock_application)
mock_logic.assert_called_once()
mock_application.bot.send_message.assert_not_called()
# Note: Testing the run() method itself is more of an integration test,
# as it involves setting up the full Application and polling loop.
# Unit tests here focus on the helper\'s own logic methods.
if __name__ == \'__main__\':
unittest.main()
-307
View File
@@ -1,307 +0,0 @@
import unittest
from unittest.mock import MagicMock, patch
import os
import base64
import logging
import requests # Required for spec in MagicMock
# Ensure tools/github_tool.py is accessible
from tools.github_tool import GitHubTool
# Helper to create a mock response for requests.Session
def create_mock_response(status_code, json_data=None, text_data=None, headers=None, links=None):
mock_resp = MagicMock()
mock_resp.status_code = status_code
if json_data is not None:
mock_resp.json = MagicMock(return_value=json_data)
mock_resp.text = text_data if text_data is not None else str(json_data)
mock_resp.headers = headers if headers else {}
mock_resp.links = links if links else {} # For pagination in _list_branches
return mock_resp
class TestGitHubTool(unittest.TestCase):
def setUp(self):
self.mock_session = MagicMock(spec=requests.Session)
self.mock_session.headers = {} # Simulate a new session's headers
self.test_token = "test_github_token"
self.test_repo = "owner/repo"
self.test_base_url = "https://api.example.com" # Use a non-default base_url for some tests
# Suppress logging output during tests unless explicitly testing for it
self.logger = logging.getLogger('tools.github_tool')
# Ensure only one NullHandler to prevent duplicate messages if tests run multiple times in a session
if not any(isinstance(h, logging.NullHandler) for h in self.logger.handlers):
self.logger.addHandler(logging.NullHandler())
self.logger.propagate = False # Prevent propagation to root logger if it has handlers
def test_init_with_args_and_session(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, base_url=self.test_base_url, logger=self.logger)
self.assertEqual(tool.session, self.mock_session)
self.assertEqual(tool._token, self.test_token)
self.assertEqual(tool._repo, self.test_repo)
self.assertEqual(tool.base_url, self.test_base_url)
self.assertEqual(tool.current_branch, "main") # Default initial branch
@patch('requests.Session')
def test_init_creates_session_if_not_provided(self, MockSessionConstructor):
mock_created_session = MagicMock(spec=requests.Session)
mock_created_session.headers = {}
MockSessionConstructor.return_value = mock_created_session
# Temporarily set env vars for this test
original_token = os.environ.get("GITHUB_TOKEN")
original_repo = os.environ.get("GITHUB_REPOSITORY")
os.environ["GITHUB_TOKEN"] = "env_token"
os.environ["GITHUB_REPOSITORY"] = "env/repo"
tool = GitHubTool(logger=self.logger) # Use env vars
MockSessionConstructor.assert_called_once()
self.assertEqual(tool.session, mock_created_session)
self.assertEqual(tool._token, "env_token")
self.assertEqual(tool._repo, "env/repo")
self.assertIn("Authorization", mock_created_session.headers)
self.assertEqual(mock_created_session.headers["Authorization"], "token env_token")
# Restore original env vars
if original_token is None: del os.environ["GITHUB_TOKEN"]
else: os.environ["GITHUB_TOKEN"] = original_token
if original_repo is None: del os.environ["GITHUB_REPOSITORY"]
else: os.environ["GITHUB_REPOSITORY"] = original_repo
def test_init_raises_value_error_if_no_token(self):
original_token = os.environ.pop("GITHUB_TOKEN", None)
with self.assertRaisesRegex(ValueError, "GitHub token must be provided"):
GitHubTool(repo=self.test_repo, logger=self.logger)
if original_token: os.environ["GITHUB_TOKEN"] = original_token
def test_init_raises_value_error_if_no_repo(self):
original_repo = os.environ.pop("GITHUB_REPOSITORY", None)
with self.assertRaisesRegex(ValueError, "GitHub repository.*must be provided"):
GitHubTool(token=self.test_token, logger=self.logger)
if original_repo: os.environ["GITHUB_REPOSITORY"] = original_repo
def test_clear_resets_branch(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, initial_branch="feature-branch", logger=self.logger)
# Mock _get_branch_sha for _set_current_branch called by clear
with patch.object(tool, '_get_branch_sha', return_value="sha_for_main"):
tool.clear()
self.assertEqual(tool.current_branch, "main")
def test_get_functions_returns_list(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
functions = tool.get_functions()
self.assertIsInstance(functions, list)
self.assertTrue(len(functions) > 0)
self.assertIn("name", functions[0]["function"])
# --- Test individual private methods ---
def test_read_file_success(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
file_content = "Hello World!"
encoded_content = base64.b64encode(file_content.encode('utf-8')).decode('utf-8')
self.mock_session.get.return_value = create_mock_response(200, json_data={"content": encoded_content})
result = tool._read_file(path="test.txt")
self.assertEqual(result, file_content)
self.mock_session.get.assert_called_once_with(
f"{tool.base_url}/repos/{self.test_repo}/contents/test.txt",
params={"ref": "main"}
)
def test_read_file_error(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
self.mock_session.get.return_value = create_mock_response(404, text_data="Not Found")
result = tool._read_file(path="nonexistent.txt")
self.assertIn("Error reading file", result)
def test_create_branch_success(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
# Mock getting base branch SHA
self.mock_session.get.return_value = create_mock_response(200, json_data={"object": {"sha": "base_sha123"}})
# Mock creating new branch
self.mock_session.post.return_value = create_mock_response(201, json_data={"ref": "refs/heads/new-feature"})
result = tool._create_branch(branch_name="new-feature", base_branch="main")
self.assertIn("Branch 'new-feature' created successfully", result)
self.assertEqual(tool.current_branch, "new-feature")
self.mock_session.get.assert_called_once() # For base branch SHA
self.mock_session.post.assert_called_once() # For creating branch
def test_create_branch_base_sha_error(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
self.mock_session.get.return_value = create_mock_response(404, text_data="Base branch not found")
result = tool._create_branch(branch_name="new-feature", base_branch="nonexistent-base")
self.assertIn("Error getting base branch SHA", result)
def test_create_branch_creation_error(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
self.mock_session.get.return_value = create_mock_response(200, json_data={"object": {"sha": "base_sha456"}})
self.mock_session.post.return_value = create_mock_response(422, text_data="Validation failed")
result = tool._create_branch(branch_name="bad-branch", base_branch="main")
self.assertIn("Error creating branch", result)
def test_commit_file_success_new_file(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
tool.current_branch = "dev-branch" # Cannot commit to main by default
# Mock GET for checking file existence (404 means new file)
self.mock_session.get.return_value = create_mock_response(404)
# Mock PUT for committing file
self.mock_session.put.return_value = create_mock_response(201, json_data={"commit": {"sha": "commit_sha_abc"}})
result = tool._commit_file(file_path="new_file.py", content="print('Hello')", commit_message="Add new_file.py")
self.assertIn("committed successfully", result)
self.assertIn("commit_sha_abc", result)
self.mock_session.get.assert_called_once() # Check file existence
self.mock_session.put.assert_called_once() # Commit file
def test_commit_file_success_update_file(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
tool.current_branch = "dev-branch"
# Mock GET for checking file existence (200 means existing file)
self.mock_session.get.return_value = create_mock_response(200, json_data={"sha": "existing_file_sha"})
# Mock PUT for committing file
self.mock_session.put.return_value = create_mock_response(200, json_data={"commit": {"sha": "commit_sha_def"}})
result = tool._commit_file(file_path="existing_file.txt", content="Updated content", commit_message="Update existing_file.txt")
self.assertIn("committed successfully", result)
self.assertIn("commit_sha_def", result)
args, kwargs = self.mock_session.put.call_args
self.assertEqual(kwargs['json']['sha'], "existing_file_sha")
def test_commit_file_to_main_branch_error(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
tool.current_branch = "main"
result = tool._commit_file(file_path="some.txt", content="content", commit_message="msg")
self.assertIn("Action directly to main branch is not allowed", result)
def test_create_pull_request_success(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
tool.current_branch = "feature-pr"
pr_url = "https://example.com/pull/1"
self.mock_session.post.return_value = create_mock_response(201, json_data={"html_url": pr_url, "number": 1})
result = tool._create_pull_request(title="New Feature PR", body="Please review.", base="main")
self.assertIn(f"Pull request created successfully: {pr_url}", result)
self.mock_session.post.assert_called_once()
call_data = self.mock_session.post.call_args[1]['json']
self.assertEqual(call_data['head'], "feature-pr")
self.assertEqual(call_data['base'], "main")
def test_create_pull_request_same_branch_error(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
tool.current_branch = "main"
result = tool._create_pull_request(title="PR to self", body="This should fail", base="main")
self.assertIn("Cannot create a pull request from branch 'main' to itself", result)
def test_list_files_success(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
mock_items = [
{"name": "file1.txt", "type": "file", "path": "dir/file1.txt"},
{"name": "subdir", "type": "dir", "path": "dir/subdir"}
]
self.mock_session.get.return_value = create_mock_response(200, json_data=mock_items)
result = tool._list_files(path="dir")
self.assertEqual(len(result), 2)
self.assertEqual(result[0]["name"], "file1.txt")
self.assertEqual(result[1]["type"], "dir")
def test_search_code_success(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
mock_search_results = {
"items": [{"path": "src/code.py", "html_url": "url1"}]
}
self.mock_session.get.return_value = create_mock_response(200, json_data=mock_search_results)
results = tool._search_code(query="my_function")
self.assertEqual(len(results), 1)
self.assertEqual(results[0]["path"], "src/code.py")
def test_get_commit_history_success(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
mock_commits = [{
"sha": "sha1", "commit": {"message": "Msg1", "author": {"name": "Authy", "date": "Date1"}}
}]
self.mock_session.get.return_value = create_mock_response(200, json_data=mock_commits)
commits = tool._get_commit_history(file_path="file.txt", num_commits=1)
self.assertEqual(len(commits), 1)
self.assertEqual(commits[0]["sha"], "sha1")
def test_set_current_branch_success(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
# Mock _get_branch_sha to simulate branch exists
with patch.object(tool, '_get_branch_sha', return_value="some_sha_for_dev"):
result = tool._set_current_branch(branch_name="dev")
self.assertEqual(tool.current_branch, "dev")
self.assertIn("Current branch set to: dev", result)
def test_set_current_branch_not_exists(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
with patch.object(tool, '_get_branch_sha', return_value="Error getting SHA for branch"):
result = tool._set_current_branch(branch_name="nonexistent-branch")
self.assertNotEqual(tool.current_branch, "nonexistent-branch") # Should not change
self.assertIn("Cannot set current branch", result)
def test_list_branches_single_page(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
mock_branches = [{"name": "main"}, {"name": "dev"}]
self.mock_session.get.return_value = create_mock_response(200, json_data=mock_branches, links={}) # No "next" link
branches = tool._list_branches(all_pages=True)
self.assertEqual(branches, ["main", "dev"])
self.mock_session.get.assert_called_once()
def test_list_branches_multiple_pages(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
# Page 1 response
page1_branches = [{"name": "branch1"}, {"name": "branch2"}]
next_url = f"{tool.base_url}/repos/{self.test_repo}/branches?page=2"
response1 = create_mock_response(200, json_data=page1_branches, links={"next": {"url": next_url}})
# Page 2 response
page2_branches = [{"name": "branch3"}]
response2 = create_mock_response(200, json_data=page2_branches, links={}) # No "next" link
self.mock_session.get.side_effect = [response1, response2]
branches = tool._list_branches(all_pages=True)
self.assertEqual(branches, ["branch1", "branch2", "branch3"])
self.assertEqual(self.mock_session.get.call_count, 2)
# Check that the second call used the next_url
calls = self.mock_session.get.call_args_list
self.assertEqual(calls[1][0][0], next_url) # args[0] is the URL
# --- Test execute dispatcher ---
def test_execute_read_file(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
with patch.object(tool, '_read_file', return_value="file content") as mock_method:
result = tool.execute(function_name="read_file", path="test.md")
mock_method.assert_called_once_with(path="test.md")
self.assertEqual(result, "file content")
def test_execute_unknown_function(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
result = tool.execute(function_name="non_existent_function_name", arg1="val1")
self.assertIn("Unknown function: non_existent_function_name", result)
def test_execute_method_exception(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
with patch.object(tool, '_read_file', side_effect=Exception("Kaboom")) as mock_method:
result = tool.execute(function_name="read_file", path="crash.txt")
self.assertIn("Error during read_file execution: Kaboom", result)
if __name__ == '__main__':
unittest.main()
-146
View File
@@ -1,146 +0,0 @@
import unittest
from unittest.mock import patch, mock_open, MagicMock
import os
import logging
from datetime import datetime, timedelta
# Ensure tools/log_tool.py is accessible
from tools.log_tool import LogTool
class TestLogTool(unittest.TestCase):
def setUp(self):
self.test_log_file_path = "test_dummy_log.log"
# Suppress logging output during tests unless explicitly testing for it
self.logger = logging.getLogger('tools.log_tool')
# Ensure only one NullHandler to prevent duplicate messages if tests run multiple times in a session
if not any(isinstance(h, logging.NullHandler) for h in self.logger.handlers):
self.logger.addHandler(logging.NullHandler())
self.logger.propagate = False # Prevent propagation to root logger if it has handlers
def test_init_default_log_path(self):
tool = LogTool(logger=self.logger)
self.assertEqual(tool.configured_log_file_path, 'logs/output.log')
def test_init_custom_log_path(self):
tool = LogTool(log_file_path=self.test_log_file_path, logger=self.logger)
self.assertEqual(tool.configured_log_file_path, self.test_log_file_path)
def test_get_functions(self):
tool = LogTool(logger=self.logger)
functions = tool.get_functions()
self.assertIsInstance(functions, list)
self.assertEqual(len(functions), 1)
self.assertEqual(functions[0]["function"]["name"], "get_log_contents")
@patch("os.path.exists", return_value=False)
def test_get_log_contents_file_not_exists(self, mock_exists):
tool = LogTool(log_file_path=self.test_log_file_path, logger=self.logger)
result = tool._get_log_contents()
self.assertIn("Log file does not exist", result)
mock_exists.assert_called_once_with(self.test_log_file_path)
@patch("os.path.exists", return_value=True)
@patch("builtins.open", new_callable=mock_open, read_data="line1\nline2\nline3\nline4\nline5")
def test_get_log_contents_with_line_count(self, mock_file_open, mock_exists):
tool = LogTool(log_file_path=self.test_log_file_path, logger=self.logger)
result = tool._get_log_contents(line_count=3)
self.assertEqual(result, "line3\nline4\nline5")
mock_exists.assert_called_once_with(self.test_log_file_path)
mock_file_open.assert_called_once_with(self.test_log_file_path, 'r', encoding='utf-8')
@patch("os.path.exists", return_value=True)
@patch("builtins.open", new_callable=mock_open, read_data="line1\nline2\n")
def test_get_log_contents_line_count_more_than_available(self, mock_file_open, mock_exists):
tool = LogTool(log_file_path=self.test_log_file_path, logger=self.logger)
result = tool._get_log_contents(line_count=5)
self.assertEqual(result, "line1\nline2\n")
@patch("os.path.exists", return_value=True)
@patch("builtins.open", new_callable=mock_open, read_data="line1\nline2\n")
def test_get_log_contents_invalid_line_count_uses_default(self, mock_file_open, mock_exists):
tool = LogTool(log_file_path=self.test_log_file_path, logger=self.logger)
# Test with zero, negative, and non-integer line_count (though type hint is int)
# The code defaults to 150 if invalid. Here, we only have 2 lines.
with patch.object(tool.logger, 'warning') as mock_log_warning:
result_zero = tool._get_log_contents(line_count=0)
self.assertEqual(result_zero, "line1\nline2\n")
mock_log_warning.assert_any_call("Invalid line_count '0' provided, defaulting to fetch last 150 lines.")
mock_file_open.reset_mock() # Reset for next call
result_neg = tool._get_log_contents(line_count=-5)
self.assertEqual(result_neg, "line1\nline2\n")
mock_log_warning.assert_any_call("Invalid line_count '-5' provided, defaulting to fetch last 150 lines.")
@patch("os.path.exists", return_value=True)
def test_get_log_contents_last_24_hours(self, mock_exists):
tool = LogTool(log_file_path=self.test_log_file_path, logger=self.logger)
now = datetime.now()
one_hour_ago_dt = now - timedelta(hours=1)
two_days_ago_dt = now - timedelta(days=2)
one_hour_ago_str = one_hour_ago_dt.strftime(LogTool.EXPECTED_LOG_TIMESTAMP_FORMAT)
two_days_ago_str = two_days_ago_dt.strftime(LogTool.EXPECTED_LOG_TIMESTAMP_FORMAT)
log_data = (
f"{two_days_ago_str} - OLD - This is an old log entry.\n"
f"No timestamp here - this line should be skipped by time filter.\n"
f"{one_hour_ago_str} - RECENT - This is a recent log entry.\n"
f"Malformed Date 2023-xx-01 - Another skipped line.\n"
f"{now.strftime(LogTool.EXPECTED_LOG_TIMESTAMP_FORMAT)} - CURRENT - This is a current log entry.\n"
)
expected_output = (
f"{one_hour_ago_str} - RECENT - This is a recent log entry.\n"
f"{now.strftime(LogTool.EXPECTED_LOG_TIMESTAMP_FORMAT)} - CURRENT - This is a current log entry.\n"
)
with patch("builtins.open", mock_open(read_data=log_data)):
result = tool._get_log_contents(line_count=None) # Trigger 24-hour logic
self.assertEqual(result, expected_output)
@patch("os.path.exists", return_value=True)
@patch("builtins.open", side_effect=IOError("File read error!"))
def test_get_log_contents_file_read_exception(self, mock_file_open, mock_exists):
tool = LogTool(log_file_path=self.test_log_file_path, logger=self.logger)
result = tool._get_log_contents(line_count=10)
self.assertIn("An error occurred while reading the log file: File read error!", result)
def test_execute_get_log_contents(self):
tool = LogTool(logger=self.logger)
mock_return_value = "Mocked log content"
with patch.object(tool, '_get_log_contents', return_value=mock_return_value) as mock_method:
result = tool.execute(function_name="get_log_contents", line_count=50)
mock_method.assert_called_once_with(line_count=50)
self.assertEqual(result, mock_return_value)
def test_execute_get_log_contents_no_line_count(self):
tool = LogTool(logger=self.logger)
mock_return_value = "Mocked log content for 24h"
with patch.object(tool, '_get_log_contents', return_value=mock_return_value) as mock_method:
result = tool.execute(function_name="get_log_contents") # No line_count
mock_method.assert_called_once_with(line_count=None) # Expects None to trigger 24h
self.assertEqual(result, mock_return_value)
def test_execute_unknown_function(self):
tool = LogTool(logger=self.logger)
result = tool.execute(function_name="non_existent_log_function")
self.assertIn("Unknown function: non_existent_log_function", result)
def test_clear_method(self):
tool = LogTool(logger=self.logger)
# Set a specific level for the logger for this test if needed to capture debug
original_level = tool.logger.level
tool.logger.setLevel(logging.DEBUG)
with self.assertLogs(tool.logger, level='DEBUG') as cm:
tool.clear()
self.assertTrue(any("LogTool clear called" in message for message in cm.output))
tool.logger.setLevel(original_level) # Reset level
if __name__ == '__main__':
unittest.main()
-217
View File
@@ -1,217 +0,0 @@
import unittest
from unittest.mock import patch, MagicMock, ANY
import time
import logging
# Ensure tools.metrics is accessible
from tools.metrics import Metrics # Import the class itself for direct testing
from tools.metrics import metrics as global_metrics_instance # Import the global instance
# A simple function to decorate for testing
def sample_function_for_metrics(duration=0.01):
# Simulate some work
# Note: time.sleep is not always precisely profiled by cProfile in the same way as CPU-bound work.
# For testing, we will mock the cProfile/pstats interaction rather than relying on actual sleep duration.
if duration > 0: # Make it conditional so we can test zero-time case too
pass # The actual work is not important when mocking cProfile results
return "sample_output"
def another_sample_function(x, y):
return x + y
class TestMetrics(unittest.TestCase):
def setUp(self):
# Create a fresh Metrics instance for most tests to avoid interference
self.logger = logging.getLogger('tools.metrics.test')
if not self.logger.handlers: # Avoid adding handler multiple times
self.logger.addHandler(logging.NullHandler())
self.metrics_instance = Metrics(logger=self.logger)
# Clear the global instance before each test that might use it
global_metrics_instance.clear_metrics()
def test_measure_decorator_counts_calls(self):
decorated_func = self.metrics_instance.measure(sample_function_for_metrics)
self.assertEqual(self.metrics_instance.call_count["sample_function_for_metrics"], 0)
decorated_func()
self.assertEqual(self.metrics_instance.call_count["sample_function_for_metrics"], 1)
decorated_func()
self.assertEqual(self.metrics_instance.call_count["sample_function_for_metrics"], 2)
@patch('cProfile.Profile')
@patch('pstats.Stats')
def test_measure_decorator_records_time(self, MockPStats, MockCProfile):
# Mock cProfile and pstats to control the time value
mock_profiler_instance = MockCProfile.return_value
mock_pstats_instance = MockPStats.return_value
# Simulate that pstats.Stats.stats dictionary contains the function's stats
# Key: (filename, lineno, funcname)
# Value: (cc, nc, tt, ct, callers) where ct is cumulative_time (index 3)
# Get code object of the function *before* decoration for correct key
original_func_code = sample_function_for_metrics.__code__
func_key = (original_func_code.co_filename, original_func_code.co_firstlineno, original_func_code.co_name)
# Configure mock_pstats_instance.stats to return our desired time
mock_pstats_instance.stats = {func_key: (1, 1, 0.05, 0.123, {})} # cc, nc, tt, ct=0.123
decorated_func = self.metrics_instance.measure(sample_function_for_metrics)
self.assertEqual(self.metrics_instance.total_time["sample_function_for_metrics"], 0)
# Call the decorated function
decorated_func(duration=0) # Duration arg doesn't matter due to mocking
# Assertions
mock_profiler_instance.enable.assert_called_once()
mock_profiler_instance.disable.assert_called_once()
MockPStats.assert_called_once_with(mock_profiler_instance)
self.assertEqual(self.metrics_instance.total_time["sample_function_for_metrics"], 0.123)
# Call again to see accumulation
# Reset mock stats for a new time value if needed, or assume same time per call
mock_pstats_instance.stats = {func_key: (1, 1, 0.05, 0.100, {})} # New ct=0.100
decorated_func(duration=0)
self.assertAlmostEqual(self.metrics_instance.total_time["sample_function_for_metrics"], 0.123 + 0.100)
@patch('cProfile.Profile')
@patch('pstats.Stats')
def test_measure_decorator_fallback_time_recording_by_name(self, MockPStats, MockCProfile):
mock_profiler_instance = MockCProfile.return_value
mock_pstats_instance = MockPStats.return_value
original_func_code = sample_function_for_metrics.__code__ # func to be decorated
# Simulate the primary key lookup fails by creating a slightly different key for what we expect
# This is what the code will try to look up first.
expected_primary_key = (original_func_code.co_filename, original_func_code.co_firstlineno, original_func_code.co_name)
# This is the key that will *actually* be in pstats.stats, simulating a mismatch for primary lookup
# but a match for the by-name fallback.
actual_stats_key_in_pstats = (original_func_code.co_filename,
original_func_code.co_firstlineno + 5, # simulate a lineno difference for primary key mismatch
original_func_code.co_name) # Name is the same for fallback
mock_pstats_instance.stats = {
# expected_primary_key is NOT present
actual_stats_key_in_pstats: (1, 1, 0.03, 0.077, {}) # ct = 0.077
}
decorated_func = self.metrics_instance.measure(sample_function_for_metrics)
# Expecting a debug log for fallback, but assertLogs needs the logger to have a handler that captures.
# self.logger is already set up with NullHandler. For this test, let's use a specific logger.
metrics_internal_logger = logging.getLogger('tools.metrics') # Logger used inside Metrics class
original_level = metrics_internal_logger.level
metrics_internal_logger.setLevel(logging.DEBUG)
with self.assertLogs(metrics_internal_logger, level='DEBUG') as log_capture:
decorated_func(duration=0)
metrics_internal_logger.setLevel(original_level) # Reset logger level
self.assertTrue(any("Found stats for sample_function_for_metrics by name" in msg for msg in log_capture.output))
self.assertEqual(self.metrics_instance.total_time["sample_function_for_metrics"], 0.077)
@patch('cProfile.Profile')
@patch('pstats.Stats')
def test_measure_decorator_handles_func_stats_not_found(self, MockPStats, MockCProfile):
mock_profiler_instance = MockCProfile.return_value
mock_pstats_instance = MockPStats.return_value
mock_pstats_instance.stats = {} # Empty stats, function will not be found
decorated_func = self.metrics_instance.measure(sample_function_for_metrics)
metrics_internal_logger = logging.getLogger('tools.metrics')
original_level = metrics_internal_logger.level
metrics_internal_logger.setLevel(logging.WARNING)
with self.assertLogs(metrics_internal_logger, level='WARNING') as log_capture:
decorated_func(duration=0)
metrics_internal_logger.setLevel(original_level)
self.assertTrue(any("Could not find exact cProfile stats" in msg for msg in log_capture.output))
self.assertEqual(self.metrics_instance.total_time["sample_function_for_metrics"], 0)
def test_get_metrics_empty(self):
self.assertEqual(self.metrics_instance.get_metrics(), {})
@patch('cProfile.Profile')
@patch('pstats.Stats')
def test_get_metrics_with_data(self, MockPStats, MockCProfile):
mock_pstats_instance = MockPStats.return_value
# Decorate two different functions
decorated_func1 = self.metrics_instance.measure(sample_function_for_metrics)
decorated_func2 = self.metrics_instance.measure(another_sample_function)
# Data for func1
func1_code = sample_function_for_metrics.__code__
func1_key = (func1_code.co_filename, func1_code.co_firstlineno, func1_code.co_name)
mock_pstats_instance.stats = {func1_key: (1,1,0.1,0.1,{})}
decorated_func1()
# Data for func2
func2_code = another_sample_function.__code__
func2_key = (func2_code.co_filename, func2_code.co_firstlineno, func2_code.co_name)
mock_pstats_instance.stats = {func2_key: (1,1,0.2,0.2,{})} # Cumulative time 0.2
decorated_func2(1,2)
mock_pstats_instance.stats = {func2_key: (1,1,0.3,0.3,{})} # Cumulative time 0.3 for second call
decorated_func2(3,4)
metrics_data = self.metrics_instance.get_metrics()
self.assertIn("sample_function_for_metrics", metrics_data)
self.assertEqual(metrics_data["sample_function_for_metrics"]["call_count"], 1)
self.assertEqual(metrics_data["sample_function_for_metrics"]["total_time"], 0.1)
self.assertEqual(metrics_data["sample_function_for_metrics"]["average_time"], 0.1)
self.assertIn("another_sample_function", metrics_data)
self.assertEqual(metrics_data["another_sample_function"]["call_count"], 2)
self.assertAlmostEqual(metrics_data["another_sample_function"]["total_time"], 0.5)
self.assertAlmostEqual(metrics_data["another_sample_function"]["average_time"], 0.25)
def test_clear_metrics(self):
# Add some data
self.metrics_instance.call_count["test_func"] = 5
self.metrics_instance.total_time["test_func"] = 1.234
self.metrics_instance.clear_metrics()
self.assertEqual(self.metrics_instance.call_count, {})
self.assertEqual(self.metrics_instance.total_time, {})
self.assertEqual(self.metrics_instance.get_metrics(), {})
# Test the global instance
@patch('cProfile.Profile')
@patch('pstats.Stats')
def test_global_metrics_instance_usage(self, MockPStats, MockCProfile):
mock_pstats_instance = MockPStats.return_value
# Decorate a function with the global instance
@global_metrics_instance.measure
def globally_decorated_func():
return "global_output"
# Setup mock stats for the globally decorated function
# Access __wrapped__ to get the original function if other decorators might be present or for consistency.
original_g_func = globally_decorated_func.__wrapped__
func_code = original_g_func.__code__
func_key = (func_code.co_filename, func_code.co_firstlineno, func_code.co_name)
mock_pstats_instance.stats = {func_key: (1,1,0.05,0.05,{})}
globally_decorated_func()
metrics_data = global_metrics_instance.get_metrics()
self.assertIn("globally_decorated_func", metrics_data)
self.assertEqual(metrics_data["globally_decorated_func"]["call_count"], 1)
self.assertEqual(metrics_data["globally_decorated_func"]["total_time"], 0.05)
if __name__ == '__main__':
unittest.main()
-161
View File
@@ -1,161 +0,0 @@
import unittest
from unittest.mock import MagicMock, patch
import logging
# Ensure tools.metrics_tool and tools.metrics are accessible
from tools.metrics_tool import MetricsTool
from tools.metrics import Metrics # Used for typehinting and creating a mockable instance
class TestMetricsTool(unittest.TestCase):
def setUp(self):
self.mock_metrics_provider = MagicMock(spec=Metrics)
self.logger = logging.getLogger('tools.metrics_tool.test')
if not self.logger.handlers:
self.logger.addHandler(logging.NullHandler())
self.logger.propagate = False
def test_init_with_provider(self):
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
self.assertEqual(tool.metrics_provider, self.mock_metrics_provider)
@patch('tools.metrics_tool.global_metrics_instance') # Patch the global instance path
def test_init_default_provider(self, mock_global_metrics):
tool = MetricsTool(logger=self.logger)
self.assertEqual(tool.metrics_provider, mock_global_metrics)
def test_get_functions(self):
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
functions = tool.get_functions()
self.assertIsInstance(functions, list)
self.assertTrue(len(functions) == 3) # Based on current definition
self.assertIn("get_function_metrics", [f["function"]["name"] for f in functions])
self.assertIn("get_specific_function_metrics", [f["function"]["name"] for f in functions])
self.assertIn("get_top_n_functions", [f["function"]["name"] for f in functions])
def test_execute_get_function_metrics(self):
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
expected_metrics = {"func1": {"call_count": 1, "total_time": 0.1}}
self.mock_metrics_provider.get_metrics.return_value = expected_metrics
result = tool.execute(function_name="get_function_metrics")
self.mock_metrics_provider.get_metrics.assert_called_once()
self.assertEqual(result, expected_metrics)
def test_execute_get_specific_function_metrics_found(self):
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
func_metrics = {"call_count": 5, "total_time": 0.5, "average_time": 0.1}
all_metrics = {"specific_func": func_metrics, "other_func": {}}
self.mock_metrics_provider.get_metrics.return_value = all_metrics
# The execute method expects kwargs that match the function parameters in get_functions.
# So, the argument name for the function to get is 'function_name' in the tool's spec.
result = tool.execute(function_name="get_specific_function_metrics", **{"function_name": "specific_func"})
self.assertEqual(result, func_metrics)
def test_execute_get_specific_function_metrics_not_found(self):
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
self.mock_metrics_provider.get_metrics.return_value = {"other_func": {}}
result = tool.execute(function_name="get_specific_function_metrics", **{"function_name": "non_existent_func"})
self.assertEqual(result, "No metrics found for function: non_existent_func")
def test_execute_get_specific_function_metrics_missing_arg(self):
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
result = tool.execute(function_name="get_specific_function_metrics") # Missing function_name kwarg
self.assertIn("Error: Missing required argument 'function_name'", result)
def test_execute_get_top_n_functions(self):
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
metrics_data = {
"func_a": {"call_count": 1, "total_time": 0.3},
"func_b": {"call_count": 1, "total_time": 0.1},
"func_c": {"call_count": 1, "total_time": 0.5},
"func_d": {"call_count": 1, "total_time": 0.2},
}
self.mock_metrics_provider.get_metrics.return_value = metrics_data
# Test getting top 2
result = tool.execute(function_name="get_top_n_functions", n=2)
expected_top_2 = {"func_c": metrics_data["func_c"], "func_a": metrics_data["func_a"]}
self.assertEqual(result, expected_top_2)
# Test getting top 1
result_top_1 = tool.execute(function_name="get_top_n_functions", n=1)
expected_top_1 = {"func_c": metrics_data["func_c"]}
self.assertEqual(result_top_1, expected_top_1)
# Test N larger than available functions
result_top_all = tool.execute(function_name="get_top_n_functions", n=10)
# Order should be func_c, func_a, func_d, func_b
expected_top_all_keys = ["func_c", "func_a", "func_d", "func_b"]
self.assertEqual(list(result_top_all.keys()), expected_top_all_keys)
def test_execute_get_top_n_functions_malformed_metrics(self):
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
metrics_data = {
"func_a": {"call_count": 1, "total_time": 0.3},
"func_b": "not a dict", # Malformed
"func_c": {"call_count": 1}, # Missing total_time
"func_d": {"call_count": 1, "total_time": 0.2},
}
self.mock_metrics_provider.get_metrics.return_value = metrics_data
metrics_tool_logger = logging.getLogger('tools.metrics_tool')
original_level = metrics_tool_logger.level
metrics_tool_logger.setLevel(logging.WARNING)
with self.assertLogs(metrics_tool_logger, level='WARNING') as log_capture:
result = tool.execute(function_name="get_top_n_functions", n=2)
metrics_tool_logger.setLevel(original_level)
# Check that warnings were logged for malformed items
self.assertTrue(any("Metric item for 'func_b' is not in expected format" in msg for msg in log_capture.output))
self.assertTrue(any("Metric item for 'func_c' is not in expected format" in msg for msg in log_capture.output))
# Expected: func_a, func_d (as they are valid and sortable)
expected_result = {
"func_a": metrics_data["func_a"],
"func_d": metrics_data["func_d"]
}
self.assertEqual(result, expected_result)
def test_execute_get_top_n_functions_invalid_n(self):
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
self.mock_metrics_provider.get_metrics.return_value = {} # No metrics needed for this test
result_zero = tool.execute(function_name="get_top_n_functions", n=0)
self.assertIn("Error: Argument 'n' must be a positive integer.", result_zero)
result_negative = tool.execute(function_name="get_top_n_functions", n=-1)
self.assertIn("Error: Argument 'n' must be a positive integer.", result_negative)
result_string = tool.execute(function_name="get_top_n_functions", n="abc")
self.assertIn("Error: Argument 'n' must be an integer.", result_string)
def test_execute_get_top_n_functions_missing_arg_n(self):
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
result = tool.execute(function_name="get_top_n_functions") # Missing n
self.assertIn("Error: Missing required argument 'n'.", result)
def test_execute_unknown_function(self):
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
result = tool.execute(function_name="non_existent_metrics_function")
self.assertIn("Unknown function: non_existent_metrics_function", result)
def test_clear_method(self):
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
metrics_tool_logger = logging.getLogger('tools.metrics_tool')
original_level = metrics_tool_logger.level
metrics_tool_logger.setLevel(logging.DEBUG)
with self.assertLogs(metrics_tool_logger, level='DEBUG') as cm:
tool.clear()
metrics_tool_logger.setLevel(original_level)
self.assertTrue(any("MetricsTool clear method called" in message for message in cm.output))
if __name__ == '__main__':
unittest.main()
+9 -6
View File
@@ -4,8 +4,7 @@ import zipfile
import io
import re
import logging
from .base_tool import BaseTool # Added
from .metrics import metrics # Added
from .base_tool import BaseTool
# Configure logging for the tool - This will be handled by the logger instance now
# logger = logging.getLogger(__name__) # Commented out or removed
@@ -70,7 +69,8 @@ class GitHubCIHelper(BaseTool): # Inherits from BaseTool
},
"required": ["pull_request_number"]
}
}
},
"_tags": ["read"]
},
{
"type": "function",
@@ -85,7 +85,8 @@ class GitHubCIHelper(BaseTool): # Inherits from BaseTool
},
"required": ["pull_request_number"]
}
}
},
"_tags": ["read"]
},
{
"type": "function",
@@ -100,7 +101,8 @@ class GitHubCIHelper(BaseTool): # Inherits from BaseTool
},
"required": ["run_id"]
}
}
},
"_tags": ["read"]
},
{
"type": "function",
@@ -114,7 +116,8 @@ class GitHubCIHelper(BaseTool): # Inherits from BaseTool
},
"required": ["log_content"]
}
}
},
"_tags": ["read"]
}
]
+72 -74
View File
@@ -1,6 +1,5 @@
# tools/github_tool.py
from .base_tool import BaseTool
from .metrics import metrics
import requests
import os
import base64
@@ -57,7 +56,8 @@ class GitHubTool(BaseTool):
},
"required": ["path"]
}
}
},
"_tags": ["read"]
},
{
"type": "function",
@@ -71,7 +71,8 @@ class GitHubTool(BaseTool):
},
"required": ["path"]
}
}
},
"_tags": ["read"]
},
{
"type": "function",
@@ -85,7 +86,8 @@ class GitHubTool(BaseTool):
},
"required": ["query"]
}
}
},
"_tags": ["read"]
},
{
"type": "function",
@@ -100,7 +102,8 @@ class GitHubTool(BaseTool):
},
"required": ["branch_name"]
}
}
},
"_tags": ["write"]
},
{
"type": "function",
@@ -116,7 +119,8 @@ class GitHubTool(BaseTool):
},
"required": ["file_path", "commit_message", "content"]
}
}
},
"_tags": ["write"]
},
{
"type": "function",
@@ -132,7 +136,8 @@ class GitHubTool(BaseTool):
},
"required": ["title", "body"]
}
}
},
"_tags": ["write"]
},
{
"type": "function",
@@ -147,7 +152,8 @@ class GitHubTool(BaseTool):
},
"required": ["file_path"]
}
}
},
"_tags": ["read"]
},
{
"type": "function",
@@ -162,7 +168,8 @@ class GitHubTool(BaseTool):
},
"required": ["file_path"]
}
}
},
"_tags": ["read"]
},
{
"type": "function",
@@ -176,7 +183,8 @@ class GitHubTool(BaseTool):
},
"required": ["branch"]
}
}
},
"_tags": ["read"]
},
{
"type": "function",
@@ -184,7 +192,8 @@ class GitHubTool(BaseTool):
"name": "get_current_branch",
"description": "Get the name of the current branch",
"parameters": { "type": "object", "properties": {} }
}
},
"_tags": ["read"]
},
{
"type": "function",
@@ -198,7 +207,8 @@ class GitHubTool(BaseTool):
},
"required": ["branch_name"]
}
}
},
"_tags": ["read", "write"]
},
{
"type": "function",
@@ -213,7 +223,8 @@ class GitHubTool(BaseTool):
},
"required": ["file_path", "commit_sha"]
}
}
},
"_tags": ["read"]
},
{
"type": "function",
@@ -227,7 +238,8 @@ class GitHubTool(BaseTool):
"all_pages": {"type": "boolean", "description": "Whether to fetch all pages of results", "default": True}
}
}
}
},
"_tags": ["read"]
},
{
"type": "function",
@@ -241,7 +253,8 @@ class GitHubTool(BaseTool):
},
"required": ["pull_number"]
}
}
},
"_tags": ["write"]
},
{
"type": "function",
@@ -255,7 +268,8 @@ class GitHubTool(BaseTool):
},
"required": ["pull_number"]
}
}
},
"_tags": ["write"]
},
{
"type": "function",
@@ -277,7 +291,8 @@ class GitHubTool(BaseTool):
},
"required": ["pull_number"]
}
}
},
"_tags": ["write"]
},
{
"type": "function",
@@ -291,7 +306,8 @@ class GitHubTool(BaseTool):
},
"required": ["branch_name"]
}
}
},
"_tags": ["write"]
},
{
"type": "function",
@@ -305,7 +321,8 @@ class GitHubTool(BaseTool):
},
"required": ["issue_number"]
}
}
},
"_tags": ["read"]
},
{
"type": "function",
@@ -325,7 +342,8 @@ class GitHubTool(BaseTool):
},
"required": ["title", "body"]
}
}
},
"_tags": ["communicate"]
},
{
"type": "function",
@@ -340,7 +358,8 @@ class GitHubTool(BaseTool):
"page": {"type": "integer", "default": 1, "description": "Page number of the results to fetch"}
}
}
}
},
"_tags": ["read"]
},
{
"type": "function",
@@ -355,7 +374,8 @@ class GitHubTool(BaseTool):
},
"required": ["issue_number", "comment"]
}
}
},
"_tags": ["communicate"]
},
{
"type": "function",
@@ -369,7 +389,8 @@ class GitHubTool(BaseTool):
},
"required": ["issue_number"]
}
}
},
"_tags": ["read"]
},
{
"type": "function",
@@ -383,7 +404,8 @@ class GitHubTool(BaseTool):
},
"required": ["pull_number"]
}
}
},
"_tags": ["read"]
},
{
"type": "function",
@@ -398,7 +420,8 @@ class GitHubTool(BaseTool):
},
"required": ["name"]
}
}
},
"_tags": ["communicate"]
},
{
"type": "function",
@@ -413,7 +436,8 @@ class GitHubTool(BaseTool):
},
"required": ["project_id", "column_name"]
}
}
},
"_tags": ["communicate"]
},
{
"type": "function",
@@ -428,7 +452,8 @@ class GitHubTool(BaseTool):
},
"required": ["column_id", "note"]
}
}
},
"_tags": ["communicate"]
},
{
"type": "function",
@@ -444,7 +469,8 @@ class GitHubTool(BaseTool):
},
"required": ["card_id", "position", "column_id"]
}
}
},
"_tags": ["communicate"]
},
{
"type": "function",
@@ -460,7 +486,8 @@ class GitHubTool(BaseTool):
},
"required": ["card_id", "content_id", "content_type"]
}
}
},
"_tags": ["communicate"]
},
{
"type": "function",
@@ -468,7 +495,8 @@ class GitHubTool(BaseTool):
"name": "list_project_boards",
"description": "List project boards associated with the repository",
"parameters": { "type": "object", "properties": {} }
}
},
"_tags": ["communicate"]
},
{
"type": "function",
@@ -482,7 +510,8 @@ class GitHubTool(BaseTool):
},
"required": ["project_id"]
}
}
},
"_tags": ["communicate"]
},
{
"type": "function",
@@ -496,7 +525,8 @@ class GitHubTool(BaseTool):
},
"required": ["pull_number"]
}
}
},
"_tags": ["read"]
},
{
"type": "function",
@@ -510,7 +540,8 @@ class GitHubTool(BaseTool):
},
"required": ["pull_number"]
}
}
},
"_tags": ["read"]
},
{
"type": "function",
@@ -524,7 +555,8 @@ class GitHubTool(BaseTool):
},
"required": ["pull_number"]
}
}
},
"_tags": ["read"]
},
{
"type": "function",
@@ -545,7 +577,8 @@ class GitHubTool(BaseTool):
},
"required": ["pull_number", "body", "commit_id", "path", "position"]
}
}
},
"_tags": ["communicate"]
},
{
"type": "function",
@@ -559,7 +592,8 @@ class GitHubTool(BaseTool):
},
"required": ["pull_number"]
}
}
},
"_tags": ["read"]
},
{
"type": "function",
@@ -575,11 +609,11 @@ class GitHubTool(BaseTool):
},
"required": ["pull_number", "event"]
}
}
},
"_tags": ["communicate"]
}
]
@metrics.measure
def execute(self, function_name, **kwargs):
self.logger.info(f"Executing GitHub Tool function: {function_name} with args: {kwargs}")
# Dispatch to the appropriate private method
@@ -598,7 +632,6 @@ class GitHubTool(BaseTool):
# Private methods for each function, using self.session for HTTP requests
@metrics.measure
def _read_file(self, path):
self.logger.info(f"Reading file: {path} from branch: {self.current_branch}")
url = f"{self.base_url}/repos/{self._repo}/contents/{path}"
@@ -613,7 +646,6 @@ class GitHubTool(BaseTool):
self.logger.error(error_message)
return error_message
@metrics.measure
def _create_branch(self, branch_name, base_branch="main"):
self.logger.info(f"Creating branch: {branch_name} from base: {base_branch}")
# Get SHA of base branch
@@ -639,7 +671,6 @@ class GitHubTool(BaseTool):
self.logger.error(error_message)
return error_message
@metrics.measure
def _commit_file(self, file_path, content, commit_message):
self.logger.info(f"Committing file: {file_path} to branch: {self.current_branch} with message: '{commit_message}'")
if self.current_branch == "main":
@@ -679,7 +710,6 @@ class GitHubTool(BaseTool):
self.logger.error(error_message)
return error_message
@metrics.measure
def _create_pull_request(self, title, body, base="main"):
self.logger.info(f"Creating pull request: '{title}' from branch '{self.current_branch}' to '{base}'")
if self.current_branch == base:
@@ -701,7 +731,6 @@ class GitHubTool(BaseTool):
self.logger.error(error_message)
return error_message
@metrics.measure
def _get_branch_sha(self, branch):
self.logger.info(f"Getting SHA for branch: {branch}")
url = f"{self.base_url}/repos/{self._repo}/git/refs/heads/{branch}"
@@ -715,7 +744,6 @@ class GitHubTool(BaseTool):
self.logger.error(error_message)
return error_message
@metrics.measure
def _list_files(self, path):
self.logger.info(f"Listing files in path: '{path}' on branch: '{self.current_branch}'")
url = f"{self.base_url}/repos/{self._repo}/contents/{path.strip('/')}" # Ensure no leading/trailing slashes for consistency
@@ -738,7 +766,6 @@ class GitHubTool(BaseTool):
self.logger.error(error_message)
return error_message
@metrics.measure
def _search_code(self, query):
self.logger.info(f"Searching code with query: '{query}' in repo: '{self._repo}'")
url = f"{self.base_url}/search/code"
@@ -754,7 +781,6 @@ class GitHubTool(BaseTool):
self.logger.error(error_message)
return error_message
@metrics.measure
def _get_commit_history(self, file_path, num_commits=10):
self.logger.info(f"Getting last {num_commits} commit(s) for file: '{file_path}' on branch '{self.current_branch}'")
url = f"{self.base_url}/repos/{self._repo}/commits"
@@ -775,18 +801,15 @@ class GitHubTool(BaseTool):
self.logger.error(error_message)
return error_message
@metrics.measure
def _view_commit_details_for_file(self, file_path, num_commits=10):
# This function is essentially the same as get_commit_history based on its description.
self.logger.info(f"Viewing commit details for file '{file_path}' (last {num_commits} commits) - using _get_commit_history.")
return self._get_commit_history(file_path, num_commits)
@metrics.measure
def _get_current_branch(self):
self.logger.info(f"Current branch is: {self.current_branch}")
return self.current_branch
@metrics.measure
def _set_current_branch(self, branch_name):
self.logger.info(f"Attempting to set current branch to: {branch_name}")
# Check if branch exists by trying to get its SHA
@@ -801,7 +824,6 @@ class GitHubTool(BaseTool):
self.logger.info(success_message)
return success_message
@metrics.measure
def _get_file_at_commit(self, file_path, commit_sha):
self.logger.info(f"Getting file '{file_path}' at commit SHA: {commit_sha}")
url = f"{self.base_url}/repos/{self._repo}/contents/{file_path}"
@@ -816,7 +838,6 @@ class GitHubTool(BaseTool):
self.logger.error(error_message)
return error_message
@metrics.measure
def _list_branches(self, per_page=100, all_pages=True):
self.logger.info(f"Listing branches for repo '{self._repo}'. Per_page={per_page}, All_pages={all_pages}")
url = f"{self.base_url}/repos/{self._repo}/branches"
@@ -844,7 +865,6 @@ class GitHubTool(BaseTool):
self.logger.info(f"Successfully listed {len(branches_list)} branches.")
return branches_list
@metrics.measure
def _approve_pull_request(self, pull_number):
self.logger.info(f"Approving pull request #{pull_number}")
url = f"{self.base_url}/repos/{self._repo}/pulls/{pull_number}/reviews"
@@ -859,7 +879,6 @@ class GitHubTool(BaseTool):
self.logger.error(error_message)
return error_message
@metrics.measure
def _close_pull_request(self, pull_number):
self.logger.info(f"Closing pull request #{pull_number}")
url = f"{self.base_url}/repos/{self._repo}/pulls/{pull_number}"
@@ -874,7 +893,6 @@ class GitHubTool(BaseTool):
self.logger.error(error_message)
return error_message
@metrics.measure
def _merge_pull_request(self, pull_number, commit_title="Merge pull request", commit_message="", merge_method="merge"):
self.logger.info(f"Merging pull request #{pull_number} using method '{merge_method}'")
url = f"{self.base_url}/repos/{self._repo}/pulls/{pull_number}/merge"
@@ -897,7 +915,6 @@ class GitHubTool(BaseTool):
self.logger.error(error_message)
return error_message
@metrics.measure
def _delete_branch(self, branch_name):
self.logger.info(f"Deleting branch: {branch_name}")
if branch_name == "main" or (hasattr(self, 'default_branch') and branch_name == self.default_branch) :
@@ -920,7 +937,6 @@ class GitHubTool(BaseTool):
self.logger.error(error_message)
return error_message
@metrics.measure
def _get_issue_details(self, issue_number):
self.logger.info(f"Getting details for issue #{issue_number}")
url = f"{self.base_url}/repos/{self._repo}/issues/{issue_number}"
@@ -933,7 +949,6 @@ class GitHubTool(BaseTool):
self.logger.error(error_message)
return error_message
@metrics.measure
def _create_issue(self, title, body, labels=None):
self.logger.info(f"Creating new issue with title: '{title}'")
url = f"{self.base_url}/repos/{self._repo}/issues"
@@ -953,7 +968,6 @@ class GitHubTool(BaseTool):
self.logger.error(error_message)
return error_message
@metrics.measure
def _list_issues(self, state="open", per_page=30, page=1):
self.logger.info(f"Listing issues with state: {state}, per_page: {per_page}, page: {page}")
url = f"{self.base_url}/repos/{self._repo}/issues"
@@ -969,7 +983,6 @@ class GitHubTool(BaseTool):
self.logger.error(error_message)
return error_message
@metrics.measure
def _add_issue_comment(self, issue_number, comment):
self.logger.info(f"Adding comment to issue #{issue_number}: '{comment[:50]}...'")
url = f"{self.base_url}/repos/{self._repo}/issues/{issue_number}/comments"
@@ -985,7 +998,6 @@ class GitHubTool(BaseTool):
self.logger.error(error_message)
return error_message
@metrics.measure
def _get_issue_comments(self, issue_number):
self.logger.info(f"Getting comments for issue #{issue_number}")
url = f"{self.base_url}/repos/{self._repo}/issues/{issue_number}/comments"
@@ -1000,14 +1012,12 @@ class GitHubTool(BaseTool):
self.logger.error(error_message)
return error_message
@metrics.measure
def _get_pull_request_general_comments(self, pull_number):
self.logger.info(f"Getting general comments for pull request #{pull_number}")
# In GitHub API, PR comments (general, not review comments on lines) are issue comments.
# The PR is also an issue, so use the issue comments endpoint.
return self._get_issue_comments(issue_number=pull_number)
@metrics.measure
def _create_project_board(self, name, body=None):
self.logger.info(f"Creating project board: '{name}'")
url = f"{self.base_url}/repos/{self._repo}/projects"
@@ -1026,7 +1036,6 @@ class GitHubTool(BaseTool):
self.logger.error(error_message)
return error_message
@metrics.measure
def _create_project_column(self, project_id, column_name):
self.logger.info(f"Creating column '{column_name}' for project ID: {project_id}")
url = f"{self.base_url}/projects/{project_id}/columns"
@@ -1044,7 +1053,6 @@ class GitHubTool(BaseTool):
self.logger.error(error_message)
return error_message
@metrics.measure
def _create_project_card(self, column_id, note=None, content_id=None, content_type=None):
self.logger.info(f"Creating card in column ID: {column_id}")
url = f"{self.base_url}/projects/columns/{column_id}/cards"
@@ -1075,7 +1083,6 @@ class GitHubTool(BaseTool):
self.logger.error(error_message)
return error_message
@metrics.measure
def _move_project_card(self, card_id, position, column_id=None):
self.logger.info(f"Moving card ID: {card_id} to position: {position}" + (f" in column ID: {column_id}" if column_id else ""))
url = f"{self.base_url}/projects/columns/cards/{card_id}/moves"
@@ -1100,7 +1107,6 @@ class GitHubTool(BaseTool):
# For updating an existing card to link an issue, one would PATCH the card's content_id/content_type.
# Let's assume the function intends to update an existing card if it's a separate function.
# However, the provided API spec for `link_issue_to_project_card` uses PATCH on card_id, so let's implement that.
@metrics.measure
def _link_issue_to_project_card(self, card_id, content_id, content_type):
self.logger.info(f"Linking content_id {content_id} (type: {content_type}) to card_id {card_id}")
url = f"{self.base_url}/projects/cards/{card_id}" # Note: API docs suggest /projects/columns/cards/{card_id} or /projects/cards/{card_id}
@@ -1120,7 +1126,6 @@ class GitHubTool(BaseTool):
self.logger.error(error_message)
return error_message
@metrics.measure
def _list_project_boards(self):
self.logger.info(f"Listing project boards for repo: {self._repo}")
url = f"{self.base_url}/repos/{self._repo}/projects"
@@ -1136,7 +1141,6 @@ class GitHubTool(BaseTool):
self.logger.error(error_message)
return error_message
@metrics.measure
def _view_project_board_items(self, project_id):
self.logger.info(f"Viewing items for project ID: {project_id}")
columns_url = f"{self.base_url}/projects/{project_id}/columns"
@@ -1165,7 +1169,6 @@ class GitHubTool(BaseTool):
self.logger.info(f"Successfully retrieved items for project ID: {project_id}.")
return project_items
@metrics.measure
def _get_pull_request_details(self, pull_number):
self.logger.info(f"Getting details for PR #{pull_number}")
url = f"{self.base_url}/repos/{self._repo}/pulls/{pull_number}"
@@ -1178,7 +1181,6 @@ class GitHubTool(BaseTool):
self.logger.error(error_message)
return error_message
@metrics.measure
def _get_pull_request_diff(self, pull_number):
self.logger.info(f"Getting diff for PR #{pull_number}")
url = f"{self.base_url}/repos/{self._repo}/pulls/{pull_number}"
@@ -1193,7 +1195,6 @@ class GitHubTool(BaseTool):
self.logger.error(error_message)
return error_message
@metrics.measure
def _get_pull_request_files(self, pull_number):
self.logger.info(f"Getting files for PR #{pull_number}")
url = f"{self.base_url}/repos/{self._repo}/pulls/{pull_number}/files"
@@ -1206,7 +1207,6 @@ class GitHubTool(BaseTool):
self.logger.error(error_message)
return error_message
@metrics.measure
def _create_pull_request_review_comment(self, pull_number, body, commit_id, path, position, side="RIGHT", start_line=None, start_side=None):
self.logger.info(f"Creating review comment on PR #{pull_number}, file '{path}', position {position}")
url = f"{self.base_url}/repos/{self._repo}/pulls/{pull_number}/comments"
@@ -1225,7 +1225,6 @@ class GitHubTool(BaseTool):
self.logger.error(error_message)
return error_message
@metrics.measure
def _list_pull_request_review_comments(self, pull_number):
self.logger.info(f"Listing review comments for PR #{pull_number}")
url = f"{self.base_url}/repos/{self._repo}/pulls/{pull_number}/comments"
@@ -1238,7 +1237,6 @@ class GitHubTool(BaseTool):
self.logger.error(error_message)
return error_message
@metrics.measure
def _submit_pull_request_review(self, pull_number, event, body=None):
self.logger.info(f"Submitting '{event}' review for PR #{pull_number}")
url = f"{self.base_url}/repos/{self._repo}/pulls/{pull_number}/reviews"
-3
View File
@@ -1,6 +1,5 @@
# tools/log_tool.py
from .base_tool import BaseTool
from .metrics import metrics
import logging
import os
from datetime import datetime, timedelta
@@ -44,7 +43,6 @@ class LogTool(BaseTool):
}
]
@metrics.measure
def execute(self, function_name, **kwargs):
self.logger.info(f"Executing LogTool function: {function_name} with args: {kwargs}")
if function_name == "get_log_contents":
@@ -55,7 +53,6 @@ class LogTool(BaseTool):
self.logger.error(error_message)
return error_message
@metrics.measure
def _get_log_contents(self, line_count=None): # Default line_count is None to trigger 24h logic if not specified
self.logger.info(f"Attempting to get log contents from: {self.configured_log_file_path}. Line count: {line_count if line_count is not None else 'Last 24 hours'}")
-79
View File
@@ -1,79 +0,0 @@
# tools/metrics.py
import cProfile
import pstats
import io
from functools import wraps
from collections import defaultdict
import logging
class Metrics:
def __init__(self, logger=None):
self.call_count = defaultdict(int)
self.total_time = defaultdict(float)
self.logger = logger if logger else logging.getLogger(__name__)
if not self.logger.handlers:
self.logger.addHandler(logging.NullHandler())
self.logger.debug("Metrics instance initialized.")
def measure(self, func):
@wraps(func)
def wrapper(*args, **kwargs):
self.call_count[func.__name__] += 1
pr = cProfile.Profile()
pr.enable()
result = func(*args, **kwargs)
pr.disable()
ps = pstats.Stats(pr)
func_code = func.__code__
func_key_tuple = (func_code.co_filename, func_code.co_firstlineno, func_code.co_name)
time_spent_for_func = 0.0
if func_key_tuple in ps.stats:
time_spent_for_func = ps.stats[func_key_tuple][3] # [3] is cumulative time (ct)
else:
# Fallback: try to find by function name if exact key fails (e.g. due to decorators changing code object details slightly)
# This is less precise and might pick up other functions if names are not unique across files.
found_by_name = False
for key, stat in ps.stats.items():
if key[2] == func.__name__: # key[2] is function name
time_spent_for_func = stat[3] # cumulative time
self.logger.debug(f"Found stats for {func.__name__} by name {key} after primary key failed.")
found_by_name = True
break
if not found_by_name:
self.logger.warning(
f"Could not find exact cProfile stats for {func.__name__} with key {func_key_tuple} or by name. "
f"Time for this call will be recorded as 0. This might occur for non-Python functions or due to complex decorators."
)
self.total_time[func.__name__] += time_spent_for_func
self.logger.debug(f"Measured cumulative time for {func.__name__}: {time_spent_for_func:.6f}s")
return result
return wrapper
def get_metrics(self):
metrics_data = {}
for func_name in self.call_count:
count = self.call_count[func_name]
total_t = self.total_time[func_name]
metrics_data[func_name] = {
'call_count': count,
'total_time': round(total_t, 6),
'average_time': round(total_t / count, 6) if count > 0 else 0
}
return metrics_data
def clear_metrics(self):
self.call_count.clear()
self.total_time.clear()
self.logger.info("Metrics cleared.")
# Global instance for convenience
_metrics_instance_logger = logging.getLogger(__name__ + ".global_instance")
if not _metrics_instance_logger.handlers:
_metrics_instance_logger.addHandler(logging.NullHandler())
metrics = Metrics(logger=_metrics_instance_logger)
-128
View File
@@ -1,128 +0,0 @@
# tools/metrics_tool.py
from .base_tool import BaseTool
from .metrics import metrics as global_metrics_instance # For default and measuring execute
from .metrics import Metrics # For type hinting and potentially creating a new one if needed
import logging
class MetricsTool(BaseTool):
def __init__(self, metrics_provider: Metrics | None = None, logger: logging.Logger | None = None):
self.metrics_provider = metrics_provider if metrics_provider is not None else global_metrics_instance
self.logger = logger if logger else logging.getLogger(__name__)
if not self.logger.handlers:
self.logger.addHandler(logging.NullHandler())
self.logger.debug(f"MetricsTool initialized. Using metrics provider: {self.metrics_provider}")
def clear(self):
# This tool itself doesn't hold state that needs clearing beyond what its metrics_provider might do.
# If this tool were responsible for clearing the metrics it reports on, it would call:
# self.metrics_provider.clear_metrics()
self.logger.debug("MetricsTool clear method called. No local state to clear.")
pass
def get_functions(self):
return [
{
"type": "function",
"function": {
"name": "get_function_metrics",
"description": "Get metrics for all measured functions.",
"parameters": {
"type": "object",
"properties": {},
"required": []
}
}
},
{
"type": "function",
"function": {
"name": "get_specific_function_metrics",
"description": "Get metrics for a specific function.",
"parameters": {
"type": "object",
"properties": {
"function_name": {
"type": "string",
"description": "Name of the function to get metrics for"
}
},
"required": ["function_name"]
}
}
},
{
"type": "function",
"function": {
"name": "get_top_n_functions",
"description": "Get the top N functions by total execution time.",
"parameters": {
"type": "object",
"properties": {
"n": {
"type": "integer",
"description": "Number of top functions to retrieve"
}
},
"required": ["n"]
}
}
}
]
@global_metrics_instance.measure # The execute method can be measured by the global instance
def execute(self, function_name, **kwargs):
self.logger.info(f"Executing MetricsTool function: {function_name} with args: {kwargs}")
if function_name == "get_function_metrics":
return self._get_function_metrics()
elif function_name == "get_specific_function_metrics":
func_name_arg = kwargs.get("function_name")
if func_name_arg is None: # Check if None, as empty string could be a valid (though unlikely) func name
self.logger.warning("'function_name' argument is missing for get_specific_function_metrics.")
return "Error: Missing required argument 'function_name'."
return self._get_specific_function_metrics(str(func_name_arg)) # Ensure string
elif function_name == "get_top_n_functions":
n_arg = kwargs.get("n")
if n_arg is None:
self.logger.warning("'n' argument is missing for get_top_n_functions.")
return "Error: Missing required argument 'n'."
try:
n_val = int(n_arg)
if n_val <= 0:
self.logger.warning(f"'n' argument must be a positive integer, got {n_val}.")
return "Error: Argument 'n' must be a positive integer."
return self._get_top_n_functions(n_val)
except ValueError:
self.logger.warning(f"'n' argument must be an integer, got '{n_arg}'.")
return "Error: Argument 'n' must be an integer."
else:
error_message = f"Unknown function: {function_name}"
self.logger.error(error_message)
return error_message
def _get_function_metrics(self):
self.logger.debug("Calling metrics_provider.get_metrics() for all functions.")
return self.metrics_provider.get_metrics()
def _get_specific_function_metrics(self, function_to_get):
self.logger.debug(f"Getting metrics for specific function: {function_to_get}")
all_metrics = self.metrics_provider.get_metrics()
return all_metrics.get(function_to_get, f"No metrics found for function: {function_to_get}")
def _get_top_n_functions(self, n):
self.logger.debug(f"Getting top {n} functions by total execution time.")
all_metrics = self.metrics_provider.get_metrics()
# Ensure that the items are actual metric dicts before trying to access 'total_time'
valid_metrics_items = []
for name, metric_values in all_metrics.items():
if isinstance(metric_values, dict) and 'total_time' in metric_values:
valid_metrics_items.append((name, metric_values))
else:
self.logger.warning(f"Metric item for '{name}' is not in expected format: {metric_values}")
# Sort items by total_time. items() gives list of (func_name, metrics_dict)
try:
sorted_metrics = sorted(valid_metrics_items, key=lambda item: item[1]['total_time'], reverse=True)
return dict(sorted_metrics[:n])
except TypeError as e:
self.logger.error(f"Error sorting metrics, possibly due to unexpected data types: {e}", exc_info=True)
return "Error: Could not sort metrics due to unexpected data."
@@ -28,7 +28,7 @@ class StandaloneLLMTool(BaseTool):
"model": {
"type": "string",
"description": "The model to use for generating the detailed instructions. Use mini for most coding tasks, preview when needing sophisticated reasoning",
"enum": ["o1-mini", "o1-preview"],
"enum": ["mini", "max"],
"default": "o1-mini"
},
"max_tokens": {
@@ -38,7 +38,8 @@ class StandaloneLLMTool(BaseTool):
},
"required": ["prompt"]
}
}
},
"_tags": ["llm", "external"]
}
]