Merge pull request #187 from bucolucas/feature/chat-status-update
feat: Enhance status command with system prompt and LLM details
This commit is contained in:
@@ -5,103 +5,216 @@ from anthropic import Anthropic
|
|||||||
from base_telegram_inference_bot import BaseTelegramInferenceBot
|
from base_telegram_inference_bot import BaseTelegramInferenceBot
|
||||||
from telegram_helper import TelegramHelper
|
from telegram_helper import TelegramHelper
|
||||||
|
|
||||||
|
# logging.basicConfig(level=logging.INFO) # Usually configured in main execution script
|
||||||
|
|
||||||
class AnthropicTelegramInferenceBot(BaseTelegramInferenceBot):
|
class AnthropicTelegramInferenceBot(BaseTelegramInferenceBot):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.anthropic_client = Anthropic(
|
self.anthropic_client = Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY"))
|
||||||
api_key=os.environ.get("ANTHROPIC_API_KEY"),
|
# Note: default_headers for max_tokens with older models might be needed.
|
||||||
default_headers={"anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15"}
|
# For Claude 3.5 Sonnet, max_tokens is a top-level param in messages.create
|
||||||
|
|
||||||
|
# Configure model and tokens. Using Sonnet 3.5 as default.
|
||||||
|
# ANTHROPIC_MODEL and ANTHROPIC_MAX_TOKENS would be new ENVs.
|
||||||
|
self._configure_model_and_tokens(
|
||||||
|
os.environ.get("ANTHROPIC_MODEL", "claude-3-5-sonnet-20240620"),
|
||||||
|
os.environ.get("ANTHROPIC_MAX_TOKENS", "4096") # Default max tokens for Sonnet 3.5
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_chat_response(self, messages):
|
def _configure_model_and_tokens(self, model_name, max_tokens_str, default_max_tokens=4096):
|
||||||
|
self.model = model_name if model_name else "claude-3-5-sonnet-20240620"
|
||||||
|
try:
|
||||||
|
# Anthropic's max_tokens is an integer.
|
||||||
|
self.max_tokens = int(max_tokens_str) if max_tokens_str is not None else default_max_tokens
|
||||||
|
except ValueError:
|
||||||
|
logging.error(f"Invalid value for Anthropic max_tokens: {max_tokens_str}. Using default {default_max_tokens}.")
|
||||||
|
self.max_tokens = default_max_tokens
|
||||||
|
logging.info(f"Configured to use Anthropic model: {self.model} with max_tokens: {self.max_tokens}")
|
||||||
|
|
||||||
|
def get_system_prompt_description(self) -> str:
|
||||||
|
system_prompt_path = os.getenv("SYSTEM_PROMPT_PATH")
|
||||||
|
if system_prompt_path and os.path.isfile(system_prompt_path):
|
||||||
|
return f"System Prompt File: {os.path.basename(system_prompt_path)}"
|
||||||
|
elif system_prompt_path:
|
||||||
|
return f"System Prompt File: {os.path.basename(system_prompt_path)} (Not found at path: {system_prompt_path})"
|
||||||
|
else:
|
||||||
|
return "System Prompt File: Not configured (SYSTEM_PROMPT_PATH not set)."
|
||||||
|
|
||||||
|
def get_llm_description(self) -> str:
|
||||||
|
return f"LLM: {self.model}, Max Tokens: {self.max_tokens}"
|
||||||
|
|
||||||
|
def get_chat_response(self, messages_history):
|
||||||
|
current_system_prompt = self.system_prompt if self.system_prompt else ""
|
||||||
|
anthropic_tools = []
|
||||||
|
if hasattr(self, 'functions') and self.functions:
|
||||||
anthropic_tools = [
|
anthropic_tools = [
|
||||||
{
|
{
|
||||||
"name": function['name'],
|
"name": function['name'],
|
||||||
"description": function['description'],
|
"description": function['description'],
|
||||||
"input_schema": function['parameters'] if function['parameters'] not in [None, {}] else {"type": "object", "properties": {"param1": {"type": "string", "description": "Unnecessary"}}, "required": []}
|
"input_schema": function['parameters'] if function['parameters'] not in [None, {}] else {"type": "object", "properties": {}}
|
||||||
}
|
}
|
||||||
for function in self.functions
|
for function in self.functions
|
||||||
]
|
]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = self.anthropic_client.messages.create(
|
response = self.anthropic_client.messages.create(
|
||||||
model="claude-3-5-sonnet-20240620",
|
model=self.model,
|
||||||
system=self.system_prompt,
|
system=current_system_prompt,
|
||||||
messages=messages,
|
messages=messages_history,
|
||||||
max_tokens=8192,
|
max_tokens=self.max_tokens,
|
||||||
tools=anthropic_tools,
|
tools=anthropic_tools if anthropic_tools else None,
|
||||||
tool_choice={"type": "auto"}
|
tool_choice={"type": "auto"} if anthropic_tools else None
|
||||||
)
|
)
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"An error occurred: {str(e)}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Anthropic API call failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
async def handle_message(self, user_id, user_message):
|
async def handle_message(self, user_id, user_message):
|
||||||
if user_id not in self.conversation_history:
|
if user_id not in self.conversation_history:
|
||||||
self.conversation_history[user_id] = []
|
self.conversation_history[user_id] = []
|
||||||
|
|
||||||
self.conversation_history[user_id].append({"role": "user", "content": user_message})
|
self.conversation_history[user_id].append({"role": "user", "content": user_message})
|
||||||
messages = self.conversation_history[user_id]
|
current_turn_messages = list(self.conversation_history[user_id])
|
||||||
|
|
||||||
response = self.get_chat_response(messages)
|
|
||||||
tool_calls = []
|
|
||||||
full_message = []
|
|
||||||
for message_part in response.content:
|
|
||||||
full_message.append(message_part)
|
|
||||||
if message_part.type == "tool_use":
|
|
||||||
tool_calls.append(message_part)
|
|
||||||
messages.append({"role": "assistant", "content": full_message})
|
|
||||||
|
|
||||||
|
MAX_TOOL_ITERATIONS = 5
|
||||||
tool_use_count = 0
|
tool_use_count = 0
|
||||||
while len(tool_calls) > 0 and tool_use_count < 50:
|
assistant_response_content = ""
|
||||||
tool_use_results = []
|
|
||||||
while len(tool_calls) > 0:
|
|
||||||
tool_call = tool_calls.pop(0)
|
|
||||||
tool_response = self.call_tool(tool_call.name, json.dumps(tool_call.input))
|
|
||||||
tool_use_results.append({"type": "tool_result", "tool_use_id": tool_call.id, "content": json.dumps(tool_response)})
|
|
||||||
|
|
||||||
messages.append({"role": "user", "content": tool_use_results})
|
while tool_use_count < MAX_TOOL_ITERATIONS:
|
||||||
|
response = self.get_chat_response(current_turn_messages)
|
||||||
|
|
||||||
response = self.get_chat_response(messages)
|
if not response or not response.content:
|
||||||
full_message = []
|
logging.error("No valid response content from Anthropic LLM.")
|
||||||
|
self.conversation_history[user_id] = current_turn_messages # Persist what we have
|
||||||
|
return "Error: Could not get a valid response from the LLM."
|
||||||
|
|
||||||
for message_part in response.content:
|
assistant_current_turn_content_blocks = response.content
|
||||||
full_message.append(message_part)
|
current_turn_messages.append({"role": "assistant", "content": assistant_current_turn_content_blocks})
|
||||||
if message_part.type == "tool_use":
|
|
||||||
tool_calls.append(message_part)
|
text_parts_from_assistant = []
|
||||||
messages.append({"role": "assistant", "content": full_message})
|
tool_calls_from_response = []
|
||||||
|
for block in assistant_current_turn_content_blocks:
|
||||||
|
if block.type == "text":
|
||||||
|
text_parts_from_assistant.append(block.text)
|
||||||
|
elif block.type == "tool_use":
|
||||||
|
tool_calls_from_response.append(block)
|
||||||
|
|
||||||
|
assistant_response_content = "".join(text_parts_from_assistant)
|
||||||
|
|
||||||
|
if not tool_calls_from_response:
|
||||||
|
break
|
||||||
|
|
||||||
|
tool_results_for_model = []
|
||||||
|
for tool_call in tool_calls_from_response:
|
||||||
|
tool_name = tool_call.name
|
||||||
|
tool_input = tool_call.input
|
||||||
|
tool_use_id = tool_call.id
|
||||||
|
|
||||||
|
logging.info(f"Attempting to call Anthropic tool: {tool_name} with input: {tool_input}")
|
||||||
|
try:
|
||||||
|
tool_response_data = self.call_tool(tool_name, tool_input)
|
||||||
|
|
||||||
|
if isinstance(tool_response_data, str):
|
||||||
|
tool_result_content_block = [{"type": "text", "text": tool_response_data}]
|
||||||
|
elif isinstance(tool_response_data, dict) or isinstance(tool_response_data, list):
|
||||||
|
try:
|
||||||
|
# If tool_response_data is already a list of Anthropic content blocks, use as is.
|
||||||
|
# Otherwise, dump to JSON string and wrap in a text block.
|
||||||
|
is_valid_block_list = isinstance(tool_response_data, list) and all(isinstance(item, dict) and "type" in item for item in tool_response_data)
|
||||||
|
if is_valid_block_list:
|
||||||
|
tool_result_content_block = tool_response_data
|
||||||
|
else:
|
||||||
|
tool_result_content_block = [{"type": "text", "text": json.dumps(tool_response_data)}]
|
||||||
|
except (TypeError, json.JSONDecodeError): # Not easily serializable or not a valid block list
|
||||||
|
tool_result_content_block = [{"type": "text", "text": str(tool_response_data)}]
|
||||||
|
else: # bool, int, float, None, etc.
|
||||||
|
tool_result_content_block = [{"type": "text", "text": str(tool_response_data)}]
|
||||||
|
|
||||||
|
tool_results_for_model.append({
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": tool_use_id,
|
||||||
|
"content": tool_result_content_block
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error calling tool {tool_name}: {e}")
|
||||||
|
tool_results_for_model.append({
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": tool_use_id,
|
||||||
|
"content": [{"type": "text", "text": f"Error executing tool {tool_name}: {str(e)}"}],
|
||||||
|
"is_error": True
|
||||||
|
})
|
||||||
|
|
||||||
|
current_turn_messages.append({"role": "user", "content": tool_results_for_model})
|
||||||
|
|
||||||
tool_use_count += 1
|
tool_use_count += 1
|
||||||
|
if tool_use_count >= MAX_TOOL_ITERATIONS:
|
||||||
|
logging.warning(f"Max tool iterations ({MAX_TOOL_ITERATIONS}) reached for Anthropic.")
|
||||||
|
break
|
||||||
|
|
||||||
if (tool_use_count == 0):
|
self.conversation_history[user_id] = current_turn_messages
|
||||||
assistant_reply = response.content
|
|
||||||
self.conversation_history[user_id].append({"role": "assistant", "content": assistant_reply})
|
|
||||||
|
|
||||||
if len(self.conversation_history[user_id]) > 20:
|
if len(self.conversation_history[user_id]) > 20:
|
||||||
self.conversation_history[user_id] = self.conversation_history[user_id][-20:]
|
self.conversation_history[user_id] = self.conversation_history[user_id][-20:]
|
||||||
|
|
||||||
return messages[-1]["content"][0].text
|
if assistant_response_content: # Text from the last successful assistant turn (or before max iterations)
|
||||||
|
return assistant_response_content
|
||||||
|
else: # Fallback if no text content was generated by assistant (e.g. initial error, or only tool use)
|
||||||
|
if current_turn_messages:
|
||||||
|
# Try to get the *very last* text block from the *very last* assistant message in history.
|
||||||
|
last_message_in_turn = current_turn_messages[-1]
|
||||||
|
if last_message_in_turn.get("role") == "assistant" and isinstance(last_message_in_turn.get("content"), list):
|
||||||
|
for block in reversed(last_message_in_turn["content"]):
|
||||||
|
if block.type == "text":
|
||||||
|
return block.text
|
||||||
|
return "No textual response from assistant."
|
||||||
|
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
logging.info("Bot started")
|
logging.info("Anthropic Bot started")
|
||||||
|
|
||||||
async def clear(self, user_id):
|
async def clear(self, user_id):
|
||||||
super().clear_conversation(user_id)
|
super().clear_conversation(user_id)
|
||||||
logging.info(f"Cleared conversation history and image for user {user_id}")
|
logging.info(f"Cleared conversation history for Anthropic bot, user {user_id}")
|
||||||
|
|
||||||
async def status(self):
|
|
||||||
return "Currently using claude-3-5-sonnet-20240620"
|
|
||||||
|
|
||||||
async def abort_processing(self, user_id):
|
async def abort_processing(self, user_id):
|
||||||
if user_id in self.processing_status:
|
if user_id in self.processing_status:
|
||||||
self.processing_status[user_id]["processing"] = False
|
self.processing_status[user_id]["processing"] = False
|
||||||
await self.clear(user_id)
|
await self.clear(user_id)
|
||||||
return "Processing aborted."
|
return "Processing aborted and conversation cleared."
|
||||||
else:
|
else:
|
||||||
return "No active processing to abort."
|
await self.clear(user_id)
|
||||||
|
return "No active processing found to abort. Conversation cleared."
|
||||||
|
|
||||||
|
async def switch_model(self):
|
||||||
|
primary_model = os.environ.get("ANTHROPIC_MODEL", "claude-3-5-sonnet-20240620")
|
||||||
|
primary_max_tokens = os.environ.get("ANTHROPIC_MAX_TOKENS", "4096")
|
||||||
|
|
||||||
|
secondary_model_env = os.environ.get("ANTHROPIC_SECONDARY_MODEL")
|
||||||
|
secondary_max_tokens_env = os.environ.get("ANTHROPIC_SECONDARY_MAX_TOKENS")
|
||||||
|
|
||||||
|
if not secondary_model_env:
|
||||||
|
logging.warning("ANTHROPIC_SECONDARY_MODEL not defined. Cannot switch model.")
|
||||||
|
return f"Model switching not configured. Currently using {self.model}."
|
||||||
|
|
||||||
|
if self.model == primary_model:
|
||||||
|
target_model = secondary_model_env
|
||||||
|
target_max_tokens = secondary_max_tokens_env if secondary_max_tokens_env else "2048"
|
||||||
|
else:
|
||||||
|
target_model = primary_model
|
||||||
|
target_max_tokens = primary_max_tokens
|
||||||
|
|
||||||
|
self._configure_model_and_tokens(target_model, target_max_tokens)
|
||||||
|
logging.info(f"Switched Anthropic model to: {self.model}")
|
||||||
|
return f"Switched to Anthropic model: {self.model}"
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
if not os.environ.get("ANTHROPIC_API_KEY"):
|
||||||
|
logging.error("FATAL: ANTHROPIC_API_KEY environment variable not set.")
|
||||||
|
return
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||||
|
|
||||||
bot = AnthropicTelegramInferenceBot()
|
bot = AnthropicTelegramInferenceBot()
|
||||||
telegram_helper = TelegramHelper(bot)
|
telegram_helper = TelegramHelper(bot)
|
||||||
telegram_helper.run()
|
telegram_helper.run()
|
||||||
|
|||||||
@@ -64,6 +64,23 @@ class BaseTelegramInferenceBot(ABC):
|
|||||||
if function["function"]["name"] == function_name:
|
if function["function"]["name"] == function_name:
|
||||||
return tool.execute(function_name, **function_args)
|
return tool.execute(function_name, **function_args)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_system_prompt_description(self) -> str:
|
||||||
|
"""Returns a description of the system prompt being used."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_llm_description(self) -> str:
|
||||||
|
"""Returns a description of the LLM being used."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def status(self) -> str: # Changed from abstract to concrete
|
||||||
|
"""Provides a status message including prompt and LLM information."""
|
||||||
|
prompt_desc = self.get_system_prompt_description()
|
||||||
|
llm_desc = self.get_llm_description()
|
||||||
|
# Consider potential async calls if get_... methods were async
|
||||||
|
# For now, assuming they are synchronous as per design
|
||||||
|
return f"{prompt_desc}\n{llm_desc}"
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def start(self):
|
async def start(self):
|
||||||
@@ -73,10 +90,6 @@ class BaseTelegramInferenceBot(ABC):
|
|||||||
async def clear(self, user_id):
|
async def clear(self, user_id):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def status(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def abort_processing(self, user_id):
|
async def abort_processing(self, user_id):
|
||||||
pass
|
pass
|
||||||
@@ -1,12 +1,11 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
from base_telegram_inference_bot import BaseTelegramInferenceBot # Assuming this base class exists
|
from base_telegram_inference_bot import BaseTelegramInferenceBot
|
||||||
from telegram_helper import TelegramHelper # Assuming this helper class exists
|
from telegram_helper import TelegramHelper
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
# Ensure basic logging is configured if not done elsewhere
|
# logging.basicConfig(level=logging.INFO) # Usually configured in main execution script
|
||||||
# logging.basicConfig(level=logging.INFO) # Example: You might have a more sophisticated setup
|
|
||||||
|
|
||||||
class ChatGPTTelegramInferenceBot(BaseTelegramInferenceBot):
|
class ChatGPTTelegramInferenceBot(BaseTelegramInferenceBot):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -14,12 +13,12 @@ class ChatGPTTelegramInferenceBot(BaseTelegramInferenceBot):
|
|||||||
self.client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
|
self.client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
|
||||||
|
|
||||||
self._configure_model_and_tokens(
|
self._configure_model_and_tokens(
|
||||||
os.environ.get("OPENAI_SMALL_MODEL"), # Default model
|
os.environ.get("OPENAI_SMALL_MODEL", "gpt-3.5-turbo"), # Default to a common small model
|
||||||
os.environ.get("OPENAI_SMALL_MODEL_MAX_TOKENS") # Default tokens
|
os.environ.get("OPENAI_SMALL_MODEL_MAX_TOKENS")
|
||||||
)
|
)
|
||||||
|
|
||||||
def _configure_model_and_tokens(self, model_name, max_tokens_str, default_max_tokens=1000):
|
def _configure_model_and_tokens(self, model_name, max_tokens_str, default_max_tokens=1000):
|
||||||
self.model = model_name
|
self.model = model_name if model_name else "gpt-3.5-turbo" # Ensure model has a default
|
||||||
try:
|
try:
|
||||||
self.max_tokens = int(max_tokens_str) if max_tokens_str is not None else default_max_tokens
|
self.max_tokens = int(max_tokens_str) if max_tokens_str is not None else default_max_tokens
|
||||||
except ValueError:
|
except ValueError:
|
||||||
@@ -27,11 +26,23 @@ class ChatGPTTelegramInferenceBot(BaseTelegramInferenceBot):
|
|||||||
self.max_tokens = default_max_tokens
|
self.max_tokens = default_max_tokens
|
||||||
logging.info(f"Configured to use model: {self.model} with max_tokens: {self.max_tokens}")
|
logging.info(f"Configured to use model: {self.model} with max_tokens: {self.max_tokens}")
|
||||||
|
|
||||||
|
def get_system_prompt_description(self) -> str:
|
||||||
|
system_prompt_path = os.getenv("SYSTEM_PROMPT_PATH")
|
||||||
|
if system_prompt_path and os.path.isfile(system_prompt_path):
|
||||||
|
return f"System Prompt File: {os.path.basename(system_prompt_path)}"
|
||||||
|
elif system_prompt_path: # Path is set but file not found
|
||||||
|
return f"System Prompt File: {os.path.basename(system_prompt_path)} (Not found at path: {system_prompt_path})"
|
||||||
|
else: # Path not set
|
||||||
|
return "System Prompt File: Not configured (SYSTEM_PROMPT_PATH not set)."
|
||||||
|
|
||||||
|
def get_llm_description(self) -> str:
|
||||||
|
return f"LLM: {self.model}, Max Tokens: {self.max_tokens}"
|
||||||
|
|
||||||
def get_chat_response(self, messages):
|
def get_chat_response(self, messages):
|
||||||
try:
|
try:
|
||||||
response = self.client.chat.completions.create(
|
response = self.client.chat.completions.create(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
messages=messages, # The system prompt is expected to be part of messages here
|
messages=messages,
|
||||||
tools=self.functions if hasattr(self, 'functions') and self.functions else None,
|
tools=self.functions if hasattr(self, 'functions') and self.functions else None,
|
||||||
tool_choice="auto" if hasattr(self, 'functions') and self.functions else None,
|
tool_choice="auto" if hasattr(self, 'functions') and self.functions else None,
|
||||||
max_tokens=self.max_tokens
|
max_tokens=self.max_tokens
|
||||||
@@ -52,92 +63,112 @@ class ChatGPTTelegramInferenceBot(BaseTelegramInferenceBot):
|
|||||||
|
|
||||||
response = self.get_chat_response(messages)
|
response = self.get_chat_response(messages)
|
||||||
|
|
||||||
tool_calls = []
|
if not (response.choices and response.choices[0].message):
|
||||||
|
logging.error("No valid response choice message from LLM.")
|
||||||
|
return "Error: Could not get a valid response from the LLM."
|
||||||
|
|
||||||
for message_part in response.choices:
|
messages.append(response.choices[0].message) # Append the assistant's response message
|
||||||
if message_part.finish_reason == "tool_calls":
|
|
||||||
tool_calls.extend(message_part.message.tool_calls)
|
|
||||||
|
|
||||||
messages.append(response.choices[0].message)
|
tool_calls_from_response = []
|
||||||
|
if response.choices[0].message.tool_calls:
|
||||||
|
tool_calls_from_response.extend(response.choices[0].message.tool_calls)
|
||||||
|
|
||||||
tool_use_count = 0
|
tool_use_count = 0
|
||||||
while len(tool_calls) > 0 and tool_use_count < 500:
|
MAX_TOOL_ITERATIONS = 5
|
||||||
tool_use_results = []
|
|
||||||
|
|
||||||
while len(tool_calls) > 0:
|
while tool_calls_from_response and tool_use_count < MAX_TOOL_ITERATIONS:
|
||||||
tool_call_message = tool_calls.pop(0)
|
tool_results_for_model = []
|
||||||
tool_call_id = tool_call_message.id
|
|
||||||
tool_call = tool_call_message.function
|
for tool_call in tool_calls_from_response:
|
||||||
tool_response = self.call_tool(tool_call.name, tool_call.arguments)
|
tool_call_id = tool_call.id
|
||||||
|
function_to_call = tool_call.function
|
||||||
|
|
||||||
|
logging.info(f"Attempting to call tool: {function_to_call.name} with args: {function_to_call.arguments}")
|
||||||
try:
|
try:
|
||||||
tool_use_results.append({"role": "tool", "tool_call_id": tool_call_id, "name":tool_call.name, "content": str(tool_response) })
|
tool_response_content = self.call_tool(function_to_call.name, function_to_call.arguments)
|
||||||
except (TypeError, ValueError) as e:
|
if not isinstance(tool_response_content, str):
|
||||||
logging.error(f"Failed to serialize tool response: {e}")
|
tool_response_content = json.dumps(tool_response_content)
|
||||||
tool_use_results.append({"role": "function", "name": tool_call.name, "content": "Serialization error"})
|
except Exception as e:
|
||||||
|
logging.error(f"Error calling tool {function_to_call.name}: {e}")
|
||||||
|
tool_response_content = f"Error executing tool {function_to_call.name}: {str(e)}"
|
||||||
|
|
||||||
messages.extend(tool_use_results)
|
tool_results_for_model.append({
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": tool_call_id,
|
||||||
|
"name": function_to_call.name,
|
||||||
|
"content": tool_response_content
|
||||||
|
})
|
||||||
|
|
||||||
|
messages.extend(tool_results_for_model)
|
||||||
|
|
||||||
response = self.get_chat_response(messages)
|
response = self.get_chat_response(messages)
|
||||||
|
if not (response.choices and response.choices[0].message):
|
||||||
for message_part in response.choices:
|
logging.error("No valid response choice message from LLM after tool call.")
|
||||||
if message_part.finish_reason == "tool_calls":
|
return "Error: Could not get a valid response from the LLM after tool call."
|
||||||
tool_calls.extend(message_part.message.tool_calls)
|
|
||||||
|
|
||||||
messages.append(response.choices[0].message)
|
messages.append(response.choices[0].message)
|
||||||
|
|
||||||
tool_use_count += 1
|
tool_calls_from_response = []
|
||||||
|
if response.choices[0].message.tool_calls:
|
||||||
|
tool_calls_from_response.extend(response.choices[0].message.tool_calls)
|
||||||
|
|
||||||
if len(self.conversation_history[user_id]) > 20:
|
tool_use_count += 1
|
||||||
|
if tool_use_count >= MAX_TOOL_ITERATIONS and tool_calls_from_response:
|
||||||
|
logging.warning(f"Max tool iterations ({MAX_TOOL_ITERATIONS}) reached. Returning last assistant message.")
|
||||||
|
|
||||||
|
if len(self.conversation_history[user_id]) > 20: # This limit seems small, consider increasing
|
||||||
self.conversation_history[user_id] = self.conversation_history[user_id][-20:]
|
self.conversation_history[user_id] = self.conversation_history[user_id][-20:]
|
||||||
|
|
||||||
return messages[-1].content
|
final_assistant_message = messages[-1]
|
||||||
|
return final_assistant_message.content if final_assistant_message.role == "assistant" and final_assistant_message.content else "No content in final message."
|
||||||
|
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
logging.info("Bot started")
|
logging.info("ChatGPT Bot started")
|
||||||
# Potentially call super().start() if it exists and does something
|
# super().start() if Base class start() has common logic
|
||||||
|
|
||||||
async def clear(self, user_id):
|
async def clear(self, user_id):
|
||||||
super().clear_conversation(user_id)
|
super().clear_conversation(user_id)
|
||||||
|
|
||||||
|
# status() method is inherited from BaseTelegramInferenceBot
|
||||||
async def status(self):
|
|
||||||
return f"Currently using: {self.model}, Max Tokens: {self.max_tokens}"
|
|
||||||
|
|
||||||
async def abort_processing(self, user_id):
|
async def abort_processing(self, user_id):
|
||||||
# This depends on how processing_status is managed, likely in BaseTelegramInferenceBot
|
if user_id in self.processing_status: # Relies on processing_status from Base
|
||||||
if hasattr(self, 'processing_status') and user_id in self.processing_status:
|
self.processing_status[user_id]["processing"] = False
|
||||||
self.processing_status[user_id]["processing"] = False # Example
|
await self.clear(user_id)
|
||||||
await self.clear(user_id) # Clearing conversation on abort might be desired
|
|
||||||
return "Processing aborted and conversation cleared."
|
return "Processing aborted and conversation cleared."
|
||||||
else:
|
else:
|
||||||
# If not tracking processing_status here, just clear for safety
|
|
||||||
await self.clear(user_id)
|
await self.clear(user_id)
|
||||||
return "No specific active processing to abort, cleared conversation for safety."
|
return "No active processing found to abort. Conversation cleared."
|
||||||
|
|
||||||
async def switch_model(self):
|
async def switch_model(self):
|
||||||
current_small_model = os.environ.get("OPENAI_SMALL_MODEL")
|
# Ensure environment variables for model names are set for this to work meaningfully
|
||||||
current_large_model = os.environ.get("OPENAI_LARGE_MODEL")
|
current_small_model = os.environ.get("OPENAI_SMALL_MODEL", "gpt-3.5-turbo")
|
||||||
|
current_large_model = os.environ.get("OPENAI_LARGE_MODEL", "gpt-4") # Example large model
|
||||||
|
|
||||||
if self.model == current_small_model:
|
# Default to small model if current model is not recognized or if it's the large one
|
||||||
target_model = current_large_model
|
if self.model == current_large_model or self.model != current_small_model :
|
||||||
target_max_tokens = os.environ.get("OPENAI_LARGE_MODEL_MAX_TOKENS")
|
|
||||||
else:
|
|
||||||
target_model = current_small_model
|
target_model = current_small_model
|
||||||
target_max_tokens = os.environ.get("OPENAI_SMALL_MODEL_MAX_TOKENS")
|
target_max_tokens = os.environ.get("OPENAI_SMALL_MODEL_MAX_TOKENS")
|
||||||
|
else: # Current is small (or default), switch to large
|
||||||
|
target_model = current_large_model
|
||||||
|
target_max_tokens = os.environ.get("OPENAI_LARGE_MODEL_MAX_TOKENS")
|
||||||
|
|
||||||
self._configure_model_and_tokens(target_model, target_max_tokens)
|
self._configure_model_and_tokens(target_model, target_max_tokens)
|
||||||
|
logging.info(f"Switched to model: {self.model}")
|
||||||
return f"Switched to model: {self.model}"
|
return f"Switched to model: {self.model}"
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# Ensure OPENAI_API_KEY and other environment variables are set
|
|
||||||
if not os.environ.get("OPENAI_API_KEY"):
|
if not os.environ.get("OPENAI_API_KEY"):
|
||||||
logging.error("FATAL: OPENAI_API_KEY environment variable not set.")
|
logging.error("FATAL: OPENAI_API_KEY environment variable not set.")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Configure logging here if it's the main entry point
|
||||||
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||||
|
|
||||||
bot = ChatGPTTelegramInferenceBot()
|
bot = ChatGPTTelegramInferenceBot()
|
||||||
telegram_helper = TelegramHelper(bot)
|
telegram_helper = TelegramHelper(bot)
|
||||||
telegram_helper.run()
|
telegram_helper.run()
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
||||||
main()
|
main()
|
||||||
@@ -1,12 +1,11 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
from base_telegram_inference_bot import BaseTelegramInferenceBot # Assuming this base class exists
|
from base_telegram_inference_bot import BaseTelegramInferenceBot
|
||||||
from telegram_helper import TelegramHelper # Assuming this helper class exists
|
from telegram_helper import TelegramHelper # This import might be unused if main() is removed or TelegramHelper is not directly instantiated here.
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
# Ensure basic logging is configured if not done elsewhere
|
# logging.basicConfig(level=logging.INFO) # Usually configured in main execution script
|
||||||
# logging.basicConfig(level=logging.INFO) # Example: You might have a more sophisticated setup
|
|
||||||
|
|
||||||
class GeminiTelegramInferenceBot(BaseTelegramInferenceBot):
|
class GeminiTelegramInferenceBot(BaseTelegramInferenceBot):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -14,12 +13,12 @@ class GeminiTelegramInferenceBot(BaseTelegramInferenceBot):
|
|||||||
self.client = OpenAI(api_key=os.environ.get("GEMINI_API_KEY"), base_url=os.environ.get("GEMINI_API_BASE_URL"))
|
self.client = OpenAI(api_key=os.environ.get("GEMINI_API_KEY"), base_url=os.environ.get("GEMINI_API_BASE_URL"))
|
||||||
|
|
||||||
self._configure_model_and_tokens(
|
self._configure_model_and_tokens(
|
||||||
os.environ.get("GEMINI_SMALL_MODEL"), # Default model
|
os.environ.get("GEMINI_SMALL_MODEL"),
|
||||||
os.environ.get("GEMINI_SMALL_MODEL_MAX_TOKENS") # Default tokens
|
os.environ.get("GEMINI_SMALL_MODEL_MAX_TOKENS")
|
||||||
)
|
)
|
||||||
|
|
||||||
def _configure_model_and_tokens(self, model_name, max_tokens_str, default_max_tokens=1000):
|
def _configure_model_and_tokens(self, model_name, max_tokens_str, default_max_tokens=1000):
|
||||||
self.model = model_name
|
self.model = model_name if model_name else "default-gemini-model" # Ensure model has a default
|
||||||
try:
|
try:
|
||||||
self.max_tokens = int(max_tokens_str) if max_tokens_str is not None else default_max_tokens
|
self.max_tokens = int(max_tokens_str) if max_tokens_str is not None else default_max_tokens
|
||||||
except ValueError:
|
except ValueError:
|
||||||
@@ -27,11 +26,23 @@ class GeminiTelegramInferenceBot(BaseTelegramInferenceBot):
|
|||||||
self.max_tokens = default_max_tokens
|
self.max_tokens = default_max_tokens
|
||||||
logging.info(f"Configured to use model: {self.model} with max_tokens: {self.max_tokens}")
|
logging.info(f"Configured to use model: {self.model} with max_tokens: {self.max_tokens}")
|
||||||
|
|
||||||
|
def get_system_prompt_description(self) -> str:
|
||||||
|
system_prompt_path = os.getenv("SYSTEM_PROMPT_PATH")
|
||||||
|
if system_prompt_path and os.path.isfile(system_prompt_path):
|
||||||
|
return f"System Prompt File: {os.path.basename(system_prompt_path)}"
|
||||||
|
elif system_prompt_path: # Path is set but file not found
|
||||||
|
return f"System Prompt File: {os.path.basename(system_prompt_path)} (Not found at path: {system_prompt_path})"
|
||||||
|
else: # Path not set
|
||||||
|
return "System Prompt File: Not configured (SYSTEM_PROMPT_PATH not set)."
|
||||||
|
|
||||||
|
def get_llm_description(self) -> str:
|
||||||
|
return f"LLM: {self.model}, Max Tokens: {self.max_tokens}"
|
||||||
|
|
||||||
def get_chat_response(self, messages):
|
def get_chat_response(self, messages):
|
||||||
try:
|
try:
|
||||||
response = self.client.chat.completions.create(
|
response = self.client.chat.completions.create(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
messages=messages, # The system prompt is expected to be part of messages here
|
messages=messages,
|
||||||
tools=self.functions if hasattr(self, 'functions') and self.functions else None,
|
tools=self.functions if hasattr(self, 'functions') and self.functions else None,
|
||||||
tool_choice="auto" if hasattr(self, 'functions') and self.functions else None,
|
tool_choice="auto" if hasattr(self, 'functions') and self.functions else None,
|
||||||
max_tokens=self.max_tokens
|
max_tokens=self.max_tokens
|
||||||
@@ -39,6 +50,8 @@ class GeminiTelegramInferenceBot(BaseTelegramInferenceBot):
|
|||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Gemini API call failed: {e}")
|
logging.error(f"Gemini API call failed: {e}")
|
||||||
|
# Return a more structured error or re-raise a custom exception
|
||||||
|
# For now, re-raising to be handled by the caller
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def handle_message(self, user_id, user_message):
|
async def handle_message(self, user_id, user_message):
|
||||||
@@ -52,92 +65,125 @@ class GeminiTelegramInferenceBot(BaseTelegramInferenceBot):
|
|||||||
|
|
||||||
response = self.get_chat_response(messages)
|
response = self.get_chat_response(messages)
|
||||||
|
|
||||||
tool_calls = []
|
# Ensure response.choices[0].message exists before appending
|
||||||
|
if response.choices and response.choices[0].message:
|
||||||
|
messages.append(response.choices[0].message) # Append the assistant's response message
|
||||||
|
else:
|
||||||
|
logging.error("No valid response choice message from LLM.")
|
||||||
|
return "Error: Could not get a valid response from the LLM."
|
||||||
|
|
||||||
for message_part in response.choices:
|
tool_calls_from_response = []
|
||||||
if message_part.finish_reason == "tool_calls":
|
if response.choices[0].message.tool_calls:
|
||||||
tool_calls.extend(message_part.message.tool_calls)
|
tool_calls_from_response.extend(response.choices[0].message.tool_calls)
|
||||||
|
|
||||||
messages.append(response.choices[0].message)
|
|
||||||
|
|
||||||
tool_use_count = 0
|
tool_use_count = 0
|
||||||
while len(tool_calls) > 0 and tool_use_count < 500:
|
MAX_TOOL_ITERATIONS = 5 # Define a max to prevent infinite loops more explicitly
|
||||||
tool_use_results = []
|
|
||||||
|
|
||||||
while len(tool_calls) > 0:
|
while tool_calls_from_response and tool_use_count < MAX_TOOL_ITERATIONS:
|
||||||
tool_call_message = tool_calls.pop(0)
|
tool_results_for_model = [] # Results to be sent back to the model
|
||||||
tool_call_id = tool_call_message.id
|
|
||||||
tool_call = tool_call_message.function
|
for tool_call in tool_calls_from_response:
|
||||||
tool_response = self.call_tool(tool_call.name, tool_call.arguments)
|
tool_call_id = tool_call.id
|
||||||
|
function_to_call = tool_call.function
|
||||||
|
|
||||||
|
logging.info(f"Attempting to call tool: {function_to_call.name} with args: {function_to_call.arguments}")
|
||||||
try:
|
try:
|
||||||
tool_use_results.append({"role": "tool", "tool_call_id": tool_call_id, "name":tool_call.name, "content": str(tool_response) })
|
tool_response_content = self.call_tool(function_to_call.name, function_to_call.arguments)
|
||||||
except (TypeError, ValueError) as e:
|
# Ensure tool_response_content is a string for the API
|
||||||
logging.error(f"Failed to serialize tool response: {e}")
|
if not isinstance(tool_response_content, str):
|
||||||
tool_use_results.append({"role": "function", "name": tool_call.name, "content": "Serialization error"})
|
tool_response_content = json.dumps(tool_response_content)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error calling tool {function_to_call.name}: {e}")
|
||||||
|
tool_response_content = f"Error executing tool {function_to_call.name}: {str(e)}"
|
||||||
|
|
||||||
messages.extend(tool_use_results)
|
tool_results_for_model.append({
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": tool_call_id,
|
||||||
|
"name": function_to_call.name,
|
||||||
|
"content": tool_response_content
|
||||||
|
})
|
||||||
|
|
||||||
|
messages.extend(tool_results_for_model) # Add tool responses to message history
|
||||||
|
|
||||||
|
# Get new response from model based on tool execution results
|
||||||
response = self.get_chat_response(messages)
|
response = self.get_chat_response(messages)
|
||||||
|
if not (response.choices and response.choices[0].message):
|
||||||
|
logging.error("No valid response choice message from LLM after tool call.")
|
||||||
|
return "Error: Could not get a valid response from the LLM after tool call."
|
||||||
|
|
||||||
for message_part in response.choices:
|
messages.append(response.choices[0].message) # Append new assistant message
|
||||||
if message_part.finish_reason == "tool_calls":
|
|
||||||
tool_calls.extend(message_part.message.tool_calls)
|
|
||||||
|
|
||||||
messages.append(response.choices[0].message)
|
# Check for new tool calls
|
||||||
|
tool_calls_from_response = [] # Reset for this iteration
|
||||||
|
if response.choices[0].message.tool_calls:
|
||||||
|
tool_calls_from_response.extend(response.choices[0].message.tool_calls)
|
||||||
|
|
||||||
tool_use_count += 1
|
tool_use_count += 1
|
||||||
|
if tool_use_count >= MAX_TOOL_ITERATIONS and tool_calls_from_response:
|
||||||
|
logging.warning(f"Max tool iterations ({MAX_TOOL_ITERATIONS}) reached. Returning last assistant message.")
|
||||||
|
# May need to return a message indicating this to user
|
||||||
|
|
||||||
if len(self.conversation_history[user_id]) > 2000:
|
# Conversation history management
|
||||||
|
if len(self.conversation_history[user_id]) > 2000: # Assuming this limit is for messages, not tokens
|
||||||
self.conversation_history[user_id] = self.conversation_history[user_id][-2000:]
|
self.conversation_history[user_id] = self.conversation_history[user_id][-2000:]
|
||||||
|
|
||||||
return messages[-1].content
|
# Return the latest assistant content
|
||||||
|
final_assistant_message = messages[-1]
|
||||||
|
return final_assistant_message.content if final_assistant_message.role == "assistant" and final_assistant_message.content else "No content in final message."
|
||||||
|
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
logging.info("Bot started")
|
logging.info("Gemini Bot started")
|
||||||
# Potentially call super().start() if it exists and does something
|
# super().start() if Base class start() has common logic
|
||||||
|
|
||||||
async def clear(self, user_id):
|
async def clear(self, user_id):
|
||||||
super().clear_conversation(user_id)
|
super().clear_conversation(user_id) # Calls base class method
|
||||||
|
|
||||||
|
# status() method is inherited from BaseTelegramInferenceBot
|
||||||
async def status(self):
|
|
||||||
return f"Currently using: {self.model}, Max Tokens: {self.max_tokens}"
|
|
||||||
|
|
||||||
async def abort_processing(self, user_id):
|
async def abort_processing(self, user_id):
|
||||||
# This depends on how processing_status is managed, likely in BaseTelegramInferenceBot
|
if user_id in self.processing_status:
|
||||||
if hasattr(self, 'processing_status') and user_id in self.processing_status:
|
self.processing_status[user_id]["processing"] = False
|
||||||
self.processing_status[user_id]["processing"] = False # Example
|
# It's good practice to also clear the conversation for an aborted state
|
||||||
await self.clear(user_id) # Clearing conversation on abort might be desired
|
await self.clear(user_id)
|
||||||
return "Processing aborted and conversation cleared."
|
return "Processing aborted and conversation cleared."
|
||||||
else:
|
else:
|
||||||
# If not tracking processing_status here, just clear for safety
|
# If no specific status, clearing conversation is a safe default
|
||||||
await self.clear(user_id)
|
await self.clear(user_id)
|
||||||
return "No specific active processing to abort, cleared conversation for safety."
|
return "No active processing found to abort. Conversation cleared."
|
||||||
|
|
||||||
async def switch_model(self):
|
async def switch_model(self):
|
||||||
current_small_model = os.environ.get("GEMINI_SMALL_MODEL")
|
current_small_model = os.environ.get("GEMINI_SMALL_MODEL")
|
||||||
current_large_model = os.environ.get("GEMINI_LARGE_MODEL")
|
current_large_model = os.environ.get("GEMINI_LARGE_MODEL")
|
||||||
|
|
||||||
if self.model == current_small_model:
|
# Default to small model if current model is not recognized or if it's the large one
|
||||||
target_model = current_large_model
|
if self.model == current_large_model or self.model != current_small_model :
|
||||||
target_max_tokens = os.environ.get("GEMINI_LARGE_MODEL_MAX_TOKENS")
|
|
||||||
else:
|
|
||||||
target_model = current_small_model
|
target_model = current_small_model
|
||||||
target_max_tokens = os.environ.get("GEMINI_SMALL_MODEL_MAX_TOKENS")
|
target_max_tokens = os.environ.get("GEMINI_SMALL_MODEL_MAX_TOKENS")
|
||||||
|
else: # Current is small, switch to large
|
||||||
|
target_model = current_large_model
|
||||||
|
target_max_tokens = os.environ.get("GEMINI_LARGE_MODEL_MAX_TOKENS")
|
||||||
|
|
||||||
self._configure_model_and_tokens(target_model, target_max_tokens)
|
self._configure_model_and_tokens(target_model, target_max_tokens)
|
||||||
|
logging.info(f"Switched to model: {self.model}")
|
||||||
return f"Switched to model: {self.model}"
|
return f"Switched to model: {self.model}"
|
||||||
|
|
||||||
|
# The main() function and if __name__ == '__main__': block are for standalone execution.
|
||||||
|
# If this bot is imported as a module, these might not be necessary or might be handled differently.
|
||||||
|
# For now, keeping them as they were.
|
||||||
def main():
|
def main():
|
||||||
# Ensure GEMINI_API_KEY and other environment variables are set
|
|
||||||
if not os.environ.get("GEMINI_API_KEY"):
|
if not os.environ.get("GEMINI_API_KEY"):
|
||||||
logging.error("FATAL: GEMINI_API_KEY environment variable not set.")
|
logging.error("FATAL: GEMINI_API_KEY environment variable not set.")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Configure logging here if it's the main entry point
|
||||||
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||||
|
|
||||||
bot = GeminiTelegramInferenceBot()
|
bot = GeminiTelegramInferenceBot()
|
||||||
|
# The instantiation of TelegramHelper and running it implies this file can be an entry point.
|
||||||
|
# If it's purely a module, this main() would be removed.
|
||||||
telegram_helper = TelegramHelper(bot)
|
telegram_helper = TelegramHelper(bot)
|
||||||
telegram_helper.run()
|
telegram_helper.run()
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
||||||
main()
|
main()
|
||||||
Reference in New Issue
Block a user