feat: Implement prompt/LLM status and refine tool handling (Gemini bot)

This commit is contained in:
cyclop-bot
2025-06-02 14:31:30 -05:00
parent f5b75f77ca
commit d1c8693cc4
+100 -54
View File
@@ -1,12 +1,11 @@
import json import json
import os import os
import logging import logging
from base_telegram_inference_bot import BaseTelegramInferenceBot # Assuming this base class exists from base_telegram_inference_bot import BaseTelegramInferenceBot
from telegram_helper import TelegramHelper # Assuming this helper class exists from telegram_helper import TelegramHelper # This import might be unused if main() is removed or TelegramHelper is not directly instantiated here.
from openai import OpenAI from openai import OpenAI
# Ensure basic logging is configured if not done elsewhere # logging.basicConfig(level=logging.INFO) # Usually configured in main execution script
# logging.basicConfig(level=logging.INFO) # Example: You might have a more sophisticated setup
class GeminiTelegramInferenceBot(BaseTelegramInferenceBot): class GeminiTelegramInferenceBot(BaseTelegramInferenceBot):
def __init__(self): def __init__(self):
@@ -14,12 +13,12 @@ class GeminiTelegramInferenceBot(BaseTelegramInferenceBot):
self.client = OpenAI(api_key=os.environ.get("GEMINI_API_KEY"), base_url=os.environ.get("GEMINI_API_BASE_URL")) self.client = OpenAI(api_key=os.environ.get("GEMINI_API_KEY"), base_url=os.environ.get("GEMINI_API_BASE_URL"))
self._configure_model_and_tokens( self._configure_model_and_tokens(
os.environ.get("GEMINI_SMALL_MODEL"), # Default model os.environ.get("GEMINI_SMALL_MODEL"),
os.environ.get("GEMINI_SMALL_MODEL_MAX_TOKENS") # Default tokens os.environ.get("GEMINI_SMALL_MODEL_MAX_TOKENS")
) )
def _configure_model_and_tokens(self, model_name, max_tokens_str, default_max_tokens=1000): def _configure_model_and_tokens(self, model_name, max_tokens_str, default_max_tokens=1000):
self.model = model_name self.model = model_name if model_name else "default-gemini-model" # Ensure model has a default
try: try:
self.max_tokens = int(max_tokens_str) if max_tokens_str is not None else default_max_tokens self.max_tokens = int(max_tokens_str) if max_tokens_str is not None else default_max_tokens
except ValueError: except ValueError:
@@ -27,11 +26,23 @@ class GeminiTelegramInferenceBot(BaseTelegramInferenceBot):
self.max_tokens = default_max_tokens self.max_tokens = default_max_tokens
logging.info(f"Configured to use model: {self.model} with max_tokens: {self.max_tokens}") logging.info(f"Configured to use model: {self.model} with max_tokens: {self.max_tokens}")
def get_system_prompt_description(self) -> str:
system_prompt_path = os.getenv("SYSTEM_PROMPT_PATH")
if system_prompt_path and os.path.isfile(system_prompt_path):
return f"System Prompt File: {os.path.basename(system_prompt_path)}"
elif system_prompt_path: # Path is set but file not found
return f"System Prompt File: {os.path.basename(system_prompt_path)} (Not found at path: {system_prompt_path})"
else: # Path not set
return "System Prompt File: Not configured (SYSTEM_PROMPT_PATH not set)."
def get_llm_description(self) -> str:
return f"LLM: {self.model}, Max Tokens: {self.max_tokens}"
def get_chat_response(self, messages): def get_chat_response(self, messages):
try: try:
response = self.client.chat.completions.create( response = self.client.chat.completions.create(
model=self.model, model=self.model,
messages=messages, # The system prompt is expected to be part of messages here messages=messages,
tools=self.functions if hasattr(self, 'functions') and self.functions else None, tools=self.functions if hasattr(self, 'functions') and self.functions else None,
tool_choice="auto" if hasattr(self, 'functions') and self.functions else None, tool_choice="auto" if hasattr(self, 'functions') and self.functions else None,
max_tokens=self.max_tokens max_tokens=self.max_tokens
@@ -39,6 +50,8 @@ class GeminiTelegramInferenceBot(BaseTelegramInferenceBot):
return response return response
except Exception as e: except Exception as e:
logging.error(f"Gemini API call failed: {e}") logging.error(f"Gemini API call failed: {e}")
# Return a more structured error or re-raise a custom exception
# For now, re-raising to be handled by the caller
raise raise
async def handle_message(self, user_id, user_message): async def handle_message(self, user_id, user_message):
@@ -52,92 +65,125 @@ class GeminiTelegramInferenceBot(BaseTelegramInferenceBot):
response = self.get_chat_response(messages) response = self.get_chat_response(messages)
tool_calls = [] # Ensure response.choices[0].message exists before appending
if response.choices and response.choices[0].message:
for message_part in response.choices: messages.append(response.choices[0].message) # Append the assistant's response message
if message_part.finish_reason == "tool_calls": else:
tool_calls.extend(message_part.message.tool_calls) logging.error("No valid response choice message from LLM.")
return "Error: Could not get a valid response from the LLM."
tool_calls_from_response = []
if response.choices[0].message.tool_calls:
tool_calls_from_response.extend(response.choices[0].message.tool_calls)
messages.append(response.choices[0].message)
tool_use_count = 0 tool_use_count = 0
while len(tool_calls) > 0 and tool_use_count < 500: MAX_TOOL_ITERATIONS = 5 # Define a max to prevent infinite loops more explicitly
tool_use_results = []
while len(tool_calls) > 0: while tool_calls_from_response and tool_use_count < MAX_TOOL_ITERATIONS:
tool_call_message = tool_calls.pop(0) tool_results_for_model = [] # Results to be sent back to the model
tool_call_id = tool_call_message.id
tool_call = tool_call_message.function for tool_call in tool_calls_from_response:
tool_response = self.call_tool(tool_call.name, tool_call.arguments) tool_call_id = tool_call.id
function_to_call = tool_call.function
logging.info(f"Attempting to call tool: {function_to_call.name} with args: {function_to_call.arguments}")
try: try:
tool_use_results.append({"role": "tool", "tool_call_id": tool_call_id, "name":tool_call.name, "content": str(tool_response) }) tool_response_content = self.call_tool(function_to_call.name, function_to_call.arguments)
except (TypeError, ValueError) as e: # Ensure tool_response_content is a string for the API
logging.error(f"Failed to serialize tool response: {e}") if not isinstance(tool_response_content, str):
tool_use_results.append({"role": "function", "name": tool_call.name, "content": "Serialization error"}) tool_response_content = json.dumps(tool_response_content)
except Exception as e:
logging.error(f"Error calling tool {function_to_call.name}: {e}")
tool_response_content = f"Error executing tool {function_to_call.name}: {str(e)}"
tool_results_for_model.append({
"role": "tool",
"tool_call_id": tool_call_id,
"name": function_to_call.name,
"content": tool_response_content
})
messages.extend(tool_use_results) messages.extend(tool_results_for_model) # Add tool responses to message history
# Get new response from model based on tool execution results
response = self.get_chat_response(messages) response = self.get_chat_response(messages)
if not (response.choices and response.choices[0].message):
for message_part in response.choices: logging.error("No valid response choice message from LLM after tool call.")
if message_part.finish_reason == "tool_calls": return "Error: Could not get a valid response from the LLM after tool call."
tool_calls.extend(message_part.message.tool_calls)
messages.append(response.choices[0].message)
messages.append(response.choices[0].message) # Append new assistant message
# Check for new tool calls
tool_calls_from_response = [] # Reset for this iteration
if response.choices[0].message.tool_calls:
tool_calls_from_response.extend(response.choices[0].message.tool_calls)
tool_use_count += 1 tool_use_count += 1
if tool_use_count >= MAX_TOOL_ITERATIONS and tool_calls_from_response:
logging.warning(f"Max tool iterations ({MAX_TOOL_ITERATIONS}) reached. Returning last assistant message.")
# May need to return a message indicating this to user
if len(self.conversation_history[user_id]) > 2000: # Conversation history management
if len(self.conversation_history[user_id]) > 2000: # Assuming this limit is for messages, not tokens
self.conversation_history[user_id] = self.conversation_history[user_id][-2000:] self.conversation_history[user_id] = self.conversation_history[user_id][-2000:]
return messages[-1].content # Return the latest assistant content
final_assistant_message = messages[-1]
return final_assistant_message.content if final_assistant_message.role == "assistant" and final_assistant_message.content else "No content in final message."
async def start(self): async def start(self):
logging.info("Bot started") logging.info("Gemini Bot started")
# Potentially call super().start() if it exists and does something # super().start() if Base class start() has common logic
async def clear(self, user_id): async def clear(self, user_id):
super().clear_conversation(user_id) super().clear_conversation(user_id) # Calls base class method
# status() method is inherited from BaseTelegramInferenceBot
async def status(self):
return f"Currently using: {self.model}, Max Tokens: {self.max_tokens}"
async def abort_processing(self, user_id): async def abort_processing(self, user_id):
# This depends on how processing_status is managed, likely in BaseTelegramInferenceBot if user_id in self.processing_status:
if hasattr(self, 'processing_status') and user_id in self.processing_status: self.processing_status[user_id]["processing"] = False
self.processing_status[user_id]["processing"] = False # Example # It's good practice to also clear the conversation for an aborted state
await self.clear(user_id) # Clearing conversation on abort might be desired await self.clear(user_id)
return "Processing aborted and conversation cleared." return "Processing aborted and conversation cleared."
else: else:
# If not tracking processing_status here, just clear for safety # If no specific status, clearing conversation is a safe default
await self.clear(user_id) await self.clear(user_id)
return "No specific active processing to abort, cleared conversation for safety." return "No active processing found to abort. Conversation cleared."
async def switch_model(self): async def switch_model(self):
current_small_model = os.environ.get("GEMINI_SMALL_MODEL") current_small_model = os.environ.get("GEMINI_SMALL_MODEL")
current_large_model = os.environ.get("GEMINI_LARGE_MODEL") current_large_model = os.environ.get("GEMINI_LARGE_MODEL")
if self.model == current_small_model: # Default to small model if current model is not recognized or if it's the large one
target_model = current_large_model if self.model == current_large_model or self.model != current_small_model :
target_max_tokens = os.environ.get("GEMINI_LARGE_MODEL_MAX_TOKENS")
else:
target_model = current_small_model target_model = current_small_model
target_max_tokens = os.environ.get("GEMINI_SMALL_MODEL_MAX_TOKENS") target_max_tokens = os.environ.get("GEMINI_SMALL_MODEL_MAX_TOKENS")
else: # Current is small, switch to large
target_model = current_large_model
target_max_tokens = os.environ.get("GEMINI_LARGE_MODEL_MAX_TOKENS")
self._configure_model_and_tokens(target_model, target_max_tokens) self._configure_model_and_tokens(target_model, target_max_tokens)
logging.info(f"Switched to model: {self.model}")
return f"Switched to model: {self.model}" return f"Switched to model: {self.model}"
# The main() function and if __name__ == '__main__': block are for standalone execution.
# If this bot is imported as a module, these might not be necessary or might be handled differently.
# For now, keeping them as they were.
def main(): def main():
# Ensure GEMINI_API_KEY and other environment variables are set
if not os.environ.get("GEMINI_API_KEY"): if not os.environ.get("GEMINI_API_KEY"):
logging.error("FATAL: GEMINI_API_KEY environment variable not set.") logging.error("FATAL: GEMINI_API_KEY environment variable not set.")
return return
# Configure logging here if it's the main entry point
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
bot = GeminiTelegramInferenceBot() bot = GeminiTelegramInferenceBot()
# The instantiation of TelegramHelper and running it implies this file can be an entry point.
# If it's purely a module, this main() would be removed.
telegram_helper = TelegramHelper(bot) telegram_helper = TelegramHelper(bot)
telegram_helper.run() telegram_helper.run()
if __name__ == '__main__': if __name__ == '__main__':
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') main()
main()