From d1c8693cc40af1af0c71d7463a7ed33a66ed2793 Mon Sep 17 00:00:00 2001 From: cyclop-bot <178948048+cyclop-bot@users.noreply.github.com> Date: Mon, 2 Jun 2025 14:31:30 -0500 Subject: [PATCH] feat: Implement prompt/LLM status and refine tool handling (Gemini bot) --- gemini_telegram_inference_bot.py | 154 ++++++++++++++++++++----------- 1 file changed, 100 insertions(+), 54 deletions(-) diff --git a/gemini_telegram_inference_bot.py b/gemini_telegram_inference_bot.py index 4e59b7d..fccde2f 100644 --- a/gemini_telegram_inference_bot.py +++ b/gemini_telegram_inference_bot.py @@ -1,12 +1,11 @@ import json import os import logging -from base_telegram_inference_bot import BaseTelegramInferenceBot # Assuming this base class exists -from telegram_helper import TelegramHelper # Assuming this helper class exists +from base_telegram_inference_bot import BaseTelegramInferenceBot +from telegram_helper import TelegramHelper # This import might be unused if main() is removed or TelegramHelper is not directly instantiated here. from openai import OpenAI -# Ensure basic logging is configured if not done elsewhere -# logging.basicConfig(level=logging.INFO) # Example: You might have a more sophisticated setup +# logging.basicConfig(level=logging.INFO) # Usually configured in main execution script class GeminiTelegramInferenceBot(BaseTelegramInferenceBot): def __init__(self): @@ -14,12 +13,12 @@ class GeminiTelegramInferenceBot(BaseTelegramInferenceBot): self.client = OpenAI(api_key=os.environ.get("GEMINI_API_KEY"), base_url=os.environ.get("GEMINI_API_BASE_URL")) self._configure_model_and_tokens( - os.environ.get("GEMINI_SMALL_MODEL"), # Default model - os.environ.get("GEMINI_SMALL_MODEL_MAX_TOKENS") # Default tokens + os.environ.get("GEMINI_SMALL_MODEL"), + os.environ.get("GEMINI_SMALL_MODEL_MAX_TOKENS") ) def _configure_model_and_tokens(self, model_name, max_tokens_str, default_max_tokens=1000): - self.model = model_name + self.model = model_name if model_name else "default-gemini-model" # Ensure model has a default try: self.max_tokens = int(max_tokens_str) if max_tokens_str is not None else default_max_tokens except ValueError: @@ -27,11 +26,23 @@ class GeminiTelegramInferenceBot(BaseTelegramInferenceBot): self.max_tokens = default_max_tokens logging.info(f"Configured to use model: {self.model} with max_tokens: {self.max_tokens}") + def get_system_prompt_description(self) -> str: + system_prompt_path = os.getenv("SYSTEM_PROMPT_PATH") + if system_prompt_path and os.path.isfile(system_prompt_path): + return f"System Prompt File: {os.path.basename(system_prompt_path)}" + elif system_prompt_path: # Path is set but file not found + return f"System Prompt File: {os.path.basename(system_prompt_path)} (Not found at path: {system_prompt_path})" + else: # Path not set + return "System Prompt File: Not configured (SYSTEM_PROMPT_PATH not set)." + + def get_llm_description(self) -> str: + return f"LLM: {self.model}, Max Tokens: {self.max_tokens}" + def get_chat_response(self, messages): try: response = self.client.chat.completions.create( model=self.model, - messages=messages, # The system prompt is expected to be part of messages here + messages=messages, tools=self.functions if hasattr(self, 'functions') and self.functions else None, tool_choice="auto" if hasattr(self, 'functions') and self.functions else None, max_tokens=self.max_tokens @@ -39,6 +50,8 @@ class GeminiTelegramInferenceBot(BaseTelegramInferenceBot): return response except Exception as e: logging.error(f"Gemini API call failed: {e}") + # Return a more structured error or re-raise a custom exception + # For now, re-raising to be handled by the caller raise async def handle_message(self, user_id, user_message): @@ -52,92 +65,125 @@ class GeminiTelegramInferenceBot(BaseTelegramInferenceBot): response = self.get_chat_response(messages) - tool_calls = [] - - for message_part in response.choices: - if message_part.finish_reason == "tool_calls": - tool_calls.extend(message_part.message.tool_calls) + # Ensure response.choices[0].message exists before appending + if response.choices and response.choices[0].message: + messages.append(response.choices[0].message) # Append the assistant's response message + else: + logging.error("No valid response choice message from LLM.") + return "Error: Could not get a valid response from the LLM." + + tool_calls_from_response = [] + if response.choices[0].message.tool_calls: + tool_calls_from_response.extend(response.choices[0].message.tool_calls) - messages.append(response.choices[0].message) - tool_use_count = 0 - while len(tool_calls) > 0 and tool_use_count < 500: - tool_use_results = [] + MAX_TOOL_ITERATIONS = 5 # Define a max to prevent infinite loops more explicitly - while len(tool_calls) > 0: - tool_call_message = tool_calls.pop(0) - tool_call_id = tool_call_message.id - tool_call = tool_call_message.function - tool_response = self.call_tool(tool_call.name, tool_call.arguments) + while tool_calls_from_response and tool_use_count < MAX_TOOL_ITERATIONS: + tool_results_for_model = [] # Results to be sent back to the model + + for tool_call in tool_calls_from_response: + tool_call_id = tool_call.id + function_to_call = tool_call.function + + logging.info(f"Attempting to call tool: {function_to_call.name} with args: {function_to_call.arguments}") try: - tool_use_results.append({"role": "tool", "tool_call_id": tool_call_id, "name":tool_call.name, "content": str(tool_response) }) - except (TypeError, ValueError) as e: - logging.error(f"Failed to serialize tool response: {e}") - tool_use_results.append({"role": "function", "name": tool_call.name, "content": "Serialization error"}) + tool_response_content = self.call_tool(function_to_call.name, function_to_call.arguments) + # Ensure tool_response_content is a string for the API + if not isinstance(tool_response_content, str): + tool_response_content = json.dumps(tool_response_content) + except Exception as e: + logging.error(f"Error calling tool {function_to_call.name}: {e}") + tool_response_content = f"Error executing tool {function_to_call.name}: {str(e)}" + + tool_results_for_model.append({ + "role": "tool", + "tool_call_id": tool_call_id, + "name": function_to_call.name, + "content": tool_response_content + }) - messages.extend(tool_use_results) + messages.extend(tool_results_for_model) # Add tool responses to message history + # Get new response from model based on tool execution results response = self.get_chat_response(messages) - - for message_part in response.choices: - if message_part.finish_reason == "tool_calls": - tool_calls.extend(message_part.message.tool_calls) - - messages.append(response.choices[0].message) + if not (response.choices and response.choices[0].message): + logging.error("No valid response choice message from LLM after tool call.") + return "Error: Could not get a valid response from the LLM after tool call." + messages.append(response.choices[0].message) # Append new assistant message + + # Check for new tool calls + tool_calls_from_response = [] # Reset for this iteration + if response.choices[0].message.tool_calls: + tool_calls_from_response.extend(response.choices[0].message.tool_calls) + tool_use_count += 1 + if tool_use_count >= MAX_TOOL_ITERATIONS and tool_calls_from_response: + logging.warning(f"Max tool iterations ({MAX_TOOL_ITERATIONS}) reached. Returning last assistant message.") + # May need to return a message indicating this to user - if len(self.conversation_history[user_id]) > 2000: + # Conversation history management + if len(self.conversation_history[user_id]) > 2000: # Assuming this limit is for messages, not tokens self.conversation_history[user_id] = self.conversation_history[user_id][-2000:] - return messages[-1].content + # Return the latest assistant content + final_assistant_message = messages[-1] + return final_assistant_message.content if final_assistant_message.role == "assistant" and final_assistant_message.content else "No content in final message." + async def start(self): - logging.info("Bot started") - # Potentially call super().start() if it exists and does something + logging.info("Gemini Bot started") + # super().start() if Base class start() has common logic async def clear(self, user_id): - super().clear_conversation(user_id) + super().clear_conversation(user_id) # Calls base class method - - async def status(self): - return f"Currently using: {self.model}, Max Tokens: {self.max_tokens}" + # status() method is inherited from BaseTelegramInferenceBot async def abort_processing(self, user_id): - # This depends on how processing_status is managed, likely in BaseTelegramInferenceBot - if hasattr(self, 'processing_status') and user_id in self.processing_status: - self.processing_status[user_id]["processing"] = False # Example - await self.clear(user_id) # Clearing conversation on abort might be desired + if user_id in self.processing_status: + self.processing_status[user_id]["processing"] = False + # It's good practice to also clear the conversation for an aborted state + await self.clear(user_id) return "Processing aborted and conversation cleared." else: - # If not tracking processing_status here, just clear for safety + # If no specific status, clearing conversation is a safe default await self.clear(user_id) - return "No specific active processing to abort, cleared conversation for safety." + return "No active processing found to abort. Conversation cleared." async def switch_model(self): current_small_model = os.environ.get("GEMINI_SMALL_MODEL") current_large_model = os.environ.get("GEMINI_LARGE_MODEL") - if self.model == current_small_model: - target_model = current_large_model - target_max_tokens = os.environ.get("GEMINI_LARGE_MODEL_MAX_TOKENS") - else: + # Default to small model if current model is not recognized or if it's the large one + if self.model == current_large_model or self.model != current_small_model : target_model = current_small_model target_max_tokens = os.environ.get("GEMINI_SMALL_MODEL_MAX_TOKENS") + else: # Current is small, switch to large + target_model = current_large_model + target_max_tokens = os.environ.get("GEMINI_LARGE_MODEL_MAX_TOKENS") self._configure_model_and_tokens(target_model, target_max_tokens) + logging.info(f"Switched to model: {self.model}") return f"Switched to model: {self.model}" +# The main() function and if __name__ == '__main__': block are for standalone execution. +# If this bot is imported as a module, these might not be necessary or might be handled differently. +# For now, keeping them as they were. def main(): - # Ensure GEMINI_API_KEY and other environment variables are set if not os.environ.get("GEMINI_API_KEY"): logging.error("FATAL: GEMINI_API_KEY environment variable not set.") return + # Configure logging here if it's the main entry point + logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') + bot = GeminiTelegramInferenceBot() + # The instantiation of TelegramHelper and running it implies this file can be an entry point. + # If it's purely a module, this main() would be removed. telegram_helper = TelegramHelper(bot) telegram_helper.run() if __name__ == '__main__': - logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') - main() \ No newline at end of file + main()