From 8ee7df789e56e1d6471fc544bdaa48cd3e5e712d Mon Sep 17 00:00:00 2001 From: Jonathan Lucas Date: Sun, 18 Aug 2024 12:16:03 -0500 Subject: [PATCH] Fixed anthropic inference --- telegram_inference_bot.py | 57 +++++++++++++++++++++++---------------- 1 file changed, 34 insertions(+), 23 deletions(-) diff --git a/telegram_inference_bot.py b/telegram_inference_bot.py index 3bc6c8f..85fc61b 100644 --- a/telegram_inference_bot.py +++ b/telegram_inference_bot.py @@ -16,7 +16,8 @@ load_dotenv() openai_client = OpenAI() anthropic_client = anthropic.Anthropic( - api_key=os.environ.get("ANTHROPIC_API_KEY") + api_key=os.environ.get("ANTHROPIC_API_KEY"), + default_headers={"anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15"} ) GPT_4O = "gpt-4o" @@ -28,7 +29,7 @@ model_max_tokens = { } use_smart_model = True -use_anthropic = False +use_anthropic = True # Set up logging to console and file logging.basicConfig(level=logging.WARNING, handlers=[ @@ -96,10 +97,8 @@ async def handle_message(update: Update, context: ContextTypes.DEFAULT_TYPE) -> for message in response.content: if message.type == "tool_use": tool_calls.append(message) - assistant_message = "" else: - assistant_message = message - tool_calls = None + messages.append({"role": "assistant", "content": response.content}) else: assistant_message = response.choices[0].message if hasattr(assistant_message, 'function_call') and assistant_message.function_call is not None: @@ -113,37 +112,48 @@ async def handle_message(update: Update, context: ContextTypes.DEFAULT_TYPE) -> function_name = tool_call.name tool_response = call_tool(tool_call) - - conversation_history[user_id].append({"role": "function", "name": function_name, "content": json.dumps(tool_response)}) - messages.append({ - "role": "function", - "name": function_name, - "content": json.dumps(tool_response) - }) + + formatted_result = {} + + if use_anthropic: + formatted_result = {"role": "user", "content":[{"type": "tool_result", "tool_use_id": tool_call.id, "content": json.dumps(tool_response)}]} + else: + formatted_result = {"role": "function", "name": function_name, "content": json.dumps(tool_response)} + + messages.append(formatted_result) response = get_chat_response(messages) + assistant_message = "" if use_anthropic: for message in response.content: if message.type == "tool_use": tool_calls.append(message) - assistant_message = "" else: - assistant_message = message - tool_calls = None + messages.append({"role": "assistant", "content": response.content}) else: assistant_message = response.choices[0].message + conversation_history[user_id].append({"role": "assistant", "content": assistant_message}) if hasattr(assistant_message, 'function_call') and assistant_message.function_call is not None: tool_calls.append(assistant_message.function_call) + assistant_reply = assistant_message toolUseCount += 1 + - assistant_reply = assistant_message - conversation_history[user_id].append({"role": "assistant", "content": assistant_reply}) - + if (toolUseCount == 0): + if use_anthropic: + assistant_reply = response.content + else: + assistant_reply = assistant_message + conversation_history[user_id].append({"role": "assistant", "content": assistant_reply}) + if len(conversation_history[user_id]) > 20: conversation_history[user_id] = conversation_history[user_id][-20:] - await update.message.reply_text(assistant_reply.content if not use_anthropic else assistant_reply) + if use_anthropic: + await update.message.reply_text(messages[-1]["content"][0].text) + else: + await update.message.reply_text(assistant_reply.content) except Exception as e: logging.error(f"An error occurred: {str(e)}") @@ -151,7 +161,7 @@ async def handle_message(update: Update, context: ContextTypes.DEFAULT_TYPE) -> def call_tool(function_call): function_name = function_call.name if use_anthropic else function_call.name - function_args = "{}" if use_anthropic else function_call.arguments + function_args = json.dumps(function_call.input) if use_anthropic else function_call.arguments for tool in tools: if function_name in [f["name"] for f in tool.get_functions()]: return tool.execute(function_name, **json.loads(function_args)) @@ -181,10 +191,10 @@ def get_claude_response(messages): ] try: response = anthropic_client.messages.create( - model="claude-3-sonnet-20240229", + model="claude-3-5-sonnet-20240620", system=system_prompt, - messages=[{"role": m["role"], "content": m["content"]} for m in messages], - max_tokens=4096, + messages=messages, + max_tokens=8192, tools=anthropic_tools ) except Exception as e: @@ -201,6 +211,7 @@ async def switch(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: await update.message.reply_text(f"Switched to model: {model}") async def switch_providers(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + await clear(update, context) global use_anthropic use_anthropic = not use_anthropic logging.info("Using Anthropic" if use_anthropic else "Using OpenAI")