Refactor: Inject dependencies in AnthropicTelegramInferenceBot

This commit is contained in:
cyclop-bot
2025-06-02 16:41:53 -05:00
parent 0c1a0d1e5b
commit e203afa493
+111 -64
View File
@@ -3,25 +3,48 @@ import json
import logging import logging
from anthropic import Anthropic, APIError, RateLimitError from anthropic import Anthropic, APIError, RateLimitError
from base_telegram_inference_bot import BaseTelegramInferenceBot from base_telegram_inference_bot import BaseTelegramInferenceBot
from telegram_helper import TelegramHelper from telegram_helper import TelegramHelper # Used in main, not class
class AnthropicTelegramInferenceBot(BaseTelegramInferenceBot): class AnthropicTelegramInferenceBot(BaseTelegramInferenceBot):
def __init__(self): DEFAULT_SMALL_MODEL_NAME = "claude-3-haiku-20240307"
super().__init__() DEFAULT_SMALL_MODEL_MAX_TOKENS = "2048"
self.anthropic_client = Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY")) DEFAULT_LARGE_MODEL_NAME = "claude-3-opus-20240229"
DEFAULT_LARGE_MODEL_MAX_TOKENS = "4096"
# Initialize with the small model by default
self.small_model_name = os.environ.get("ANTHROPIC_SMALL_MODEL", "claude-3-haiku-20240307")
self.small_model_max_tokens = os.environ.get("ANTHROPIC_SMALL_MODEL_MAX_TOKENS", "2048")
self.large_model_name = os.environ.get("ANTHROPIC_LARGE_MODEL", "claude-3-opus-20240229")
self.large_model_max_tokens = os.environ.get("ANTHROPIC_LARGE_MODEL_MAX_TOKENS", "4096")
def __init__(
self,
anthropic_client: Anthropic | None = None,
api_key: str | None = None,
small_model_name: str | None = None,
small_model_max_tokens: str | None = None,
large_model_name: str | None = None,
large_model_max_tokens: str | None = None,
system_prompt_content: str | None = None,
system_prompt_path: str | None = None
):
super().__init__(system_prompt_content=system_prompt_content, system_prompt_path=system_prompt_path)
if anthropic_client:
self.anthropic_client = anthropic_client
else:
_api_key = api_key or os.environ.get("ANTHROPIC_API_KEY")
if not _api_key:
raise ValueError("Anthropic API key must be provided either via argument or ANTHROPIC_API_KEY environment variable.")
self.anthropic_client = Anthropic(api_key=_api_key)
self.small_model_name = small_model_name or os.environ.get("ANTHROPIC_SMALL_MODEL") or self.DEFAULT_SMALL_MODEL_NAME
self.small_model_max_tokens_str = small_model_max_tokens or os.environ.get("ANTHROPIC_SMALL_MODEL_MAX_TOKENS") or self.DEFAULT_SMALL_MODEL_MAX_TOKENS
self.large_model_name = large_model_name or os.environ.get("ANTHROPIC_LARGE_MODEL") or self.DEFAULT_LARGE_MODEL_NAME
self.large_model_max_tokens_str = large_model_max_tokens or os.environ.get("ANTHROPIC_LARGE_MODEL_MAX_TOKENS") or self.DEFAULT_LARGE_MODEL_MAX_TOKENS
# Initialize with the small model by default
self._configure_model_and_tokens( self._configure_model_and_tokens(
self.small_model_name, self.small_model_name,
self.small_model_max_tokens self.small_model_max_tokens_str,
default_max_tokens=int(self.DEFAULT_SMALL_MODEL_MAX_TOKENS) # pass int for default
) )
def _configure_model_and_tokens(self, model_name, max_tokens_str, default_max_tokens=2048): # Default max_tokens adjusted for typical "small" def _configure_model_and_tokens(self, model_name: str, max_tokens_str: str, default_max_tokens: int = 2048):
self.model = model_name self.model = model_name
try: try:
self.max_tokens = int(max_tokens_str) if max_tokens_str is not None else default_max_tokens self.max_tokens = int(max_tokens_str) if max_tokens_str is not None else default_max_tokens
@@ -65,17 +88,19 @@ class AnthropicTelegramInferenceBot(BaseTelegramInferenceBot):
def _format_tool_response_for_anthropic(self, tool_response_data): def _format_tool_response_for_anthropic(self, tool_response_data):
if isinstance(tool_response_data, str): if isinstance(tool_response_data, str):
# Wrap plain string in a list of text blocks if not already structured
return [{"type": "text", "text": tool_response_data}] return [{"type": "text", "text": tool_response_data}]
elif isinstance(tool_response_data, list) and all(isinstance(item, dict) and "type" in item for item in tool_response_data):
# Already a list of content blocks
return tool_response_data
elif isinstance(tool_response_data, (dict, list)): elif isinstance(tool_response_data, (dict, list)):
# Attempt to JSON dump other dicts/lists if not already in content block format
try: try:
is_valid_block_list = isinstance(tool_response_data, list) and all(isinstance(item, dict) and "type" in item for item in tool_response_data) return [{"type": "text", "text": json.dumps(tool_response_data)}]
if is_valid_block_list:
return tool_response_data
else:
return [{"type": "text", "text": json.dumps(tool_response_data)}]
except (TypeError, json.JSONDecodeError): except (TypeError, json.JSONDecodeError):
return [{"type": "text", "text": str(tool_response_data)}] return [{"type": "text", "text": str(tool_response_data)}] # Fallback to string
else: else:
# Fallback for other types (int, float, etc.)
return [{"type": "text", "text": str(tool_response_data)}] return [{"type": "text", "text": str(tool_response_data)}]
async def handle_message(self, user_id, user_message): async def handle_message(self, user_id, user_message):
@@ -87,14 +112,14 @@ class AnthropicTelegramInferenceBot(BaseTelegramInferenceBot):
MAX_TOOL_ITERATIONS = 5 MAX_TOOL_ITERATIONS = 5
tool_use_count = 0 tool_use_count = 0
assistant_response_content = "" assistant_response_content = ""
while tool_use_count < MAX_TOOL_ITERATIONS: while tool_use_count < MAX_TOOL_ITERATIONS:
response = self.get_chat_response(current_turn_messages) response = self.get_chat_response(current_turn_messages)
if not response or not response.content: if not response or not response.content:
logging.error("No valid response content from Anthropic LLM.") logging.error("No valid response content from Anthropic LLM.")
self.conversation_history[user_id] = current_turn_messages self.conversation_history[user_id] = current_turn_messages # Save current state
return "Error: Could not get a valid response from the LLM." return "Error: Could not get a valid response from the LLM."
assistant_current_turn_content_blocks = response.content assistant_current_turn_content_blocks = response.content
@@ -111,23 +136,22 @@ class AnthropicTelegramInferenceBot(BaseTelegramInferenceBot):
assistant_response_content = "".join(text_parts_from_assistant) assistant_response_content = "".join(text_parts_from_assistant)
if not tool_calls_from_response: if not tool_calls_from_response:
break break
tool_results_for_model = [] tool_results_for_model = []
for tool_call in tool_calls_from_response: for tool_call in tool_calls_from_response:
tool_name = tool_call.name tool_name = tool_call.name
tool_input = tool_call.input tool_input = tool_call.input
tool_use_id = tool_call.id tool_use_id = tool_call.id
logging.info(f"Attempting to call Anthropic tool: {tool_name} with input: {tool_input}") logging.info(f"Attempting to call Anthropic tool: {tool_name} with input: {tool_input}")
try: try:
tool_response_data = self.call_tool(tool_name, tool_input) tool_response_data = self.call_tool(tool_name, tool_input)
tool_result_content_block = self._format_tool_response_for_anthropic(tool_response_data) tool_result_content_block = self._format_tool_response_for_anthropic(tool_response_data)
tool_results_for_model.append({ tool_results_for_model.append({
"type": "tool_result", "type": "tool_result",
"tool_use_id": tool_use_id, "tool_use_id": tool_use_id,
"content": tool_result_content_block "content": tool_result_content_block
}) })
except Exception as e: except Exception as e:
logging.error(f"Error calling tool {tool_name}: {e}") logging.error(f"Error calling tool {tool_name}: {e}")
@@ -135,14 +159,18 @@ class AnthropicTelegramInferenceBot(BaseTelegramInferenceBot):
"type": "tool_result", "type": "tool_result",
"tool_use_id": tool_use_id, "tool_use_id": tool_use_id,
"content": [{"type": "text", "text": f"Error executing tool {tool_name}: {str(e)}"}], "content": [{"type": "text", "text": f"Error executing tool {tool_name}: {str(e)}"}],
"is_error": True "is_error": True
}) })
current_turn_messages.append({"role": "user", "content": tool_results_for_model}) current_turn_messages.append({"role": "user", "content": tool_results_for_model}) # Anthropic expects tool results as a user message
tool_use_count += 1 tool_use_count += 1
if tool_use_count >= MAX_TOOL_ITERATIONS: if tool_use_count >= MAX_TOOL_ITERATIONS:
logging.warning(f"Max tool iterations ({MAX_TOOL_ITERATIONS}) reached for Anthropic.") logging.warning(f"Max tool iterations ({MAX_TOOL_ITERATIONS}) reached for Anthropic.")
# Update assistant_response_content with any text from the last assistant turn before breaking
if not assistant_response_content and text_parts_from_assistant:
assistant_response_content = "".join(text_parts_from_assistant)
assistant_response_content += "\n[Max tool iterations reached]"
break break
self.conversation_history[user_id] = current_turn_messages self.conversation_history[user_id] = current_turn_messages
@@ -153,70 +181,89 @@ class AnthropicTelegramInferenceBot(BaseTelegramInferenceBot):
if assistant_response_content: if assistant_response_content:
return assistant_response_content return assistant_response_content
else: else:
# Fallback if no text parts were found but there was an assistant message
if current_turn_messages: if current_turn_messages:
last_message_in_turn = current_turn_messages[-1] last_message_in_turn = current_turn_messages[-1]
# Check if the last message content has text blocks (Anthropic specific structure)
if last_message_in_turn.get("role") == "assistant" and isinstance(last_message_in_turn.get("content"), list): if last_message_in_turn.get("role") == "assistant" and isinstance(last_message_in_turn.get("content"), list):
for block in reversed(last_message_in_turn["content"]): for block in reversed(last_message_in_turn["content"]):
if block.type == "text": if block.type == "text" and hasattr(block, 'text') and block.text:
return block.text return block.text # Return the first non-empty text found from the end
return "No textual response from assistant." return "No textual response generated by the assistant after processing." # More informative default
async def start(self): async def start(self):
logging.info("Anthropic Bot started") logging.info("Anthropic Bot started")
async def clear_conversation_history(self, user_id): # clear_conversation_history is inherited from BaseTelegramInferenceBot and calls super().clear_conversation_history
super().clear_conversation_history(user_id) # No need to override if the base implementation is sufficient, unless specific logging is needed.
logging.info(f"Cleared conversation history for Anthropic bot, user {user_id}") # async def clear_conversation_history(self, user_id):
# super().clear_conversation_history(user_id)
# logging.info(f"Cleared conversation history for Anthropic bot, user {user_id}")
async def abort_processing(self, user_id): async def abort_processing(self, user_id):
# This abort is a soft abort, as actual Anthropic API call is synchronous within handle_message
# It primarily clears state and prevents further processing in the bot's loop if any.
if user_id in self.processing_status: if user_id in self.processing_status:
self.processing_status[user_id]["processing"] = False self.processing_status[user_id]["processing"] = False # Mark as not processing
await self.clear_conversation_history(user_id) # self.clear_processing_status(user_id) # Use base class method to remove entry
return "Processing aborted and conversation cleared." # Clearing history might be too aggressive for a simple abort, depends on desired UX
else: # For now, let's just stop processing and clear the flag.
await self.clear_conversation_history(user_id) # Consider if conversation history should be cleared here or if that is a separate user action.
return "No active processing found to abort. Conversation cleared." # super().clear_conversation_history(user_id) # Moved to be less aggressive
logging.info(f"Abort requested for user {user_id}. Processing flag cleared.")
return "Processing aborted. You can send a new message or /clear the conversation."
async def switch_model(self): async def switch_model(self):
# Ensure ANTHROPIC_SMALL_MODEL and ANTHROPIC_LARGE_MODEL related env vars are loaded in __init__
# or ensure they are freshly checked here if they can change during runtime (less common for model names).
# For this implementation, we rely on the values stored during __init__.
if not self.small_model_name or not self.large_model_name: if not self.small_model_name or not self.large_model_name:
logging.warning("Small or Large model names for Anthropic are not defined. Cannot switch model.") logging.warning("Small or Large model names for Anthropic are not defined. Cannot switch model.")
return f"Model switching not fully configured. Currently using {self.model}." return f"Model switching not fully configured. Currently using {self.model}."
if self.model == self.small_model_name: current_is_small = self.model == self.small_model_name
current_is_large = self.model == self.large_model_name
if current_is_small:
target_model = self.large_model_name target_model = self.large_model_name
target_max_tokens = self.large_model_max_tokens target_max_tokens_str = self.large_model_max_tokens_str
# Use default large max_tokens if specific one isn't set or invalid default_target_max_tokens = int(self.DEFAULT_LARGE_MODEL_MAX_TOKENS)
default_max_tokens_for_large = "4096" elif current_is_large:
elif self.model == self.large_model_name:
target_model = self.small_model_name target_model = self.small_model_name
target_max_tokens = self.small_model_max_tokens target_max_tokens_str = self.small_model_max_tokens_str
# Use default small max_tokens if specific one isn't set or invalid default_target_max_tokens = int(self.DEFAULT_SMALL_MODEL_MAX_TOKENS)
default_max_tokens_for_large = "2048"
else: else:
# Current model is neither the designated small nor large, switch to small as a reset logging.warning(f"Current model {self.model} is unrecognized. Switching to default small model.")
logging.warning(f"Current model {self.model} is neither the configured small nor large model. Switching to small model.")
target_model = self.small_model_name target_model = self.small_model_name
target_max_tokens = self.small_model_max_tokens target_max_tokens_str = self.small_model_max_tokens_str
default_max_tokens_for_large = "2048" default_target_max_tokens = int(self.DEFAULT_SMALL_MODEL_MAX_TOKENS)
self._configure_model_and_tokens(target_model, target_max_tokens_str, default_max_tokens=default_target_max_tokens)
self._configure_model_and_tokens(target_model, target_max_tokens, default_max_tokens=int(default_max_tokens_for_large)) # Pass appropriate default
logging.info(f"Switched Anthropic model to: {self.model}") logging.info(f"Switched Anthropic model to: {self.model}")
return f"Switched to Anthropic model: {self.model} (Max Tokens: {self.max_tokens})"#Provide token info return f"Switched to Anthropic model: {self.model} (Max Tokens: {self.max_tokens})"
# The main function is for standalone execution and basic testing, not part of the class itself.
# It's good practice to update it to reflect changes if you use it for quick tests.
# For unit tests, we'll instantiate the class with mocked dependencies.
def main(): def main():
if not os.environ.get("ANTHROPIC_API_KEY"):
logging.error("FATAL: ANTHROPIC_API_KEY environment variable not set.")
return
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
bot = AnthropicTelegramInferenceBot() # Example of how to instantiate with new constructor (assuming API key is in ENV for this example)
# For real tests, you'd mock Anthropic() or pass a mock client.
try:
# These would typically come from a config file or CLI args in a real app if not ENV
# For this example, we rely on ENV or defaults being handled by constructor if not provided.
bot = AnthropicTelegramInferenceBot(
api_key=os.environ.get("ANTHROPIC_API_KEY") # Explicitly pass, or let constructor handle ENV
)
except ValueError as e:
logging.error(f"Failed to initialize bot: {e}")
return
except Exception as e: # Catch any other init errors
logging.error(f"An unexpected error occurred during bot initialization: {e}")
return
# TelegramHelper also updated, ensure it's instantiated correctly for this main context.
# For this basic main, we might not pass all configurable paths to TelegramHelper,
# letting them use defaults.
telegram_helper = TelegramHelper(bot) telegram_helper = TelegramHelper(bot)
telegram_helper.run() telegram_helper.run()