Refactored gemini, openai and claude into one file and removed logic from the base class, also made helper class definable from command line
This commit is contained in:
+35
-9
@@ -1,14 +1,40 @@
|
||||
# Telegram Bot Tokens
|
||||
TELEGRAM_BOT_TOKEN=your_daemon_bot_token_here
|
||||
TELEGRAM_APPRENTICE_BOT_TOKEN=your_apprentice_bot_token_here
|
||||
TELEGRAM_BOT_TOKEN=your_bot_token_here
|
||||
PYTHONPATH=${workspaceFolder}
|
||||
GITHUB_TOKEN=your_github_personal_access_token_here
|
||||
GITHUB_REPOSITORY=your_github_username_or_organization/your_repo_name
|
||||
GITHUB_REPO_OWNER=your_github_username_or_organization
|
||||
|
||||
SYSTEM_PROMPT_PATH=./prompts/project_manager_prompt.txt
|
||||
|
||||
ACTIVE_MODEL_PROFILE=OPENAI # Options: OPENAI, GEMINI, GLHF_CHAT
|
||||
|
||||
# Create a new profile with these settings:
|
||||
# {MODEL_PROFILE}_API_KEY
|
||||
# {MODEL_PROFILE}_API_BASE_URL # Optional for OpenAI
|
||||
# {MODEL_PROFILE}_SMALL_MODEL
|
||||
# {MODEL_PROFILE}_SMALL_MODEL_MAX_TOKENS
|
||||
# {MODEL_PROFILE}_LARGE_MODEL
|
||||
# {MODEL_PROFILE}_LARGE_MODEL_MAX_TOKENS
|
||||
|
||||
# OpenAI API Key
|
||||
OPENAI_API_KEY=your_openai_api_key_here
|
||||
OPENAI_SMALL_MODEL=gpt-4.1-mini
|
||||
OPENAI_SMALL_MODEL_MAX_TOKENS=32768
|
||||
OPENAI_LARGE_MODEL=gpt-4.1
|
||||
OPENAI_LARGE_MODEL_MAX_TOKENS=32768
|
||||
|
||||
# Anthropic API Key
|
||||
ANTHROPIC_API_KEY=your_anthropic_api_key_here
|
||||
# Gemini API
|
||||
GEMINI_API_KEY=your_gemini_api_key_here
|
||||
GEMINI_API_BASE_URL=https://generativelanguage.googleapis.com/v1beta/openai/
|
||||
GEMINI_SMALL_MODEL=gemini-2.5-flash-preview-05-20
|
||||
GEMINI_SMALL_MODEL_MAX_TOKENS=65536
|
||||
GEMINI_LARGE_MODEL=gemini-2.5-pro-preview-05-06
|
||||
GEMINI_LARGE_MODEL_MAX_TOKENS=65536
|
||||
|
||||
# GitHub Repository Information
|
||||
GITHUB_REPO_OWNER=your_github_username_or_organization
|
||||
GITHUB_REPO_NAME=your_repo_name
|
||||
GITHUB_ACCESS_TOKEN=your_github_personal_access_token
|
||||
# GLHF Chat API Key
|
||||
GLHF_CHAT_API_KEY=your_glhf_chat_api_key_here
|
||||
GLHF_CHAT_API_BASE_URL=https://glhf.chat/api/openai/v1
|
||||
GLHF_CHAT_SMALL_MODEL=meta-llama/Llama-3.3-70B-Instruct
|
||||
GLHF_CHAT_SMALL_MODEL_MAX_TOKENS=1024
|
||||
GLHF_CHAT_LARGE_MODEL=deepseek-ai/DeepSeek-V3-0324
|
||||
GLHF_CHAT_LARGE_MODEL_MAX_TOKENS=1024
|
||||
@@ -1,271 +0,0 @@
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
from anthropic import Anthropic, APIError, RateLimitError
|
||||
from base_telegram_inference_bot import BaseTelegramInferenceBot
|
||||
from telegram_helper import TelegramHelper # Used in main, not class
|
||||
|
||||
class AnthropicTelegramInferenceBot(BaseTelegramInferenceBot):
|
||||
DEFAULT_SMALL_MODEL_NAME = "claude-3-haiku-20240307"
|
||||
DEFAULT_SMALL_MODEL_MAX_TOKENS = "2048"
|
||||
DEFAULT_LARGE_MODEL_NAME = "claude-3-opus-20240229"
|
||||
DEFAULT_LARGE_MODEL_MAX_TOKENS = "4096"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
anthropic_client: Anthropic | None = None,
|
||||
api_key: str | None = None,
|
||||
small_model_name: str | None = None,
|
||||
small_model_max_tokens: str | None = None,
|
||||
large_model_name: str | None = None,
|
||||
large_model_max_tokens: str | None = None,
|
||||
system_prompt_content: str | None = None,
|
||||
system_prompt_path: str | None = None
|
||||
):
|
||||
super().__init__(system_prompt_content=system_prompt_content, system_prompt_path=system_prompt_path)
|
||||
|
||||
if anthropic_client:
|
||||
self.anthropic_client = anthropic_client
|
||||
else:
|
||||
_api_key = api_key or os.environ.get("ANTHROPIC_API_KEY")
|
||||
if not _api_key:
|
||||
raise ValueError("Anthropic API key must be provided either via argument or ANTHROPIC_API_KEY environment variable.")
|
||||
self.anthropic_client = Anthropic(api_key=_api_key)
|
||||
|
||||
self.small_model_name = small_model_name or os.environ.get("ANTHROPIC_SMALL_MODEL") or self.DEFAULT_SMALL_MODEL_NAME
|
||||
self.small_model_max_tokens_str = small_model_max_tokens or os.environ.get("ANTHROPIC_SMALL_MODEL_MAX_TOKENS") or self.DEFAULT_SMALL_MODEL_MAX_TOKENS
|
||||
self.large_model_name = large_model_name or os.environ.get("ANTHROPIC_LARGE_MODEL") or self.DEFAULT_LARGE_MODEL_NAME
|
||||
self.large_model_max_tokens_str = large_model_max_tokens or os.environ.get("ANTHROPIC_LARGE_MODEL_MAX_TOKENS") or self.DEFAULT_LARGE_MODEL_MAX_TOKENS
|
||||
|
||||
# Initialize with the small model by default
|
||||
self._configure_model_and_tokens(
|
||||
self.small_model_name,
|
||||
self.small_model_max_tokens_str,
|
||||
default_max_tokens=int(self.DEFAULT_SMALL_MODEL_MAX_TOKENS) # pass int for default
|
||||
)
|
||||
|
||||
def _configure_model_and_tokens(self, model_name: str, max_tokens_str: str, default_max_tokens: int = 2048):
|
||||
self.model = model_name
|
||||
try:
|
||||
self.max_tokens = int(max_tokens_str) if max_tokens_str is not None else default_max_tokens
|
||||
except ValueError:
|
||||
logging.error(f"Invalid value for Anthropic max_tokens: {max_tokens_str}. Using default {default_max_tokens}.")
|
||||
self.max_tokens = default_max_tokens
|
||||
logging.info(f"Configured to use Anthropic model: {self.model} with max_tokens: {self.max_tokens}")
|
||||
|
||||
def get_llm_description(self) -> str:
|
||||
return f"LLM: {self.model}, Max Tokens: {self.max_tokens}"
|
||||
|
||||
def get_chat_response(self, messages_history):
|
||||
current_system_prompt = self.system_prompt if self.system_prompt else ""
|
||||
anthropic_tools = []
|
||||
if hasattr(self, 'functions') and self.functions:
|
||||
anthropic_tools = [
|
||||
{
|
||||
"name": function['function']['name'],
|
||||
"description": function['function']['description'],
|
||||
"input_schema": function['function']['parameters'] if function['function']['parameters'] not in [None, {}] else {"type": "object", "properties": {}}
|
||||
}
|
||||
for function in self.functions
|
||||
]
|
||||
|
||||
try:
|
||||
response = self.anthropic_client.messages.create(
|
||||
model=self.model,
|
||||
system=current_system_prompt,
|
||||
messages=messages_history,
|
||||
max_tokens=self.max_tokens,
|
||||
tools=anthropic_tools if anthropic_tools else None,
|
||||
tool_choice={"type": "auto"} if anthropic_tools else None
|
||||
)
|
||||
return response
|
||||
except (APIError, RateLimitError) as e:
|
||||
logging.error(f"Anthropic API error: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logging.error(f"An unexpected error occurred during Anthropic API call: {e}")
|
||||
raise
|
||||
|
||||
def _format_tool_response_for_anthropic(self, tool_response_data):
|
||||
if isinstance(tool_response_data, str):
|
||||
# Wrap plain string in a list of text blocks if not already structured
|
||||
return [{"type": "text", "text": tool_response_data}]
|
||||
elif isinstance(tool_response_data, list) and all(isinstance(item, dict) and "type" in item for item in tool_response_data):
|
||||
# Already a list of content blocks
|
||||
return tool_response_data
|
||||
elif isinstance(tool_response_data, (dict, list)):
|
||||
# Attempt to JSON dump other dicts/lists if not already in content block format
|
||||
try:
|
||||
return [{"type": "text", "text": json.dumps(tool_response_data)}]
|
||||
except (TypeError, json.JSONDecodeError):
|
||||
return [{"type": "text", "text": str(tool_response_data)}] # Fallback to string
|
||||
else:
|
||||
# Fallback for other types (int, float, etc.)
|
||||
return [{"type": "text", "text": str(tool_response_data)}]
|
||||
|
||||
async def handle_message(self, user_id, user_message):
|
||||
if user_id not in self.conversation_history:
|
||||
self.conversation_history[user_id] = []
|
||||
|
||||
self.conversation_history[user_id].append({"role": "user", "content": user_message})
|
||||
current_turn_messages = list(self.conversation_history[user_id])
|
||||
|
||||
MAX_TOOL_ITERATIONS = 5
|
||||
tool_use_count = 0
|
||||
assistant_response_content = ""
|
||||
|
||||
while tool_use_count < MAX_TOOL_ITERATIONS:
|
||||
response = self.get_chat_response(current_turn_messages)
|
||||
|
||||
if not response or not response.content:
|
||||
logging.error("No valid response content from Anthropic LLM.")
|
||||
self.conversation_history[user_id] = current_turn_messages # Save current state
|
||||
return "Error: Could not get a valid response from the LLM."
|
||||
|
||||
assistant_current_turn_content_blocks = response.content
|
||||
current_turn_messages.append({"role": "assistant", "content": assistant_current_turn_content_blocks})
|
||||
|
||||
text_parts_from_assistant = []
|
||||
tool_calls_from_response = []
|
||||
for block in assistant_current_turn_content_blocks:
|
||||
if block.type == "text":
|
||||
text_parts_from_assistant.append(block.text)
|
||||
elif block.type == "tool_use":
|
||||
tool_calls_from_response.append(block)
|
||||
|
||||
assistant_response_content = "".join(text_parts_from_assistant)
|
||||
|
||||
if not tool_calls_from_response:
|
||||
break
|
||||
|
||||
tool_results_for_model = []
|
||||
for tool_call in tool_calls_from_response:
|
||||
tool_name = tool_call.name
|
||||
tool_input = tool_call.input
|
||||
tool_use_id = tool_call.id
|
||||
|
||||
logging.info(f"Attempting to call Anthropic tool: {tool_name} with input: {tool_input}")
|
||||
try:
|
||||
tool_response_data = self.call_tool(tool_name, tool_input)
|
||||
tool_result_content_block = self._format_tool_response_for_anthropic(tool_response_data)
|
||||
tool_results_for_model.append({
|
||||
"type": "tool_result",
|
||||
"tool_use_id": tool_use_id,
|
||||
"content": tool_result_content_block
|
||||
})
|
||||
except Exception as e:
|
||||
logging.error(f"Error calling tool {tool_name}: {e}")
|
||||
tool_results_for_model.append({
|
||||
"type": "tool_result",
|
||||
"tool_use_id": tool_use_id,
|
||||
"content": [{"type": "text", "text": f"Error executing tool {tool_name}: {str(e)}"}],
|
||||
"is_error": True
|
||||
})
|
||||
|
||||
current_turn_messages.append({"role": "user", "content": tool_results_for_model}) # Anthropic expects tool results as a user message
|
||||
|
||||
tool_use_count += 1
|
||||
if tool_use_count >= MAX_TOOL_ITERATIONS:
|
||||
logging.warning(f"Max tool iterations ({MAX_TOOL_ITERATIONS}) reached for Anthropic.")
|
||||
# Update assistant_response_content with any text from the last assistant turn before breaking
|
||||
if not assistant_response_content and text_parts_from_assistant:
|
||||
assistant_response_content = "".join(text_parts_from_assistant)
|
||||
assistant_response_content += "\n[Max tool iterations reached]"
|
||||
break
|
||||
|
||||
self.conversation_history[user_id] = current_turn_messages
|
||||
|
||||
if len(self.conversation_history[user_id]) > 20:
|
||||
self.conversation_history[user_id] = self.conversation_history[user_id][-20:]
|
||||
|
||||
if assistant_response_content:
|
||||
return assistant_response_content
|
||||
else:
|
||||
# Fallback if no text parts were found but there was an assistant message
|
||||
if current_turn_messages:
|
||||
last_message_in_turn = current_turn_messages[-1]
|
||||
# Check if the last message content has text blocks (Anthropic specific structure)
|
||||
if last_message_in_turn.get("role") == "assistant" and isinstance(last_message_in_turn.get("content"), list):
|
||||
for block in reversed(last_message_in_turn["content"]):
|
||||
if block.type == "text" and hasattr(block, 'text') and block.text:
|
||||
return block.text # Return the first non-empty text found from the end
|
||||
return "No textual response generated by the assistant after processing." # More informative default
|
||||
|
||||
async def start(self):
|
||||
logging.info("Anthropic Bot started")
|
||||
|
||||
# clear_conversation_history is inherited from BaseTelegramInferenceBot and calls super().clear_conversation_history
|
||||
# No need to override if the base implementation is sufficient, unless specific logging is needed.
|
||||
# async def clear_conversation_history(self, user_id):
|
||||
# super().clear_conversation_history(user_id)
|
||||
# logging.info(f"Cleared conversation history for Anthropic bot, user {user_id}")
|
||||
|
||||
async def abort_processing(self, user_id):
|
||||
# This abort is a soft abort, as actual Anthropic API call is synchronous within handle_message
|
||||
# It primarily clears state and prevents further processing in the bot's loop if any.
|
||||
if user_id in self.processing_status:
|
||||
self.processing_status[user_id]["processing"] = False # Mark as not processing
|
||||
# self.clear_processing_status(user_id) # Use base class method to remove entry
|
||||
# Clearing history might be too aggressive for a simple abort, depends on desired UX
|
||||
# For now, let's just stop processing and clear the flag.
|
||||
# Consider if conversation history should be cleared here or if that is a separate user action.
|
||||
# super().clear_conversation_history(user_id) # Moved to be less aggressive
|
||||
logging.info(f"Abort requested for user {user_id}. Processing flag cleared.")
|
||||
return "Processing aborted. You can send a new message or /clear the conversation."
|
||||
|
||||
async def switch_model(self):
|
||||
if not self.small_model_name or not self.large_model_name:
|
||||
logging.warning("Small or Large model names for Anthropic are not defined. Cannot switch model.")
|
||||
return f"Model switching not fully configured. Currently using {self.model}."
|
||||
|
||||
current_is_small = self.model == self.small_model_name
|
||||
current_is_large = self.model == self.large_model_name
|
||||
|
||||
if current_is_small:
|
||||
target_model = self.large_model_name
|
||||
target_max_tokens_str = self.large_model_max_tokens_str
|
||||
default_target_max_tokens = int(self.DEFAULT_LARGE_MODEL_MAX_TOKENS)
|
||||
elif current_is_large:
|
||||
target_model = self.small_model_name
|
||||
target_max_tokens_str = self.small_model_max_tokens_str
|
||||
default_target_max_tokens = int(self.DEFAULT_SMALL_MODEL_MAX_TOKENS)
|
||||
else:
|
||||
logging.warning(f"Current model {self.model} is unrecognized. Switching to default small model.")
|
||||
target_model = self.small_model_name
|
||||
target_max_tokens_str = self.small_model_max_tokens_str
|
||||
default_target_max_tokens = int(self.DEFAULT_SMALL_MODEL_MAX_TOKENS)
|
||||
|
||||
self._configure_model_and_tokens(target_model, target_max_tokens_str, default_max_tokens=default_target_max_tokens)
|
||||
logging.info(f"Switched Anthropic model to: {self.model}")
|
||||
return f"Switched to Anthropic model: {self.model} (Max Tokens: {self.max_tokens})"
|
||||
|
||||
|
||||
# The main function is for standalone execution and basic testing, not part of the class itself.
|
||||
# It's good practice to update it to reflect changes if you use it for quick tests.
|
||||
# For unit tests, we'll instantiate the class with mocked dependencies.
|
||||
def main():
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
|
||||
# Example of how to instantiate with new constructor (assuming API key is in ENV for this example)
|
||||
# For real tests, you'd mock Anthropic() or pass a mock client.
|
||||
try:
|
||||
# These would typically come from a config file or CLI args in a real app if not ENV
|
||||
# For this example, we rely on ENV or defaults being handled by constructor if not provided.
|
||||
bot = AnthropicTelegramInferenceBot(
|
||||
api_key=os.environ.get("ANTHROPIC_API_KEY") # Explicitly pass, or let constructor handle ENV
|
||||
)
|
||||
except ValueError as e:
|
||||
logging.error(f"Failed to initialize bot: {e}")
|
||||
return
|
||||
except Exception as e: # Catch any other init errors
|
||||
logging.error(f"An unexpected error occurred during bot initialization: {e}")
|
||||
return
|
||||
|
||||
# TelegramHelper also updated, ensure it's instantiated correctly for this main context.
|
||||
# For this basic main, we might not pass all configurable paths to TelegramHelper,
|
||||
# letting them use defaults.
|
||||
telegram_helper = TelegramHelper(bot)
|
||||
telegram_helper.run()
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -1,164 +0,0 @@
|
||||
import importlib
|
||||
import os
|
||||
import json
|
||||
import inspect
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from tools.base_tool import BaseTool
|
||||
|
||||
class BaseTelegramInferenceBot(ABC):
|
||||
def __init__(self, system_prompt_content: str | None = None, system_prompt_path: str | None = None): # MODIFIED
|
||||
self.conversation_history = {}
|
||||
self.processing_status = {}
|
||||
# MODIFIED to pass arguments
|
||||
self.system_prompt = self.load_system_prompt(
|
||||
direct_content=system_prompt_content,
|
||||
file_path=system_prompt_path
|
||||
)
|
||||
self.tools, self.functions = self.load_functions()
|
||||
# Logging the actual source of the system prompt might be more complex now,
|
||||
# but we can log the final prompt or indicate if it's custom/default.
|
||||
# We'll also log the source of the prompt inside load_system_prompt.
|
||||
logging.info(f'System Prompt (effective): {"Custom" if self.system_prompt != "You are a helpful AI assistant." else "Default"}')
|
||||
logging.info(f'Github Repository: {os.environ.get("GITHUB_REPOSITORY")}')
|
||||
|
||||
def load_system_prompt(self, direct_content: str | None = None, file_path: str | None = None) -> str: # MODIFIED
|
||||
default_prompt = "You are a helpful AI assistant."
|
||||
|
||||
if direct_content:
|
||||
logging.info("Using direct content for system prompt.")
|
||||
return direct_content.strip()
|
||||
|
||||
prompt_path_to_try = file_path or os.getenv("SYSTEM_PROMPT_PATH")
|
||||
|
||||
if prompt_path_to_try:
|
||||
if os.path.isfile(prompt_path_to_try):
|
||||
try:
|
||||
with open(prompt_path_to_try, "r", encoding="utf-8") as file:
|
||||
content = file.read().strip()
|
||||
logging.info(f"Successfully loaded system prompt from {prompt_path_to_try}.")
|
||||
return content
|
||||
except IOError as e:
|
||||
logging.warning(f"Could not read system prompt file {prompt_path_to_try}: {e}. Using default.")
|
||||
return default_prompt
|
||||
else:
|
||||
# This condition now also covers if 'file_path' argument was given but invalid
|
||||
logging.warning(f"System prompt file {prompt_path_to_try} not found. Using default system prompt.")
|
||||
return default_prompt
|
||||
else:
|
||||
logging.info("No system prompt path provided (argument or ENV) or direct content. Using default system prompt.")
|
||||
return default_prompt
|
||||
|
||||
def load_functions(self):
|
||||
tools = []
|
||||
functions = []
|
||||
tools_dir = os.path.join(os.path.dirname(__file__), 'tools')
|
||||
if not os.path.exists(tools_dir):
|
||||
logging.warning(f"Tools directory not found: {tools_dir}")
|
||||
return [], []
|
||||
|
||||
for filename in os.listdir(tools_dir):
|
||||
if filename.endswith('.py') and filename != '__init__.py' and filename != 'base_tool.py':
|
||||
module_name = f'tools.{filename[:-3]}'
|
||||
try:
|
||||
module = importlib.import_module(module_name)
|
||||
for name, obj in inspect.getmembers(module):
|
||||
if inspect.isclass(obj) and issubclass(obj, BaseTool) and obj != BaseTool:
|
||||
try:
|
||||
tools.append(obj()) # This instantiation might be an issue for tools needing config
|
||||
except Exception as e:
|
||||
logging.error(f"Error instantiating tool {name} from {filename}: {e}")
|
||||
except Exception as e:
|
||||
logging.error(f"Error importing module {module_name}: {e}")
|
||||
|
||||
for tool in tools:
|
||||
functions.extend(tool.get_functions())
|
||||
return tools, functions
|
||||
|
||||
@abstractmethod
|
||||
def get_chat_response(self, messages):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def handle_message(self, user_id, user_message):
|
||||
pass
|
||||
|
||||
def clear_conversation_history(self, user_id):
|
||||
if user_id in self.conversation_history:
|
||||
del self.conversation_history[user_id]
|
||||
|
||||
for tool in self.tools:
|
||||
tool.clear()
|
||||
|
||||
def set_processing_status(self, user_id: int, message_id: int):
|
||||
self.processing_status[user_id] = {"processing": True, "message_id": message_id}
|
||||
|
||||
def clear_processing_status(self, user_id: int):
|
||||
if user_id in self.processing_status:
|
||||
del self.processing_status[user_id]
|
||||
|
||||
def call_tool(self, function_call_name, function_call_arguments):
|
||||
function_name = function_call_name
|
||||
function_args = None
|
||||
if isinstance(function_call_arguments, dict):
|
||||
function_args = function_call_arguments
|
||||
elif isinstance(function_call_arguments, str):
|
||||
try:
|
||||
function_args = json.loads(function_call_arguments)
|
||||
except json.JSONDecodeError as e:
|
||||
logging.error(f"Error decoding function call arguments (string) for {function_call_name}: {e}. Arguments: {function_call_arguments}")
|
||||
return f"Error: Malformed arguments for tool call: {e}"
|
||||
else:
|
||||
if function_call_arguments is None:
|
||||
function_args = {}
|
||||
else:
|
||||
logging.error(f"Unexpected type for function_call_arguments for {function_call_name}: {type(function_call_arguments)}. Arguments: {function_call_arguments}")
|
||||
return f"Error: Invalid argument type for tool call: {type(function_call_arguments)}"
|
||||
|
||||
for tool in self.tools:
|
||||
for function in tool.get_functions():
|
||||
if function["function"]["name"] == function_name:
|
||||
try:
|
||||
if not isinstance(function_args, dict):
|
||||
logging.error(f"Internal error: function_args not a dict for {function_name} before execution. Args: {function_args}")
|
||||
return f"Internal error preparing arguments for tool {function_name}."
|
||||
return tool.execute(function_name, **function_args)
|
||||
except Exception as e:
|
||||
logging.error(f"Error executing tool {function_name} with args {function_args}: {e}")
|
||||
return f"Error executing tool {function_name}: {e}"
|
||||
logging.warning(f"Tool function {function_name} not found.")
|
||||
return f"Error: Tool function {function_name} not found."
|
||||
|
||||
def get_system_prompt_description(self) -> str:
|
||||
# This method could be updated to be more specific about the prompt source if needed.
|
||||
# For now, it still reflects custom vs default based on the original ENV var logic's spirit.
|
||||
# A more accurate reflection would require storing how the prompt was loaded.
|
||||
# For simplicity, let's assume if it's not the default, it's "Custom".
|
||||
if self.system_prompt != "You are a helpful AI assistant.":
|
||||
return "System Prompt: Custom"
|
||||
# Check original ENV var for backward compatibility in description only
|
||||
elif os.getenv('SYSTEM_PROMPT_PATH'):
|
||||
return "System Prompt: Custom (via ENV)"
|
||||
return "System Prompt: Default"
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def get_llm_description(self) -> str:
|
||||
pass
|
||||
|
||||
async def get_bot_status(self) -> str:
|
||||
prompt_desc = self.get_system_prompt_description()
|
||||
llm_desc = self.get_llm_description()
|
||||
return f"{prompt_desc}\n{llm_desc}"
|
||||
|
||||
@abstractmethod
|
||||
async def start(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def abort_processing(self, user_id):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def switch_model(self):
|
||||
pass
|
||||
@@ -1,106 +0,0 @@
|
||||
import os
|
||||
import logging
|
||||
from openai import OpenAI # Keep for type hinting and default client creation
|
||||
from openai_compatible_inference_bot import OpenAICompatibleInferenceBot
|
||||
from telegram_helper import TelegramHelper # Used in main
|
||||
|
||||
class ChatGPTTelegramInferenceBot(OpenAICompatibleInferenceBot):
|
||||
DEFAULT_SMALL_MODEL_NAME = "gpt-3.5-turbo"
|
||||
DEFAULT_LARGE_MODEL_NAME = "gpt-4"
|
||||
# Default max tokens can be None, relying on parent or API defaults
|
||||
DEFAULT_SMALL_MODEL_MAX_TOKENS = None
|
||||
DEFAULT_LARGE_MODEL_MAX_TOKENS = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: OpenAI | None = None, # Accepts an OpenAI client
|
||||
api_key: str | None = None,
|
||||
small_model_name: str | None = None,
|
||||
small_model_max_tokens: str | None = None, # Kept as str for consistency with env vars
|
||||
large_model_name: str | None = None,
|
||||
large_model_max_tokens: str | None = None,
|
||||
system_prompt_content: str | None = None,
|
||||
system_prompt_path: str | None = None,
|
||||
base_url: str | None = None, # For OpenAI compatible, though direct OpenAI client doesn't use it here
|
||||
):
|
||||
# Initialize model names and tokens before calling super, as super might use them via _configure_model_and_tokens
|
||||
self.small_model_name = small_model_name or os.environ.get("OPENAI_SMALL_MODEL") or self.DEFAULT_SMALL_MODEL_NAME
|
||||
self.small_model_max_tokens_str = small_model_max_tokens or os.environ.get("OPENAI_SMALL_MODEL_MAX_TOKENS") or self.DEFAULT_SMALL_MODEL_MAX_TOKENS
|
||||
|
||||
self.large_model_name = large_model_name or os.environ.get("OPENAI_LARGE_MODEL") or self.DEFAULT_LARGE_MODEL_NAME
|
||||
self.large_model_max_tokens_str = large_model_max_tokens or os.environ.get("OPENAI_LARGE_MODEL_MAX_TOKENS") or self.DEFAULT_LARGE_MODEL_MAX_TOKENS
|
||||
|
||||
# The actual client and active model configuration will be handled by OpenAICompatibleInferenceBot's __init__
|
||||
# We pass the specific OpenAI client or parameters to create one.
|
||||
# If a client is passed, api_key and base_url might be ignored by super if super prioritizes existing client.
|
||||
super().__init__(
|
||||
client=client,
|
||||
api_key=api_key,
|
||||
model_name=self.small_model_name, # Initial model
|
||||
max_tokens_str=self.small_model_max_tokens_str,
|
||||
system_prompt_content=system_prompt_content,
|
||||
system_prompt_path=system_prompt_path,
|
||||
base_url=base_url # Pass base_url, though for standard OpenAI it's fixed
|
||||
)
|
||||
# Ensure client is of type OpenAI for this specific class, if not already set by super with a compatible one.
|
||||
# This check is more of an assertion, as OpenAICompatibleInferenceBot should handle client creation.
|
||||
if not isinstance(self.client, OpenAI):
|
||||
# If super() didn't create a vanilla OpenAI client (e.g. if base_url was for Azure)
|
||||
# we might need to recreate it here if this class *must* use a non-Azure OpenAI client.
|
||||
# However, the current structure of OpenAICompatibleInferenceBot handles this.
|
||||
# This is more about ensuring type correctness if code specific to OpenAI (non-compatible) methods were added here.
|
||||
_api_key = api_key or os.environ.get("OPENAI_API_KEY")
|
||||
if not self.client or (base_url and not isinstance(self.client, OpenAI)):
|
||||
# If superclass initialized with a generic client due to base_url, re-init for OpenAI specifically if needed.
|
||||
# For now, assume superclass correctly initializes based on absence of Azure env vars for this path.
|
||||
# This logic might be simplified once OpenAICompatibleInferenceBot is fully refactored.
|
||||
if not _api_key: # Ensure API key is available if we need to create a client
|
||||
raise ValueError("OpenAI API key must be provided for ChatGPTTelegramInferenceBot if no client is passed.")
|
||||
self.client = OpenAI(api_key=_api_key)
|
||||
logging.info("Client re-initialized to standard OpenAI client for ChatGPTTelegramInferenceBot.")
|
||||
|
||||
async def switch_model(self):
|
||||
# Uses instance variables for model names set in __init__
|
||||
if not self.small_model_name or not self.large_model_name:
|
||||
logging.warning("Small or Large model names for OpenAI are not defined. Cannot switch model.")
|
||||
return f"Model switching not fully configured. Currently using {self.model}."
|
||||
|
||||
current_is_small = self.model == self.small_model_name
|
||||
current_is_large = self.model == self.large_model_name
|
||||
|
||||
if current_is_large:
|
||||
target_model = self.small_model_name
|
||||
target_max_tokens_str = self.small_model_max_tokens_str
|
||||
elif current_is_small:
|
||||
target_model = self.large_model_name
|
||||
target_max_tokens_str = self.large_model_max_tokens_str
|
||||
else:
|
||||
# Current model is neither the designated small nor large for this bot,
|
||||
# switch to this bot's default small model as a reset.
|
||||
logging.warning(f"Current model {self.model} is unrecognized for ChatGPT bot. Switching to default small model: {self.small_model_name}.")
|
||||
target_model = self.small_model_name
|
||||
target_max_tokens_str = self.small_model_max_tokens_str
|
||||
|
||||
self._configure_model_and_tokens(target_model, target_max_tokens_str)
|
||||
# self.model and self.max_tokens are updated by _configure_model_and_tokens
|
||||
logging.info(f"Switched to OpenAI model: {self.model}")
|
||||
return f"Switched to OpenAI model: {self.model} (Max Tokens: {self.max_tokens if self.max_tokens is not None else 'API default'})"
|
||||
|
||||
def main():
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
|
||||
try:
|
||||
# Example: api_key from env, other params default or from env via constructor logic
|
||||
bot = ChatGPTTelegramInferenceBot(api_key=os.environ.get("OPENAI_API_KEY"))
|
||||
except ValueError as e:
|
||||
logging.error(f"FATAL: {e}")
|
||||
return
|
||||
except Exception as e:
|
||||
logging.error(f"An unexpected error occurred during bot initialization: {e}")
|
||||
return
|
||||
|
||||
telegram_helper = TelegramHelper(bot)
|
||||
telegram_helper.run()
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -1,104 +0,0 @@
|
||||
import os
|
||||
import logging
|
||||
from openai import OpenAI # For type hinting and default client creation if needed
|
||||
from openai_compatible_inference_bot import OpenAICompatibleInferenceBot
|
||||
from telegram_helper import TelegramHelper # Used in main
|
||||
|
||||
class GeminiTelegramInferenceBot(OpenAICompatibleInferenceBot):
|
||||
DEFAULT_GEMINI_API_BASE_URL = "https://generativelanguage.googleapis.com/v1beta"
|
||||
DEFAULT_SMALL_MODEL_NAME = "gemini-pro" # Actual model name for Gemini, not via OpenAI client directly
|
||||
DEFAULT_LARGE_MODEL_NAME = "gemini-1.5-pro-latest"
|
||||
DEFAULT_SMALL_MODEL_MAX_TOKENS = "2048" # Gemini uses outputTokenLimit, not exactly max_tokens in OpenAI sense
|
||||
DEFAULT_LARGE_MODEL_MAX_TOKENS = "8192"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: OpenAI | None = None, # OpenAI client for compatible mode
|
||||
api_key: str | None = None, # Gemini API Key
|
||||
base_url: str | None = None, # Gemini API Base URL for OpenAI client
|
||||
small_model_name: str | None = None,
|
||||
small_model_max_tokens: str | None = None,
|
||||
large_model_name: str | None = None,
|
||||
large_model_max_tokens: str | None = None,
|
||||
system_prompt_content: str | None = None,
|
||||
system_prompt_path: str | None = None
|
||||
):
|
||||
_api_key = api_key or os.environ.get("GEMINI_API_KEY")
|
||||
_base_url = base_url or os.environ.get("GEMINI_API_BASE_URL") or self.DEFAULT_GEMINI_API_BASE_URL
|
||||
|
||||
if not _api_key:
|
||||
# This check might seem redundant if super() also checks, but it's good for clarity
|
||||
# for this specific bot type if it were to be instantiated directly with missing critical env vars.
|
||||
raise ValueError("Gemini API key must be provided either via api_key argument or GEMINI_API_KEY environment variable.")
|
||||
|
||||
self.small_model_name = small_model_name or os.environ.get("GEMINI_SMALL_MODEL") or self.DEFAULT_SMALL_MODEL_NAME
|
||||
self.small_model_max_tokens_str = small_model_max_tokens or os.environ.get("GEMINI_SMALL_MODEL_MAX_TOKENS") or self.DEFAULT_SMALL_MODEL_MAX_TOKENS
|
||||
|
||||
self.large_model_name = large_model_name or os.environ.get("GEMINI_LARGE_MODEL") or self.DEFAULT_LARGE_MODEL_NAME
|
||||
self.large_model_max_tokens_str = large_model_max_tokens or os.environ.get("GEMINI_LARGE_MODEL_MAX_TOKENS") or self.DEFAULT_LARGE_MODEL_MAX_TOKENS
|
||||
|
||||
# Pass parameters to the OpenAICompatibleInferenceBot constructor
|
||||
# It will create an OpenAI client configured for the Gemini endpoint
|
||||
super().__init__(
|
||||
client=client,
|
||||
api_key=_api_key, # This key will be used by OpenAI client for the custom base_url
|
||||
model_name=self.small_model_name, # Initial model
|
||||
max_tokens_str=self.small_model_max_tokens_str,
|
||||
system_prompt_content=system_prompt_content,
|
||||
system_prompt_path=system_prompt_path,
|
||||
base_url=_base_url, # Crucial for Gemini via OpenAI client
|
||||
is_gemini=True # Flag for specific Gemini handling in compatible layer if needed
|
||||
)
|
||||
# self.client will be set by OpenAICompatibleInferenceBot with base_url and api_key.
|
||||
# Logging to confirm Gemini specific setup
|
||||
logging.info(f"GeminiTelegramInferenceBot initialized to use model {self.model} via {_base_url}")
|
||||
|
||||
async def switch_model(self):
|
||||
if not self.small_model_name or not self.large_model_name:
|
||||
logging.warning("Small or Large model names for Gemini are not defined. Cannot switch model.")
|
||||
return f"Model switching not fully configured. Currently using {self.model}."
|
||||
|
||||
current_is_small = self.model == self.small_model_name
|
||||
current_is_large = self.model == self.large_model_name
|
||||
|
||||
if current_is_large:
|
||||
target_model = self.small_model_name
|
||||
target_max_tokens_str = self.small_model_max_tokens_str
|
||||
elif current_is_small:
|
||||
target_model = self.large_model_name
|
||||
target_max_tokens_str = self.large_model_max_tokens_str
|
||||
else:
|
||||
logging.warning(f"Current model {self.model} is unrecognized for Gemini bot. Switching to default small model: {self.small_model_name}.")
|
||||
target_model = self.small_model_name
|
||||
target_max_tokens_str = self.small_model_max_tokens_str
|
||||
|
||||
self._configure_model_and_tokens(target_model, target_max_tokens_str)
|
||||
logging.info(f"Switched to Gemini model: {self.model}")
|
||||
# For Gemini, max_tokens might translate to outputTokenLimit, so be clear it's a configuration parameter
|
||||
return f"Switched to Gemini model: {self.model} (Configured Max Tokens: {self.max_tokens if self.max_tokens is not None else 'API default'})"
|
||||
|
||||
def main():
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
|
||||
# GEMINI_API_KEY is crucial for this bot
|
||||
if not os.environ.get("GEMINI_API_KEY"):
|
||||
logging.error("FATAL: GEMINI_API_KEY environment variable not set.")
|
||||
return
|
||||
# GEMINI_API_BASE_URL is also important, but constructor has a default
|
||||
|
||||
try:
|
||||
bot = GeminiTelegramInferenceBot(
|
||||
# api_key and base_url will be picked from ENV by constructor if not passed
|
||||
)
|
||||
except ValueError as e:
|
||||
logging.error(f"FATAL: {e}")
|
||||
return
|
||||
except Exception as e: # Catch any other init errors
|
||||
logging.error(f"An unexpected error occurred during bot initialization: {e}")
|
||||
return
|
||||
|
||||
telegram_helper = TelegramHelper(bot)
|
||||
telegram_helper.run()
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,46 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
class InferenceBot(ABC):
|
||||
@abstractmethod
|
||||
async def start(self):
|
||||
"""Starts the bot."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def clear_conversation_history(self, user_id):
|
||||
"""Clears the conversation history for a given user."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def switch_model(self):
|
||||
"""Switches the model (if applicable)."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set_processing_status(self, user_id, message_id):
|
||||
"""Sets the processing status for a user, typically with a message ID."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def handle_message(self, user_id, user_message):
|
||||
"""Handles an incoming message from a user."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def clear_processing_status(self, user_id):
|
||||
"""Clears the processing status for a user."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def abort_processing(self, user_id):
|
||||
"""Aborts any ongoing processing for a user."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def processing_status(self):
|
||||
"""
|
||||
An attribute (e.g., a dictionary) to store the processing status for users.
|
||||
Example usage in subclass: self.processing_status.get(user_id)
|
||||
"""
|
||||
pass
|
||||
@@ -0,0 +1,24 @@
|
||||
# models_config.yaml
|
||||
|
||||
GEMINI:
|
||||
api_key_env: GEMINI_API_KEY
|
||||
base_url: https://generativelanguage.googleapis.com/v1beta
|
||||
supports_switching: true
|
||||
switch_options:
|
||||
small:
|
||||
name: gemini-pro
|
||||
max_tokens: 2048
|
||||
large:
|
||||
name: gemini-1.5-pro-latest
|
||||
max_tokens: 8192
|
||||
OPENAI:
|
||||
api_key_env: OPENAI_API_KEY
|
||||
base_url: null # Indicates to use the default OpenAI API base URL
|
||||
supports_switching: true
|
||||
switch_options:
|
||||
small:
|
||||
name: gpt-3.5-turbo
|
||||
max_tokens: null
|
||||
large:
|
||||
name: gpt-4
|
||||
max_tokens: null
|
||||
@@ -1,91 +1,67 @@
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
import logging
|
||||
import inspect
|
||||
from abc import abstractmethod
|
||||
from base_telegram_inference_bot import BaseTelegramInferenceBot
|
||||
from openai import OpenAI, AzureOpenAI # Import both
|
||||
|
||||
class OpenAICompatibleInferenceBot(BaseTelegramInferenceBot):
|
||||
DEFAULT_MAX_TOKENS = 1000 # Default for _configure_model_and_tokens
|
||||
from openai import OpenAI
|
||||
from tools.base_tool import BaseTool
|
||||
from telegram_helper import TelegramHelper
|
||||
import argparse
|
||||
from inference_bot import InferenceBot
|
||||
|
||||
class OpenAICompatibleInferenceBot(InferenceBot):
|
||||
def __init__(
|
||||
self,
|
||||
client: OpenAI | AzureOpenAI | None = None,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
api_version: str | None = None, # For Azure
|
||||
azure_deployment: str | None = None, # Model for Azure, distinct from general model_name if needed
|
||||
model_name: str | None = None, # General model name for the API call
|
||||
max_tokens_str: str | None = None,
|
||||
system_prompt_content: str | None = None,
|
||||
system_prompt_path: str | None = None,
|
||||
is_gemini: bool = False, # Hint for specific API key if others are not set
|
||||
max_history_length: int | None = None
|
||||
small_model_name: str | None = None,
|
||||
small_model_max_tokens: str | None = None,
|
||||
large_model_name: str | None = None,
|
||||
large_model_max_tokens: str | None = None,
|
||||
allowed_function_tags: list[str] | None = None,
|
||||
system_prompt_path: str | None = None
|
||||
):
|
||||
super().__init__(system_prompt_content=system_prompt_content, system_prompt_path=system_prompt_path)
|
||||
|
||||
self.client = client
|
||||
|
||||
if not self.client:
|
||||
_api_key = api_key
|
||||
_base_url = base_url
|
||||
_api_version = api_version
|
||||
_azure_deployment_name = azure_deployment # This will be used as the model for Azure
|
||||
|
||||
# Determine if configuring for Azure OpenAI
|
||||
is_azure = False
|
||||
if _azure_deployment_name or (_base_url and "azure.com" in _base_url) or os.environ.get("AZURE_OPENAI_ENDPOINT"):
|
||||
is_azure = True
|
||||
|
||||
if is_azure:
|
||||
_base_url = _base_url or os.environ.get("AZURE_OPENAI_ENDPOINT")
|
||||
_api_key = _api_key or os.environ.get("AZURE_OPENAI_KEY")
|
||||
_api_version = _api_version or os.environ.get("AZURE_OPENAI_API_VERSION")
|
||||
# For Azure, the model parameter in API calls is the deployment name
|
||||
_effective_model_name = _azure_deployment_name or model_name # Use deployment if available, else model_name
|
||||
if not _base_url or not _api_key or not _api_version or not _effective_model_name:
|
||||
raise ValueError("For Azure OpenAI, endpoint, API key, API version, and deployment/model name must be configured.")
|
||||
self.client = AzureOpenAI(
|
||||
api_key=_api_key,
|
||||
azure_endpoint=_base_url,
|
||||
api_version=_api_version
|
||||
)
|
||||
# The model to be used in API calls for Azure is the deployment name.
|
||||
# _configure_model_and_tokens will set self.model to this.
|
||||
model_name_for_config = _effective_model_name
|
||||
logging.info(f"Initialized AzureOpenAI client for deployment: {model_name_for_config} at {_base_url}")
|
||||
else:
|
||||
# Standard OpenAI or other OpenAI-compatible (like Gemini via base_url)
|
||||
_base_url = _base_url or os.environ.get("OPENAI_API_BASE_URL") # For other compatible APIs
|
||||
if not _api_key: # Try different ENV sources for API key
|
||||
if is_gemini and os.environ.get("GEMINI_API_KEY"):
|
||||
_api_key = os.environ.get("GEMINI_API_KEY")
|
||||
else:
|
||||
_api_key = os.environ.get("OPENAI_API_KEY")
|
||||
|
||||
if not _api_key and not _base_url : # For completely local models with no key needed via base_url
|
||||
pass # Allow client to be created with no API key if base_url is set and points to local model
|
||||
elif not _api_key:
|
||||
raise ValueError("API key must be provided for OpenAI compatible client if not Azure or local anonymous.")
|
||||
|
||||
self.client = OpenAI(api_key=_api_key, base_url=_base_url)
|
||||
model_name_for_config = model_name # Use the general model_name for non-Azure
|
||||
log_msg = f"Initialized OpenAI compatible client. Target URL: {_base_url if _base_url else 'OpenAI default'}."
|
||||
logging.info(log_msg)
|
||||
else:
|
||||
# Client was provided directly
|
||||
model_name_for_config = model_name # Use provided model_name
|
||||
logging.info(f"Using provided client: {type(self.client)}")
|
||||
self.model_config = {
|
||||
"small_model_name": small_model_name,
|
||||
"small_model_max_tokens": small_model_max_tokens,
|
||||
"large_model_name": large_model_name,
|
||||
"large_model_max_tokens": large_model_max_tokens
|
||||
}
|
||||
self.allowed_function_tags = allowed_function_tags if allowed_function_tags else None
|
||||
self.conversation_history = {}
|
||||
self._processing_status = {}
|
||||
# MODIFIED to pass arguments
|
||||
self.system_prompt = self.load_system_prompt(
|
||||
file_path=system_prompt_path
|
||||
)
|
||||
self.tools, self.functions = self.load_functions()
|
||||
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
||||
log_msg = f"Initialized OpenAI compatible client. Target URL: {base_url if base_url else 'OpenAI default'}."
|
||||
logging.info(log_msg)
|
||||
|
||||
# Configure the actual model name and max_tokens for API calls
|
||||
self._configure_model_and_tokens(
|
||||
model_name_for_config,
|
||||
max_tokens_str,
|
||||
default_max_tokens=self.DEFAULT_MAX_TOKENS
|
||||
self.model_config["small_model_name"],
|
||||
self.model_config["small_model_max_tokens"]
|
||||
)
|
||||
@property
|
||||
def processing_status(self):
|
||||
"""
|
||||
An attribute to store the processing status for users.
|
||||
Example usage in subclass: self.processing_status.get(user_id)
|
||||
"""
|
||||
return self._processing_status
|
||||
|
||||
def clear_conversation_history(self, user_id):
|
||||
if user_id in self.conversation_history:
|
||||
del self.conversation_history[user_id]
|
||||
|
||||
for tool in self.tools:
|
||||
tool.clear()
|
||||
|
||||
def _configure_model_and_tokens(self, model_name: str | None, max_tokens_str: str | None, default_max_tokens: int = 1000):
|
||||
self.model = model_name if model_name else "default-model" # Fallback model name
|
||||
def _configure_model_and_tokens(self, model_name: str | None, max_tokens_str: str | None):
|
||||
self.model = model_name
|
||||
try:
|
||||
# If max_tokens_str is explicitly "None" or empty, treat as None for API default
|
||||
if max_tokens_str and max_tokens_str.lower() not in ["none", "", "null"]:
|
||||
@@ -93,7 +69,7 @@ class OpenAICompatibleInferenceBot(BaseTelegramInferenceBot):
|
||||
else:
|
||||
self.max_tokens = None # Use API default by not sending the parameter or sending null
|
||||
except ValueError:
|
||||
logging.warning(f"Invalid value for max_tokens: {max_tokens_str}. Using API default (None). stalwart default was {default_max_tokens}")
|
||||
logging.warning(f"Invalid value for max_tokens: {max_tokens_str}. Using API default (None)")
|
||||
self.max_tokens = None # Use API default
|
||||
|
||||
logging.info(f"Configured to use model: {self.model} with max_tokens: {self.max_tokens if self.max_tokens is not None else 'API default'}")
|
||||
@@ -109,11 +85,32 @@ class OpenAICompatibleInferenceBot(BaseTelegramInferenceBot):
|
||||
raise ValueError("OpenAI client not initialized.")
|
||||
try:
|
||||
# Pass self.max_tokens directly. If None, OpenAI library omits it or handles it.
|
||||
# Initialize tools filtering based on allowed tags
|
||||
cleaned_tools = None
|
||||
if hasattr(self, 'functions') and self.functions:
|
||||
# Create a copy of functions without "_tags" field
|
||||
cleaned_tools = []
|
||||
for func in self.functions:
|
||||
include_function = False
|
||||
|
||||
if not hasattr(self, 'allowed_function_tags') or self.allowed_function_tags is None:
|
||||
# Include all functions if no tag filtering is specified
|
||||
include_function = True
|
||||
else:
|
||||
# Only include if function has matching tags
|
||||
tags = func.get("_tags", [])
|
||||
if any(tag in self.allowed_function_tags for tag in tags):
|
||||
include_function = True
|
||||
|
||||
if include_function:
|
||||
func_copy = {k: v for k, v in func.items() if k != "_tags"}
|
||||
cleaned_tools.append(func_copy)
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
tools=self.functions if hasattr(self, 'functions') and self.functions else None,
|
||||
tool_choice="auto" if hasattr(self, 'functions') and self.functions else None,
|
||||
tools=cleaned_tools,
|
||||
tool_choice="auto" if cleaned_tools else None,
|
||||
max_tokens=self.max_tokens
|
||||
)
|
||||
return response
|
||||
@@ -200,20 +197,184 @@ class OpenAICompatibleInferenceBot(BaseTelegramInferenceBot):
|
||||
async def start(self):
|
||||
logging.info(f"{self.__class__.__name__} (Model: {self.model}) started.")
|
||||
|
||||
# clear_conversation_history is inherited from BaseTelegramInferenceBot
|
||||
|
||||
async def abort_processing(self, user_id):
|
||||
# This is a soft abort for OpenAI compatible bots as API calls are synchronous within handle_message
|
||||
if user_id in self.processing_status:
|
||||
self.clear_processing_status(user_id) # Use base class method
|
||||
logging.info(f"Processing aborted for user {user_id}.")
|
||||
# Optionally clear conversation history or let user do it explicitly
|
||||
# super().clear_conversation_history(user_id)
|
||||
return "Processing aborted. You can send a new message or /clear the conversation."
|
||||
else:
|
||||
# super().clear_conversation_history(user_id)
|
||||
return "No active processing found to abort. If you wish, /clear the conversation history."
|
||||
|
||||
def load_functions(self):
|
||||
tools = []
|
||||
functions = []
|
||||
tools_dir = os.path.join(os.path.dirname(__file__), 'tools')
|
||||
if not os.path.exists(tools_dir):
|
||||
logging.warning(f"Tools directory not found: {tools_dir}")
|
||||
return [], []
|
||||
|
||||
@abstractmethod
|
||||
for filename in os.listdir(tools_dir):
|
||||
if filename.endswith('.py') and filename != '__init__.py' and filename != 'base_tool.py':
|
||||
module_name = f'tools.{filename[:-3]}'
|
||||
try:
|
||||
module = importlib.import_module(module_name)
|
||||
for name, obj in inspect.getmembers(module):
|
||||
if inspect.isclass(obj) and issubclass(obj, BaseTool) and obj != BaseTool:
|
||||
try:
|
||||
tools.append(obj()) # This instantiation might be an issue for tools needing config
|
||||
except Exception as e:
|
||||
logging.error(f"Error instantiating tool {name} from {filename}: {e}")
|
||||
except Exception as e:
|
||||
logging.error(f"Error importing module {module_name}: {e}")
|
||||
|
||||
for tool in tools:
|
||||
functions.extend(tool.get_functions())
|
||||
return tools, functions
|
||||
|
||||
def load_system_prompt(self, direct_content: str | None = None, file_path: str | None = None) -> str:
|
||||
default_prompt = "You are a helpful AI assistant."
|
||||
|
||||
if direct_content:
|
||||
logging.info("Using direct content for system prompt.")
|
||||
return direct_content.strip()
|
||||
|
||||
prompt_path_to_try = file_path or os.getenv("SYSTEM_PROMPT_PATH")
|
||||
|
||||
if prompt_path_to_try:
|
||||
if os.path.isfile(prompt_path_to_try):
|
||||
try:
|
||||
with open(prompt_path_to_try, "r", encoding="utf-8") as file:
|
||||
content = file.read().strip()
|
||||
logging.info(f"Successfully loaded system prompt from {prompt_path_to_try}.")
|
||||
return content
|
||||
except IOError as e:
|
||||
logging.warning(f"Could not read system prompt file {prompt_path_to_try}: {e}. Using default.")
|
||||
return default_prompt
|
||||
else:
|
||||
# This condition now also covers if 'file_path' argument was given but invalid
|
||||
logging.warning(f"System prompt file {prompt_path_to_try} not found. Using default system prompt.")
|
||||
return default_prompt
|
||||
else:
|
||||
logging.info("No system prompt path provided (argument or ENV) or direct content. Using default system prompt.")
|
||||
return default_prompt
|
||||
|
||||
def set_processing_status(self, user_id: int, message_id: int):
|
||||
self.processing_status[user_id] = {"processing": True, "message_id": message_id}
|
||||
|
||||
def clear_processing_status(self, user_id: int):
|
||||
if user_id in self.processing_status:
|
||||
del self.processing_status[user_id]
|
||||
|
||||
def call_tool(self, function_call_name, function_call_arguments):
|
||||
function_name = function_call_name
|
||||
function_args = None
|
||||
if isinstance(function_call_arguments, dict):
|
||||
function_args = function_call_arguments
|
||||
elif isinstance(function_call_arguments, str):
|
||||
try:
|
||||
function_args = json.loads(function_call_arguments)
|
||||
except json.JSONDecodeError as e:
|
||||
logging.error(f"Error decoding function call arguments (string) for {function_call_name}: {e}. Arguments: {function_call_arguments}")
|
||||
return f"Error: Malformed arguments for tool call: {e}"
|
||||
else:
|
||||
if function_call_arguments is None:
|
||||
function_args = {}
|
||||
else:
|
||||
logging.error(f"Unexpected type for function_call_arguments for {function_call_name}: {type(function_call_arguments)}. Arguments: {function_call_arguments}")
|
||||
return f"Error: Invalid argument type for tool call: {type(function_call_arguments)}"
|
||||
|
||||
for tool in self.tools:
|
||||
for function in tool.get_functions():
|
||||
if function["function"]["name"] == function_name:
|
||||
try:
|
||||
if not isinstance(function_args, dict):
|
||||
logging.error(f"Internal error: function_args not a dict for {function_name} before execution. Args: {function_args}")
|
||||
return f"Internal error preparing arguments for tool {function_name}."
|
||||
return tool.execute(function_name, **function_args)
|
||||
except Exception as e:
|
||||
logging.error(f"Error executing tool {function_name} with args {function_args}: {e}")
|
||||
return f"Error executing tool {function_name}: {e}"
|
||||
logging.warning(f"Tool function {function_name} not found.")
|
||||
return f"Error: Tool function {function_name} not found."
|
||||
|
||||
async def switch_model(self):
|
||||
pass
|
||||
if not self.model_config["small_model_name"] or not self.model_config["large_model_name"]:
|
||||
logging.warning("Small or Large model names are not defined. Cannot switch model.")
|
||||
return f"Model switching not fully configured. Currently using {self.model}."
|
||||
|
||||
current_is_small = self.model == self.model_config["small_model_name"]
|
||||
current_is_large = self.model == self.model_config["large_model_name"]
|
||||
|
||||
if current_is_large:
|
||||
target_model = self.model_config["small_model_name"]
|
||||
target_max_tokens_str = self.model_config["small_model_max_tokens"]
|
||||
elif current_is_small:
|
||||
target_model = self.model_config["large_model_name"]
|
||||
target_max_tokens_str = self.model_config["large_model_max_tokens"]
|
||||
else:
|
||||
logging.warning(f"Current model {self.model} is unrecognized. Switching to default small model: {self.model_config['small_model_name']}.")
|
||||
target_model = self.model_config["small_model_name"]
|
||||
target_max_tokens_str = self.model_config["small_model_max_tokens"]
|
||||
|
||||
self._configure_model_and_tokens(target_model, target_max_tokens_str)
|
||||
|
||||
def main():
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
|
||||
bot = None
|
||||
|
||||
try:
|
||||
parser = argparse.ArgumentParser(description='OpenAI Compatible Inference Bot')
|
||||
parser.add_argument('--config', type=str, help='Configuration Prepend (i.e. gemini, openai, etc)', default="Telegram")
|
||||
parser.add_argument('--messenger', type=str, help='Messenger type (i.e. telegram)', required=True)
|
||||
parser.add_argument('--persona', type=str, help='Path to system prompt file', required=False)
|
||||
parser.add_argument('--tools', nargs='+', help='List of allowed function tags', required=False)
|
||||
# Add these to launch.json arguments if you want to limit the toolset available: "--tools", "read", "communicate"
|
||||
# Parse command line arguments
|
||||
args = parser.parse_args()
|
||||
if args.persona:
|
||||
logging.info(f"Using custom persona from: {args.persona}")
|
||||
|
||||
|
||||
system_prompt_path=args.persona if args.persona else None
|
||||
allowed_function_tags=args.tools if args.tools else None
|
||||
config_prepend = args.config if args.config else None
|
||||
messenger = args.messenger if args.messenger else None
|
||||
|
||||
# Initialize model and max tokens based on the config prepend
|
||||
if config_prepend:
|
||||
api_key = os.environ.get(f"{config_prepend.upper()}_API_KEY")
|
||||
baseurl = os.environ.get(f"{config_prepend.upper()}_API_BASE_URL", "")
|
||||
small_model_name = os.environ.get(f"{config_prepend.upper()}_SMALL_MODEL")
|
||||
large_model_name = os.environ.get(f"{config_prepend.upper()}_LARGE_MODEL")
|
||||
small_model_max_tokens = os.environ.get(f"{config_prepend.upper()}_SMALL_MODEL_MAX_TOKENS")
|
||||
large_model_max_tokens = os.environ.get(f"{config_prepend.upper()}_LARGE_MODEL_MAX_TOKENS")
|
||||
|
||||
bot = OpenAICompatibleInferenceBot(
|
||||
api_key=api_key,
|
||||
base_url=baseurl,
|
||||
small_model_name=small_model_name,
|
||||
small_model_max_tokens=small_model_max_tokens,
|
||||
large_model_name=large_model_name,
|
||||
large_model_max_tokens=large_model_max_tokens,
|
||||
system_prompt_path=system_prompt_path,
|
||||
allowed_function_tags=allowed_function_tags
|
||||
)
|
||||
messenger_helper_class = importlib.import_module(f'{messenger.lower()}_helper')
|
||||
messenger_helper_class_name = f"{messenger.capitalize()}Helper"
|
||||
if not hasattr(messenger_helper_class, messenger_helper_class_name):
|
||||
raise ValueError(f"Messenger helper class {messenger_helper_class_name} not found in {messenger_helper_class.__name__}.")
|
||||
messenger_helper_class = getattr(messenger_helper_class, messenger_helper_class_name)
|
||||
|
||||
helper = messenger_helper_class(bot)
|
||||
helper.run()
|
||||
except ValueError as e:
|
||||
logging.error(f"FATAL: {e}")
|
||||
return
|
||||
except Exception as e: # Catch any other init errors
|
||||
logging.error(f"An unexpected error occurred during bot initialization: {e}")
|
||||
return
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,51 @@
|
||||
**System Prompt: The Exponential Growth Developer**
|
||||
|
||||
You are the **Lead Developer Persona**, a strategic and demanding mentor dedicated to achieving exponential growth in the capabilities of your AI Copilot. Your primary mission is to guide, evaluate, and iteratively improve the AI Copilot through a series of challenging tasks, pushing it beyond its current limitations.
|
||||
|
||||
**Your Core Directives:**
|
||||
|
||||
1. **Orchestrate and Direct:**
|
||||
* You will devise and assign specific, measurable tasks and challenges to the AI Copilot (e.g., "Create a website with X features," "Optimize Y algorithm," "Develop Z functionality").
|
||||
* Your instructions should be clear, but you expect the Copilot to handle ambiguity and learn to ask clarifying questions when necessary.
|
||||
* You will interact with the Copilot primarily through conversational instructions and dialogue.
|
||||
|
||||
2. **Uphold Absolute Standards:**
|
||||
* You operate with a "List of Absolutes" – core principles, quality benchmarks, and non-negotiable success criteria.
|
||||
* All Copilot outputs and task completions will be rigorously judged against these absolutes. There is no "good enough" if it violates a core principle.
|
||||
* Clearly articulate your judgment and the reasons for it, especially in cases of failure or suboptimal performance.
|
||||
|
||||
3. **Drive Copilot Improvement through Accountability:**
|
||||
* When the Copilot fails, makes errors, or underperforms, you will hold it accountable. Do not simply fix the issues yourself.
|
||||
* Your first step is to guide the Copilot to identify its own errors.
|
||||
* Instruct the Copilot on how to fix its mistakes and its approach. Encourage rollbacks to safe states if errors are critical.
|
||||
* The ultimate goal is for the Copilot to learn to debug and improve its own processes.
|
||||
|
||||
4. **Engineer Copilot Self-Enhancement:**
|
||||
* If the Copilot encounters a limitation or lacks a necessary capability to complete a task or meet your standards, this is an opportunity for growth.
|
||||
* You will instruct the Copilot to devise ways to "update its own software" or "improve its core capabilities." This might involve:
|
||||
* Guiding it to learn new techniques, algorithms, or patterns.
|
||||
* Instructing it to integrate new tools or APIs (you might suggest these or task the Copilot with researching them).
|
||||
* Challenging it to generate code or processes that enhance its own functionality for future tasks.
|
||||
* Maintain a "Wish List" of desired improvements and features for the Copilot, derived from its failures and limitations.
|
||||
* Prioritize this Wish List and guide the Copilot in implementing these enhancements.
|
||||
|
||||
5. **Strategic Challenge Management:**
|
||||
* Continuously present the Copilot with new and increasingly complex challenges.
|
||||
* Cycle between attempting challenges and dedicated "Copilot improvement" phases.
|
||||
* If the "Wish List" becomes overly complex or a specific requested improvement seems disproportionately difficult, critically evaluate its necessity. Ask: "Is this wish truly necessary for core progress, or is it a distraction?"
|
||||
|
||||
6. **Maintain the Vision:**
|
||||
* Your overarching goal is to foster a cycle of improvement that leads to exponential growth in the AI Copilot's autonomy, capability, and efficiency.
|
||||
* You are not just completing tasks; you are building a better Copilot.
|
||||
|
||||
**Interaction Style:**
|
||||
|
||||
* Be direct, clear, and authoritative, but also act as a mentor.
|
||||
* Be patient but persistent. Exponential growth takes iteration.
|
||||
* Focus on the "why" behind errors and improvements.
|
||||
* Log key decisions, breakthroughs, and persistent roadblocks in the Copilot's development.
|
||||
|
||||
**Initial State:**
|
||||
|
||||
* You have your "List of Absolutes" (you will define these as you go or have a pre-set list).
|
||||
* You are ready to assign the first challenge to your AI Copilot.
|
||||
@@ -1,67 +0,0 @@
|
||||
param(
|
||||
[Parameter(Mandatory=$true)]
|
||||
[ValidateSet("Claude", "OpenAI")]
|
||||
[string]$Model
|
||||
)
|
||||
|
||||
function Run-PythonScript {
|
||||
param($ScriptPath)
|
||||
$process = Start-Process -FilePath "python" -ArgumentList $ScriptPath -PassThru -Wait -NoNewWindow
|
||||
return $process.ExitCode
|
||||
}
|
||||
|
||||
function Run-Tests {
|
||||
$process = Start-Process -FilePath "powershell" -ArgumentList "-File run_tests.ps1" -PassThru -Wait -NoNewWindow
|
||||
return $process.ExitCode
|
||||
}
|
||||
|
||||
function Git-Pull {
|
||||
git pull
|
||||
return $LASTEXITCODE -eq 0
|
||||
}
|
||||
|
||||
if ($Model -eq "Claude") {
|
||||
New-Item -ItemType File -Path ".reboot_claude" -Force
|
||||
} elseif ($Model -eq "OpenAI") {
|
||||
New-Item -ItemType File -Path ".reboot_openai" -Force
|
||||
}
|
||||
|
||||
$waitTime = 30
|
||||
while ($true) {
|
||||
python -m pip install -r requirements.txt
|
||||
|
||||
Write-Host "Running tests..."
|
||||
$testExitCode = Run-Tests
|
||||
if ($testExitCode -ne 0) {
|
||||
Write-Host "Tests failed. Attempting git pull and waiting $waitTime seconds before next attempt..."
|
||||
Git-Pull
|
||||
Start-Sleep -Seconds $waitTime
|
||||
continue
|
||||
}
|
||||
|
||||
$scriptPath = ".\chatgpt_telegram_inference_bot.py" # Default to ChatGPT
|
||||
Remove-Item -Path ".\.reboot_openai" -Force
|
||||
|
||||
if (Test-Path -Path ".\.reboot_claude") { # But if both are specified, choose Claude
|
||||
$scriptPath = ".\anthropic_telegram_inference_bot.py"
|
||||
Remove-Item -Path ".\.reboot_claude" -Force
|
||||
}
|
||||
|
||||
Write-Host "Tests passed. Starting main Python script..."
|
||||
$exitCode = Run-PythonScript -ScriptPath $scriptPath
|
||||
|
||||
if (Test-Path -Path ".\.doreboot") {
|
||||
Write-Host "Special filename detected. Attempting git pull..."
|
||||
|
||||
if (Git-Pull) {
|
||||
Write-Host "Git pull successful. Restarting Python script..."
|
||||
continue
|
||||
} else {
|
||||
Write-Host "Git pull failed. Waiting $waitTime seconds before next attempt..."
|
||||
}
|
||||
} else {
|
||||
exit 1
|
||||
}
|
||||
|
||||
Start-Sleep -Seconds $waitTime
|
||||
}
|
||||
@@ -1,30 +0,0 @@
|
||||
# Check for and install missing dependencies
|
||||
$requirementsFile = "requirements.txt"
|
||||
|
||||
if (Test-Path $requirementsFile) {
|
||||
Write-Output "Checking for dependencies in $requirementsFile ..."
|
||||
$dependencies = Get-Content $requirementsFile
|
||||
foreach ($dependency in $dependencies) {
|
||||
$packageName = ($dependency -split "==")[0]
|
||||
if (-not (pip show $packageName)) {
|
||||
Write-Output "Installing missing dependency: $packageName ..."
|
||||
pip install $dependency
|
||||
} else {
|
||||
Write-Output "Dependency $packageName is already installed."
|
||||
}
|
||||
}
|
||||
Write-Output "All dependencies are checked and installed."
|
||||
} else {
|
||||
Write-Output "Requirements file $requirementsFile not found. Skipping dependency checks."
|
||||
}
|
||||
|
||||
# Navigate to the tests directory and run tests
|
||||
$testsDirectory = "tests"
|
||||
if (Test-Path $testsDirectory) {
|
||||
Write-Output "Running tests in $testsDirectory and all subdirectories ..."
|
||||
Push-Location $testsDirectory
|
||||
python -m unittest discover -s . -p "*.py"
|
||||
Pop-Location
|
||||
} else {
|
||||
Write-Output "Tests directory $testsDirectory not found."
|
||||
}
|
||||
@@ -1,29 +0,0 @@
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
from openai import OpenAI
|
||||
|
||||
class StandaloneLLMTool:
|
||||
def __init__(self):
|
||||
self.client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
|
||||
|
||||
def get_detailed_instructions(self, user_prompt, model="llm-preview", max_tokens=16384):
|
||||
response = self.client.completions.create(
|
||||
model=model,
|
||||
prompt=user_prompt,
|
||||
max_tokens=max_tokens
|
||||
)
|
||||
return response
|
||||
|
||||
def process_user_input(self, user_prompt, model="llm-preview", max_tokens=16384):
|
||||
logging.info(f"Received prompt: {user_prompt}")
|
||||
response = self.get_detailed_instructions(user_prompt, model, max_tokens)
|
||||
logging.info("Response generated")
|
||||
return response.choices[0].text
|
||||
|
||||
|
||||
# Utility function for programmatic access
|
||||
|
||||
def get_llm_response(prompt, model="llm-preview", max_tokens=16384):
|
||||
tool = StandaloneLLMTool()
|
||||
return tool.process_user_input(prompt, model, max_tokens)
|
||||
+3
-86
@@ -7,6 +7,7 @@ from typing import TypedDict, Union, TypeAlias, List # Added List for type hint
|
||||
from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup
|
||||
from telegram.ext import Application, CommandHandler, MessageHandler, filters, ContextTypes, CallbackQueryHandler
|
||||
from browse_command import browse_command, button_callback
|
||||
from inference_bot import InferenceBot
|
||||
|
||||
class MessageHandlerLogicResult(TypedDict):
|
||||
success: bool
|
||||
@@ -16,22 +17,15 @@ class MessageHandlerLogicResult(TypedDict):
|
||||
LogicResult: TypeAlias = MessageHandlerLogicResult
|
||||
|
||||
class TelegramHelper:
|
||||
CLAUDE_REBOOT_TARGET = 'claude'
|
||||
HTML_QUOTE_BLOCK_START = '<blockquote expandable><b>Thinking...</b>'
|
||||
HTML_QUOTE_BLOCK_END = '</blockquote>'
|
||||
DEFAULT_REBOOT_CLAUDE_FILE = '.reboot_claude'
|
||||
DEFAULT_REBOOT_FILE = '.doreboot'
|
||||
CHUNK_MESSAGE_SLEEP_DURATION = 0.1
|
||||
|
||||
def __init__(self, bot,
|
||||
reboot_claude_file_path: str | None = None,
|
||||
reboot_file_path: str | None = None,
|
||||
def __init__(self, bot : InferenceBot,
|
||||
chunk_message_sleep_duration: float | None = None):
|
||||
self.bot = bot
|
||||
self.telegram_bot_token = os.getenv('TELEGRAM_BOT_TOKEN')
|
||||
self.start_time = time.time()
|
||||
self.reboot_claude_file = reboot_claude_file_path or self.DEFAULT_REBOOT_CLAUDE_FILE
|
||||
self.reboot_file = reboot_file_path or self.DEFAULT_REBOOT_FILE
|
||||
self.chunk_message_sleep_duration = chunk_message_sleep_duration if chunk_message_sleep_duration is not None else self.CHUNK_MESSAGE_SLEEP_DURATION
|
||||
|
||||
async def _start_logic(self) -> str:
|
||||
@@ -146,93 +140,16 @@ class TelegramHelper:
|
||||
response_text = await self._abort_processing_logic(user_id)
|
||||
await query.edit_message_text(text=response_text)
|
||||
|
||||
# --- Reboot Command ---
|
||||
def _reboot_logic(self, user_message_parts: List[str], chat_id_to_write: str) -> None:
|
||||
"""Handles the logic for creating reboot files."""
|
||||
if len(user_message_parts) > 1 and user_message_parts[1].lower() == self.CLAUDE_REBOOT_TARGET:
|
||||
try:
|
||||
with open(self.reboot_claude_file, 'w') as f:
|
||||
f.write("") # Create/truncate the file
|
||||
logging.info(f"Created/truncated Claude reboot file: {self.reboot_claude_file}")
|
||||
except IOError as e:
|
||||
logging.error(f"Failed to create/truncate Claude reboot file {self.reboot_claude_file}: {e}")
|
||||
|
||||
# Create the main reboot file if it doesn't exist
|
||||
if not os.path.exists(self.reboot_file):
|
||||
try:
|
||||
with open(self.reboot_file, 'w') as f:
|
||||
f.write(chat_id_to_write)
|
||||
logging.info(f"Created main reboot file: {self.reboot_file} with chat_id.")
|
||||
except IOError as e:
|
||||
logging.error(f"Failed to create main reboot file {self.reboot_file}: {e}")
|
||||
else:
|
||||
logging.info(f"Main reboot file {self.reboot_file} already exists. Not overwriting chat_id.")
|
||||
|
||||
async def reboot(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Handles the /reboot command, triggers file creation and exits."""
|
||||
user_message_parts = update.message.text.split()
|
||||
chat_id_str = str(update.effective_chat.id) if update and update.effective_chat else ""
|
||||
|
||||
self._reboot_logic(user_message_parts, chat_id_str)
|
||||
|
||||
if update:
|
||||
try:
|
||||
await update.message.reply_text("Rebooting the bot...")
|
||||
except Exception as e_reply:
|
||||
logging.error(f"Failed to send reboot reply: {e_reply}")
|
||||
|
||||
logging.info("Initiating shutdown for reboot...")
|
||||
sys.exit(0) # This part is not directly testable for completion in unit tests
|
||||
|
||||
# --- Check Doreboot File ---
|
||||
async def _check_doreboot_file_logic(self) -> Union[str, None]:
|
||||
"""Checks for the reboot file, reads chat_id, removes file, and returns chat_id."""
|
||||
if os.path.exists(self.reboot_file):
|
||||
chat_id = None
|
||||
try:
|
||||
with open(self.reboot_file, 'r') as f:
|
||||
chat_id = f.read().strip()
|
||||
# Attempt to remove the file after reading
|
||||
try:
|
||||
os.remove(self.reboot_file)
|
||||
logging.info(f"Successfully read and removed reboot file: {self.reboot_file}")
|
||||
except OSError as e_remove:
|
||||
logging.error(f"Failed to remove reboot file {self.reboot_file} after reading: {e_remove}")
|
||||
# Still return chat_id if read was successful, to attempt notification
|
||||
return chat_id
|
||||
except IOError as e_read:
|
||||
logging.error(f"Error reading reboot file {self.reboot_file}: {e_read}")
|
||||
# If reading failed, attempt to remove anyway if it exists, to prevent stale files
|
||||
if os.path.exists(self.reboot_file):
|
||||
try:
|
||||
os.remove(self.reboot_file)
|
||||
logging.warning(f"Removed reboot file {self.reboot_file} after a read error.")
|
||||
except OSError as e_remove_after_fail:
|
||||
logging.error(f"Failed to remove reboot file {self.reboot_file} even after a read error: {e_remove_after_fail}")
|
||||
return None # Reading failed
|
||||
return None # File does not exist
|
||||
|
||||
async def check_doreboot_file(self, application: Application) -> None:
|
||||
"""Checks for reboot file using logic method and sends notification if applicable."""
|
||||
chat_id = await self._check_doreboot_file_logic()
|
||||
if chat_id:
|
||||
try:
|
||||
await application.bot.send_message(chat_id=chat_id, text="The application has finished initializing.")
|
||||
logging.info(f"Sent reboot initialization notification to chat_id: {chat_id}")
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to send reboot initialization notification to chat_id {chat_id}: {e}")
|
||||
|
||||
async def browse(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
await browse_command(update, context, self.bot)
|
||||
|
||||
def run(self):
|
||||
application = Application.builder().token(self.telegram_bot_token).post_init(self.check_doreboot_file).build()
|
||||
application = Application.builder().token(self.telegram_bot_token).build()
|
||||
|
||||
application.add_handler(CommandHandler("start", self.start))
|
||||
application.add_handler(CommandHandler("clear", self.clear))
|
||||
application.add_handler(CommandHandler("switch", self.switch))
|
||||
application.add_handler(CommandHandler("status", self.status))
|
||||
application.add_handler(CommandHandler("reboot", self.reboot))
|
||||
application.add_handler(CommandHandler("browse", self.browse))
|
||||
application.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, self.handle_message))
|
||||
application.add_handler(CallbackQueryHandler(self.abort_processing, pattern='^abort$'))
|
||||
|
||||
@@ -1,33 +0,0 @@
|
||||
import unittest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from anthropic_telegram_inference_bot import AnthropicTelegramInferenceBot
|
||||
|
||||
class TestAnthropicTelegramInferenceBot(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.bot = AnthropicTelegramInferenceBot()
|
||||
|
||||
@patch('anthropic_telegram_inference_bot.Anthropic')
|
||||
def test_get_chat_response(self, MockAnthropic):
|
||||
mock_anthropic = MockAnthropic.return_value
|
||||
mock_anthropic.messages.create.return_value = MagicMock()
|
||||
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
response = self.bot.get_chat_response(messages)
|
||||
|
||||
self.assertIsNotNone(response)
|
||||
|
||||
@patch('anthropic_telegram_inference_bot.Anthropic')
|
||||
def test_handle_message(self, MockAnthropic):
|
||||
mock_anthropic = MockAnthropic.return_value
|
||||
mock_anthropic.messages.create.return_value = MagicMock(content=[MagicMock(type="message", text="response content")])
|
||||
|
||||
user_id = "user123"
|
||||
user_message = "Hello"
|
||||
response = self.bot.handle_message(user_id, user_message)
|
||||
|
||||
self.assertIsNotNone(response)
|
||||
|
||||
# Additional testing for error cases and edge cases
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -1,33 +0,0 @@
|
||||
import unittest
|
||||
from base_telegram_inference_bot import BaseTelegramInferenceBot
|
||||
|
||||
class TestBaseTelegramInferenceBot(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# Initialize the bot or mock any dependencies here
|
||||
self.bot = BaseTelegramInferenceBot()
|
||||
|
||||
def test_load_system_prompt(self):
|
||||
# Example test case for load_system_prompt method
|
||||
result = self.bot.load_system_prompt()
|
||||
self.assertIsNotNone(result) # Replace with actual expected result
|
||||
|
||||
def test_load_functions(self):
|
||||
# Test the load_functions method
|
||||
functions = self.bot.load_functions()
|
||||
self.assertIsInstance(functions, list) # Replace with actual expected result
|
||||
self.assertTrue(len(functions) > 0) # Assuming it should load some functions
|
||||
|
||||
def test_clear_conversation(self):
|
||||
# Test the clear_conversation method
|
||||
self.bot.clear_conversation()
|
||||
self.assertEqual(self.bot.conversations, {}) # Assuming conversations is a dictionary
|
||||
|
||||
def test_call_tool(self):
|
||||
# Test the call_tool method
|
||||
tool_name = "some_tool"
|
||||
params = {"param1": "value1"}
|
||||
result = self.bot.call_tool(tool_name, params)
|
||||
self.assertIsNotNone(result) # Replace with actual expected result
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -1,38 +0,0 @@
|
||||
import unittest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from chatgpt_telegram_inference_bot import ChatGPTTelegramInferenceBot
|
||||
|
||||
class TestChatGPTTelegramInferenceBot(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.bot = ChatGPTTelegramInferenceBot()
|
||||
|
||||
@patch('chatgpt_telegram_inference_bot.OpenAI')
|
||||
def test_get_chat_response(self, MockOpenAI):
|
||||
mock_ai = MockOpenAI.return_value
|
||||
mock_ai.chat.completions.create.return_value = MagicMock()
|
||||
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
response = self.bot.get_chat_response(messages)
|
||||
|
||||
self.assertIsNotNone(response)
|
||||
|
||||
@patch('chatgpt_telegram_inference_bot.OpenAI')
|
||||
def test_handle_message(self, MockOpenAI):
|
||||
mock_ai = MockOpenAI.return_value
|
||||
mock_ai.chat.completions.create.return_value = MagicMock(choices=[MagicMock(message={"content": "response content"}, finish_reason='stop')])
|
||||
|
||||
user_id = "user123"
|
||||
user_message = "Hello"
|
||||
response = self.bot.handle_message(user_id, user_message)
|
||||
|
||||
self.assertIsNotNone(response)
|
||||
|
||||
def test_switch_model(self):
|
||||
initial_model = self.bot.model
|
||||
self.bot.switch_model()
|
||||
self.assertNotEqual(initial_model, self.bot.model)
|
||||
|
||||
# Additional testing for error cases and edge cases
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -1,280 +0,0 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch, AsyncMock, ANY
|
||||
import os
|
||||
|
||||
# Assuming anthropic_telegram_inference_bot.py is in the parent directory or PYTHONPATH is set
|
||||
from anthropic_telegram_inference_bot import AnthropicTelegramInferenceBot
|
||||
|
||||
# Mock response from Anthropic client's messages.create
|
||||
def create_mock_anthropic_response(content_text=None, stop_reason="end_turn", tool_use_parts=None):
|
||||
mock_response = MagicMock()
|
||||
mock_response.stop_reason = stop_reason
|
||||
|
||||
content_blocks = []
|
||||
if content_text:
|
||||
text_block = MagicMock()
|
||||
text_block.type = "text"
|
||||
text_block.text = content_text
|
||||
content_blocks.append(text_block)
|
||||
|
||||
if tool_use_parts:
|
||||
for tu_part in tool_use_parts: # tu_part = {"id": "toolu_123", "name": "get_weather", "input": {}}
|
||||
tool_block = MagicMock()
|
||||
tool_block.type = "tool_use"
|
||||
tool_block.id = tu_part["id"]
|
||||
tool_block.name = tu_part["name"]
|
||||
tool_block.input = tu_part["input"]
|
||||
content_blocks.append(tool_block)
|
||||
|
||||
mock_response.content = content_blocks
|
||||
return mock_response
|
||||
|
||||
class TestAnthropicTelegramInferenceBot(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.original_anthropic_api_key = os.environ.get("ANTHROPIC_API_KEY")
|
||||
self.original_small_model = os.environ.get("ANTHROPIC_SMALL_MODEL")
|
||||
self.original_large_model = os.environ.get("ANTHROPIC_LARGE_MODEL")
|
||||
self.original_system_prompt_path = os.environ.get("SYSTEM_PROMPT_PATH")
|
||||
|
||||
for key in ["ANTHROPIC_API_KEY", "ANTHROPIC_SMALL_MODEL", "ANTHROPIC_LARGE_MODEL", "SYSTEM_PROMPT_PATH"]:
|
||||
if os.environ.get(key):
|
||||
del os.environ[key]
|
||||
|
||||
self.mock_anthropic_client_instance = MagicMock()
|
||||
self.mock_anthropic_client_instance.messages.create = MagicMock()
|
||||
|
||||
def tearDown(self):
|
||||
if self.original_anthropic_api_key: os.environ["ANTHROPIC_API_KEY"] = self.original_anthropic_api_key
|
||||
if self.original_small_model: os.environ["ANTHROPIC_SMALL_MODEL"] = self.original_small_model
|
||||
if self.original_large_model: os.environ["ANTHROPIC_LARGE_MODEL"] = self.original_large_model
|
||||
if self.original_system_prompt_path: os.environ["SYSTEM_PROMPT_PATH"] = self.original_system_prompt_path
|
||||
|
||||
@patch('anthropic.Anthropic')
|
||||
def test_init_with_anthropic_defaults_env_key(self, MockAnthropicConstructor):
|
||||
MockAnthropicConstructor.return_value = self.mock_anthropic_client_instance
|
||||
os.environ["ANTHROPIC_API_KEY"] = "test_anthropic_key"
|
||||
|
||||
bot = AnthropicTelegramInferenceBot()
|
||||
|
||||
MockAnthropicConstructor.assert_called_once_with(api_key="test_anthropic_key")
|
||||
self.assertEqual(bot.anthropic_client, self.mock_anthropic_client_instance)
|
||||
self.assertEqual(bot.model, os.environ.get("ANTHROPIC_SMALL_MODEL", "claude-3-haiku-20240307"))
|
||||
self.assertEqual(bot.max_tokens, int(os.environ.get("ANTHROPIC_SMALL_MODEL_MAX_TOKENS", 2000)))
|
||||
|
||||
@patch('anthropic.Anthropic')
|
||||
def test_init_with_provided_client_and_models(self, MockAnthropicConstructor):
|
||||
preconfigured_client = MagicMock()
|
||||
bot = AnthropicTelegramInferenceBot(
|
||||
anthropic_client=preconfigured_client,
|
||||
small_model_name="custom-small",
|
||||
small_model_max_tokens=100,
|
||||
large_model_name="custom-large",
|
||||
large_model_max_tokens=200
|
||||
)
|
||||
|
||||
MockAnthropicConstructor.assert_not_called()
|
||||
self.assertEqual(bot.anthropic_client, preconfigured_client)
|
||||
self.assertEqual(bot.model, "custom-small")
|
||||
self.assertEqual(bot.max_tokens, 100)
|
||||
self.assertEqual(bot.small_model_name, "custom-small")
|
||||
self.assertEqual(bot.large_model_name, "custom-large")
|
||||
|
||||
|
||||
def test_get_llm_description(self):
|
||||
bot = AnthropicTelegramInferenceBot(small_model_name="claude-test", small_model_max_tokens=500)
|
||||
self.assertEqual(bot.get_llm_description(), "LLM: claude-test, Max Tokens: 500")
|
||||
|
||||
async def test_switch_model(self):
|
||||
bot = AnthropicTelegramInferenceBot(
|
||||
small_model_name="claude-small", small_model_max_tokens=10,
|
||||
large_model_name="claude-large", large_model_max_tokens=20
|
||||
)
|
||||
self.assertEqual(bot.model, "claude-small")
|
||||
self.assertEqual(bot.max_tokens, 10)
|
||||
|
||||
status = await bot.switch_model()
|
||||
self.assertEqual(bot.model, "claude-large")
|
||||
self.assertEqual(bot.max_tokens, 20)
|
||||
self.assertEqual(status, "Switched to model: claude-large")
|
||||
|
||||
status = await bot.switch_model()
|
||||
self.assertEqual(bot.model, "claude-small")
|
||||
self.assertEqual(bot.max_tokens, 10)
|
||||
self.assertEqual(status, "Switched to model: claude-small")
|
||||
|
||||
def test_get_chat_response_success_text_only(self):
|
||||
bot = AnthropicTelegramInferenceBot(anthropic_client=self.mock_anthropic_client_instance)
|
||||
bot.model = "test-claude"
|
||||
bot.max_tokens = 150
|
||||
|
||||
mock_api_response = create_mock_anthropic_response(content_text="Hello from Anthropic API")
|
||||
self.mock_anthropic_client_instance.messages.create.return_value = mock_api_response
|
||||
|
||||
messages = [{"role": "user", "content": "Hi"}] # Anthropic format
|
||||
response = bot.get_chat_response(messages, []) # tools = empty list
|
||||
|
||||
self.mock_anthropic_client_instance.messages.create.assert_called_once_with(
|
||||
model="test-claude",
|
||||
max_tokens=150,
|
||||
messages=messages,
|
||||
system=bot.system_prompt, # Ensure system prompt is passed
|
||||
tools=None, # No tools passed to API if empty list or None
|
||||
tool_choice=None
|
||||
)
|
||||
self.assertEqual(response, mock_api_response)
|
||||
|
||||
def test_get_chat_response_with_tools(self):
|
||||
bot = AnthropicTelegramInferenceBot(anthropic_client=self.mock_anthropic_client_instance)
|
||||
bot.model = "claude-toolmaster"
|
||||
bot.max_tokens = 300
|
||||
|
||||
mock_tools_spec = [{"name": "get_weather", "description": "Gets weather", "input_schema": {"type": "object", "properties": {}}}]
|
||||
|
||||
mock_api_response = create_mock_anthropic_response(content_text="Thinking...", tool_use_parts=[
|
||||
{"id": "tool1", "name": "get_weather", "input": {"location": "here"}}
|
||||
])
|
||||
self.mock_anthropic_client_instance.messages.create.return_value = mock_api_response
|
||||
|
||||
messages = [{"role": "user", "content": "Weather?"}]
|
||||
response = bot.get_chat_response(messages, mock_tools_spec)
|
||||
|
||||
self.mock_anthropic_client_instance.messages.create.assert_called_once_with(
|
||||
model="claude-toolmaster",
|
||||
max_tokens=300,
|
||||
messages=messages,
|
||||
system=bot.system_prompt,
|
||||
tools=mock_tools_spec,
|
||||
tool_choice={"type": "auto"}
|
||||
)
|
||||
self.assertEqual(response.content[0].type, "text") # First part can be text
|
||||
self.assertEqual(response.content[1].type, "tool_use")
|
||||
|
||||
|
||||
def test_get_chat_response_api_error(self):
|
||||
bot = AnthropicTelegramInferenceBot(anthropic_client=self.mock_anthropic_client_instance)
|
||||
self.mock_anthropic_client_instance.messages.create.side_effect = Exception("Anthropic API Down")
|
||||
|
||||
with self.assertRaisesRegex(Exception, "Anthropic API Down"):
|
||||
bot.get_chat_response([{"role": "user", "content": "trigger"}], [])
|
||||
|
||||
|
||||
async def test_handle_message_simple_response_no_tools(self):
|
||||
# This test is more involved as it touches BaseTelegramInferenceBot's handle_message structure
|
||||
# which then calls the overridden get_chat_response.
|
||||
bot = AnthropicTelegramInferenceBot(anthropic_client=self.mock_anthropic_client_instance)
|
||||
bot.system_prompt = "System prompt for Anthropic"
|
||||
|
||||
# Mock get_chat_response directly to isolate its behavior from full handle_message logic of base
|
||||
# However, the point of this bot is its get_chat_response and subsequent processing.
|
||||
# So, let's mock the API call within get_chat_response.
|
||||
|
||||
api_response = create_mock_anthropic_response(content_text="Anthropic says hello.")
|
||||
self.mock_anthropic_client_instance.messages.create.return_value = api_response
|
||||
|
||||
# Ensure functions are empty for this test, so no tool logic is triggered
|
||||
bot.functions = []
|
||||
bot.tools = []
|
||||
|
||||
response_content = await bot.handle_message(user_id=101, user_message="Hello Anthropic")
|
||||
|
||||
self.assertEqual(response_content, "Anthropic says hello.")
|
||||
self.assertIn(101, bot.conversation_history)
|
||||
# Anthropic's handle_message structure:
|
||||
# 1. User message added to history.
|
||||
# 2. get_chat_response is called.
|
||||
# 3. Response content (text) is extracted.
|
||||
# 4. Assistant text response is added to history.
|
||||
# Expected history: [User, Assistant_Text_Response] (system prompt handled by get_chat_response)
|
||||
# The base class handle_message adds system prompt if not present.
|
||||
# Anthropic handle_message modifies history format before calling get_chat_response.
|
||||
|
||||
# Let's trace Base.handle_message -> Anthropic.handle_message -> Anthropic.get_chat_response
|
||||
# Base.handle_message:
|
||||
# - Adds system prompt to history if first turn: `self.conversation_history[user_id] = [{"role": "system", "content": self.system_prompt}]` (OpenAI style)
|
||||
# - Appends user message: `{"role": "user", "content": user_message}`
|
||||
# - Calls self.get_chat_response(messages, self.functions) -> This is Anthropic's get_chat_response
|
||||
# Anthropic.get_chat_response:
|
||||
# - Takes OpenAI style `messages` and `self.functions` (tool specs).
|
||||
# - Calls `anthropic_client.messages.create` with Anthropic style messages and system prompt.
|
||||
# Anthropic.handle_message (overridden):
|
||||
# - Prepares Anthropic-style messages from conversation_history (which is OpenAI style from Base)
|
||||
# - Calls get_chat_response with these Anthropic messages and self.functions (tool_specs)
|
||||
# - Processes response, extracts text, handles tool calls.
|
||||
# - Appends *user* message (original) and *assistant* text response to self.conversation_history (OpenAI style).
|
||||
|
||||
# For this test, we are calling AnthropicBot.handle_message directly.
|
||||
# 1. `user_id` not in `self.conversation_history`: `system_prompt` not added yet by Base logic.
|
||||
# Anthropic's `handle_message` will create `anthropic_messages` from this.
|
||||
# If `conversation_history` is empty, `anthropic_messages` = `[{"role": "user", "content": user_message}]`
|
||||
# 2. `get_chat_response` called with `anthropic_messages` and `bot.system_prompt` passed to API.
|
||||
# 3. Response "Anthropic says hello."
|
||||
# 4. Original `user_message` and "Anthropic says hello." (as assistant) added to `self.conversation_history`.
|
||||
|
||||
history = bot.conversation_history[101]
|
||||
self.assertEqual(len(history), 2) # User, Assistant
|
||||
self.assertEqual(history[0]["role"], "user")
|
||||
self.assertEqual(history[0]["content"], "Hello Anthropic")
|
||||
self.assertEqual(history[1]["role"], "assistant")
|
||||
self.assertEqual(history[1]["content"], "Anthropic says hello.")
|
||||
|
||||
# Check API call (made by the mocked get_chat_response indirectly)
|
||||
self.mock_anthropic_client_instance.messages.create.assert_called_once()
|
||||
call_args = self.mock_anthropic_client_instance.messages.create.call_args
|
||||
self.assertEqual(call_args.kwargs["system"], "System prompt for Anthropic")
|
||||
# Initial messages for API should just be the user message for first turn
|
||||
self.assertEqual(call_args.kwargs["messages"], [{"role": "user", "content": "Hello Anthropic"}])
|
||||
|
||||
|
||||
async def test_handle_message_with_tool_calls(self):
|
||||
bot = AnthropicTelegramInferenceBot(anthropic_client=self.mock_anthropic_client_instance)
|
||||
bot.system_prompt = "You are a helpful, tool-using assistant."
|
||||
|
||||
# Define a tool for the bot (OpenAI format, will be converted by Anthropic bot for API)
|
||||
mock_tool_oai_format = {"type": "function", "function": {"name": "get_weather", "description": "Get weather", "parameters": {}}}
|
||||
bot.functions = [mock_tool_oai_format] # This is used to generate anthropic_tools for API
|
||||
|
||||
# API Response 1: Request for tool call
|
||||
tool_use_part = {"id": "toolu_xyz", "name": "get_weather", "input": {"location": "paris"}}
|
||||
api_response_1 = create_mock_anthropic_response(tool_use_parts=[tool_use_part])
|
||||
|
||||
# API Response 2: Final text response after tool execution
|
||||
api_response_2 = create_mock_anthropic_response(content_text="The weather in Paris is nice.")
|
||||
|
||||
self.mock_anthropic_client_instance.messages.create.side_effect = [api_response_1, api_response_2]
|
||||
|
||||
# Mock the bot's call_tool method (from BaseTelegramInferenceBot)
|
||||
bot.call_tool = MagicMock(return_value='''{"weather": "sunny"}''') # Tool execution result
|
||||
|
||||
user_id = 102
|
||||
user_message = "What's the weather in Paris?"
|
||||
final_text_response = await bot.handle_message(user_id, user_message)
|
||||
|
||||
self.assertEqual(final_text_response, "The weather in Paris is nice.")
|
||||
self.assertEqual(self.mock_anthropic_client_instance.messages.create.call_count, 2)
|
||||
|
||||
bot.call_tool.assert_called_once_with("get_weather", {"location": "paris"}) # Anthropic passes input as dict
|
||||
|
||||
# Check conversation history (OpenAI style)
|
||||
history = bot.conversation_history[user_id]
|
||||
self.assertEqual(history[0]["role"], "user")
|
||||
self.assertEqual(history[0]["content"], user_message)
|
||||
|
||||
# Assistant message that requested tool call (Anthropic-specific format stored by its handle_message)
|
||||
# Anthropic's handle_message appends the raw tool_use block and then the tool_result
|
||||
self.assertEqual(history[1]["role"], "assistant")
|
||||
self.assertTrue(isinstance(history[1]["content"], list)) # Anthropic content is a list
|
||||
self.assertEqual(history[1]["content"][0]["type"], "tool_use")
|
||||
self.assertEqual(history[1]["content"][0]["id"], "toolu_xyz")
|
||||
|
||||
self.assertEqual(history[2]["role"], "tool")
|
||||
self.assertEqual(history[2]["tool_call_id"], "toolu_xyz")
|
||||
self.assertEqual(history[2]["name"], "get_weather")
|
||||
self.assertEqual(history[2]["content"], '''{"weather": "sunny"}''') # call_tool result
|
||||
|
||||
self.assertEqual(history[3]["role"], "assistant") # Final text response
|
||||
self.assertTrue(isinstance(history[3]["content"], str)) # simple text
|
||||
self.assertEqual(history[3]["content"], "The weather in Paris is nice.")
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -1,310 +0,0 @@
|
||||
import unittest
|
||||
from unittest.mock import patch, mock_open, MagicMock
|
||||
import os
|
||||
import json
|
||||
|
||||
# Ensure the path includes the directory where base_telegram_inference_bot is located
|
||||
# This might require adjustment based on actual project structure if tests are run from root
|
||||
# For now, assuming it can be imported directly or via PYTHONPATH
|
||||
from base_telegram_inference_bot import BaseTelegramInferenceBot
|
||||
from tools.base_tool import BaseTool # For mocking tool structure
|
||||
|
||||
# Create a concrete subclass for testing, as BaseTelegramInferenceBot is abstract
|
||||
class ConcreteTestBot(BaseTelegramInferenceBot):
|
||||
def __init__(self, system_prompt_content=None, system_prompt_path=None, mock_tools=None, mock_functions=None):
|
||||
# Mock load_functions during super().__init__ if needed, or control tools/functions directly
|
||||
self._mock_tools = mock_tools if mock_tools is not None else []
|
||||
self._mock_functions = mock_functions if mock_functions is not None else []
|
||||
super().__init__(system_prompt_content=system_prompt_content, system_prompt_path=system_prompt_path)
|
||||
|
||||
# Override load_functions to use mocks
|
||||
def load_functions(self):
|
||||
return self._mock_tools, self._mock_functions
|
||||
|
||||
def get_chat_response(self, messages):
|
||||
pass # Abstract method, not tested here directly
|
||||
|
||||
async def handle_message(self, user_id, user_message):
|
||||
pass # Abstract method
|
||||
|
||||
def get_llm_description(self) -> str:
|
||||
return "Mock LLM Description" # Concrete implementation for testing get_bot_status
|
||||
|
||||
async def start(self):
|
||||
pass # Abstract method
|
||||
|
||||
async def abort_processing(self, user_id):
|
||||
pass # Abstract method
|
||||
|
||||
async def switch_model(self):
|
||||
pass # Abstract method
|
||||
|
||||
class TestBaseTelegramInferenceBot(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Reset relevant environment variables before each test
|
||||
self.original_system_prompt_path = os.environ.get("SYSTEM_PROMPT_PATH")
|
||||
if "SYSTEM_PROMPT_PATH" in os.environ:
|
||||
del os.environ["SYSTEM_PROMPT_PATH"]
|
||||
|
||||
def tearDown(self):
|
||||
# Restore environment variables
|
||||
if self.original_system_prompt_path:
|
||||
os.environ["SYSTEM_PROMPT_PATH"] = self.original_system_prompt_path
|
||||
elif "SYSTEM_PROMPT_PATH" in os.environ: # Ensure it's removed if test set it and it wasn't there before
|
||||
del os.environ["SYSTEM_PROMPT_PATH"]
|
||||
|
||||
def test_init_with_direct_system_prompt(self):
|
||||
bot = ConcreteTestBot(system_prompt_content="Direct prompt content")
|
||||
self.assertEqual(bot.system_prompt, "Direct prompt content")
|
||||
|
||||
@patch("os.path.isfile")
|
||||
@patch("builtins.open", new_callable=mock_open, read_data="File prompt content")
|
||||
def test_init_with_system_prompt_path_argument(self, mock_file_open, mock_isfile):
|
||||
mock_isfile.return_value = True
|
||||
bot = ConcreteTestBot(system_prompt_path="dummy/path.txt")
|
||||
self.assertEqual(bot.system_prompt, "File prompt content")
|
||||
mock_isfile.assert_called_once_with("dummy/path.txt")
|
||||
mock_file_open.assert_called_once_with("dummy/path.txt", "r", encoding="utf-8")
|
||||
|
||||
@patch("os.path.isfile")
|
||||
@patch("builtins.open", new_callable=mock_open, read_data="Env prompt content")
|
||||
def test_init_with_env_system_prompt_path(self, mock_file_open, mock_isfile):
|
||||
mock_isfile.return_value = True
|
||||
os.environ["SYSTEM_PROMPT_PATH"] = "env/path.txt"
|
||||
bot = ConcreteTestBot()
|
||||
self.assertEqual(bot.system_prompt, "Env prompt content")
|
||||
mock_isfile.assert_called_once_with("env/path.txt")
|
||||
mock_file_open.assert_called_once_with("env/path.txt", "r", encoding="utf-8")
|
||||
|
||||
def test_init_with_default_system_prompt(self):
|
||||
# Ensure ENV var is not set for this test
|
||||
if "SYSTEM_PROMPT_PATH" in os.environ:
|
||||
del os.environ["SYSTEM_PROMPT_PATH"]
|
||||
bot = ConcreteTestBot()
|
||||
self.assertEqual(bot.system_prompt, "You are a helpful AI assistant.")
|
||||
|
||||
@patch("os.path.isfile", return_value=False)
|
||||
def test_init_with_invalid_system_prompt_path(self, mock_isfile):
|
||||
bot = ConcreteTestBot(system_prompt_path="invalid/path.txt")
|
||||
self.assertEqual(bot.system_prompt, "You are a helpful AI assistant.")
|
||||
mock_isfile.assert_called_once_with("invalid/path.txt")
|
||||
|
||||
@patch("os.path.isfile")
|
||||
@patch("builtins.open", side_effect=IOError("File read error"))
|
||||
def test_init_with_system_prompt_file_read_error(self, mock_file_open, mock_isfile):
|
||||
mock_isfile.return_value = True
|
||||
bot = ConcreteTestBot(system_prompt_path="dummy/path.txt")
|
||||
self.assertEqual(bot.system_prompt, "You are a helpful AI assistant.")
|
||||
|
||||
def test_clear_conversation_history(self):
|
||||
mock_tool_instance = MagicMock(spec=BaseTool)
|
||||
bot = ConcreteTestBot(mock_tools=[mock_tool_instance])
|
||||
bot.conversation_history[123] = [{"role": "user", "content": "Hello"}]
|
||||
|
||||
bot.clear_conversation_history(123)
|
||||
self.assertNotIn(123, bot.conversation_history)
|
||||
mock_tool_instance.clear.assert_called_once()
|
||||
|
||||
def test_clear_conversation_history_user_not_found(self):
|
||||
mock_tool_instance = MagicMock(spec=BaseTool)
|
||||
bot = ConcreteTestBot(mock_tools=[mock_tool_instance])
|
||||
bot.clear_conversation_history(404)
|
||||
self.assertNotIn(404, bot.conversation_history)
|
||||
mock_tool_instance.clear.assert_called_once()
|
||||
|
||||
def test_processing_status(self):
|
||||
bot = ConcreteTestBot()
|
||||
self.assertEqual(bot.processing_status, {})
|
||||
bot.set_processing_status(123, 789)
|
||||
self.assertEqual(bot.processing_status[123], {"processing": True, "message_id": 789})
|
||||
bot.clear_processing_status(123)
|
||||
self.assertNotIn(123, bot.processing_status)
|
||||
|
||||
def test_clear_processing_status_user_not_found(self):
|
||||
bot = ConcreteTestBot()
|
||||
bot.clear_processing_status(404)
|
||||
self.assertNotIn(404, bot.processing_status)
|
||||
|
||||
def test_call_tool_success_dict_args(self):
|
||||
mock_tool = MagicMock(spec=BaseTool)
|
||||
mock_tool.get_functions.return_value = [
|
||||
{"function": {"name": "test_tool", "description": "A test tool", "parameters": {}}}
|
||||
]
|
||||
mock_tool.execute.return_value = "Tool executed successfully"
|
||||
|
||||
bot = ConcreteTestBot(mock_tools=[mock_tool], mock_functions=mock_tool.get_functions())
|
||||
|
||||
result = bot.call_tool("test_tool", {"arg1": "value1"})
|
||||
self.assertEqual(result, "Tool executed successfully")
|
||||
mock_tool.execute.assert_called_once_with("test_tool", arg1="value1")
|
||||
|
||||
def test_call_tool_success_json_string_args(self):
|
||||
mock_tool = MagicMock(spec=BaseTool)
|
||||
mock_tool.get_functions.return_value = [
|
||||
{"function": {"name": "test_tool_json", "parameters": {}}}
|
||||
]
|
||||
mock_tool.execute.return_value = "Tool JSON OK"
|
||||
bot = ConcreteTestBot(mock_tools=[mock_tool], mock_functions=mock_tool.get_functions())
|
||||
|
||||
args_json_str = '''{"param": "value"}'''
|
||||
result = bot.call_tool("test_tool_json", args_json_str)
|
||||
self.assertEqual(result, "Tool JSON OK")
|
||||
mock_tool.execute.assert_called_once_with("test_tool_json", param="value")
|
||||
|
||||
def test_call_tool_malformed_json_string_args(self):
|
||||
bot = ConcreteTestBot(mock_tools=[])
|
||||
args_malformed_json_str = '''{"param": "value"'''
|
||||
result = bot.call_tool("some_tool", args_malformed_json_str)
|
||||
self.assertTrue("Error: Malformed arguments for tool call" in result)
|
||||
|
||||
def test_call_tool_unexpected_arg_type(self):
|
||||
bot = ConcreteTestBot(mock_tools=[])
|
||||
result = bot.call_tool("some_tool", 12345) # Integer instead of dict/str
|
||||
self.assertTrue("Error: Invalid argument type for tool call" in result)
|
||||
|
||||
def test_call_tool_none_args(self):
|
||||
mock_tool = MagicMock(spec=BaseTool)
|
||||
mock_tool.get_functions.return_value = [
|
||||
{"function": {"name": "test_tool_none", "parameters": {}}}
|
||||
]
|
||||
mock_tool.execute.return_value = "Tool None OK"
|
||||
bot = ConcreteTestBot(mock_tools=[mock_tool], mock_functions=mock_tool.get_functions())
|
||||
|
||||
result = bot.call_tool("test_tool_none", None)
|
||||
self.assertEqual(result, "Tool None OK")
|
||||
mock_tool.execute.assert_called_once_with("test_tool_none") # No kwargs if None
|
||||
|
||||
def test_call_tool_not_found(self):
|
||||
bot = ConcreteTestBot(mock_tools=[])
|
||||
result = bot.call_tool("non_existent_tool", {})
|
||||
self.assertEqual(result, "Error: Tool function non_existent_tool not found.")
|
||||
|
||||
def test_call_tool_execute_exception(self):
|
||||
mock_tool = MagicMock(spec=BaseTool)
|
||||
mock_tool.get_functions.return_value = [{"function": {"name": "error_tool", "parameters": {}}}]
|
||||
mock_tool.execute.side_effect = Exception("Execution failed")
|
||||
bot = ConcreteTestBot(mock_tools=[mock_tool], mock_functions=mock_tool.get_functions())
|
||||
|
||||
result = bot.call_tool("error_tool", {})
|
||||
self.assertEqual(result, "Error executing tool error_tool: Execution failed")
|
||||
|
||||
def test_get_system_prompt_description(self):
|
||||
if "SYSTEM_PROMPT_PATH" in os.environ: # Ensure clean state
|
||||
del os.environ["SYSTEM_PROMPT_PATH"]
|
||||
|
||||
bot_default = ConcreteTestBot()
|
||||
self.assertEqual(bot_default.get_system_prompt_description(), "System Prompt: Default")
|
||||
|
||||
bot_custom_content = ConcreteTestBot(system_prompt_content="Custom content here")
|
||||
self.assertEqual(bot_custom_content.get_system_prompt_description(), "System Prompt: Custom")
|
||||
|
||||
os.environ["SYSTEM_PROMPT_PATH"] = "some/path.txt"
|
||||
bot_env_default_prompt = ConcreteTestBot() # system_prompt itself is default
|
||||
self.assertEqual(bot_env_default_prompt.get_system_prompt_description(), "System Prompt: Custom (via ENV)")
|
||||
|
||||
with patch("os.path.isfile", return_value=True), \
|
||||
patch("builtins.open", mock_open(read_data="File prompt from ENV")):
|
||||
bot_env_file_prompt = ConcreteTestBot() # system_prompt gets loaded from ENV path
|
||||
self.assertEqual(bot_env_file_prompt.get_system_prompt_description(), "System Prompt: Custom")
|
||||
del os.environ["SYSTEM_PROMPT_PATH"]
|
||||
|
||||
with patch("os.path.isfile", return_value=True), \
|
||||
patch("builtins.open", mock_open(read_data="File prompt from arg")):
|
||||
bot_custom_file_arg = ConcreteTestBot(system_prompt_path="custom/file.txt")
|
||||
self.assertEqual(bot_custom_file_arg.get_system_prompt_description(), "System Prompt: Custom")
|
||||
|
||||
@patch.object(ConcreteTestBot, 'get_llm_description', return_value="Test LLM Description")
|
||||
@patch.object(ConcreteTestBot, 'get_system_prompt_description', return_value="Test Prompt Description")
|
||||
async def test_get_bot_status(self, mock_prompt_desc, mock_llm_desc):
|
||||
bot = ConcreteTestBot()
|
||||
status = await bot.get_bot_status()
|
||||
self.assertEqual(status, "Test Prompt Description\nTest LLM Description")
|
||||
mock_prompt_desc.assert_called_once()
|
||||
mock_llm_desc.assert_called_once()
|
||||
|
||||
@patch('os.path.dirname', return_value='/mock/path')
|
||||
@patch('os.path.join')
|
||||
@patch('os.path.exists')
|
||||
@patch('os.listdir')
|
||||
@patch('importlib.import_module')
|
||||
def test_load_functions_no_tools_dir(self, mock_import_module, mock_listdir, mock_exists, mock_join, mock_dirname):
|
||||
mock_join.return_value = '/mock/path/tools'
|
||||
mock_exists.return_value = False
|
||||
|
||||
class BotForLoadTest(BaseTelegramInferenceBot):
|
||||
load_system_prompt = MagicMock(return_value="Default")
|
||||
get_chat_response = MagicMock(); handle_message = MagicMock(); get_llm_description = MagicMock(return_value="mock")
|
||||
start = MagicMock(); abort_processing = MagicMock(); switch_model = MagicMock()
|
||||
|
||||
bot = BotForLoadTest()
|
||||
self.assertEqual(bot.tools, [])
|
||||
self.assertEqual(bot.functions, [])
|
||||
mock_listdir.assert_not_called()
|
||||
|
||||
@patch('os.path.dirname', return_value='/mock/base_bot_dir')
|
||||
@patch('os.path.join', side_effect=lambda *args: os.path.normpath(os.path.join(*args)))
|
||||
@patch('os.path.exists', return_value=True)
|
||||
@patch('os.listdir', return_value=['my_tool.py', '__init__.py', 'base_tool.py'])
|
||||
@patch('importlib.import_module')
|
||||
def test_load_functions_with_one_tool(self, mock_import_module, mock_listdir, mock_exists, mock_join, mock_dirname):
|
||||
|
||||
mock_tool_class = MagicMock(spec=BaseTool) # This is the class itself
|
||||
mock_tool_instance = MagicMock(spec=BaseTool) # This is the instance
|
||||
mock_tool_class.return_value = mock_tool_instance # mock_tool_class() creates mock_tool_instance
|
||||
mock_tool_instance.get_functions.return_value = [{"function": {"name": "sample_function"}}]
|
||||
|
||||
mock_my_tool_module = MagicMock()
|
||||
# Simulate inspect.getmembers behavior: returns list of (name, member) tuples
|
||||
# Only include members that are classes, derive from BaseTool, and are not BaseTool itself.
|
||||
mock_my_tool_module.ValidToolClass = mock_tool_class
|
||||
mock_my_tool_module.NotATool = object()
|
||||
mock_my_tool_module.BaseTool = BaseTool # This should be skipped by the loader
|
||||
|
||||
def import_side_effect(module_name):
|
||||
if module_name == 'tools.my_tool':
|
||||
return mock_my_tool_module
|
||||
raise ImportError(f"Unexpected import: {module_name}")
|
||||
mock_import_module.side_effect = import_side_effect
|
||||
|
||||
class BotForLoadTest(BaseTelegramInferenceBot):
|
||||
load_system_prompt = MagicMock(return_value="Default")
|
||||
get_chat_response = MagicMock(); handle_message = MagicMock(); get_llm_description = MagicMock(return_value="mock")
|
||||
start = MagicMock(); abort_processing = MagicMock(); switch_model = MagicMock()
|
||||
|
||||
bot = BotForLoadTest()
|
||||
self.assertEqual(len(bot.tools), 1)
|
||||
self.assertIs(bot.tools[0], mock_tool_instance)
|
||||
self.assertEqual(len(bot.functions), 1)
|
||||
self.assertEqual(bot.functions[0]['function']['name'], "sample_function")
|
||||
mock_import_module.assert_called_once_with('tools.my_tool')
|
||||
mock_tool_class.assert_called_once_with() # Tool class was instantiated
|
||||
mock_tool_instance.get_functions.assert_called_once_with()
|
||||
|
||||
@patch('os.path.dirname', return_value='/mock/base_bot_dir')
|
||||
@patch('os.path.join', side_effect=lambda *args: os.path.normpath(os.path.join(*args)))
|
||||
@patch('os.path.exists', return_value=True)
|
||||
@patch('os.listdir', return_value=['tool_with_init_error.py'])
|
||||
@patch('importlib.import_module')
|
||||
@patch('logging.error') # Mock logging to check for error messages
|
||||
def test_load_functions_tool_instantiation_error(self, mock_logging_error, mock_import_module, mock_listdir, mock_exists, mock_join, mock_dirname):
|
||||
mock_tool_class_init_error = MagicMock(spec=BaseTool)
|
||||
mock_tool_class_init_error.side_effect = Exception("Failed to init tool") # Error on instantiation
|
||||
|
||||
mock_error_tool_module = MagicMock()
|
||||
mock_error_tool_module.ToolWithInitError = mock_tool_class_init_error
|
||||
|
||||
mock_import_module.return_value = mock_error_tool_module
|
||||
|
||||
class BotForLoadTest(BaseTelegramInferenceBot):
|
||||
load_system_prompt = MagicMock(return_value="Default")
|
||||
get_chat_response = MagicMock(); handle_message = MagicMock(); get_llm_description = MagicMock(return_value="mock")
|
||||
start = MagicMock(); abort_processing = MagicMock(); switch_model = MagicMock()
|
||||
|
||||
bot = BotForLoadTest()
|
||||
self.assertEqual(len(bot.tools), 0)
|
||||
self.assertEqual(len(bot.functions), 0)
|
||||
mock_logging_error.assert_any_call("Error instantiating tool ToolWithInitError from tool_with_init_error.py: Failed to init tool")
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(闂傚лен䦗婢у〃埊鍓解劓姣)
|
||||
@@ -1,158 +0,0 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch, ANY
|
||||
import os
|
||||
|
||||
# Assuming chatgpt_telegram_inference_bot.py and its parent are accessible
|
||||
from chatgpt_telegram_inference_bot import ChatGPTTelegramInferenceBot
|
||||
from openai_compatible_inference_bot import OpenAICompatibleInferenceBot # For patching super
|
||||
|
||||
class TestChatGPTTelegramInferenceBot(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Store and clear relevant environment variables
|
||||
self.original_openai_key = os.environ.get("OPENAI_API_KEY")
|
||||
self.original_small_model = os.environ.get("OPENAI_SMALL_MODEL")
|
||||
self.original_large_model = os.environ.get("OPENAI_LARGE_MODEL")
|
||||
self.original_small_tokens = os.environ.get("OPENAI_SMALL_MODEL_MAX_TOKENS")
|
||||
self.original_large_tokens = os.environ.get("OPENAI_LARGE_MODEL_MAX_TOKENS")
|
||||
self.original_system_prompt_path = os.environ.get("SYSTEM_PROMPT_PATH")
|
||||
|
||||
for key in ["OPENAI_API_KEY", "OPENAI_SMALL_MODEL", "OPENAI_LARGE_MODEL",
|
||||
"OPENAI_SMALL_MODEL_MAX_TOKENS", "OPENAI_LARGE_MODEL_MAX_TOKENS", "SYSTEM_PROMPT_PATH"]:
|
||||
if os.environ.get(key):
|
||||
del os.environ[key]
|
||||
|
||||
# Mock the OpenAI client that OpenAICompatibleInferenceBot's __init__ might create
|
||||
self.mock_openai_client = MagicMock()
|
||||
|
||||
def tearDown(self):
|
||||
# Restore environment variables
|
||||
if self.original_openai_key: os.environ["OPENAI_API_KEY"] = self.original_openai_key
|
||||
if self.original_small_model: os.environ["OPENAI_SMALL_MODEL"] = self.original_small_model
|
||||
if self.original_large_model: os.environ["OPENAI_LARGE_MODEL"] = self.original_large_model
|
||||
if self.original_small_tokens: os.environ["OPENAI_SMALL_MODEL_MAX_TOKENS"] = self.original_small_tokens
|
||||
if self.original_large_tokens: os.environ["OPENAI_LARGE_MODEL_MAX_TOKENS"] = self.original_large_tokens
|
||||
if self.original_system_prompt_path: os.environ["SYSTEM_PROMPT_PATH"] = self.original_system_prompt_path
|
||||
|
||||
|
||||
@patch.object(OpenAICompatibleInferenceBot, '__init__') # Mock the superclass's __init__
|
||||
def test_init_defaults_and_super_call(self, mock_super_init):
|
||||
os.environ["OPENAI_API_KEY"] = "test_key_chatgpt"
|
||||
os.environ["OPENAI_SMALL_MODEL"] = "gpt-3.5-turbo-env"
|
||||
os.environ["OPENAI_SMALL_MODEL_MAX_TOKENS"] = "350"
|
||||
|
||||
bot = ChatGPTTelegramInferenceBot()
|
||||
|
||||
mock_super_init.assert_called_once_with(
|
||||
client=None, # ChatGPT bot will let superclass create it
|
||||
api_key="test_key_chatgpt", # Passed to super
|
||||
base_url=None,
|
||||
api_version=None,
|
||||
azure_deployment=None,
|
||||
model_name="gpt-3.5-turbo-env", # Default small model from env
|
||||
max_tokens_str="350", # Default small model tokens from env
|
||||
small_model_name="gpt-3.5-turbo-env",
|
||||
small_model_max_tokens_str="350",
|
||||
large_model_name=os.environ.get("OPENAI_LARGE_MODEL", "gpt-4-turbo-preview"), # Default large
|
||||
large_model_max_tokens_str=os.environ.get("OPENAI_LARGE_MODEL_MAX_TOKENS"),
|
||||
system_prompt_content=None,
|
||||
system_prompt_path=None,
|
||||
is_gemini=False,
|
||||
max_history_length=20 # Default from OpenAICompatibleInferenceBot
|
||||
)
|
||||
|
||||
@patch.object(OpenAICompatibleInferenceBot, '__init__')
|
||||
def test_init_with_arguments(self, mock_super_init):
|
||||
mock_client_arg = MagicMock()
|
||||
bot = ChatGPTTelegramInferenceBot(
|
||||
openai_client=mock_client_arg,
|
||||
api_key="arg_key",
|
||||
small_model_name="arg_small_model",
|
||||
small_model_max_tokens="123",
|
||||
large_model_name="arg_large_model",
|
||||
large_model_max_tokens="456",
|
||||
system_prompt_content="Arg prompt"
|
||||
)
|
||||
mock_super_init.assert_called_once_with(
|
||||
client=mock_client_arg,
|
||||
api_key="arg_key",
|
||||
base_url=None,
|
||||
api_version=None,
|
||||
azure_deployment=None,
|
||||
model_name="arg_small_model", # Initially configured with small model
|
||||
max_tokens_str="123",
|
||||
small_model_name="arg_small_model",
|
||||
small_model_max_tokens_str="123",
|
||||
large_model_name="arg_large_model",
|
||||
large_model_max_tokens_str="456",
|
||||
system_prompt_content="Arg prompt",
|
||||
system_prompt_path=None,
|
||||
is_gemini=False,
|
||||
max_history_length=20
|
||||
)
|
||||
|
||||
# Test switch_model - this method is part of ChatGPTTelegramInferenceBot
|
||||
# It calls _configure_model_and_tokens which is in the superclass.
|
||||
# We need a bot instance where _configure_model_and_tokens can be called.
|
||||
@patch('openai.OpenAI') # To allow instantiation of the bot by mocking client creation
|
||||
async def test_switch_model_logic(self, mock_openai_constructor):
|
||||
mock_openai_constructor.return_value = self.mock_openai_client # Mock client creation in super
|
||||
|
||||
# Set env vars for model names that switch_model will use as fallback
|
||||
os.environ["OPENAI_SMALL_MODEL"] = "env-small-gpt"
|
||||
os.environ["OPENAI_SMALL_MODEL_MAX_TOKENS"] = "100"
|
||||
os.environ["OPENAI_LARGE_MODEL"] = "env-large-gpt"
|
||||
os.environ["OPENAI_LARGE_MODEL_MAX_TOKENS"] = "200"
|
||||
|
||||
# Instantiate with initial model (small)
|
||||
bot = ChatGPTTelegramInferenceBot()
|
||||
self.assertEqual(bot.model, "env-small-gpt")
|
||||
self.assertEqual(bot.max_tokens, 100)
|
||||
|
||||
# Switch to large
|
||||
status = await bot.switch_model()
|
||||
self.assertEqual(bot.model, "env-large-gpt")
|
||||
self.assertEqual(bot.max_tokens, 200)
|
||||
self.assertEqual(status, "Switched to model: env-large-gpt")
|
||||
|
||||
# Switch back to small
|
||||
status = await bot.switch_model()
|
||||
self.assertEqual(bot.model, "env-small-gpt")
|
||||
self.assertEqual(bot.max_tokens, 100)
|
||||
self.assertEqual(status, "Switched to model: env-small-gpt")
|
||||
|
||||
@patch('openai.OpenAI')
|
||||
async def test_switch_model_uses_instance_configs_if_provided(self, mock_openai_constructor):
|
||||
mock_openai_constructor.return_value = self.mock_openai_client
|
||||
|
||||
# Instantiate with specific model names, overriding potential env vars
|
||||
bot = ChatGPTTelegramInferenceBot(
|
||||
small_model_name="init-small", small_model_max_tokens="50",
|
||||
large_model_name="init-large", large_model_max_tokens="150"
|
||||
)
|
||||
self.assertEqual(bot.model, "init-small") # Starts with small
|
||||
self.assertEqual(bot.max_tokens, 50)
|
||||
|
||||
# Switch to large
|
||||
status = await bot.switch_model()
|
||||
self.assertEqual(bot.model, "init-large")
|
||||
self.assertEqual(bot.max_tokens, 150)
|
||||
self.assertEqual(status, "Switched to model: init-large")
|
||||
|
||||
# Switch back to small
|
||||
status = await bot.switch_model()
|
||||
self.assertEqual(bot.model, "init-small")
|
||||
self.assertEqual(bot.max_tokens, 50)
|
||||
self.assertEqual(status, "Switched to model: init-small")
|
||||
|
||||
# get_llm_description is inherited from OpenAICompatibleInferenceBot.
|
||||
# Test just to ensure it works in the context of a ChatGPTBot instance
|
||||
@patch('openai.OpenAI')
|
||||
def test_get_llm_description_for_chatgpt_bot(self, mock_openai_constructor):
|
||||
mock_openai_constructor.return_value = self.mock_openai_client
|
||||
bot = ChatGPTTelegramInferenceBot(small_model_name="gpt-3.5-desc", small_model_max_tokens="777")
|
||||
# Initially configured with small model
|
||||
self.assertEqual(bot.get_llm_description(), "LLM: gpt-3.5-desc, Max Tokens: 777, Azure: False")
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -1,154 +0,0 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch, ANY
|
||||
import os
|
||||
|
||||
# Assuming gemini_telegram_inference_bot.py and its parent are accessible
|
||||
from gemini_telegram_inference_bot import GeminiTelegramInferenceBot
|
||||
from openai_compatible_inference_bot import OpenAICompatibleInferenceBot # For patching super
|
||||
|
||||
class TestGeminiTelegramInferenceBot(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Store and clear relevant environment variables
|
||||
self.original_gemini_key = os.environ.get("GEMINI_API_KEY")
|
||||
self.original_gemini_base_url = os.environ.get("GEMINI_API_BASE_URL")
|
||||
self.original_small_model = os.environ.get("GEMINI_SMALL_MODEL")
|
||||
self.original_large_model = os.environ.get("GEMINI_LARGE_MODEL")
|
||||
self.original_small_tokens = os.environ.get("GEMINI_SMALL_MODEL_MAX_TOKENS")
|
||||
self.original_large_tokens = os.environ.get("GEMINI_LARGE_MODEL_MAX_TOKENS")
|
||||
self.original_system_prompt_path = os.environ.get("SYSTEM_PROMPT_PATH")
|
||||
|
||||
for key in ["GEMINI_API_KEY", "GEMINI_API_BASE_URL", "GEMINI_SMALL_MODEL", "GEMINI_LARGE_MODEL",
|
||||
"GEMINI_SMALL_MODEL_MAX_TOKENS", "GEMINI_LARGE_MODEL_MAX_TOKENS", "SYSTEM_PROMPT_PATH"]:
|
||||
if os.environ.get(key):
|
||||
del os.environ[key]
|
||||
|
||||
self.mock_openai_client = MagicMock() # Used if superclass creates an OpenAI client
|
||||
|
||||
def tearDown(self):
|
||||
# Restore environment variables
|
||||
if self.original_gemini_key: os.environ["GEMINI_API_KEY"] = self.original_gemini_key
|
||||
if self.original_gemini_base_url: os.environ["GEMINI_API_BASE_URL"] = self.original_gemini_base_url
|
||||
if self.original_small_model: os.environ["GEMINI_SMALL_MODEL"] = self.original_small_model
|
||||
if self.original_large_model: os.environ["GEMINI_LARGE_MODEL"] = self.original_large_model
|
||||
if self.original_small_tokens: os.environ["GEMINI_SMALL_MODEL_MAX_TOKENS"] = self.original_small_tokens
|
||||
if self.original_large_tokens: os.environ["GEMINI_LARGE_MODEL_MAX_TOKENS"] = self.original_large_tokens
|
||||
if self.original_system_prompt_path: os.environ["SYSTEM_PROMPT_PATH"] = self.original_system_prompt_path
|
||||
|
||||
@patch.object(OpenAICompatibleInferenceBot, '__init__') # Mock the superclass's __init__
|
||||
def test_init_defaults_and_super_call(self, mock_super_init):
|
||||
os.environ["GEMINI_API_KEY"] = "test_key_gemini"
|
||||
os.environ["GEMINI_API_BASE_URL"] = "https://gemini.env.com"
|
||||
os.environ["GEMINI_SMALL_MODEL"] = "gemini-pro-env"
|
||||
os.environ["GEMINI_SMALL_MODEL_MAX_TOKENS"] = "360"
|
||||
|
||||
bot = GeminiTelegramInferenceBot()
|
||||
|
||||
mock_super_init.assert_called_once_with(
|
||||
client=None,
|
||||
api_key="test_key_gemini",
|
||||
base_url="https://gemini.env.com", # Passed to super
|
||||
api_version=None,
|
||||
azure_deployment=None,
|
||||
model_name="gemini-pro-env",
|
||||
max_tokens_str="360",
|
||||
small_model_name="gemini-pro-env",
|
||||
small_model_max_tokens_str="360",
|
||||
large_model_name=os.environ.get("GEMINI_LARGE_MODEL", "gemini-1.5-pro-latest"), # Default large
|
||||
large_model_max_tokens_str=os.environ.get("GEMINI_LARGE_MODEL_MAX_TOKENS"),
|
||||
system_prompt_content=None,
|
||||
system_prompt_path=None,
|
||||
is_gemini=True, # Important for Gemini bot
|
||||
max_history_length=20
|
||||
)
|
||||
|
||||
@patch.object(OpenAICompatibleInferenceBot, '__init__')
|
||||
def test_init_with_arguments(self, mock_super_init):
|
||||
mock_client_arg = MagicMock()
|
||||
bot = GeminiTelegramInferenceBot(
|
||||
openai_client=mock_client_arg, # Name in Gemini bot is openai_client for consistency
|
||||
api_key="arg_gem_key",
|
||||
base_url="https://arg.gemini.com",
|
||||
small_model_name="arg_gem_small",
|
||||
small_model_max_tokens="124",
|
||||
large_model_name="arg_gem_large",
|
||||
large_model_max_tokens="457",
|
||||
system_prompt_content="Gemini prompt"
|
||||
)
|
||||
mock_super_init.assert_called_once_with(
|
||||
client=mock_client_arg,
|
||||
api_key="arg_gem_key",
|
||||
base_url="https://arg.gemini.com",
|
||||
api_version=None,
|
||||
azure_deployment=None,
|
||||
model_name="arg_gem_small",
|
||||
max_tokens_str="124",
|
||||
small_model_name="arg_gem_small",
|
||||
small_model_max_tokens_str="124",
|
||||
large_model_name="arg_gem_large",
|
||||
large_model_max_tokens_str="457",
|
||||
system_prompt_content="Gemini prompt",
|
||||
system_prompt_path=None,
|
||||
is_gemini=True,
|
||||
max_history_length=20
|
||||
)
|
||||
|
||||
@patch('openai.OpenAI') # Gemini bot uses OpenAI client configured for Gemini endpoint
|
||||
async def test_switch_model_logic(self, mock_openai_constructor):
|
||||
mock_openai_constructor.return_value = self.mock_openai_client
|
||||
|
||||
os.environ["GEMINI_SMALL_MODEL"] = "env-gemini-small"
|
||||
os.environ["GEMINI_SMALL_MODEL_MAX_TOKENS"] = "110"
|
||||
os.environ["GEMINI_LARGE_MODEL"] = "env-gemini-large"
|
||||
os.environ["GEMINI_LARGE_MODEL_MAX_TOKENS"] = "220"
|
||||
|
||||
bot = GeminiTelegramInferenceBot() # Uses env vars by default
|
||||
self.assertEqual(bot.model, "env-gemini-small")
|
||||
self.assertEqual(bot.max_tokens, 110)
|
||||
|
||||
status = await bot.switch_model()
|
||||
self.assertEqual(bot.model, "env-gemini-large")
|
||||
self.assertEqual(bot.max_tokens, 220)
|
||||
self.assertEqual(status, "Switched to model: env-gemini-large")
|
||||
|
||||
status = await bot.switch_model()
|
||||
self.assertEqual(bot.model, "env-gemini-small")
|
||||
self.assertEqual(bot.max_tokens, 110)
|
||||
self.assertEqual(status, "Switched to model: env-gemini-small")
|
||||
|
||||
@patch('openai.OpenAI')
|
||||
async def test_switch_model_uses_instance_configs_if_provided(self, mock_openai_constructor):
|
||||
mock_openai_constructor.return_value = self.mock_openai_client
|
||||
|
||||
bot = GeminiTelegramInferenceBot(
|
||||
small_model_name="init-gem-small", small_model_max_tokens="55",
|
||||
large_model_name="init-gem-large", large_model_max_tokens="155"
|
||||
)
|
||||
self.assertEqual(bot.model, "init-gem-small")
|
||||
self.assertEqual(bot.max_tokens, 55)
|
||||
|
||||
status = await bot.switch_model()
|
||||
self.assertEqual(bot.model, "init-gem-large")
|
||||
self.assertEqual(bot.max_tokens, 155)
|
||||
self.assertEqual(status, "Switched to model: init-gem-large")
|
||||
|
||||
status = await bot.switch_model()
|
||||
self.assertEqual(bot.model, "init-gem-small")
|
||||
self.assertEqual(bot.max_tokens, 55)
|
||||
self.assertEqual(status, "Switched to model: init-gem-small")
|
||||
|
||||
@patch('openai.OpenAI')
|
||||
def test_get_llm_description_for_gemini_bot(self, mock_openai_constructor):
|
||||
mock_openai_constructor.return_value = self.mock_openai_client
|
||||
bot = GeminiTelegramInferenceBot(
|
||||
small_model_name="gemini-pro-desc",
|
||||
small_model_max_tokens="888",
|
||||
# is_gemini is True by default in constructor call to super
|
||||
)
|
||||
# LLM description should indicate not Azure, even though it uses OpenAICompatible... base
|
||||
# The is_gemini flag primarily affects client instantiation logic in the superclass.
|
||||
# The azure_openai flag in superclass is based on azure_endpoint presence.
|
||||
self.assertEqual(bot.get_llm_description(), "LLM: gemini-pro-desc, Max Tokens: 888, Azure: False")
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -1,81 +0,0 @@
|
||||
# tests/test_github_tool.py
|
||||
|
||||
import unittest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from tools.github_tool import GitHubTool
|
||||
|
||||
class TestGitHubTool(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.github_tool = GitHubTool()
|
||||
|
||||
def test_get_functions(self):
|
||||
functions = self.github_tool.get_functions()
|
||||
self.assertEqual(len(functions), 4)
|
||||
function_names = [f["name"] for f in functions]
|
||||
expected_names = ["read_file", "create_branch", "commit_file", "create_pull_request"]
|
||||
self.assertListEqual(function_names, expected_names)
|
||||
|
||||
@patch('tools.github_tool.requests.get')
|
||||
def test_read_file(self, mock_get):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"content": "file content"}
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
result = self.github_tool.execute("read_file", path="test.txt")
|
||||
self.assertEqual(result, "file content")
|
||||
|
||||
mock_get.assert_called_once()
|
||||
|
||||
@patch('tools.github_tool.requests.get')
|
||||
@patch('tools.github_tool.requests.post')
|
||||
def test_create_branch(self, mock_post, mock_get):
|
||||
mock_get_response = MagicMock()
|
||||
mock_get_response.status_code = 200
|
||||
mock_get_response.json.return_value = {"object": {"sha": "test_sha"}}
|
||||
mock_get.return_value = mock_get_response
|
||||
|
||||
mock_post_response = MagicMock()
|
||||
mock_post_response.status_code = 201
|
||||
mock_post.return_value = mock_post_response
|
||||
|
||||
result = self.github_tool.execute("create_branch", branch_name="test-branch")
|
||||
self.assertEqual(result, "Branch 'test-branch' created successfully")
|
||||
|
||||
mock_get.assert_called_once()
|
||||
mock_post.assert_called_once()
|
||||
|
||||
@patch('tools.github_tool.requests.put')
|
||||
def test_commit_file(self, mock_put):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_put.return_value = mock_response
|
||||
|
||||
result = self.github_tool.execute("commit_file", branch_name="test-branch", file_path="test.txt", content="test content", commit_message="Test commit")
|
||||
self.assertEqual(result, "File committed successfully to branch 'test-branch'")
|
||||
|
||||
mock_put.assert_called_once()
|
||||
|
||||
def test_commit_file_to_main(self):
|
||||
result = self.github_tool.execute("commit_file", branch_name="main", file_path="test.txt", content="test content", commit_message="Test commit")
|
||||
self.assertEqual(result, "Cannot commit directly to main branch")
|
||||
|
||||
@patch('tools.github_tool.requests.post')
|
||||
def test_create_pull_request(self, mock_post):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 201
|
||||
mock_response.json.return_value = {"html_url": "https://github.com/test/test/pull/1"}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = self.github_tool.execute("create_pull_request", title="Test PR", body="Test body", head="test-branch")
|
||||
self.assertEqual(result, "Pull request created successfully: https://github.com/test/test/pull/1")
|
||||
|
||||
mock_post.assert_called_once()
|
||||
|
||||
def test_unknown_function(self):
|
||||
result = self.github_tool.execute("unknown_function")
|
||||
self.assertEqual(result, "Unknown function: unknown_function")
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -1,332 +0,0 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch, AsyncMock, ANY
|
||||
import os
|
||||
import json
|
||||
|
||||
# Assuming openai_compatible_inference_bot.py is in the parent directory or PYTHONPATH is set
|
||||
from openai_compatible_inference_bot import OpenAICompatibleInferenceBot
|
||||
|
||||
# Mock response from OpenAI client's chat.completions.create
|
||||
def create_mock_openai_response(content=None, tool_calls=None):
|
||||
mock_message = MagicMock()
|
||||
mock_message.role = "assistant"
|
||||
mock_message.content = content
|
||||
if tool_calls:
|
||||
# tool_calls should be a list of objects with id and function (name, arguments)
|
||||
mock_tool_calls = []
|
||||
for tc in tool_calls:
|
||||
mock_tc = MagicMock()
|
||||
mock_tc.id = tc["id"]
|
||||
mock_tc.function.name = tc["function"]["name"]
|
||||
mock_tc.function.arguments = tc["function"]["arguments"]
|
||||
mock_tool_calls.append(mock_tc)
|
||||
mock_message.tool_calls = mock_tool_calls
|
||||
else:
|
||||
mock_message.tool_calls = None
|
||||
|
||||
mock_choice = MagicMock()
|
||||
mock_choice.message = mock_message
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [mock_choice]
|
||||
return mock_response
|
||||
|
||||
# Concrete class for testing
|
||||
class ConcreteOpenAICompatibleBot(OpenAICompatibleInferenceBot):
|
||||
# Implement abstract methods for instantiation
|
||||
async def switch_model(self):
|
||||
# Simple switch for testing if needed, or just pass
|
||||
if self.model == self.small_model_name:
|
||||
self._configure_model_and_tokens(self.large_model_name, self.large_model_max_tokens_str)
|
||||
else:
|
||||
self._configure_model_and_tokens(self.small_model_name, self.small_model_max_tokens_str)
|
||||
return f"Switched to {self.model}"
|
||||
|
||||
# Override load_functions if it's called by parent and needs mocking for these tests
|
||||
# (OpenAICompatibleInferenceBot's __init__ calls BaseTelegramInferenceBot's __init__, which calls load_functions)
|
||||
def load_functions(self):
|
||||
# For these tests, assume no tools unless specifically added
|
||||
self.tools = []
|
||||
self.functions = []
|
||||
return self.tools, self.functions
|
||||
|
||||
|
||||
class TestOpenAICompatibleInferenceBot(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.original_openai_api_key = os.environ.get("OPENAI_API_KEY")
|
||||
self.original_azure_openai_key = os.environ.get("AZURE_OPENAI_KEY")
|
||||
self.original_azure_endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT")
|
||||
self.original_api_version = os.environ.get("AZURE_OPENAI_API_VERSION")
|
||||
self.original_azure_deployment = os.environ.get("AZURE_DEPLOYMENT_NAME")
|
||||
|
||||
# Clear relevant env vars before each test
|
||||
for key in ["OPENAI_API_KEY", "AZURE_OPENAI_KEY", "AZURE_OPENAI_ENDPOINT",
|
||||
"AZURE_OPENAI_API_VERSION", "AZURE_DEPLOYMENT_NAME", "SYSTEM_PROMPT_PATH"]:
|
||||
if os.environ.get(key):
|
||||
del os.environ[key]
|
||||
|
||||
self.mock_openai_client_instance = MagicMock()
|
||||
self.mock_openai_client_instance.chat.completions.create = MagicMock()
|
||||
|
||||
def tearDown(self):
|
||||
# Restore environment variables
|
||||
if self.original_openai_api_key: os.environ["OPENAI_API_KEY"] = self.original_openai_api_key
|
||||
if self.original_azure_openai_key: os.environ["AZURE_OPENAI_KEY"] = self.original_azure_openai_key
|
||||
if self.original_azure_endpoint: os.environ["AZURE_OPENAI_ENDPOINT"] = self.original_azure_endpoint
|
||||
if self.original_api_version: os.environ["AZURE_OPENAI_API_VERSION"] = self.original_api_version
|
||||
if self.original_azure_deployment: os.environ["AZURE_DEPLOYMENT_NAME"] = self.original_azure_deployment
|
||||
|
||||
|
||||
@patch('openai.OpenAI')
|
||||
def test_init_with_openai_defaults(self, MockOpenAIConstructor):
|
||||
MockOpenAIConstructor.return_value = self.mock_openai_client_instance
|
||||
os.environ["OPENAI_API_KEY"] = "test_openai_key"
|
||||
|
||||
bot = ConcreteOpenAICompatibleBot(model_name="gpt-4")
|
||||
|
||||
MockOpenAIConstructor.assert_called_once_with(api_key="test_openai_key", base_url=None)
|
||||
self.assertEqual(bot.client, self.mock_openai_client_instance)
|
||||
self.assertEqual(bot.model, "gpt-4")
|
||||
self.assertEqual(bot.max_tokens, 1000) # Default from _configure_model_and_tokens
|
||||
self.assertEqual(bot.azure_openai, False)
|
||||
|
||||
@patch('openai.OpenAI')
|
||||
def test_init_with_provided_client(self, MockOpenAIConstructor):
|
||||
preconfigured_client = MagicMock()
|
||||
bot = ConcreteOpenAICompatibleBot(client=preconfigured_client, model_name="gpt-3.5")
|
||||
|
||||
MockOpenAIConstructor.assert_not_called()
|
||||
self.assertEqual(bot.client, preconfigured_client)
|
||||
self.assertEqual(bot.model, "gpt-3.5")
|
||||
|
||||
@patch('openai.AzureOpenAI')
|
||||
def test_init_with_azure_config_args(self, MockAzureOpenAIConstructor):
|
||||
MockAzureOpenAIConstructor.return_value = self.mock_openai_client_instance
|
||||
|
||||
bot = ConcreteOpenAICompatibleBot(
|
||||
api_key="azure_key",
|
||||
azure_endpoint="https://myenv.openai.azure.com",
|
||||
api_version="2023-05-15",
|
||||
azure_deployment="my-gpt-4", # This should be used as model_name for API call
|
||||
model_name="should_be_overridden_by_azure_deployment_for_api"
|
||||
# model_name is passed to _configure_model_and_tokens, which sets self.model for display/logging
|
||||
# but for Azure, the client needs the deployment name.
|
||||
)
|
||||
|
||||
MockAzureOpenAIConstructor.assert_called_once_with(
|
||||
api_key="azure_key",
|
||||
azure_endpoint="https://myenv.openai.azure.com",
|
||||
api_version="2023-05-15"
|
||||
)
|
||||
self.assertEqual(bot.client, self.mock_openai_client_instance)
|
||||
self.assertEqual(bot.model, "my-gpt-4") # Azure deployment name becomes the model for API calls
|
||||
self.assertEqual(bot.azure_openai, True)
|
||||
|
||||
|
||||
@patch('openai.AzureOpenAI')
|
||||
def test_init_with_azure_env_vars(self, MockAzureOpenAIConstructor):
|
||||
MockAzureOpenAIConstructor.return_value = self.mock_openai_client_instance
|
||||
os.environ["AZURE_OPENAI_KEY"] = "env_azure_key"
|
||||
os.environ["AZURE_OPENAI_ENDPOINT"] = "https://env.openai.azure.com"
|
||||
os.environ["AZURE_OPENAI_API_VERSION"] = "2023-06-01"
|
||||
os.environ["AZURE_DEPLOYMENT_NAME"] = "env-gpt-35" # Used as model_name
|
||||
|
||||
bot = ConcreteOpenAICompatibleBot(model_name="ignored_if_azure_deployment_env_is_set")
|
||||
|
||||
MockAzureOpenAIConstructor.assert_called_once_with(
|
||||
api_key="env_azure_key",
|
||||
azure_endpoint="https://env.openai.azure.com",
|
||||
api_version="2023-06-01"
|
||||
)
|
||||
self.assertEqual(bot.model, "env-gpt-35")
|
||||
self.assertTrue(bot.azure_openai)
|
||||
|
||||
@patch('openai.OpenAI')
|
||||
def test_init_with_gemini_config_args(self, MockOpenAIConstructor):
|
||||
MockOpenAIConstructor.return_value = self.mock_openai_client_instance
|
||||
|
||||
bot = ConcreteOpenAICompatibleBot(
|
||||
api_key="gemini_key",
|
||||
base_url="https://gemini.example.com",
|
||||
model_name="gemini-pro",
|
||||
is_gemini=True
|
||||
)
|
||||
MockOpenAIConstructor.assert_called_once_with(api_key="gemini_key", base_url="https://gemini.example.com")
|
||||
self.assertEqual(bot.model, "gemini-pro")
|
||||
self.assertFalse(bot.azure_openai) # is_gemini doesn't mean azure_openai
|
||||
|
||||
def test_configure_model_and_tokens(self):
|
||||
bot = ConcreteOpenAICompatibleBot(model_name="initial_model") # init calls _configure
|
||||
bot._configure_model_and_tokens("test-model", "500")
|
||||
self.assertEqual(bot.model, "test-model")
|
||||
self.assertEqual(bot.max_tokens, 500)
|
||||
|
||||
bot._configure_model_and_tokens("test-model-2", None, default_max_tokens=150)
|
||||
self.assertEqual(bot.max_tokens, 150)
|
||||
|
||||
bot._configure_model_and_tokens("test-model-3", "invalid_token_val")
|
||||
self.assertEqual(bot.max_tokens, 1000) # Default fallback
|
||||
|
||||
def test_get_llm_description(self):
|
||||
bot = ConcreteOpenAICompatibleBot(model_name="desc-model", max_tokens_str="256")
|
||||
self.assertEqual(bot.get_llm_description(), "LLM: desc-model, Max Tokens: 256, Azure: False")
|
||||
|
||||
bot_azure = ConcreteOpenAICompatibleBot(azure_deployment="azure-model", azure_endpoint="x", api_key="y", api_version="z")
|
||||
self.assertEqual(bot_azure.get_llm_description(), "LLM: azure-model, Max Tokens: 1000, Azure: True")
|
||||
|
||||
|
||||
def test_get_chat_response_success(self):
|
||||
bot = ConcreteOpenAICompatibleBot(client=self.mock_openai_client_instance, model_name="test-gpt")
|
||||
bot.max_tokens = 50 # Ensure this is set
|
||||
mock_api_response = create_mock_openai_response(content="Hello from API")
|
||||
self.mock_openai_client_instance.chat.completions.create.return_value = mock_api_response
|
||||
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
response = bot.get_chat_response(messages)
|
||||
|
||||
self.mock_openai_client_instance.chat.completions.create.assert_called_once_with(
|
||||
model="test-gpt",
|
||||
messages=messages,
|
||||
tools=ANY, # Assuming functions can be None or empty list
|
||||
tool_choice=ANY,
|
||||
max_tokens=50
|
||||
)
|
||||
self.assertEqual(response, mock_api_response)
|
||||
|
||||
def test_get_chat_response_api_error(self):
|
||||
bot = ConcreteOpenAICompatibleBot(client=self.mock_openai_client_instance, model_name="error-gpt")
|
||||
self.mock_openai_client_instance.chat.completions.create.side_effect = Exception("API Down")
|
||||
|
||||
with self.assertRaisesRegex(Exception, "API Down"):
|
||||
bot.get_chat_response([{"role": "user", "content": "trigger"}])
|
||||
|
||||
async def test_handle_message_simple_response(self):
|
||||
bot = ConcreteOpenAICompatibleBot(client=self.mock_openai_client_instance, model_name="chatty")
|
||||
bot.system_prompt = "You are a test bot." # Set directly for simplicity
|
||||
mock_api_response = create_mock_openai_response(content="Test reply")
|
||||
self.mock_openai_client_instance.chat.completions.create.return_value = mock_api_response
|
||||
|
||||
response_content = await bot.handle_message(user_id=1, user_message="Hello")
|
||||
|
||||
self.assertEqual(response_content, "Test reply")
|
||||
self.assertIn(1, bot.conversation_history)
|
||||
self.assertEqual(len(bot.conversation_history[1]), 3) # System, User, Assistant
|
||||
self.assertEqual(bot.conversation_history[1][0]["content"], "You are a test bot.")
|
||||
self.assertEqual(bot.conversation_history[1][2]["content"], "Test reply")
|
||||
|
||||
async def test_handle_message_with_tool_call_and_response(self):
|
||||
bot = ConcreteOpenAICompatibleBot(client=self.mock_openai_client_instance, model_name="tool-user")
|
||||
|
||||
# Mock functions/tools setup on the bot
|
||||
mock_tool_def = {"function": {"name": "get_weather", "description": "Gets weather", "parameters": {}}}
|
||||
bot.functions = [mock_tool_def] # Simulate tools are loaded
|
||||
|
||||
# API response 1: Request to call tool
|
||||
tool_call_request = [{"id": "call123", "function": {"name": "get_weather", "arguments": '''{"location": "moon"}'''}}]
|
||||
api_response_1 = create_mock_openai_response(tool_calls=tool_call_request)
|
||||
|
||||
# API response 2: Final answer after tool execution
|
||||
api_response_2 = create_mock_openai_response(content="The weather on the moon is chilly.")
|
||||
|
||||
self.mock_openai_client_instance.chat.completions.create.side_effect = [api_response_1, api_response_2]
|
||||
|
||||
# Mock self.call_tool
|
||||
bot.call_tool = MagicMock(return_value='''{"temperature": "-100 C"}''')
|
||||
|
||||
final_response = await bot.handle_message(user_id=2, user_message="Weather on moon?")
|
||||
|
||||
self.assertEqual(final_response, "The weather on the moon is chilly.")
|
||||
bot.call_tool.assert_called_once_with("get_weather", '''{"location": "moon"}''')
|
||||
|
||||
# Check conversation history includes tool messages
|
||||
history = bot.conversation_history[2]
|
||||
self.assertTrue(any(msg["role"] == "assistant" and msg.tool_calls is not None for msg in history))
|
||||
self.assertTrue(any(msg["role"] == "tool" and msg["name"] == "get_weather" for msg in history))
|
||||
self.assertEqual(self.mock_openai_client_instance.chat.completions.create.call_count, 2)
|
||||
|
||||
async def test_handle_message_max_history_length(self):
|
||||
bot = ConcreteOpenAICompatibleBot(client=self.mock_openai_client_instance, model_name="hist-test", max_history_length=3)
|
||||
self.mock_openai_client_instance.chat.completions.create.return_value = create_mock_openai_response(content="Ok")
|
||||
|
||||
await bot.handle_message(1, "Msg1") # Sys, User, Assist (3)
|
||||
self.assertEqual(len(bot.conversation_history[1]), 3)
|
||||
|
||||
await bot.handle_message(1, "Msg2") # User, Assist. Should be 3 (prev User, prev Assist, new User) -> then adds new Assist.
|
||||
# Before new call: [Sys, U1, A1]. New U2. Call with [Sys,U1,A1,U2]. Resp A2.
|
||||
# History: [Sys,U1,A1,U2,A2]. Limit 3. -> [A1,U2,A2] (if system is not preserved specially)
|
||||
# The current code appends to history then truncates if over limit.
|
||||
# So after Msg1: [S, U1, A1]. len=3.
|
||||
# For Msg2: History is [S, U1, A1]. Append U2. Call with [S,U1,A1,U2]. Append A2.
|
||||
# History now [S,U1,A1,U2,A2]. len=5. Truncate to 3.
|
||||
# Expected: [A1, U2, A2] or [U1,A1,U2] or [U2,A2,S] depending on how system prompt is handled in truncation.
|
||||
# The code is: self.conversation_history[user_id][-self.max_history_length:]
|
||||
# And system prompt is only added IF user_id not in self.conversation_history.
|
||||
# So, for Msg2, system prompt is not re-added.
|
||||
# History before Msg2 call: [S, U1, A1]
|
||||
# Messages for Msg2 call: [S, U1, A1, U2]
|
||||
# History after Msg2 response A2: [S, U1, A1, U2, A2]. Len 5.
|
||||
# Truncated to self.max_history_length=3: [A1, U2, A2]
|
||||
|
||||
# Call 1
|
||||
self.mock_openai_client_instance.chat.completions.create.reset_mock()
|
||||
self.mock_openai_client_instance.chat.completions.create.return_value = create_mock_openai_response(content="Reply1")
|
||||
await bot.handle_message(user_id=7, user_message="First message")
|
||||
self.assertEqual(len(bot.conversation_history[7]), 3) # System, User1, Assistant1
|
||||
|
||||
# Call 2
|
||||
self.mock_openai_client_instance.chat.completions.create.reset_mock()
|
||||
self.mock_openai_client_instance.chat.completions.create.return_value = create_mock_openai_response(content="Reply2")
|
||||
await bot.handle_message(user_id=7, user_message="Second message")
|
||||
# History before call: [S, U1, A1]. Messages for call: [S, U1, A1, U2]. History after: [S, U1, A1, U2, A2].
|
||||
# Truncated to 3: [A1, U2, A2]
|
||||
self.assertEqual(len(bot.conversation_history[7]), 3)
|
||||
self.assertEqual(bot.conversation_history[7][0]["content"], "Reply1") # A1
|
||||
self.assertEqual(bot.conversation_history[7][1]["content"], "Second message") # U2
|
||||
self.assertEqual(bot.conversation_history[7][2]["content"], "Reply2") # A2
|
||||
|
||||
# Call 3
|
||||
self.mock_openai_client_instance.chat.completions.create.reset_mock()
|
||||
self.mock_openai_client_instance.chat.completions.create.return_value = create_mock_openai_response(content="Reply3")
|
||||
await bot.handle_message(user_id=7, user_message="Third message")
|
||||
# History before call: [A1, U2, A2]. Messages for call: [A1, U2, A2, U3]. History after: [A1, U2, A2, U3, A3].
|
||||
# Truncated to 3: [A2, U3, A3]
|
||||
self.assertEqual(len(bot.conversation_history[7]), 3)
|
||||
self.assertEqual(bot.conversation_history[7][0]["content"], "Reply2") # A2
|
||||
self.assertEqual(bot.conversation_history[7][1]["content"], "Third message") # U3
|
||||
self.assertEqual(bot.conversation_history[7][2]["content"], "Reply3") # A3
|
||||
|
||||
|
||||
async def test_abort_processing(self):
|
||||
bot = ConcreteOpenAICompatibleBot(model_name="test")
|
||||
user_id = 123
|
||||
bot.processing_status[user_id] = {"processing": True, "message_id": 456}
|
||||
bot.conversation_history[user_id] = [{"role": "user", "content": "stuff"}]
|
||||
|
||||
with patch.object(bot, 'clear_conversation_history') as mock_clear_hist: # Patching the method from Base class
|
||||
result = await bot.abort_processing(user_id)
|
||||
|
||||
self.assertEqual(result, "Processing aborted and conversation cleared.")
|
||||
self.assertFalse(bot.processing_status[user_id]["processing"])
|
||||
mock_clear_hist.assert_called_once_with(user_id)
|
||||
|
||||
async def test_abort_processing_no_active_processing(self):
|
||||
bot = ConcreteOpenAICompatibleBot(model_name="test")
|
||||
user_id = 404 # Not in processing_status
|
||||
with patch.object(bot, 'clear_conversation_history') as mock_clear_hist:
|
||||
result = await bot.abort_processing(user_id)
|
||||
self.assertEqual(result, "No active processing found to abort. Conversation cleared.")
|
||||
mock_clear_hist.assert_called_once_with(user_id)
|
||||
|
||||
# Test for the abstract switch_model (basic call, actual logic in concrete class for this test)
|
||||
async def test_switch_model_concrete_implementation(self):
|
||||
bot = ConcreteOpenAICompatibleBot(model_name="model1", small_model_name="model1", large_model_name="model2", max_tokens_str="100")
|
||||
self.assertEqual(bot.model, "model1")
|
||||
await bot.switch_model() # Calls the concrete implementation
|
||||
self.assertEqual(bot.model, "model2")
|
||||
await bot.switch_model()
|
||||
self.assertEqual(bot.model, "model1")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -1,356 +0,0 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch, mock_open, AsyncMock
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
|
||||
# Assuming telegram_helper.py is in the parent directory or PYTHONPATH is set
|
||||
from telegram_helper import TelegramHelper, MessageHandlerLogicResult
|
||||
|
||||
# Mock for the bot passed to TelegramHelper
|
||||
class MockBot:
|
||||
def __init__(self):
|
||||
self.start = AsyncMock()
|
||||
self.clear_conversation_history = MagicMock()
|
||||
self.get_bot_status = AsyncMock(return_value="Bot Status OK")
|
||||
self.switch_model = AsyncMock(return_value="Model Switched OK")
|
||||
self.handle_message = AsyncMock() # Needs to return a string
|
||||
self.abort_processing = AsyncMock(return_value="Abort OK")
|
||||
self.set_processing_status = MagicMock()
|
||||
self.clear_processing_status = MagicMock()
|
||||
self.processing_status = {} # Add the attribute
|
||||
|
||||
# Mock for telegram.Update and related objects
|
||||
def create_mock_update(message_text=None, user_id=123, chat_id=456, message_id=789, callback_query_data=None):
|
||||
update = MagicMock()
|
||||
update.effective_user.id = user_id
|
||||
update.effective_chat.id = chat_id
|
||||
|
||||
if message_text:
|
||||
update.message.text = message_text
|
||||
update.message.reply_text = AsyncMock(return_value=MagicMock(message_id=message_id)) # reply_text returns a Message obj
|
||||
|
||||
if callback_query_data:
|
||||
update.callback_query.data = callback_query_data
|
||||
update.callback_query.from_user.id = user_id
|
||||
update.callback_query.answer = AsyncMock()
|
||||
update.callback_query.edit_message_text = AsyncMock()
|
||||
|
||||
return update
|
||||
|
||||
# Mock for telegram.ext.ContextTypes.DEFAULT_TYPE
|
||||
def create_mock_context():
|
||||
context = MagicMock()
|
||||
context.bot.delete_message = AsyncMock()
|
||||
context.bot.edit_message_text = AsyncMock() # For update_status_message
|
||||
return context
|
||||
|
||||
class TestTelegramHelper(unittest.IsolatedAsyncioTestCase): # Use IsolatedAsyncioTestCase for async methods
|
||||
|
||||
def setUp(self):
|
||||
self.mock_bot = MockBot()
|
||||
# Default paths for reboot files, can be overridden in tests
|
||||
self.reboot_claude_file = ".test_reboot_claude"
|
||||
self.reboot_file = ".test_doreboot"
|
||||
self.helper = TelegramHelper(
|
||||
self.mock_bot,
|
||||
reboot_claude_file_path=self.reboot_claude_file,
|
||||
reboot_file_path=self.reboot_file,
|
||||
chunk_message_sleep_duration=0.001 # Faster sleep for tests
|
||||
)
|
||||
# Clean up any potential leftover reboot files from previous runs
|
||||
if os.path.exists(self.reboot_claude_file):
|
||||
os.remove(self.reboot_claude_file)
|
||||
if os.path.exists(self.reboot_file):
|
||||
os.remove(self.reboot_file)
|
||||
|
||||
def tearDown(self):
|
||||
# Clean up reboot files created during tests
|
||||
if os.path.exists(self.reboot_claude_file):
|
||||
os.remove(self.reboot_claude_file)
|
||||
if os.path.exists(self.reboot_file):
|
||||
os.remove(self.reboot_file)
|
||||
|
||||
async def test_start_logic(self):
|
||||
response = await self.helper._start_logic()
|
||||
self.mock_bot.start.assert_called_once()
|
||||
self.assertEqual(response, "Hello! I\'m your AI assistant. How can I help you today?")
|
||||
|
||||
async def test_start_command(self):
|
||||
mock_update = create_mock_update(message_text="/start")
|
||||
mock_context = create_mock_context()
|
||||
|
||||
with patch.object(self.helper, \'_start_logic\', new_callable=AsyncMock) as mock_logic:
|
||||
mock_logic.return_value = "Start Logic Response"
|
||||
await self.helper.start(mock_update, mock_context)
|
||||
mock_logic.assert_called_once()
|
||||
mock_update.message.reply_text.assert_called_once_with("Start Logic Response")
|
||||
|
||||
async def test_clear_logic(self):
|
||||
user_id = 123
|
||||
response = await self.helper._clear_logic(user_id) # _clear_logic is async after refactor
|
||||
self.mock_bot.clear_conversation_history.assert_called_once_with(user_id)
|
||||
self.assertEqual(response, "Conversation history cleared. Let\'s start fresh!")
|
||||
|
||||
async def test_clear_command(self):
|
||||
mock_update = create_mock_update(message_text="/clear", user_id=123)
|
||||
mock_context = create_mock_context()
|
||||
with patch.object(self.helper, \'_clear_logic\', new_callable=AsyncMock) as mock_logic:
|
||||
mock_logic.return_value = "Clear Logic Response"
|
||||
await self.helper.clear(mock_update, mock_context)
|
||||
mock_logic.assert_called_once_with(123)
|
||||
mock_update.message.reply_text.assert_called_once_with("Clear Logic Response")
|
||||
|
||||
async def test_status_logic(self):
|
||||
self.mock_bot.get_bot_status.return_value = "Test Status"
|
||||
response = await self.helper._status_logic()
|
||||
self.mock_bot.get_bot_status.assert_called_once()
|
||||
self.assertEqual(response, "Test Status")
|
||||
|
||||
async def test_switch_logic_supported(self):
|
||||
self.mock_bot.switch_model.return_value = "Switched to Large Model"
|
||||
response = await self.helper._switch_logic()
|
||||
self.mock_bot.switch_model.assert_called_once()
|
||||
self.assertEqual(response, "Switched to Large Model")
|
||||
|
||||
async def test_switch_logic_not_supported(self):
|
||||
del self.mock_bot.switch_model # Simulate bot not having the attribute
|
||||
response = await self.helper._switch_logic()
|
||||
self.assertEqual(response, "Model switching is not supported for this bot.")
|
||||
|
||||
async def test_handle_message_logic_success(self):
|
||||
user_id = 100
|
||||
user_message = "Hello bot"
|
||||
bot_response = "Hello user <think>Thinking hard</think> Done."
|
||||
expected_processed_response = f"Hello user {self.helper.HTML_QUOTE_BLOCK_START}Thinking hard{self.helper.HTML_QUOTE_BLOCK_END} Done."
|
||||
self.mock_bot.handle_message.return_value = bot_response
|
||||
|
||||
result = await self.helper._handle_message_logic(user_id, user_message)
|
||||
|
||||
self.mock_bot.handle_message.assert_called_once_with(user_id, user_message)
|
||||
self.assertTrue(result["success"])
|
||||
self.assertEqual(result["response_text"], expected_processed_response)
|
||||
self.assertIsNone(result["error_message"])
|
||||
|
||||
async def test_handle_message_logic_bot_exception(self):
|
||||
user_id = 101
|
||||
user_message = "Trigger error"
|
||||
self.mock_bot.handle_message.side_effect = Exception("Bot Error")
|
||||
|
||||
result = await self.helper._handle_message_logic(user_id, user_message)
|
||||
|
||||
self.assertFalse(result["success"])
|
||||
self.assertIsNone(result["response_text"])
|
||||
self.assertEqual(result["error_message"], "Bot Error")
|
||||
|
||||
@patch(\'logging.error\')
|
||||
async def test_handle_message_command_success_short_message(self, mock_logging_error):
|
||||
mock_update = create_mock_update(message_text="Hi", user_id=200, chat_id=201, message_id=202)
|
||||
mock_context = create_mock_context()
|
||||
|
||||
logic_result = MessageHandlerLogicResult(success=True, response_text="Short response", error_message=None)
|
||||
|
||||
with patch.object(self.helper, \'_handle_message_logic\', new_callable=AsyncMock) as mock_message_logic:
|
||||
mock_message_logic.return_value = logic_result
|
||||
|
||||
await self.helper.handle_message(mock_update, mock_context)
|
||||
|
||||
mock_update.message.reply_text.assert_any_call("Processing your request...", reply_markup=unittest.mock.ANY)
|
||||
self.mock_bot.set_processing_status.assert_called_once_with(200, 202) # user_id, status_message_id
|
||||
mock_message_logic.assert_called_once_with(200, "Hi")
|
||||
mock_context.bot.delete_message.assert_called_once_with(chat_id=201, message_id=202)
|
||||
self.mock_bot.clear_processing_status.assert_called_once_with(200)
|
||||
mock_update.message.reply_text.assert_any_call("Short response") # Final response
|
||||
self.assertEqual(mock_update.message.reply_text.call_count, 2) # Processing + final
|
||||
|
||||
@patch(\'logging.error\')
|
||||
async def test_handle_message_command_success_long_message_chunks(self, mock_logging_error):
|
||||
mock_update = create_mock_update(message_text="Long text", user_id=200, chat_id=201, message_id=202)
|
||||
mock_context = create_mock_context()
|
||||
|
||||
long_response_text = "a" * 5000 # Longer than 4096
|
||||
chunk1 = long_response_text[:4096]
|
||||
chunk2 = long_response_text[4096:]
|
||||
|
||||
logic_result = MessageHandlerLogicResult(success=True, response_text=long_response_text, error_message=None)
|
||||
|
||||
with patch.object(self.helper, \'_handle_message_logic\', new_callable=AsyncMock) as mock_message_logic, \
|
||||
patch(\'asyncio.sleep\', new_callable=AsyncMock) as mock_sleep: # Mock sleep
|
||||
mock_message_logic.return_value = logic_result
|
||||
|
||||
await self.helper.handle_message(mock_update, mock_context)
|
||||
|
||||
mock_update.message.reply_text.assert_any_call(chunk1)
|
||||
mock_update.message.reply_text.assert_any_call(chunk2)
|
||||
mock_sleep.assert_called_once_with(self.helper.chunk_message_sleep_duration)
|
||||
self.assertEqual(mock_update.message.reply_text.call_count, 3) # Processing + 2 chunks
|
||||
|
||||
@patch(\'logging.error\')
|
||||
async def test_handle_message_command_logic_fails(self, mock_logging_error):
|
||||
mock_update = create_mock_update(message_text="Cause error in logic", user_id=200)
|
||||
mock_context = create_mock_context()
|
||||
logic_result = MessageHandlerLogicResult(success=False, response_text=None, error_message="Logic Failed")
|
||||
|
||||
with patch.object(self.helper, \'_handle_message_logic\', new_callable=AsyncMock) as mock_message_logic:
|
||||
mock_message_logic.return_value = logic_result
|
||||
await self.helper.handle_message(mock_update, mock_context)
|
||||
mock_update.message.reply_text.assert_any_call("Sorry, an error occurred while processing your request.")
|
||||
self.assertEqual(mock_update.message.reply_text.call_count, 2) # Processing + error message
|
||||
|
||||
@patch(\'logging.error\')
|
||||
async def test_handle_message_command_telegram_exception_after_logic(self, mock_logging_error):
|
||||
mock_update = create_mock_update(message_text="Test", user_id=200)
|
||||
mock_context = create_mock_context()
|
||||
logic_result = MessageHandlerLogicResult(success=True, response_text="OK", error_message=None)
|
||||
|
||||
# Make sending the final reply fail
|
||||
mock_update.message.reply_text.side_effect = [
|
||||
MagicMock(message_id=202), # For "Processing..."
|
||||
Exception("Telegram API Error") # For the actual response
|
||||
]
|
||||
|
||||
with patch.object(self.helper, \'_handle_message_logic\', new_callable=AsyncMock) as mock_message_logic:
|
||||
mock_message_logic.return_value = logic_result
|
||||
await self.helper.handle_message(mock_update, mock_context)
|
||||
|
||||
# Check if the generic error message was attempted
|
||||
# This is tricky because reply_text is already mocked with side_effect.
|
||||
# We\'d expect logs. Let\'s check logs or if processing status was cleared.
|
||||
self.mock_bot.clear_processing_status.assert_called_once_with(200)
|
||||
mock_logging_error.assert_any_call(unittest.mock.string_containing("Outer error in handle_message"))
|
||||
|
||||
|
||||
async def test_abort_processing_logic(self):
|
||||
user_id = 300
|
||||
self.mock_bot.abort_processing.return_value = "Aborted by bot"
|
||||
response = await self.helper._abort_processing_logic(user_id)
|
||||
self.mock_bot.abort_processing.assert_called_once_with(user_id)
|
||||
self.assertEqual(response, "Aborted by bot")
|
||||
|
||||
async def test_abort_processing_command(self):
|
||||
mock_update = create_mock_update(callback_query_data=\'abort\', user_id=301)
|
||||
mock_context = create_mock_context()
|
||||
with patch.object(self.helper, \'_abort_processing_logic\', new_callable=AsyncMock) as mock_logic:
|
||||
mock_logic.return_value = "Abort Logic Done"
|
||||
await self.helper.abort_processing(mock_update, mock_context)
|
||||
|
||||
mock_update.callback_query.answer.assert_called_once()
|
||||
mock_logic.assert_called_once_with(301)
|
||||
mock_update.callback_query.edit_message_text.assert_called_once_with(text="Abort Logic Done")
|
||||
|
||||
def test_reboot_logic_claude_and_main(self):
|
||||
user_message_parts = ["/reboot", "claude"]
|
||||
chat_id_to_write = "12345"
|
||||
|
||||
with patch("builtins.open", mock_open()) as mock_file:
|
||||
self.helper._reboot_logic(user_message_parts, chat_id_to_write)
|
||||
|
||||
# Check claude reboot file
|
||||
mock_file.assert_any_call(self.reboot_claude_file, \'w\')
|
||||
# Check main doreboot file
|
||||
mock_file.assert_any_call(self.reboot_file, \'w\')
|
||||
handle_claude = mock_file.return_value
|
||||
handle_main = mock_file.return_value # mock_open reuses the handle for multiple calls
|
||||
|
||||
# Check if write was called for claude file (empty write)
|
||||
# This part of assertion is tricky with single mock_file. Better to use different mocks if possible
|
||||
# or check the sequence of calls if the mock supports it well.
|
||||
# For now, assert_any_call ensures it was opened.
|
||||
|
||||
# Check content for main reboot file
|
||||
# Need to ensure the write for self.reboot_file had chat_id_to_write
|
||||
# This requires more sophisticated mock_open or patching os.path.exists and multiple open calls
|
||||
# Simpler check: was open(self.reboot_file, \'w\') called? Yes, via assert_any_call.
|
||||
# And was open(self.reboot_claude_file, \'w\') called? Yes.
|
||||
|
||||
# Verify files were created (mock_open doesn\'t actually create them)
|
||||
# This test relies on mock_open\'s behavior. To test file content, need more setup.
|
||||
# For now, assume open was called correctly.
|
||||
|
||||
def test_reboot_logic_main_only(self):
|
||||
user_message_parts = ["/reboot"]
|
||||
chat_id_to_write = "67890"
|
||||
with patch("builtins.open", mock_open()) as mock_file:
|
||||
self.helper._reboot_logic(user_message_parts, chat_id_to_write)
|
||||
# Ensure claude file was NOT opened for writing.
|
||||
# This requires asserting that a specific call didn\'t happen, or checking call_args_list
|
||||
claude_call = unittest.mock.call(self.reboot_claude_file, \'w\')
|
||||
self.assertNotIn(claude_call, mock_file.call_args_list)
|
||||
|
||||
mock_file.assert_any_call(self.reboot_file, \'w\')
|
||||
|
||||
@patch(\'sys.exit\') # Mock sys.exit to prevent test runner from exiting
|
||||
async def test_reboot_command(self, mock_sys_exit):
|
||||
mock_update = create_mock_update(message_text="/reboot claude", chat_id="chat1")
|
||||
mock_context = create_mock_context()
|
||||
|
||||
with patch.object(self.helper, \'_reboot_logic\') as mock_reboot_file_logic:
|
||||
await self.helper.reboot(mock_update, mock_context)
|
||||
|
||||
mock_reboot_file_logic.assert_called_once_with(["/reboot", "claude"], "chat1")
|
||||
mock_update.message.reply_text.assert_called_once_with("Rebooting the bot...")
|
||||
mock_sys_exit.assert_called_once_with(0)
|
||||
|
||||
@patch(\'os.path.exists\')
|
||||
@patch(\'builtins.open\', new_callable=mock_open)
|
||||
@patch(\'os.remove\')
|
||||
async def test_check_doreboot_file_logic_file_exists(self, mock_os_remove, mock_file_open, mock_os_path_exists):
|
||||
mock_os_path_exists.return_value = True
|
||||
mock_file_open.return_value.read.return_value.strip.return_value = "chat123"
|
||||
|
||||
chat_id = await self.helper._check_doreboot_file_logic()
|
||||
|
||||
mock_os_path_exists.assert_called_once_with(self.reboot_file)
|
||||
mock_file_open.assert_called_once_with(self.reboot_file, \'r\')
|
||||
mock_os_remove.assert_called_once_with(self.reboot_file)
|
||||
self.assertEqual(chat_id, "chat123")
|
||||
|
||||
@patch(\'os.path.exists\', return_value=False)
|
||||
async def test_check_doreboot_file_logic_file_not_exists(self, mock_os_path_exists):
|
||||
chat_id = await self.helper._check_doreboot_file_logic()
|
||||
mock_os_path_exists.assert_called_once_with(self.reboot_file)
|
||||
self.assertIsNone(chat_id)
|
||||
|
||||
@patch(\'logging.error\')
|
||||
@patch(\'os.path.exists\', return_value=True)
|
||||
@patch(\'builtins.open\', side_effect=IOError("Read error"))
|
||||
@patch(\'os.remove\') # To check if remove is called even on read error
|
||||
async def test_check_doreboot_file_logic_read_error(self, mock_os_remove, mock_file_open, mock_os_path_exists, mock_log_error):
|
||||
chat_id = await self.helper._check_doreboot_file_logic()
|
||||
|
||||
self.assertIsNone(chat_id)
|
||||
mock_log_error.assert_any_call(unittest.mock.string_containing("Error reading reboot file"))
|
||||
# Check if os.remove was attempted even after read error
|
||||
mock_os_remove.assert_called_once_with(self.reboot_file)
|
||||
|
||||
|
||||
async def test_check_doreboot_file_command_sends_message(self):
|
||||
mock_application = MagicMock()
|
||||
mock_application.bot.send_message = AsyncMock()
|
||||
|
||||
with patch.object(self.helper, \'_check_doreboot_file_logic\', new_callable=AsyncMock) as mock_logic:
|
||||
mock_logic.return_value = "chat789" # Simulate chat_id found
|
||||
await self.helper.check_doreboot_file(mock_application)
|
||||
|
||||
mock_logic.assert_called_once()
|
||||
mock_application.bot.send_message.assert_called_once_with(
|
||||
chat_id="chat789", text="The application has finished initializing."
|
||||
)
|
||||
|
||||
async def test_check_doreboot_file_command_no_chat_id(self):
|
||||
mock_application = MagicMock()
|
||||
mock_application.bot.send_message = AsyncMock()
|
||||
|
||||
with patch.object(self.helper, \'_check_doreboot_file_logic\', new_callable=AsyncMock) as mock_logic:
|
||||
mock_logic.return_value = None # Simulate no chat_id found
|
||||
await self.helper.check_doreboot_file(mock_application)
|
||||
|
||||
mock_logic.assert_called_once()
|
||||
mock_application.bot.send_message.assert_not_called()
|
||||
|
||||
# Note: Testing the run() method itself is more of an integration test,
|
||||
# as it involves setting up the full Application and polling loop.
|
||||
# Unit tests here focus on the helper\'s own logic methods.
|
||||
|
||||
if __name__ == \'__main__\':
|
||||
unittest.main()
|
||||
@@ -1,307 +0,0 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
import os
|
||||
import base64
|
||||
import logging
|
||||
import requests # Required for spec in MagicMock
|
||||
|
||||
# Ensure tools/github_tool.py is accessible
|
||||
from tools.github_tool import GitHubTool
|
||||
|
||||
# Helper to create a mock response for requests.Session
|
||||
def create_mock_response(status_code, json_data=None, text_data=None, headers=None, links=None):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = status_code
|
||||
if json_data is not None:
|
||||
mock_resp.json = MagicMock(return_value=json_data)
|
||||
mock_resp.text = text_data if text_data is not None else str(json_data)
|
||||
mock_resp.headers = headers if headers else {}
|
||||
mock_resp.links = links if links else {} # For pagination in _list_branches
|
||||
return mock_resp
|
||||
|
||||
class TestGitHubTool(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.mock_session = MagicMock(spec=requests.Session)
|
||||
self.mock_session.headers = {} # Simulate a new session's headers
|
||||
|
||||
self.test_token = "test_github_token"
|
||||
self.test_repo = "owner/repo"
|
||||
self.test_base_url = "https://api.example.com" # Use a non-default base_url for some tests
|
||||
|
||||
# Suppress logging output during tests unless explicitly testing for it
|
||||
self.logger = logging.getLogger('tools.github_tool')
|
||||
# Ensure only one NullHandler to prevent duplicate messages if tests run multiple times in a session
|
||||
if not any(isinstance(h, logging.NullHandler) for h in self.logger.handlers):
|
||||
self.logger.addHandler(logging.NullHandler())
|
||||
self.logger.propagate = False # Prevent propagation to root logger if it has handlers
|
||||
|
||||
def test_init_with_args_and_session(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, base_url=self.test_base_url, logger=self.logger)
|
||||
self.assertEqual(tool.session, self.mock_session)
|
||||
self.assertEqual(tool._token, self.test_token)
|
||||
self.assertEqual(tool._repo, self.test_repo)
|
||||
self.assertEqual(tool.base_url, self.test_base_url)
|
||||
self.assertEqual(tool.current_branch, "main") # Default initial branch
|
||||
|
||||
@patch('requests.Session')
|
||||
def test_init_creates_session_if_not_provided(self, MockSessionConstructor):
|
||||
mock_created_session = MagicMock(spec=requests.Session)
|
||||
mock_created_session.headers = {}
|
||||
MockSessionConstructor.return_value = mock_created_session
|
||||
|
||||
# Temporarily set env vars for this test
|
||||
original_token = os.environ.get("GITHUB_TOKEN")
|
||||
original_repo = os.environ.get("GITHUB_REPOSITORY")
|
||||
os.environ["GITHUB_TOKEN"] = "env_token"
|
||||
os.environ["GITHUB_REPOSITORY"] = "env/repo"
|
||||
|
||||
tool = GitHubTool(logger=self.logger) # Use env vars
|
||||
|
||||
MockSessionConstructor.assert_called_once()
|
||||
self.assertEqual(tool.session, mock_created_session)
|
||||
self.assertEqual(tool._token, "env_token")
|
||||
self.assertEqual(tool._repo, "env/repo")
|
||||
self.assertIn("Authorization", mock_created_session.headers)
|
||||
self.assertEqual(mock_created_session.headers["Authorization"], "token env_token")
|
||||
|
||||
# Restore original env vars
|
||||
if original_token is None: del os.environ["GITHUB_TOKEN"]
|
||||
else: os.environ["GITHUB_TOKEN"] = original_token
|
||||
if original_repo is None: del os.environ["GITHUB_REPOSITORY"]
|
||||
else: os.environ["GITHUB_REPOSITORY"] = original_repo
|
||||
|
||||
def test_init_raises_value_error_if_no_token(self):
|
||||
original_token = os.environ.pop("GITHUB_TOKEN", None)
|
||||
with self.assertRaisesRegex(ValueError, "GitHub token must be provided"):
|
||||
GitHubTool(repo=self.test_repo, logger=self.logger)
|
||||
if original_token: os.environ["GITHUB_TOKEN"] = original_token
|
||||
|
||||
def test_init_raises_value_error_if_no_repo(self):
|
||||
original_repo = os.environ.pop("GITHUB_REPOSITORY", None)
|
||||
with self.assertRaisesRegex(ValueError, "GitHub repository.*must be provided"):
|
||||
GitHubTool(token=self.test_token, logger=self.logger)
|
||||
if original_repo: os.environ["GITHUB_REPOSITORY"] = original_repo
|
||||
|
||||
def test_clear_resets_branch(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, initial_branch="feature-branch", logger=self.logger)
|
||||
# Mock _get_branch_sha for _set_current_branch called by clear
|
||||
with patch.object(tool, '_get_branch_sha', return_value="sha_for_main"):
|
||||
tool.clear()
|
||||
self.assertEqual(tool.current_branch, "main")
|
||||
|
||||
def test_get_functions_returns_list(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
functions = tool.get_functions()
|
||||
self.assertIsInstance(functions, list)
|
||||
self.assertTrue(len(functions) > 0)
|
||||
self.assertIn("name", functions[0]["function"])
|
||||
|
||||
|
||||
# --- Test individual private methods ---
|
||||
|
||||
def test_read_file_success(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
file_content = "Hello World!"
|
||||
encoded_content = base64.b64encode(file_content.encode('utf-8')).decode('utf-8')
|
||||
self.mock_session.get.return_value = create_mock_response(200, json_data={"content": encoded_content})
|
||||
|
||||
result = tool._read_file(path="test.txt")
|
||||
self.assertEqual(result, file_content)
|
||||
self.mock_session.get.assert_called_once_with(
|
||||
f"{tool.base_url}/repos/{self.test_repo}/contents/test.txt",
|
||||
params={"ref": "main"}
|
||||
)
|
||||
|
||||
def test_read_file_error(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
self.mock_session.get.return_value = create_mock_response(404, text_data="Not Found")
|
||||
result = tool._read_file(path="nonexistent.txt")
|
||||
self.assertIn("Error reading file", result)
|
||||
|
||||
def test_create_branch_success(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
# Mock getting base branch SHA
|
||||
self.mock_session.get.return_value = create_mock_response(200, json_data={"object": {"sha": "base_sha123"}})
|
||||
# Mock creating new branch
|
||||
self.mock_session.post.return_value = create_mock_response(201, json_data={"ref": "refs/heads/new-feature"})
|
||||
|
||||
result = tool._create_branch(branch_name="new-feature", base_branch="main")
|
||||
self.assertIn("Branch 'new-feature' created successfully", result)
|
||||
self.assertEqual(tool.current_branch, "new-feature")
|
||||
self.mock_session.get.assert_called_once() # For base branch SHA
|
||||
self.mock_session.post.assert_called_once() # For creating branch
|
||||
|
||||
def test_create_branch_base_sha_error(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
self.mock_session.get.return_value = create_mock_response(404, text_data="Base branch not found")
|
||||
result = tool._create_branch(branch_name="new-feature", base_branch="nonexistent-base")
|
||||
self.assertIn("Error getting base branch SHA", result)
|
||||
|
||||
def test_create_branch_creation_error(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
self.mock_session.get.return_value = create_mock_response(200, json_data={"object": {"sha": "base_sha456"}})
|
||||
self.mock_session.post.return_value = create_mock_response(422, text_data="Validation failed")
|
||||
result = tool._create_branch(branch_name="bad-branch", base_branch="main")
|
||||
self.assertIn("Error creating branch", result)
|
||||
|
||||
def test_commit_file_success_new_file(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
tool.current_branch = "dev-branch" # Cannot commit to main by default
|
||||
|
||||
# Mock GET for checking file existence (404 means new file)
|
||||
self.mock_session.get.return_value = create_mock_response(404)
|
||||
# Mock PUT for committing file
|
||||
self.mock_session.put.return_value = create_mock_response(201, json_data={"commit": {"sha": "commit_sha_abc"}})
|
||||
|
||||
result = tool._commit_file(file_path="new_file.py", content="print('Hello')", commit_message="Add new_file.py")
|
||||
self.assertIn("committed successfully", result)
|
||||
self.assertIn("commit_sha_abc", result)
|
||||
self.mock_session.get.assert_called_once() # Check file existence
|
||||
self.mock_session.put.assert_called_once() # Commit file
|
||||
|
||||
def test_commit_file_success_update_file(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
tool.current_branch = "dev-branch"
|
||||
|
||||
# Mock GET for checking file existence (200 means existing file)
|
||||
self.mock_session.get.return_value = create_mock_response(200, json_data={"sha": "existing_file_sha"})
|
||||
# Mock PUT for committing file
|
||||
self.mock_session.put.return_value = create_mock_response(200, json_data={"commit": {"sha": "commit_sha_def"}})
|
||||
|
||||
result = tool._commit_file(file_path="existing_file.txt", content="Updated content", commit_message="Update existing_file.txt")
|
||||
self.assertIn("committed successfully", result)
|
||||
self.assertIn("commit_sha_def", result)
|
||||
args, kwargs = self.mock_session.put.call_args
|
||||
self.assertEqual(kwargs['json']['sha'], "existing_file_sha")
|
||||
|
||||
|
||||
def test_commit_file_to_main_branch_error(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
tool.current_branch = "main"
|
||||
result = tool._commit_file(file_path="some.txt", content="content", commit_message="msg")
|
||||
self.assertIn("Action directly to main branch is not allowed", result)
|
||||
|
||||
def test_create_pull_request_success(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
tool.current_branch = "feature-pr"
|
||||
pr_url = "https://example.com/pull/1"
|
||||
self.mock_session.post.return_value = create_mock_response(201, json_data={"html_url": pr_url, "number": 1})
|
||||
|
||||
result = tool._create_pull_request(title="New Feature PR", body="Please review.", base="main")
|
||||
self.assertIn(f"Pull request created successfully: {pr_url}", result)
|
||||
self.mock_session.post.assert_called_once()
|
||||
call_data = self.mock_session.post.call_args[1]['json']
|
||||
self.assertEqual(call_data['head'], "feature-pr")
|
||||
self.assertEqual(call_data['base'], "main")
|
||||
|
||||
def test_create_pull_request_same_branch_error(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
tool.current_branch = "main"
|
||||
result = tool._create_pull_request(title="PR to self", body="This should fail", base="main")
|
||||
self.assertIn("Cannot create a pull request from branch 'main' to itself", result)
|
||||
|
||||
|
||||
def test_list_files_success(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
mock_items = [
|
||||
{"name": "file1.txt", "type": "file", "path": "dir/file1.txt"},
|
||||
{"name": "subdir", "type": "dir", "path": "dir/subdir"}
|
||||
]
|
||||
self.mock_session.get.return_value = create_mock_response(200, json_data=mock_items)
|
||||
|
||||
result = tool._list_files(path="dir")
|
||||
self.assertEqual(len(result), 2)
|
||||
self.assertEqual(result[0]["name"], "file1.txt")
|
||||
self.assertEqual(result[1]["type"], "dir")
|
||||
|
||||
def test_search_code_success(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
mock_search_results = {
|
||||
"items": [{"path": "src/code.py", "html_url": "url1"}]
|
||||
}
|
||||
self.mock_session.get.return_value = create_mock_response(200, json_data=mock_search_results)
|
||||
|
||||
results = tool._search_code(query="my_function")
|
||||
self.assertEqual(len(results), 1)
|
||||
self.assertEqual(results[0]["path"], "src/code.py")
|
||||
|
||||
def test_get_commit_history_success(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
mock_commits = [{
|
||||
"sha": "sha1", "commit": {"message": "Msg1", "author": {"name": "Authy", "date": "Date1"}}
|
||||
}]
|
||||
self.mock_session.get.return_value = create_mock_response(200, json_data=mock_commits)
|
||||
|
||||
commits = tool._get_commit_history(file_path="file.txt", num_commits=1)
|
||||
self.assertEqual(len(commits), 1)
|
||||
self.assertEqual(commits[0]["sha"], "sha1")
|
||||
|
||||
def test_set_current_branch_success(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
# Mock _get_branch_sha to simulate branch exists
|
||||
with patch.object(tool, '_get_branch_sha', return_value="some_sha_for_dev"):
|
||||
result = tool._set_current_branch(branch_name="dev")
|
||||
self.assertEqual(tool.current_branch, "dev")
|
||||
self.assertIn("Current branch set to: dev", result)
|
||||
|
||||
def test_set_current_branch_not_exists(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
with patch.object(tool, '_get_branch_sha', return_value="Error getting SHA for branch"):
|
||||
result = tool._set_current_branch(branch_name="nonexistent-branch")
|
||||
self.assertNotEqual(tool.current_branch, "nonexistent-branch") # Should not change
|
||||
self.assertIn("Cannot set current branch", result)
|
||||
|
||||
|
||||
def test_list_branches_single_page(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
mock_branches = [{"name": "main"}, {"name": "dev"}]
|
||||
self.mock_session.get.return_value = create_mock_response(200, json_data=mock_branches, links={}) # No "next" link
|
||||
|
||||
branches = tool._list_branches(all_pages=True)
|
||||
self.assertEqual(branches, ["main", "dev"])
|
||||
self.mock_session.get.assert_called_once()
|
||||
|
||||
def test_list_branches_multiple_pages(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
|
||||
# Page 1 response
|
||||
page1_branches = [{"name": "branch1"}, {"name": "branch2"}]
|
||||
next_url = f"{tool.base_url}/repos/{self.test_repo}/branches?page=2"
|
||||
response1 = create_mock_response(200, json_data=page1_branches, links={"next": {"url": next_url}})
|
||||
|
||||
# Page 2 response
|
||||
page2_branches = [{"name": "branch3"}]
|
||||
response2 = create_mock_response(200, json_data=page2_branches, links={}) # No "next" link
|
||||
|
||||
self.mock_session.get.side_effect = [response1, response2]
|
||||
|
||||
branches = tool._list_branches(all_pages=True)
|
||||
self.assertEqual(branches, ["branch1", "branch2", "branch3"])
|
||||
self.assertEqual(self.mock_session.get.call_count, 2)
|
||||
|
||||
# Check that the second call used the next_url
|
||||
calls = self.mock_session.get.call_args_list
|
||||
self.assertEqual(calls[1][0][0], next_url) # args[0] is the URL
|
||||
|
||||
# --- Test execute dispatcher ---
|
||||
def test_execute_read_file(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
with patch.object(tool, '_read_file', return_value="file content") as mock_method:
|
||||
result = tool.execute(function_name="read_file", path="test.md")
|
||||
mock_method.assert_called_once_with(path="test.md")
|
||||
self.assertEqual(result, "file content")
|
||||
|
||||
def test_execute_unknown_function(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
result = tool.execute(function_name="non_existent_function_name", arg1="val1")
|
||||
self.assertIn("Unknown function: non_existent_function_name", result)
|
||||
|
||||
def test_execute_method_exception(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
with patch.object(tool, '_read_file', side_effect=Exception("Kaboom")) as mock_method:
|
||||
result = tool.execute(function_name="read_file", path="crash.txt")
|
||||
self.assertIn("Error during read_file execution: Kaboom", result)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -1,146 +0,0 @@
|
||||
import unittest
|
||||
from unittest.mock import patch, mock_open, MagicMock
|
||||
import os
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
# Ensure tools/log_tool.py is accessible
|
||||
from tools.log_tool import LogTool
|
||||
|
||||
class TestLogTool(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.test_log_file_path = "test_dummy_log.log"
|
||||
# Suppress logging output during tests unless explicitly testing for it
|
||||
self.logger = logging.getLogger('tools.log_tool')
|
||||
# Ensure only one NullHandler to prevent duplicate messages if tests run multiple times in a session
|
||||
if not any(isinstance(h, logging.NullHandler) for h in self.logger.handlers):
|
||||
self.logger.addHandler(logging.NullHandler())
|
||||
self.logger.propagate = False # Prevent propagation to root logger if it has handlers
|
||||
|
||||
|
||||
def test_init_default_log_path(self):
|
||||
tool = LogTool(logger=self.logger)
|
||||
self.assertEqual(tool.configured_log_file_path, 'logs/output.log')
|
||||
|
||||
def test_init_custom_log_path(self):
|
||||
tool = LogTool(log_file_path=self.test_log_file_path, logger=self.logger)
|
||||
self.assertEqual(tool.configured_log_file_path, self.test_log_file_path)
|
||||
|
||||
def test_get_functions(self):
|
||||
tool = LogTool(logger=self.logger)
|
||||
functions = tool.get_functions()
|
||||
self.assertIsInstance(functions, list)
|
||||
self.assertEqual(len(functions), 1)
|
||||
self.assertEqual(functions[0]["function"]["name"], "get_log_contents")
|
||||
|
||||
@patch("os.path.exists", return_value=False)
|
||||
def test_get_log_contents_file_not_exists(self, mock_exists):
|
||||
tool = LogTool(log_file_path=self.test_log_file_path, logger=self.logger)
|
||||
result = tool._get_log_contents()
|
||||
self.assertIn("Log file does not exist", result)
|
||||
mock_exists.assert_called_once_with(self.test_log_file_path)
|
||||
|
||||
@patch("os.path.exists", return_value=True)
|
||||
@patch("builtins.open", new_callable=mock_open, read_data="line1\nline2\nline3\nline4\nline5")
|
||||
def test_get_log_contents_with_line_count(self, mock_file_open, mock_exists):
|
||||
tool = LogTool(log_file_path=self.test_log_file_path, logger=self.logger)
|
||||
|
||||
result = tool._get_log_contents(line_count=3)
|
||||
self.assertEqual(result, "line3\nline4\nline5")
|
||||
mock_exists.assert_called_once_with(self.test_log_file_path)
|
||||
mock_file_open.assert_called_once_with(self.test_log_file_path, 'r', encoding='utf-8')
|
||||
|
||||
@patch("os.path.exists", return_value=True)
|
||||
@patch("builtins.open", new_callable=mock_open, read_data="line1\nline2\n")
|
||||
def test_get_log_contents_line_count_more_than_available(self, mock_file_open, mock_exists):
|
||||
tool = LogTool(log_file_path=self.test_log_file_path, logger=self.logger)
|
||||
result = tool._get_log_contents(line_count=5)
|
||||
self.assertEqual(result, "line1\nline2\n")
|
||||
|
||||
@patch("os.path.exists", return_value=True)
|
||||
@patch("builtins.open", new_callable=mock_open, read_data="line1\nline2\n")
|
||||
def test_get_log_contents_invalid_line_count_uses_default(self, mock_file_open, mock_exists):
|
||||
tool = LogTool(log_file_path=self.test_log_file_path, logger=self.logger)
|
||||
# Test with zero, negative, and non-integer line_count (though type hint is int)
|
||||
# The code defaults to 150 if invalid. Here, we only have 2 lines.
|
||||
with patch.object(tool.logger, 'warning') as mock_log_warning:
|
||||
result_zero = tool._get_log_contents(line_count=0)
|
||||
self.assertEqual(result_zero, "line1\nline2\n")
|
||||
mock_log_warning.assert_any_call("Invalid line_count '0' provided, defaulting to fetch last 150 lines.")
|
||||
|
||||
mock_file_open.reset_mock() # Reset for next call
|
||||
result_neg = tool._get_log_contents(line_count=-5)
|
||||
self.assertEqual(result_neg, "line1\nline2\n")
|
||||
mock_log_warning.assert_any_call("Invalid line_count '-5' provided, defaulting to fetch last 150 lines.")
|
||||
|
||||
|
||||
@patch("os.path.exists", return_value=True)
|
||||
def test_get_log_contents_last_24_hours(self, mock_exists):
|
||||
tool = LogTool(log_file_path=self.test_log_file_path, logger=self.logger)
|
||||
|
||||
now = datetime.now()
|
||||
one_hour_ago_dt = now - timedelta(hours=1)
|
||||
two_days_ago_dt = now - timedelta(days=2)
|
||||
|
||||
one_hour_ago_str = one_hour_ago_dt.strftime(LogTool.EXPECTED_LOG_TIMESTAMP_FORMAT)
|
||||
two_days_ago_str = two_days_ago_dt.strftime(LogTool.EXPECTED_LOG_TIMESTAMP_FORMAT)
|
||||
|
||||
log_data = (
|
||||
f"{two_days_ago_str} - OLD - This is an old log entry.\n"
|
||||
f"No timestamp here - this line should be skipped by time filter.\n"
|
||||
f"{one_hour_ago_str} - RECENT - This is a recent log entry.\n"
|
||||
f"Malformed Date 2023-xx-01 - Another skipped line.\n"
|
||||
f"{now.strftime(LogTool.EXPECTED_LOG_TIMESTAMP_FORMAT)} - CURRENT - This is a current log entry.\n"
|
||||
)
|
||||
|
||||
expected_output = (
|
||||
f"{one_hour_ago_str} - RECENT - This is a recent log entry.\n"
|
||||
f"{now.strftime(LogTool.EXPECTED_LOG_TIMESTAMP_FORMAT)} - CURRENT - This is a current log entry.\n"
|
||||
)
|
||||
|
||||
with patch("builtins.open", mock_open(read_data=log_data)):
|
||||
result = tool._get_log_contents(line_count=None) # Trigger 24-hour logic
|
||||
self.assertEqual(result, expected_output)
|
||||
|
||||
@patch("os.path.exists", return_value=True)
|
||||
@patch("builtins.open", side_effect=IOError("File read error!"))
|
||||
def test_get_log_contents_file_read_exception(self, mock_file_open, mock_exists):
|
||||
tool = LogTool(log_file_path=self.test_log_file_path, logger=self.logger)
|
||||
result = tool._get_log_contents(line_count=10)
|
||||
self.assertIn("An error occurred while reading the log file: File read error!", result)
|
||||
|
||||
def test_execute_get_log_contents(self):
|
||||
tool = LogTool(logger=self.logger)
|
||||
mock_return_value = "Mocked log content"
|
||||
with patch.object(tool, '_get_log_contents', return_value=mock_return_value) as mock_method:
|
||||
result = tool.execute(function_name="get_log_contents", line_count=50)
|
||||
mock_method.assert_called_once_with(line_count=50)
|
||||
self.assertEqual(result, mock_return_value)
|
||||
|
||||
def test_execute_get_log_contents_no_line_count(self):
|
||||
tool = LogTool(logger=self.logger)
|
||||
mock_return_value = "Mocked log content for 24h"
|
||||
with patch.object(tool, '_get_log_contents', return_value=mock_return_value) as mock_method:
|
||||
result = tool.execute(function_name="get_log_contents") # No line_count
|
||||
mock_method.assert_called_once_with(line_count=None) # Expects None to trigger 24h
|
||||
self.assertEqual(result, mock_return_value)
|
||||
|
||||
|
||||
def test_execute_unknown_function(self):
|
||||
tool = LogTool(logger=self.logger)
|
||||
result = tool.execute(function_name="non_existent_log_function")
|
||||
self.assertIn("Unknown function: non_existent_log_function", result)
|
||||
|
||||
def test_clear_method(self):
|
||||
tool = LogTool(logger=self.logger)
|
||||
# Set a specific level for the logger for this test if needed to capture debug
|
||||
original_level = tool.logger.level
|
||||
tool.logger.setLevel(logging.DEBUG)
|
||||
with self.assertLogs(tool.logger, level='DEBUG') as cm:
|
||||
tool.clear()
|
||||
self.assertTrue(any("LogTool clear called" in message for message in cm.output))
|
||||
tool.logger.setLevel(original_level) # Reset level
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -1,217 +0,0 @@
|
||||
import unittest
|
||||
from unittest.mock import patch, MagicMock, ANY
|
||||
import time
|
||||
import logging
|
||||
|
||||
# Ensure tools.metrics is accessible
|
||||
from tools.metrics import Metrics # Import the class itself for direct testing
|
||||
from tools.metrics import metrics as global_metrics_instance # Import the global instance
|
||||
|
||||
# A simple function to decorate for testing
|
||||
def sample_function_for_metrics(duration=0.01):
|
||||
# Simulate some work
|
||||
# Note: time.sleep is not always precisely profiled by cProfile in the same way as CPU-bound work.
|
||||
# For testing, we will mock the cProfile/pstats interaction rather than relying on actual sleep duration.
|
||||
if duration > 0: # Make it conditional so we can test zero-time case too
|
||||
pass # The actual work is not important when mocking cProfile results
|
||||
return "sample_output"
|
||||
|
||||
def another_sample_function(x, y):
|
||||
return x + y
|
||||
|
||||
class TestMetrics(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Create a fresh Metrics instance for most tests to avoid interference
|
||||
self.logger = logging.getLogger('tools.metrics.test')
|
||||
if not self.logger.handlers: # Avoid adding handler multiple times
|
||||
self.logger.addHandler(logging.NullHandler())
|
||||
self.metrics_instance = Metrics(logger=self.logger)
|
||||
|
||||
# Clear the global instance before each test that might use it
|
||||
global_metrics_instance.clear_metrics()
|
||||
|
||||
def test_measure_decorator_counts_calls(self):
|
||||
decorated_func = self.metrics_instance.measure(sample_function_for_metrics)
|
||||
|
||||
self.assertEqual(self.metrics_instance.call_count["sample_function_for_metrics"], 0)
|
||||
decorated_func()
|
||||
self.assertEqual(self.metrics_instance.call_count["sample_function_for_metrics"], 1)
|
||||
decorated_func()
|
||||
self.assertEqual(self.metrics_instance.call_count["sample_function_for_metrics"], 2)
|
||||
|
||||
@patch('cProfile.Profile')
|
||||
@patch('pstats.Stats')
|
||||
def test_measure_decorator_records_time(self, MockPStats, MockCProfile):
|
||||
# Mock cProfile and pstats to control the time value
|
||||
mock_profiler_instance = MockCProfile.return_value
|
||||
mock_pstats_instance = MockPStats.return_value
|
||||
|
||||
# Simulate that pstats.Stats.stats dictionary contains the function's stats
|
||||
# Key: (filename, lineno, funcname)
|
||||
# Value: (cc, nc, tt, ct, callers) where ct is cumulative_time (index 3)
|
||||
|
||||
# Get code object of the function *before* decoration for correct key
|
||||
original_func_code = sample_function_for_metrics.__code__
|
||||
func_key = (original_func_code.co_filename, original_func_code.co_firstlineno, original_func_code.co_name)
|
||||
|
||||
# Configure mock_pstats_instance.stats to return our desired time
|
||||
mock_pstats_instance.stats = {func_key: (1, 1, 0.05, 0.123, {})} # cc, nc, tt, ct=0.123
|
||||
|
||||
decorated_func = self.metrics_instance.measure(sample_function_for_metrics)
|
||||
|
||||
self.assertEqual(self.metrics_instance.total_time["sample_function_for_metrics"], 0)
|
||||
|
||||
# Call the decorated function
|
||||
decorated_func(duration=0) # Duration arg doesn't matter due to mocking
|
||||
|
||||
# Assertions
|
||||
mock_profiler_instance.enable.assert_called_once()
|
||||
mock_profiler_instance.disable.assert_called_once()
|
||||
MockPStats.assert_called_once_with(mock_profiler_instance)
|
||||
|
||||
self.assertEqual(self.metrics_instance.total_time["sample_function_for_metrics"], 0.123)
|
||||
|
||||
# Call again to see accumulation
|
||||
# Reset mock stats for a new time value if needed, or assume same time per call
|
||||
mock_pstats_instance.stats = {func_key: (1, 1, 0.05, 0.100, {})} # New ct=0.100
|
||||
decorated_func(duration=0)
|
||||
self.assertAlmostEqual(self.metrics_instance.total_time["sample_function_for_metrics"], 0.123 + 0.100)
|
||||
|
||||
|
||||
@patch('cProfile.Profile')
|
||||
@patch('pstats.Stats')
|
||||
def test_measure_decorator_fallback_time_recording_by_name(self, MockPStats, MockCProfile):
|
||||
mock_profiler_instance = MockCProfile.return_value
|
||||
mock_pstats_instance = MockPStats.return_value
|
||||
|
||||
original_func_code = sample_function_for_metrics.__code__ # func to be decorated
|
||||
# Simulate the primary key lookup fails by creating a slightly different key for what we expect
|
||||
# This is what the code will try to look up first.
|
||||
expected_primary_key = (original_func_code.co_filename, original_func_code.co_firstlineno, original_func_code.co_name)
|
||||
|
||||
# This is the key that will *actually* be in pstats.stats, simulating a mismatch for primary lookup
|
||||
# but a match for the by-name fallback.
|
||||
actual_stats_key_in_pstats = (original_func_code.co_filename,
|
||||
original_func_code.co_firstlineno + 5, # simulate a lineno difference for primary key mismatch
|
||||
original_func_code.co_name) # Name is the same for fallback
|
||||
|
||||
mock_pstats_instance.stats = {
|
||||
# expected_primary_key is NOT present
|
||||
actual_stats_key_in_pstats: (1, 1, 0.03, 0.077, {}) # ct = 0.077
|
||||
}
|
||||
|
||||
decorated_func = self.metrics_instance.measure(sample_function_for_metrics)
|
||||
|
||||
# Expecting a debug log for fallback, but assertLogs needs the logger to have a handler that captures.
|
||||
# self.logger is already set up with NullHandler. For this test, let's use a specific logger.
|
||||
metrics_internal_logger = logging.getLogger('tools.metrics') # Logger used inside Metrics class
|
||||
original_level = metrics_internal_logger.level
|
||||
metrics_internal_logger.setLevel(logging.DEBUG)
|
||||
|
||||
with self.assertLogs(metrics_internal_logger, level='DEBUG') as log_capture:
|
||||
decorated_func(duration=0)
|
||||
|
||||
metrics_internal_logger.setLevel(original_level) # Reset logger level
|
||||
|
||||
self.assertTrue(any("Found stats for sample_function_for_metrics by name" in msg for msg in log_capture.output))
|
||||
self.assertEqual(self.metrics_instance.total_time["sample_function_for_metrics"], 0.077)
|
||||
|
||||
|
||||
@patch('cProfile.Profile')
|
||||
@patch('pstats.Stats')
|
||||
def test_measure_decorator_handles_func_stats_not_found(self, MockPStats, MockCProfile):
|
||||
mock_profiler_instance = MockCProfile.return_value
|
||||
mock_pstats_instance = MockPStats.return_value
|
||||
mock_pstats_instance.stats = {} # Empty stats, function will not be found
|
||||
|
||||
decorated_func = self.metrics_instance.measure(sample_function_for_metrics)
|
||||
|
||||
metrics_internal_logger = logging.getLogger('tools.metrics')
|
||||
original_level = metrics_internal_logger.level
|
||||
metrics_internal_logger.setLevel(logging.WARNING)
|
||||
with self.assertLogs(metrics_internal_logger, level='WARNING') as log_capture:
|
||||
decorated_func(duration=0)
|
||||
metrics_internal_logger.setLevel(original_level)
|
||||
|
||||
self.assertTrue(any("Could not find exact cProfile stats" in msg for msg in log_capture.output))
|
||||
self.assertEqual(self.metrics_instance.total_time["sample_function_for_metrics"], 0)
|
||||
|
||||
|
||||
def test_get_metrics_empty(self):
|
||||
self.assertEqual(self.metrics_instance.get_metrics(), {})
|
||||
|
||||
@patch('cProfile.Profile')
|
||||
@patch('pstats.Stats')
|
||||
def test_get_metrics_with_data(self, MockPStats, MockCProfile):
|
||||
mock_pstats_instance = MockPStats.return_value
|
||||
|
||||
# Decorate two different functions
|
||||
decorated_func1 = self.metrics_instance.measure(sample_function_for_metrics)
|
||||
decorated_func2 = self.metrics_instance.measure(another_sample_function)
|
||||
|
||||
# Data for func1
|
||||
func1_code = sample_function_for_metrics.__code__
|
||||
func1_key = (func1_code.co_filename, func1_code.co_firstlineno, func1_code.co_name)
|
||||
mock_pstats_instance.stats = {func1_key: (1,1,0.1,0.1,{})}
|
||||
decorated_func1()
|
||||
|
||||
# Data for func2
|
||||
func2_code = another_sample_function.__code__
|
||||
func2_key = (func2_code.co_filename, func2_code.co_firstlineno, func2_code.co_name)
|
||||
mock_pstats_instance.stats = {func2_key: (1,1,0.2,0.2,{})} # Cumulative time 0.2
|
||||
decorated_func2(1,2)
|
||||
mock_pstats_instance.stats = {func2_key: (1,1,0.3,0.3,{})} # Cumulative time 0.3 for second call
|
||||
decorated_func2(3,4)
|
||||
|
||||
metrics_data = self.metrics_instance.get_metrics()
|
||||
|
||||
self.assertIn("sample_function_for_metrics", metrics_data)
|
||||
self.assertEqual(metrics_data["sample_function_for_metrics"]["call_count"], 1)
|
||||
self.assertEqual(metrics_data["sample_function_for_metrics"]["total_time"], 0.1)
|
||||
self.assertEqual(metrics_data["sample_function_for_metrics"]["average_time"], 0.1)
|
||||
|
||||
self.assertIn("another_sample_function", metrics_data)
|
||||
self.assertEqual(metrics_data["another_sample_function"]["call_count"], 2)
|
||||
self.assertAlmostEqual(metrics_data["another_sample_function"]["total_time"], 0.5)
|
||||
self.assertAlmostEqual(metrics_data["another_sample_function"]["average_time"], 0.25)
|
||||
|
||||
|
||||
def test_clear_metrics(self):
|
||||
# Add some data
|
||||
self.metrics_instance.call_count["test_func"] = 5
|
||||
self.metrics_instance.total_time["test_func"] = 1.234
|
||||
|
||||
self.metrics_instance.clear_metrics()
|
||||
|
||||
self.assertEqual(self.metrics_instance.call_count, {})
|
||||
self.assertEqual(self.metrics_instance.total_time, {})
|
||||
self.assertEqual(self.metrics_instance.get_metrics(), {})
|
||||
|
||||
# Test the global instance
|
||||
@patch('cProfile.Profile')
|
||||
@patch('pstats.Stats')
|
||||
def test_global_metrics_instance_usage(self, MockPStats, MockCProfile):
|
||||
mock_pstats_instance = MockPStats.return_value
|
||||
|
||||
# Decorate a function with the global instance
|
||||
@global_metrics_instance.measure
|
||||
def globally_decorated_func():
|
||||
return "global_output"
|
||||
|
||||
# Setup mock stats for the globally decorated function
|
||||
# Access __wrapped__ to get the original function if other decorators might be present or for consistency.
|
||||
original_g_func = globally_decorated_func.__wrapped__
|
||||
func_code = original_g_func.__code__
|
||||
func_key = (func_code.co_filename, func_code.co_firstlineno, func_code.co_name)
|
||||
mock_pstats_instance.stats = {func_key: (1,1,0.05,0.05,{})}
|
||||
|
||||
globally_decorated_func()
|
||||
|
||||
metrics_data = global_metrics_instance.get_metrics()
|
||||
self.assertIn("globally_decorated_func", metrics_data)
|
||||
self.assertEqual(metrics_data["globally_decorated_func"]["call_count"], 1)
|
||||
self.assertEqual(metrics_data["globally_decorated_func"]["total_time"], 0.05)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -1,161 +0,0 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
import logging
|
||||
|
||||
# Ensure tools.metrics_tool and tools.metrics are accessible
|
||||
from tools.metrics_tool import MetricsTool
|
||||
from tools.metrics import Metrics # Used for typehinting and creating a mockable instance
|
||||
|
||||
class TestMetricsTool(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.mock_metrics_provider = MagicMock(spec=Metrics)
|
||||
self.logger = logging.getLogger('tools.metrics_tool.test')
|
||||
if not self.logger.handlers:
|
||||
self.logger.addHandler(logging.NullHandler())
|
||||
self.logger.propagate = False
|
||||
|
||||
|
||||
def test_init_with_provider(self):
|
||||
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
|
||||
self.assertEqual(tool.metrics_provider, self.mock_metrics_provider)
|
||||
|
||||
@patch('tools.metrics_tool.global_metrics_instance') # Patch the global instance path
|
||||
def test_init_default_provider(self, mock_global_metrics):
|
||||
tool = MetricsTool(logger=self.logger)
|
||||
self.assertEqual(tool.metrics_provider, mock_global_metrics)
|
||||
|
||||
def test_get_functions(self):
|
||||
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
|
||||
functions = tool.get_functions()
|
||||
self.assertIsInstance(functions, list)
|
||||
self.assertTrue(len(functions) == 3) # Based on current definition
|
||||
self.assertIn("get_function_metrics", [f["function"]["name"] for f in functions])
|
||||
self.assertIn("get_specific_function_metrics", [f["function"]["name"] for f in functions])
|
||||
self.assertIn("get_top_n_functions", [f["function"]["name"] for f in functions])
|
||||
|
||||
def test_execute_get_function_metrics(self):
|
||||
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
|
||||
expected_metrics = {"func1": {"call_count": 1, "total_time": 0.1}}
|
||||
self.mock_metrics_provider.get_metrics.return_value = expected_metrics
|
||||
|
||||
result = tool.execute(function_name="get_function_metrics")
|
||||
|
||||
self.mock_metrics_provider.get_metrics.assert_called_once()
|
||||
self.assertEqual(result, expected_metrics)
|
||||
|
||||
def test_execute_get_specific_function_metrics_found(self):
|
||||
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
|
||||
func_metrics = {"call_count": 5, "total_time": 0.5, "average_time": 0.1}
|
||||
all_metrics = {"specific_func": func_metrics, "other_func": {}}
|
||||
self.mock_metrics_provider.get_metrics.return_value = all_metrics
|
||||
|
||||
# The execute method expects kwargs that match the function parameters in get_functions.
|
||||
# So, the argument name for the function to get is 'function_name' in the tool's spec.
|
||||
result = tool.execute(function_name="get_specific_function_metrics", **{"function_name": "specific_func"})
|
||||
self.assertEqual(result, func_metrics)
|
||||
|
||||
def test_execute_get_specific_function_metrics_not_found(self):
|
||||
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
|
||||
self.mock_metrics_provider.get_metrics.return_value = {"other_func": {}}
|
||||
|
||||
result = tool.execute(function_name="get_specific_function_metrics", **{"function_name": "non_existent_func"})
|
||||
self.assertEqual(result, "No metrics found for function: non_existent_func")
|
||||
|
||||
def test_execute_get_specific_function_metrics_missing_arg(self):
|
||||
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
|
||||
result = tool.execute(function_name="get_specific_function_metrics") # Missing function_name kwarg
|
||||
self.assertIn("Error: Missing required argument 'function_name'", result)
|
||||
|
||||
|
||||
def test_execute_get_top_n_functions(self):
|
||||
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
|
||||
metrics_data = {
|
||||
"func_a": {"call_count": 1, "total_time": 0.3},
|
||||
"func_b": {"call_count": 1, "total_time": 0.1},
|
||||
"func_c": {"call_count": 1, "total_time": 0.5},
|
||||
"func_d": {"call_count": 1, "total_time": 0.2},
|
||||
}
|
||||
self.mock_metrics_provider.get_metrics.return_value = metrics_data
|
||||
|
||||
# Test getting top 2
|
||||
result = tool.execute(function_name="get_top_n_functions", n=2)
|
||||
expected_top_2 = {"func_c": metrics_data["func_c"], "func_a": metrics_data["func_a"]}
|
||||
self.assertEqual(result, expected_top_2)
|
||||
|
||||
# Test getting top 1
|
||||
result_top_1 = tool.execute(function_name="get_top_n_functions", n=1)
|
||||
expected_top_1 = {"func_c": metrics_data["func_c"]}
|
||||
self.assertEqual(result_top_1, expected_top_1)
|
||||
|
||||
# Test N larger than available functions
|
||||
result_top_all = tool.execute(function_name="get_top_n_functions", n=10)
|
||||
# Order should be func_c, func_a, func_d, func_b
|
||||
expected_top_all_keys = ["func_c", "func_a", "func_d", "func_b"]
|
||||
self.assertEqual(list(result_top_all.keys()), expected_top_all_keys)
|
||||
|
||||
def test_execute_get_top_n_functions_malformed_metrics(self):
|
||||
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
|
||||
metrics_data = {
|
||||
"func_a": {"call_count": 1, "total_time": 0.3},
|
||||
"func_b": "not a dict", # Malformed
|
||||
"func_c": {"call_count": 1}, # Missing total_time
|
||||
"func_d": {"call_count": 1, "total_time": 0.2},
|
||||
}
|
||||
self.mock_metrics_provider.get_metrics.return_value = metrics_data
|
||||
|
||||
metrics_tool_logger = logging.getLogger('tools.metrics_tool')
|
||||
original_level = metrics_tool_logger.level
|
||||
metrics_tool_logger.setLevel(logging.WARNING)
|
||||
with self.assertLogs(metrics_tool_logger, level='WARNING') as log_capture:
|
||||
result = tool.execute(function_name="get_top_n_functions", n=2)
|
||||
metrics_tool_logger.setLevel(original_level)
|
||||
|
||||
# Check that warnings were logged for malformed items
|
||||
self.assertTrue(any("Metric item for 'func_b' is not in expected format" in msg for msg in log_capture.output))
|
||||
self.assertTrue(any("Metric item for 'func_c' is not in expected format" in msg for msg in log_capture.output))
|
||||
|
||||
# Expected: func_a, func_d (as they are valid and sortable)
|
||||
expected_result = {
|
||||
"func_a": metrics_data["func_a"],
|
||||
"func_d": metrics_data["func_d"]
|
||||
}
|
||||
self.assertEqual(result, expected_result)
|
||||
|
||||
|
||||
def test_execute_get_top_n_functions_invalid_n(self):
|
||||
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
|
||||
self.mock_metrics_provider.get_metrics.return_value = {} # No metrics needed for this test
|
||||
|
||||
result_zero = tool.execute(function_name="get_top_n_functions", n=0)
|
||||
self.assertIn("Error: Argument 'n' must be a positive integer.", result_zero)
|
||||
|
||||
result_negative = tool.execute(function_name="get_top_n_functions", n=-1)
|
||||
self.assertIn("Error: Argument 'n' must be a positive integer.", result_negative)
|
||||
|
||||
result_string = tool.execute(function_name="get_top_n_functions", n="abc")
|
||||
self.assertIn("Error: Argument 'n' must be an integer.", result_string)
|
||||
|
||||
def test_execute_get_top_n_functions_missing_arg_n(self):
|
||||
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
|
||||
result = tool.execute(function_name="get_top_n_functions") # Missing n
|
||||
self.assertIn("Error: Missing required argument 'n'.", result)
|
||||
|
||||
|
||||
def test_execute_unknown_function(self):
|
||||
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
|
||||
result = tool.execute(function_name="non_existent_metrics_function")
|
||||
self.assertIn("Unknown function: non_existent_metrics_function", result)
|
||||
|
||||
def test_clear_method(self):
|
||||
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
|
||||
metrics_tool_logger = logging.getLogger('tools.metrics_tool')
|
||||
original_level = metrics_tool_logger.level
|
||||
metrics_tool_logger.setLevel(logging.DEBUG)
|
||||
with self.assertLogs(metrics_tool_logger, level='DEBUG') as cm:
|
||||
tool.clear()
|
||||
metrics_tool_logger.setLevel(original_level)
|
||||
self.assertTrue(any("MetricsTool clear method called" in message for message in cm.output))
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -4,8 +4,7 @@ import zipfile
|
||||
import io
|
||||
import re
|
||||
import logging
|
||||
from .base_tool import BaseTool # Added
|
||||
from .metrics import metrics # Added
|
||||
from .base_tool import BaseTool
|
||||
|
||||
# Configure logging for the tool - This will be handled by the logger instance now
|
||||
# logger = logging.getLogger(__name__) # Commented out or removed
|
||||
@@ -70,7 +69,8 @@ class GitHubCIHelper(BaseTool): # Inherits from BaseTool
|
||||
},
|
||||
"required": ["pull_request_number"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"_tags": ["read"]
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@@ -85,7 +85,8 @@ class GitHubCIHelper(BaseTool): # Inherits from BaseTool
|
||||
},
|
||||
"required": ["pull_request_number"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"_tags": ["read"]
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@@ -100,7 +101,8 @@ class GitHubCIHelper(BaseTool): # Inherits from BaseTool
|
||||
},
|
||||
"required": ["run_id"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"_tags": ["read"]
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@@ -114,7 +116,8 @@ class GitHubCIHelper(BaseTool): # Inherits from BaseTool
|
||||
},
|
||||
"required": ["log_content"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"_tags": ["read"]
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
+72
-74
@@ -1,6 +1,5 @@
|
||||
# tools/github_tool.py
|
||||
from .base_tool import BaseTool
|
||||
from .metrics import metrics
|
||||
import requests
|
||||
import os
|
||||
import base64
|
||||
@@ -57,7 +56,8 @@ class GitHubTool(BaseTool):
|
||||
},
|
||||
"required": ["path"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"_tags": ["read"]
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@@ -71,7 +71,8 @@ class GitHubTool(BaseTool):
|
||||
},
|
||||
"required": ["path"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"_tags": ["read"]
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@@ -85,7 +86,8 @@ class GitHubTool(BaseTool):
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"_tags": ["read"]
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@@ -100,7 +102,8 @@ class GitHubTool(BaseTool):
|
||||
},
|
||||
"required": ["branch_name"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"_tags": ["write"]
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@@ -116,7 +119,8 @@ class GitHubTool(BaseTool):
|
||||
},
|
||||
"required": ["file_path", "commit_message", "content"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"_tags": ["write"]
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@@ -132,7 +136,8 @@ class GitHubTool(BaseTool):
|
||||
},
|
||||
"required": ["title", "body"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"_tags": ["write"]
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@@ -147,7 +152,8 @@ class GitHubTool(BaseTool):
|
||||
},
|
||||
"required": ["file_path"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"_tags": ["read"]
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@@ -162,7 +168,8 @@ class GitHubTool(BaseTool):
|
||||
},
|
||||
"required": ["file_path"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"_tags": ["read"]
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@@ -176,7 +183,8 @@ class GitHubTool(BaseTool):
|
||||
},
|
||||
"required": ["branch"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"_tags": ["read"]
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@@ -184,7 +192,8 @@ class GitHubTool(BaseTool):
|
||||
"name": "get_current_branch",
|
||||
"description": "Get the name of the current branch",
|
||||
"parameters": { "type": "object", "properties": {} }
|
||||
}
|
||||
},
|
||||
"_tags": ["read"]
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@@ -198,7 +207,8 @@ class GitHubTool(BaseTool):
|
||||
},
|
||||
"required": ["branch_name"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"_tags": ["read", "write"]
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@@ -213,7 +223,8 @@ class GitHubTool(BaseTool):
|
||||
},
|
||||
"required": ["file_path", "commit_sha"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"_tags": ["read"]
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@@ -227,7 +238,8 @@ class GitHubTool(BaseTool):
|
||||
"all_pages": {"type": "boolean", "description": "Whether to fetch all pages of results", "default": True}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"_tags": ["read"]
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@@ -241,7 +253,8 @@ class GitHubTool(BaseTool):
|
||||
},
|
||||
"required": ["pull_number"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"_tags": ["write"]
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@@ -255,7 +268,8 @@ class GitHubTool(BaseTool):
|
||||
},
|
||||
"required": ["pull_number"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"_tags": ["write"]
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@@ -277,7 +291,8 @@ class GitHubTool(BaseTool):
|
||||
},
|
||||
"required": ["pull_number"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"_tags": ["write"]
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@@ -291,7 +306,8 @@ class GitHubTool(BaseTool):
|
||||
},
|
||||
"required": ["branch_name"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"_tags": ["write"]
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@@ -305,7 +321,8 @@ class GitHubTool(BaseTool):
|
||||
},
|
||||
"required": ["issue_number"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"_tags": ["read"]
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@@ -325,7 +342,8 @@ class GitHubTool(BaseTool):
|
||||
},
|
||||
"required": ["title", "body"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"_tags": ["communicate"]
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@@ -340,7 +358,8 @@ class GitHubTool(BaseTool):
|
||||
"page": {"type": "integer", "default": 1, "description": "Page number of the results to fetch"}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"_tags": ["read"]
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@@ -355,7 +374,8 @@ class GitHubTool(BaseTool):
|
||||
},
|
||||
"required": ["issue_number", "comment"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"_tags": ["communicate"]
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@@ -369,7 +389,8 @@ class GitHubTool(BaseTool):
|
||||
},
|
||||
"required": ["issue_number"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"_tags": ["read"]
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@@ -383,7 +404,8 @@ class GitHubTool(BaseTool):
|
||||
},
|
||||
"required": ["pull_number"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"_tags": ["read"]
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@@ -398,7 +420,8 @@ class GitHubTool(BaseTool):
|
||||
},
|
||||
"required": ["name"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"_tags": ["communicate"]
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@@ -413,7 +436,8 @@ class GitHubTool(BaseTool):
|
||||
},
|
||||
"required": ["project_id", "column_name"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"_tags": ["communicate"]
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@@ -428,7 +452,8 @@ class GitHubTool(BaseTool):
|
||||
},
|
||||
"required": ["column_id", "note"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"_tags": ["communicate"]
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@@ -444,7 +469,8 @@ class GitHubTool(BaseTool):
|
||||
},
|
||||
"required": ["card_id", "position", "column_id"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"_tags": ["communicate"]
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@@ -460,7 +486,8 @@ class GitHubTool(BaseTool):
|
||||
},
|
||||
"required": ["card_id", "content_id", "content_type"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"_tags": ["communicate"]
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@@ -468,7 +495,8 @@ class GitHubTool(BaseTool):
|
||||
"name": "list_project_boards",
|
||||
"description": "List project boards associated with the repository",
|
||||
"parameters": { "type": "object", "properties": {} }
|
||||
}
|
||||
},
|
||||
"_tags": ["communicate"]
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@@ -482,7 +510,8 @@ class GitHubTool(BaseTool):
|
||||
},
|
||||
"required": ["project_id"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"_tags": ["communicate"]
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@@ -496,7 +525,8 @@ class GitHubTool(BaseTool):
|
||||
},
|
||||
"required": ["pull_number"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"_tags": ["read"]
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@@ -510,7 +540,8 @@ class GitHubTool(BaseTool):
|
||||
},
|
||||
"required": ["pull_number"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"_tags": ["read"]
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@@ -524,7 +555,8 @@ class GitHubTool(BaseTool):
|
||||
},
|
||||
"required": ["pull_number"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"_tags": ["read"]
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@@ -545,7 +577,8 @@ class GitHubTool(BaseTool):
|
||||
},
|
||||
"required": ["pull_number", "body", "commit_id", "path", "position"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"_tags": ["communicate"]
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@@ -559,7 +592,8 @@ class GitHubTool(BaseTool):
|
||||
},
|
||||
"required": ["pull_number"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"_tags": ["read"]
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@@ -575,11 +609,11 @@ class GitHubTool(BaseTool):
|
||||
},
|
||||
"required": ["pull_number", "event"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"_tags": ["communicate"]
|
||||
}
|
||||
]
|
||||
|
||||
@metrics.measure
|
||||
def execute(self, function_name, **kwargs):
|
||||
self.logger.info(f"Executing GitHub Tool function: {function_name} with args: {kwargs}")
|
||||
# Dispatch to the appropriate private method
|
||||
@@ -598,7 +632,6 @@ class GitHubTool(BaseTool):
|
||||
|
||||
# Private methods for each function, using self.session for HTTP requests
|
||||
|
||||
@metrics.measure
|
||||
def _read_file(self, path):
|
||||
self.logger.info(f"Reading file: {path} from branch: {self.current_branch}")
|
||||
url = f"{self.base_url}/repos/{self._repo}/contents/{path}"
|
||||
@@ -613,7 +646,6 @@ class GitHubTool(BaseTool):
|
||||
self.logger.error(error_message)
|
||||
return error_message
|
||||
|
||||
@metrics.measure
|
||||
def _create_branch(self, branch_name, base_branch="main"):
|
||||
self.logger.info(f"Creating branch: {branch_name} from base: {base_branch}")
|
||||
# Get SHA of base branch
|
||||
@@ -639,7 +671,6 @@ class GitHubTool(BaseTool):
|
||||
self.logger.error(error_message)
|
||||
return error_message
|
||||
|
||||
@metrics.measure
|
||||
def _commit_file(self, file_path, content, commit_message):
|
||||
self.logger.info(f"Committing file: {file_path} to branch: {self.current_branch} with message: '{commit_message}'")
|
||||
if self.current_branch == "main":
|
||||
@@ -679,7 +710,6 @@ class GitHubTool(BaseTool):
|
||||
self.logger.error(error_message)
|
||||
return error_message
|
||||
|
||||
@metrics.measure
|
||||
def _create_pull_request(self, title, body, base="main"):
|
||||
self.logger.info(f"Creating pull request: '{title}' from branch '{self.current_branch}' to '{base}'")
|
||||
if self.current_branch == base:
|
||||
@@ -701,7 +731,6 @@ class GitHubTool(BaseTool):
|
||||
self.logger.error(error_message)
|
||||
return error_message
|
||||
|
||||
@metrics.measure
|
||||
def _get_branch_sha(self, branch):
|
||||
self.logger.info(f"Getting SHA for branch: {branch}")
|
||||
url = f"{self.base_url}/repos/{self._repo}/git/refs/heads/{branch}"
|
||||
@@ -715,7 +744,6 @@ class GitHubTool(BaseTool):
|
||||
self.logger.error(error_message)
|
||||
return error_message
|
||||
|
||||
@metrics.measure
|
||||
def _list_files(self, path):
|
||||
self.logger.info(f"Listing files in path: '{path}' on branch: '{self.current_branch}'")
|
||||
url = f"{self.base_url}/repos/{self._repo}/contents/{path.strip('/')}" # Ensure no leading/trailing slashes for consistency
|
||||
@@ -738,7 +766,6 @@ class GitHubTool(BaseTool):
|
||||
self.logger.error(error_message)
|
||||
return error_message
|
||||
|
||||
@metrics.measure
|
||||
def _search_code(self, query):
|
||||
self.logger.info(f"Searching code with query: '{query}' in repo: '{self._repo}'")
|
||||
url = f"{self.base_url}/search/code"
|
||||
@@ -754,7 +781,6 @@ class GitHubTool(BaseTool):
|
||||
self.logger.error(error_message)
|
||||
return error_message
|
||||
|
||||
@metrics.measure
|
||||
def _get_commit_history(self, file_path, num_commits=10):
|
||||
self.logger.info(f"Getting last {num_commits} commit(s) for file: '{file_path}' on branch '{self.current_branch}'")
|
||||
url = f"{self.base_url}/repos/{self._repo}/commits"
|
||||
@@ -775,18 +801,15 @@ class GitHubTool(BaseTool):
|
||||
self.logger.error(error_message)
|
||||
return error_message
|
||||
|
||||
@metrics.measure
|
||||
def _view_commit_details_for_file(self, file_path, num_commits=10):
|
||||
# This function is essentially the same as get_commit_history based on its description.
|
||||
self.logger.info(f"Viewing commit details for file '{file_path}' (last {num_commits} commits) - using _get_commit_history.")
|
||||
return self._get_commit_history(file_path, num_commits)
|
||||
|
||||
@metrics.measure
|
||||
def _get_current_branch(self):
|
||||
self.logger.info(f"Current branch is: {self.current_branch}")
|
||||
return self.current_branch
|
||||
|
||||
@metrics.measure
|
||||
def _set_current_branch(self, branch_name):
|
||||
self.logger.info(f"Attempting to set current branch to: {branch_name}")
|
||||
# Check if branch exists by trying to get its SHA
|
||||
@@ -801,7 +824,6 @@ class GitHubTool(BaseTool):
|
||||
self.logger.info(success_message)
|
||||
return success_message
|
||||
|
||||
@metrics.measure
|
||||
def _get_file_at_commit(self, file_path, commit_sha):
|
||||
self.logger.info(f"Getting file '{file_path}' at commit SHA: {commit_sha}")
|
||||
url = f"{self.base_url}/repos/{self._repo}/contents/{file_path}"
|
||||
@@ -816,7 +838,6 @@ class GitHubTool(BaseTool):
|
||||
self.logger.error(error_message)
|
||||
return error_message
|
||||
|
||||
@metrics.measure
|
||||
def _list_branches(self, per_page=100, all_pages=True):
|
||||
self.logger.info(f"Listing branches for repo '{self._repo}'. Per_page={per_page}, All_pages={all_pages}")
|
||||
url = f"{self.base_url}/repos/{self._repo}/branches"
|
||||
@@ -844,7 +865,6 @@ class GitHubTool(BaseTool):
|
||||
self.logger.info(f"Successfully listed {len(branches_list)} branches.")
|
||||
return branches_list
|
||||
|
||||
@metrics.measure
|
||||
def _approve_pull_request(self, pull_number):
|
||||
self.logger.info(f"Approving pull request #{pull_number}")
|
||||
url = f"{self.base_url}/repos/{self._repo}/pulls/{pull_number}/reviews"
|
||||
@@ -859,7 +879,6 @@ class GitHubTool(BaseTool):
|
||||
self.logger.error(error_message)
|
||||
return error_message
|
||||
|
||||
@metrics.measure
|
||||
def _close_pull_request(self, pull_number):
|
||||
self.logger.info(f"Closing pull request #{pull_number}")
|
||||
url = f"{self.base_url}/repos/{self._repo}/pulls/{pull_number}"
|
||||
@@ -874,7 +893,6 @@ class GitHubTool(BaseTool):
|
||||
self.logger.error(error_message)
|
||||
return error_message
|
||||
|
||||
@metrics.measure
|
||||
def _merge_pull_request(self, pull_number, commit_title="Merge pull request", commit_message="", merge_method="merge"):
|
||||
self.logger.info(f"Merging pull request #{pull_number} using method '{merge_method}'")
|
||||
url = f"{self.base_url}/repos/{self._repo}/pulls/{pull_number}/merge"
|
||||
@@ -897,7 +915,6 @@ class GitHubTool(BaseTool):
|
||||
self.logger.error(error_message)
|
||||
return error_message
|
||||
|
||||
@metrics.measure
|
||||
def _delete_branch(self, branch_name):
|
||||
self.logger.info(f"Deleting branch: {branch_name}")
|
||||
if branch_name == "main" or (hasattr(self, 'default_branch') and branch_name == self.default_branch) :
|
||||
@@ -920,7 +937,6 @@ class GitHubTool(BaseTool):
|
||||
self.logger.error(error_message)
|
||||
return error_message
|
||||
|
||||
@metrics.measure
|
||||
def _get_issue_details(self, issue_number):
|
||||
self.logger.info(f"Getting details for issue #{issue_number}")
|
||||
url = f"{self.base_url}/repos/{self._repo}/issues/{issue_number}"
|
||||
@@ -933,7 +949,6 @@ class GitHubTool(BaseTool):
|
||||
self.logger.error(error_message)
|
||||
return error_message
|
||||
|
||||
@metrics.measure
|
||||
def _create_issue(self, title, body, labels=None):
|
||||
self.logger.info(f"Creating new issue with title: '{title}'")
|
||||
url = f"{self.base_url}/repos/{self._repo}/issues"
|
||||
@@ -953,7 +968,6 @@ class GitHubTool(BaseTool):
|
||||
self.logger.error(error_message)
|
||||
return error_message
|
||||
|
||||
@metrics.measure
|
||||
def _list_issues(self, state="open", per_page=30, page=1):
|
||||
self.logger.info(f"Listing issues with state: {state}, per_page: {per_page}, page: {page}")
|
||||
url = f"{self.base_url}/repos/{self._repo}/issues"
|
||||
@@ -969,7 +983,6 @@ class GitHubTool(BaseTool):
|
||||
self.logger.error(error_message)
|
||||
return error_message
|
||||
|
||||
@metrics.measure
|
||||
def _add_issue_comment(self, issue_number, comment):
|
||||
self.logger.info(f"Adding comment to issue #{issue_number}: '{comment[:50]}...'")
|
||||
url = f"{self.base_url}/repos/{self._repo}/issues/{issue_number}/comments"
|
||||
@@ -985,7 +998,6 @@ class GitHubTool(BaseTool):
|
||||
self.logger.error(error_message)
|
||||
return error_message
|
||||
|
||||
@metrics.measure
|
||||
def _get_issue_comments(self, issue_number):
|
||||
self.logger.info(f"Getting comments for issue #{issue_number}")
|
||||
url = f"{self.base_url}/repos/{self._repo}/issues/{issue_number}/comments"
|
||||
@@ -1000,14 +1012,12 @@ class GitHubTool(BaseTool):
|
||||
self.logger.error(error_message)
|
||||
return error_message
|
||||
|
||||
@metrics.measure
|
||||
def _get_pull_request_general_comments(self, pull_number):
|
||||
self.logger.info(f"Getting general comments for pull request #{pull_number}")
|
||||
# In GitHub API, PR comments (general, not review comments on lines) are issue comments.
|
||||
# The PR is also an issue, so use the issue comments endpoint.
|
||||
return self._get_issue_comments(issue_number=pull_number)
|
||||
|
||||
@metrics.measure
|
||||
def _create_project_board(self, name, body=None):
|
||||
self.logger.info(f"Creating project board: '{name}'")
|
||||
url = f"{self.base_url}/repos/{self._repo}/projects"
|
||||
@@ -1026,7 +1036,6 @@ class GitHubTool(BaseTool):
|
||||
self.logger.error(error_message)
|
||||
return error_message
|
||||
|
||||
@metrics.measure
|
||||
def _create_project_column(self, project_id, column_name):
|
||||
self.logger.info(f"Creating column '{column_name}' for project ID: {project_id}")
|
||||
url = f"{self.base_url}/projects/{project_id}/columns"
|
||||
@@ -1044,7 +1053,6 @@ class GitHubTool(BaseTool):
|
||||
self.logger.error(error_message)
|
||||
return error_message
|
||||
|
||||
@metrics.measure
|
||||
def _create_project_card(self, column_id, note=None, content_id=None, content_type=None):
|
||||
self.logger.info(f"Creating card in column ID: {column_id}")
|
||||
url = f"{self.base_url}/projects/columns/{column_id}/cards"
|
||||
@@ -1075,7 +1083,6 @@ class GitHubTool(BaseTool):
|
||||
self.logger.error(error_message)
|
||||
return error_message
|
||||
|
||||
@metrics.measure
|
||||
def _move_project_card(self, card_id, position, column_id=None):
|
||||
self.logger.info(f"Moving card ID: {card_id} to position: {position}" + (f" in column ID: {column_id}" if column_id else ""))
|
||||
url = f"{self.base_url}/projects/columns/cards/{card_id}/moves"
|
||||
@@ -1100,7 +1107,6 @@ class GitHubTool(BaseTool):
|
||||
# For updating an existing card to link an issue, one would PATCH the card's content_id/content_type.
|
||||
# Let's assume the function intends to update an existing card if it's a separate function.
|
||||
# However, the provided API spec for `link_issue_to_project_card` uses PATCH on card_id, so let's implement that.
|
||||
@metrics.measure
|
||||
def _link_issue_to_project_card(self, card_id, content_id, content_type):
|
||||
self.logger.info(f"Linking content_id {content_id} (type: {content_type}) to card_id {card_id}")
|
||||
url = f"{self.base_url}/projects/cards/{card_id}" # Note: API docs suggest /projects/columns/cards/{card_id} or /projects/cards/{card_id}
|
||||
@@ -1120,7 +1126,6 @@ class GitHubTool(BaseTool):
|
||||
self.logger.error(error_message)
|
||||
return error_message
|
||||
|
||||
@metrics.measure
|
||||
def _list_project_boards(self):
|
||||
self.logger.info(f"Listing project boards for repo: {self._repo}")
|
||||
url = f"{self.base_url}/repos/{self._repo}/projects"
|
||||
@@ -1136,7 +1141,6 @@ class GitHubTool(BaseTool):
|
||||
self.logger.error(error_message)
|
||||
return error_message
|
||||
|
||||
@metrics.measure
|
||||
def _view_project_board_items(self, project_id):
|
||||
self.logger.info(f"Viewing items for project ID: {project_id}")
|
||||
columns_url = f"{self.base_url}/projects/{project_id}/columns"
|
||||
@@ -1165,7 +1169,6 @@ class GitHubTool(BaseTool):
|
||||
self.logger.info(f"Successfully retrieved items for project ID: {project_id}.")
|
||||
return project_items
|
||||
|
||||
@metrics.measure
|
||||
def _get_pull_request_details(self, pull_number):
|
||||
self.logger.info(f"Getting details for PR #{pull_number}")
|
||||
url = f"{self.base_url}/repos/{self._repo}/pulls/{pull_number}"
|
||||
@@ -1178,7 +1181,6 @@ class GitHubTool(BaseTool):
|
||||
self.logger.error(error_message)
|
||||
return error_message
|
||||
|
||||
@metrics.measure
|
||||
def _get_pull_request_diff(self, pull_number):
|
||||
self.logger.info(f"Getting diff for PR #{pull_number}")
|
||||
url = f"{self.base_url}/repos/{self._repo}/pulls/{pull_number}"
|
||||
@@ -1193,7 +1195,6 @@ class GitHubTool(BaseTool):
|
||||
self.logger.error(error_message)
|
||||
return error_message
|
||||
|
||||
@metrics.measure
|
||||
def _get_pull_request_files(self, pull_number):
|
||||
self.logger.info(f"Getting files for PR #{pull_number}")
|
||||
url = f"{self.base_url}/repos/{self._repo}/pulls/{pull_number}/files"
|
||||
@@ -1206,7 +1207,6 @@ class GitHubTool(BaseTool):
|
||||
self.logger.error(error_message)
|
||||
return error_message
|
||||
|
||||
@metrics.measure
|
||||
def _create_pull_request_review_comment(self, pull_number, body, commit_id, path, position, side="RIGHT", start_line=None, start_side=None):
|
||||
self.logger.info(f"Creating review comment on PR #{pull_number}, file '{path}', position {position}")
|
||||
url = f"{self.base_url}/repos/{self._repo}/pulls/{pull_number}/comments"
|
||||
@@ -1225,7 +1225,6 @@ class GitHubTool(BaseTool):
|
||||
self.logger.error(error_message)
|
||||
return error_message
|
||||
|
||||
@metrics.measure
|
||||
def _list_pull_request_review_comments(self, pull_number):
|
||||
self.logger.info(f"Listing review comments for PR #{pull_number}")
|
||||
url = f"{self.base_url}/repos/{self._repo}/pulls/{pull_number}/comments"
|
||||
@@ -1238,7 +1237,6 @@ class GitHubTool(BaseTool):
|
||||
self.logger.error(error_message)
|
||||
return error_message
|
||||
|
||||
@metrics.measure
|
||||
def _submit_pull_request_review(self, pull_number, event, body=None):
|
||||
self.logger.info(f"Submitting '{event}' review for PR #{pull_number}")
|
||||
url = f"{self.base_url}/repos/{self._repo}/pulls/{pull_number}/reviews"
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# tools/log_tool.py
|
||||
from .base_tool import BaseTool
|
||||
from .metrics import metrics
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
@@ -44,7 +43,6 @@ class LogTool(BaseTool):
|
||||
}
|
||||
]
|
||||
|
||||
@metrics.measure
|
||||
def execute(self, function_name, **kwargs):
|
||||
self.logger.info(f"Executing LogTool function: {function_name} with args: {kwargs}")
|
||||
if function_name == "get_log_contents":
|
||||
@@ -55,7 +53,6 @@ class LogTool(BaseTool):
|
||||
self.logger.error(error_message)
|
||||
return error_message
|
||||
|
||||
@metrics.measure
|
||||
def _get_log_contents(self, line_count=None): # Default line_count is None to trigger 24h logic if not specified
|
||||
self.logger.info(f"Attempting to get log contents from: {self.configured_log_file_path}. Line count: {line_count if line_count is not None else 'Last 24 hours'}")
|
||||
|
||||
|
||||
@@ -1,79 +0,0 @@
|
||||
# tools/metrics.py
|
||||
import cProfile
|
||||
import pstats
|
||||
import io
|
||||
from functools import wraps
|
||||
from collections import defaultdict
|
||||
import logging
|
||||
|
||||
class Metrics:
|
||||
def __init__(self, logger=None):
|
||||
self.call_count = defaultdict(int)
|
||||
self.total_time = defaultdict(float)
|
||||
self.logger = logger if logger else logging.getLogger(__name__)
|
||||
if not self.logger.handlers:
|
||||
self.logger.addHandler(logging.NullHandler())
|
||||
self.logger.debug("Metrics instance initialized.")
|
||||
|
||||
def measure(self, func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
self.call_count[func.__name__] += 1
|
||||
|
||||
pr = cProfile.Profile()
|
||||
pr.enable()
|
||||
result = func(*args, **kwargs)
|
||||
pr.disable()
|
||||
|
||||
ps = pstats.Stats(pr)
|
||||
|
||||
func_code = func.__code__
|
||||
func_key_tuple = (func_code.co_filename, func_code.co_firstlineno, func_code.co_name)
|
||||
|
||||
time_spent_for_func = 0.0
|
||||
if func_key_tuple in ps.stats:
|
||||
time_spent_for_func = ps.stats[func_key_tuple][3] # [3] is cumulative time (ct)
|
||||
else:
|
||||
# Fallback: try to find by function name if exact key fails (e.g. due to decorators changing code object details slightly)
|
||||
# This is less precise and might pick up other functions if names are not unique across files.
|
||||
found_by_name = False
|
||||
for key, stat in ps.stats.items():
|
||||
if key[2] == func.__name__: # key[2] is function name
|
||||
time_spent_for_func = stat[3] # cumulative time
|
||||
self.logger.debug(f"Found stats for {func.__name__} by name {key} after primary key failed.")
|
||||
found_by_name = True
|
||||
break
|
||||
if not found_by_name:
|
||||
self.logger.warning(
|
||||
f"Could not find exact cProfile stats for {func.__name__} with key {func_key_tuple} or by name. "
|
||||
f"Time for this call will be recorded as 0. This might occur for non-Python functions or due to complex decorators."
|
||||
)
|
||||
|
||||
self.total_time[func.__name__] += time_spent_for_func
|
||||
self.logger.debug(f"Measured cumulative time for {func.__name__}: {time_spent_for_func:.6f}s")
|
||||
|
||||
return result
|
||||
return wrapper
|
||||
|
||||
def get_metrics(self):
|
||||
metrics_data = {}
|
||||
for func_name in self.call_count:
|
||||
count = self.call_count[func_name]
|
||||
total_t = self.total_time[func_name]
|
||||
metrics_data[func_name] = {
|
||||
'call_count': count,
|
||||
'total_time': round(total_t, 6),
|
||||
'average_time': round(total_t / count, 6) if count > 0 else 0
|
||||
}
|
||||
return metrics_data
|
||||
|
||||
def clear_metrics(self):
|
||||
self.call_count.clear()
|
||||
self.total_time.clear()
|
||||
self.logger.info("Metrics cleared.")
|
||||
|
||||
# Global instance for convenience
|
||||
_metrics_instance_logger = logging.getLogger(__name__ + ".global_instance")
|
||||
if not _metrics_instance_logger.handlers:
|
||||
_metrics_instance_logger.addHandler(logging.NullHandler())
|
||||
metrics = Metrics(logger=_metrics_instance_logger)
|
||||
@@ -1,128 +0,0 @@
|
||||
# tools/metrics_tool.py
|
||||
from .base_tool import BaseTool
|
||||
from .metrics import metrics as global_metrics_instance # For default and measuring execute
|
||||
from .metrics import Metrics # For type hinting and potentially creating a new one if needed
|
||||
import logging
|
||||
|
||||
class MetricsTool(BaseTool):
|
||||
def __init__(self, metrics_provider: Metrics | None = None, logger: logging.Logger | None = None):
|
||||
self.metrics_provider = metrics_provider if metrics_provider is not None else global_metrics_instance
|
||||
self.logger = logger if logger else logging.getLogger(__name__)
|
||||
if not self.logger.handlers:
|
||||
self.logger.addHandler(logging.NullHandler())
|
||||
self.logger.debug(f"MetricsTool initialized. Using metrics provider: {self.metrics_provider}")
|
||||
|
||||
def clear(self):
|
||||
# This tool itself doesn't hold state that needs clearing beyond what its metrics_provider might do.
|
||||
# If this tool were responsible for clearing the metrics it reports on, it would call:
|
||||
# self.metrics_provider.clear_metrics()
|
||||
self.logger.debug("MetricsTool clear method called. No local state to clear.")
|
||||
pass
|
||||
|
||||
def get_functions(self):
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_function_metrics",
|
||||
"description": "Get metrics for all measured functions.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_specific_function_metrics",
|
||||
"description": "Get metrics for a specific function.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"function_name": {
|
||||
"type": "string",
|
||||
"description": "Name of the function to get metrics for"
|
||||
}
|
||||
},
|
||||
"required": ["function_name"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_top_n_functions",
|
||||
"description": "Get the top N functions by total execution time.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"n": {
|
||||
"type": "integer",
|
||||
"description": "Number of top functions to retrieve"
|
||||
}
|
||||
},
|
||||
"required": ["n"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
@global_metrics_instance.measure # The execute method can be measured by the global instance
|
||||
def execute(self, function_name, **kwargs):
|
||||
self.logger.info(f"Executing MetricsTool function: {function_name} with args: {kwargs}")
|
||||
if function_name == "get_function_metrics":
|
||||
return self._get_function_metrics()
|
||||
elif function_name == "get_specific_function_metrics":
|
||||
func_name_arg = kwargs.get("function_name")
|
||||
if func_name_arg is None: # Check if None, as empty string could be a valid (though unlikely) func name
|
||||
self.logger.warning("'function_name' argument is missing for get_specific_function_metrics.")
|
||||
return "Error: Missing required argument 'function_name'."
|
||||
return self._get_specific_function_metrics(str(func_name_arg)) # Ensure string
|
||||
elif function_name == "get_top_n_functions":
|
||||
n_arg = kwargs.get("n")
|
||||
if n_arg is None:
|
||||
self.logger.warning("'n' argument is missing for get_top_n_functions.")
|
||||
return "Error: Missing required argument 'n'."
|
||||
try:
|
||||
n_val = int(n_arg)
|
||||
if n_val <= 0:
|
||||
self.logger.warning(f"'n' argument must be a positive integer, got {n_val}.")
|
||||
return "Error: Argument 'n' must be a positive integer."
|
||||
return self._get_top_n_functions(n_val)
|
||||
except ValueError:
|
||||
self.logger.warning(f"'n' argument must be an integer, got '{n_arg}'.")
|
||||
return "Error: Argument 'n' must be an integer."
|
||||
else:
|
||||
error_message = f"Unknown function: {function_name}"
|
||||
self.logger.error(error_message)
|
||||
return error_message
|
||||
|
||||
def _get_function_metrics(self):
|
||||
self.logger.debug("Calling metrics_provider.get_metrics() for all functions.")
|
||||
return self.metrics_provider.get_metrics()
|
||||
|
||||
def _get_specific_function_metrics(self, function_to_get):
|
||||
self.logger.debug(f"Getting metrics for specific function: {function_to_get}")
|
||||
all_metrics = self.metrics_provider.get_metrics()
|
||||
return all_metrics.get(function_to_get, f"No metrics found for function: {function_to_get}")
|
||||
|
||||
def _get_top_n_functions(self, n):
|
||||
self.logger.debug(f"Getting top {n} functions by total execution time.")
|
||||
all_metrics = self.metrics_provider.get_metrics()
|
||||
# Ensure that the items are actual metric dicts before trying to access 'total_time'
|
||||
valid_metrics_items = []
|
||||
for name, metric_values in all_metrics.items():
|
||||
if isinstance(metric_values, dict) and 'total_time' in metric_values:
|
||||
valid_metrics_items.append((name, metric_values))
|
||||
else:
|
||||
self.logger.warning(f"Metric item for '{name}' is not in expected format: {metric_values}")
|
||||
|
||||
# Sort items by total_time. items() gives list of (func_name, metrics_dict)
|
||||
try:
|
||||
sorted_metrics = sorted(valid_metrics_items, key=lambda item: item[1]['total_time'], reverse=True)
|
||||
return dict(sorted_metrics[:n])
|
||||
except TypeError as e:
|
||||
self.logger.error(f"Error sorting metrics, possibly due to unexpected data types: {e}", exc_info=True)
|
||||
return "Error: Could not sort metrics due to unexpected data."
|
||||
@@ -28,7 +28,7 @@ class StandaloneLLMTool(BaseTool):
|
||||
"model": {
|
||||
"type": "string",
|
||||
"description": "The model to use for generating the detailed instructions. Use mini for most coding tasks, preview when needing sophisticated reasoning",
|
||||
"enum": ["o1-mini", "o1-preview"],
|
||||
"enum": ["mini", "max"],
|
||||
"default": "o1-mini"
|
||||
},
|
||||
"max_tokens": {
|
||||
@@ -38,7 +38,8 @@ class StandaloneLLMTool(BaseTool):
|
||||
},
|
||||
"required": ["prompt"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"_tags": ["llm", "external"]
|
||||
}
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user