Refactored gemini, openai and claude into one file and removed logic from the base class, also made helper class definable from command line
This commit is contained in:
@@ -1,33 +0,0 @@
|
||||
import unittest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from anthropic_telegram_inference_bot import AnthropicTelegramInferenceBot
|
||||
|
||||
class TestAnthropicTelegramInferenceBot(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.bot = AnthropicTelegramInferenceBot()
|
||||
|
||||
@patch('anthropic_telegram_inference_bot.Anthropic')
|
||||
def test_get_chat_response(self, MockAnthropic):
|
||||
mock_anthropic = MockAnthropic.return_value
|
||||
mock_anthropic.messages.create.return_value = MagicMock()
|
||||
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
response = self.bot.get_chat_response(messages)
|
||||
|
||||
self.assertIsNotNone(response)
|
||||
|
||||
@patch('anthropic_telegram_inference_bot.Anthropic')
|
||||
def test_handle_message(self, MockAnthropic):
|
||||
mock_anthropic = MockAnthropic.return_value
|
||||
mock_anthropic.messages.create.return_value = MagicMock(content=[MagicMock(type="message", text="response content")])
|
||||
|
||||
user_id = "user123"
|
||||
user_message = "Hello"
|
||||
response = self.bot.handle_message(user_id, user_message)
|
||||
|
||||
self.assertIsNotNone(response)
|
||||
|
||||
# Additional testing for error cases and edge cases
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -1,33 +0,0 @@
|
||||
import unittest
|
||||
from base_telegram_inference_bot import BaseTelegramInferenceBot
|
||||
|
||||
class TestBaseTelegramInferenceBot(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# Initialize the bot or mock any dependencies here
|
||||
self.bot = BaseTelegramInferenceBot()
|
||||
|
||||
def test_load_system_prompt(self):
|
||||
# Example test case for load_system_prompt method
|
||||
result = self.bot.load_system_prompt()
|
||||
self.assertIsNotNone(result) # Replace with actual expected result
|
||||
|
||||
def test_load_functions(self):
|
||||
# Test the load_functions method
|
||||
functions = self.bot.load_functions()
|
||||
self.assertIsInstance(functions, list) # Replace with actual expected result
|
||||
self.assertTrue(len(functions) > 0) # Assuming it should load some functions
|
||||
|
||||
def test_clear_conversation(self):
|
||||
# Test the clear_conversation method
|
||||
self.bot.clear_conversation()
|
||||
self.assertEqual(self.bot.conversations, {}) # Assuming conversations is a dictionary
|
||||
|
||||
def test_call_tool(self):
|
||||
# Test the call_tool method
|
||||
tool_name = "some_tool"
|
||||
params = {"param1": "value1"}
|
||||
result = self.bot.call_tool(tool_name, params)
|
||||
self.assertIsNotNone(result) # Replace with actual expected result
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -1,38 +0,0 @@
|
||||
import unittest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from chatgpt_telegram_inference_bot import ChatGPTTelegramInferenceBot
|
||||
|
||||
class TestChatGPTTelegramInferenceBot(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.bot = ChatGPTTelegramInferenceBot()
|
||||
|
||||
@patch('chatgpt_telegram_inference_bot.OpenAI')
|
||||
def test_get_chat_response(self, MockOpenAI):
|
||||
mock_ai = MockOpenAI.return_value
|
||||
mock_ai.chat.completions.create.return_value = MagicMock()
|
||||
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
response = self.bot.get_chat_response(messages)
|
||||
|
||||
self.assertIsNotNone(response)
|
||||
|
||||
@patch('chatgpt_telegram_inference_bot.OpenAI')
|
||||
def test_handle_message(self, MockOpenAI):
|
||||
mock_ai = MockOpenAI.return_value
|
||||
mock_ai.chat.completions.create.return_value = MagicMock(choices=[MagicMock(message={"content": "response content"}, finish_reason='stop')])
|
||||
|
||||
user_id = "user123"
|
||||
user_message = "Hello"
|
||||
response = self.bot.handle_message(user_id, user_message)
|
||||
|
||||
self.assertIsNotNone(response)
|
||||
|
||||
def test_switch_model(self):
|
||||
initial_model = self.bot.model
|
||||
self.bot.switch_model()
|
||||
self.assertNotEqual(initial_model, self.bot.model)
|
||||
|
||||
# Additional testing for error cases and edge cases
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -1,280 +0,0 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch, AsyncMock, ANY
|
||||
import os
|
||||
|
||||
# Assuming anthropic_telegram_inference_bot.py is in the parent directory or PYTHONPATH is set
|
||||
from anthropic_telegram_inference_bot import AnthropicTelegramInferenceBot
|
||||
|
||||
# Mock response from Anthropic client's messages.create
|
||||
def create_mock_anthropic_response(content_text=None, stop_reason="end_turn", tool_use_parts=None):
|
||||
mock_response = MagicMock()
|
||||
mock_response.stop_reason = stop_reason
|
||||
|
||||
content_blocks = []
|
||||
if content_text:
|
||||
text_block = MagicMock()
|
||||
text_block.type = "text"
|
||||
text_block.text = content_text
|
||||
content_blocks.append(text_block)
|
||||
|
||||
if tool_use_parts:
|
||||
for tu_part in tool_use_parts: # tu_part = {"id": "toolu_123", "name": "get_weather", "input": {}}
|
||||
tool_block = MagicMock()
|
||||
tool_block.type = "tool_use"
|
||||
tool_block.id = tu_part["id"]
|
||||
tool_block.name = tu_part["name"]
|
||||
tool_block.input = tu_part["input"]
|
||||
content_blocks.append(tool_block)
|
||||
|
||||
mock_response.content = content_blocks
|
||||
return mock_response
|
||||
|
||||
class TestAnthropicTelegramInferenceBot(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.original_anthropic_api_key = os.environ.get("ANTHROPIC_API_KEY")
|
||||
self.original_small_model = os.environ.get("ANTHROPIC_SMALL_MODEL")
|
||||
self.original_large_model = os.environ.get("ANTHROPIC_LARGE_MODEL")
|
||||
self.original_system_prompt_path = os.environ.get("SYSTEM_PROMPT_PATH")
|
||||
|
||||
for key in ["ANTHROPIC_API_KEY", "ANTHROPIC_SMALL_MODEL", "ANTHROPIC_LARGE_MODEL", "SYSTEM_PROMPT_PATH"]:
|
||||
if os.environ.get(key):
|
||||
del os.environ[key]
|
||||
|
||||
self.mock_anthropic_client_instance = MagicMock()
|
||||
self.mock_anthropic_client_instance.messages.create = MagicMock()
|
||||
|
||||
def tearDown(self):
|
||||
if self.original_anthropic_api_key: os.environ["ANTHROPIC_API_KEY"] = self.original_anthropic_api_key
|
||||
if self.original_small_model: os.environ["ANTHROPIC_SMALL_MODEL"] = self.original_small_model
|
||||
if self.original_large_model: os.environ["ANTHROPIC_LARGE_MODEL"] = self.original_large_model
|
||||
if self.original_system_prompt_path: os.environ["SYSTEM_PROMPT_PATH"] = self.original_system_prompt_path
|
||||
|
||||
@patch('anthropic.Anthropic')
|
||||
def test_init_with_anthropic_defaults_env_key(self, MockAnthropicConstructor):
|
||||
MockAnthropicConstructor.return_value = self.mock_anthropic_client_instance
|
||||
os.environ["ANTHROPIC_API_KEY"] = "test_anthropic_key"
|
||||
|
||||
bot = AnthropicTelegramInferenceBot()
|
||||
|
||||
MockAnthropicConstructor.assert_called_once_with(api_key="test_anthropic_key")
|
||||
self.assertEqual(bot.anthropic_client, self.mock_anthropic_client_instance)
|
||||
self.assertEqual(bot.model, os.environ.get("ANTHROPIC_SMALL_MODEL", "claude-3-haiku-20240307"))
|
||||
self.assertEqual(bot.max_tokens, int(os.environ.get("ANTHROPIC_SMALL_MODEL_MAX_TOKENS", 2000)))
|
||||
|
||||
@patch('anthropic.Anthropic')
|
||||
def test_init_with_provided_client_and_models(self, MockAnthropicConstructor):
|
||||
preconfigured_client = MagicMock()
|
||||
bot = AnthropicTelegramInferenceBot(
|
||||
anthropic_client=preconfigured_client,
|
||||
small_model_name="custom-small",
|
||||
small_model_max_tokens=100,
|
||||
large_model_name="custom-large",
|
||||
large_model_max_tokens=200
|
||||
)
|
||||
|
||||
MockAnthropicConstructor.assert_not_called()
|
||||
self.assertEqual(bot.anthropic_client, preconfigured_client)
|
||||
self.assertEqual(bot.model, "custom-small")
|
||||
self.assertEqual(bot.max_tokens, 100)
|
||||
self.assertEqual(bot.small_model_name, "custom-small")
|
||||
self.assertEqual(bot.large_model_name, "custom-large")
|
||||
|
||||
|
||||
def test_get_llm_description(self):
|
||||
bot = AnthropicTelegramInferenceBot(small_model_name="claude-test", small_model_max_tokens=500)
|
||||
self.assertEqual(bot.get_llm_description(), "LLM: claude-test, Max Tokens: 500")
|
||||
|
||||
async def test_switch_model(self):
|
||||
bot = AnthropicTelegramInferenceBot(
|
||||
small_model_name="claude-small", small_model_max_tokens=10,
|
||||
large_model_name="claude-large", large_model_max_tokens=20
|
||||
)
|
||||
self.assertEqual(bot.model, "claude-small")
|
||||
self.assertEqual(bot.max_tokens, 10)
|
||||
|
||||
status = await bot.switch_model()
|
||||
self.assertEqual(bot.model, "claude-large")
|
||||
self.assertEqual(bot.max_tokens, 20)
|
||||
self.assertEqual(status, "Switched to model: claude-large")
|
||||
|
||||
status = await bot.switch_model()
|
||||
self.assertEqual(bot.model, "claude-small")
|
||||
self.assertEqual(bot.max_tokens, 10)
|
||||
self.assertEqual(status, "Switched to model: claude-small")
|
||||
|
||||
def test_get_chat_response_success_text_only(self):
|
||||
bot = AnthropicTelegramInferenceBot(anthropic_client=self.mock_anthropic_client_instance)
|
||||
bot.model = "test-claude"
|
||||
bot.max_tokens = 150
|
||||
|
||||
mock_api_response = create_mock_anthropic_response(content_text="Hello from Anthropic API")
|
||||
self.mock_anthropic_client_instance.messages.create.return_value = mock_api_response
|
||||
|
||||
messages = [{"role": "user", "content": "Hi"}] # Anthropic format
|
||||
response = bot.get_chat_response(messages, []) # tools = empty list
|
||||
|
||||
self.mock_anthropic_client_instance.messages.create.assert_called_once_with(
|
||||
model="test-claude",
|
||||
max_tokens=150,
|
||||
messages=messages,
|
||||
system=bot.system_prompt, # Ensure system prompt is passed
|
||||
tools=None, # No tools passed to API if empty list or None
|
||||
tool_choice=None
|
||||
)
|
||||
self.assertEqual(response, mock_api_response)
|
||||
|
||||
def test_get_chat_response_with_tools(self):
|
||||
bot = AnthropicTelegramInferenceBot(anthropic_client=self.mock_anthropic_client_instance)
|
||||
bot.model = "claude-toolmaster"
|
||||
bot.max_tokens = 300
|
||||
|
||||
mock_tools_spec = [{"name": "get_weather", "description": "Gets weather", "input_schema": {"type": "object", "properties": {}}}]
|
||||
|
||||
mock_api_response = create_mock_anthropic_response(content_text="Thinking...", tool_use_parts=[
|
||||
{"id": "tool1", "name": "get_weather", "input": {"location": "here"}}
|
||||
])
|
||||
self.mock_anthropic_client_instance.messages.create.return_value = mock_api_response
|
||||
|
||||
messages = [{"role": "user", "content": "Weather?"}]
|
||||
response = bot.get_chat_response(messages, mock_tools_spec)
|
||||
|
||||
self.mock_anthropic_client_instance.messages.create.assert_called_once_with(
|
||||
model="claude-toolmaster",
|
||||
max_tokens=300,
|
||||
messages=messages,
|
||||
system=bot.system_prompt,
|
||||
tools=mock_tools_spec,
|
||||
tool_choice={"type": "auto"}
|
||||
)
|
||||
self.assertEqual(response.content[0].type, "text") # First part can be text
|
||||
self.assertEqual(response.content[1].type, "tool_use")
|
||||
|
||||
|
||||
def test_get_chat_response_api_error(self):
|
||||
bot = AnthropicTelegramInferenceBot(anthropic_client=self.mock_anthropic_client_instance)
|
||||
self.mock_anthropic_client_instance.messages.create.side_effect = Exception("Anthropic API Down")
|
||||
|
||||
with self.assertRaisesRegex(Exception, "Anthropic API Down"):
|
||||
bot.get_chat_response([{"role": "user", "content": "trigger"}], [])
|
||||
|
||||
|
||||
async def test_handle_message_simple_response_no_tools(self):
|
||||
# This test is more involved as it touches BaseTelegramInferenceBot's handle_message structure
|
||||
# which then calls the overridden get_chat_response.
|
||||
bot = AnthropicTelegramInferenceBot(anthropic_client=self.mock_anthropic_client_instance)
|
||||
bot.system_prompt = "System prompt for Anthropic"
|
||||
|
||||
# Mock get_chat_response directly to isolate its behavior from full handle_message logic of base
|
||||
# However, the point of this bot is its get_chat_response and subsequent processing.
|
||||
# So, let's mock the API call within get_chat_response.
|
||||
|
||||
api_response = create_mock_anthropic_response(content_text="Anthropic says hello.")
|
||||
self.mock_anthropic_client_instance.messages.create.return_value = api_response
|
||||
|
||||
# Ensure functions are empty for this test, so no tool logic is triggered
|
||||
bot.functions = []
|
||||
bot.tools = []
|
||||
|
||||
response_content = await bot.handle_message(user_id=101, user_message="Hello Anthropic")
|
||||
|
||||
self.assertEqual(response_content, "Anthropic says hello.")
|
||||
self.assertIn(101, bot.conversation_history)
|
||||
# Anthropic's handle_message structure:
|
||||
# 1. User message added to history.
|
||||
# 2. get_chat_response is called.
|
||||
# 3. Response content (text) is extracted.
|
||||
# 4. Assistant text response is added to history.
|
||||
# Expected history: [User, Assistant_Text_Response] (system prompt handled by get_chat_response)
|
||||
# The base class handle_message adds system prompt if not present.
|
||||
# Anthropic handle_message modifies history format before calling get_chat_response.
|
||||
|
||||
# Let's trace Base.handle_message -> Anthropic.handle_message -> Anthropic.get_chat_response
|
||||
# Base.handle_message:
|
||||
# - Adds system prompt to history if first turn: `self.conversation_history[user_id] = [{"role": "system", "content": self.system_prompt}]` (OpenAI style)
|
||||
# - Appends user message: `{"role": "user", "content": user_message}`
|
||||
# - Calls self.get_chat_response(messages, self.functions) -> This is Anthropic's get_chat_response
|
||||
# Anthropic.get_chat_response:
|
||||
# - Takes OpenAI style `messages` and `self.functions` (tool specs).
|
||||
# - Calls `anthropic_client.messages.create` with Anthropic style messages and system prompt.
|
||||
# Anthropic.handle_message (overridden):
|
||||
# - Prepares Anthropic-style messages from conversation_history (which is OpenAI style from Base)
|
||||
# - Calls get_chat_response with these Anthropic messages and self.functions (tool_specs)
|
||||
# - Processes response, extracts text, handles tool calls.
|
||||
# - Appends *user* message (original) and *assistant* text response to self.conversation_history (OpenAI style).
|
||||
|
||||
# For this test, we are calling AnthropicBot.handle_message directly.
|
||||
# 1. `user_id` not in `self.conversation_history`: `system_prompt` not added yet by Base logic.
|
||||
# Anthropic's `handle_message` will create `anthropic_messages` from this.
|
||||
# If `conversation_history` is empty, `anthropic_messages` = `[{"role": "user", "content": user_message}]`
|
||||
# 2. `get_chat_response` called with `anthropic_messages` and `bot.system_prompt` passed to API.
|
||||
# 3. Response "Anthropic says hello."
|
||||
# 4. Original `user_message` and "Anthropic says hello." (as assistant) added to `self.conversation_history`.
|
||||
|
||||
history = bot.conversation_history[101]
|
||||
self.assertEqual(len(history), 2) # User, Assistant
|
||||
self.assertEqual(history[0]["role"], "user")
|
||||
self.assertEqual(history[0]["content"], "Hello Anthropic")
|
||||
self.assertEqual(history[1]["role"], "assistant")
|
||||
self.assertEqual(history[1]["content"], "Anthropic says hello.")
|
||||
|
||||
# Check API call (made by the mocked get_chat_response indirectly)
|
||||
self.mock_anthropic_client_instance.messages.create.assert_called_once()
|
||||
call_args = self.mock_anthropic_client_instance.messages.create.call_args
|
||||
self.assertEqual(call_args.kwargs["system"], "System prompt for Anthropic")
|
||||
# Initial messages for API should just be the user message for first turn
|
||||
self.assertEqual(call_args.kwargs["messages"], [{"role": "user", "content": "Hello Anthropic"}])
|
||||
|
||||
|
||||
async def test_handle_message_with_tool_calls(self):
|
||||
bot = AnthropicTelegramInferenceBot(anthropic_client=self.mock_anthropic_client_instance)
|
||||
bot.system_prompt = "You are a helpful, tool-using assistant."
|
||||
|
||||
# Define a tool for the bot (OpenAI format, will be converted by Anthropic bot for API)
|
||||
mock_tool_oai_format = {"type": "function", "function": {"name": "get_weather", "description": "Get weather", "parameters": {}}}
|
||||
bot.functions = [mock_tool_oai_format] # This is used to generate anthropic_tools for API
|
||||
|
||||
# API Response 1: Request for tool call
|
||||
tool_use_part = {"id": "toolu_xyz", "name": "get_weather", "input": {"location": "paris"}}
|
||||
api_response_1 = create_mock_anthropic_response(tool_use_parts=[tool_use_part])
|
||||
|
||||
# API Response 2: Final text response after tool execution
|
||||
api_response_2 = create_mock_anthropic_response(content_text="The weather in Paris is nice.")
|
||||
|
||||
self.mock_anthropic_client_instance.messages.create.side_effect = [api_response_1, api_response_2]
|
||||
|
||||
# Mock the bot's call_tool method (from BaseTelegramInferenceBot)
|
||||
bot.call_tool = MagicMock(return_value='''{"weather": "sunny"}''') # Tool execution result
|
||||
|
||||
user_id = 102
|
||||
user_message = "What's the weather in Paris?"
|
||||
final_text_response = await bot.handle_message(user_id, user_message)
|
||||
|
||||
self.assertEqual(final_text_response, "The weather in Paris is nice.")
|
||||
self.assertEqual(self.mock_anthropic_client_instance.messages.create.call_count, 2)
|
||||
|
||||
bot.call_tool.assert_called_once_with("get_weather", {"location": "paris"}) # Anthropic passes input as dict
|
||||
|
||||
# Check conversation history (OpenAI style)
|
||||
history = bot.conversation_history[user_id]
|
||||
self.assertEqual(history[0]["role"], "user")
|
||||
self.assertEqual(history[0]["content"], user_message)
|
||||
|
||||
# Assistant message that requested tool call (Anthropic-specific format stored by its handle_message)
|
||||
# Anthropic's handle_message appends the raw tool_use block and then the tool_result
|
||||
self.assertEqual(history[1]["role"], "assistant")
|
||||
self.assertTrue(isinstance(history[1]["content"], list)) # Anthropic content is a list
|
||||
self.assertEqual(history[1]["content"][0]["type"], "tool_use")
|
||||
self.assertEqual(history[1]["content"][0]["id"], "toolu_xyz")
|
||||
|
||||
self.assertEqual(history[2]["role"], "tool")
|
||||
self.assertEqual(history[2]["tool_call_id"], "toolu_xyz")
|
||||
self.assertEqual(history[2]["name"], "get_weather")
|
||||
self.assertEqual(history[2]["content"], '''{"weather": "sunny"}''') # call_tool result
|
||||
|
||||
self.assertEqual(history[3]["role"], "assistant") # Final text response
|
||||
self.assertTrue(isinstance(history[3]["content"], str)) # simple text
|
||||
self.assertEqual(history[3]["content"], "The weather in Paris is nice.")
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -1,310 +0,0 @@
|
||||
import unittest
|
||||
from unittest.mock import patch, mock_open, MagicMock
|
||||
import os
|
||||
import json
|
||||
|
||||
# Ensure the path includes the directory where base_telegram_inference_bot is located
|
||||
# This might require adjustment based on actual project structure if tests are run from root
|
||||
# For now, assuming it can be imported directly or via PYTHONPATH
|
||||
from base_telegram_inference_bot import BaseTelegramInferenceBot
|
||||
from tools.base_tool import BaseTool # For mocking tool structure
|
||||
|
||||
# Create a concrete subclass for testing, as BaseTelegramInferenceBot is abstract
|
||||
class ConcreteTestBot(BaseTelegramInferenceBot):
|
||||
def __init__(self, system_prompt_content=None, system_prompt_path=None, mock_tools=None, mock_functions=None):
|
||||
# Mock load_functions during super().__init__ if needed, or control tools/functions directly
|
||||
self._mock_tools = mock_tools if mock_tools is not None else []
|
||||
self._mock_functions = mock_functions if mock_functions is not None else []
|
||||
super().__init__(system_prompt_content=system_prompt_content, system_prompt_path=system_prompt_path)
|
||||
|
||||
# Override load_functions to use mocks
|
||||
def load_functions(self):
|
||||
return self._mock_tools, self._mock_functions
|
||||
|
||||
def get_chat_response(self, messages):
|
||||
pass # Abstract method, not tested here directly
|
||||
|
||||
async def handle_message(self, user_id, user_message):
|
||||
pass # Abstract method
|
||||
|
||||
def get_llm_description(self) -> str:
|
||||
return "Mock LLM Description" # Concrete implementation for testing get_bot_status
|
||||
|
||||
async def start(self):
|
||||
pass # Abstract method
|
||||
|
||||
async def abort_processing(self, user_id):
|
||||
pass # Abstract method
|
||||
|
||||
async def switch_model(self):
|
||||
pass # Abstract method
|
||||
|
||||
class TestBaseTelegramInferenceBot(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Reset relevant environment variables before each test
|
||||
self.original_system_prompt_path = os.environ.get("SYSTEM_PROMPT_PATH")
|
||||
if "SYSTEM_PROMPT_PATH" in os.environ:
|
||||
del os.environ["SYSTEM_PROMPT_PATH"]
|
||||
|
||||
def tearDown(self):
|
||||
# Restore environment variables
|
||||
if self.original_system_prompt_path:
|
||||
os.environ["SYSTEM_PROMPT_PATH"] = self.original_system_prompt_path
|
||||
elif "SYSTEM_PROMPT_PATH" in os.environ: # Ensure it's removed if test set it and it wasn't there before
|
||||
del os.environ["SYSTEM_PROMPT_PATH"]
|
||||
|
||||
def test_init_with_direct_system_prompt(self):
|
||||
bot = ConcreteTestBot(system_prompt_content="Direct prompt content")
|
||||
self.assertEqual(bot.system_prompt, "Direct prompt content")
|
||||
|
||||
@patch("os.path.isfile")
|
||||
@patch("builtins.open", new_callable=mock_open, read_data="File prompt content")
|
||||
def test_init_with_system_prompt_path_argument(self, mock_file_open, mock_isfile):
|
||||
mock_isfile.return_value = True
|
||||
bot = ConcreteTestBot(system_prompt_path="dummy/path.txt")
|
||||
self.assertEqual(bot.system_prompt, "File prompt content")
|
||||
mock_isfile.assert_called_once_with("dummy/path.txt")
|
||||
mock_file_open.assert_called_once_with("dummy/path.txt", "r", encoding="utf-8")
|
||||
|
||||
@patch("os.path.isfile")
|
||||
@patch("builtins.open", new_callable=mock_open, read_data="Env prompt content")
|
||||
def test_init_with_env_system_prompt_path(self, mock_file_open, mock_isfile):
|
||||
mock_isfile.return_value = True
|
||||
os.environ["SYSTEM_PROMPT_PATH"] = "env/path.txt"
|
||||
bot = ConcreteTestBot()
|
||||
self.assertEqual(bot.system_prompt, "Env prompt content")
|
||||
mock_isfile.assert_called_once_with("env/path.txt")
|
||||
mock_file_open.assert_called_once_with("env/path.txt", "r", encoding="utf-8")
|
||||
|
||||
def test_init_with_default_system_prompt(self):
|
||||
# Ensure ENV var is not set for this test
|
||||
if "SYSTEM_PROMPT_PATH" in os.environ:
|
||||
del os.environ["SYSTEM_PROMPT_PATH"]
|
||||
bot = ConcreteTestBot()
|
||||
self.assertEqual(bot.system_prompt, "You are a helpful AI assistant.")
|
||||
|
||||
@patch("os.path.isfile", return_value=False)
|
||||
def test_init_with_invalid_system_prompt_path(self, mock_isfile):
|
||||
bot = ConcreteTestBot(system_prompt_path="invalid/path.txt")
|
||||
self.assertEqual(bot.system_prompt, "You are a helpful AI assistant.")
|
||||
mock_isfile.assert_called_once_with("invalid/path.txt")
|
||||
|
||||
@patch("os.path.isfile")
|
||||
@patch("builtins.open", side_effect=IOError("File read error"))
|
||||
def test_init_with_system_prompt_file_read_error(self, mock_file_open, mock_isfile):
|
||||
mock_isfile.return_value = True
|
||||
bot = ConcreteTestBot(system_prompt_path="dummy/path.txt")
|
||||
self.assertEqual(bot.system_prompt, "You are a helpful AI assistant.")
|
||||
|
||||
def test_clear_conversation_history(self):
|
||||
mock_tool_instance = MagicMock(spec=BaseTool)
|
||||
bot = ConcreteTestBot(mock_tools=[mock_tool_instance])
|
||||
bot.conversation_history[123] = [{"role": "user", "content": "Hello"}]
|
||||
|
||||
bot.clear_conversation_history(123)
|
||||
self.assertNotIn(123, bot.conversation_history)
|
||||
mock_tool_instance.clear.assert_called_once()
|
||||
|
||||
def test_clear_conversation_history_user_not_found(self):
|
||||
mock_tool_instance = MagicMock(spec=BaseTool)
|
||||
bot = ConcreteTestBot(mock_tools=[mock_tool_instance])
|
||||
bot.clear_conversation_history(404)
|
||||
self.assertNotIn(404, bot.conversation_history)
|
||||
mock_tool_instance.clear.assert_called_once()
|
||||
|
||||
def test_processing_status(self):
|
||||
bot = ConcreteTestBot()
|
||||
self.assertEqual(bot.processing_status, {})
|
||||
bot.set_processing_status(123, 789)
|
||||
self.assertEqual(bot.processing_status[123], {"processing": True, "message_id": 789})
|
||||
bot.clear_processing_status(123)
|
||||
self.assertNotIn(123, bot.processing_status)
|
||||
|
||||
def test_clear_processing_status_user_not_found(self):
|
||||
bot = ConcreteTestBot()
|
||||
bot.clear_processing_status(404)
|
||||
self.assertNotIn(404, bot.processing_status)
|
||||
|
||||
def test_call_tool_success_dict_args(self):
|
||||
mock_tool = MagicMock(spec=BaseTool)
|
||||
mock_tool.get_functions.return_value = [
|
||||
{"function": {"name": "test_tool", "description": "A test tool", "parameters": {}}}
|
||||
]
|
||||
mock_tool.execute.return_value = "Tool executed successfully"
|
||||
|
||||
bot = ConcreteTestBot(mock_tools=[mock_tool], mock_functions=mock_tool.get_functions())
|
||||
|
||||
result = bot.call_tool("test_tool", {"arg1": "value1"})
|
||||
self.assertEqual(result, "Tool executed successfully")
|
||||
mock_tool.execute.assert_called_once_with("test_tool", arg1="value1")
|
||||
|
||||
def test_call_tool_success_json_string_args(self):
|
||||
mock_tool = MagicMock(spec=BaseTool)
|
||||
mock_tool.get_functions.return_value = [
|
||||
{"function": {"name": "test_tool_json", "parameters": {}}}
|
||||
]
|
||||
mock_tool.execute.return_value = "Tool JSON OK"
|
||||
bot = ConcreteTestBot(mock_tools=[mock_tool], mock_functions=mock_tool.get_functions())
|
||||
|
||||
args_json_str = '''{"param": "value"}'''
|
||||
result = bot.call_tool("test_tool_json", args_json_str)
|
||||
self.assertEqual(result, "Tool JSON OK")
|
||||
mock_tool.execute.assert_called_once_with("test_tool_json", param="value")
|
||||
|
||||
def test_call_tool_malformed_json_string_args(self):
|
||||
bot = ConcreteTestBot(mock_tools=[])
|
||||
args_malformed_json_str = '''{"param": "value"'''
|
||||
result = bot.call_tool("some_tool", args_malformed_json_str)
|
||||
self.assertTrue("Error: Malformed arguments for tool call" in result)
|
||||
|
||||
def test_call_tool_unexpected_arg_type(self):
|
||||
bot = ConcreteTestBot(mock_tools=[])
|
||||
result = bot.call_tool("some_tool", 12345) # Integer instead of dict/str
|
||||
self.assertTrue("Error: Invalid argument type for tool call" in result)
|
||||
|
||||
def test_call_tool_none_args(self):
|
||||
mock_tool = MagicMock(spec=BaseTool)
|
||||
mock_tool.get_functions.return_value = [
|
||||
{"function": {"name": "test_tool_none", "parameters": {}}}
|
||||
]
|
||||
mock_tool.execute.return_value = "Tool None OK"
|
||||
bot = ConcreteTestBot(mock_tools=[mock_tool], mock_functions=mock_tool.get_functions())
|
||||
|
||||
result = bot.call_tool("test_tool_none", None)
|
||||
self.assertEqual(result, "Tool None OK")
|
||||
mock_tool.execute.assert_called_once_with("test_tool_none") # No kwargs if None
|
||||
|
||||
def test_call_tool_not_found(self):
|
||||
bot = ConcreteTestBot(mock_tools=[])
|
||||
result = bot.call_tool("non_existent_tool", {})
|
||||
self.assertEqual(result, "Error: Tool function non_existent_tool not found.")
|
||||
|
||||
def test_call_tool_execute_exception(self):
|
||||
mock_tool = MagicMock(spec=BaseTool)
|
||||
mock_tool.get_functions.return_value = [{"function": {"name": "error_tool", "parameters": {}}}]
|
||||
mock_tool.execute.side_effect = Exception("Execution failed")
|
||||
bot = ConcreteTestBot(mock_tools=[mock_tool], mock_functions=mock_tool.get_functions())
|
||||
|
||||
result = bot.call_tool("error_tool", {})
|
||||
self.assertEqual(result, "Error executing tool error_tool: Execution failed")
|
||||
|
||||
def test_get_system_prompt_description(self):
|
||||
if "SYSTEM_PROMPT_PATH" in os.environ: # Ensure clean state
|
||||
del os.environ["SYSTEM_PROMPT_PATH"]
|
||||
|
||||
bot_default = ConcreteTestBot()
|
||||
self.assertEqual(bot_default.get_system_prompt_description(), "System Prompt: Default")
|
||||
|
||||
bot_custom_content = ConcreteTestBot(system_prompt_content="Custom content here")
|
||||
self.assertEqual(bot_custom_content.get_system_prompt_description(), "System Prompt: Custom")
|
||||
|
||||
os.environ["SYSTEM_PROMPT_PATH"] = "some/path.txt"
|
||||
bot_env_default_prompt = ConcreteTestBot() # system_prompt itself is default
|
||||
self.assertEqual(bot_env_default_prompt.get_system_prompt_description(), "System Prompt: Custom (via ENV)")
|
||||
|
||||
with patch("os.path.isfile", return_value=True), \
|
||||
patch("builtins.open", mock_open(read_data="File prompt from ENV")):
|
||||
bot_env_file_prompt = ConcreteTestBot() # system_prompt gets loaded from ENV path
|
||||
self.assertEqual(bot_env_file_prompt.get_system_prompt_description(), "System Prompt: Custom")
|
||||
del os.environ["SYSTEM_PROMPT_PATH"]
|
||||
|
||||
with patch("os.path.isfile", return_value=True), \
|
||||
patch("builtins.open", mock_open(read_data="File prompt from arg")):
|
||||
bot_custom_file_arg = ConcreteTestBot(system_prompt_path="custom/file.txt")
|
||||
self.assertEqual(bot_custom_file_arg.get_system_prompt_description(), "System Prompt: Custom")
|
||||
|
||||
@patch.object(ConcreteTestBot, 'get_llm_description', return_value="Test LLM Description")
|
||||
@patch.object(ConcreteTestBot, 'get_system_prompt_description', return_value="Test Prompt Description")
|
||||
async def test_get_bot_status(self, mock_prompt_desc, mock_llm_desc):
|
||||
bot = ConcreteTestBot()
|
||||
status = await bot.get_bot_status()
|
||||
self.assertEqual(status, "Test Prompt Description\nTest LLM Description")
|
||||
mock_prompt_desc.assert_called_once()
|
||||
mock_llm_desc.assert_called_once()
|
||||
|
||||
@patch('os.path.dirname', return_value='/mock/path')
|
||||
@patch('os.path.join')
|
||||
@patch('os.path.exists')
|
||||
@patch('os.listdir')
|
||||
@patch('importlib.import_module')
|
||||
def test_load_functions_no_tools_dir(self, mock_import_module, mock_listdir, mock_exists, mock_join, mock_dirname):
|
||||
mock_join.return_value = '/mock/path/tools'
|
||||
mock_exists.return_value = False
|
||||
|
||||
class BotForLoadTest(BaseTelegramInferenceBot):
|
||||
load_system_prompt = MagicMock(return_value="Default")
|
||||
get_chat_response = MagicMock(); handle_message = MagicMock(); get_llm_description = MagicMock(return_value="mock")
|
||||
start = MagicMock(); abort_processing = MagicMock(); switch_model = MagicMock()
|
||||
|
||||
bot = BotForLoadTest()
|
||||
self.assertEqual(bot.tools, [])
|
||||
self.assertEqual(bot.functions, [])
|
||||
mock_listdir.assert_not_called()
|
||||
|
||||
@patch('os.path.dirname', return_value='/mock/base_bot_dir')
|
||||
@patch('os.path.join', side_effect=lambda *args: os.path.normpath(os.path.join(*args)))
|
||||
@patch('os.path.exists', return_value=True)
|
||||
@patch('os.listdir', return_value=['my_tool.py', '__init__.py', 'base_tool.py'])
|
||||
@patch('importlib.import_module')
|
||||
def test_load_functions_with_one_tool(self, mock_import_module, mock_listdir, mock_exists, mock_join, mock_dirname):
|
||||
|
||||
mock_tool_class = MagicMock(spec=BaseTool) # This is the class itself
|
||||
mock_tool_instance = MagicMock(spec=BaseTool) # This is the instance
|
||||
mock_tool_class.return_value = mock_tool_instance # mock_tool_class() creates mock_tool_instance
|
||||
mock_tool_instance.get_functions.return_value = [{"function": {"name": "sample_function"}}]
|
||||
|
||||
mock_my_tool_module = MagicMock()
|
||||
# Simulate inspect.getmembers behavior: returns list of (name, member) tuples
|
||||
# Only include members that are classes, derive from BaseTool, and are not BaseTool itself.
|
||||
mock_my_tool_module.ValidToolClass = mock_tool_class
|
||||
mock_my_tool_module.NotATool = object()
|
||||
mock_my_tool_module.BaseTool = BaseTool # This should be skipped by the loader
|
||||
|
||||
def import_side_effect(module_name):
|
||||
if module_name == 'tools.my_tool':
|
||||
return mock_my_tool_module
|
||||
raise ImportError(f"Unexpected import: {module_name}")
|
||||
mock_import_module.side_effect = import_side_effect
|
||||
|
||||
class BotForLoadTest(BaseTelegramInferenceBot):
|
||||
load_system_prompt = MagicMock(return_value="Default")
|
||||
get_chat_response = MagicMock(); handle_message = MagicMock(); get_llm_description = MagicMock(return_value="mock")
|
||||
start = MagicMock(); abort_processing = MagicMock(); switch_model = MagicMock()
|
||||
|
||||
bot = BotForLoadTest()
|
||||
self.assertEqual(len(bot.tools), 1)
|
||||
self.assertIs(bot.tools[0], mock_tool_instance)
|
||||
self.assertEqual(len(bot.functions), 1)
|
||||
self.assertEqual(bot.functions[0]['function']['name'], "sample_function")
|
||||
mock_import_module.assert_called_once_with('tools.my_tool')
|
||||
mock_tool_class.assert_called_once_with() # Tool class was instantiated
|
||||
mock_tool_instance.get_functions.assert_called_once_with()
|
||||
|
||||
@patch('os.path.dirname', return_value='/mock/base_bot_dir')
|
||||
@patch('os.path.join', side_effect=lambda *args: os.path.normpath(os.path.join(*args)))
|
||||
@patch('os.path.exists', return_value=True)
|
||||
@patch('os.listdir', return_value=['tool_with_init_error.py'])
|
||||
@patch('importlib.import_module')
|
||||
@patch('logging.error') # Mock logging to check for error messages
|
||||
def test_load_functions_tool_instantiation_error(self, mock_logging_error, mock_import_module, mock_listdir, mock_exists, mock_join, mock_dirname):
|
||||
mock_tool_class_init_error = MagicMock(spec=BaseTool)
|
||||
mock_tool_class_init_error.side_effect = Exception("Failed to init tool") # Error on instantiation
|
||||
|
||||
mock_error_tool_module = MagicMock()
|
||||
mock_error_tool_module.ToolWithInitError = mock_tool_class_init_error
|
||||
|
||||
mock_import_module.return_value = mock_error_tool_module
|
||||
|
||||
class BotForLoadTest(BaseTelegramInferenceBot):
|
||||
load_system_prompt = MagicMock(return_value="Default")
|
||||
get_chat_response = MagicMock(); handle_message = MagicMock(); get_llm_description = MagicMock(return_value="mock")
|
||||
start = MagicMock(); abort_processing = MagicMock(); switch_model = MagicMock()
|
||||
|
||||
bot = BotForLoadTest()
|
||||
self.assertEqual(len(bot.tools), 0)
|
||||
self.assertEqual(len(bot.functions), 0)
|
||||
mock_logging_error.assert_any_call("Error instantiating tool ToolWithInitError from tool_with_init_error.py: Failed to init tool")
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(闂傚лен䦗婢у〃埊鍓解劓姣)
|
||||
@@ -1,158 +0,0 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch, ANY
|
||||
import os
|
||||
|
||||
# Assuming chatgpt_telegram_inference_bot.py and its parent are accessible
|
||||
from chatgpt_telegram_inference_bot import ChatGPTTelegramInferenceBot
|
||||
from openai_compatible_inference_bot import OpenAICompatibleInferenceBot # For patching super
|
||||
|
||||
class TestChatGPTTelegramInferenceBot(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Store and clear relevant environment variables
|
||||
self.original_openai_key = os.environ.get("OPENAI_API_KEY")
|
||||
self.original_small_model = os.environ.get("OPENAI_SMALL_MODEL")
|
||||
self.original_large_model = os.environ.get("OPENAI_LARGE_MODEL")
|
||||
self.original_small_tokens = os.environ.get("OPENAI_SMALL_MODEL_MAX_TOKENS")
|
||||
self.original_large_tokens = os.environ.get("OPENAI_LARGE_MODEL_MAX_TOKENS")
|
||||
self.original_system_prompt_path = os.environ.get("SYSTEM_PROMPT_PATH")
|
||||
|
||||
for key in ["OPENAI_API_KEY", "OPENAI_SMALL_MODEL", "OPENAI_LARGE_MODEL",
|
||||
"OPENAI_SMALL_MODEL_MAX_TOKENS", "OPENAI_LARGE_MODEL_MAX_TOKENS", "SYSTEM_PROMPT_PATH"]:
|
||||
if os.environ.get(key):
|
||||
del os.environ[key]
|
||||
|
||||
# Mock the OpenAI client that OpenAICompatibleInferenceBot's __init__ might create
|
||||
self.mock_openai_client = MagicMock()
|
||||
|
||||
def tearDown(self):
|
||||
# Restore environment variables
|
||||
if self.original_openai_key: os.environ["OPENAI_API_KEY"] = self.original_openai_key
|
||||
if self.original_small_model: os.environ["OPENAI_SMALL_MODEL"] = self.original_small_model
|
||||
if self.original_large_model: os.environ["OPENAI_LARGE_MODEL"] = self.original_large_model
|
||||
if self.original_small_tokens: os.environ["OPENAI_SMALL_MODEL_MAX_TOKENS"] = self.original_small_tokens
|
||||
if self.original_large_tokens: os.environ["OPENAI_LARGE_MODEL_MAX_TOKENS"] = self.original_large_tokens
|
||||
if self.original_system_prompt_path: os.environ["SYSTEM_PROMPT_PATH"] = self.original_system_prompt_path
|
||||
|
||||
|
||||
@patch.object(OpenAICompatibleInferenceBot, '__init__') # Mock the superclass's __init__
|
||||
def test_init_defaults_and_super_call(self, mock_super_init):
|
||||
os.environ["OPENAI_API_KEY"] = "test_key_chatgpt"
|
||||
os.environ["OPENAI_SMALL_MODEL"] = "gpt-3.5-turbo-env"
|
||||
os.environ["OPENAI_SMALL_MODEL_MAX_TOKENS"] = "350"
|
||||
|
||||
bot = ChatGPTTelegramInferenceBot()
|
||||
|
||||
mock_super_init.assert_called_once_with(
|
||||
client=None, # ChatGPT bot will let superclass create it
|
||||
api_key="test_key_chatgpt", # Passed to super
|
||||
base_url=None,
|
||||
api_version=None,
|
||||
azure_deployment=None,
|
||||
model_name="gpt-3.5-turbo-env", # Default small model from env
|
||||
max_tokens_str="350", # Default small model tokens from env
|
||||
small_model_name="gpt-3.5-turbo-env",
|
||||
small_model_max_tokens_str="350",
|
||||
large_model_name=os.environ.get("OPENAI_LARGE_MODEL", "gpt-4-turbo-preview"), # Default large
|
||||
large_model_max_tokens_str=os.environ.get("OPENAI_LARGE_MODEL_MAX_TOKENS"),
|
||||
system_prompt_content=None,
|
||||
system_prompt_path=None,
|
||||
is_gemini=False,
|
||||
max_history_length=20 # Default from OpenAICompatibleInferenceBot
|
||||
)
|
||||
|
||||
@patch.object(OpenAICompatibleInferenceBot, '__init__')
|
||||
def test_init_with_arguments(self, mock_super_init):
|
||||
mock_client_arg = MagicMock()
|
||||
bot = ChatGPTTelegramInferenceBot(
|
||||
openai_client=mock_client_arg,
|
||||
api_key="arg_key",
|
||||
small_model_name="arg_small_model",
|
||||
small_model_max_tokens="123",
|
||||
large_model_name="arg_large_model",
|
||||
large_model_max_tokens="456",
|
||||
system_prompt_content="Arg prompt"
|
||||
)
|
||||
mock_super_init.assert_called_once_with(
|
||||
client=mock_client_arg,
|
||||
api_key="arg_key",
|
||||
base_url=None,
|
||||
api_version=None,
|
||||
azure_deployment=None,
|
||||
model_name="arg_small_model", # Initially configured with small model
|
||||
max_tokens_str="123",
|
||||
small_model_name="arg_small_model",
|
||||
small_model_max_tokens_str="123",
|
||||
large_model_name="arg_large_model",
|
||||
large_model_max_tokens_str="456",
|
||||
system_prompt_content="Arg prompt",
|
||||
system_prompt_path=None,
|
||||
is_gemini=False,
|
||||
max_history_length=20
|
||||
)
|
||||
|
||||
# Test switch_model - this method is part of ChatGPTTelegramInferenceBot
|
||||
# It calls _configure_model_and_tokens which is in the superclass.
|
||||
# We need a bot instance where _configure_model_and_tokens can be called.
|
||||
@patch('openai.OpenAI') # To allow instantiation of the bot by mocking client creation
|
||||
async def test_switch_model_logic(self, mock_openai_constructor):
|
||||
mock_openai_constructor.return_value = self.mock_openai_client # Mock client creation in super
|
||||
|
||||
# Set env vars for model names that switch_model will use as fallback
|
||||
os.environ["OPENAI_SMALL_MODEL"] = "env-small-gpt"
|
||||
os.environ["OPENAI_SMALL_MODEL_MAX_TOKENS"] = "100"
|
||||
os.environ["OPENAI_LARGE_MODEL"] = "env-large-gpt"
|
||||
os.environ["OPENAI_LARGE_MODEL_MAX_TOKENS"] = "200"
|
||||
|
||||
# Instantiate with initial model (small)
|
||||
bot = ChatGPTTelegramInferenceBot()
|
||||
self.assertEqual(bot.model, "env-small-gpt")
|
||||
self.assertEqual(bot.max_tokens, 100)
|
||||
|
||||
# Switch to large
|
||||
status = await bot.switch_model()
|
||||
self.assertEqual(bot.model, "env-large-gpt")
|
||||
self.assertEqual(bot.max_tokens, 200)
|
||||
self.assertEqual(status, "Switched to model: env-large-gpt")
|
||||
|
||||
# Switch back to small
|
||||
status = await bot.switch_model()
|
||||
self.assertEqual(bot.model, "env-small-gpt")
|
||||
self.assertEqual(bot.max_tokens, 100)
|
||||
self.assertEqual(status, "Switched to model: env-small-gpt")
|
||||
|
||||
@patch('openai.OpenAI')
|
||||
async def test_switch_model_uses_instance_configs_if_provided(self, mock_openai_constructor):
|
||||
mock_openai_constructor.return_value = self.mock_openai_client
|
||||
|
||||
# Instantiate with specific model names, overriding potential env vars
|
||||
bot = ChatGPTTelegramInferenceBot(
|
||||
small_model_name="init-small", small_model_max_tokens="50",
|
||||
large_model_name="init-large", large_model_max_tokens="150"
|
||||
)
|
||||
self.assertEqual(bot.model, "init-small") # Starts with small
|
||||
self.assertEqual(bot.max_tokens, 50)
|
||||
|
||||
# Switch to large
|
||||
status = await bot.switch_model()
|
||||
self.assertEqual(bot.model, "init-large")
|
||||
self.assertEqual(bot.max_tokens, 150)
|
||||
self.assertEqual(status, "Switched to model: init-large")
|
||||
|
||||
# Switch back to small
|
||||
status = await bot.switch_model()
|
||||
self.assertEqual(bot.model, "init-small")
|
||||
self.assertEqual(bot.max_tokens, 50)
|
||||
self.assertEqual(status, "Switched to model: init-small")
|
||||
|
||||
# get_llm_description is inherited from OpenAICompatibleInferenceBot.
|
||||
# Test just to ensure it works in the context of a ChatGPTBot instance
|
||||
@patch('openai.OpenAI')
|
||||
def test_get_llm_description_for_chatgpt_bot(self, mock_openai_constructor):
|
||||
mock_openai_constructor.return_value = self.mock_openai_client
|
||||
bot = ChatGPTTelegramInferenceBot(small_model_name="gpt-3.5-desc", small_model_max_tokens="777")
|
||||
# Initially configured with small model
|
||||
self.assertEqual(bot.get_llm_description(), "LLM: gpt-3.5-desc, Max Tokens: 777, Azure: False")
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -1,154 +0,0 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch, ANY
|
||||
import os
|
||||
|
||||
# Assuming gemini_telegram_inference_bot.py and its parent are accessible
|
||||
from gemini_telegram_inference_bot import GeminiTelegramInferenceBot
|
||||
from openai_compatible_inference_bot import OpenAICompatibleInferenceBot # For patching super
|
||||
|
||||
class TestGeminiTelegramInferenceBot(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Store and clear relevant environment variables
|
||||
self.original_gemini_key = os.environ.get("GEMINI_API_KEY")
|
||||
self.original_gemini_base_url = os.environ.get("GEMINI_API_BASE_URL")
|
||||
self.original_small_model = os.environ.get("GEMINI_SMALL_MODEL")
|
||||
self.original_large_model = os.environ.get("GEMINI_LARGE_MODEL")
|
||||
self.original_small_tokens = os.environ.get("GEMINI_SMALL_MODEL_MAX_TOKENS")
|
||||
self.original_large_tokens = os.environ.get("GEMINI_LARGE_MODEL_MAX_TOKENS")
|
||||
self.original_system_prompt_path = os.environ.get("SYSTEM_PROMPT_PATH")
|
||||
|
||||
for key in ["GEMINI_API_KEY", "GEMINI_API_BASE_URL", "GEMINI_SMALL_MODEL", "GEMINI_LARGE_MODEL",
|
||||
"GEMINI_SMALL_MODEL_MAX_TOKENS", "GEMINI_LARGE_MODEL_MAX_TOKENS", "SYSTEM_PROMPT_PATH"]:
|
||||
if os.environ.get(key):
|
||||
del os.environ[key]
|
||||
|
||||
self.mock_openai_client = MagicMock() # Used if superclass creates an OpenAI client
|
||||
|
||||
def tearDown(self):
|
||||
# Restore environment variables
|
||||
if self.original_gemini_key: os.environ["GEMINI_API_KEY"] = self.original_gemini_key
|
||||
if self.original_gemini_base_url: os.environ["GEMINI_API_BASE_URL"] = self.original_gemini_base_url
|
||||
if self.original_small_model: os.environ["GEMINI_SMALL_MODEL"] = self.original_small_model
|
||||
if self.original_large_model: os.environ["GEMINI_LARGE_MODEL"] = self.original_large_model
|
||||
if self.original_small_tokens: os.environ["GEMINI_SMALL_MODEL_MAX_TOKENS"] = self.original_small_tokens
|
||||
if self.original_large_tokens: os.environ["GEMINI_LARGE_MODEL_MAX_TOKENS"] = self.original_large_tokens
|
||||
if self.original_system_prompt_path: os.environ["SYSTEM_PROMPT_PATH"] = self.original_system_prompt_path
|
||||
|
||||
@patch.object(OpenAICompatibleInferenceBot, '__init__') # Mock the superclass's __init__
|
||||
def test_init_defaults_and_super_call(self, mock_super_init):
|
||||
os.environ["GEMINI_API_KEY"] = "test_key_gemini"
|
||||
os.environ["GEMINI_API_BASE_URL"] = "https://gemini.env.com"
|
||||
os.environ["GEMINI_SMALL_MODEL"] = "gemini-pro-env"
|
||||
os.environ["GEMINI_SMALL_MODEL_MAX_TOKENS"] = "360"
|
||||
|
||||
bot = GeminiTelegramInferenceBot()
|
||||
|
||||
mock_super_init.assert_called_once_with(
|
||||
client=None,
|
||||
api_key="test_key_gemini",
|
||||
base_url="https://gemini.env.com", # Passed to super
|
||||
api_version=None,
|
||||
azure_deployment=None,
|
||||
model_name="gemini-pro-env",
|
||||
max_tokens_str="360",
|
||||
small_model_name="gemini-pro-env",
|
||||
small_model_max_tokens_str="360",
|
||||
large_model_name=os.environ.get("GEMINI_LARGE_MODEL", "gemini-1.5-pro-latest"), # Default large
|
||||
large_model_max_tokens_str=os.environ.get("GEMINI_LARGE_MODEL_MAX_TOKENS"),
|
||||
system_prompt_content=None,
|
||||
system_prompt_path=None,
|
||||
is_gemini=True, # Important for Gemini bot
|
||||
max_history_length=20
|
||||
)
|
||||
|
||||
@patch.object(OpenAICompatibleInferenceBot, '__init__')
|
||||
def test_init_with_arguments(self, mock_super_init):
|
||||
mock_client_arg = MagicMock()
|
||||
bot = GeminiTelegramInferenceBot(
|
||||
openai_client=mock_client_arg, # Name in Gemini bot is openai_client for consistency
|
||||
api_key="arg_gem_key",
|
||||
base_url="https://arg.gemini.com",
|
||||
small_model_name="arg_gem_small",
|
||||
small_model_max_tokens="124",
|
||||
large_model_name="arg_gem_large",
|
||||
large_model_max_tokens="457",
|
||||
system_prompt_content="Gemini prompt"
|
||||
)
|
||||
mock_super_init.assert_called_once_with(
|
||||
client=mock_client_arg,
|
||||
api_key="arg_gem_key",
|
||||
base_url="https://arg.gemini.com",
|
||||
api_version=None,
|
||||
azure_deployment=None,
|
||||
model_name="arg_gem_small",
|
||||
max_tokens_str="124",
|
||||
small_model_name="arg_gem_small",
|
||||
small_model_max_tokens_str="124",
|
||||
large_model_name="arg_gem_large",
|
||||
large_model_max_tokens_str="457",
|
||||
system_prompt_content="Gemini prompt",
|
||||
system_prompt_path=None,
|
||||
is_gemini=True,
|
||||
max_history_length=20
|
||||
)
|
||||
|
||||
@patch('openai.OpenAI') # Gemini bot uses OpenAI client configured for Gemini endpoint
|
||||
async def test_switch_model_logic(self, mock_openai_constructor):
|
||||
mock_openai_constructor.return_value = self.mock_openai_client
|
||||
|
||||
os.environ["GEMINI_SMALL_MODEL"] = "env-gemini-small"
|
||||
os.environ["GEMINI_SMALL_MODEL_MAX_TOKENS"] = "110"
|
||||
os.environ["GEMINI_LARGE_MODEL"] = "env-gemini-large"
|
||||
os.environ["GEMINI_LARGE_MODEL_MAX_TOKENS"] = "220"
|
||||
|
||||
bot = GeminiTelegramInferenceBot() # Uses env vars by default
|
||||
self.assertEqual(bot.model, "env-gemini-small")
|
||||
self.assertEqual(bot.max_tokens, 110)
|
||||
|
||||
status = await bot.switch_model()
|
||||
self.assertEqual(bot.model, "env-gemini-large")
|
||||
self.assertEqual(bot.max_tokens, 220)
|
||||
self.assertEqual(status, "Switched to model: env-gemini-large")
|
||||
|
||||
status = await bot.switch_model()
|
||||
self.assertEqual(bot.model, "env-gemini-small")
|
||||
self.assertEqual(bot.max_tokens, 110)
|
||||
self.assertEqual(status, "Switched to model: env-gemini-small")
|
||||
|
||||
@patch('openai.OpenAI')
|
||||
async def test_switch_model_uses_instance_configs_if_provided(self, mock_openai_constructor):
|
||||
mock_openai_constructor.return_value = self.mock_openai_client
|
||||
|
||||
bot = GeminiTelegramInferenceBot(
|
||||
small_model_name="init-gem-small", small_model_max_tokens="55",
|
||||
large_model_name="init-gem-large", large_model_max_tokens="155"
|
||||
)
|
||||
self.assertEqual(bot.model, "init-gem-small")
|
||||
self.assertEqual(bot.max_tokens, 55)
|
||||
|
||||
status = await bot.switch_model()
|
||||
self.assertEqual(bot.model, "init-gem-large")
|
||||
self.assertEqual(bot.max_tokens, 155)
|
||||
self.assertEqual(status, "Switched to model: init-gem-large")
|
||||
|
||||
status = await bot.switch_model()
|
||||
self.assertEqual(bot.model, "init-gem-small")
|
||||
self.assertEqual(bot.max_tokens, 55)
|
||||
self.assertEqual(status, "Switched to model: init-gem-small")
|
||||
|
||||
@patch('openai.OpenAI')
|
||||
def test_get_llm_description_for_gemini_bot(self, mock_openai_constructor):
|
||||
mock_openai_constructor.return_value = self.mock_openai_client
|
||||
bot = GeminiTelegramInferenceBot(
|
||||
small_model_name="gemini-pro-desc",
|
||||
small_model_max_tokens="888",
|
||||
# is_gemini is True by default in constructor call to super
|
||||
)
|
||||
# LLM description should indicate not Azure, even though it uses OpenAICompatible... base
|
||||
# The is_gemini flag primarily affects client instantiation logic in the superclass.
|
||||
# The azure_openai flag in superclass is based on azure_endpoint presence.
|
||||
self.assertEqual(bot.get_llm_description(), "LLM: gemini-pro-desc, Max Tokens: 888, Azure: False")
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -1,81 +0,0 @@
|
||||
# tests/test_github_tool.py
|
||||
|
||||
import unittest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from tools.github_tool import GitHubTool
|
||||
|
||||
class TestGitHubTool(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.github_tool = GitHubTool()
|
||||
|
||||
def test_get_functions(self):
|
||||
functions = self.github_tool.get_functions()
|
||||
self.assertEqual(len(functions), 4)
|
||||
function_names = [f["name"] for f in functions]
|
||||
expected_names = ["read_file", "create_branch", "commit_file", "create_pull_request"]
|
||||
self.assertListEqual(function_names, expected_names)
|
||||
|
||||
@patch('tools.github_tool.requests.get')
|
||||
def test_read_file(self, mock_get):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"content": "file content"}
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
result = self.github_tool.execute("read_file", path="test.txt")
|
||||
self.assertEqual(result, "file content")
|
||||
|
||||
mock_get.assert_called_once()
|
||||
|
||||
@patch('tools.github_tool.requests.get')
|
||||
@patch('tools.github_tool.requests.post')
|
||||
def test_create_branch(self, mock_post, mock_get):
|
||||
mock_get_response = MagicMock()
|
||||
mock_get_response.status_code = 200
|
||||
mock_get_response.json.return_value = {"object": {"sha": "test_sha"}}
|
||||
mock_get.return_value = mock_get_response
|
||||
|
||||
mock_post_response = MagicMock()
|
||||
mock_post_response.status_code = 201
|
||||
mock_post.return_value = mock_post_response
|
||||
|
||||
result = self.github_tool.execute("create_branch", branch_name="test-branch")
|
||||
self.assertEqual(result, "Branch 'test-branch' created successfully")
|
||||
|
||||
mock_get.assert_called_once()
|
||||
mock_post.assert_called_once()
|
||||
|
||||
@patch('tools.github_tool.requests.put')
|
||||
def test_commit_file(self, mock_put):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_put.return_value = mock_response
|
||||
|
||||
result = self.github_tool.execute("commit_file", branch_name="test-branch", file_path="test.txt", content="test content", commit_message="Test commit")
|
||||
self.assertEqual(result, "File committed successfully to branch 'test-branch'")
|
||||
|
||||
mock_put.assert_called_once()
|
||||
|
||||
def test_commit_file_to_main(self):
|
||||
result = self.github_tool.execute("commit_file", branch_name="main", file_path="test.txt", content="test content", commit_message="Test commit")
|
||||
self.assertEqual(result, "Cannot commit directly to main branch")
|
||||
|
||||
@patch('tools.github_tool.requests.post')
|
||||
def test_create_pull_request(self, mock_post):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 201
|
||||
mock_response.json.return_value = {"html_url": "https://github.com/test/test/pull/1"}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = self.github_tool.execute("create_pull_request", title="Test PR", body="Test body", head="test-branch")
|
||||
self.assertEqual(result, "Pull request created successfully: https://github.com/test/test/pull/1")
|
||||
|
||||
mock_post.assert_called_once()
|
||||
|
||||
def test_unknown_function(self):
|
||||
result = self.github_tool.execute("unknown_function")
|
||||
self.assertEqual(result, "Unknown function: unknown_function")
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -1,332 +0,0 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch, AsyncMock, ANY
|
||||
import os
|
||||
import json
|
||||
|
||||
# Assuming openai_compatible_inference_bot.py is in the parent directory or PYTHONPATH is set
|
||||
from openai_compatible_inference_bot import OpenAICompatibleInferenceBot
|
||||
|
||||
# Mock response from OpenAI client's chat.completions.create
|
||||
def create_mock_openai_response(content=None, tool_calls=None):
|
||||
mock_message = MagicMock()
|
||||
mock_message.role = "assistant"
|
||||
mock_message.content = content
|
||||
if tool_calls:
|
||||
# tool_calls should be a list of objects with id and function (name, arguments)
|
||||
mock_tool_calls = []
|
||||
for tc in tool_calls:
|
||||
mock_tc = MagicMock()
|
||||
mock_tc.id = tc["id"]
|
||||
mock_tc.function.name = tc["function"]["name"]
|
||||
mock_tc.function.arguments = tc["function"]["arguments"]
|
||||
mock_tool_calls.append(mock_tc)
|
||||
mock_message.tool_calls = mock_tool_calls
|
||||
else:
|
||||
mock_message.tool_calls = None
|
||||
|
||||
mock_choice = MagicMock()
|
||||
mock_choice.message = mock_message
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [mock_choice]
|
||||
return mock_response
|
||||
|
||||
# Concrete class for testing
|
||||
class ConcreteOpenAICompatibleBot(OpenAICompatibleInferenceBot):
|
||||
# Implement abstract methods for instantiation
|
||||
async def switch_model(self):
|
||||
# Simple switch for testing if needed, or just pass
|
||||
if self.model == self.small_model_name:
|
||||
self._configure_model_and_tokens(self.large_model_name, self.large_model_max_tokens_str)
|
||||
else:
|
||||
self._configure_model_and_tokens(self.small_model_name, self.small_model_max_tokens_str)
|
||||
return f"Switched to {self.model}"
|
||||
|
||||
# Override load_functions if it's called by parent and needs mocking for these tests
|
||||
# (OpenAICompatibleInferenceBot's __init__ calls BaseTelegramInferenceBot's __init__, which calls load_functions)
|
||||
def load_functions(self):
|
||||
# For these tests, assume no tools unless specifically added
|
||||
self.tools = []
|
||||
self.functions = []
|
||||
return self.tools, self.functions
|
||||
|
||||
|
||||
class TestOpenAICompatibleInferenceBot(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.original_openai_api_key = os.environ.get("OPENAI_API_KEY")
|
||||
self.original_azure_openai_key = os.environ.get("AZURE_OPENAI_KEY")
|
||||
self.original_azure_endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT")
|
||||
self.original_api_version = os.environ.get("AZURE_OPENAI_API_VERSION")
|
||||
self.original_azure_deployment = os.environ.get("AZURE_DEPLOYMENT_NAME")
|
||||
|
||||
# Clear relevant env vars before each test
|
||||
for key in ["OPENAI_API_KEY", "AZURE_OPENAI_KEY", "AZURE_OPENAI_ENDPOINT",
|
||||
"AZURE_OPENAI_API_VERSION", "AZURE_DEPLOYMENT_NAME", "SYSTEM_PROMPT_PATH"]:
|
||||
if os.environ.get(key):
|
||||
del os.environ[key]
|
||||
|
||||
self.mock_openai_client_instance = MagicMock()
|
||||
self.mock_openai_client_instance.chat.completions.create = MagicMock()
|
||||
|
||||
def tearDown(self):
|
||||
# Restore environment variables
|
||||
if self.original_openai_api_key: os.environ["OPENAI_API_KEY"] = self.original_openai_api_key
|
||||
if self.original_azure_openai_key: os.environ["AZURE_OPENAI_KEY"] = self.original_azure_openai_key
|
||||
if self.original_azure_endpoint: os.environ["AZURE_OPENAI_ENDPOINT"] = self.original_azure_endpoint
|
||||
if self.original_api_version: os.environ["AZURE_OPENAI_API_VERSION"] = self.original_api_version
|
||||
if self.original_azure_deployment: os.environ["AZURE_DEPLOYMENT_NAME"] = self.original_azure_deployment
|
||||
|
||||
|
||||
@patch('openai.OpenAI')
|
||||
def test_init_with_openai_defaults(self, MockOpenAIConstructor):
|
||||
MockOpenAIConstructor.return_value = self.mock_openai_client_instance
|
||||
os.environ["OPENAI_API_KEY"] = "test_openai_key"
|
||||
|
||||
bot = ConcreteOpenAICompatibleBot(model_name="gpt-4")
|
||||
|
||||
MockOpenAIConstructor.assert_called_once_with(api_key="test_openai_key", base_url=None)
|
||||
self.assertEqual(bot.client, self.mock_openai_client_instance)
|
||||
self.assertEqual(bot.model, "gpt-4")
|
||||
self.assertEqual(bot.max_tokens, 1000) # Default from _configure_model_and_tokens
|
||||
self.assertEqual(bot.azure_openai, False)
|
||||
|
||||
@patch('openai.OpenAI')
|
||||
def test_init_with_provided_client(self, MockOpenAIConstructor):
|
||||
preconfigured_client = MagicMock()
|
||||
bot = ConcreteOpenAICompatibleBot(client=preconfigured_client, model_name="gpt-3.5")
|
||||
|
||||
MockOpenAIConstructor.assert_not_called()
|
||||
self.assertEqual(bot.client, preconfigured_client)
|
||||
self.assertEqual(bot.model, "gpt-3.5")
|
||||
|
||||
@patch('openai.AzureOpenAI')
|
||||
def test_init_with_azure_config_args(self, MockAzureOpenAIConstructor):
|
||||
MockAzureOpenAIConstructor.return_value = self.mock_openai_client_instance
|
||||
|
||||
bot = ConcreteOpenAICompatibleBot(
|
||||
api_key="azure_key",
|
||||
azure_endpoint="https://myenv.openai.azure.com",
|
||||
api_version="2023-05-15",
|
||||
azure_deployment="my-gpt-4", # This should be used as model_name for API call
|
||||
model_name="should_be_overridden_by_azure_deployment_for_api"
|
||||
# model_name is passed to _configure_model_and_tokens, which sets self.model for display/logging
|
||||
# but for Azure, the client needs the deployment name.
|
||||
)
|
||||
|
||||
MockAzureOpenAIConstructor.assert_called_once_with(
|
||||
api_key="azure_key",
|
||||
azure_endpoint="https://myenv.openai.azure.com",
|
||||
api_version="2023-05-15"
|
||||
)
|
||||
self.assertEqual(bot.client, self.mock_openai_client_instance)
|
||||
self.assertEqual(bot.model, "my-gpt-4") # Azure deployment name becomes the model for API calls
|
||||
self.assertEqual(bot.azure_openai, True)
|
||||
|
||||
|
||||
@patch('openai.AzureOpenAI')
|
||||
def test_init_with_azure_env_vars(self, MockAzureOpenAIConstructor):
|
||||
MockAzureOpenAIConstructor.return_value = self.mock_openai_client_instance
|
||||
os.environ["AZURE_OPENAI_KEY"] = "env_azure_key"
|
||||
os.environ["AZURE_OPENAI_ENDPOINT"] = "https://env.openai.azure.com"
|
||||
os.environ["AZURE_OPENAI_API_VERSION"] = "2023-06-01"
|
||||
os.environ["AZURE_DEPLOYMENT_NAME"] = "env-gpt-35" # Used as model_name
|
||||
|
||||
bot = ConcreteOpenAICompatibleBot(model_name="ignored_if_azure_deployment_env_is_set")
|
||||
|
||||
MockAzureOpenAIConstructor.assert_called_once_with(
|
||||
api_key="env_azure_key",
|
||||
azure_endpoint="https://env.openai.azure.com",
|
||||
api_version="2023-06-01"
|
||||
)
|
||||
self.assertEqual(bot.model, "env-gpt-35")
|
||||
self.assertTrue(bot.azure_openai)
|
||||
|
||||
@patch('openai.OpenAI')
|
||||
def test_init_with_gemini_config_args(self, MockOpenAIConstructor):
|
||||
MockOpenAIConstructor.return_value = self.mock_openai_client_instance
|
||||
|
||||
bot = ConcreteOpenAICompatibleBot(
|
||||
api_key="gemini_key",
|
||||
base_url="https://gemini.example.com",
|
||||
model_name="gemini-pro",
|
||||
is_gemini=True
|
||||
)
|
||||
MockOpenAIConstructor.assert_called_once_with(api_key="gemini_key", base_url="https://gemini.example.com")
|
||||
self.assertEqual(bot.model, "gemini-pro")
|
||||
self.assertFalse(bot.azure_openai) # is_gemini doesn't mean azure_openai
|
||||
|
||||
def test_configure_model_and_tokens(self):
|
||||
bot = ConcreteOpenAICompatibleBot(model_name="initial_model") # init calls _configure
|
||||
bot._configure_model_and_tokens("test-model", "500")
|
||||
self.assertEqual(bot.model, "test-model")
|
||||
self.assertEqual(bot.max_tokens, 500)
|
||||
|
||||
bot._configure_model_and_tokens("test-model-2", None, default_max_tokens=150)
|
||||
self.assertEqual(bot.max_tokens, 150)
|
||||
|
||||
bot._configure_model_and_tokens("test-model-3", "invalid_token_val")
|
||||
self.assertEqual(bot.max_tokens, 1000) # Default fallback
|
||||
|
||||
def test_get_llm_description(self):
|
||||
bot = ConcreteOpenAICompatibleBot(model_name="desc-model", max_tokens_str="256")
|
||||
self.assertEqual(bot.get_llm_description(), "LLM: desc-model, Max Tokens: 256, Azure: False")
|
||||
|
||||
bot_azure = ConcreteOpenAICompatibleBot(azure_deployment="azure-model", azure_endpoint="x", api_key="y", api_version="z")
|
||||
self.assertEqual(bot_azure.get_llm_description(), "LLM: azure-model, Max Tokens: 1000, Azure: True")
|
||||
|
||||
|
||||
def test_get_chat_response_success(self):
|
||||
bot = ConcreteOpenAICompatibleBot(client=self.mock_openai_client_instance, model_name="test-gpt")
|
||||
bot.max_tokens = 50 # Ensure this is set
|
||||
mock_api_response = create_mock_openai_response(content="Hello from API")
|
||||
self.mock_openai_client_instance.chat.completions.create.return_value = mock_api_response
|
||||
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
response = bot.get_chat_response(messages)
|
||||
|
||||
self.mock_openai_client_instance.chat.completions.create.assert_called_once_with(
|
||||
model="test-gpt",
|
||||
messages=messages,
|
||||
tools=ANY, # Assuming functions can be None or empty list
|
||||
tool_choice=ANY,
|
||||
max_tokens=50
|
||||
)
|
||||
self.assertEqual(response, mock_api_response)
|
||||
|
||||
def test_get_chat_response_api_error(self):
|
||||
bot = ConcreteOpenAICompatibleBot(client=self.mock_openai_client_instance, model_name="error-gpt")
|
||||
self.mock_openai_client_instance.chat.completions.create.side_effect = Exception("API Down")
|
||||
|
||||
with self.assertRaisesRegex(Exception, "API Down"):
|
||||
bot.get_chat_response([{"role": "user", "content": "trigger"}])
|
||||
|
||||
async def test_handle_message_simple_response(self):
|
||||
bot = ConcreteOpenAICompatibleBot(client=self.mock_openai_client_instance, model_name="chatty")
|
||||
bot.system_prompt = "You are a test bot." # Set directly for simplicity
|
||||
mock_api_response = create_mock_openai_response(content="Test reply")
|
||||
self.mock_openai_client_instance.chat.completions.create.return_value = mock_api_response
|
||||
|
||||
response_content = await bot.handle_message(user_id=1, user_message="Hello")
|
||||
|
||||
self.assertEqual(response_content, "Test reply")
|
||||
self.assertIn(1, bot.conversation_history)
|
||||
self.assertEqual(len(bot.conversation_history[1]), 3) # System, User, Assistant
|
||||
self.assertEqual(bot.conversation_history[1][0]["content"], "You are a test bot.")
|
||||
self.assertEqual(bot.conversation_history[1][2]["content"], "Test reply")
|
||||
|
||||
async def test_handle_message_with_tool_call_and_response(self):
|
||||
bot = ConcreteOpenAICompatibleBot(client=self.mock_openai_client_instance, model_name="tool-user")
|
||||
|
||||
# Mock functions/tools setup on the bot
|
||||
mock_tool_def = {"function": {"name": "get_weather", "description": "Gets weather", "parameters": {}}}
|
||||
bot.functions = [mock_tool_def] # Simulate tools are loaded
|
||||
|
||||
# API response 1: Request to call tool
|
||||
tool_call_request = [{"id": "call123", "function": {"name": "get_weather", "arguments": '''{"location": "moon"}'''}}]
|
||||
api_response_1 = create_mock_openai_response(tool_calls=tool_call_request)
|
||||
|
||||
# API response 2: Final answer after tool execution
|
||||
api_response_2 = create_mock_openai_response(content="The weather on the moon is chilly.")
|
||||
|
||||
self.mock_openai_client_instance.chat.completions.create.side_effect = [api_response_1, api_response_2]
|
||||
|
||||
# Mock self.call_tool
|
||||
bot.call_tool = MagicMock(return_value='''{"temperature": "-100 C"}''')
|
||||
|
||||
final_response = await bot.handle_message(user_id=2, user_message="Weather on moon?")
|
||||
|
||||
self.assertEqual(final_response, "The weather on the moon is chilly.")
|
||||
bot.call_tool.assert_called_once_with("get_weather", '''{"location": "moon"}''')
|
||||
|
||||
# Check conversation history includes tool messages
|
||||
history = bot.conversation_history[2]
|
||||
self.assertTrue(any(msg["role"] == "assistant" and msg.tool_calls is not None for msg in history))
|
||||
self.assertTrue(any(msg["role"] == "tool" and msg["name"] == "get_weather" for msg in history))
|
||||
self.assertEqual(self.mock_openai_client_instance.chat.completions.create.call_count, 2)
|
||||
|
||||
async def test_handle_message_max_history_length(self):
|
||||
bot = ConcreteOpenAICompatibleBot(client=self.mock_openai_client_instance, model_name="hist-test", max_history_length=3)
|
||||
self.mock_openai_client_instance.chat.completions.create.return_value = create_mock_openai_response(content="Ok")
|
||||
|
||||
await bot.handle_message(1, "Msg1") # Sys, User, Assist (3)
|
||||
self.assertEqual(len(bot.conversation_history[1]), 3)
|
||||
|
||||
await bot.handle_message(1, "Msg2") # User, Assist. Should be 3 (prev User, prev Assist, new User) -> then adds new Assist.
|
||||
# Before new call: [Sys, U1, A1]. New U2. Call with [Sys,U1,A1,U2]. Resp A2.
|
||||
# History: [Sys,U1,A1,U2,A2]. Limit 3. -> [A1,U2,A2] (if system is not preserved specially)
|
||||
# The current code appends to history then truncates if over limit.
|
||||
# So after Msg1: [S, U1, A1]. len=3.
|
||||
# For Msg2: History is [S, U1, A1]. Append U2. Call with [S,U1,A1,U2]. Append A2.
|
||||
# History now [S,U1,A1,U2,A2]. len=5. Truncate to 3.
|
||||
# Expected: [A1, U2, A2] or [U1,A1,U2] or [U2,A2,S] depending on how system prompt is handled in truncation.
|
||||
# The code is: self.conversation_history[user_id][-self.max_history_length:]
|
||||
# And system prompt is only added IF user_id not in self.conversation_history.
|
||||
# So, for Msg2, system prompt is not re-added.
|
||||
# History before Msg2 call: [S, U1, A1]
|
||||
# Messages for Msg2 call: [S, U1, A1, U2]
|
||||
# History after Msg2 response A2: [S, U1, A1, U2, A2]. Len 5.
|
||||
# Truncated to self.max_history_length=3: [A1, U2, A2]
|
||||
|
||||
# Call 1
|
||||
self.mock_openai_client_instance.chat.completions.create.reset_mock()
|
||||
self.mock_openai_client_instance.chat.completions.create.return_value = create_mock_openai_response(content="Reply1")
|
||||
await bot.handle_message(user_id=7, user_message="First message")
|
||||
self.assertEqual(len(bot.conversation_history[7]), 3) # System, User1, Assistant1
|
||||
|
||||
# Call 2
|
||||
self.mock_openai_client_instance.chat.completions.create.reset_mock()
|
||||
self.mock_openai_client_instance.chat.completions.create.return_value = create_mock_openai_response(content="Reply2")
|
||||
await bot.handle_message(user_id=7, user_message="Second message")
|
||||
# History before call: [S, U1, A1]. Messages for call: [S, U1, A1, U2]. History after: [S, U1, A1, U2, A2].
|
||||
# Truncated to 3: [A1, U2, A2]
|
||||
self.assertEqual(len(bot.conversation_history[7]), 3)
|
||||
self.assertEqual(bot.conversation_history[7][0]["content"], "Reply1") # A1
|
||||
self.assertEqual(bot.conversation_history[7][1]["content"], "Second message") # U2
|
||||
self.assertEqual(bot.conversation_history[7][2]["content"], "Reply2") # A2
|
||||
|
||||
# Call 3
|
||||
self.mock_openai_client_instance.chat.completions.create.reset_mock()
|
||||
self.mock_openai_client_instance.chat.completions.create.return_value = create_mock_openai_response(content="Reply3")
|
||||
await bot.handle_message(user_id=7, user_message="Third message")
|
||||
# History before call: [A1, U2, A2]. Messages for call: [A1, U2, A2, U3]. History after: [A1, U2, A2, U3, A3].
|
||||
# Truncated to 3: [A2, U3, A3]
|
||||
self.assertEqual(len(bot.conversation_history[7]), 3)
|
||||
self.assertEqual(bot.conversation_history[7][0]["content"], "Reply2") # A2
|
||||
self.assertEqual(bot.conversation_history[7][1]["content"], "Third message") # U3
|
||||
self.assertEqual(bot.conversation_history[7][2]["content"], "Reply3") # A3
|
||||
|
||||
|
||||
async def test_abort_processing(self):
|
||||
bot = ConcreteOpenAICompatibleBot(model_name="test")
|
||||
user_id = 123
|
||||
bot.processing_status[user_id] = {"processing": True, "message_id": 456}
|
||||
bot.conversation_history[user_id] = [{"role": "user", "content": "stuff"}]
|
||||
|
||||
with patch.object(bot, 'clear_conversation_history') as mock_clear_hist: # Patching the method from Base class
|
||||
result = await bot.abort_processing(user_id)
|
||||
|
||||
self.assertEqual(result, "Processing aborted and conversation cleared.")
|
||||
self.assertFalse(bot.processing_status[user_id]["processing"])
|
||||
mock_clear_hist.assert_called_once_with(user_id)
|
||||
|
||||
async def test_abort_processing_no_active_processing(self):
|
||||
bot = ConcreteOpenAICompatibleBot(model_name="test")
|
||||
user_id = 404 # Not in processing_status
|
||||
with patch.object(bot, 'clear_conversation_history') as mock_clear_hist:
|
||||
result = await bot.abort_processing(user_id)
|
||||
self.assertEqual(result, "No active processing found to abort. Conversation cleared.")
|
||||
mock_clear_hist.assert_called_once_with(user_id)
|
||||
|
||||
# Test for the abstract switch_model (basic call, actual logic in concrete class for this test)
|
||||
async def test_switch_model_concrete_implementation(self):
|
||||
bot = ConcreteOpenAICompatibleBot(model_name="model1", small_model_name="model1", large_model_name="model2", max_tokens_str="100")
|
||||
self.assertEqual(bot.model, "model1")
|
||||
await bot.switch_model() # Calls the concrete implementation
|
||||
self.assertEqual(bot.model, "model2")
|
||||
await bot.switch_model()
|
||||
self.assertEqual(bot.model, "model1")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -1,356 +0,0 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch, mock_open, AsyncMock
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
|
||||
# Assuming telegram_helper.py is in the parent directory or PYTHONPATH is set
|
||||
from telegram_helper import TelegramHelper, MessageHandlerLogicResult
|
||||
|
||||
# Mock for the bot passed to TelegramHelper
|
||||
class MockBot:
|
||||
def __init__(self):
|
||||
self.start = AsyncMock()
|
||||
self.clear_conversation_history = MagicMock()
|
||||
self.get_bot_status = AsyncMock(return_value="Bot Status OK")
|
||||
self.switch_model = AsyncMock(return_value="Model Switched OK")
|
||||
self.handle_message = AsyncMock() # Needs to return a string
|
||||
self.abort_processing = AsyncMock(return_value="Abort OK")
|
||||
self.set_processing_status = MagicMock()
|
||||
self.clear_processing_status = MagicMock()
|
||||
self.processing_status = {} # Add the attribute
|
||||
|
||||
# Mock for telegram.Update and related objects
|
||||
def create_mock_update(message_text=None, user_id=123, chat_id=456, message_id=789, callback_query_data=None):
|
||||
update = MagicMock()
|
||||
update.effective_user.id = user_id
|
||||
update.effective_chat.id = chat_id
|
||||
|
||||
if message_text:
|
||||
update.message.text = message_text
|
||||
update.message.reply_text = AsyncMock(return_value=MagicMock(message_id=message_id)) # reply_text returns a Message obj
|
||||
|
||||
if callback_query_data:
|
||||
update.callback_query.data = callback_query_data
|
||||
update.callback_query.from_user.id = user_id
|
||||
update.callback_query.answer = AsyncMock()
|
||||
update.callback_query.edit_message_text = AsyncMock()
|
||||
|
||||
return update
|
||||
|
||||
# Mock for telegram.ext.ContextTypes.DEFAULT_TYPE
|
||||
def create_mock_context():
|
||||
context = MagicMock()
|
||||
context.bot.delete_message = AsyncMock()
|
||||
context.bot.edit_message_text = AsyncMock() # For update_status_message
|
||||
return context
|
||||
|
||||
class TestTelegramHelper(unittest.IsolatedAsyncioTestCase): # Use IsolatedAsyncioTestCase for async methods
|
||||
|
||||
def setUp(self):
|
||||
self.mock_bot = MockBot()
|
||||
# Default paths for reboot files, can be overridden in tests
|
||||
self.reboot_claude_file = ".test_reboot_claude"
|
||||
self.reboot_file = ".test_doreboot"
|
||||
self.helper = TelegramHelper(
|
||||
self.mock_bot,
|
||||
reboot_claude_file_path=self.reboot_claude_file,
|
||||
reboot_file_path=self.reboot_file,
|
||||
chunk_message_sleep_duration=0.001 # Faster sleep for tests
|
||||
)
|
||||
# Clean up any potential leftover reboot files from previous runs
|
||||
if os.path.exists(self.reboot_claude_file):
|
||||
os.remove(self.reboot_claude_file)
|
||||
if os.path.exists(self.reboot_file):
|
||||
os.remove(self.reboot_file)
|
||||
|
||||
def tearDown(self):
|
||||
# Clean up reboot files created during tests
|
||||
if os.path.exists(self.reboot_claude_file):
|
||||
os.remove(self.reboot_claude_file)
|
||||
if os.path.exists(self.reboot_file):
|
||||
os.remove(self.reboot_file)
|
||||
|
||||
async def test_start_logic(self):
|
||||
response = await self.helper._start_logic()
|
||||
self.mock_bot.start.assert_called_once()
|
||||
self.assertEqual(response, "Hello! I\'m your AI assistant. How can I help you today?")
|
||||
|
||||
async def test_start_command(self):
|
||||
mock_update = create_mock_update(message_text="/start")
|
||||
mock_context = create_mock_context()
|
||||
|
||||
with patch.object(self.helper, \'_start_logic\', new_callable=AsyncMock) as mock_logic:
|
||||
mock_logic.return_value = "Start Logic Response"
|
||||
await self.helper.start(mock_update, mock_context)
|
||||
mock_logic.assert_called_once()
|
||||
mock_update.message.reply_text.assert_called_once_with("Start Logic Response")
|
||||
|
||||
async def test_clear_logic(self):
|
||||
user_id = 123
|
||||
response = await self.helper._clear_logic(user_id) # _clear_logic is async after refactor
|
||||
self.mock_bot.clear_conversation_history.assert_called_once_with(user_id)
|
||||
self.assertEqual(response, "Conversation history cleared. Let\'s start fresh!")
|
||||
|
||||
async def test_clear_command(self):
|
||||
mock_update = create_mock_update(message_text="/clear", user_id=123)
|
||||
mock_context = create_mock_context()
|
||||
with patch.object(self.helper, \'_clear_logic\', new_callable=AsyncMock) as mock_logic:
|
||||
mock_logic.return_value = "Clear Logic Response"
|
||||
await self.helper.clear(mock_update, mock_context)
|
||||
mock_logic.assert_called_once_with(123)
|
||||
mock_update.message.reply_text.assert_called_once_with("Clear Logic Response")
|
||||
|
||||
async def test_status_logic(self):
|
||||
self.mock_bot.get_bot_status.return_value = "Test Status"
|
||||
response = await self.helper._status_logic()
|
||||
self.mock_bot.get_bot_status.assert_called_once()
|
||||
self.assertEqual(response, "Test Status")
|
||||
|
||||
async def test_switch_logic_supported(self):
|
||||
self.mock_bot.switch_model.return_value = "Switched to Large Model"
|
||||
response = await self.helper._switch_logic()
|
||||
self.mock_bot.switch_model.assert_called_once()
|
||||
self.assertEqual(response, "Switched to Large Model")
|
||||
|
||||
async def test_switch_logic_not_supported(self):
|
||||
del self.mock_bot.switch_model # Simulate bot not having the attribute
|
||||
response = await self.helper._switch_logic()
|
||||
self.assertEqual(response, "Model switching is not supported for this bot.")
|
||||
|
||||
async def test_handle_message_logic_success(self):
|
||||
user_id = 100
|
||||
user_message = "Hello bot"
|
||||
bot_response = "Hello user <think>Thinking hard</think> Done."
|
||||
expected_processed_response = f"Hello user {self.helper.HTML_QUOTE_BLOCK_START}Thinking hard{self.helper.HTML_QUOTE_BLOCK_END} Done."
|
||||
self.mock_bot.handle_message.return_value = bot_response
|
||||
|
||||
result = await self.helper._handle_message_logic(user_id, user_message)
|
||||
|
||||
self.mock_bot.handle_message.assert_called_once_with(user_id, user_message)
|
||||
self.assertTrue(result["success"])
|
||||
self.assertEqual(result["response_text"], expected_processed_response)
|
||||
self.assertIsNone(result["error_message"])
|
||||
|
||||
async def test_handle_message_logic_bot_exception(self):
|
||||
user_id = 101
|
||||
user_message = "Trigger error"
|
||||
self.mock_bot.handle_message.side_effect = Exception("Bot Error")
|
||||
|
||||
result = await self.helper._handle_message_logic(user_id, user_message)
|
||||
|
||||
self.assertFalse(result["success"])
|
||||
self.assertIsNone(result["response_text"])
|
||||
self.assertEqual(result["error_message"], "Bot Error")
|
||||
|
||||
@patch(\'logging.error\')
|
||||
async def test_handle_message_command_success_short_message(self, mock_logging_error):
|
||||
mock_update = create_mock_update(message_text="Hi", user_id=200, chat_id=201, message_id=202)
|
||||
mock_context = create_mock_context()
|
||||
|
||||
logic_result = MessageHandlerLogicResult(success=True, response_text="Short response", error_message=None)
|
||||
|
||||
with patch.object(self.helper, \'_handle_message_logic\', new_callable=AsyncMock) as mock_message_logic:
|
||||
mock_message_logic.return_value = logic_result
|
||||
|
||||
await self.helper.handle_message(mock_update, mock_context)
|
||||
|
||||
mock_update.message.reply_text.assert_any_call("Processing your request...", reply_markup=unittest.mock.ANY)
|
||||
self.mock_bot.set_processing_status.assert_called_once_with(200, 202) # user_id, status_message_id
|
||||
mock_message_logic.assert_called_once_with(200, "Hi")
|
||||
mock_context.bot.delete_message.assert_called_once_with(chat_id=201, message_id=202)
|
||||
self.mock_bot.clear_processing_status.assert_called_once_with(200)
|
||||
mock_update.message.reply_text.assert_any_call("Short response") # Final response
|
||||
self.assertEqual(mock_update.message.reply_text.call_count, 2) # Processing + final
|
||||
|
||||
@patch(\'logging.error\')
|
||||
async def test_handle_message_command_success_long_message_chunks(self, mock_logging_error):
|
||||
mock_update = create_mock_update(message_text="Long text", user_id=200, chat_id=201, message_id=202)
|
||||
mock_context = create_mock_context()
|
||||
|
||||
long_response_text = "a" * 5000 # Longer than 4096
|
||||
chunk1 = long_response_text[:4096]
|
||||
chunk2 = long_response_text[4096:]
|
||||
|
||||
logic_result = MessageHandlerLogicResult(success=True, response_text=long_response_text, error_message=None)
|
||||
|
||||
with patch.object(self.helper, \'_handle_message_logic\', new_callable=AsyncMock) as mock_message_logic, \
|
||||
patch(\'asyncio.sleep\', new_callable=AsyncMock) as mock_sleep: # Mock sleep
|
||||
mock_message_logic.return_value = logic_result
|
||||
|
||||
await self.helper.handle_message(mock_update, mock_context)
|
||||
|
||||
mock_update.message.reply_text.assert_any_call(chunk1)
|
||||
mock_update.message.reply_text.assert_any_call(chunk2)
|
||||
mock_sleep.assert_called_once_with(self.helper.chunk_message_sleep_duration)
|
||||
self.assertEqual(mock_update.message.reply_text.call_count, 3) # Processing + 2 chunks
|
||||
|
||||
@patch(\'logging.error\')
|
||||
async def test_handle_message_command_logic_fails(self, mock_logging_error):
|
||||
mock_update = create_mock_update(message_text="Cause error in logic", user_id=200)
|
||||
mock_context = create_mock_context()
|
||||
logic_result = MessageHandlerLogicResult(success=False, response_text=None, error_message="Logic Failed")
|
||||
|
||||
with patch.object(self.helper, \'_handle_message_logic\', new_callable=AsyncMock) as mock_message_logic:
|
||||
mock_message_logic.return_value = logic_result
|
||||
await self.helper.handle_message(mock_update, mock_context)
|
||||
mock_update.message.reply_text.assert_any_call("Sorry, an error occurred while processing your request.")
|
||||
self.assertEqual(mock_update.message.reply_text.call_count, 2) # Processing + error message
|
||||
|
||||
@patch(\'logging.error\')
|
||||
async def test_handle_message_command_telegram_exception_after_logic(self, mock_logging_error):
|
||||
mock_update = create_mock_update(message_text="Test", user_id=200)
|
||||
mock_context = create_mock_context()
|
||||
logic_result = MessageHandlerLogicResult(success=True, response_text="OK", error_message=None)
|
||||
|
||||
# Make sending the final reply fail
|
||||
mock_update.message.reply_text.side_effect = [
|
||||
MagicMock(message_id=202), # For "Processing..."
|
||||
Exception("Telegram API Error") # For the actual response
|
||||
]
|
||||
|
||||
with patch.object(self.helper, \'_handle_message_logic\', new_callable=AsyncMock) as mock_message_logic:
|
||||
mock_message_logic.return_value = logic_result
|
||||
await self.helper.handle_message(mock_update, mock_context)
|
||||
|
||||
# Check if the generic error message was attempted
|
||||
# This is tricky because reply_text is already mocked with side_effect.
|
||||
# We\'d expect logs. Let\'s check logs or if processing status was cleared.
|
||||
self.mock_bot.clear_processing_status.assert_called_once_with(200)
|
||||
mock_logging_error.assert_any_call(unittest.mock.string_containing("Outer error in handle_message"))
|
||||
|
||||
|
||||
async def test_abort_processing_logic(self):
|
||||
user_id = 300
|
||||
self.mock_bot.abort_processing.return_value = "Aborted by bot"
|
||||
response = await self.helper._abort_processing_logic(user_id)
|
||||
self.mock_bot.abort_processing.assert_called_once_with(user_id)
|
||||
self.assertEqual(response, "Aborted by bot")
|
||||
|
||||
async def test_abort_processing_command(self):
|
||||
mock_update = create_mock_update(callback_query_data=\'abort\', user_id=301)
|
||||
mock_context = create_mock_context()
|
||||
with patch.object(self.helper, \'_abort_processing_logic\', new_callable=AsyncMock) as mock_logic:
|
||||
mock_logic.return_value = "Abort Logic Done"
|
||||
await self.helper.abort_processing(mock_update, mock_context)
|
||||
|
||||
mock_update.callback_query.answer.assert_called_once()
|
||||
mock_logic.assert_called_once_with(301)
|
||||
mock_update.callback_query.edit_message_text.assert_called_once_with(text="Abort Logic Done")
|
||||
|
||||
def test_reboot_logic_claude_and_main(self):
|
||||
user_message_parts = ["/reboot", "claude"]
|
||||
chat_id_to_write = "12345"
|
||||
|
||||
with patch("builtins.open", mock_open()) as mock_file:
|
||||
self.helper._reboot_logic(user_message_parts, chat_id_to_write)
|
||||
|
||||
# Check claude reboot file
|
||||
mock_file.assert_any_call(self.reboot_claude_file, \'w\')
|
||||
# Check main doreboot file
|
||||
mock_file.assert_any_call(self.reboot_file, \'w\')
|
||||
handle_claude = mock_file.return_value
|
||||
handle_main = mock_file.return_value # mock_open reuses the handle for multiple calls
|
||||
|
||||
# Check if write was called for claude file (empty write)
|
||||
# This part of assertion is tricky with single mock_file. Better to use different mocks if possible
|
||||
# or check the sequence of calls if the mock supports it well.
|
||||
# For now, assert_any_call ensures it was opened.
|
||||
|
||||
# Check content for main reboot file
|
||||
# Need to ensure the write for self.reboot_file had chat_id_to_write
|
||||
# This requires more sophisticated mock_open or patching os.path.exists and multiple open calls
|
||||
# Simpler check: was open(self.reboot_file, \'w\') called? Yes, via assert_any_call.
|
||||
# And was open(self.reboot_claude_file, \'w\') called? Yes.
|
||||
|
||||
# Verify files were created (mock_open doesn\'t actually create them)
|
||||
# This test relies on mock_open\'s behavior. To test file content, need more setup.
|
||||
# For now, assume open was called correctly.
|
||||
|
||||
def test_reboot_logic_main_only(self):
|
||||
user_message_parts = ["/reboot"]
|
||||
chat_id_to_write = "67890"
|
||||
with patch("builtins.open", mock_open()) as mock_file:
|
||||
self.helper._reboot_logic(user_message_parts, chat_id_to_write)
|
||||
# Ensure claude file was NOT opened for writing.
|
||||
# This requires asserting that a specific call didn\'t happen, or checking call_args_list
|
||||
claude_call = unittest.mock.call(self.reboot_claude_file, \'w\')
|
||||
self.assertNotIn(claude_call, mock_file.call_args_list)
|
||||
|
||||
mock_file.assert_any_call(self.reboot_file, \'w\')
|
||||
|
||||
@patch(\'sys.exit\') # Mock sys.exit to prevent test runner from exiting
|
||||
async def test_reboot_command(self, mock_sys_exit):
|
||||
mock_update = create_mock_update(message_text="/reboot claude", chat_id="chat1")
|
||||
mock_context = create_mock_context()
|
||||
|
||||
with patch.object(self.helper, \'_reboot_logic\') as mock_reboot_file_logic:
|
||||
await self.helper.reboot(mock_update, mock_context)
|
||||
|
||||
mock_reboot_file_logic.assert_called_once_with(["/reboot", "claude"], "chat1")
|
||||
mock_update.message.reply_text.assert_called_once_with("Rebooting the bot...")
|
||||
mock_sys_exit.assert_called_once_with(0)
|
||||
|
||||
@patch(\'os.path.exists\')
|
||||
@patch(\'builtins.open\', new_callable=mock_open)
|
||||
@patch(\'os.remove\')
|
||||
async def test_check_doreboot_file_logic_file_exists(self, mock_os_remove, mock_file_open, mock_os_path_exists):
|
||||
mock_os_path_exists.return_value = True
|
||||
mock_file_open.return_value.read.return_value.strip.return_value = "chat123"
|
||||
|
||||
chat_id = await self.helper._check_doreboot_file_logic()
|
||||
|
||||
mock_os_path_exists.assert_called_once_with(self.reboot_file)
|
||||
mock_file_open.assert_called_once_with(self.reboot_file, \'r\')
|
||||
mock_os_remove.assert_called_once_with(self.reboot_file)
|
||||
self.assertEqual(chat_id, "chat123")
|
||||
|
||||
@patch(\'os.path.exists\', return_value=False)
|
||||
async def test_check_doreboot_file_logic_file_not_exists(self, mock_os_path_exists):
|
||||
chat_id = await self.helper._check_doreboot_file_logic()
|
||||
mock_os_path_exists.assert_called_once_with(self.reboot_file)
|
||||
self.assertIsNone(chat_id)
|
||||
|
||||
@patch(\'logging.error\')
|
||||
@patch(\'os.path.exists\', return_value=True)
|
||||
@patch(\'builtins.open\', side_effect=IOError("Read error"))
|
||||
@patch(\'os.remove\') # To check if remove is called even on read error
|
||||
async def test_check_doreboot_file_logic_read_error(self, mock_os_remove, mock_file_open, mock_os_path_exists, mock_log_error):
|
||||
chat_id = await self.helper._check_doreboot_file_logic()
|
||||
|
||||
self.assertIsNone(chat_id)
|
||||
mock_log_error.assert_any_call(unittest.mock.string_containing("Error reading reboot file"))
|
||||
# Check if os.remove was attempted even after read error
|
||||
mock_os_remove.assert_called_once_with(self.reboot_file)
|
||||
|
||||
|
||||
async def test_check_doreboot_file_command_sends_message(self):
|
||||
mock_application = MagicMock()
|
||||
mock_application.bot.send_message = AsyncMock()
|
||||
|
||||
with patch.object(self.helper, \'_check_doreboot_file_logic\', new_callable=AsyncMock) as mock_logic:
|
||||
mock_logic.return_value = "chat789" # Simulate chat_id found
|
||||
await self.helper.check_doreboot_file(mock_application)
|
||||
|
||||
mock_logic.assert_called_once()
|
||||
mock_application.bot.send_message.assert_called_once_with(
|
||||
chat_id="chat789", text="The application has finished initializing."
|
||||
)
|
||||
|
||||
async def test_check_doreboot_file_command_no_chat_id(self):
|
||||
mock_application = MagicMock()
|
||||
mock_application.bot.send_message = AsyncMock()
|
||||
|
||||
with patch.object(self.helper, \'_check_doreboot_file_logic\', new_callable=AsyncMock) as mock_logic:
|
||||
mock_logic.return_value = None # Simulate no chat_id found
|
||||
await self.helper.check_doreboot_file(mock_application)
|
||||
|
||||
mock_logic.assert_called_once()
|
||||
mock_application.bot.send_message.assert_not_called()
|
||||
|
||||
# Note: Testing the run() method itself is more of an integration test,
|
||||
# as it involves setting up the full Application and polling loop.
|
||||
# Unit tests here focus on the helper\'s own logic methods.
|
||||
|
||||
if __name__ == \'__main__\':
|
||||
unittest.main()
|
||||
@@ -1,307 +0,0 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
import os
|
||||
import base64
|
||||
import logging
|
||||
import requests # Required for spec in MagicMock
|
||||
|
||||
# Ensure tools/github_tool.py is accessible
|
||||
from tools.github_tool import GitHubTool
|
||||
|
||||
# Helper to create a mock response for requests.Session
|
||||
def create_mock_response(status_code, json_data=None, text_data=None, headers=None, links=None):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = status_code
|
||||
if json_data is not None:
|
||||
mock_resp.json = MagicMock(return_value=json_data)
|
||||
mock_resp.text = text_data if text_data is not None else str(json_data)
|
||||
mock_resp.headers = headers if headers else {}
|
||||
mock_resp.links = links if links else {} # For pagination in _list_branches
|
||||
return mock_resp
|
||||
|
||||
class TestGitHubTool(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.mock_session = MagicMock(spec=requests.Session)
|
||||
self.mock_session.headers = {} # Simulate a new session's headers
|
||||
|
||||
self.test_token = "test_github_token"
|
||||
self.test_repo = "owner/repo"
|
||||
self.test_base_url = "https://api.example.com" # Use a non-default base_url for some tests
|
||||
|
||||
# Suppress logging output during tests unless explicitly testing for it
|
||||
self.logger = logging.getLogger('tools.github_tool')
|
||||
# Ensure only one NullHandler to prevent duplicate messages if tests run multiple times in a session
|
||||
if not any(isinstance(h, logging.NullHandler) for h in self.logger.handlers):
|
||||
self.logger.addHandler(logging.NullHandler())
|
||||
self.logger.propagate = False # Prevent propagation to root logger if it has handlers
|
||||
|
||||
def test_init_with_args_and_session(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, base_url=self.test_base_url, logger=self.logger)
|
||||
self.assertEqual(tool.session, self.mock_session)
|
||||
self.assertEqual(tool._token, self.test_token)
|
||||
self.assertEqual(tool._repo, self.test_repo)
|
||||
self.assertEqual(tool.base_url, self.test_base_url)
|
||||
self.assertEqual(tool.current_branch, "main") # Default initial branch
|
||||
|
||||
@patch('requests.Session')
|
||||
def test_init_creates_session_if_not_provided(self, MockSessionConstructor):
|
||||
mock_created_session = MagicMock(spec=requests.Session)
|
||||
mock_created_session.headers = {}
|
||||
MockSessionConstructor.return_value = mock_created_session
|
||||
|
||||
# Temporarily set env vars for this test
|
||||
original_token = os.environ.get("GITHUB_TOKEN")
|
||||
original_repo = os.environ.get("GITHUB_REPOSITORY")
|
||||
os.environ["GITHUB_TOKEN"] = "env_token"
|
||||
os.environ["GITHUB_REPOSITORY"] = "env/repo"
|
||||
|
||||
tool = GitHubTool(logger=self.logger) # Use env vars
|
||||
|
||||
MockSessionConstructor.assert_called_once()
|
||||
self.assertEqual(tool.session, mock_created_session)
|
||||
self.assertEqual(tool._token, "env_token")
|
||||
self.assertEqual(tool._repo, "env/repo")
|
||||
self.assertIn("Authorization", mock_created_session.headers)
|
||||
self.assertEqual(mock_created_session.headers["Authorization"], "token env_token")
|
||||
|
||||
# Restore original env vars
|
||||
if original_token is None: del os.environ["GITHUB_TOKEN"]
|
||||
else: os.environ["GITHUB_TOKEN"] = original_token
|
||||
if original_repo is None: del os.environ["GITHUB_REPOSITORY"]
|
||||
else: os.environ["GITHUB_REPOSITORY"] = original_repo
|
||||
|
||||
def test_init_raises_value_error_if_no_token(self):
|
||||
original_token = os.environ.pop("GITHUB_TOKEN", None)
|
||||
with self.assertRaisesRegex(ValueError, "GitHub token must be provided"):
|
||||
GitHubTool(repo=self.test_repo, logger=self.logger)
|
||||
if original_token: os.environ["GITHUB_TOKEN"] = original_token
|
||||
|
||||
def test_init_raises_value_error_if_no_repo(self):
|
||||
original_repo = os.environ.pop("GITHUB_REPOSITORY", None)
|
||||
with self.assertRaisesRegex(ValueError, "GitHub repository.*must be provided"):
|
||||
GitHubTool(token=self.test_token, logger=self.logger)
|
||||
if original_repo: os.environ["GITHUB_REPOSITORY"] = original_repo
|
||||
|
||||
def test_clear_resets_branch(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, initial_branch="feature-branch", logger=self.logger)
|
||||
# Mock _get_branch_sha for _set_current_branch called by clear
|
||||
with patch.object(tool, '_get_branch_sha', return_value="sha_for_main"):
|
||||
tool.clear()
|
||||
self.assertEqual(tool.current_branch, "main")
|
||||
|
||||
def test_get_functions_returns_list(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
functions = tool.get_functions()
|
||||
self.assertIsInstance(functions, list)
|
||||
self.assertTrue(len(functions) > 0)
|
||||
self.assertIn("name", functions[0]["function"])
|
||||
|
||||
|
||||
# --- Test individual private methods ---
|
||||
|
||||
def test_read_file_success(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
file_content = "Hello World!"
|
||||
encoded_content = base64.b64encode(file_content.encode('utf-8')).decode('utf-8')
|
||||
self.mock_session.get.return_value = create_mock_response(200, json_data={"content": encoded_content})
|
||||
|
||||
result = tool._read_file(path="test.txt")
|
||||
self.assertEqual(result, file_content)
|
||||
self.mock_session.get.assert_called_once_with(
|
||||
f"{tool.base_url}/repos/{self.test_repo}/contents/test.txt",
|
||||
params={"ref": "main"}
|
||||
)
|
||||
|
||||
def test_read_file_error(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
self.mock_session.get.return_value = create_mock_response(404, text_data="Not Found")
|
||||
result = tool._read_file(path="nonexistent.txt")
|
||||
self.assertIn("Error reading file", result)
|
||||
|
||||
def test_create_branch_success(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
# Mock getting base branch SHA
|
||||
self.mock_session.get.return_value = create_mock_response(200, json_data={"object": {"sha": "base_sha123"}})
|
||||
# Mock creating new branch
|
||||
self.mock_session.post.return_value = create_mock_response(201, json_data={"ref": "refs/heads/new-feature"})
|
||||
|
||||
result = tool._create_branch(branch_name="new-feature", base_branch="main")
|
||||
self.assertIn("Branch 'new-feature' created successfully", result)
|
||||
self.assertEqual(tool.current_branch, "new-feature")
|
||||
self.mock_session.get.assert_called_once() # For base branch SHA
|
||||
self.mock_session.post.assert_called_once() # For creating branch
|
||||
|
||||
def test_create_branch_base_sha_error(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
self.mock_session.get.return_value = create_mock_response(404, text_data="Base branch not found")
|
||||
result = tool._create_branch(branch_name="new-feature", base_branch="nonexistent-base")
|
||||
self.assertIn("Error getting base branch SHA", result)
|
||||
|
||||
def test_create_branch_creation_error(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
self.mock_session.get.return_value = create_mock_response(200, json_data={"object": {"sha": "base_sha456"}})
|
||||
self.mock_session.post.return_value = create_mock_response(422, text_data="Validation failed")
|
||||
result = tool._create_branch(branch_name="bad-branch", base_branch="main")
|
||||
self.assertIn("Error creating branch", result)
|
||||
|
||||
def test_commit_file_success_new_file(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
tool.current_branch = "dev-branch" # Cannot commit to main by default
|
||||
|
||||
# Mock GET for checking file existence (404 means new file)
|
||||
self.mock_session.get.return_value = create_mock_response(404)
|
||||
# Mock PUT for committing file
|
||||
self.mock_session.put.return_value = create_mock_response(201, json_data={"commit": {"sha": "commit_sha_abc"}})
|
||||
|
||||
result = tool._commit_file(file_path="new_file.py", content="print('Hello')", commit_message="Add new_file.py")
|
||||
self.assertIn("committed successfully", result)
|
||||
self.assertIn("commit_sha_abc", result)
|
||||
self.mock_session.get.assert_called_once() # Check file existence
|
||||
self.mock_session.put.assert_called_once() # Commit file
|
||||
|
||||
def test_commit_file_success_update_file(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
tool.current_branch = "dev-branch"
|
||||
|
||||
# Mock GET for checking file existence (200 means existing file)
|
||||
self.mock_session.get.return_value = create_mock_response(200, json_data={"sha": "existing_file_sha"})
|
||||
# Mock PUT for committing file
|
||||
self.mock_session.put.return_value = create_mock_response(200, json_data={"commit": {"sha": "commit_sha_def"}})
|
||||
|
||||
result = tool._commit_file(file_path="existing_file.txt", content="Updated content", commit_message="Update existing_file.txt")
|
||||
self.assertIn("committed successfully", result)
|
||||
self.assertIn("commit_sha_def", result)
|
||||
args, kwargs = self.mock_session.put.call_args
|
||||
self.assertEqual(kwargs['json']['sha'], "existing_file_sha")
|
||||
|
||||
|
||||
def test_commit_file_to_main_branch_error(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
tool.current_branch = "main"
|
||||
result = tool._commit_file(file_path="some.txt", content="content", commit_message="msg")
|
||||
self.assertIn("Action directly to main branch is not allowed", result)
|
||||
|
||||
def test_create_pull_request_success(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
tool.current_branch = "feature-pr"
|
||||
pr_url = "https://example.com/pull/1"
|
||||
self.mock_session.post.return_value = create_mock_response(201, json_data={"html_url": pr_url, "number": 1})
|
||||
|
||||
result = tool._create_pull_request(title="New Feature PR", body="Please review.", base="main")
|
||||
self.assertIn(f"Pull request created successfully: {pr_url}", result)
|
||||
self.mock_session.post.assert_called_once()
|
||||
call_data = self.mock_session.post.call_args[1]['json']
|
||||
self.assertEqual(call_data['head'], "feature-pr")
|
||||
self.assertEqual(call_data['base'], "main")
|
||||
|
||||
def test_create_pull_request_same_branch_error(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
tool.current_branch = "main"
|
||||
result = tool._create_pull_request(title="PR to self", body="This should fail", base="main")
|
||||
self.assertIn("Cannot create a pull request from branch 'main' to itself", result)
|
||||
|
||||
|
||||
def test_list_files_success(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
mock_items = [
|
||||
{"name": "file1.txt", "type": "file", "path": "dir/file1.txt"},
|
||||
{"name": "subdir", "type": "dir", "path": "dir/subdir"}
|
||||
]
|
||||
self.mock_session.get.return_value = create_mock_response(200, json_data=mock_items)
|
||||
|
||||
result = tool._list_files(path="dir")
|
||||
self.assertEqual(len(result), 2)
|
||||
self.assertEqual(result[0]["name"], "file1.txt")
|
||||
self.assertEqual(result[1]["type"], "dir")
|
||||
|
||||
def test_search_code_success(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
mock_search_results = {
|
||||
"items": [{"path": "src/code.py", "html_url": "url1"}]
|
||||
}
|
||||
self.mock_session.get.return_value = create_mock_response(200, json_data=mock_search_results)
|
||||
|
||||
results = tool._search_code(query="my_function")
|
||||
self.assertEqual(len(results), 1)
|
||||
self.assertEqual(results[0]["path"], "src/code.py")
|
||||
|
||||
def test_get_commit_history_success(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
mock_commits = [{
|
||||
"sha": "sha1", "commit": {"message": "Msg1", "author": {"name": "Authy", "date": "Date1"}}
|
||||
}]
|
||||
self.mock_session.get.return_value = create_mock_response(200, json_data=mock_commits)
|
||||
|
||||
commits = tool._get_commit_history(file_path="file.txt", num_commits=1)
|
||||
self.assertEqual(len(commits), 1)
|
||||
self.assertEqual(commits[0]["sha"], "sha1")
|
||||
|
||||
def test_set_current_branch_success(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
# Mock _get_branch_sha to simulate branch exists
|
||||
with patch.object(tool, '_get_branch_sha', return_value="some_sha_for_dev"):
|
||||
result = tool._set_current_branch(branch_name="dev")
|
||||
self.assertEqual(tool.current_branch, "dev")
|
||||
self.assertIn("Current branch set to: dev", result)
|
||||
|
||||
def test_set_current_branch_not_exists(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
with patch.object(tool, '_get_branch_sha', return_value="Error getting SHA for branch"):
|
||||
result = tool._set_current_branch(branch_name="nonexistent-branch")
|
||||
self.assertNotEqual(tool.current_branch, "nonexistent-branch") # Should not change
|
||||
self.assertIn("Cannot set current branch", result)
|
||||
|
||||
|
||||
def test_list_branches_single_page(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
mock_branches = [{"name": "main"}, {"name": "dev"}]
|
||||
self.mock_session.get.return_value = create_mock_response(200, json_data=mock_branches, links={}) # No "next" link
|
||||
|
||||
branches = tool._list_branches(all_pages=True)
|
||||
self.assertEqual(branches, ["main", "dev"])
|
||||
self.mock_session.get.assert_called_once()
|
||||
|
||||
def test_list_branches_multiple_pages(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
|
||||
# Page 1 response
|
||||
page1_branches = [{"name": "branch1"}, {"name": "branch2"}]
|
||||
next_url = f"{tool.base_url}/repos/{self.test_repo}/branches?page=2"
|
||||
response1 = create_mock_response(200, json_data=page1_branches, links={"next": {"url": next_url}})
|
||||
|
||||
# Page 2 response
|
||||
page2_branches = [{"name": "branch3"}]
|
||||
response2 = create_mock_response(200, json_data=page2_branches, links={}) # No "next" link
|
||||
|
||||
self.mock_session.get.side_effect = [response1, response2]
|
||||
|
||||
branches = tool._list_branches(all_pages=True)
|
||||
self.assertEqual(branches, ["branch1", "branch2", "branch3"])
|
||||
self.assertEqual(self.mock_session.get.call_count, 2)
|
||||
|
||||
# Check that the second call used the next_url
|
||||
calls = self.mock_session.get.call_args_list
|
||||
self.assertEqual(calls[1][0][0], next_url) # args[0] is the URL
|
||||
|
||||
# --- Test execute dispatcher ---
|
||||
def test_execute_read_file(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
with patch.object(tool, '_read_file', return_value="file content") as mock_method:
|
||||
result = tool.execute(function_name="read_file", path="test.md")
|
||||
mock_method.assert_called_once_with(path="test.md")
|
||||
self.assertEqual(result, "file content")
|
||||
|
||||
def test_execute_unknown_function(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
result = tool.execute(function_name="non_existent_function_name", arg1="val1")
|
||||
self.assertIn("Unknown function: non_existent_function_name", result)
|
||||
|
||||
def test_execute_method_exception(self):
|
||||
tool = GitHubTool(session=self.mock_session, token=self.test_token, repo=self.test_repo, logger=self.logger)
|
||||
with patch.object(tool, '_read_file', side_effect=Exception("Kaboom")) as mock_method:
|
||||
result = tool.execute(function_name="read_file", path="crash.txt")
|
||||
self.assertIn("Error during read_file execution: Kaboom", result)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -1,146 +0,0 @@
|
||||
import unittest
|
||||
from unittest.mock import patch, mock_open, MagicMock
|
||||
import os
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
# Ensure tools/log_tool.py is accessible
|
||||
from tools.log_tool import LogTool
|
||||
|
||||
class TestLogTool(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.test_log_file_path = "test_dummy_log.log"
|
||||
# Suppress logging output during tests unless explicitly testing for it
|
||||
self.logger = logging.getLogger('tools.log_tool')
|
||||
# Ensure only one NullHandler to prevent duplicate messages if tests run multiple times in a session
|
||||
if not any(isinstance(h, logging.NullHandler) for h in self.logger.handlers):
|
||||
self.logger.addHandler(logging.NullHandler())
|
||||
self.logger.propagate = False # Prevent propagation to root logger if it has handlers
|
||||
|
||||
|
||||
def test_init_default_log_path(self):
|
||||
tool = LogTool(logger=self.logger)
|
||||
self.assertEqual(tool.configured_log_file_path, 'logs/output.log')
|
||||
|
||||
def test_init_custom_log_path(self):
|
||||
tool = LogTool(log_file_path=self.test_log_file_path, logger=self.logger)
|
||||
self.assertEqual(tool.configured_log_file_path, self.test_log_file_path)
|
||||
|
||||
def test_get_functions(self):
|
||||
tool = LogTool(logger=self.logger)
|
||||
functions = tool.get_functions()
|
||||
self.assertIsInstance(functions, list)
|
||||
self.assertEqual(len(functions), 1)
|
||||
self.assertEqual(functions[0]["function"]["name"], "get_log_contents")
|
||||
|
||||
@patch("os.path.exists", return_value=False)
|
||||
def test_get_log_contents_file_not_exists(self, mock_exists):
|
||||
tool = LogTool(log_file_path=self.test_log_file_path, logger=self.logger)
|
||||
result = tool._get_log_contents()
|
||||
self.assertIn("Log file does not exist", result)
|
||||
mock_exists.assert_called_once_with(self.test_log_file_path)
|
||||
|
||||
@patch("os.path.exists", return_value=True)
|
||||
@patch("builtins.open", new_callable=mock_open, read_data="line1\nline2\nline3\nline4\nline5")
|
||||
def test_get_log_contents_with_line_count(self, mock_file_open, mock_exists):
|
||||
tool = LogTool(log_file_path=self.test_log_file_path, logger=self.logger)
|
||||
|
||||
result = tool._get_log_contents(line_count=3)
|
||||
self.assertEqual(result, "line3\nline4\nline5")
|
||||
mock_exists.assert_called_once_with(self.test_log_file_path)
|
||||
mock_file_open.assert_called_once_with(self.test_log_file_path, 'r', encoding='utf-8')
|
||||
|
||||
@patch("os.path.exists", return_value=True)
|
||||
@patch("builtins.open", new_callable=mock_open, read_data="line1\nline2\n")
|
||||
def test_get_log_contents_line_count_more_than_available(self, mock_file_open, mock_exists):
|
||||
tool = LogTool(log_file_path=self.test_log_file_path, logger=self.logger)
|
||||
result = tool._get_log_contents(line_count=5)
|
||||
self.assertEqual(result, "line1\nline2\n")
|
||||
|
||||
@patch("os.path.exists", return_value=True)
|
||||
@patch("builtins.open", new_callable=mock_open, read_data="line1\nline2\n")
|
||||
def test_get_log_contents_invalid_line_count_uses_default(self, mock_file_open, mock_exists):
|
||||
tool = LogTool(log_file_path=self.test_log_file_path, logger=self.logger)
|
||||
# Test with zero, negative, and non-integer line_count (though type hint is int)
|
||||
# The code defaults to 150 if invalid. Here, we only have 2 lines.
|
||||
with patch.object(tool.logger, 'warning') as mock_log_warning:
|
||||
result_zero = tool._get_log_contents(line_count=0)
|
||||
self.assertEqual(result_zero, "line1\nline2\n")
|
||||
mock_log_warning.assert_any_call("Invalid line_count '0' provided, defaulting to fetch last 150 lines.")
|
||||
|
||||
mock_file_open.reset_mock() # Reset for next call
|
||||
result_neg = tool._get_log_contents(line_count=-5)
|
||||
self.assertEqual(result_neg, "line1\nline2\n")
|
||||
mock_log_warning.assert_any_call("Invalid line_count '-5' provided, defaulting to fetch last 150 lines.")
|
||||
|
||||
|
||||
@patch("os.path.exists", return_value=True)
|
||||
def test_get_log_contents_last_24_hours(self, mock_exists):
|
||||
tool = LogTool(log_file_path=self.test_log_file_path, logger=self.logger)
|
||||
|
||||
now = datetime.now()
|
||||
one_hour_ago_dt = now - timedelta(hours=1)
|
||||
two_days_ago_dt = now - timedelta(days=2)
|
||||
|
||||
one_hour_ago_str = one_hour_ago_dt.strftime(LogTool.EXPECTED_LOG_TIMESTAMP_FORMAT)
|
||||
two_days_ago_str = two_days_ago_dt.strftime(LogTool.EXPECTED_LOG_TIMESTAMP_FORMAT)
|
||||
|
||||
log_data = (
|
||||
f"{two_days_ago_str} - OLD - This is an old log entry.\n"
|
||||
f"No timestamp here - this line should be skipped by time filter.\n"
|
||||
f"{one_hour_ago_str} - RECENT - This is a recent log entry.\n"
|
||||
f"Malformed Date 2023-xx-01 - Another skipped line.\n"
|
||||
f"{now.strftime(LogTool.EXPECTED_LOG_TIMESTAMP_FORMAT)} - CURRENT - This is a current log entry.\n"
|
||||
)
|
||||
|
||||
expected_output = (
|
||||
f"{one_hour_ago_str} - RECENT - This is a recent log entry.\n"
|
||||
f"{now.strftime(LogTool.EXPECTED_LOG_TIMESTAMP_FORMAT)} - CURRENT - This is a current log entry.\n"
|
||||
)
|
||||
|
||||
with patch("builtins.open", mock_open(read_data=log_data)):
|
||||
result = tool._get_log_contents(line_count=None) # Trigger 24-hour logic
|
||||
self.assertEqual(result, expected_output)
|
||||
|
||||
@patch("os.path.exists", return_value=True)
|
||||
@patch("builtins.open", side_effect=IOError("File read error!"))
|
||||
def test_get_log_contents_file_read_exception(self, mock_file_open, mock_exists):
|
||||
tool = LogTool(log_file_path=self.test_log_file_path, logger=self.logger)
|
||||
result = tool._get_log_contents(line_count=10)
|
||||
self.assertIn("An error occurred while reading the log file: File read error!", result)
|
||||
|
||||
def test_execute_get_log_contents(self):
|
||||
tool = LogTool(logger=self.logger)
|
||||
mock_return_value = "Mocked log content"
|
||||
with patch.object(tool, '_get_log_contents', return_value=mock_return_value) as mock_method:
|
||||
result = tool.execute(function_name="get_log_contents", line_count=50)
|
||||
mock_method.assert_called_once_with(line_count=50)
|
||||
self.assertEqual(result, mock_return_value)
|
||||
|
||||
def test_execute_get_log_contents_no_line_count(self):
|
||||
tool = LogTool(logger=self.logger)
|
||||
mock_return_value = "Mocked log content for 24h"
|
||||
with patch.object(tool, '_get_log_contents', return_value=mock_return_value) as mock_method:
|
||||
result = tool.execute(function_name="get_log_contents") # No line_count
|
||||
mock_method.assert_called_once_with(line_count=None) # Expects None to trigger 24h
|
||||
self.assertEqual(result, mock_return_value)
|
||||
|
||||
|
||||
def test_execute_unknown_function(self):
|
||||
tool = LogTool(logger=self.logger)
|
||||
result = tool.execute(function_name="non_existent_log_function")
|
||||
self.assertIn("Unknown function: non_existent_log_function", result)
|
||||
|
||||
def test_clear_method(self):
|
||||
tool = LogTool(logger=self.logger)
|
||||
# Set a specific level for the logger for this test if needed to capture debug
|
||||
original_level = tool.logger.level
|
||||
tool.logger.setLevel(logging.DEBUG)
|
||||
with self.assertLogs(tool.logger, level='DEBUG') as cm:
|
||||
tool.clear()
|
||||
self.assertTrue(any("LogTool clear called" in message for message in cm.output))
|
||||
tool.logger.setLevel(original_level) # Reset level
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -1,217 +0,0 @@
|
||||
import unittest
|
||||
from unittest.mock import patch, MagicMock, ANY
|
||||
import time
|
||||
import logging
|
||||
|
||||
# Ensure tools.metrics is accessible
|
||||
from tools.metrics import Metrics # Import the class itself for direct testing
|
||||
from tools.metrics import metrics as global_metrics_instance # Import the global instance
|
||||
|
||||
# A simple function to decorate for testing
|
||||
def sample_function_for_metrics(duration=0.01):
|
||||
# Simulate some work
|
||||
# Note: time.sleep is not always precisely profiled by cProfile in the same way as CPU-bound work.
|
||||
# For testing, we will mock the cProfile/pstats interaction rather than relying on actual sleep duration.
|
||||
if duration > 0: # Make it conditional so we can test zero-time case too
|
||||
pass # The actual work is not important when mocking cProfile results
|
||||
return "sample_output"
|
||||
|
||||
def another_sample_function(x, y):
|
||||
return x + y
|
||||
|
||||
class TestMetrics(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Create a fresh Metrics instance for most tests to avoid interference
|
||||
self.logger = logging.getLogger('tools.metrics.test')
|
||||
if not self.logger.handlers: # Avoid adding handler multiple times
|
||||
self.logger.addHandler(logging.NullHandler())
|
||||
self.metrics_instance = Metrics(logger=self.logger)
|
||||
|
||||
# Clear the global instance before each test that might use it
|
||||
global_metrics_instance.clear_metrics()
|
||||
|
||||
def test_measure_decorator_counts_calls(self):
|
||||
decorated_func = self.metrics_instance.measure(sample_function_for_metrics)
|
||||
|
||||
self.assertEqual(self.metrics_instance.call_count["sample_function_for_metrics"], 0)
|
||||
decorated_func()
|
||||
self.assertEqual(self.metrics_instance.call_count["sample_function_for_metrics"], 1)
|
||||
decorated_func()
|
||||
self.assertEqual(self.metrics_instance.call_count["sample_function_for_metrics"], 2)
|
||||
|
||||
@patch('cProfile.Profile')
|
||||
@patch('pstats.Stats')
|
||||
def test_measure_decorator_records_time(self, MockPStats, MockCProfile):
|
||||
# Mock cProfile and pstats to control the time value
|
||||
mock_profiler_instance = MockCProfile.return_value
|
||||
mock_pstats_instance = MockPStats.return_value
|
||||
|
||||
# Simulate that pstats.Stats.stats dictionary contains the function's stats
|
||||
# Key: (filename, lineno, funcname)
|
||||
# Value: (cc, nc, tt, ct, callers) where ct is cumulative_time (index 3)
|
||||
|
||||
# Get code object of the function *before* decoration for correct key
|
||||
original_func_code = sample_function_for_metrics.__code__
|
||||
func_key = (original_func_code.co_filename, original_func_code.co_firstlineno, original_func_code.co_name)
|
||||
|
||||
# Configure mock_pstats_instance.stats to return our desired time
|
||||
mock_pstats_instance.stats = {func_key: (1, 1, 0.05, 0.123, {})} # cc, nc, tt, ct=0.123
|
||||
|
||||
decorated_func = self.metrics_instance.measure(sample_function_for_metrics)
|
||||
|
||||
self.assertEqual(self.metrics_instance.total_time["sample_function_for_metrics"], 0)
|
||||
|
||||
# Call the decorated function
|
||||
decorated_func(duration=0) # Duration arg doesn't matter due to mocking
|
||||
|
||||
# Assertions
|
||||
mock_profiler_instance.enable.assert_called_once()
|
||||
mock_profiler_instance.disable.assert_called_once()
|
||||
MockPStats.assert_called_once_with(mock_profiler_instance)
|
||||
|
||||
self.assertEqual(self.metrics_instance.total_time["sample_function_for_metrics"], 0.123)
|
||||
|
||||
# Call again to see accumulation
|
||||
# Reset mock stats for a new time value if needed, or assume same time per call
|
||||
mock_pstats_instance.stats = {func_key: (1, 1, 0.05, 0.100, {})} # New ct=0.100
|
||||
decorated_func(duration=0)
|
||||
self.assertAlmostEqual(self.metrics_instance.total_time["sample_function_for_metrics"], 0.123 + 0.100)
|
||||
|
||||
|
||||
@patch('cProfile.Profile')
|
||||
@patch('pstats.Stats')
|
||||
def test_measure_decorator_fallback_time_recording_by_name(self, MockPStats, MockCProfile):
|
||||
mock_profiler_instance = MockCProfile.return_value
|
||||
mock_pstats_instance = MockPStats.return_value
|
||||
|
||||
original_func_code = sample_function_for_metrics.__code__ # func to be decorated
|
||||
# Simulate the primary key lookup fails by creating a slightly different key for what we expect
|
||||
# This is what the code will try to look up first.
|
||||
expected_primary_key = (original_func_code.co_filename, original_func_code.co_firstlineno, original_func_code.co_name)
|
||||
|
||||
# This is the key that will *actually* be in pstats.stats, simulating a mismatch for primary lookup
|
||||
# but a match for the by-name fallback.
|
||||
actual_stats_key_in_pstats = (original_func_code.co_filename,
|
||||
original_func_code.co_firstlineno + 5, # simulate a lineno difference for primary key mismatch
|
||||
original_func_code.co_name) # Name is the same for fallback
|
||||
|
||||
mock_pstats_instance.stats = {
|
||||
# expected_primary_key is NOT present
|
||||
actual_stats_key_in_pstats: (1, 1, 0.03, 0.077, {}) # ct = 0.077
|
||||
}
|
||||
|
||||
decorated_func = self.metrics_instance.measure(sample_function_for_metrics)
|
||||
|
||||
# Expecting a debug log for fallback, but assertLogs needs the logger to have a handler that captures.
|
||||
# self.logger is already set up with NullHandler. For this test, let's use a specific logger.
|
||||
metrics_internal_logger = logging.getLogger('tools.metrics') # Logger used inside Metrics class
|
||||
original_level = metrics_internal_logger.level
|
||||
metrics_internal_logger.setLevel(logging.DEBUG)
|
||||
|
||||
with self.assertLogs(metrics_internal_logger, level='DEBUG') as log_capture:
|
||||
decorated_func(duration=0)
|
||||
|
||||
metrics_internal_logger.setLevel(original_level) # Reset logger level
|
||||
|
||||
self.assertTrue(any("Found stats for sample_function_for_metrics by name" in msg for msg in log_capture.output))
|
||||
self.assertEqual(self.metrics_instance.total_time["sample_function_for_metrics"], 0.077)
|
||||
|
||||
|
||||
@patch('cProfile.Profile')
|
||||
@patch('pstats.Stats')
|
||||
def test_measure_decorator_handles_func_stats_not_found(self, MockPStats, MockCProfile):
|
||||
mock_profiler_instance = MockCProfile.return_value
|
||||
mock_pstats_instance = MockPStats.return_value
|
||||
mock_pstats_instance.stats = {} # Empty stats, function will not be found
|
||||
|
||||
decorated_func = self.metrics_instance.measure(sample_function_for_metrics)
|
||||
|
||||
metrics_internal_logger = logging.getLogger('tools.metrics')
|
||||
original_level = metrics_internal_logger.level
|
||||
metrics_internal_logger.setLevel(logging.WARNING)
|
||||
with self.assertLogs(metrics_internal_logger, level='WARNING') as log_capture:
|
||||
decorated_func(duration=0)
|
||||
metrics_internal_logger.setLevel(original_level)
|
||||
|
||||
self.assertTrue(any("Could not find exact cProfile stats" in msg for msg in log_capture.output))
|
||||
self.assertEqual(self.metrics_instance.total_time["sample_function_for_metrics"], 0)
|
||||
|
||||
|
||||
def test_get_metrics_empty(self):
|
||||
self.assertEqual(self.metrics_instance.get_metrics(), {})
|
||||
|
||||
@patch('cProfile.Profile')
|
||||
@patch('pstats.Stats')
|
||||
def test_get_metrics_with_data(self, MockPStats, MockCProfile):
|
||||
mock_pstats_instance = MockPStats.return_value
|
||||
|
||||
# Decorate two different functions
|
||||
decorated_func1 = self.metrics_instance.measure(sample_function_for_metrics)
|
||||
decorated_func2 = self.metrics_instance.measure(another_sample_function)
|
||||
|
||||
# Data for func1
|
||||
func1_code = sample_function_for_metrics.__code__
|
||||
func1_key = (func1_code.co_filename, func1_code.co_firstlineno, func1_code.co_name)
|
||||
mock_pstats_instance.stats = {func1_key: (1,1,0.1,0.1,{})}
|
||||
decorated_func1()
|
||||
|
||||
# Data for func2
|
||||
func2_code = another_sample_function.__code__
|
||||
func2_key = (func2_code.co_filename, func2_code.co_firstlineno, func2_code.co_name)
|
||||
mock_pstats_instance.stats = {func2_key: (1,1,0.2,0.2,{})} # Cumulative time 0.2
|
||||
decorated_func2(1,2)
|
||||
mock_pstats_instance.stats = {func2_key: (1,1,0.3,0.3,{})} # Cumulative time 0.3 for second call
|
||||
decorated_func2(3,4)
|
||||
|
||||
metrics_data = self.metrics_instance.get_metrics()
|
||||
|
||||
self.assertIn("sample_function_for_metrics", metrics_data)
|
||||
self.assertEqual(metrics_data["sample_function_for_metrics"]["call_count"], 1)
|
||||
self.assertEqual(metrics_data["sample_function_for_metrics"]["total_time"], 0.1)
|
||||
self.assertEqual(metrics_data["sample_function_for_metrics"]["average_time"], 0.1)
|
||||
|
||||
self.assertIn("another_sample_function", metrics_data)
|
||||
self.assertEqual(metrics_data["another_sample_function"]["call_count"], 2)
|
||||
self.assertAlmostEqual(metrics_data["another_sample_function"]["total_time"], 0.5)
|
||||
self.assertAlmostEqual(metrics_data["another_sample_function"]["average_time"], 0.25)
|
||||
|
||||
|
||||
def test_clear_metrics(self):
|
||||
# Add some data
|
||||
self.metrics_instance.call_count["test_func"] = 5
|
||||
self.metrics_instance.total_time["test_func"] = 1.234
|
||||
|
||||
self.metrics_instance.clear_metrics()
|
||||
|
||||
self.assertEqual(self.metrics_instance.call_count, {})
|
||||
self.assertEqual(self.metrics_instance.total_time, {})
|
||||
self.assertEqual(self.metrics_instance.get_metrics(), {})
|
||||
|
||||
# Test the global instance
|
||||
@patch('cProfile.Profile')
|
||||
@patch('pstats.Stats')
|
||||
def test_global_metrics_instance_usage(self, MockPStats, MockCProfile):
|
||||
mock_pstats_instance = MockPStats.return_value
|
||||
|
||||
# Decorate a function with the global instance
|
||||
@global_metrics_instance.measure
|
||||
def globally_decorated_func():
|
||||
return "global_output"
|
||||
|
||||
# Setup mock stats for the globally decorated function
|
||||
# Access __wrapped__ to get the original function if other decorators might be present or for consistency.
|
||||
original_g_func = globally_decorated_func.__wrapped__
|
||||
func_code = original_g_func.__code__
|
||||
func_key = (func_code.co_filename, func_code.co_firstlineno, func_code.co_name)
|
||||
mock_pstats_instance.stats = {func_key: (1,1,0.05,0.05,{})}
|
||||
|
||||
globally_decorated_func()
|
||||
|
||||
metrics_data = global_metrics_instance.get_metrics()
|
||||
self.assertIn("globally_decorated_func", metrics_data)
|
||||
self.assertEqual(metrics_data["globally_decorated_func"]["call_count"], 1)
|
||||
self.assertEqual(metrics_data["globally_decorated_func"]["total_time"], 0.05)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -1,161 +0,0 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
import logging
|
||||
|
||||
# Ensure tools.metrics_tool and tools.metrics are accessible
|
||||
from tools.metrics_tool import MetricsTool
|
||||
from tools.metrics import Metrics # Used for typehinting and creating a mockable instance
|
||||
|
||||
class TestMetricsTool(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.mock_metrics_provider = MagicMock(spec=Metrics)
|
||||
self.logger = logging.getLogger('tools.metrics_tool.test')
|
||||
if not self.logger.handlers:
|
||||
self.logger.addHandler(logging.NullHandler())
|
||||
self.logger.propagate = False
|
||||
|
||||
|
||||
def test_init_with_provider(self):
|
||||
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
|
||||
self.assertEqual(tool.metrics_provider, self.mock_metrics_provider)
|
||||
|
||||
@patch('tools.metrics_tool.global_metrics_instance') # Patch the global instance path
|
||||
def test_init_default_provider(self, mock_global_metrics):
|
||||
tool = MetricsTool(logger=self.logger)
|
||||
self.assertEqual(tool.metrics_provider, mock_global_metrics)
|
||||
|
||||
def test_get_functions(self):
|
||||
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
|
||||
functions = tool.get_functions()
|
||||
self.assertIsInstance(functions, list)
|
||||
self.assertTrue(len(functions) == 3) # Based on current definition
|
||||
self.assertIn("get_function_metrics", [f["function"]["name"] for f in functions])
|
||||
self.assertIn("get_specific_function_metrics", [f["function"]["name"] for f in functions])
|
||||
self.assertIn("get_top_n_functions", [f["function"]["name"] for f in functions])
|
||||
|
||||
def test_execute_get_function_metrics(self):
|
||||
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
|
||||
expected_metrics = {"func1": {"call_count": 1, "total_time": 0.1}}
|
||||
self.mock_metrics_provider.get_metrics.return_value = expected_metrics
|
||||
|
||||
result = tool.execute(function_name="get_function_metrics")
|
||||
|
||||
self.mock_metrics_provider.get_metrics.assert_called_once()
|
||||
self.assertEqual(result, expected_metrics)
|
||||
|
||||
def test_execute_get_specific_function_metrics_found(self):
|
||||
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
|
||||
func_metrics = {"call_count": 5, "total_time": 0.5, "average_time": 0.1}
|
||||
all_metrics = {"specific_func": func_metrics, "other_func": {}}
|
||||
self.mock_metrics_provider.get_metrics.return_value = all_metrics
|
||||
|
||||
# The execute method expects kwargs that match the function parameters in get_functions.
|
||||
# So, the argument name for the function to get is 'function_name' in the tool's spec.
|
||||
result = tool.execute(function_name="get_specific_function_metrics", **{"function_name": "specific_func"})
|
||||
self.assertEqual(result, func_metrics)
|
||||
|
||||
def test_execute_get_specific_function_metrics_not_found(self):
|
||||
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
|
||||
self.mock_metrics_provider.get_metrics.return_value = {"other_func": {}}
|
||||
|
||||
result = tool.execute(function_name="get_specific_function_metrics", **{"function_name": "non_existent_func"})
|
||||
self.assertEqual(result, "No metrics found for function: non_existent_func")
|
||||
|
||||
def test_execute_get_specific_function_metrics_missing_arg(self):
|
||||
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
|
||||
result = tool.execute(function_name="get_specific_function_metrics") # Missing function_name kwarg
|
||||
self.assertIn("Error: Missing required argument 'function_name'", result)
|
||||
|
||||
|
||||
def test_execute_get_top_n_functions(self):
|
||||
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
|
||||
metrics_data = {
|
||||
"func_a": {"call_count": 1, "total_time": 0.3},
|
||||
"func_b": {"call_count": 1, "total_time": 0.1},
|
||||
"func_c": {"call_count": 1, "total_time": 0.5},
|
||||
"func_d": {"call_count": 1, "total_time": 0.2},
|
||||
}
|
||||
self.mock_metrics_provider.get_metrics.return_value = metrics_data
|
||||
|
||||
# Test getting top 2
|
||||
result = tool.execute(function_name="get_top_n_functions", n=2)
|
||||
expected_top_2 = {"func_c": metrics_data["func_c"], "func_a": metrics_data["func_a"]}
|
||||
self.assertEqual(result, expected_top_2)
|
||||
|
||||
# Test getting top 1
|
||||
result_top_1 = tool.execute(function_name="get_top_n_functions", n=1)
|
||||
expected_top_1 = {"func_c": metrics_data["func_c"]}
|
||||
self.assertEqual(result_top_1, expected_top_1)
|
||||
|
||||
# Test N larger than available functions
|
||||
result_top_all = tool.execute(function_name="get_top_n_functions", n=10)
|
||||
# Order should be func_c, func_a, func_d, func_b
|
||||
expected_top_all_keys = ["func_c", "func_a", "func_d", "func_b"]
|
||||
self.assertEqual(list(result_top_all.keys()), expected_top_all_keys)
|
||||
|
||||
def test_execute_get_top_n_functions_malformed_metrics(self):
|
||||
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
|
||||
metrics_data = {
|
||||
"func_a": {"call_count": 1, "total_time": 0.3},
|
||||
"func_b": "not a dict", # Malformed
|
||||
"func_c": {"call_count": 1}, # Missing total_time
|
||||
"func_d": {"call_count": 1, "total_time": 0.2},
|
||||
}
|
||||
self.mock_metrics_provider.get_metrics.return_value = metrics_data
|
||||
|
||||
metrics_tool_logger = logging.getLogger('tools.metrics_tool')
|
||||
original_level = metrics_tool_logger.level
|
||||
metrics_tool_logger.setLevel(logging.WARNING)
|
||||
with self.assertLogs(metrics_tool_logger, level='WARNING') as log_capture:
|
||||
result = tool.execute(function_name="get_top_n_functions", n=2)
|
||||
metrics_tool_logger.setLevel(original_level)
|
||||
|
||||
# Check that warnings were logged for malformed items
|
||||
self.assertTrue(any("Metric item for 'func_b' is not in expected format" in msg for msg in log_capture.output))
|
||||
self.assertTrue(any("Metric item for 'func_c' is not in expected format" in msg for msg in log_capture.output))
|
||||
|
||||
# Expected: func_a, func_d (as they are valid and sortable)
|
||||
expected_result = {
|
||||
"func_a": metrics_data["func_a"],
|
||||
"func_d": metrics_data["func_d"]
|
||||
}
|
||||
self.assertEqual(result, expected_result)
|
||||
|
||||
|
||||
def test_execute_get_top_n_functions_invalid_n(self):
|
||||
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
|
||||
self.mock_metrics_provider.get_metrics.return_value = {} # No metrics needed for this test
|
||||
|
||||
result_zero = tool.execute(function_name="get_top_n_functions", n=0)
|
||||
self.assertIn("Error: Argument 'n' must be a positive integer.", result_zero)
|
||||
|
||||
result_negative = tool.execute(function_name="get_top_n_functions", n=-1)
|
||||
self.assertIn("Error: Argument 'n' must be a positive integer.", result_negative)
|
||||
|
||||
result_string = tool.execute(function_name="get_top_n_functions", n="abc")
|
||||
self.assertIn("Error: Argument 'n' must be an integer.", result_string)
|
||||
|
||||
def test_execute_get_top_n_functions_missing_arg_n(self):
|
||||
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
|
||||
result = tool.execute(function_name="get_top_n_functions") # Missing n
|
||||
self.assertIn("Error: Missing required argument 'n'.", result)
|
||||
|
||||
|
||||
def test_execute_unknown_function(self):
|
||||
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
|
||||
result = tool.execute(function_name="non_existent_metrics_function")
|
||||
self.assertIn("Unknown function: non_existent_metrics_function", result)
|
||||
|
||||
def test_clear_method(self):
|
||||
tool = MetricsTool(metrics_provider=self.mock_metrics_provider, logger=self.logger)
|
||||
metrics_tool_logger = logging.getLogger('tools.metrics_tool')
|
||||
original_level = metrics_tool_logger.level
|
||||
metrics_tool_logger.setLevel(logging.DEBUG)
|
||||
with self.assertLogs(metrics_tool_logger, level='DEBUG') as cm:
|
||||
tool.clear()
|
||||
metrics_tool_logger.setLevel(original_level)
|
||||
self.assertTrue(any("MetricsTool clear method called" in message for message in cm.output))
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user