Refactored gemini, openai and claude into one file and removed logic from the base class, also made helper class definable from command line

This commit is contained in:
2025-06-03 13:04:42 -05:00
parent bd0ce3e340
commit f15228fa58
36 changed files with 487 additions and 3847 deletions
View File
View File
@@ -1,33 +0,0 @@
import unittest
from unittest.mock import patch, MagicMock
from anthropic_telegram_inference_bot import AnthropicTelegramInferenceBot
class TestAnthropicTelegramInferenceBot(unittest.TestCase):
def setUp(self):
self.bot = AnthropicTelegramInferenceBot()
@patch('anthropic_telegram_inference_bot.Anthropic')
def test_get_chat_response(self, MockAnthropic):
mock_anthropic = MockAnthropic.return_value
mock_anthropic.messages.create.return_value = MagicMock()
messages = [{"role": "user", "content": "Hello"}]
response = self.bot.get_chat_response(messages)
self.assertIsNotNone(response)
@patch('anthropic_telegram_inference_bot.Anthropic')
def test_handle_message(self, MockAnthropic):
mock_anthropic = MockAnthropic.return_value
mock_anthropic.messages.create.return_value = MagicMock(content=[MagicMock(type="message", text="response content")])
user_id = "user123"
user_message = "Hello"
response = self.bot.handle_message(user_id, user_message)
self.assertIsNotNone(response)
# Additional testing for error cases and edge cases
if __name__ == '__main__':
unittest.main()
@@ -1,33 +0,0 @@
import unittest
from base_telegram_inference_bot import BaseTelegramInferenceBot
class TestBaseTelegramInferenceBot(unittest.TestCase):
def setUp(self):
# Initialize the bot or mock any dependencies here
self.bot = BaseTelegramInferenceBot()
def test_load_system_prompt(self):
# Example test case for load_system_prompt method
result = self.bot.load_system_prompt()
self.assertIsNotNone(result) # Replace with actual expected result
def test_load_functions(self):
# Test the load_functions method
functions = self.bot.load_functions()
self.assertIsInstance(functions, list) # Replace with actual expected result
self.assertTrue(len(functions) > 0) # Assuming it should load some functions
def test_clear_conversation(self):
# Test the clear_conversation method
self.bot.clear_conversation()
self.assertEqual(self.bot.conversations, {}) # Assuming conversations is a dictionary
def test_call_tool(self):
# Test the call_tool method
tool_name = "some_tool"
params = {"param1": "value1"}
result = self.bot.call_tool(tool_name, params)
self.assertIsNotNone(result) # Replace with actual expected result
if __name__ == '__main__':
unittest.main()
@@ -1,38 +0,0 @@
import unittest
from unittest.mock import patch, MagicMock
from chatgpt_telegram_inference_bot import ChatGPTTelegramInferenceBot
class TestChatGPTTelegramInferenceBot(unittest.TestCase):
def setUp(self):
self.bot = ChatGPTTelegramInferenceBot()
@patch('chatgpt_telegram_inference_bot.OpenAI')
def test_get_chat_response(self, MockOpenAI):
mock_ai = MockOpenAI.return_value
mock_ai.chat.completions.create.return_value = MagicMock()
messages = [{"role": "user", "content": "Hello"}]
response = self.bot.get_chat_response(messages)
self.assertIsNotNone(response)
@patch('chatgpt_telegram_inference_bot.OpenAI')
def test_handle_message(self, MockOpenAI):
mock_ai = MockOpenAI.return_value
mock_ai.chat.completions.create.return_value = MagicMock(choices=[MagicMock(message={"content": "response content"}, finish_reason='stop')])
user_id = "user123"
user_message = "Hello"
response = self.bot.handle_message(user_id, user_message)
self.assertIsNotNone(response)
def test_switch_model(self):
initial_model = self.bot.model
self.bot.switch_model()
self.assertNotEqual(initial_model, self.bot.model)
# Additional testing for error cases and edge cases
if __name__ == '__main__':
unittest.main()
View File
@@ -1,280 +0,0 @@
import unittest
from unittest.mock import MagicMock, patch, AsyncMock, ANY
import os
# Assuming anthropic_telegram_inference_bot.py is in the parent directory or PYTHONPATH is set
from anthropic_telegram_inference_bot import AnthropicTelegramInferenceBot
# Mock response from Anthropic client's messages.create
def create_mock_anthropic_response(content_text=None, stop_reason="end_turn", tool_use_parts=None):
mock_response = MagicMock()
mock_response.stop_reason = stop_reason
content_blocks = []
if content_text:
text_block = MagicMock()
text_block.type = "text"
text_block.text = content_text
content_blocks.append(text_block)
if tool_use_parts:
for tu_part in tool_use_parts: # tu_part = {"id": "toolu_123", "name": "get_weather", "input": {}}
tool_block = MagicMock()
tool_block.type = "tool_use"
tool_block.id = tu_part["id"]
tool_block.name = tu_part["name"]
tool_block.input = tu_part["input"]
content_blocks.append(tool_block)
mock_response.content = content_blocks
return mock_response
class TestAnthropicTelegramInferenceBot(unittest.IsolatedAsyncioTestCase):
def setUp(self):
self.original_anthropic_api_key = os.environ.get("ANTHROPIC_API_KEY")
self.original_small_model = os.environ.get("ANTHROPIC_SMALL_MODEL")
self.original_large_model = os.environ.get("ANTHROPIC_LARGE_MODEL")
self.original_system_prompt_path = os.environ.get("SYSTEM_PROMPT_PATH")
for key in ["ANTHROPIC_API_KEY", "ANTHROPIC_SMALL_MODEL", "ANTHROPIC_LARGE_MODEL", "SYSTEM_PROMPT_PATH"]:
if os.environ.get(key):
del os.environ[key]
self.mock_anthropic_client_instance = MagicMock()
self.mock_anthropic_client_instance.messages.create = MagicMock()
def tearDown(self):
if self.original_anthropic_api_key: os.environ["ANTHROPIC_API_KEY"] = self.original_anthropic_api_key
if self.original_small_model: os.environ["ANTHROPIC_SMALL_MODEL"] = self.original_small_model
if self.original_large_model: os.environ["ANTHROPIC_LARGE_MODEL"] = self.original_large_model
if self.original_system_prompt_path: os.environ["SYSTEM_PROMPT_PATH"] = self.original_system_prompt_path
@patch('anthropic.Anthropic')
def test_init_with_anthropic_defaults_env_key(self, MockAnthropicConstructor):
MockAnthropicConstructor.return_value = self.mock_anthropic_client_instance
os.environ["ANTHROPIC_API_KEY"] = "test_anthropic_key"
bot = AnthropicTelegramInferenceBot()
MockAnthropicConstructor.assert_called_once_with(api_key="test_anthropic_key")
self.assertEqual(bot.anthropic_client, self.mock_anthropic_client_instance)
self.assertEqual(bot.model, os.environ.get("ANTHROPIC_SMALL_MODEL", "claude-3-haiku-20240307"))
self.assertEqual(bot.max_tokens, int(os.environ.get("ANTHROPIC_SMALL_MODEL_MAX_TOKENS", 2000)))
@patch('anthropic.Anthropic')
def test_init_with_provided_client_and_models(self, MockAnthropicConstructor):
preconfigured_client = MagicMock()
bot = AnthropicTelegramInferenceBot(
anthropic_client=preconfigured_client,
small_model_name="custom-small",
small_model_max_tokens=100,
large_model_name="custom-large",
large_model_max_tokens=200
)
MockAnthropicConstructor.assert_not_called()
self.assertEqual(bot.anthropic_client, preconfigured_client)
self.assertEqual(bot.model, "custom-small")
self.assertEqual(bot.max_tokens, 100)
self.assertEqual(bot.small_model_name, "custom-small")
self.assertEqual(bot.large_model_name, "custom-large")
def test_get_llm_description(self):
bot = AnthropicTelegramInferenceBot(small_model_name="claude-test", small_model_max_tokens=500)
self.assertEqual(bot.get_llm_description(), "LLM: claude-test, Max Tokens: 500")
async def test_switch_model(self):
bot = AnthropicTelegramInferenceBot(
small_model_name="claude-small", small_model_max_tokens=10,
large_model_name="claude-large", large_model_max_tokens=20
)
self.assertEqual(bot.model, "claude-small")
self.assertEqual(bot.max_tokens, 10)
status = await bot.switch_model()
self.assertEqual(bot.model, "claude-large")
self.assertEqual(bot.max_tokens, 20)
self.assertEqual(status, "Switched to model: claude-large")
status = await bot.switch_model()
self.assertEqual(bot.model, "claude-small")
self.assertEqual(bot.max_tokens, 10)
self.assertEqual(status, "Switched to model: claude-small")
def test_get_chat_response_success_text_only(self):
bot = AnthropicTelegramInferenceBot(anthropic_client=self.mock_anthropic_client_instance)
bot.model = "test-claude"
bot.max_tokens = 150
mock_api_response = create_mock_anthropic_response(content_text="Hello from Anthropic API")
self.mock_anthropic_client_instance.messages.create.return_value = mock_api_response
messages = [{"role": "user", "content": "Hi"}] # Anthropic format
response = bot.get_chat_response(messages, []) # tools = empty list
self.mock_anthropic_client_instance.messages.create.assert_called_once_with(
model="test-claude",
max_tokens=150,
messages=messages,
system=bot.system_prompt, # Ensure system prompt is passed
tools=None, # No tools passed to API if empty list or None
tool_choice=None
)
self.assertEqual(response, mock_api_response)
def test_get_chat_response_with_tools(self):
bot = AnthropicTelegramInferenceBot(anthropic_client=self.mock_anthropic_client_instance)
bot.model = "claude-toolmaster"
bot.max_tokens = 300
mock_tools_spec = [{"name": "get_weather", "description": "Gets weather", "input_schema": {"type": "object", "properties": {}}}]
mock_api_response = create_mock_anthropic_response(content_text="Thinking...", tool_use_parts=[
{"id": "tool1", "name": "get_weather", "input": {"location": "here"}}
])
self.mock_anthropic_client_instance.messages.create.return_value = mock_api_response
messages = [{"role": "user", "content": "Weather?"}]
response = bot.get_chat_response(messages, mock_tools_spec)
self.mock_anthropic_client_instance.messages.create.assert_called_once_with(
model="claude-toolmaster",
max_tokens=300,
messages=messages,
system=bot.system_prompt,
tools=mock_tools_spec,
tool_choice={"type": "auto"}
)
self.assertEqual(response.content[0].type, "text") # First part can be text
self.assertEqual(response.content[1].type, "tool_use")
def test_get_chat_response_api_error(self):
bot = AnthropicTelegramInferenceBot(anthropic_client=self.mock_anthropic_client_instance)
self.mock_anthropic_client_instance.messages.create.side_effect = Exception("Anthropic API Down")
with self.assertRaisesRegex(Exception, "Anthropic API Down"):
bot.get_chat_response([{"role": "user", "content": "trigger"}], [])
async def test_handle_message_simple_response_no_tools(self):
# This test is more involved as it touches BaseTelegramInferenceBot's handle_message structure
# which then calls the overridden get_chat_response.
bot = AnthropicTelegramInferenceBot(anthropic_client=self.mock_anthropic_client_instance)
bot.system_prompt = "System prompt for Anthropic"
# Mock get_chat_response directly to isolate its behavior from full handle_message logic of base
# However, the point of this bot is its get_chat_response and subsequent processing.
# So, let's mock the API call within get_chat_response.
api_response = create_mock_anthropic_response(content_text="Anthropic says hello.")
self.mock_anthropic_client_instance.messages.create.return_value = api_response
# Ensure functions are empty for this test, so no tool logic is triggered
bot.functions = []
bot.tools = []
response_content = await bot.handle_message(user_id=101, user_message="Hello Anthropic")
self.assertEqual(response_content, "Anthropic says hello.")
self.assertIn(101, bot.conversation_history)
# Anthropic's handle_message structure:
# 1. User message added to history.
# 2. get_chat_response is called.
# 3. Response content (text) is extracted.
# 4. Assistant text response is added to history.
# Expected history: [User, Assistant_Text_Response] (system prompt handled by get_chat_response)
# The base class handle_message adds system prompt if not present.
# Anthropic handle_message modifies history format before calling get_chat_response.
# Let's trace Base.handle_message -> Anthropic.handle_message -> Anthropic.get_chat_response
# Base.handle_message:
# - Adds system prompt to history if first turn: `self.conversation_history[user_id] = [{"role": "system", "content": self.system_prompt}]` (OpenAI style)
# - Appends user message: `{"role": "user", "content": user_message}`
# - Calls self.get_chat_response(messages, self.functions) -> This is Anthropic's get_chat_response
# Anthropic.get_chat_response:
# - Takes OpenAI style `messages` and `self.functions` (tool specs).
# - Calls `anthropic_client.messages.create` with Anthropic style messages and system prompt.
# Anthropic.handle_message (overridden):
# - Prepares Anthropic-style messages from conversation_history (which is OpenAI style from Base)
# - Calls get_chat_response with these Anthropic messages and self.functions (tool_specs)
# - Processes response, extracts text, handles tool calls.
# - Appends *user* message (original) and *assistant* text response to self.conversation_history (OpenAI style).
# For this test, we are calling AnthropicBot.handle_message directly.
# 1. `user_id` not in `self.conversation_history`: `system_prompt` not added yet by Base logic.
# Anthropic's `handle_message` will create `anthropic_messages` from this.
# If `conversation_history` is empty, `anthropic_messages` = `[{"role": "user", "content": user_message}]`
# 2. `get_chat_response` called with `anthropic_messages` and `bot.system_prompt` passed to API.
# 3. Response "Anthropic says hello."
# 4. Original `user_message` and "Anthropic says hello." (as assistant) added to `self.conversation_history`.
history = bot.conversation_history[101]
self.assertEqual(len(history), 2) # User, Assistant
self.assertEqual(history[0]["role"], "user")
self.assertEqual(history[0]["content"], "Hello Anthropic")
self.assertEqual(history[1]["role"], "assistant")
self.assertEqual(history[1]["content"], "Anthropic says hello.")
# Check API call (made by the mocked get_chat_response indirectly)
self.mock_anthropic_client_instance.messages.create.assert_called_once()
call_args = self.mock_anthropic_client_instance.messages.create.call_args
self.assertEqual(call_args.kwargs["system"], "System prompt for Anthropic")
# Initial messages for API should just be the user message for first turn
self.assertEqual(call_args.kwargs["messages"], [{"role": "user", "content": "Hello Anthropic"}])
async def test_handle_message_with_tool_calls(self):
bot = AnthropicTelegramInferenceBot(anthropic_client=self.mock_anthropic_client_instance)
bot.system_prompt = "You are a helpful, tool-using assistant."
# Define a tool for the bot (OpenAI format, will be converted by Anthropic bot for API)
mock_tool_oai_format = {"type": "function", "function": {"name": "get_weather", "description": "Get weather", "parameters": {}}}
bot.functions = [mock_tool_oai_format] # This is used to generate anthropic_tools for API
# API Response 1: Request for tool call
tool_use_part = {"id": "toolu_xyz", "name": "get_weather", "input": {"location": "paris"}}
api_response_1 = create_mock_anthropic_response(tool_use_parts=[tool_use_part])
# API Response 2: Final text response after tool execution
api_response_2 = create_mock_anthropic_response(content_text="The weather in Paris is nice.")
self.mock_anthropic_client_instance.messages.create.side_effect = [api_response_1, api_response_2]
# Mock the bot's call_tool method (from BaseTelegramInferenceBot)
bot.call_tool = MagicMock(return_value='''{"weather": "sunny"}''') # Tool execution result
user_id = 102
user_message = "What's the weather in Paris?"
final_text_response = await bot.handle_message(user_id, user_message)
self.assertEqual(final_text_response, "The weather in Paris is nice.")
self.assertEqual(self.mock_anthropic_client_instance.messages.create.call_count, 2)
bot.call_tool.assert_called_once_with("get_weather", {"location": "paris"}) # Anthropic passes input as dict
# Check conversation history (OpenAI style)
history = bot.conversation_history[user_id]
self.assertEqual(history[0]["role"], "user")
self.assertEqual(history[0]["content"], user_message)
# Assistant message that requested tool call (Anthropic-specific format stored by its handle_message)
# Anthropic's handle_message appends the raw tool_use block and then the tool_result
self.assertEqual(history[1]["role"], "assistant")
self.assertTrue(isinstance(history[1]["content"], list)) # Anthropic content is a list
self.assertEqual(history[1]["content"][0]["type"], "tool_use")
self.assertEqual(history[1]["content"][0]["id"], "toolu_xyz")
self.assertEqual(history[2]["role"], "tool")
self.assertEqual(history[2]["tool_call_id"], "toolu_xyz")
self.assertEqual(history[2]["name"], "get_weather")
self.assertEqual(history[2]["content"], '''{"weather": "sunny"}''') # call_tool result
self.assertEqual(history[3]["role"], "assistant") # Final text response
self.assertTrue(isinstance(history[3]["content"], str)) # simple text
self.assertEqual(history[3]["content"], "The weather in Paris is nice.")
if __name__ == '__main__':
unittest.main()
-310
View File
@@ -1,310 +0,0 @@
import unittest
from unittest.mock import patch, mock_open, MagicMock
import os
import json
# Ensure the path includes the directory where base_telegram_inference_bot is located
# This might require adjustment based on actual project structure if tests are run from root
# For now, assuming it can be imported directly or via PYTHONPATH
from base_telegram_inference_bot import BaseTelegramInferenceBot
from tools.base_tool import BaseTool # For mocking tool structure
# Create a concrete subclass for testing, as BaseTelegramInferenceBot is abstract
class ConcreteTestBot(BaseTelegramInferenceBot):
def __init__(self, system_prompt_content=None, system_prompt_path=None, mock_tools=None, mock_functions=None):
# Mock load_functions during super().__init__ if needed, or control tools/functions directly
self._mock_tools = mock_tools if mock_tools is not None else []
self._mock_functions = mock_functions if mock_functions is not None else []
super().__init__(system_prompt_content=system_prompt_content, system_prompt_path=system_prompt_path)
# Override load_functions to use mocks
def load_functions(self):
return self._mock_tools, self._mock_functions
def get_chat_response(self, messages):
pass # Abstract method, not tested here directly
async def handle_message(self, user_id, user_message):
pass # Abstract method
def get_llm_description(self) -> str:
return "Mock LLM Description" # Concrete implementation for testing get_bot_status
async def start(self):
pass # Abstract method
async def abort_processing(self, user_id):
pass # Abstract method
async def switch_model(self):
pass # Abstract method
class TestBaseTelegramInferenceBot(unittest.TestCase):
def setUp(self):
# Reset relevant environment variables before each test
self.original_system_prompt_path = os.environ.get("SYSTEM_PROMPT_PATH")
if "SYSTEM_PROMPT_PATH" in os.environ:
del os.environ["SYSTEM_PROMPT_PATH"]
def tearDown(self):
# Restore environment variables
if self.original_system_prompt_path:
os.environ["SYSTEM_PROMPT_PATH"] = self.original_system_prompt_path
elif "SYSTEM_PROMPT_PATH" in os.environ: # Ensure it's removed if test set it and it wasn't there before
del os.environ["SYSTEM_PROMPT_PATH"]
def test_init_with_direct_system_prompt(self):
bot = ConcreteTestBot(system_prompt_content="Direct prompt content")
self.assertEqual(bot.system_prompt, "Direct prompt content")
@patch("os.path.isfile")
@patch("builtins.open", new_callable=mock_open, read_data="File prompt content")
def test_init_with_system_prompt_path_argument(self, mock_file_open, mock_isfile):
mock_isfile.return_value = True
bot = ConcreteTestBot(system_prompt_path="dummy/path.txt")
self.assertEqual(bot.system_prompt, "File prompt content")
mock_isfile.assert_called_once_with("dummy/path.txt")
mock_file_open.assert_called_once_with("dummy/path.txt", "r", encoding="utf-8")
@patch("os.path.isfile")
@patch("builtins.open", new_callable=mock_open, read_data="Env prompt content")
def test_init_with_env_system_prompt_path(self, mock_file_open, mock_isfile):
mock_isfile.return_value = True
os.environ["SYSTEM_PROMPT_PATH"] = "env/path.txt"
bot = ConcreteTestBot()
self.assertEqual(bot.system_prompt, "Env prompt content")
mock_isfile.assert_called_once_with("env/path.txt")
mock_file_open.assert_called_once_with("env/path.txt", "r", encoding="utf-8")
def test_init_with_default_system_prompt(self):
# Ensure ENV var is not set for this test
if "SYSTEM_PROMPT_PATH" in os.environ:
del os.environ["SYSTEM_PROMPT_PATH"]
bot = ConcreteTestBot()
self.assertEqual(bot.system_prompt, "You are a helpful AI assistant.")
@patch("os.path.isfile", return_value=False)
def test_init_with_invalid_system_prompt_path(self, mock_isfile):
bot = ConcreteTestBot(system_prompt_path="invalid/path.txt")
self.assertEqual(bot.system_prompt, "You are a helpful AI assistant.")
mock_isfile.assert_called_once_with("invalid/path.txt")
@patch("os.path.isfile")
@patch("builtins.open", side_effect=IOError("File read error"))
def test_init_with_system_prompt_file_read_error(self, mock_file_open, mock_isfile):
mock_isfile.return_value = True
bot = ConcreteTestBot(system_prompt_path="dummy/path.txt")
self.assertEqual(bot.system_prompt, "You are a helpful AI assistant.")
def test_clear_conversation_history(self):
mock_tool_instance = MagicMock(spec=BaseTool)
bot = ConcreteTestBot(mock_tools=[mock_tool_instance])
bot.conversation_history[123] = [{"role": "user", "content": "Hello"}]
bot.clear_conversation_history(123)
self.assertNotIn(123, bot.conversation_history)
mock_tool_instance.clear.assert_called_once()
def test_clear_conversation_history_user_not_found(self):
mock_tool_instance = MagicMock(spec=BaseTool)
bot = ConcreteTestBot(mock_tools=[mock_tool_instance])
bot.clear_conversation_history(404)
self.assertNotIn(404, bot.conversation_history)
mock_tool_instance.clear.assert_called_once()
def test_processing_status(self):
bot = ConcreteTestBot()
self.assertEqual(bot.processing_status, {})
bot.set_processing_status(123, 789)
self.assertEqual(bot.processing_status[123], {"processing": True, "message_id": 789})
bot.clear_processing_status(123)
self.assertNotIn(123, bot.processing_status)
def test_clear_processing_status_user_not_found(self):
bot = ConcreteTestBot()
bot.clear_processing_status(404)
self.assertNotIn(404, bot.processing_status)
def test_call_tool_success_dict_args(self):
mock_tool = MagicMock(spec=BaseTool)
mock_tool.get_functions.return_value = [
{"function": {"name": "test_tool", "description": "A test tool", "parameters": {}}}
]
mock_tool.execute.return_value = "Tool executed successfully"
bot = ConcreteTestBot(mock_tools=[mock_tool], mock_functions=mock_tool.get_functions())
result = bot.call_tool("test_tool", {"arg1": "value1"})
self.assertEqual(result, "Tool executed successfully")
mock_tool.execute.assert_called_once_with("test_tool", arg1="value1")
def test_call_tool_success_json_string_args(self):
mock_tool = MagicMock(spec=BaseTool)
mock_tool.get_functions.return_value = [
{"function": {"name": "test_tool_json", "parameters": {}}}
]
mock_tool.execute.return_value = "Tool JSON OK"
bot = ConcreteTestBot(mock_tools=[mock_tool], mock_functions=mock_tool.get_functions())
args_json_str = '''{"param": "value"}'''
result = bot.call_tool("test_tool_json", args_json_str)
self.assertEqual(result, "Tool JSON OK")
mock_tool.execute.assert_called_once_with("test_tool_json", param="value")
def test_call_tool_malformed_json_string_args(self):
bot = ConcreteTestBot(mock_tools=[])
args_malformed_json_str = '''{"param": "value"'''
result = bot.call_tool("some_tool", args_malformed_json_str)
self.assertTrue("Error: Malformed arguments for tool call" in result)
def test_call_tool_unexpected_arg_type(self):
bot = ConcreteTestBot(mock_tools=[])
result = bot.call_tool("some_tool", 12345) # Integer instead of dict/str
self.assertTrue("Error: Invalid argument type for tool call" in result)
def test_call_tool_none_args(self):
mock_tool = MagicMock(spec=BaseTool)
mock_tool.get_functions.return_value = [
{"function": {"name": "test_tool_none", "parameters": {}}}
]
mock_tool.execute.return_value = "Tool None OK"
bot = ConcreteTestBot(mock_tools=[mock_tool], mock_functions=mock_tool.get_functions())
result = bot.call_tool("test_tool_none", None)
self.assertEqual(result, "Tool None OK")
mock_tool.execute.assert_called_once_with("test_tool_none") # No kwargs if None
def test_call_tool_not_found(self):
bot = ConcreteTestBot(mock_tools=[])
result = bot.call_tool("non_existent_tool", {})
self.assertEqual(result, "Error: Tool function non_existent_tool not found.")
def test_call_tool_execute_exception(self):
mock_tool = MagicMock(spec=BaseTool)
mock_tool.get_functions.return_value = [{"function": {"name": "error_tool", "parameters": {}}}]
mock_tool.execute.side_effect = Exception("Execution failed")
bot = ConcreteTestBot(mock_tools=[mock_tool], mock_functions=mock_tool.get_functions())
result = bot.call_tool("error_tool", {})
self.assertEqual(result, "Error executing tool error_tool: Execution failed")
def test_get_system_prompt_description(self):
if "SYSTEM_PROMPT_PATH" in os.environ: # Ensure clean state
del os.environ["SYSTEM_PROMPT_PATH"]
bot_default = ConcreteTestBot()
self.assertEqual(bot_default.get_system_prompt_description(), "System Prompt: Default")
bot_custom_content = ConcreteTestBot(system_prompt_content="Custom content here")
self.assertEqual(bot_custom_content.get_system_prompt_description(), "System Prompt: Custom")
os.environ["SYSTEM_PROMPT_PATH"] = "some/path.txt"
bot_env_default_prompt = ConcreteTestBot() # system_prompt itself is default
self.assertEqual(bot_env_default_prompt.get_system_prompt_description(), "System Prompt: Custom (via ENV)")
with patch("os.path.isfile", return_value=True), \
patch("builtins.open", mock_open(read_data="File prompt from ENV")):
bot_env_file_prompt = ConcreteTestBot() # system_prompt gets loaded from ENV path
self.assertEqual(bot_env_file_prompt.get_system_prompt_description(), "System Prompt: Custom")
del os.environ["SYSTEM_PROMPT_PATH"]
with patch("os.path.isfile", return_value=True), \
patch("builtins.open", mock_open(read_data="File prompt from arg")):
bot_custom_file_arg = ConcreteTestBot(system_prompt_path="custom/file.txt")
self.assertEqual(bot_custom_file_arg.get_system_prompt_description(), "System Prompt: Custom")
@patch.object(ConcreteTestBot, 'get_llm_description', return_value="Test LLM Description")
@patch.object(ConcreteTestBot, 'get_system_prompt_description', return_value="Test Prompt Description")
async def test_get_bot_status(self, mock_prompt_desc, mock_llm_desc):
bot = ConcreteTestBot()
status = await bot.get_bot_status()
self.assertEqual(status, "Test Prompt Description\nTest LLM Description")
mock_prompt_desc.assert_called_once()
mock_llm_desc.assert_called_once()
@patch('os.path.dirname', return_value='/mock/path')
@patch('os.path.join')
@patch('os.path.exists')
@patch('os.listdir')
@patch('importlib.import_module')
def test_load_functions_no_tools_dir(self, mock_import_module, mock_listdir, mock_exists, mock_join, mock_dirname):
mock_join.return_value = '/mock/path/tools'
mock_exists.return_value = False
class BotForLoadTest(BaseTelegramInferenceBot):
load_system_prompt = MagicMock(return_value="Default")
get_chat_response = MagicMock(); handle_message = MagicMock(); get_llm_description = MagicMock(return_value="mock")
start = MagicMock(); abort_processing = MagicMock(); switch_model = MagicMock()
bot = BotForLoadTest()
self.assertEqual(bot.tools, [])
self.assertEqual(bot.functions, [])
mock_listdir.assert_not_called()
@patch('os.path.dirname', return_value='/mock/base_bot_dir')
@patch('os.path.join', side_effect=lambda *args: os.path.normpath(os.path.join(*args)))
@patch('os.path.exists', return_value=True)
@patch('os.listdir', return_value=['my_tool.py', '__init__.py', 'base_tool.py'])
@patch('importlib.import_module')
def test_load_functions_with_one_tool(self, mock_import_module, mock_listdir, mock_exists, mock_join, mock_dirname):
mock_tool_class = MagicMock(spec=BaseTool) # This is the class itself
mock_tool_instance = MagicMock(spec=BaseTool) # This is the instance
mock_tool_class.return_value = mock_tool_instance # mock_tool_class() creates mock_tool_instance
mock_tool_instance.get_functions.return_value = [{"function": {"name": "sample_function"}}]
mock_my_tool_module = MagicMock()
# Simulate inspect.getmembers behavior: returns list of (name, member) tuples
# Only include members that are classes, derive from BaseTool, and are not BaseTool itself.
mock_my_tool_module.ValidToolClass = mock_tool_class
mock_my_tool_module.NotATool = object()
mock_my_tool_module.BaseTool = BaseTool # This should be skipped by the loader
def import_side_effect(module_name):
if module_name == 'tools.my_tool':
return mock_my_tool_module
raise ImportError(f"Unexpected import: {module_name}")
mock_import_module.side_effect = import_side_effect
class BotForLoadTest(BaseTelegramInferenceBot):
load_system_prompt = MagicMock(return_value="Default")
get_chat_response = MagicMock(); handle_message = MagicMock(); get_llm_description = MagicMock(return_value="mock")
start = MagicMock(); abort_processing = MagicMock(); switch_model = MagicMock()
bot = BotForLoadTest()
self.assertEqual(len(bot.tools), 1)
self.assertIs(bot.tools[0], mock_tool_instance)
self.assertEqual(len(bot.functions), 1)
self.assertEqual(bot.functions[0]['function']['name'], "sample_function")
mock_import_module.assert_called_once_with('tools.my_tool')
mock_tool_class.assert_called_once_with() # Tool class was instantiated
mock_tool_instance.get_functions.assert_called_once_with()
@patch('os.path.dirname', return_value='/mock/base_bot_dir')
@patch('os.path.join', side_effect=lambda *args: os.path.normpath(os.path.join(*args)))
@patch('os.path.exists', return_value=True)
@patch('os.listdir', return_value=['tool_with_init_error.py'])
@patch('importlib.import_module')
@patch('logging.error') # Mock logging to check for error messages
def test_load_functions_tool_instantiation_error(self, mock_logging_error, mock_import_module, mock_listdir, mock_exists, mock_join, mock_dirname):
mock_tool_class_init_error = MagicMock(spec=BaseTool)
mock_tool_class_init_error.side_effect = Exception("Failed to init tool") # Error on instantiation
mock_error_tool_module = MagicMock()
mock_error_tool_module.ToolWithInitError = mock_tool_class_init_error
mock_import_module.return_value = mock_error_tool_module
class BotForLoadTest(BaseTelegramInferenceBot):
load_system_prompt = MagicMock(return_value="Default")
get_chat_response = MagicMock(); handle_message = MagicMock(); get_llm_description = MagicMock(return_value="mock")
start = MagicMock(); abort_processing = MagicMock(); switch_model = MagicMock()
bot = BotForLoadTest()
self.assertEqual(len(bot.tools), 0)
self.assertEqual(len(bot.functions), 0)
mock_logging_error.assert_any_call("Error instantiating tool ToolWithInitError from tool_with_init_error.py: Failed to init tool")
if __name__ == '__main__':
unittest.main(闂傚лен䦗婢у埊鍓解劓姣)
@@ -1,158 +0,0 @@
import unittest
from unittest.mock import MagicMock, patch, ANY
import os
# Assuming chatgpt_telegram_inference_bot.py and its parent are accessible
from chatgpt_telegram_inference_bot import ChatGPTTelegramInferenceBot
from openai_compatible_inference_bot import OpenAICompatibleInferenceBot # For patching super
class TestChatGPTTelegramInferenceBot(unittest.IsolatedAsyncioTestCase):
def setUp(self):
# Store and clear relevant environment variables
self.original_openai_key = os.environ.get("OPENAI_API_KEY")
self.original_small_model = os.environ.get("OPENAI_SMALL_MODEL")
self.original_large_model = os.environ.get("OPENAI_LARGE_MODEL")
self.original_small_tokens = os.environ.get("OPENAI_SMALL_MODEL_MAX_TOKENS")
self.original_large_tokens = os.environ.get("OPENAI_LARGE_MODEL_MAX_TOKENS")
self.original_system_prompt_path = os.environ.get("SYSTEM_PROMPT_PATH")
for key in ["OPENAI_API_KEY", "OPENAI_SMALL_MODEL", "OPENAI_LARGE_MODEL",
"OPENAI_SMALL_MODEL_MAX_TOKENS", "OPENAI_LARGE_MODEL_MAX_TOKENS", "SYSTEM_PROMPT_PATH"]:
if os.environ.get(key):
del os.environ[key]
# Mock the OpenAI client that OpenAICompatibleInferenceBot's __init__ might create
self.mock_openai_client = MagicMock()
def tearDown(self):
# Restore environment variables
if self.original_openai_key: os.environ["OPENAI_API_KEY"] = self.original_openai_key
if self.original_small_model: os.environ["OPENAI_SMALL_MODEL"] = self.original_small_model
if self.original_large_model: os.environ["OPENAI_LARGE_MODEL"] = self.original_large_model
if self.original_small_tokens: os.environ["OPENAI_SMALL_MODEL_MAX_TOKENS"] = self.original_small_tokens
if self.original_large_tokens: os.environ["OPENAI_LARGE_MODEL_MAX_TOKENS"] = self.original_large_tokens
if self.original_system_prompt_path: os.environ["SYSTEM_PROMPT_PATH"] = self.original_system_prompt_path
@patch.object(OpenAICompatibleInferenceBot, '__init__') # Mock the superclass's __init__
def test_init_defaults_and_super_call(self, mock_super_init):
os.environ["OPENAI_API_KEY"] = "test_key_chatgpt"
os.environ["OPENAI_SMALL_MODEL"] = "gpt-3.5-turbo-env"
os.environ["OPENAI_SMALL_MODEL_MAX_TOKENS"] = "350"
bot = ChatGPTTelegramInferenceBot()
mock_super_init.assert_called_once_with(
client=None, # ChatGPT bot will let superclass create it
api_key="test_key_chatgpt", # Passed to super
base_url=None,
api_version=None,
azure_deployment=None,
model_name="gpt-3.5-turbo-env", # Default small model from env
max_tokens_str="350", # Default small model tokens from env
small_model_name="gpt-3.5-turbo-env",
small_model_max_tokens_str="350",
large_model_name=os.environ.get("OPENAI_LARGE_MODEL", "gpt-4-turbo-preview"), # Default large
large_model_max_tokens_str=os.environ.get("OPENAI_LARGE_MODEL_MAX_TOKENS"),
system_prompt_content=None,
system_prompt_path=None,
is_gemini=False,
max_history_length=20 # Default from OpenAICompatibleInferenceBot
)
@patch.object(OpenAICompatibleInferenceBot, '__init__')
def test_init_with_arguments(self, mock_super_init):
mock_client_arg = MagicMock()
bot = ChatGPTTelegramInferenceBot(
openai_client=mock_client_arg,
api_key="arg_key",
small_model_name="arg_small_model",
small_model_max_tokens="123",
large_model_name="arg_large_model",
large_model_max_tokens="456",
system_prompt_content="Arg prompt"
)
mock_super_init.assert_called_once_with(
client=mock_client_arg,
api_key="arg_key",
base_url=None,
api_version=None,
azure_deployment=None,
model_name="arg_small_model", # Initially configured with small model
max_tokens_str="123",
small_model_name="arg_small_model",
small_model_max_tokens_str="123",
large_model_name="arg_large_model",
large_model_max_tokens_str="456",
system_prompt_content="Arg prompt",
system_prompt_path=None,
is_gemini=False,
max_history_length=20
)
# Test switch_model - this method is part of ChatGPTTelegramInferenceBot
# It calls _configure_model_and_tokens which is in the superclass.
# We need a bot instance where _configure_model_and_tokens can be called.
@patch('openai.OpenAI') # To allow instantiation of the bot by mocking client creation
async def test_switch_model_logic(self, mock_openai_constructor):
mock_openai_constructor.return_value = self.mock_openai_client # Mock client creation in super
# Set env vars for model names that switch_model will use as fallback
os.environ["OPENAI_SMALL_MODEL"] = "env-small-gpt"
os.environ["OPENAI_SMALL_MODEL_MAX_TOKENS"] = "100"
os.environ["OPENAI_LARGE_MODEL"] = "env-large-gpt"
os.environ["OPENAI_LARGE_MODEL_MAX_TOKENS"] = "200"
# Instantiate with initial model (small)
bot = ChatGPTTelegramInferenceBot()
self.assertEqual(bot.model, "env-small-gpt")
self.assertEqual(bot.max_tokens, 100)
# Switch to large
status = await bot.switch_model()
self.assertEqual(bot.model, "env-large-gpt")
self.assertEqual(bot.max_tokens, 200)
self.assertEqual(status, "Switched to model: env-large-gpt")
# Switch back to small
status = await bot.switch_model()
self.assertEqual(bot.model, "env-small-gpt")
self.assertEqual(bot.max_tokens, 100)
self.assertEqual(status, "Switched to model: env-small-gpt")
@patch('openai.OpenAI')
async def test_switch_model_uses_instance_configs_if_provided(self, mock_openai_constructor):
mock_openai_constructor.return_value = self.mock_openai_client
# Instantiate with specific model names, overriding potential env vars
bot = ChatGPTTelegramInferenceBot(
small_model_name="init-small", small_model_max_tokens="50",
large_model_name="init-large", large_model_max_tokens="150"
)
self.assertEqual(bot.model, "init-small") # Starts with small
self.assertEqual(bot.max_tokens, 50)
# Switch to large
status = await bot.switch_model()
self.assertEqual(bot.model, "init-large")
self.assertEqual(bot.max_tokens, 150)
self.assertEqual(status, "Switched to model: init-large")
# Switch back to small
status = await bot.switch_model()
self.assertEqual(bot.model, "init-small")
self.assertEqual(bot.max_tokens, 50)
self.assertEqual(status, "Switched to model: init-small")
# get_llm_description is inherited from OpenAICompatibleInferenceBot.
# Test just to ensure it works in the context of a ChatGPTBot instance
@patch('openai.OpenAI')
def test_get_llm_description_for_chatgpt_bot(self, mock_openai_constructor):
mock_openai_constructor.return_value = self.mock_openai_client
bot = ChatGPTTelegramInferenceBot(small_model_name="gpt-3.5-desc", small_model_max_tokens="777")
# Initially configured with small model
self.assertEqual(bot.get_llm_description(), "LLM: gpt-3.5-desc, Max Tokens: 777, Azure: False")
if __name__ == '__main__':
unittest.main()
-154
View File
@@ -1,154 +0,0 @@
import unittest
from unittest.mock import MagicMock, patch, ANY
import os
# Assuming gemini_telegram_inference_bot.py and its parent are accessible
from gemini_telegram_inference_bot import GeminiTelegramInferenceBot
from openai_compatible_inference_bot import OpenAICompatibleInferenceBot # For patching super
class TestGeminiTelegramInferenceBot(unittest.IsolatedAsyncioTestCase):
def setUp(self):
# Store and clear relevant environment variables
self.original_gemini_key = os.environ.get("GEMINI_API_KEY")
self.original_gemini_base_url = os.environ.get("GEMINI_API_BASE_URL")
self.original_small_model = os.environ.get("GEMINI_SMALL_MODEL")
self.original_large_model = os.environ.get("GEMINI_LARGE_MODEL")
self.original_small_tokens = os.environ.get("GEMINI_SMALL_MODEL_MAX_TOKENS")
self.original_large_tokens = os.environ.get("GEMINI_LARGE_MODEL_MAX_TOKENS")
self.original_system_prompt_path = os.environ.get("SYSTEM_PROMPT_PATH")
for key in ["GEMINI_API_KEY", "GEMINI_API_BASE_URL", "GEMINI_SMALL_MODEL", "GEMINI_LARGE_MODEL",
"GEMINI_SMALL_MODEL_MAX_TOKENS", "GEMINI_LARGE_MODEL_MAX_TOKENS", "SYSTEM_PROMPT_PATH"]:
if os.environ.get(key):
del os.environ[key]
self.mock_openai_client = MagicMock() # Used if superclass creates an OpenAI client
def tearDown(self):
# Restore environment variables
if self.original_gemini_key: os.environ["GEMINI_API_KEY"] = self.original_gemini_key
if self.original_gemini_base_url: os.environ["GEMINI_API_BASE_URL"] = self.original_gemini_base_url
if self.original_small_model: os.environ["GEMINI_SMALL_MODEL"] = self.original_small_model
if self.original_large_model: os.environ["GEMINI_LARGE_MODEL"] = self.original_large_model
if self.original_small_tokens: os.environ["GEMINI_SMALL_MODEL_MAX_TOKENS"] = self.original_small_tokens
if self.original_large_tokens: os.environ["GEMINI_LARGE_MODEL_MAX_TOKENS"] = self.original_large_tokens
if self.original_system_prompt_path: os.environ["SYSTEM_PROMPT_PATH"] = self.original_system_prompt_path
@patch.object(OpenAICompatibleInferenceBot, '__init__') # Mock the superclass's __init__
def test_init_defaults_and_super_call(self, mock_super_init):
os.environ["GEMINI_API_KEY"] = "test_key_gemini"
os.environ["GEMINI_API_BASE_URL"] = "https://gemini.env.com"
os.environ["GEMINI_SMALL_MODEL"] = "gemini-pro-env"
os.environ["GEMINI_SMALL_MODEL_MAX_TOKENS"] = "360"
bot = GeminiTelegramInferenceBot()
mock_super_init.assert_called_once_with(
client=None,
api_key="test_key_gemini",
base_url="https://gemini.env.com", # Passed to super
api_version=None,
azure_deployment=None,
model_name="gemini-pro-env",
max_tokens_str="360",
small_model_name="gemini-pro-env",
small_model_max_tokens_str="360",
large_model_name=os.environ.get("GEMINI_LARGE_MODEL", "gemini-1.5-pro-latest"), # Default large
large_model_max_tokens_str=os.environ.get("GEMINI_LARGE_MODEL_MAX_TOKENS"),
system_prompt_content=None,
system_prompt_path=None,
is_gemini=True, # Important for Gemini bot
max_history_length=20
)
@patch.object(OpenAICompatibleInferenceBot, '__init__')
def test_init_with_arguments(self, mock_super_init):
mock_client_arg = MagicMock()
bot = GeminiTelegramInferenceBot(
openai_client=mock_client_arg, # Name in Gemini bot is openai_client for consistency
api_key="arg_gem_key",
base_url="https://arg.gemini.com",
small_model_name="arg_gem_small",
small_model_max_tokens="124",
large_model_name="arg_gem_large",
large_model_max_tokens="457",
system_prompt_content="Gemini prompt"
)
mock_super_init.assert_called_once_with(
client=mock_client_arg,
api_key="arg_gem_key",
base_url="https://arg.gemini.com",
api_version=None,
azure_deployment=None,
model_name="arg_gem_small",
max_tokens_str="124",
small_model_name="arg_gem_small",
small_model_max_tokens_str="124",
large_model_name="arg_gem_large",
large_model_max_tokens_str="457",
system_prompt_content="Gemini prompt",
system_prompt_path=None,
is_gemini=True,
max_history_length=20
)
@patch('openai.OpenAI') # Gemini bot uses OpenAI client configured for Gemini endpoint
async def test_switch_model_logic(self, mock_openai_constructor):
mock_openai_constructor.return_value = self.mock_openai_client
os.environ["GEMINI_SMALL_MODEL"] = "env-gemini-small"
os.environ["GEMINI_SMALL_MODEL_MAX_TOKENS"] = "110"
os.environ["GEMINI_LARGE_MODEL"] = "env-gemini-large"
os.environ["GEMINI_LARGE_MODEL_MAX_TOKENS"] = "220"
bot = GeminiTelegramInferenceBot() # Uses env vars by default
self.assertEqual(bot.model, "env-gemini-small")
self.assertEqual(bot.max_tokens, 110)
status = await bot.switch_model()
self.assertEqual(bot.model, "env-gemini-large")
self.assertEqual(bot.max_tokens, 220)
self.assertEqual(status, "Switched to model: env-gemini-large")
status = await bot.switch_model()
self.assertEqual(bot.model, "env-gemini-small")
self.assertEqual(bot.max_tokens, 110)
self.assertEqual(status, "Switched to model: env-gemini-small")
@patch('openai.OpenAI')
async def test_switch_model_uses_instance_configs_if_provided(self, mock_openai_constructor):
mock_openai_constructor.return_value = self.mock_openai_client
bot = GeminiTelegramInferenceBot(
small_model_name="init-gem-small", small_model_max_tokens="55",
large_model_name="init-gem-large", large_model_max_tokens="155"
)
self.assertEqual(bot.model, "init-gem-small")
self.assertEqual(bot.max_tokens, 55)
status = await bot.switch_model()
self.assertEqual(bot.model, "init-gem-large")
self.assertEqual(bot.max_tokens, 155)
self.assertEqual(status, "Switched to model: init-gem-large")
status = await bot.switch_model()
self.assertEqual(bot.model, "init-gem-small")
self.assertEqual(bot.max_tokens, 55)
self.assertEqual(status, "Switched to model: init-gem-small")
@patch('openai.OpenAI')
def test_get_llm_description_for_gemini_bot(self, mock_openai_constructor):
mock_openai_constructor.return_value = self.mock_openai_client
bot = GeminiTelegramInferenceBot(
small_model_name="gemini-pro-desc",
small_model_max_tokens="888",
# is_gemini is True by default in constructor call to super
)
# LLM description should indicate not Azure, even though it uses OpenAICompatible... base
# The is_gemini flag primarily affects client instantiation logic in the superclass.
# The azure_openai flag in superclass is based on azure_endpoint presence.
self.assertEqual(bot.get_llm_description(), "LLM: gemini-pro-desc, Max Tokens: 888, Azure: False")
if __name__ == '__main__':
unittest.main()
-81
View File
@@ -1,81 +0,0 @@
# tests/test_github_tool.py
import unittest
from unittest.mock import patch, MagicMock
from tools.github_tool import GitHubTool
class TestGitHubTool(unittest.TestCase):
def setUp(self):
self.github_tool = GitHubTool()
def test_get_functions(self):
functions = self.github_tool.get_functions()
self.assertEqual(len(functions), 4)
function_names = [f["name"] for f in functions]
expected_names = ["read_file", "create_branch", "commit_file", "create_pull_request"]
self.assertListEqual(function_names, expected_names)
@patch('tools.github_tool.requests.get')
def test_read_file(self, mock_get):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {"content": "file content"}
mock_get.return_value = mock_response
result = self.github_tool.execute("read_file", path="test.txt")
self.assertEqual(result, "file content")
mock_get.assert_called_once()
@patch('tools.github_tool.requests.get')
@patch('tools.github_tool.requests.post')
def test_create_branch(self, mock_post, mock_get):
mock_get_response = MagicMock()
mock_get_response.status_code = 200
mock_get_response.json.return_value = {"object": {"sha": "test_sha"}}
mock_get.return_value = mock_get_response
mock_post_response = MagicMock()
mock_post_response.status_code = 201
mock_post.return_value = mock_post_response
result = self.github_tool.execute("create_branch", branch_name="test-branch")
self.assertEqual(result, "Branch 'test-branch' created successfully")
mock_get.assert_called_once()
mock_post.assert_called_once()
@patch('tools.github_tool.requests.put')
def test_commit_file(self, mock_put):
mock_response = MagicMock()
mock_response.status_code = 200
mock_put.return_value = mock_response
result = self.github_tool.execute("commit_file", branch_name="test-branch", file_path="test.txt", content="test content", commit_message="Test commit")
self.assertEqual(result, "File committed successfully to branch 'test-branch'")
mock_put.assert_called_once()
def test_commit_file_to_main(self):
result = self.github_tool.execute("commit_file", branch_name="main", file_path="test.txt", content="test content", commit_message="Test commit")
self.assertEqual(result, "Cannot commit directly to main branch")
@patch('tools.github_tool.requests.post')
def test_create_pull_request(self, mock_post):
mock_response = MagicMock()
mock_response.status_code = 201
mock_response.json.return_value = {"html_url": "https://github.com/test/test/pull/1"}
mock_post.return_value = mock_response
result = self.github_tool.execute("create_pull_request", title="Test PR", body="Test body", head="test-branch")
self.assertEqual(result, "Pull request created successfully: https://github.com/test/test/pull/1")
mock_post.assert_called_once()
def test_unknown_function(self):
result = self.github_tool.execute("unknown_function")
self.assertEqual(result, "Unknown function: unknown_function")
if __name__ == '__main__':
unittest.main()
@@ -1,332 +0,0 @@
import unittest
from unittest.mock import MagicMock, patch, AsyncMock, ANY
import os
import json
# Assuming openai_compatible_inference_bot.py is in the parent directory or PYTHONPATH is set
from openai_compatible_inference_bot import OpenAICompatibleInferenceBot
# Mock response from OpenAI client's chat.completions.create
def create_mock_openai_response(content=None, tool_calls=None):
mock_message = MagicMock()
mock_message.role = "assistant"
mock_message.content = content
if tool_calls:
# tool_calls should be a list of objects with id and function (name, arguments)
mock_tool_calls = []
for tc in tool_calls:
mock_tc = MagicMock()
mock_tc.id = tc["id"]
mock_tc.function.name = tc["function"]["name"]
mock_tc.function.arguments = tc["function"]["arguments"]
mock_tool_calls.append(mock_tc)
mock_message.tool_calls = mock_tool_calls
else:
mock_message.tool_calls = None
mock_choice = MagicMock()
mock_choice.message = mock_message
mock_response = MagicMock()
mock_response.choices = [mock_choice]
return mock_response
# Concrete class for testing
class ConcreteOpenAICompatibleBot(OpenAICompatibleInferenceBot):
# Implement abstract methods for instantiation
async def switch_model(self):
# Simple switch for testing if needed, or just pass
if self.model == self.small_model_name:
self._configure_model_and_tokens(self.large_model_name, self.large_model_max_tokens_str)
else:
self._configure_model_and_tokens(self.small_model_name, self.small_model_max_tokens_str)
return f"Switched to {self.model}"
# Override load_functions if it's called by parent and needs mocking for these tests
# (OpenAICompatibleInferenceBot's __init__ calls BaseTelegramInferenceBot's __init__, which calls load_functions)
def load_functions(self):
# For these tests, assume no tools unless specifically added
self.tools = []
self.functions = []
return self.tools, self.functions
class TestOpenAICompatibleInferenceBot(unittest.IsolatedAsyncioTestCase):
def setUp(self):
self.original_openai_api_key = os.environ.get("OPENAI_API_KEY")
self.original_azure_openai_key = os.environ.get("AZURE_OPENAI_KEY")
self.original_azure_endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT")
self.original_api_version = os.environ.get("AZURE_OPENAI_API_VERSION")
self.original_azure_deployment = os.environ.get("AZURE_DEPLOYMENT_NAME")
# Clear relevant env vars before each test
for key in ["OPENAI_API_KEY", "AZURE_OPENAI_KEY", "AZURE_OPENAI_ENDPOINT",
"AZURE_OPENAI_API_VERSION", "AZURE_DEPLOYMENT_NAME", "SYSTEM_PROMPT_PATH"]:
if os.environ.get(key):
del os.environ[key]
self.mock_openai_client_instance = MagicMock()
self.mock_openai_client_instance.chat.completions.create = MagicMock()
def tearDown(self):
# Restore environment variables
if self.original_openai_api_key: os.environ["OPENAI_API_KEY"] = self.original_openai_api_key
if self.original_azure_openai_key: os.environ["AZURE_OPENAI_KEY"] = self.original_azure_openai_key
if self.original_azure_endpoint: os.environ["AZURE_OPENAI_ENDPOINT"] = self.original_azure_endpoint
if self.original_api_version: os.environ["AZURE_OPENAI_API_VERSION"] = self.original_api_version
if self.original_azure_deployment: os.environ["AZURE_DEPLOYMENT_NAME"] = self.original_azure_deployment
@patch('openai.OpenAI')
def test_init_with_openai_defaults(self, MockOpenAIConstructor):
MockOpenAIConstructor.return_value = self.mock_openai_client_instance
os.environ["OPENAI_API_KEY"] = "test_openai_key"
bot = ConcreteOpenAICompatibleBot(model_name="gpt-4")
MockOpenAIConstructor.assert_called_once_with(api_key="test_openai_key", base_url=None)
self.assertEqual(bot.client, self.mock_openai_client_instance)
self.assertEqual(bot.model, "gpt-4")
self.assertEqual(bot.max_tokens, 1000) # Default from _configure_model_and_tokens
self.assertEqual(bot.azure_openai, False)
@patch('openai.OpenAI')
def test_init_with_provided_client(self, MockOpenAIConstructor):
preconfigured_client = MagicMock()
bot = ConcreteOpenAICompatibleBot(client=preconfigured_client, model_name="gpt-3.5")
MockOpenAIConstructor.assert_not_called()
self.assertEqual(bot.client, preconfigured_client)
self.assertEqual(bot.model, "gpt-3.5")
@patch('openai.AzureOpenAI')
def test_init_with_azure_config_args(self, MockAzureOpenAIConstructor):
MockAzureOpenAIConstructor.return_value = self.mock_openai_client_instance
bot = ConcreteOpenAICompatibleBot(
api_key="azure_key",
azure_endpoint="https://myenv.openai.azure.com",
api_version="2023-05-15",
azure_deployment="my-gpt-4", # This should be used as model_name for API call
model_name="should_be_overridden_by_azure_deployment_for_api"
# model_name is passed to _configure_model_and_tokens, which sets self.model for display/logging
# but for Azure, the client needs the deployment name.
)
MockAzureOpenAIConstructor.assert_called_once_with(
api_key="azure_key",
azure_endpoint="https://myenv.openai.azure.com",
api_version="2023-05-15"
)
self.assertEqual(bot.client, self.mock_openai_client_instance)
self.assertEqual(bot.model, "my-gpt-4") # Azure deployment name becomes the model for API calls
self.assertEqual(bot.azure_openai, True)
@patch('openai.AzureOpenAI')
def test_init_with_azure_env_vars(self, MockAzureOpenAIConstructor):
MockAzureOpenAIConstructor.return_value = self.mock_openai_client_instance
os.environ["AZURE_OPENAI_KEY"] = "env_azure_key"
os.environ["AZURE_OPENAI_ENDPOINT"] = "https://env.openai.azure.com"
os.environ["AZURE_OPENAI_API_VERSION"] = "2023-06-01"
os.environ["AZURE_DEPLOYMENT_NAME"] = "env-gpt-35" # Used as model_name
bot = ConcreteOpenAICompatibleBot(model_name="ignored_if_azure_deployment_env_is_set")
MockAzureOpenAIConstructor.assert_called_once_with(
api_key="env_azure_key",
azure_endpoint="https://env.openai.azure.com",
api_version="2023-06-01"
)
self.assertEqual(bot.model, "env-gpt-35")
self.assertTrue(bot.azure_openai)
@patch('openai.OpenAI')
def test_init_with_gemini_config_args(self, MockOpenAIConstructor):
MockOpenAIConstructor.return_value = self.mock_openai_client_instance
bot = ConcreteOpenAICompatibleBot(
api_key="gemini_key",
base_url="https://gemini.example.com",
model_name="gemini-pro",
is_gemini=True
)
MockOpenAIConstructor.assert_called_once_with(api_key="gemini_key", base_url="https://gemini.example.com")
self.assertEqual(bot.model, "gemini-pro")
self.assertFalse(bot.azure_openai) # is_gemini doesn't mean azure_openai
def test_configure_model_and_tokens(self):
bot = ConcreteOpenAICompatibleBot(model_name="initial_model") # init calls _configure
bot._configure_model_and_tokens("test-model", "500")
self.assertEqual(bot.model, "test-model")
self.assertEqual(bot.max_tokens, 500)
bot._configure_model_and_tokens("test-model-2", None, default_max_tokens=150)
self.assertEqual(bot.max_tokens, 150)
bot._configure_model_and_tokens("test-model-3", "invalid_token_val")
self.assertEqual(bot.max_tokens, 1000) # Default fallback
def test_get_llm_description(self):
bot = ConcreteOpenAICompatibleBot(model_name="desc-model", max_tokens_str="256")
self.assertEqual(bot.get_llm_description(), "LLM: desc-model, Max Tokens: 256, Azure: False")
bot_azure = ConcreteOpenAICompatibleBot(azure_deployment="azure-model", azure_endpoint="x", api_key="y", api_version="z")
self.assertEqual(bot_azure.get_llm_description(), "LLM: azure-model, Max Tokens: 1000, Azure: True")
def test_get_chat_response_success(self):
bot = ConcreteOpenAICompatibleBot(client=self.mock_openai_client_instance, model_name="test-gpt")
bot.max_tokens = 50 # Ensure this is set
mock_api_response = create_mock_openai_response(content="Hello from API")
self.mock_openai_client_instance.chat.completions.create.return_value = mock_api_response
messages = [{"role": "user", "content": "Hi"}]
response = bot.get_chat_response(messages)
self.mock_openai_client_instance.chat.completions.create.assert_called_once_with(
model="test-gpt",
messages=messages,
tools=ANY, # Assuming functions can be None or empty list
tool_choice=ANY,
max_tokens=50
)
self.assertEqual(response, mock_api_response)
def test_get_chat_response_api_error(self):
bot = ConcreteOpenAICompatibleBot(client=self.mock_openai_client_instance, model_name="error-gpt")
self.mock_openai_client_instance.chat.completions.create.side_effect = Exception("API Down")
with self.assertRaisesRegex(Exception, "API Down"):
bot.get_chat_response([{"role": "user", "content": "trigger"}])
async def test_handle_message_simple_response(self):
bot = ConcreteOpenAICompatibleBot(client=self.mock_openai_client_instance, model_name="chatty")
bot.system_prompt = "You are a test bot." # Set directly for simplicity
mock_api_response = create_mock_openai_response(content="Test reply")
self.mock_openai_client_instance.chat.completions.create.return_value = mock_api_response
response_content = await bot.handle_message(user_id=1, user_message="Hello")
self.assertEqual(response_content, "Test reply")
self.assertIn(1, bot.conversation_history)
self.assertEqual(len(bot.conversation_history[1]), 3) # System, User, Assistant
self.assertEqual(bot.conversation_history[1][0]["content"], "You are a test bot.")
self.assertEqual(bot.conversation_history[1][2]["content"], "Test reply")
async def test_handle_message_with_tool_call_and_response(self):
bot = ConcreteOpenAICompatibleBot(client=self.mock_openai_client_instance, model_name="tool-user")
# Mock functions/tools setup on the bot
mock_tool_def = {"function": {"name": "get_weather", "description": "Gets weather", "parameters": {}}}
bot.functions = [mock_tool_def] # Simulate tools are loaded
# API response 1: Request to call tool
tool_call_request = [{"id": "call123", "function": {"name": "get_weather", "arguments": '''{"location": "moon"}'''}}]
api_response_1 = create_mock_openai_response(tool_calls=tool_call_request)
# API response 2: Final answer after tool execution
api_response_2 = create_mock_openai_response(content="The weather on the moon is chilly.")
self.mock_openai_client_instance.chat.completions.create.side_effect = [api_response_1, api_response_2]
# Mock self.call_tool
bot.call_tool = MagicMock(return_value='''{"temperature": "-100 C"}''')
final_response = await bot.handle_message(user_id=2, user_message="Weather on moon?")
self.assertEqual(final_response, "The weather on the moon is chilly.")
bot.call_tool.assert_called_once_with("get_weather", '''{"location": "moon"}''')
# Check conversation history includes tool messages
history = bot.conversation_history[2]
self.assertTrue(any(msg["role"] == "assistant" and msg.tool_calls is not None for msg in history))
self.assertTrue(any(msg["role"] == "tool" and msg["name"] == "get_weather" for msg in history))
self.assertEqual(self.mock_openai_client_instance.chat.completions.create.call_count, 2)
async def test_handle_message_max_history_length(self):
bot = ConcreteOpenAICompatibleBot(client=self.mock_openai_client_instance, model_name="hist-test", max_history_length=3)
self.mock_openai_client_instance.chat.completions.create.return_value = create_mock_openai_response(content="Ok")
await bot.handle_message(1, "Msg1") # Sys, User, Assist (3)
self.assertEqual(len(bot.conversation_history[1]), 3)
await bot.handle_message(1, "Msg2") # User, Assist. Should be 3 (prev User, prev Assist, new User) -> then adds new Assist.
# Before new call: [Sys, U1, A1]. New U2. Call with [Sys,U1,A1,U2]. Resp A2.
# History: [Sys,U1,A1,U2,A2]. Limit 3. -> [A1,U2,A2] (if system is not preserved specially)
# The current code appends to history then truncates if over limit.
# So after Msg1: [S, U1, A1]. len=3.
# For Msg2: History is [S, U1, A1]. Append U2. Call with [S,U1,A1,U2]. Append A2.
# History now [S,U1,A1,U2,A2]. len=5. Truncate to 3.
# Expected: [A1, U2, A2] or [U1,A1,U2] or [U2,A2,S] depending on how system prompt is handled in truncation.
# The code is: self.conversation_history[user_id][-self.max_history_length:]
# And system prompt is only added IF user_id not in self.conversation_history.
# So, for Msg2, system prompt is not re-added.
# History before Msg2 call: [S, U1, A1]
# Messages for Msg2 call: [S, U1, A1, U2]
# History after Msg2 response A2: [S, U1, A1, U2, A2]. Len 5.
# Truncated to self.max_history_length=3: [A1, U2, A2]
# Call 1
self.mock_openai_client_instance.chat.completions.create.reset_mock()
self.mock_openai_client_instance.chat.completions.create.return_value = create_mock_openai_response(content="Reply1")
await bot.handle_message(user_id=7, user_message="First message")
self.assertEqual(len(bot.conversation_history[7]), 3) # System, User1, Assistant1
# Call 2
self.mock_openai_client_instance.chat.completions.create.reset_mock()
self.mock_openai_client_instance.chat.completions.create.return_value = create_mock_openai_response(content="Reply2")
await bot.handle_message(user_id=7, user_message="Second message")
# History before call: [S, U1, A1]. Messages for call: [S, U1, A1, U2]. History after: [S, U1, A1, U2, A2].
# Truncated to 3: [A1, U2, A2]
self.assertEqual(len(bot.conversation_history[7]), 3)
self.assertEqual(bot.conversation_history[7][0]["content"], "Reply1") # A1
self.assertEqual(bot.conversation_history[7][1]["content"], "Second message") # U2
self.assertEqual(bot.conversation_history[7][2]["content"], "Reply2") # A2
# Call 3
self.mock_openai_client_instance.chat.completions.create.reset_mock()
self.mock_openai_client_instance.chat.completions.create.return_value = create_mock_openai_response(content="Reply3")
await bot.handle_message(user_id=7, user_message="Third message")
# History before call: [A1, U2, A2]. Messages for call: [A1, U2, A2, U3]. History after: [A1, U2, A2, U3, A3].
# Truncated to 3: [A2, U3, A3]
self.assertEqual(len(bot.conversation_history[7]), 3)
self.assertEqual(bot.conversation_history[7][0]["content"], "Reply2") # A2
self.assertEqual(bot.conversation_history[7][1]["content"], "Third message") # U3
self.assertEqual(bot.conversation_history[7][2]["content"], "Reply3") # A3
async def test_abort_processing(self):
bot = ConcreteOpenAICompatibleBot(model_name="test")
user_id = 123
bot.processing_status[user_id] = {"processing": True, "message_id": 456}
bot.conversation_history[user_id] = [{"role": "user", "content": "stuff"}]
with patch.object(bot, 'clear_conversation_history') as mock_clear_hist: # Patching the method from Base class
result = await bot.abort_processing(user_id)
self.assertEqual(result, "Processing aborted and conversation cleared.")
self.assertFalse(bot.processing_status[user_id]["processing"])
mock_clear_hist.assert_called_once_with(user_id)
async def test_abort_processing_no_active_processing(self):
bot = ConcreteOpenAICompatibleBot(model_name="test")
user_id = 404 # Not in processing_status
with patch.object(bot, 'clear_conversation_history') as mock_clear_hist:
result = await bot.abort_processing(user_id)
self.assertEqual(result, "No active processing found to abort. Conversation cleared.")
mock_clear_hist.assert_called_once_with(user_id)
# Test for the abstract switch_model (basic call, actual logic in concrete class for this test)
async def test_switch_model_concrete_implementation(self):
bot = ConcreteOpenAICompatibleBot(model_name="model1", small_model_name="model1", large_model_name="model2", max_tokens_str="100")
self.assertEqual(bot.model, "model1")
await bot.switch_model() # Calls the concrete implementation
self.assertEqual(bot.model, "model2")
await bot.switch_model()
self.assertEqual(bot.model, "model1")
if __name__ == '__main__':
unittest.main()
-356
View File
@@ -1,356 +0,0 @@
import unittest
from unittest.mock import MagicMock, patch, mock_open, AsyncMock
import asyncio
import os
import sys
# Assuming telegram_helper.py is in the parent directory or PYTHONPATH is set
from telegram_helper import TelegramHelper, MessageHandlerLogicResult
# Mock for the bot passed to TelegramHelper
class MockBot:
def __init__(self):
self.start = AsyncMock()
self.clear_conversation_history = MagicMock()
self.get_bot_status = AsyncMock(return_value="Bot Status OK")
self.switch_model = AsyncMock(return_value="Model Switched OK")
self.handle_message = AsyncMock() # Needs to return a string
self.abort_processing = AsyncMock(return_value="Abort OK")
self.set_processing_status = MagicMock()
self.clear_processing_status = MagicMock()
self.processing_status = {} # Add the attribute
# Mock for telegram.Update and related objects
def create_mock_update(message_text=None, user_id=123, chat_id=456, message_id=789, callback_query_data=None):
update = MagicMock()
update.effective_user.id = user_id
update.effective_chat.id = chat_id
if message_text:
update.message.text = message_text
update.message.reply_text = AsyncMock(return_value=MagicMock(message_id=message_id)) # reply_text returns a Message obj
if callback_query_data:
update.callback_query.data = callback_query_data
update.callback_query.from_user.id = user_id
update.callback_query.answer = AsyncMock()
update.callback_query.edit_message_text = AsyncMock()
return update
# Mock for telegram.ext.ContextTypes.DEFAULT_TYPE
def create_mock_context():
context = MagicMock()
context.bot.delete_message = AsyncMock()
context.bot.edit_message_text = AsyncMock() # For update_status_message
return context
class TestTelegramHelper(unittest.IsolatedAsyncioTestCase): # Use IsolatedAsyncioTestCase for async methods
def setUp(self):
self.mock_bot = MockBot()
# Default paths for reboot files, can be overridden in tests
self.reboot_claude_file = ".test_reboot_claude"
self.reboot_file = ".test_doreboot"
self.helper = TelegramHelper(
self.mock_bot,
reboot_claude_file_path=self.reboot_claude_file,
reboot_file_path=self.reboot_file,
chunk_message_sleep_duration=0.001 # Faster sleep for tests
)
# Clean up any potential leftover reboot files from previous runs
if os.path.exists(self.reboot_claude_file):
os.remove(self.reboot_claude_file)
if os.path.exists(self.reboot_file):
os.remove(self.reboot_file)
def tearDown(self):
# Clean up reboot files created during tests
if os.path.exists(self.reboot_claude_file):
os.remove(self.reboot_claude_file)
if os.path.exists(self.reboot_file):
os.remove(self.reboot_file)
async def test_start_logic(self):
response = await self.helper._start_logic()
self.mock_bot.start.assert_called_once()
self.assertEqual(response, "Hello! I\'m your AI assistant. How can I help you today?")
async def test_start_command(self):
mock_update = create_mock_update(message_text="/start")
mock_context = create_mock_context()
with patch.object(self.helper, \'_start_logic\', new_callable=AsyncMock) as mock_logic:
mock_logic.return_value = "Start Logic Response"
await self.helper.start(mock_update, mock_context)
mock_logic.assert_called_once()
mock_update.message.reply_text.assert_called_once_with("Start Logic Response")
async def test_clear_logic(self):
user_id = 123
response = await self.helper._clear_logic(user_id) # _clear_logic is async after refactor
self.mock_bot.clear_conversation_history.assert_called_once_with(user_id)
self.assertEqual(response, "Conversation history cleared. Let\'s start fresh!")
async def test_clear_command(self):
mock_update = create_mock_update(message_text="/clear", user_id=123)
mock_context = create_mock_context()
with patch.object(self.helper, \'_clear_logic\', new_callable=AsyncMock) as mock_logic:
mock_logic.return_value = "Clear Logic Response"
await self.helper.clear(mock_update, mock_context)
mock_logic.assert_called_once_with(123)
mock_update.message.reply_text.assert_called_once_with("Clear Logic Response")
async def test_status_logic(self):
self.mock_bot.get_bot_status.return_value = "Test Status"
response = await self.helper._status_logic()
self.mock_bot.get_bot_status.assert_called_once()
self.assertEqual(response, "Test Status")
async def test_switch_logic_supported(self):
self.mock_bot.switch_model.return_value = "Switched to Large Model"
response = await self.helper._switch_logic()
self.mock_bot.switch_model.assert_called_once()
self.assertEqual(response, "Switched to Large Model")
async def test_switch_logic_not_supported(self):
del self.mock_bot.switch_model # Simulate bot not having the attribute
response = await self.helper._switch_logic()
self.assertEqual(response, "Model switching is not supported for this bot.")
async def test_handle_message_logic_success(self):
user_id = 100
user_message = "Hello bot"
bot_response = "Hello user <think>Thinking hard</think> Done."
expected_processed_response = f"Hello user {self.helper.HTML_QUOTE_BLOCK_START}Thinking hard{self.helper.HTML_QUOTE_BLOCK_END} Done."
self.mock_bot.handle_message.return_value = bot_response
result = await self.helper._handle_message_logic(user_id, user_message)
self.mock_bot.handle_message.assert_called_once_with(user_id, user_message)
self.assertTrue(result["success"])
self.assertEqual(result["response_text"], expected_processed_response)
self.assertIsNone(result["error_message"])
async def test_handle_message_logic_bot_exception(self):
user_id = 101
user_message = "Trigger error"
self.mock_bot.handle_message.side_effect = Exception("Bot Error")
result = await self.helper._handle_message_logic(user_id, user_message)
self.assertFalse(result["success"])
self.assertIsNone(result["response_text"])
self.assertEqual(result["error_message"], "Bot Error")
@patch(\'logging.error\')
async def test_handle_message_command_success_short_message(self, mock_logging_error):
mock_update = create_mock_update(message_text="Hi", user_id=200, chat_id=201, message_id=202)
mock_context = create_mock_context()
logic_result = MessageHandlerLogicResult(success=True, response_text="Short response", error_message=None)
with patch.object(self.helper, \'_handle_message_logic\', new_callable=AsyncMock) as mock_message_logic:
mock_message_logic.return_value = logic_result
await self.helper.handle_message(mock_update, mock_context)
mock_update.message.reply_text.assert_any_call("Processing your request...", reply_markup=unittest.mock.ANY)
self.mock_bot.set_processing_status.assert_called_once_with(200, 202) # user_id, status_message_id
mock_message_logic.assert_called_once_with(200, "Hi")
mock_context.bot.delete_message.assert_called_once_with(chat_id=201, message_id=202)
self.mock_bot.clear_processing_status.assert_called_once_with(200)
mock_update.message.reply_text.assert_any_call("Short response") # Final response
self.assertEqual(mock_update.message.reply_text.call_count, 2) # Processing + final
@patch(\'logging.error\')
async def test_handle_message_command_success_long_message_chunks(self, mock_logging_error):
mock_update = create_mock_update(message_text="Long text", user_id=200, chat_id=201, message_id=202)
mock_context = create_mock_context()
long_response_text = "a" * 5000 # Longer than 4096
chunk1 = long_response_text[:4096]
chunk2 = long_response_text[4096:]
logic_result = MessageHandlerLogicResult(success=True, response_text=long_response_text, error_message=None)
with patch.object(self.helper, \'_handle_message_logic\', new_callable=AsyncMock) as mock_message_logic, \
patch(\'asyncio.sleep\', new_callable=AsyncMock) as mock_sleep: # Mock sleep
mock_message_logic.return_value = logic_result
await self.helper.handle_message(mock_update, mock_context)
mock_update.message.reply_text.assert_any_call(chunk1)
mock_update.message.reply_text.assert_any_call(chunk2)
mock_sleep.assert_called_once_with(self.helper.chunk_message_sleep_duration)
self.assertEqual(mock_update.message.reply_text.call_count, 3) # Processing + 2 chunks
@patch(\'logging.error\')
async def test_handle_message_command_logic_fails(self, mock_logging_error):
mock_update = create_mock_update(message_text="Cause error in logic", user_id=200)
mock_context = create_mock_context()
logic_result = MessageHandlerLogicResult(success=False, response_text=None, error_message="Logic Failed")
with patch.object(self.helper, \'_handle_message_logic\', new_callable=AsyncMock) as mock_message_logic:
mock_message_logic.return_value = logic_result
await self.helper.handle_message(mock_update, mock_context)
mock_update.message.reply_text.assert_any_call("Sorry, an error occurred while processing your request.")
self.assertEqual(mock_update.message.reply_text.call_count, 2) # Processing + error message
@patch(\'logging.error\')
async def test_handle_message_command_telegram_exception_after_logic(self, mock_logging_error):
mock_update = create_mock_update(message_text="Test", user_id=200)
mock_context = create_mock_context()
logic_result = MessageHandlerLogicResult(success=True, response_text="OK", error_message=None)
# Make sending the final reply fail
mock_update.message.reply_text.side_effect = [
MagicMock(message_id=202), # For "Processing..."
Exception("Telegram API Error") # For the actual response
]
with patch.object(self.helper, \'_handle_message_logic\', new_callable=AsyncMock) as mock_message_logic:
mock_message_logic.return_value = logic_result
await self.helper.handle_message(mock_update, mock_context)
# Check if the generic error message was attempted
# This is tricky because reply_text is already mocked with side_effect.
# We\'d expect logs. Let\'s check logs or if processing status was cleared.
self.mock_bot.clear_processing_status.assert_called_once_with(200)
mock_logging_error.assert_any_call(unittest.mock.string_containing("Outer error in handle_message"))
async def test_abort_processing_logic(self):
user_id = 300
self.mock_bot.abort_processing.return_value = "Aborted by bot"
response = await self.helper._abort_processing_logic(user_id)
self.mock_bot.abort_processing.assert_called_once_with(user_id)
self.assertEqual(response, "Aborted by bot")
async def test_abort_processing_command(self):
mock_update = create_mock_update(callback_query_data=\'abort\', user_id=301)
mock_context = create_mock_context()
with patch.object(self.helper, \'_abort_processing_logic\', new_callable=AsyncMock) as mock_logic:
mock_logic.return_value = "Abort Logic Done"
await self.helper.abort_processing(mock_update, mock_context)
mock_update.callback_query.answer.assert_called_once()
mock_logic.assert_called_once_with(301)
mock_update.callback_query.edit_message_text.assert_called_once_with(text="Abort Logic Done")
def test_reboot_logic_claude_and_main(self):
user_message_parts = ["/reboot", "claude"]
chat_id_to_write = "12345"
with patch("builtins.open", mock_open()) as mock_file:
self.helper._reboot_logic(user_message_parts, chat_id_to_write)
# Check claude reboot file
mock_file.assert_any_call(self.reboot_claude_file, \'w\')
# Check main doreboot file
mock_file.assert_any_call(self.reboot_file, \'w\')
handle_claude = mock_file.return_value
handle_main = mock_file.return_value # mock_open reuses the handle for multiple calls
# Check if write was called for claude file (empty write)
# This part of assertion is tricky with single mock_file. Better to use different mocks if possible
# or check the sequence of calls if the mock supports it well.
# For now, assert_any_call ensures it was opened.
# Check content for main reboot file
# Need to ensure the write for self.reboot_file had chat_id_to_write
# This requires more sophisticated mock_open or patching os.path.exists and multiple open calls
# Simpler check: was open(self.reboot_file, \'w\') called? Yes, via assert_any_call.
# And was open(self.reboot_claude_file, \'w\') called? Yes.
# Verify files were created (mock_open doesn\'t actually create them)
# This test relies on mock_open\'s behavior. To test file content, need more setup.
# For now, assume open was called correctly.
def test_reboot_logic_main_only(self):
user_message_parts = ["/reboot"]
chat_id_to_write = "67890"
with patch("builtins.open", mock_open()) as mock_file:
self.helper._reboot_logic(user_message_parts, chat_id_to_write)
# Ensure claude file was NOT opened for writing.
# This requires asserting that a specific call didn\'t happen, or checking call_args_list
claude_call = unittest.mock.call(self.reboot_claude_file, \'w\')
self.assertNotIn(claude_call, mock_file.call_args_list)
mock_file.assert_any_call(self.reboot_file, \'w\')
@patch(\'sys.exit\') # Mock sys.exit to prevent test runner from exiting
async def test_reboot_command(self, mock_sys_exit):
mock_update = create_mock_update(message_text="/reboot claude", chat_id="chat1")
mock_context = create_mock_context()
with patch.object(self.helper, \'_reboot_logic\') as mock_reboot_file_logic:
await self.helper.reboot(mock_update, mock_context)
mock_reboot_file_logic.assert_called_once_with(["/reboot", "claude"], "chat1")
mock_update.message.reply_text.assert_called_once_with("Rebooting the bot...")
mock_sys_exit.assert_called_once_with(0)
@patch(\'os.path.exists\')
@patch(\'builtins.open\', new_callable=mock_open)
@patch(\'os.remove\')
async def test_check_doreboot_file_logic_file_exists(self, mock_os_remove, mock_file_open, mock_os_path_exists):
mock_os_path_exists.return_value = True
mock_file_open.return_value.read.return_value.strip.return_value = "chat123"
chat_id = await self.helper._check_doreboot_file_logic()
mock_os_path_exists.assert_called_once_with(self.reboot_file)
mock_file_open.assert_called_once_with(self.reboot_file, \'r\')
mock_os_remove.assert_called_once_with(self.reboot_file)
self.assertEqual(chat_id, "chat123")
@patch(\'os.path.exists\', return_value=False)
async def test_check_doreboot_file_logic_file_not_exists(self, mock_os_path_exists):
chat_id = await self.helper._check_doreboot_file_logic()
mock_os_path_exists.assert_called_once_with(self.reboot_file)
self.assertIsNone(chat_id)
@patch(\'logging.error\')
@patch(\'os.path.exists\', return_value=True)
@patch(\'builtins.open\', side_effect=IOError("Read error"))
@patch(\'os.remove\') # To check if remove is called even on read error
async def test_check_doreboot_file_logic_read_error(self, mock_os_remove, mock_file_open, mock_os_path_exists, mock_log_error):
chat_id = await self.helper._check_doreboot_file_logic()
self.assertIsNone(chat_id)
mock_log_error.assert_any_call(unittest.mock.string_containing("Error reading reboot file"))
# Check if os.remove was attempted even after read error
mock_os_remove.assert_called_once_with(self.reboot_file)
async def test_check_doreboot_file_command_sends_message(self):
mock_application = MagicMock()
mock_application.bot.send_message = AsyncMock()
with patch.object(self.helper, \'_check_doreboot_file_logic\', new_callable=AsyncMock) as mock_logic:
mock_logic.return_value = "chat789" # Simulate chat_id found
await self.helper.check_doreboot_file(mock_application)
mock_logic.assert_called_once()
mock_application.bot.send_message.assert_called_once_with(
chat_id="chat789", text="The application has finished initializing."
)
async def test_check_doreboot_file_command_no_chat_id(self):
mock_application = MagicMock()
mock_application.bot.send_message = AsyncMock()
with patch.object(self.helper, \'_check_doreboot_file_logic\', new_callable=AsyncMock) as mock_logic:
mock_logic.return_value = None # Simulate no chat_id found
await self.helper.check_doreboot_file(mock_application)
mock_logic.assert_called_once()
mock_application.bot.send_message.assert_not_called()
# Note: Testing the run() method itself is more of an integration test,
# as it involves setting up the full Application and polling loop.
# Unit tests here focus on the helper\'s own logic methods.
if __name__ == \'__main__\':
unittest.main()
-307
View File
@@ -1,307 +0,0 @@
import unittest
from unittest.mock import MagicMock, patch
import os
import base64
import logging
import requests # Required for spec in MagicMock
# Ensure tools/github_tool.py is accessible
from tools.github_tool import GitHubTool
# Helper to create a mock response for requests.Session
def create_mock_response(status_code, json_data=None, text_data=None, headers=None, links=None):
mock_resp = MagicMock()
mock_resp.status_code = status_code
if json_data is not None:
mock_resp.json = MagicMock(return_value=json_data)
mock_resp.text = text_data if text_data is not None else str(json_data)
mock_resp.headers = headers if headers else {}
mock_resp.links = links if links else {} # For pagination in _list_branches
return mock_resp
class TestGitHubTool(unittest.TestCase):
def setUp(self):
self.mock_session = MagicMock(spec=requests.Session)
self.mock_session.headers = {} # Simulate a new session's headers
self.test_token = "test_github_token"
self.test_repo = "owner/repo"
self.test_base_url = "https://api.example.com" # Use a non-default base_url for some tests
# Suppress logging output during tests unless explicitly testing for it
self.logger = logging.getLogger('tools.github_tool')
# Ensure only one NullHandler to prevent duplicate messages if tests run multiple times in a session
if not any(isinstance(h, logging.NullHandler) for h in self.logger.handlers):
self.logger.addHandler(logging.NullHandler())
self.logger.propagate = False # Prevent propagation to root logger if it has handlers
def test_init_with_args_and_session(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, base_url=self.test_base_url, logger=self.logger)
self.assertEqual(tool.session, self.mock_session)
self.assertEqual(tool._token, self.test_token)
self.assertEqual(tool._repo, self.test_repo)
self.assertEqual(tool.base_url, self.test_base_url)
self.assertEqual(tool.current_branch, "main") # Default initial branch
@patch('requests.Session')
def test_init_creates_session_if_not_provided(self, MockSessionConstructor):
mock_created_session = MagicMock(spec=requests.Session)
mock_created_session.headers = {}
MockSessionConstructor.return_value = mock_created_session
# Temporarily set env vars for this test
original_token = os.environ.get("GITHUB_TOKEN")
original_repo = os.environ.get("GITHUB_REPOSITORY")
os.environ["GITHUB_TOKEN"] = "env_token"
os.environ["GITHUB_REPOSITORY"] = "env/repo"
tool = GitHubTool(logger=self.logger) # Use env vars
MockSessionConstructor.assert_called_once()
self.assertEqual(tool.session, mock_created_session)
self.assertEqual(tool._token, "env_token")
self.assertEqual(tool._repo, "env/repo")
self.assertIn("Authorization", mock_created_session.headers)
self.assertEqual(mock_created_session.headers["Authorization"], "token env_token")
# Restore original env vars
if original_token is None: del os.environ["GITHUB_TOKEN"]
else: os.environ["GITHUB_TOKEN"] = original_token
if original_repo is None: del os.environ["GITHUB_REPOSITORY"]
else: os.environ["GITHUB_REPOSITORY"] = original_repo
def test_init_raises_value_error_if_no_token(self):
original_token = os.environ.pop("GITHUB_TOKEN", None)
with self.assertRaisesRegex(ValueError, "GitHub token must be provided"):
GitHubTool(repo=self.test_repo, logger=self.logger)
if original_token: os.environ["GITHUB_TOKEN"] = original_token
def test_init_raises_value_error_if_no_repo(self):
original_repo = os.environ.pop("GITHUB_REPOSITORY", None)
with self.assertRaisesRegex(ValueError, "GitHub repository.*must be provided"):
GitHubTool(token=self.test_token, logger=self.logger)
if original_repo: os.environ["GITHUB_REPOSITORY"] = original_repo
def test_clear_resets_branch(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, initial_branch="feature-branch", logger=self.logger)
# Mock _get_branch_sha for _set_current_branch called by clear
with patch.object(tool, '_get_branch_sha', return_value="sha_for_main"):
tool.clear()
self.assertEqual(tool.current_branch, "main")
def test_get_functions_returns_list(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
functions = tool.get_functions()
self.assertIsInstance(functions, list)
self.assertTrue(len(functions) > 0)
self.assertIn("name", functions[0]["function"])
# --- Test individual private methods ---
def test_read_file_success(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
file_content = "Hello World!"
encoded_content = base64.b64encode(file_content.encode('utf-8')).decode('utf-8')
self.mock_session.get.return_value = create_mock_response(200, json_data={"content": encoded_content})
result = tool._read_file(path="test.txt")
self.assertEqual(result, file_content)
self.mock_session.get.assert_called_once_with(
f"{tool.base_url}/repos/{self.test_repo}/contents/test.txt",
params={"ref": "main"}
)
def test_read_file_error(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
self.mock_session.get.return_value = create_mock_response(404, text_data="Not Found")
result = tool._read_file(path="nonexistent.txt")
self.assertIn("Error reading file", result)
def test_create_branch_success(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
# Mock getting base branch SHA
self.mock_session.get.return_value = create_mock_response(200, json_data={"object": {"sha": "base_sha123"}})
# Mock creating new branch
self.mock_session.post.return_value = create_mock_response(201, json_data={"ref": "refs/heads/new-feature"})
result = tool._create_branch(branch_name="new-feature", base_branch="main")
self.assertIn("Branch 'new-feature' created successfully", result)
self.assertEqual(tool.current_branch, "new-feature")
self.mock_session.get.assert_called_once() # For base branch SHA
self.mock_session.post.assert_called_once() # For creating branch
def test_create_branch_base_sha_error(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
self.mock_session.get.return_value = create_mock_response(404, text_data="Base branch not found")
result = tool._create_branch(branch_name="new-feature", base_branch="nonexistent-base")
self.assertIn("Error getting base branch SHA", result)
def test_create_branch_creation_error(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
self.mock_session.get.return_value = create_mock_response(200, json_data={"object": {"sha": "base_sha456"}})
self.mock_session.post.return_value = create_mock_response(422, text_data="Validation failed")
result = tool._create_branch(branch_name="bad-branch", base_branch="main")
self.assertIn("Error creating branch", result)
def test_commit_file_success_new_file(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
tool.current_branch = "dev-branch" # Cannot commit to main by default
# Mock GET for checking file existence (404 means new file)
self.mock_session.get.return_value = create_mock_response(404)
# Mock PUT for committing file
self.mock_session.put.return_value = create_mock_response(201, json_data={"commit": {"sha": "commit_sha_abc"}})
result = tool._commit_file(file_path="new_file.py", content="print('Hello')", commit_message="Add new_file.py")
self.assertIn("committed successfully", result)
self.assertIn("commit_sha_abc", result)
self.mock_session.get.assert_called_once() # Check file existence
self.mock_session.put.assert_called_once() # Commit file
def test_commit_file_success_update_file(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
tool.current_branch = "dev-branch"
# Mock GET for checking file existence (200 means existing file)
self.mock_session.get.return_value = create_mock_response(200, json_data={"sha": "existing_file_sha"})
# Mock PUT for committing file
self.mock_session.put.return_value = create_mock_response(200, json_data={"commit": {"sha": "commit_sha_def"}})
result = tool._commit_file(file_path="existing_file.txt", content="Updated content", commit_message="Update existing_file.txt")
self.assertIn("committed successfully", result)
self.assertIn("commit_sha_def", result)
args, kwargs = self.mock_session.put.call_args
self.assertEqual(kwargs['json']['sha'], "existing_file_sha")
def test_commit_file_to_main_branch_error(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
tool.current_branch = "main"
result = tool._commit_file(file_path="some.txt", content="content", commit_message="msg")
self.assertIn("Action directly to main branch is not allowed", result)
def test_create_pull_request_success(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
tool.current_branch = "feature-pr"
pr_url = "https://example.com/pull/1"
self.mock_session.post.return_value = create_mock_response(201, json_data={"html_url": pr_url, "number": 1})
result = tool._create_pull_request(title="New Feature PR", body="Please review.", base="main")
self.assertIn(f"Pull request created successfully: {pr_url}", result)
self.mock_session.post.assert_called_once()
call_data = self.mock_session.post.call_args[1]['json']
self.assertEqual(call_data['head'], "feature-pr")
self.assertEqual(call_data['base'], "main")
def test_create_pull_request_same_branch_error(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
tool.current_branch = "main"
result = tool._create_pull_request(title="PR to self", body="This should fail", base="main")
self.assertIn("Cannot create a pull request from branch 'main' to itself", result)
def test_list_files_success(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
mock_items = [
{"name": "file1.txt", "type": "file", "path": "dir/file1.txt"},
{"name": "subdir", "type": "dir", "path": "dir/subdir"}
]
self.mock_session.get.return_value = create_mock_response(200, json_data=mock_items)
result = tool._list_files(path="dir")
self.assertEqual(len(result), 2)
self.assertEqual(result[0]["name"], "file1.txt")
self.assertEqual(result[1]["type"], "dir")
def test_search_code_success(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
mock_search_results = {
"items": [{"path": "src/code.py", "html_url": "url1"}]
}
self.mock_session.get.return_value = create_mock_response(200, json_data=mock_search_results)
results = tool._search_code(query="my_function")
self.assertEqual(len(results), 1)
self.assertEqual(results[0]["path"], "src/code.py")
def test_get_commit_history_success(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
mock_commits = [{
"sha": "sha1", "commit": {"message": "Msg1", "author": {"name": "Authy", "date": "Date1"}}
}]
self.mock_session.get.return_value = create_mock_response(200, json_data=mock_commits)
commits = tool._get_commit_history(file_path="file.txt", num_commits=1)
self.assertEqual(len(commits), 1)
self.assertEqual(commits[0]["sha"], "sha1")
def test_set_current_branch_success(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
# Mock _get_branch_sha to simulate branch exists
with patch.object(tool, '_get_branch_sha', return_value="some_sha_for_dev"):
result = tool._set_current_branch(branch_name="dev")
self.assertEqual(tool.current_branch, "dev")
self.assertIn("Current branch set to: dev", result)
def test_set_current_branch_not_exists(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
with patch.object(tool, '_get_branch_sha', return_value="Error getting SHA for branch"):
result = tool._set_current_branch(branch_name="nonexistent-branch")
self.assertNotEqual(tool.current_branch, "nonexistent-branch") # Should not change
self.assertIn("Cannot set current branch", result)
def test_list_branches_single_page(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
mock_branches = [{"name": "main"}, {"name": "dev"}]
self.mock_session.get.return_value = create_mock_response(200, json_data=mock_branches, links={}) # No "next" link
branches = tool._list_branches(all_pages=True)
self.assertEqual(branches, ["main", "dev"])
self.mock_session.get.assert_called_once()
def test_list_branches_multiple_pages(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
# Page 1 response
page1_branches = [{"name": "branch1"}, {"name": "branch2"}]
next_url = f"{tool.base_url}/repos/{self.test_repo}/branches?page=2"
response1 = create_mock_response(200, json_data=page1_branches, links={"next": {"url": next_url}})
# Page 2 response
page2_branches = [{"name": "branch3"}]
response2 = create_mock_response(200, json_data=page2_branches, links={}) # No "next" link
self.mock_session.get.side_effect = [response1, response2]
branches = tool._list_branches(all_pages=True)
self.assertEqual(branches, ["branch1", "branch2", "branch3"])
self.assertEqual(self.mock_session.get.call_count, 2)
# Check that the second call used the next_url
calls = self.mock_session.get.call_args_list
self.assertEqual(calls[1][0][0], next_url) # args[0] is the URL
# --- Test execute dispatcher ---
def test_execute_read_file(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
with patch.object(tool, '_read_file', return_value="file content") as mock_method:
result = tool.execute(function_name="read_file", path="test.md")
mock_method.assert_called_once_with(path="test.md")
self.assertEqual(result, "file content")
def test_execute_unknown_function(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
result = tool.execute(function_name="non_existent_function_name", arg1="val1")
self.assertIn("Unknown function: non_existent_function_name", result)
def test_execute_method_exception(self):
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
with patch.object(tool, '_read_file', side_effect=Exception("Kaboom")) as mock_method:
result = tool.execute(function_name="read_file", path="crash.txt")
self.assertIn("Error during read_file execution: Kaboom", result)
if __name__ == '__main__':
unittest.main()
-146
View File
@@ -1,146 +0,0 @@
import unittest
from unittest.mock import patch, mock_open, MagicMock
import os
import logging
from datetime import datetime, timedelta
# Ensure tools/log_tool.py is accessible
from tools.log_tool import LogTool
class TestLogTool(unittest.TestCase):
def setUp(self):
self.test_log_file_path = "test_dummy_log.log"
# Suppress logging output during tests unless explicitly testing for it
self.logger = logging.getLogger('tools.log_tool')
# Ensure only one NullHandler to prevent duplicate messages if tests run multiple times in a session
if not any(isinstance(h, logging.NullHandler) for h in self.logger.handlers):
self.logger.addHandler(logging.NullHandler())
self.logger.propagate = False # Prevent propagation to root logger if it has handlers
def test_init_default_log_path(self):
tool = LogTool(logger=self.logger)
self.assertEqual(tool.configured_log_file_path, 'logs/output.log')
def test_init_custom_log_path(self):
tool = LogTool(log_file_path=self.test_log_file_path, logger=self.logger)
self.assertEqual(tool.configured_log_file_path, self.test_log_file_path)
def test_get_functions(self):
tool = LogTool(logger=self.logger)
functions = tool.get_functions()
self.assertIsInstance(functions, list)
self.assertEqual(len(functions), 1)
self.assertEqual(functions[0]["function"]["name"], "get_log_contents")
@patch("os.path.exists", return_value=False)
def test_get_log_contents_file_not_exists(self, mock_exists):
tool = LogTool(log_file_path=self.test_log_file_path, logger=self.logger)
result = tool._get_log_contents()
self.assertIn("Log file does not exist", result)
mock_exists.assert_called_once_with(self.test_log_file_path)
@patch("os.path.exists", return_value=True)
@patch("builtins.open", new_callable=mock_open, read_data="line1\nline2\nline3\nline4\nline5")
def test_get_log_contents_with_line_count(self, mock_file_open, mock_exists):
tool = LogTool(log_file_path=self.test_log_file_path, logger=self.logger)
result = tool._get_log_contents(line_count=3)
self.assertEqual(result, "line3\nline4\nline5")
mock_exists.assert_called_once_with(self.test_log_file_path)
mock_file_open.assert_called_once_with(self.test_log_file_path, 'r', encoding='utf-8')
@patch("os.path.exists", return_value=True)
@patch("builtins.open", new_callable=mock_open, read_data="line1\nline2\n")
def test_get_log_contents_line_count_more_than_available(self, mock_file_open, mock_exists):
tool = LogTool(log_file_path=self.test_log_file_path, logger=self.logger)
result = tool._get_log_contents(line_count=5)
self.assertEqual(result, "line1\nline2\n")
@patch("os.path.exists", return_value=True)
@patch("builtins.open", new_callable=mock_open, read_data="line1\nline2\n")
def test_get_log_contents_invalid_line_count_uses_default(self, mock_file_open, mock_exists):
tool = LogTool(log_file_path=self.test_log_file_path, logger=self.logger)
# Test with zero, negative, and non-integer line_count (though type hint is int)
# The code defaults to 150 if invalid. Here, we only have 2 lines.
with patch.object(tool.logger, 'warning') as mock_log_warning:
result_zero = tool._get_log_contents(line_count=0)
self.assertEqual(result_zero, "line1\nline2\n")
mock_log_warning.assert_any_call("Invalid line_count '0' provided, defaulting to fetch last 150 lines.")
mock_file_open.reset_mock() # Reset for next call
result_neg = tool._get_log_contents(line_count=-5)
self.assertEqual(result_neg, "line1\nline2\n")
mock_log_warning.assert_any_call("Invalid line_count '-5' provided, defaulting to fetch last 150 lines.")
@patch("os.path.exists", return_value=True)
def test_get_log_contents_last_24_hours(self, mock_exists):
tool = LogTool(log_file_path=self.test_log_file_path, logger=self.logger)
now = datetime.now()
one_hour_ago_dt = now - timedelta(hours=1)
two_days_ago_dt = now - timedelta(days=2)
one_hour_ago_str = one_hour_ago_dt.strftime(LogTool.EXPECTED_LOG_TIMESTAMP_FORMAT)
two_days_ago_str = two_days_ago_dt.strftime(LogTool.EXPECTED_LOG_TIMESTAMP_FORMAT)
log_data = (
f"{two_days_ago_str} - OLD - This is an old log entry.\n"
f"No timestamp here - this line should be skipped by time filter.\n"
f"{one_hour_ago_str} - RECENT - This is a recent log entry.\n"
f"Malformed Date 2023-xx-01 - Another skipped line.\n"
f"{now.strftime(LogTool.EXPECTED_LOG_TIMESTAMP_FORMAT)} - CURRENT - This is a current log entry.\n"
)
expected_output = (
f"{one_hour_ago_str} - RECENT - This is a recent log entry.\n"
f"{now.strftime(LogTool.EXPECTED_LOG_TIMESTAMP_FORMAT)} - CURRENT - This is a current log entry.\n"
)
with patch("builtins.open", mock_open(read_data=log_data)):
result = tool._get_log_contents(line_count=None) # Trigger 24-hour logic
self.assertEqual(result, expected_output)
@patch("os.path.exists", return_value=True)
@patch("builtins.open", side_effect=IOError("File read error!"))
def test_get_log_contents_file_read_exception(self, mock_file_open, mock_exists):
tool = LogTool(log_file_path=self.test_log_file_path, logger=self.logger)
result = tool._get_log_contents(line_count=10)
self.assertIn("An error occurred while reading the log file: File read error!", result)
def test_execute_get_log_contents(self):
tool = LogTool(logger=self.logger)
mock_return_value = "Mocked log content"
with patch.object(tool, '_get_log_contents', return_value=mock_return_value) as mock_method:
result = tool.execute(function_name="get_log_contents", line_count=50)
mock_method.assert_called_once_with(line_count=50)
self.assertEqual(result, mock_return_value)
def test_execute_get_log_contents_no_line_count(self):
tool = LogTool(logger=self.logger)
mock_return_value = "Mocked log content for 24h"
with patch.object(tool, '_get_log_contents', return_value=mock_return_value) as mock_method:
result = tool.execute(function_name="get_log_contents") # No line_count
mock_method.assert_called_once_with(line_count=None) # Expects None to trigger 24h
self.assertEqual(result, mock_return_value)
def test_execute_unknown_function(self):
tool = LogTool(logger=self.logger)
result = tool.execute(function_name="non_existent_log_function")
self.assertIn("Unknown function: non_existent_log_function", result)
def test_clear_method(self):
tool = LogTool(logger=self.logger)
# Set a specific level for the logger for this test if needed to capture debug
original_level = tool.logger.level
tool.logger.setLevel(logging.DEBUG)
with self.assertLogs(tool.logger, level='DEBUG') as cm:
tool.clear()
self.assertTrue(any("LogTool clear called" in message for message in cm.output))
tool.logger.setLevel(original_level) # Reset level
if __name__ == '__main__':
unittest.main()
-217
View File
@@ -1,217 +0,0 @@
import unittest
from unittest.mock import patch, MagicMock, ANY
import time
import logging
# Ensure tools.metrics is accessible
from tools.metrics import Metrics # Import the class itself for direct testing
from tools.metrics import metrics as global_metrics_instance # Import the global instance
# A simple function to decorate for testing
def sample_function_for_metrics(duration=0.01):
# Simulate some work
# Note: time.sleep is not always precisely profiled by cProfile in the same way as CPU-bound work.
# For testing, we will mock the cProfile/pstats interaction rather than relying on actual sleep duration.
if duration > 0: # Make it conditional so we can test zero-time case too
pass # The actual work is not important when mocking cProfile results
return "sample_output"
def another_sample_function(x, y):
return x + y
class TestMetrics(unittest.TestCase):
def setUp(self):
# Create a fresh Metrics instance for most tests to avoid interference
self.logger = logging.getLogger('tools.metrics.test')
if not self.logger.handlers: # Avoid adding handler multiple times
self.logger.addHandler(logging.NullHandler())
self.metrics_instance = Metrics(logger=self.logger)
# Clear the global instance before each test that might use it
global_metrics_instance.clear_metrics()
def test_measure_decorator_counts_calls(self):
decorated_func = self.metrics_instance.measure(sample_function_for_metrics)
self.assertEqual(self.metrics_instance.call_count["sample_function_for_metrics"], 0)
decorated_func()
self.assertEqual(self.metrics_instance.call_count["sample_function_for_metrics"], 1)
decorated_func()
self.assertEqual(self.metrics_instance.call_count["sample_function_for_metrics"], 2)
@patch('cProfile.Profile')
@patch('pstats.Stats')
def test_measure_decorator_records_time(self, MockPStats, MockCProfile):
# Mock cProfile and pstats to control the time value
mock_profiler_instance = MockCProfile.return_value
mock_pstats_instance = MockPStats.return_value
# Simulate that pstats.Stats.stats dictionary contains the function's stats
# Key: (filename, lineno, funcname)
# Value: (cc, nc, tt, ct, callers) where ct is cumulative_time (index 3)
# Get code object of the function *before* decoration for correct key
original_func_code = sample_function_for_metrics.__code__
func_key = (original_func_code.co_filename, original_func_code.co_firstlineno, original_func_code.co_name)
# Configure mock_pstats_instance.stats to return our desired time
mock_pstats_instance.stats = {func_key: (1, 1, 0.05, 0.123, {})} # cc, nc, tt, ct=0.123
decorated_func = self.metrics_instance.measure(sample_function_for_metrics)
self.assertEqual(self.metrics_instance.total_time["sample_function_for_metrics"], 0)
# Call the decorated function
decorated_func(duration=0) # Duration arg doesn't matter due to mocking
# Assertions
mock_profiler_instance.enable.assert_called_once()
mock_profiler_instance.disable.assert_called_once()
MockPStats.assert_called_once_with(mock_profiler_instance)
self.assertEqual(self.metrics_instance.total_time["sample_function_for_metrics"], 0.123)
# Call again to see accumulation
# Reset mock stats for a new time value if needed, or assume same time per call
mock_pstats_instance.stats = {func_key: (1, 1, 0.05, 0.100, {})} # New ct=0.100
decorated_func(duration=0)
self.assertAlmostEqual(self.metrics_instance.total_time["sample_function_for_metrics"], 0.123 + 0.100)
@patch('cProfile.Profile')
@patch('pstats.Stats')
def test_measure_decorator_fallback_time_recording_by_name(self, MockPStats, MockCProfile):
mock_profiler_instance = MockCProfile.return_value
mock_pstats_instance = MockPStats.return_value
original_func_code = sample_function_for_metrics.__code__ # func to be decorated
# Simulate the primary key lookup fails by creating a slightly different key for what we expect
# This is what the code will try to look up first.
expected_primary_key = (original_func_code.co_filename, original_func_code.co_firstlineno, original_func_code.co_name)
# This is the key that will *actually* be in pstats.stats, simulating a mismatch for primary lookup
# but a match for the by-name fallback.
actual_stats_key_in_pstats = (original_func_code.co_filename,
original_func_code.co_firstlineno + 5, # simulate a lineno difference for primary key mismatch
original_func_code.co_name) # Name is the same for fallback
mock_pstats_instance.stats = {
# expected_primary_key is NOT present
actual_stats_key_in_pstats: (1, 1, 0.03, 0.077, {}) # ct = 0.077
}
decorated_func = self.metrics_instance.measure(sample_function_for_metrics)
# Expecting a debug log for fallback, but assertLogs needs the logger to have a handler that captures.
# self.logger is already set up with NullHandler. For this test, let's use a specific logger.
metrics_internal_logger = logging.getLogger('tools.metrics') # Logger used inside Metrics class
original_level = metrics_internal_logger.level
metrics_internal_logger.setLevel(logging.DEBUG)
with self.assertLogs(metrics_internal_logger, level='DEBUG') as log_capture:
decorated_func(duration=0)
metrics_internal_logger.setLevel(original_level) # Reset logger level
self.assertTrue(any("Found stats for sample_function_for_metrics by name" in msg for msg in log_capture.output))
self.assertEqual(self.metrics_instance.total_time["sample_function_for_metrics"], 0.077)
@patch('cProfile.Profile')
@patch('pstats.Stats')
def test_measure_decorator_handles_func_stats_not_found(self, MockPStats, MockCProfile):
mock_profiler_instance = MockCProfile.return_value
mock_pstats_instance = MockPStats.return_value
mock_pstats_instance.stats = {} # Empty stats, function will not be found
decorated_func = self.metrics_instance.measure(sample_function_for_metrics)
metrics_internal_logger = logging.getLogger('tools.metrics')
original_level = metrics_internal_logger.level
metrics_internal_logger.setLevel(logging.WARNING)
with self.assertLogs(metrics_internal_logger, level='WARNING') as log_capture:
decorated_func(duration=0)
metrics_internal_logger.setLevel(original_level)
self.assertTrue(any("Could not find exact cProfile stats" in msg for msg in log_capture.output))
self.assertEqual(self.metrics_instance.total_time["sample_function_for_metrics"], 0)
def test_get_metrics_empty(self):
self.assertEqual(self.metrics_instance.get_metrics(), {})
@patch('cProfile.Profile')
@patch('pstats.Stats')
def test_get_metrics_with_data(self, MockPStats, MockCProfile):
mock_pstats_instance = MockPStats.return_value
# Decorate two different functions
decorated_func1 = self.metrics_instance.measure(sample_function_for_metrics)
decorated_func2 = self.metrics_instance.measure(another_sample_function)
# Data for func1
func1_code = sample_function_for_metrics.__code__
func1_key = (func1_code.co_filename, func1_code.co_firstlineno, func1_code.co_name)
mock_pstats_instance.stats = {func1_key: (1,1,0.1,0.1,{})}
decorated_func1()
# Data for func2
func2_code = another_sample_function.__code__
func2_key = (func2_code.co_filename, func2_code.co_firstlineno, func2_code.co_name)
mock_pstats_instance.stats = {func2_key: (1,1,0.2,0.2,{})} # Cumulative time 0.2
decorated_func2(1,2)
mock_pstats_instance.stats = {func2_key: (1,1,0.3,0.3,{})} # Cumulative time 0.3 for second call
decorated_func2(3,4)
metrics_data = self.metrics_instance.get_metrics()
self.assertIn("sample_function_for_metrics", metrics_data)
self.assertEqual(metrics_data["sample_function_for_metrics"]["call_count"], 1)
self.assertEqual(metrics_data["sample_function_for_metrics"]["total_time"], 0.1)
self.assertEqual(metrics_data["sample_function_for_metrics"]["average_time"], 0.1)
self.assertIn("another_sample_function", metrics_data)
self.assertEqual(metrics_data["another_sample_function"]["call_count"], 2)
self.assertAlmostEqual(metrics_data["another_sample_function"]["total_time"], 0.5)
self.assertAlmostEqual(metrics_data["another_sample_function"]["average_time"], 0.25)
def test_clear_metrics(self):
# Add some data
self.metrics_instance.call_count["test_func"] = 5
self.metrics_instance.total_time["test_func"] = 1.234
self.metrics_instance.clear_metrics()
self.assertEqual(self.metrics_instance.call_count, {})
self.assertEqual(self.metrics_instance.total_time, {})
self.assertEqual(self.metrics_instance.get_metrics(), {})
# Test the global instance
@patch('cProfile.Profile')
@patch('pstats.Stats')
def test_global_metrics_instance_usage(self, MockPStats, MockCProfile):
mock_pstats_instance = MockPStats.return_value
# Decorate a function with the global instance
@global_metrics_instance.measure
def globally_decorated_func():
return "global_output"
# Setup mock stats for the globally decorated function
# Access __wrapped__ to get the original function if other decorators might be present or for consistency.
original_g_func = globally_decorated_func.__wrapped__
func_code = original_g_func.__code__
func_key = (func_code.co_filename, func_code.co_firstlineno, func_code.co_name)
mock_pstats_instance.stats = {func_key: (1,1,0.05,0.05,{})}
globally_decorated_func()
metrics_data = global_metrics_instance.get_metrics()
self.assertIn("globally_decorated_func", metrics_data)
self.assertEqual(metrics_data["globally_decorated_func"]["call_count"], 1)
self.assertEqual(metrics_data["globally_decorated_func"]["total_time"], 0.05)
if __name__ == '__main__':
unittest.main()
-161
View File
@@ -1,161 +0,0 @@
import unittest
from unittest.mock import MagicMock, patch
import logging
# Ensure tools.metrics_tool and tools.metrics are accessible
from tools.metrics_tool import MetricsTool
from tools.metrics import Metrics # Used for typehinting and creating a mockable instance
class TestMetricsTool(unittest.TestCase):
def setUp(self):
self.mock_metrics_provider = MagicMock(spec=Metrics)
self.logger = logging.getLogger('tools.metrics_tool.test')
if not self.logger.handlers:
self.logger.addHandler(logging.NullHandler())
self.logger.propagate = False
def test_init_with_provider(self):
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
self.assertEqual(tool.metrics_provider, self.mock_metrics_provider)
@patch('tools.metrics_tool.global_metrics_instance') # Patch the global instance path
def test_init_default_provider(self, mock_global_metrics):
tool = MetricsTool(logger=self.logger)
self.assertEqual(tool.metrics_provider, mock_global_metrics)
def test_get_functions(self):
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
functions = tool.get_functions()
self.assertIsInstance(functions, list)
self.assertTrue(len(functions) == 3) # Based on current definition
self.assertIn("get_function_metrics", [f["function"]["name"] for f in functions])
self.assertIn("get_specific_function_metrics", [f["function"]["name"] for f in functions])
self.assertIn("get_top_n_functions", [f["function"]["name"] for f in functions])
def test_execute_get_function_metrics(self):
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
expected_metrics = {"func1": {"call_count": 1, "total_time": 0.1}}
self.mock_metrics_provider.get_metrics.return_value = expected_metrics
result = tool.execute(function_name="get_function_metrics")
self.mock_metrics_provider.get_metrics.assert_called_once()
self.assertEqual(result, expected_metrics)
def test_execute_get_specific_function_metrics_found(self):
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
func_metrics = {"call_count": 5, "total_time": 0.5, "average_time": 0.1}
all_metrics = {"specific_func": func_metrics, "other_func": {}}
self.mock_metrics_provider.get_metrics.return_value = all_metrics
# The execute method expects kwargs that match the function parameters in get_functions.
# So, the argument name for the function to get is 'function_name' in the tool's spec.
result = tool.execute(function_name="get_specific_function_metrics", **{"function_name": "specific_func"})
self.assertEqual(result, func_metrics)
def test_execute_get_specific_function_metrics_not_found(self):
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
self.mock_metrics_provider.get_metrics.return_value = {"other_func": {}}
result = tool.execute(function_name="get_specific_function_metrics", **{"function_name": "non_existent_func"})
self.assertEqual(result, "No metrics found for function: non_existent_func")
def test_execute_get_specific_function_metrics_missing_arg(self):
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
result = tool.execute(function_name="get_specific_function_metrics") # Missing function_name kwarg
self.assertIn("Error: Missing required argument 'function_name'", result)
def test_execute_get_top_n_functions(self):
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
metrics_data = {
"func_a": {"call_count": 1, "total_time": 0.3},
"func_b": {"call_count": 1, "total_time": 0.1},
"func_c": {"call_count": 1, "total_time": 0.5},
"func_d": {"call_count": 1, "total_time": 0.2},
}
self.mock_metrics_provider.get_metrics.return_value = metrics_data
# Test getting top 2
result = tool.execute(function_name="get_top_n_functions", n=2)
expected_top_2 = {"func_c": metrics_data["func_c"], "func_a": metrics_data["func_a"]}
self.assertEqual(result, expected_top_2)
# Test getting top 1
result_top_1 = tool.execute(function_name="get_top_n_functions", n=1)
expected_top_1 = {"func_c": metrics_data["func_c"]}
self.assertEqual(result_top_1, expected_top_1)
# Test N larger than available functions
result_top_all = tool.execute(function_name="get_top_n_functions", n=10)
# Order should be func_c, func_a, func_d, func_b
expected_top_all_keys = ["func_c", "func_a", "func_d", "func_b"]
self.assertEqual(list(result_top_all.keys()), expected_top_all_keys)
def test_execute_get_top_n_functions_malformed_metrics(self):
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
metrics_data = {
"func_a": {"call_count": 1, "total_time": 0.3},
"func_b": "not a dict", # Malformed
"func_c": {"call_count": 1}, # Missing total_time
"func_d": {"call_count": 1, "total_time": 0.2},
}
self.mock_metrics_provider.get_metrics.return_value = metrics_data
metrics_tool_logger = logging.getLogger('tools.metrics_tool')
original_level = metrics_tool_logger.level
metrics_tool_logger.setLevel(logging.WARNING)
with self.assertLogs(metrics_tool_logger, level='WARNING') as log_capture:
result = tool.execute(function_name="get_top_n_functions", n=2)
metrics_tool_logger.setLevel(original_level)
# Check that warnings were logged for malformed items
self.assertTrue(any("Metric item for 'func_b' is not in expected format" in msg for msg in log_capture.output))
self.assertTrue(any("Metric item for 'func_c' is not in expected format" in msg for msg in log_capture.output))
# Expected: func_a, func_d (as they are valid and sortable)
expected_result = {
"func_a": metrics_data["func_a"],
"func_d": metrics_data["func_d"]
}
self.assertEqual(result, expected_result)
def test_execute_get_top_n_functions_invalid_n(self):
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
self.mock_metrics_provider.get_metrics.return_value = {} # No metrics needed for this test
result_zero = tool.execute(function_name="get_top_n_functions", n=0)
self.assertIn("Error: Argument 'n' must be a positive integer.", result_zero)
result_negative = tool.execute(function_name="get_top_n_functions", n=-1)
self.assertIn("Error: Argument 'n' must be a positive integer.", result_negative)
result_string = tool.execute(function_name="get_top_n_functions", n="abc")
self.assertIn("Error: Argument 'n' must be an integer.", result_string)
def test_execute_get_top_n_functions_missing_arg_n(self):
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
result = tool.execute(function_name="get_top_n_functions") # Missing n
self.assertIn("Error: Missing required argument 'n'.", result)
def test_execute_unknown_function(self):
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
result = tool.execute(function_name="non_existent_metrics_function")
self.assertIn("Unknown function: non_existent_metrics_function", result)
def test_clear_method(self):
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
metrics_tool_logger = logging.getLogger('tools.metrics_tool')
original_level = metrics_tool_logger.level
metrics_tool_logger.setLevel(logging.DEBUG)
with self.assertLogs(metrics_tool_logger, level='DEBUG') as cm:
tool.clear()
metrics_tool_logger.setLevel(original_level)
self.assertTrue(any("MetricsTool clear method called" in message for message in cm.output))
if __name__ == '__main__':
unittest.main()