Refactor: Inject dependencies in AnthropicTelegramInferenceBot
This commit is contained in:
@@ -3,25 +3,48 @@ import json
|
||||
import logging
|
||||
from anthropic import Anthropic, APIError, RateLimitError
|
||||
from base_telegram_inference_bot import BaseTelegramInferenceBot
|
||||
from telegram_helper import TelegramHelper
|
||||
from telegram_helper import TelegramHelper # Used in main, not class
|
||||
|
||||
class AnthropicTelegramInferenceBot(BaseTelegramInferenceBot):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.anthropic_client = Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY"))
|
||||
DEFAULT_SMALL_MODEL_NAME = "claude-3-haiku-20240307"
|
||||
DEFAULT_SMALL_MODEL_MAX_TOKENS = "2048"
|
||||
DEFAULT_LARGE_MODEL_NAME = "claude-3-opus-20240229"
|
||||
DEFAULT_LARGE_MODEL_MAX_TOKENS = "4096"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
anthropic_client: Anthropic | None = None,
|
||||
api_key: str | None = None,
|
||||
small_model_name: str | None = None,
|
||||
small_model_max_tokens: str | None = None,
|
||||
large_model_name: str | None = None,
|
||||
large_model_max_tokens: str | None = None,
|
||||
system_prompt_content: str | None = None,
|
||||
system_prompt_path: str | None = None
|
||||
):
|
||||
super().__init__(system_prompt_content=system_prompt_content, system_prompt_path=system_prompt_path)
|
||||
|
||||
if anthropic_client:
|
||||
self.anthropic_client = anthropic_client
|
||||
else:
|
||||
_api_key = api_key or os.environ.get("ANTHROPIC_API_KEY")
|
||||
if not _api_key:
|
||||
raise ValueError("Anthropic API key must be provided either via argument or ANTHROPIC_API_KEY environment variable.")
|
||||
self.anthropic_client = Anthropic(api_key=_api_key)
|
||||
|
||||
self.small_model_name = small_model_name or os.environ.get("ANTHROPIC_SMALL_MODEL") or self.DEFAULT_SMALL_MODEL_NAME
|
||||
self.small_model_max_tokens_str = small_model_max_tokens or os.environ.get("ANTHROPIC_SMALL_MODEL_MAX_TOKENS") or self.DEFAULT_SMALL_MODEL_MAX_TOKENS
|
||||
self.large_model_name = large_model_name or os.environ.get("ANTHROPIC_LARGE_MODEL") or self.DEFAULT_LARGE_MODEL_NAME
|
||||
self.large_model_max_tokens_str = large_model_max_tokens or os.environ.get("ANTHROPIC_LARGE_MODEL_MAX_TOKENS") or self.DEFAULT_LARGE_MODEL_MAX_TOKENS
|
||||
|
||||
# Initialize with the small model by default
|
||||
self.small_model_name = os.environ.get("ANTHROPIC_SMALL_MODEL", "claude-3-haiku-20240307")
|
||||
self.small_model_max_tokens = os.environ.get("ANTHROPIC_SMALL_MODEL_MAX_TOKENS", "2048")
|
||||
self.large_model_name = os.environ.get("ANTHROPIC_LARGE_MODEL", "claude-3-opus-20240229")
|
||||
self.large_model_max_tokens = os.environ.get("ANTHROPIC_LARGE_MODEL_MAX_TOKENS", "4096")
|
||||
|
||||
self._configure_model_and_tokens(
|
||||
self.small_model_name,
|
||||
self.small_model_max_tokens
|
||||
self.small_model_max_tokens_str,
|
||||
default_max_tokens=int(self.DEFAULT_SMALL_MODEL_MAX_TOKENS) # pass int for default
|
||||
)
|
||||
|
||||
def _configure_model_and_tokens(self, model_name, max_tokens_str, default_max_tokens=2048): # Default max_tokens adjusted for typical "small"
|
||||
def _configure_model_and_tokens(self, model_name: str, max_tokens_str: str, default_max_tokens: int = 2048):
|
||||
self.model = model_name
|
||||
try:
|
||||
self.max_tokens = int(max_tokens_str) if max_tokens_str is not None else default_max_tokens
|
||||
@@ -65,17 +88,19 @@ class AnthropicTelegramInferenceBot(BaseTelegramInferenceBot):
|
||||
|
||||
def _format_tool_response_for_anthropic(self, tool_response_data):
|
||||
if isinstance(tool_response_data, str):
|
||||
# Wrap plain string in a list of text blocks if not already structured
|
||||
return [{"type": "text", "text": tool_response_data}]
|
||||
elif isinstance(tool_response_data, (dict, list)):
|
||||
try:
|
||||
is_valid_block_list = isinstance(tool_response_data, list) and all(isinstance(item, dict) and "type" in item for item in tool_response_data)
|
||||
if is_valid_block_list:
|
||||
elif isinstance(tool_response_data, list) and all(isinstance(item, dict) and "type" in item for item in tool_response_data):
|
||||
# Already a list of content blocks
|
||||
return tool_response_data
|
||||
else:
|
||||
elif isinstance(tool_response_data, (dict, list)):
|
||||
# Attempt to JSON dump other dicts/lists if not already in content block format
|
||||
try:
|
||||
return [{"type": "text", "text": json.dumps(tool_response_data)}]
|
||||
except (TypeError, json.JSONDecodeError):
|
||||
return [{"type": "text", "text": str(tool_response_data)}]
|
||||
return [{"type": "text", "text": str(tool_response_data)}] # Fallback to string
|
||||
else:
|
||||
# Fallback for other types (int, float, etc.)
|
||||
return [{"type": "text", "text": str(tool_response_data)}]
|
||||
|
||||
async def handle_message(self, user_id, user_message):
|
||||
@@ -94,7 +119,7 @@ class AnthropicTelegramInferenceBot(BaseTelegramInferenceBot):
|
||||
|
||||
if not response or not response.content:
|
||||
logging.error("No valid response content from Anthropic LLM.")
|
||||
self.conversation_history[user_id] = current_turn_messages
|
||||
self.conversation_history[user_id] = current_turn_messages # Save current state
|
||||
return "Error: Could not get a valid response from the LLM."
|
||||
|
||||
assistant_current_turn_content_blocks = response.content
|
||||
@@ -123,7 +148,6 @@ class AnthropicTelegramInferenceBot(BaseTelegramInferenceBot):
|
||||
try:
|
||||
tool_response_data = self.call_tool(tool_name, tool_input)
|
||||
tool_result_content_block = self._format_tool_response_for_anthropic(tool_response_data)
|
||||
|
||||
tool_results_for_model.append({
|
||||
"type": "tool_result",
|
||||
"tool_use_id": tool_use_id,
|
||||
@@ -138,11 +162,15 @@ class AnthropicTelegramInferenceBot(BaseTelegramInferenceBot):
|
||||
"is_error": True
|
||||
})
|
||||
|
||||
current_turn_messages.append({"role": "user", "content": tool_results_for_model})
|
||||
current_turn_messages.append({"role": "user", "content": tool_results_for_model}) # Anthropic expects tool results as a user message
|
||||
|
||||
tool_use_count += 1
|
||||
if tool_use_count >= MAX_TOOL_ITERATIONS:
|
||||
logging.warning(f"Max tool iterations ({MAX_TOOL_ITERATIONS}) reached for Anthropic.")
|
||||
# Update assistant_response_content with any text from the last assistant turn before breaking
|
||||
if not assistant_response_content and text_parts_from_assistant:
|
||||
assistant_response_content = "".join(text_parts_from_assistant)
|
||||
assistant_response_content += "\n[Max tool iterations reached]"
|
||||
break
|
||||
|
||||
self.conversation_history[user_id] = current_turn_messages
|
||||
@@ -153,70 +181,89 @@ class AnthropicTelegramInferenceBot(BaseTelegramInferenceBot):
|
||||
if assistant_response_content:
|
||||
return assistant_response_content
|
||||
else:
|
||||
# Fallback if no text parts were found but there was an assistant message
|
||||
if current_turn_messages:
|
||||
last_message_in_turn = current_turn_messages[-1]
|
||||
# Check if the last message content has text blocks (Anthropic specific structure)
|
||||
if last_message_in_turn.get("role") == "assistant" and isinstance(last_message_in_turn.get("content"), list):
|
||||
for block in reversed(last_message_in_turn["content"]):
|
||||
if block.type == "text":
|
||||
return block.text
|
||||
return "No textual response from assistant."
|
||||
|
||||
if block.type == "text" and hasattr(block, 'text') and block.text:
|
||||
return block.text # Return the first non-empty text found from the end
|
||||
return "No textual response generated by the assistant after processing." # More informative default
|
||||
|
||||
async def start(self):
|
||||
logging.info("Anthropic Bot started")
|
||||
|
||||
async def clear_conversation_history(self, user_id):
|
||||
super().clear_conversation_history(user_id)
|
||||
logging.info(f"Cleared conversation history for Anthropic bot, user {user_id}")
|
||||
# clear_conversation_history is inherited from BaseTelegramInferenceBot and calls super().clear_conversation_history
|
||||
# No need to override if the base implementation is sufficient, unless specific logging is needed.
|
||||
# async def clear_conversation_history(self, user_id):
|
||||
# super().clear_conversation_history(user_id)
|
||||
# logging.info(f"Cleared conversation history for Anthropic bot, user {user_id}")
|
||||
|
||||
async def abort_processing(self, user_id):
|
||||
# This abort is a soft abort, as actual Anthropic API call is synchronous within handle_message
|
||||
# It primarily clears state and prevents further processing in the bot's loop if any.
|
||||
if user_id in self.processing_status:
|
||||
self.processing_status[user_id]["processing"] = False
|
||||
await self.clear_conversation_history(user_id)
|
||||
return "Processing aborted and conversation cleared."
|
||||
else:
|
||||
await self.clear_conversation_history(user_id)
|
||||
return "No active processing found to abort. Conversation cleared."
|
||||
self.processing_status[user_id]["processing"] = False # Mark as not processing
|
||||
# self.clear_processing_status(user_id) # Use base class method to remove entry
|
||||
# Clearing history might be too aggressive for a simple abort, depends on desired UX
|
||||
# For now, let's just stop processing and clear the flag.
|
||||
# Consider if conversation history should be cleared here or if that is a separate user action.
|
||||
# super().clear_conversation_history(user_id) # Moved to be less aggressive
|
||||
logging.info(f"Abort requested for user {user_id}. Processing flag cleared.")
|
||||
return "Processing aborted. You can send a new message or /clear the conversation."
|
||||
|
||||
async def switch_model(self):
|
||||
# Ensure ANTHROPIC_SMALL_MODEL and ANTHROPIC_LARGE_MODEL related env vars are loaded in __init__
|
||||
# or ensure they are freshly checked here if they can change during runtime (less common for model names).
|
||||
# For this implementation, we rely on the values stored during __init__.
|
||||
|
||||
if not self.small_model_name or not self.large_model_name:
|
||||
logging.warning("Small or Large model names for Anthropic are not defined. Cannot switch model.")
|
||||
return f"Model switching not fully configured. Currently using {self.model}."
|
||||
|
||||
if self.model == self.small_model_name:
|
||||
current_is_small = self.model == self.small_model_name
|
||||
current_is_large = self.model == self.large_model_name
|
||||
|
||||
if current_is_small:
|
||||
target_model = self.large_model_name
|
||||
target_max_tokens = self.large_model_max_tokens
|
||||
# Use default large max_tokens if specific one isn't set or invalid
|
||||
default_max_tokens_for_large = "4096"
|
||||
elif self.model == self.large_model_name:
|
||||
target_max_tokens_str = self.large_model_max_tokens_str
|
||||
default_target_max_tokens = int(self.DEFAULT_LARGE_MODEL_MAX_TOKENS)
|
||||
elif current_is_large:
|
||||
target_model = self.small_model_name
|
||||
target_max_tokens = self.small_model_max_tokens
|
||||
# Use default small max_tokens if specific one isn't set or invalid
|
||||
default_max_tokens_for_large = "2048"
|
||||
target_max_tokens_str = self.small_model_max_tokens_str
|
||||
default_target_max_tokens = int(self.DEFAULT_SMALL_MODEL_MAX_TOKENS)
|
||||
else:
|
||||
# Current model is neither the designated small nor large, switch to small as a reset
|
||||
logging.warning(f"Current model {self.model} is neither the configured small nor large model. Switching to small model.")
|
||||
logging.warning(f"Current model {self.model} is unrecognized. Switching to default small model.")
|
||||
target_model = self.small_model_name
|
||||
target_max_tokens = self.small_model_max_tokens
|
||||
default_max_tokens_for_large = "2048"
|
||||
target_max_tokens_str = self.small_model_max_tokens_str
|
||||
default_target_max_tokens = int(self.DEFAULT_SMALL_MODEL_MAX_TOKENS)
|
||||
|
||||
|
||||
self._configure_model_and_tokens(target_model, target_max_tokens, default_max_tokens=int(default_max_tokens_for_large)) # Pass appropriate default
|
||||
self._configure_model_and_tokens(target_model, target_max_tokens_str, default_max_tokens=default_target_max_tokens)
|
||||
logging.info(f"Switched Anthropic model to: {self.model}")
|
||||
return f"Switched to Anthropic model: {self.model} (Max Tokens: {self.max_tokens})"#Provide token info
|
||||
return f"Switched to Anthropic model: {self.model} (Max Tokens: {self.max_tokens})"
|
||||
|
||||
|
||||
# The main function is for standalone execution and basic testing, not part of the class itself.
|
||||
# It's good practice to update it to reflect changes if you use it for quick tests.
|
||||
# For unit tests, we'll instantiate the class with mocked dependencies.
|
||||
def main():
|
||||
if not os.environ.get("ANTHROPIC_API_KEY"):
|
||||
logging.error("FATAL: ANTHROPIC_API_KEY environment variable not set.")
|
||||
return
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
|
||||
bot = AnthropicTelegramInferenceBot()
|
||||
# Example of how to instantiate with new constructor (assuming API key is in ENV for this example)
|
||||
# For real tests, you'd mock Anthropic() or pass a mock client.
|
||||
try:
|
||||
# These would typically come from a config file or CLI args in a real app if not ENV
|
||||
# For this example, we rely on ENV or defaults being handled by constructor if not provided.
|
||||
bot = AnthropicTelegramInferenceBot(
|
||||
api_key=os.environ.get("ANTHROPIC_API_KEY") # Explicitly pass, or let constructor handle ENV
|
||||
)
|
||||
except ValueError as e:
|
||||
logging.error(f"Failed to initialize bot: {e}")
|
||||
return
|
||||
except Exception as e: # Catch any other init errors
|
||||
logging.error(f"An unexpected error occurred during bot initialization: {e}")
|
||||
return
|
||||
|
||||
# TelegramHelper also updated, ensure it's instantiated correctly for this main context.
|
||||
# For this basic main, we might not pass all configurable paths to TelegramHelper,
|
||||
# letting them use defaults.
|
||||
telegram_helper = TelegramHelper(bot)
|
||||
telegram_helper.run()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user