210 lines
9.3 KiB
Python
210 lines
9.3 KiB
Python
import os
|
|
import json
|
|
import logging
|
|
from anthropic import Anthropic, APIError, RateLimitError
|
|
from base_telegram_inference_bot import BaseTelegramInferenceBot
|
|
from telegram_helper import TelegramHelper
|
|
|
|
class AnthropicTelegramInferenceBot(BaseTelegramInferenceBot):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.anthropic_client = Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY"))
|
|
|
|
self._configure_model_and_tokens(
|
|
os.environ.get("ANTHROPIC_MODEL", "claude-3-5-sonnet-20240620"),
|
|
os.environ.get("ANTHROPIC_MAX_TOKENS", "4096")
|
|
)
|
|
|
|
def _configure_model_and_tokens(self, model_name, max_tokens_str, default_max_tokens=4096):
|
|
self.model = model_name if model_name else "claude-3-5-sonnet-20240620"
|
|
try:
|
|
self.max_tokens = int(max_tokens_str) if max_tokens_str is not None else default_max_tokens
|
|
except ValueError:
|
|
logging.error(f"Invalid value for Anthropic max_tokens: {max_tokens_str}. Using default {default_max_tokens}.")
|
|
self.max_tokens = default_max_tokens
|
|
logging.info(f"Configured to use Anthropic model: {self.model} with max_tokens: {self.max_tokens}")
|
|
|
|
def get_llm_description(self) -> str:
|
|
return f"LLM: {self.model}, Max Tokens: {self.max_tokens}"
|
|
|
|
def get_chat_response(self, messages_history):
|
|
current_system_prompt = self.system_prompt if self.system_prompt else ""
|
|
anthropic_tools = []
|
|
if hasattr(self, 'functions') and self.functions:
|
|
anthropic_tools = [
|
|
{
|
|
"name": function['name'],
|
|
"description": function['description'],
|
|
"input_schema": function['parameters'] if function['parameters'] not in [None, {}] else {"type": "object", "properties": {}}
|
|
}
|
|
for function in self.functions
|
|
]
|
|
|
|
try:
|
|
response = self.anthropic_client.messages.create(
|
|
model=self.model,
|
|
system=current_system_prompt,
|
|
messages=messages_history,
|
|
max_tokens=self.max_tokens,
|
|
tools=anthropic_tools if anthropic_tools else None,
|
|
tool_choice={"type": "auto"} if anthropic_tools else None
|
|
)
|
|
return response
|
|
except (APIError, RateLimitError) as e:
|
|
logging.error(f"Anthropic API error: {e}")
|
|
raise
|
|
except Exception as e:
|
|
logging.error(f"An unexpected error occurred during Anthropic API call: {e}")
|
|
raise
|
|
|
|
def _format_tool_response_for_anthropic(self, tool_response_data):
|
|
if isinstance(tool_response_data, str):
|
|
return [{"type": "text", "text": tool_response_data}]
|
|
elif isinstance(tool_response_data, (dict, list)):
|
|
try:
|
|
is_valid_block_list = isinstance(tool_response_data, list) and all(isinstance(item, dict) and "type" in item for item in tool_response_data)
|
|
if is_valid_block_list:
|
|
return tool_response_data
|
|
else:
|
|
return [{"type": "text", "text": json.dumps(tool_response_data)}]
|
|
except (TypeError, json.JSONDecodeError):
|
|
return [{"type": "text", "text": str(tool_response_data)}]
|
|
else:
|
|
return [{"type": "text", "text": str(tool_response_data)}]
|
|
|
|
async def handle_message(self, user_id, user_message):
|
|
if user_id not in self.conversation_history:
|
|
self.conversation_history[user_id] = []
|
|
|
|
self.conversation_history[user_id].append({"role": "user", "content": user_message})
|
|
current_turn_messages = list(self.conversation_history[user_id])
|
|
|
|
MAX_TOOL_ITERATIONS = 5
|
|
tool_use_count = 0
|
|
assistant_response_content = ""
|
|
|
|
while tool_use_count < MAX_TOOL_ITERATIONS:
|
|
response = self.get_chat_response(current_turn_messages)
|
|
|
|
if not response or not response.content:
|
|
logging.error("No valid response content from Anthropic LLM.")
|
|
self.conversation_history[user_id] = current_turn_messages
|
|
return "Error: Could not get a valid response from the LLM."
|
|
|
|
assistant_current_turn_content_blocks = response.content
|
|
current_turn_messages.append({"role": "assistant", "content": assistant_current_turn_content_blocks})
|
|
|
|
text_parts_from_assistant = []
|
|
tool_calls_from_response = []
|
|
for block in assistant_current_turn_content_blocks:
|
|
if block.type == "text":
|
|
text_parts_from_assistant.append(block.text)
|
|
elif block.type == "tool_use":
|
|
tool_calls_from_response.append(block)
|
|
|
|
assistant_response_content = "".join(text_parts_from_assistant)
|
|
|
|
if not tool_calls_from_response:
|
|
break
|
|
|
|
tool_results_for_model = []
|
|
for tool_call in tool_calls_from_response:
|
|
tool_name = tool_call.name
|
|
tool_input = tool_call.input
|
|
tool_use_id = tool_call.id
|
|
|
|
logging.info(f"Attempting to call Anthropic tool: {tool_name} with input: {tool_input}")
|
|
try:
|
|
tool_response_data = self.call_tool(tool_name, tool_input)
|
|
tool_result_content_block = self._format_tool_response_for_anthropic(tool_response_data)
|
|
|
|
tool_results_for_model.append({
|
|
"type": "tool_result",
|
|
"tool_use_id": tool_use_id,
|
|
"content": tool_result_content_block
|
|
})
|
|
except Exception as e:
|
|
logging.error(f"Error calling tool {tool_name}: {e}")
|
|
tool_results_for_model.append({
|
|
"type": "tool_result",
|
|
"tool_use_id": tool_use_id,
|
|
"content": [{"type": "text", "text": f"Error executing tool {tool_name}: {str(e)}"}],
|
|
"is_error": True
|
|
})
|
|
|
|
current_turn_messages.append({"role": "user", "content": tool_results_for_model})
|
|
|
|
tool_use_count += 1
|
|
if tool_use_count >= MAX_TOOL_ITERATIONS:
|
|
logging.warning(f"Max tool iterations ({MAX_TOOL_ITERATIONS}) reached for Anthropic.")
|
|
break
|
|
|
|
self.conversation_history[user_id] = current_turn_messages
|
|
|
|
if len(self.conversation_history[user_id]) > 20:
|
|
self.conversation_history[user_id] = self.conversation_history[user_id][-20:]
|
|
|
|
if assistant_response_content:
|
|
return assistant_response_content
|
|
else:
|
|
if current_turn_messages:
|
|
last_message_in_turn = current_turn_messages[-1]
|
|
if last_message_in_turn.get("role") == "assistant" and isinstance(last_message_in_turn.get("content"), list):
|
|
for block in reversed(last_message_in_turn["content"]):
|
|
if block.type == "text":
|
|
return block.text
|
|
return "No textual response from assistant."
|
|
|
|
|
|
async def start(self):
|
|
logging.info("Anthropic Bot started")
|
|
|
|
async def clear_conversation_history(self, user_id):
|
|
super().clear_conversation_history(user_id)
|
|
logging.info(f"Cleared conversation history for Anthropic bot, user {user_id}")
|
|
|
|
async def abort_processing(self, user_id):
|
|
if user_id in self.processing_status:
|
|
self.processing_status[user_id]["processing"] = False
|
|
await self.clear_conversation_history(user_id)
|
|
return "Processing aborted and conversation cleared."
|
|
else:
|
|
await self.clear_conversation_history(user_id)
|
|
return "No active processing found to abort. Conversation cleared."
|
|
|
|
async def switch_model(self):
|
|
primary_model = os.environ.get("ANTHROPIC_MODEL", "claude-3-5-sonnet-20240620")
|
|
primary_max_tokens = os.environ.get("ANTHROPIC_MAX_TOKENS", "4096")
|
|
|
|
secondary_model_env = os.environ.get("ANTHROPIC_SECONDARY_MODEL")
|
|
secondary_max_tokens_env = os.environ.get("ANTHROPIC_SECONDARY_MAX_TOKENS")
|
|
|
|
if not secondary_model_env:
|
|
logging.warning("ANTHROPIC_SECONDARY_MODEL not defined. Cannot switch model.")
|
|
return f"Model switching not configured. Currently using {self.model}."
|
|
|
|
if self.model == primary_model:
|
|
target_model = secondary_model_env
|
|
target_max_tokens = secondary_max_tokens_env if secondary_max_tokens_env else "2048"
|
|
else:
|
|
target_model = primary_model
|
|
target_max_tokens = primary_max_tokens
|
|
|
|
self._configure_model_and_tokens(target_model, target_max_tokens)
|
|
logging.info(f"Switched Anthropic model to: {self.model}")
|
|
return f"Switched to Anthropic model: {self.model}"
|
|
|
|
def main():
|
|
if not os.environ.get("ANTHROPIC_API_KEY"):
|
|
logging.error("FATAL: ANTHROPIC_API_KEY environment variable not set.")
|
|
return
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
|
|
bot = AnthropicTelegramInferenceBot()
|
|
telegram_helper = TelegramHelper(bot)
|
|
telegram_helper.run()
|
|
|
|
if __name__ == '__main__':
|
|
main()
|