diff --git a/anthropic_telegram_inference_bot.py b/anthropic_telegram_inference_bot.py index 2c39863..116b00a 100644 --- a/anthropic_telegram_inference_bot.py +++ b/anthropic_telegram_inference_bot.py @@ -5,106 +5,219 @@ from anthropic import Anthropic from base_telegram_inference_bot import BaseTelegramInferenceBot from telegram_helper import TelegramHelper +# logging.basicConfig(level=logging.INFO) # Usually configured in main execution script + class AnthropicTelegramInferenceBot(BaseTelegramInferenceBot): def __init__(self): super().__init__() - self.anthropic_client = Anthropic( - api_key=os.environ.get("ANTHROPIC_API_KEY"), - default_headers={"anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15"} + self.anthropic_client = Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY")) + # Note: default_headers for max_tokens with older models might be needed. + # For Claude 3.5 Sonnet, max_tokens is a top-level param in messages.create + + # Configure model and tokens. Using Sonnet 3.5 as default. + # ANTHROPIC_MODEL and ANTHROPIC_MAX_TOKENS would be new ENVs. + self._configure_model_and_tokens( + os.environ.get("ANTHROPIC_MODEL", "claude-3-5-sonnet-20240620"), + os.environ.get("ANTHROPIC_MAX_TOKENS", "4096") # Default max tokens for Sonnet 3.5 ) - def get_chat_response(self, messages): - anthropic_tools = [ - { - "name": function['name'], - "description": function['description'], - "input_schema": function['parameters'] if function['parameters'] not in [None, {}] else {"type": "object", "properties": {"param1": {"type": "string", "description": "Unnecessary"}}, "required": []} - } - for function in self.functions - ] + def _configure_model_and_tokens(self, model_name, max_tokens_str, default_max_tokens=4096): + self.model = model_name if model_name else "claude-3-5-sonnet-20240620" + try: + # Anthropic's max_tokens is an integer. + self.max_tokens = int(max_tokens_str) if max_tokens_str is not None else default_max_tokens + except ValueError: + logging.error(f"Invalid value for Anthropic max_tokens: {max_tokens_str}. Using default {default_max_tokens}.") + self.max_tokens = default_max_tokens + logging.info(f"Configured to use Anthropic model: {self.model} with max_tokens: {self.max_tokens}") + + def get_system_prompt_description(self) -> str: + system_prompt_path = os.getenv("SYSTEM_PROMPT_PATH") + if system_prompt_path and os.path.isfile(system_prompt_path): + return f"System Prompt File: {os.path.basename(system_prompt_path)}" + elif system_prompt_path: + return f"System Prompt File: {os.path.basename(system_prompt_path)} (Not found at path: {system_prompt_path})" + else: + return "System Prompt File: Not configured (SYSTEM_PROMPT_PATH not set)." + + def get_llm_description(self) -> str: + return f"LLM: {self.model}, Max Tokens: {self.max_tokens}" + + def get_chat_response(self, messages_history): + current_system_prompt = self.system_prompt if self.system_prompt else "" + anthropic_tools = [] + if hasattr(self, 'functions') and self.functions: + anthropic_tools = [ + { + "name": function['name'], + "description": function['description'], + "input_schema": function['parameters'] if function['parameters'] not in [None, {}] else {"type": "object", "properties": {}} + } + for function in self.functions + ] + try: response = self.anthropic_client.messages.create( - model="claude-3-5-sonnet-20240620", - system=self.system_prompt, - messages=messages, - max_tokens=8192, - tools=anthropic_tools, - tool_choice={"type": "auto"} + model=self.model, + system=current_system_prompt, + messages=messages_history, + max_tokens=self.max_tokens, + tools=anthropic_tools if anthropic_tools else None, + tool_choice={"type": "auto"} if anthropic_tools else None ) + return response except Exception as e: - logging.error(f"An error occurred: {str(e)}") - return None - - return response + logging.error(f"Anthropic API call failed: {e}") + raise async def handle_message(self, user_id, user_message): if user_id not in self.conversation_history: self.conversation_history[user_id] = [] self.conversation_history[user_id].append({"role": "user", "content": user_message}) - messages = self.conversation_history[user_id] - - response = self.get_chat_response(messages) - tool_calls = [] - full_message = [] - for message_part in response.content: - full_message.append(message_part) - if message_part.type == "tool_use": - tool_calls.append(message_part) - messages.append({"role": "assistant", "content": full_message}) + current_turn_messages = list(self.conversation_history[user_id]) + MAX_TOOL_ITERATIONS = 5 tool_use_count = 0 - while len(tool_calls) > 0 and tool_use_count < 50: - tool_use_results = [] - while len(tool_calls) > 0: - tool_call = tool_calls.pop(0) - tool_response = self.call_tool(tool_call.name, json.dumps(tool_call.input)) - tool_use_results.append({"type": "tool_result", "tool_use_id": tool_call.id, "content": json.dumps(tool_response)}) + assistant_response_content = "" - messages.append({"role": "user", "content": tool_use_results}) + while tool_use_count < MAX_TOOL_ITERATIONS: + response = self.get_chat_response(current_turn_messages) - response = self.get_chat_response(messages) - full_message = [] + if not response or not response.content: + logging.error("No valid response content from Anthropic LLM.") + self.conversation_history[user_id] = current_turn_messages # Persist what we have + return "Error: Could not get a valid response from the LLM." + + assistant_current_turn_content_blocks = response.content + current_turn_messages.append({"role": "assistant", "content": assistant_current_turn_content_blocks}) + + text_parts_from_assistant = [] + tool_calls_from_response = [] + for block in assistant_current_turn_content_blocks: + if block.type == "text": + text_parts_from_assistant.append(block.text) + elif block.type == "tool_use": + tool_calls_from_response.append(block) - for message_part in response.content: - full_message.append(message_part) - if message_part.type == "tool_use": - tool_calls.append(message_part) - messages.append({"role": "assistant", "content": full_message}) + assistant_response_content = "".join(text_parts_from_assistant) + if not tool_calls_from_response: + break + + tool_results_for_model = [] + for tool_call in tool_calls_from_response: + tool_name = tool_call.name + tool_input = tool_call.input + tool_use_id = tool_call.id + + logging.info(f"Attempting to call Anthropic tool: {tool_name} with input: {tool_input}") + try: + tool_response_data = self.call_tool(tool_name, tool_input) + + if isinstance(tool_response_data, str): + tool_result_content_block = [{"type": "text", "text": tool_response_data}] + elif isinstance(tool_response_data, dict) or isinstance(tool_response_data, list): + try: + # If tool_response_data is already a list of Anthropic content blocks, use as is. + # Otherwise, dump to JSON string and wrap in a text block. + is_valid_block_list = isinstance(tool_response_data, list) and all(isinstance(item, dict) and "type" in item for item in tool_response_data) + if is_valid_block_list: + tool_result_content_block = tool_response_data + else: + tool_result_content_block = [{"type": "text", "text": json.dumps(tool_response_data)}] + except (TypeError, json.JSONDecodeError): # Not easily serializable or not a valid block list + tool_result_content_block = [{"type": "text", "text": str(tool_response_data)}] + else: # bool, int, float, None, etc. + tool_result_content_block = [{"type": "text", "text": str(tool_response_data)}] + + tool_results_for_model.append({ + "type": "tool_result", + "tool_use_id": tool_use_id, + "content": tool_result_content_block + }) + except Exception as e: + logging.error(f"Error calling tool {tool_name}: {e}") + tool_results_for_model.append({ + "type": "tool_result", + "tool_use_id": tool_use_id, + "content": [{"type": "text", "text": f"Error executing tool {tool_name}: {str(e)}"}], + "is_error": True + }) + + current_turn_messages.append({"role": "user", "content": tool_results_for_model}) + tool_use_count += 1 + if tool_use_count >= MAX_TOOL_ITERATIONS: + logging.warning(f"Max tool iterations ({MAX_TOOL_ITERATIONS}) reached for Anthropic.") + break - if (tool_use_count == 0): - assistant_reply = response.content - self.conversation_history[user_id].append({"role": "assistant", "content": assistant_reply}) + self.conversation_history[user_id] = current_turn_messages if len(self.conversation_history[user_id]) > 20: self.conversation_history[user_id] = self.conversation_history[user_id][-20:] - return messages[-1]["content"][0].text + if assistant_response_content: # Text from the last successful assistant turn (or before max iterations) + return assistant_response_content + else: # Fallback if no text content was generated by assistant (e.g. initial error, or only tool use) + if current_turn_messages: + # Try to get the *very last* text block from the *very last* assistant message in history. + last_message_in_turn = current_turn_messages[-1] + if last_message_in_turn.get("role") == "assistant" and isinstance(last_message_in_turn.get("content"), list): + for block in reversed(last_message_in_turn["content"]): + if block.type == "text": + return block.text + return "No textual response from assistant." + async def start(self): - logging.info("Bot started") + logging.info("Anthropic Bot started") async def clear(self, user_id): super().clear_conversation(user_id) - logging.info(f"Cleared conversation history and image for user {user_id}") - - async def status(self): - return "Currently using claude-3-5-sonnet-20240620" + logging.info(f"Cleared conversation history for Anthropic bot, user {user_id}") async def abort_processing(self, user_id): if user_id in self.processing_status: self.processing_status[user_id]["processing"] = False await self.clear(user_id) - return "Processing aborted." + return "Processing aborted and conversation cleared." else: - return "No active processing to abort." + await self.clear(user_id) + return "No active processing found to abort. Conversation cleared." + + async def switch_model(self): + primary_model = os.environ.get("ANTHROPIC_MODEL", "claude-3-5-sonnet-20240620") + primary_max_tokens = os.environ.get("ANTHROPIC_MAX_TOKENS", "4096") + + secondary_model_env = os.environ.get("ANTHROPIC_SECONDARY_MODEL") + secondary_max_tokens_env = os.environ.get("ANTHROPIC_SECONDARY_MAX_TOKENS") + + if not secondary_model_env: + logging.warning("ANTHROPIC_SECONDARY_MODEL not defined. Cannot switch model.") + return f"Model switching not configured. Currently using {self.model}." + + if self.model == primary_model: + target_model = secondary_model_env + target_max_tokens = secondary_max_tokens_env if secondary_max_tokens_env else "2048" + else: + target_model = primary_model + target_max_tokens = primary_max_tokens + + self._configure_model_and_tokens(target_model, target_max_tokens) + logging.info(f"Switched Anthropic model to: {self.model}") + return f"Switched to Anthropic model: {self.model}" def main(): + if not os.environ.get("ANTHROPIC_API_KEY"): + logging.error("FATAL: ANTHROPIC_API_KEY environment variable not set.") + return + + logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') + bot = AnthropicTelegramInferenceBot() telegram_helper = TelegramHelper(bot) telegram_helper.run() if __name__ == '__main__': - main() \ No newline at end of file + main() diff --git a/base_telegram_inference_bot.py b/base_telegram_inference_bot.py index 3e4709d..5979c0b 100644 --- a/base_telegram_inference_bot.py +++ b/base_telegram_inference_bot.py @@ -63,7 +63,24 @@ class BaseTelegramInferenceBot(ABC): for function in tool.get_functions(): if function["function"]["name"] == function_name: return tool.execute(function_name, **function_args) - + + @abstractmethod + def get_system_prompt_description(self) -> str: + """Returns a description of the system prompt being used.""" + pass + + @abstractmethod + def get_llm_description(self) -> str: + """Returns a description of the LLM being used.""" + pass + + async def status(self) -> str: # Changed from abstract to concrete + """Provides a status message including prompt and LLM information.""" + prompt_desc = self.get_system_prompt_description() + llm_desc = self.get_llm_description() + # Consider potential async calls if get_... methods were async + # For now, assuming they are synchronous as per design + return f"{prompt_desc}\n{llm_desc}" @abstractmethod async def start(self): @@ -73,10 +90,6 @@ class BaseTelegramInferenceBot(ABC): async def clear(self, user_id): pass - @abstractmethod - async def status(self): - pass - @abstractmethod async def abort_processing(self, user_id): - pass \ No newline at end of file + pass diff --git a/chatgpt_telegram_inference_bot.py b/chatgpt_telegram_inference_bot.py index 1644d3c..9e6374f 100644 --- a/chatgpt_telegram_inference_bot.py +++ b/chatgpt_telegram_inference_bot.py @@ -1,12 +1,11 @@ import json import os import logging -from base_telegram_inference_bot import BaseTelegramInferenceBot # Assuming this base class exists -from telegram_helper import TelegramHelper # Assuming this helper class exists +from base_telegram_inference_bot import BaseTelegramInferenceBot +from telegram_helper import TelegramHelper from openai import OpenAI -# Ensure basic logging is configured if not done elsewhere -# logging.basicConfig(level=logging.INFO) # Example: You might have a more sophisticated setup +# logging.basicConfig(level=logging.INFO) # Usually configured in main execution script class ChatGPTTelegramInferenceBot(BaseTelegramInferenceBot): def __init__(self): @@ -14,12 +13,12 @@ class ChatGPTTelegramInferenceBot(BaseTelegramInferenceBot): self.client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) self._configure_model_and_tokens( - os.environ.get("OPENAI_SMALL_MODEL"), # Default model - os.environ.get("OPENAI_SMALL_MODEL_MAX_TOKENS") # Default tokens + os.environ.get("OPENAI_SMALL_MODEL", "gpt-3.5-turbo"), # Default to a common small model + os.environ.get("OPENAI_SMALL_MODEL_MAX_TOKENS") ) def _configure_model_and_tokens(self, model_name, max_tokens_str, default_max_tokens=1000): - self.model = model_name + self.model = model_name if model_name else "gpt-3.5-turbo" # Ensure model has a default try: self.max_tokens = int(max_tokens_str) if max_tokens_str is not None else default_max_tokens except ValueError: @@ -27,11 +26,23 @@ class ChatGPTTelegramInferenceBot(BaseTelegramInferenceBot): self.max_tokens = default_max_tokens logging.info(f"Configured to use model: {self.model} with max_tokens: {self.max_tokens}") + def get_system_prompt_description(self) -> str: + system_prompt_path = os.getenv("SYSTEM_PROMPT_PATH") + if system_prompt_path and os.path.isfile(system_prompt_path): + return f"System Prompt File: {os.path.basename(system_prompt_path)}" + elif system_prompt_path: # Path is set but file not found + return f"System Prompt File: {os.path.basename(system_prompt_path)} (Not found at path: {system_prompt_path})" + else: # Path not set + return "System Prompt File: Not configured (SYSTEM_PROMPT_PATH not set)." + + def get_llm_description(self) -> str: + return f"LLM: {self.model}, Max Tokens: {self.max_tokens}" + def get_chat_response(self, messages): try: response = self.client.chat.completions.create( model=self.model, - messages=messages, # The system prompt is expected to be part of messages here + messages=messages, tools=self.functions if hasattr(self, 'functions') and self.functions else None, tool_choice="auto" if hasattr(self, 'functions') and self.functions else None, max_tokens=self.max_tokens @@ -52,92 +63,112 @@ class ChatGPTTelegramInferenceBot(BaseTelegramInferenceBot): response = self.get_chat_response(messages) - tool_calls = [] - - for message_part in response.choices: - if message_part.finish_reason == "tool_calls": - tool_calls.extend(message_part.message.tool_calls) + if not (response.choices and response.choices[0].message): + logging.error("No valid response choice message from LLM.") + return "Error: Could not get a valid response from the LLM." - messages.append(response.choices[0].message) + messages.append(response.choices[0].message) # Append the assistant's response message + tool_calls_from_response = [] + if response.choices[0].message.tool_calls: + tool_calls_from_response.extend(response.choices[0].message.tool_calls) + tool_use_count = 0 - while len(tool_calls) > 0 and tool_use_count < 500: - tool_use_results = [] + MAX_TOOL_ITERATIONS = 5 - while len(tool_calls) > 0: - tool_call_message = tool_calls.pop(0) - tool_call_id = tool_call_message.id - tool_call = tool_call_message.function - tool_response = self.call_tool(tool_call.name, tool_call.arguments) + while tool_calls_from_response and tool_use_count < MAX_TOOL_ITERATIONS: + tool_results_for_model = [] + + for tool_call in tool_calls_from_response: + tool_call_id = tool_call.id + function_to_call = tool_call.function + + logging.info(f"Attempting to call tool: {function_to_call.name} with args: {function_to_call.arguments}") try: - tool_use_results.append({"role": "tool", "tool_call_id": tool_call_id, "name":tool_call.name, "content": str(tool_response) }) - except (TypeError, ValueError) as e: - logging.error(f"Failed to serialize tool response: {e}") - tool_use_results.append({"role": "function", "name": tool_call.name, "content": "Serialization error"}) + tool_response_content = self.call_tool(function_to_call.name, function_to_call.arguments) + if not isinstance(tool_response_content, str): + tool_response_content = json.dumps(tool_response_content) + except Exception as e: + logging.error(f"Error calling tool {function_to_call.name}: {e}") + tool_response_content = f"Error executing tool {function_to_call.name}: {str(e)}" + + tool_results_for_model.append({ + "role": "tool", + "tool_call_id": tool_call_id, + "name": function_to_call.name, + "content": tool_response_content + }) - messages.extend(tool_use_results) + messages.extend(tool_results_for_model) response = self.get_chat_response(messages) - - for message_part in response.choices: - if message_part.finish_reason == "tool_calls": - tool_calls.extend(message_part.message.tool_calls) - + if not (response.choices and response.choices[0].message): + logging.error("No valid response choice message from LLM after tool call.") + return "Error: Could not get a valid response from the LLM after tool call." + messages.append(response.choices[0].message) + tool_calls_from_response = [] + if response.choices[0].message.tool_calls: + tool_calls_from_response.extend(response.choices[0].message.tool_calls) + tool_use_count += 1 + if tool_use_count >= MAX_TOOL_ITERATIONS and tool_calls_from_response: + logging.warning(f"Max tool iterations ({MAX_TOOL_ITERATIONS}) reached. Returning last assistant message.") - if len(self.conversation_history[user_id]) > 20: + if len(self.conversation_history[user_id]) > 20: # This limit seems small, consider increasing self.conversation_history[user_id] = self.conversation_history[user_id][-20:] - return messages[-1].content + final_assistant_message = messages[-1] + return final_assistant_message.content if final_assistant_message.role == "assistant" and final_assistant_message.content else "No content in final message." + async def start(self): - logging.info("Bot started") - # Potentially call super().start() if it exists and does something + logging.info("ChatGPT Bot started") + # super().start() if Base class start() has common logic async def clear(self, user_id): super().clear_conversation(user_id) - - async def status(self): - return f"Currently using: {self.model}, Max Tokens: {self.max_tokens}" + # status() method is inherited from BaseTelegramInferenceBot async def abort_processing(self, user_id): - # This depends on how processing_status is managed, likely in BaseTelegramInferenceBot - if hasattr(self, 'processing_status') and user_id in self.processing_status: - self.processing_status[user_id]["processing"] = False # Example - await self.clear(user_id) # Clearing conversation on abort might be desired + if user_id in self.processing_status: # Relies on processing_status from Base + self.processing_status[user_id]["processing"] = False + await self.clear(user_id) return "Processing aborted and conversation cleared." else: - # If not tracking processing_status here, just clear for safety await self.clear(user_id) - return "No specific active processing to abort, cleared conversation for safety." + return "No active processing found to abort. Conversation cleared." async def switch_model(self): - current_small_model = os.environ.get("OPENAI_SMALL_MODEL") - current_large_model = os.environ.get("OPENAI_LARGE_MODEL") + # Ensure environment variables for model names are set for this to work meaningfully + current_small_model = os.environ.get("OPENAI_SMALL_MODEL", "gpt-3.5-turbo") + current_large_model = os.environ.get("OPENAI_LARGE_MODEL", "gpt-4") # Example large model - if self.model == current_small_model: - target_model = current_large_model - target_max_tokens = os.environ.get("OPENAI_LARGE_MODEL_MAX_TOKENS") - else: + # Default to small model if current model is not recognized or if it's the large one + if self.model == current_large_model or self.model != current_small_model : target_model = current_small_model target_max_tokens = os.environ.get("OPENAI_SMALL_MODEL_MAX_TOKENS") + else: # Current is small (or default), switch to large + target_model = current_large_model + target_max_tokens = os.environ.get("OPENAI_LARGE_MODEL_MAX_TOKENS") self._configure_model_and_tokens(target_model, target_max_tokens) + logging.info(f"Switched to model: {self.model}") return f"Switched to model: {self.model}" def main(): - # Ensure OPENAI_API_KEY and other environment variables are set if not os.environ.get("OPENAI_API_KEY"): logging.error("FATAL: OPENAI_API_KEY environment variable not set.") return - + + # Configure logging here if it's the main entry point + logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') + bot = ChatGPTTelegramInferenceBot() telegram_helper = TelegramHelper(bot) telegram_helper.run() if __name__ == '__main__': - logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') - main() \ No newline at end of file + main() diff --git a/gemini_telegram_inference_bot.py b/gemini_telegram_inference_bot.py index 4e59b7d..fccde2f 100644 --- a/gemini_telegram_inference_bot.py +++ b/gemini_telegram_inference_bot.py @@ -1,12 +1,11 @@ import json import os import logging -from base_telegram_inference_bot import BaseTelegramInferenceBot # Assuming this base class exists -from telegram_helper import TelegramHelper # Assuming this helper class exists +from base_telegram_inference_bot import BaseTelegramInferenceBot +from telegram_helper import TelegramHelper # This import might be unused if main() is removed or TelegramHelper is not directly instantiated here. from openai import OpenAI -# Ensure basic logging is configured if not done elsewhere -# logging.basicConfig(level=logging.INFO) # Example: You might have a more sophisticated setup +# logging.basicConfig(level=logging.INFO) # Usually configured in main execution script class GeminiTelegramInferenceBot(BaseTelegramInferenceBot): def __init__(self): @@ -14,12 +13,12 @@ class GeminiTelegramInferenceBot(BaseTelegramInferenceBot): self.client = OpenAI(api_key=os.environ.get("GEMINI_API_KEY"), base_url=os.environ.get("GEMINI_API_BASE_URL")) self._configure_model_and_tokens( - os.environ.get("GEMINI_SMALL_MODEL"), # Default model - os.environ.get("GEMINI_SMALL_MODEL_MAX_TOKENS") # Default tokens + os.environ.get("GEMINI_SMALL_MODEL"), + os.environ.get("GEMINI_SMALL_MODEL_MAX_TOKENS") ) def _configure_model_and_tokens(self, model_name, max_tokens_str, default_max_tokens=1000): - self.model = model_name + self.model = model_name if model_name else "default-gemini-model" # Ensure model has a default try: self.max_tokens = int(max_tokens_str) if max_tokens_str is not None else default_max_tokens except ValueError: @@ -27,11 +26,23 @@ class GeminiTelegramInferenceBot(BaseTelegramInferenceBot): self.max_tokens = default_max_tokens logging.info(f"Configured to use model: {self.model} with max_tokens: {self.max_tokens}") + def get_system_prompt_description(self) -> str: + system_prompt_path = os.getenv("SYSTEM_PROMPT_PATH") + if system_prompt_path and os.path.isfile(system_prompt_path): + return f"System Prompt File: {os.path.basename(system_prompt_path)}" + elif system_prompt_path: # Path is set but file not found + return f"System Prompt File: {os.path.basename(system_prompt_path)} (Not found at path: {system_prompt_path})" + else: # Path not set + return "System Prompt File: Not configured (SYSTEM_PROMPT_PATH not set)." + + def get_llm_description(self) -> str: + return f"LLM: {self.model}, Max Tokens: {self.max_tokens}" + def get_chat_response(self, messages): try: response = self.client.chat.completions.create( model=self.model, - messages=messages, # The system prompt is expected to be part of messages here + messages=messages, tools=self.functions if hasattr(self, 'functions') and self.functions else None, tool_choice="auto" if hasattr(self, 'functions') and self.functions else None, max_tokens=self.max_tokens @@ -39,6 +50,8 @@ class GeminiTelegramInferenceBot(BaseTelegramInferenceBot): return response except Exception as e: logging.error(f"Gemini API call failed: {e}") + # Return a more structured error or re-raise a custom exception + # For now, re-raising to be handled by the caller raise async def handle_message(self, user_id, user_message): @@ -52,92 +65,125 @@ class GeminiTelegramInferenceBot(BaseTelegramInferenceBot): response = self.get_chat_response(messages) - tool_calls = [] - - for message_part in response.choices: - if message_part.finish_reason == "tool_calls": - tool_calls.extend(message_part.message.tool_calls) + # Ensure response.choices[0].message exists before appending + if response.choices and response.choices[0].message: + messages.append(response.choices[0].message) # Append the assistant's response message + else: + logging.error("No valid response choice message from LLM.") + return "Error: Could not get a valid response from the LLM." + + tool_calls_from_response = [] + if response.choices[0].message.tool_calls: + tool_calls_from_response.extend(response.choices[0].message.tool_calls) - messages.append(response.choices[0].message) - tool_use_count = 0 - while len(tool_calls) > 0 and tool_use_count < 500: - tool_use_results = [] + MAX_TOOL_ITERATIONS = 5 # Define a max to prevent infinite loops more explicitly - while len(tool_calls) > 0: - tool_call_message = tool_calls.pop(0) - tool_call_id = tool_call_message.id - tool_call = tool_call_message.function - tool_response = self.call_tool(tool_call.name, tool_call.arguments) + while tool_calls_from_response and tool_use_count < MAX_TOOL_ITERATIONS: + tool_results_for_model = [] # Results to be sent back to the model + + for tool_call in tool_calls_from_response: + tool_call_id = tool_call.id + function_to_call = tool_call.function + + logging.info(f"Attempting to call tool: {function_to_call.name} with args: {function_to_call.arguments}") try: - tool_use_results.append({"role": "tool", "tool_call_id": tool_call_id, "name":tool_call.name, "content": str(tool_response) }) - except (TypeError, ValueError) as e: - logging.error(f"Failed to serialize tool response: {e}") - tool_use_results.append({"role": "function", "name": tool_call.name, "content": "Serialization error"}) + tool_response_content = self.call_tool(function_to_call.name, function_to_call.arguments) + # Ensure tool_response_content is a string for the API + if not isinstance(tool_response_content, str): + tool_response_content = json.dumps(tool_response_content) + except Exception as e: + logging.error(f"Error calling tool {function_to_call.name}: {e}") + tool_response_content = f"Error executing tool {function_to_call.name}: {str(e)}" + + tool_results_for_model.append({ + "role": "tool", + "tool_call_id": tool_call_id, + "name": function_to_call.name, + "content": tool_response_content + }) - messages.extend(tool_use_results) + messages.extend(tool_results_for_model) # Add tool responses to message history + # Get new response from model based on tool execution results response = self.get_chat_response(messages) - - for message_part in response.choices: - if message_part.finish_reason == "tool_calls": - tool_calls.extend(message_part.message.tool_calls) - - messages.append(response.choices[0].message) + if not (response.choices and response.choices[0].message): + logging.error("No valid response choice message from LLM after tool call.") + return "Error: Could not get a valid response from the LLM after tool call." + messages.append(response.choices[0].message) # Append new assistant message + + # Check for new tool calls + tool_calls_from_response = [] # Reset for this iteration + if response.choices[0].message.tool_calls: + tool_calls_from_response.extend(response.choices[0].message.tool_calls) + tool_use_count += 1 + if tool_use_count >= MAX_TOOL_ITERATIONS and tool_calls_from_response: + logging.warning(f"Max tool iterations ({MAX_TOOL_ITERATIONS}) reached. Returning last assistant message.") + # May need to return a message indicating this to user - if len(self.conversation_history[user_id]) > 2000: + # Conversation history management + if len(self.conversation_history[user_id]) > 2000: # Assuming this limit is for messages, not tokens self.conversation_history[user_id] = self.conversation_history[user_id][-2000:] - return messages[-1].content + # Return the latest assistant content + final_assistant_message = messages[-1] + return final_assistant_message.content if final_assistant_message.role == "assistant" and final_assistant_message.content else "No content in final message." + async def start(self): - logging.info("Bot started") - # Potentially call super().start() if it exists and does something + logging.info("Gemini Bot started") + # super().start() if Base class start() has common logic async def clear(self, user_id): - super().clear_conversation(user_id) + super().clear_conversation(user_id) # Calls base class method - - async def status(self): - return f"Currently using: {self.model}, Max Tokens: {self.max_tokens}" + # status() method is inherited from BaseTelegramInferenceBot async def abort_processing(self, user_id): - # This depends on how processing_status is managed, likely in BaseTelegramInferenceBot - if hasattr(self, 'processing_status') and user_id in self.processing_status: - self.processing_status[user_id]["processing"] = False # Example - await self.clear(user_id) # Clearing conversation on abort might be desired + if user_id in self.processing_status: + self.processing_status[user_id]["processing"] = False + # It's good practice to also clear the conversation for an aborted state + await self.clear(user_id) return "Processing aborted and conversation cleared." else: - # If not tracking processing_status here, just clear for safety + # If no specific status, clearing conversation is a safe default await self.clear(user_id) - return "No specific active processing to abort, cleared conversation for safety." + return "No active processing found to abort. Conversation cleared." async def switch_model(self): current_small_model = os.environ.get("GEMINI_SMALL_MODEL") current_large_model = os.environ.get("GEMINI_LARGE_MODEL") - if self.model == current_small_model: - target_model = current_large_model - target_max_tokens = os.environ.get("GEMINI_LARGE_MODEL_MAX_TOKENS") - else: + # Default to small model if current model is not recognized or if it's the large one + if self.model == current_large_model or self.model != current_small_model : target_model = current_small_model target_max_tokens = os.environ.get("GEMINI_SMALL_MODEL_MAX_TOKENS") + else: # Current is small, switch to large + target_model = current_large_model + target_max_tokens = os.environ.get("GEMINI_LARGE_MODEL_MAX_TOKENS") self._configure_model_and_tokens(target_model, target_max_tokens) + logging.info(f"Switched to model: {self.model}") return f"Switched to model: {self.model}" +# The main() function and if __name__ == '__main__': block are for standalone execution. +# If this bot is imported as a module, these might not be necessary or might be handled differently. +# For now, keeping them as they were. def main(): - # Ensure GEMINI_API_KEY and other environment variables are set if not os.environ.get("GEMINI_API_KEY"): logging.error("FATAL: GEMINI_API_KEY environment variable not set.") return + # Configure logging here if it's the main entry point + logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') + bot = GeminiTelegramInferenceBot() + # The instantiation of TelegramHelper and running it implies this file can be an entry point. + # If it's purely a module, this main() would be removed. telegram_helper = TelegramHelper(bot) telegram_helper.run() if __name__ == '__main__': - logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') - main() \ No newline at end of file + main()