diff --git a/anthropic_telegram_inference_bot.py b/anthropic_telegram_inference_bot.py index 116b00a..3838d4f 100644 --- a/anthropic_telegram_inference_bot.py +++ b/anthropic_telegram_inference_bot.py @@ -1,45 +1,29 @@ import os import json import logging -from anthropic import Anthropic +from anthropic import Anthropic, APIError, RateLimitError from base_telegram_inference_bot import BaseTelegramInferenceBot from telegram_helper import TelegramHelper -# logging.basicConfig(level=logging.INFO) # Usually configured in main execution script - class AnthropicTelegramInferenceBot(BaseTelegramInferenceBot): def __init__(self): super().__init__() self.anthropic_client = Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY")) - # Note: default_headers for max_tokens with older models might be needed. - # For Claude 3.5 Sonnet, max_tokens is a top-level param in messages.create - # Configure model and tokens. Using Sonnet 3.5 as default. - # ANTHROPIC_MODEL and ANTHROPIC_MAX_TOKENS would be new ENVs. self._configure_model_and_tokens( os.environ.get("ANTHROPIC_MODEL", "claude-3-5-sonnet-20240620"), - os.environ.get("ANTHROPIC_MAX_TOKENS", "4096") # Default max tokens for Sonnet 3.5 + os.environ.get("ANTHROPIC_MAX_TOKENS", "4096") ) def _configure_model_and_tokens(self, model_name, max_tokens_str, default_max_tokens=4096): self.model = model_name if model_name else "claude-3-5-sonnet-20240620" try: - # Anthropic's max_tokens is an integer. self.max_tokens = int(max_tokens_str) if max_tokens_str is not None else default_max_tokens except ValueError: logging.error(f"Invalid value for Anthropic max_tokens: {max_tokens_str}. Using default {default_max_tokens}.") self.max_tokens = default_max_tokens logging.info(f"Configured to use Anthropic model: {self.model} with max_tokens: {self.max_tokens}") - def get_system_prompt_description(self) -> str: - system_prompt_path = os.getenv("SYSTEM_PROMPT_PATH") - if system_prompt_path and os.path.isfile(system_prompt_path): - return f"System Prompt File: {os.path.basename(system_prompt_path)}" - elif system_prompt_path: - return f"System Prompt File: {os.path.basename(system_prompt_path)} (Not found at path: {system_prompt_path})" - else: - return "System Prompt File: Not configured (SYSTEM_PROMPT_PATH not set)." - def get_llm_description(self) -> str: return f"LLM: {self.model}, Max Tokens: {self.max_tokens}" @@ -66,9 +50,27 @@ class AnthropicTelegramInferenceBot(BaseTelegramInferenceBot): tool_choice={"type": "auto"} if anthropic_tools else None ) return response - except Exception as e: - logging.error(f"Anthropic API call failed: {e}") + except (APIError, RateLimitError) as e: + logging.error(f"Anthropic API error: {e}") raise + except Exception as e: + logging.error(f"An unexpected error occurred during Anthropic API call: {e}") + raise + + def _format_tool_response_for_anthropic(self, tool_response_data): + if isinstance(tool_response_data, str): + return [{"type": "text", "text": tool_response_data}] + elif isinstance(tool_response_data, (dict, list)): + try: + is_valid_block_list = isinstance(tool_response_data, list) and all(isinstance(item, dict) and "type" in item for item in tool_response_data) + if is_valid_block_list: + return tool_response_data + else: + return [{"type": "text", "text": json.dumps(tool_response_data)}] + except (TypeError, json.JSONDecodeError): + return [{"type": "text", "text": str(tool_response_data)}] + else: + return [{"type": "text", "text": str(tool_response_data)}] async def handle_message(self, user_id, user_message): if user_id not in self.conversation_history: @@ -86,7 +88,7 @@ class AnthropicTelegramInferenceBot(BaseTelegramInferenceBot): if not response or not response.content: logging.error("No valid response content from Anthropic LLM.") - self.conversation_history[user_id] = current_turn_messages # Persist what we have + self.conversation_history[user_id] = current_turn_messages return "Error: Could not get a valid response from the LLM." assistant_current_turn_content_blocks = response.content @@ -114,22 +116,7 @@ class AnthropicTelegramInferenceBot(BaseTelegramInferenceBot): logging.info(f"Attempting to call Anthropic tool: {tool_name} with input: {tool_input}") try: tool_response_data = self.call_tool(tool_name, tool_input) - - if isinstance(tool_response_data, str): - tool_result_content_block = [{"type": "text", "text": tool_response_data}] - elif isinstance(tool_response_data, dict) or isinstance(tool_response_data, list): - try: - # If tool_response_data is already a list of Anthropic content blocks, use as is. - # Otherwise, dump to JSON string and wrap in a text block. - is_valid_block_list = isinstance(tool_response_data, list) and all(isinstance(item, dict) and "type" in item for item in tool_response_data) - if is_valid_block_list: - tool_result_content_block = tool_response_data - else: - tool_result_content_block = [{"type": "text", "text": json.dumps(tool_response_data)}] - except (TypeError, json.JSONDecodeError): # Not easily serializable or not a valid block list - tool_result_content_block = [{"type": "text", "text": str(tool_response_data)}] - else: # bool, int, float, None, etc. - tool_result_content_block = [{"type": "text", "text": str(tool_response_data)}] + tool_result_content_block = self._format_tool_response_for_anthropic(tool_response_data) tool_results_for_model.append({ "type": "tool_result", @@ -157,11 +144,10 @@ class AnthropicTelegramInferenceBot(BaseTelegramInferenceBot): if len(self.conversation_history[user_id]) > 20: self.conversation_history[user_id] = self.conversation_history[user_id][-20:] - if assistant_response_content: # Text from the last successful assistant turn (or before max iterations) + if assistant_response_content: return assistant_response_content - else: # Fallback if no text content was generated by assistant (e.g. initial error, or only tool use) + else: if current_turn_messages: - # Try to get the *very last* text block from the *very last* assistant message in history. last_message_in_turn = current_turn_messages[-1] if last_message_in_turn.get("role") == "assistant" and isinstance(last_message_in_turn.get("content"), list): for block in reversed(last_message_in_turn["content"]): @@ -173,17 +159,17 @@ class AnthropicTelegramInferenceBot(BaseTelegramInferenceBot): async def start(self): logging.info("Anthropic Bot started") - async def clear(self, user_id): - super().clear_conversation(user_id) + async def clear_conversation_history(self, user_id): + super().clear_conversation_history(user_id) logging.info(f"Cleared conversation history for Anthropic bot, user {user_id}") async def abort_processing(self, user_id): if user_id in self.processing_status: self.processing_status[user_id]["processing"] = False - await self.clear(user_id) + await self.clear_conversation_history(user_id) return "Processing aborted and conversation cleared." else: - await self.clear(user_id) + await self.clear_conversation_history(user_id) return "No active processing found to abort. Conversation cleared." async def switch_model(self): @@ -200,7 +186,7 @@ class AnthropicTelegramInferenceBot(BaseTelegramInferenceBot): if self.model == primary_model: target_model = secondary_model_env target_max_tokens = secondary_max_tokens_env if secondary_max_tokens_env else "2048" - else: + else: target_model = primary_model target_max_tokens = primary_max_tokens diff --git a/base_telegram_inference_bot.py b/base_telegram_inference_bot.py index 5979c0b..410ce18 100644 --- a/base_telegram_inference_bot.py +++ b/base_telegram_inference_bot.py @@ -2,6 +2,7 @@ import importlib import os import json import inspect +import logging from abc import ABC, abstractmethod from tools.base_tool import BaseTool @@ -11,32 +12,44 @@ class BaseTelegramInferenceBot(ABC): self.processing_status = {} self.system_prompt = self.load_system_prompt() self.tools, self.functions = self.load_functions() - print(f'System Prompt: {os.environ.get("SYSTEM_PROMPT_PATH")}') - print(f'Github Repository: {os.environ.get("GITHUB_REPOSITORY")}') + logging.info(f'System Prompt: {os.environ.get("SYSTEM_PROMPT_PATH")}') + logging.info(f'Github Repository: {os.environ.get("GITHUB_REPOSITORY")}') - @staticmethod - def load_system_prompt(): + def load_system_prompt(self): system_prompt_path = os.getenv("SYSTEM_PROMPT_PATH") if system_prompt_path and os.path.isfile(system_prompt_path): - with open(system_prompt_path, "r", encoding="utf-8") as file: - return file.read().strip() + try: + with open(system_prompt_path, "r", encoding="utf-8") as file: + return file.read().strip() + except IOError as e: + logging.warning(f"Could not read system prompt file {system_prompt_path}: {e}") + return "You are a helpful AI assistant." else: - raise FileNotFoundError("SYSTEM_PROMPT_PATH is not set or file does not exist.") + logging.warning("SYSTEM_PROMPT_PATH is not set or file does not exist. Using default system prompt.") + return "You are a helpful AI assistant." - @staticmethod - def load_functions(): + def load_functions(self): tools = [] + functions = [] tools_dir = os.path.join(os.path.dirname(__file__), 'tools') + if not os.path.exists(tools_dir): + logging.warning(f"Tools directory not found: {tools_dir}") + return [], [] + for filename in os.listdir(tools_dir): if filename.endswith('.py') and filename != '__init__.py' and filename != 'base_tool.py': module_name = f'tools.{filename[:-3]}' - module = importlib.import_module(module_name) - for name, obj in inspect.getmembers(module): - if inspect.isclass(obj) and issubclass(obj, BaseTool) and obj != BaseTool: - tools.append(obj()) + try: + module = importlib.import_module(module_name) + for name, obj in inspect.getmembers(module): + if inspect.isclass(obj) and issubclass(obj, BaseTool) and obj != BaseTool: + try: + tools.append(obj()) + except Exception as e: + logging.error(f"Error instantiating tool {name} from {filename}: {e}") + except Exception as e: + logging.error(f"Error importing module {module_name}: {e}") - # Collect all function definitions - functions = [] for tool in tools: functions.extend(tool.get_functions()) return tools, functions @@ -49,47 +62,64 @@ class BaseTelegramInferenceBot(ABC): async def handle_message(self, user_id, user_message): pass - def clear_conversation(self, user_id): + def clear_conversation_history(self, user_id): if user_id in self.conversation_history: del self.conversation_history[user_id] + # Assuming tool.clear() is for global state or doesn't need user_id for tool in self.tools: tool.clear() + def set_processing_status(self, user_id: int, message_id: int): + self.processing_status[user_id] = {"processing": True, "message_id": message_id} + + def clear_processing_status(self, user_id: int): + if user_id in self.processing_status: + del self.processing_status[user_id] + def call_tool(self, function_call_name, function_call_arguments): function_name = function_call_name - function_args = json.loads(function_call_arguments if function_call_arguments is not None else "{}") + try: + function_args = json.loads(function_call_arguments if function_call_arguments is not None else "{}") + except json.JSONDecodeError as e: + logging.error(f"Error decoding function call arguments for {function_call_name}: {e}. Arguments: {function_call_arguments}") + return f"Error: Malformed arguments for tool call: {e}" + for tool in self.tools: for function in tool.get_functions(): if function["function"]["name"] == function_name: - return tool.execute(function_name, **function_args) + try: + return tool.execute(function_name, **function_args) + except Exception as e: + logging.error(f"Error executing tool {function_name} with args {function_args}: {e}") + return f"Error executing tool {function_name}: {e}" + logging.warning(f"Tool function {function_name} not found.") + return f"Error: Tool function {function_name} not found." - @abstractmethod def get_system_prompt_description(self) -> str: """Returns a description of the system prompt being used.""" - pass + return f"System Prompt: {'Custom' if os.getenv('SYSTEM_PROMPT_PATH') else 'Default'}" @abstractmethod def get_llm_description(self) -> str: """Returns a description of the LLM being used.""" pass - async def status(self) -> str: # Changed from abstract to concrete + async def get_bot_status(self) -> str: """Provides a status message including prompt and LLM information.""" prompt_desc = self.get_system_prompt_description() llm_desc = self.get_llm_description() - # Consider potential async calls if get_... methods were async - # For now, assuming they are synchronous as per design return f"{prompt_desc}\n{llm_desc}" @abstractmethod async def start(self): pass - @abstractmethod - async def clear(self, user_id): - pass - @abstractmethod async def abort_processing(self, user_id): pass + + @abstractmethod + async def switch_model(self): + """Switches the underlying model if supported by the bot.""" + pass diff --git a/chatgpt_telegram_inference_bot.py b/chatgpt_telegram_inference_bot.py index 9e6374f..086c49f 100644 --- a/chatgpt_telegram_inference_bot.py +++ b/chatgpt_telegram_inference_bot.py @@ -1,156 +1,27 @@ -import json import os import logging -from base_telegram_inference_bot import BaseTelegramInferenceBot -from telegram_helper import TelegramHelper from openai import OpenAI +from openai_compatible_inference_bot import OpenAICompatibleInferenceBot +from telegram_helper import TelegramHelper -# logging.basicConfig(level=logging.INFO) # Usually configured in main execution script - -class ChatGPTTelegramInferenceBot(BaseTelegramInferenceBot): +class ChatGPTTelegramInferenceBot(OpenAICompatibleInferenceBot): def __init__(self): super().__init__() self.client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) self._configure_model_and_tokens( - os.environ.get("OPENAI_SMALL_MODEL", "gpt-3.5-turbo"), # Default to a common small model + os.environ.get("OPENAI_SMALL_MODEL", "gpt-3.5-turbo"), os.environ.get("OPENAI_SMALL_MODEL_MAX_TOKENS") ) - def _configure_model_and_tokens(self, model_name, max_tokens_str, default_max_tokens=1000): - self.model = model_name if model_name else "gpt-3.5-turbo" # Ensure model has a default - try: - self.max_tokens = int(max_tokens_str) if max_tokens_str is not None else default_max_tokens - except ValueError: - logging.error(f"Invalid value for max_tokens: {max_tokens_str}. Using default {default_max_tokens}.") - self.max_tokens = default_max_tokens - logging.info(f"Configured to use model: {self.model} with max_tokens: {self.max_tokens}") - - def get_system_prompt_description(self) -> str: - system_prompt_path = os.getenv("SYSTEM_PROMPT_PATH") - if system_prompt_path and os.path.isfile(system_prompt_path): - return f"System Prompt File: {os.path.basename(system_prompt_path)}" - elif system_prompt_path: # Path is set but file not found - return f"System Prompt File: {os.path.basename(system_prompt_path)} (Not found at path: {system_prompt_path})" - else: # Path not set - return "System Prompt File: Not configured (SYSTEM_PROMPT_PATH not set)." - - def get_llm_description(self) -> str: - return f"LLM: {self.model}, Max Tokens: {self.max_tokens}" - - def get_chat_response(self, messages): - try: - response = self.client.chat.completions.create( - model=self.model, - messages=messages, - tools=self.functions if hasattr(self, 'functions') and self.functions else None, - tool_choice="auto" if hasattr(self, 'functions') and self.functions else None, - max_tokens=self.max_tokens - ) - return response - except Exception as e: - logging.error(f"OpenAI API call failed: {e}") - raise - - async def handle_message(self, user_id, user_message): - if user_id not in self.conversation_history: - self.conversation_history[user_id] = [] - if hasattr(self, 'system_prompt') and self.system_prompt: - self.conversation_history[user_id].append({"role": "system", "content": self.system_prompt}) - - self.conversation_history[user_id].append({"role": "user", "content": user_message}) - messages = self.conversation_history[user_id] - - response = self.get_chat_response(messages) - - if not (response.choices and response.choices[0].message): - logging.error("No valid response choice message from LLM.") - return "Error: Could not get a valid response from the LLM." - - messages.append(response.choices[0].message) # Append the assistant's response message - - tool_calls_from_response = [] - if response.choices[0].message.tool_calls: - tool_calls_from_response.extend(response.choices[0].message.tool_calls) - - tool_use_count = 0 - MAX_TOOL_ITERATIONS = 5 - - while tool_calls_from_response and tool_use_count < MAX_TOOL_ITERATIONS: - tool_results_for_model = [] - - for tool_call in tool_calls_from_response: - tool_call_id = tool_call.id - function_to_call = tool_call.function - - logging.info(f"Attempting to call tool: {function_to_call.name} with args: {function_to_call.arguments}") - try: - tool_response_content = self.call_tool(function_to_call.name, function_to_call.arguments) - if not isinstance(tool_response_content, str): - tool_response_content = json.dumps(tool_response_content) - except Exception as e: - logging.error(f"Error calling tool {function_to_call.name}: {e}") - tool_response_content = f"Error executing tool {function_to_call.name}: {str(e)}" - - tool_results_for_model.append({ - "role": "tool", - "tool_call_id": tool_call_id, - "name": function_to_call.name, - "content": tool_response_content - }) - - messages.extend(tool_results_for_model) - - response = self.get_chat_response(messages) - if not (response.choices and response.choices[0].message): - logging.error("No valid response choice message from LLM after tool call.") - return "Error: Could not get a valid response from the LLM after tool call." - - messages.append(response.choices[0].message) - - tool_calls_from_response = [] - if response.choices[0].message.tool_calls: - tool_calls_from_response.extend(response.choices[0].message.tool_calls) - - tool_use_count += 1 - if tool_use_count >= MAX_TOOL_ITERATIONS and tool_calls_from_response: - logging.warning(f"Max tool iterations ({MAX_TOOL_ITERATIONS}) reached. Returning last assistant message.") - - if len(self.conversation_history[user_id]) > 20: # This limit seems small, consider increasing - self.conversation_history[user_id] = self.conversation_history[user_id][-20:] - - final_assistant_message = messages[-1] - return final_assistant_message.content if final_assistant_message.role == "assistant" and final_assistant_message.content else "No content in final message." - - - async def start(self): - logging.info("ChatGPT Bot started") - # super().start() if Base class start() has common logic - - async def clear(self, user_id): - super().clear_conversation(user_id) - - # status() method is inherited from BaseTelegramInferenceBot - - async def abort_processing(self, user_id): - if user_id in self.processing_status: # Relies on processing_status from Base - self.processing_status[user_id]["processing"] = False - await self.clear(user_id) - return "Processing aborted and conversation cleared." - else: - await self.clear(user_id) - return "No active processing found to abort. Conversation cleared." - async def switch_model(self): - # Ensure environment variables for model names are set for this to work meaningfully current_small_model = os.environ.get("OPENAI_SMALL_MODEL", "gpt-3.5-turbo") - current_large_model = os.environ.get("OPENAI_LARGE_MODEL", "gpt-4") # Example large model + current_large_model = os.environ.get("OPENAI_LARGE_MODEL", "gpt-4") - # Default to small model if current model is not recognized or if it's the large one - if self.model == current_large_model or self.model != current_small_model : + if self.model == current_large_model or self.model != current_small_model: target_model = current_small_model target_max_tokens = os.environ.get("OPENAI_SMALL_MODEL_MAX_TOKENS") - else: # Current is small (or default), switch to large + else: target_model = current_large_model target_max_tokens = os.environ.get("OPENAI_LARGE_MODEL_MAX_TOKENS") @@ -163,7 +34,6 @@ def main(): logging.error("FATAL: OPENAI_API_KEY environment variable not set.") return - # Configure logging here if it's the main entry point logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') bot = ChatGPTTelegramInferenceBot() diff --git a/gemini_telegram_inference_bot.py b/gemini_telegram_inference_bot.py index fccde2f..5c5f549 100644 --- a/gemini_telegram_inference_bot.py +++ b/gemini_telegram_inference_bot.py @@ -1,166 +1,27 @@ -import json import os import logging -from base_telegram_inference_bot import BaseTelegramInferenceBot -from telegram_helper import TelegramHelper # This import might be unused if main() is removed or TelegramHelper is not directly instantiated here. from openai import OpenAI +from openai_compatible_inference_bot import OpenAICompatibleInferenceBot +from telegram_helper import TelegramHelper -# logging.basicConfig(level=logging.INFO) # Usually configured in main execution script - -class GeminiTelegramInferenceBot(BaseTelegramInferenceBot): +class GeminiTelegramInferenceBot(OpenAICompatibleInferenceBot): def __init__(self): super().__init__() self.client = OpenAI(api_key=os.environ.get("GEMINI_API_KEY"), base_url=os.environ.get("GEMINI_API_BASE_URL")) self._configure_model_and_tokens( - os.environ.get("GEMINI_SMALL_MODEL"), + os.environ.get("GEMINI_SMALL_MODEL", "gemini-pro"), os.environ.get("GEMINI_SMALL_MODEL_MAX_TOKENS") ) - def _configure_model_and_tokens(self, model_name, max_tokens_str, default_max_tokens=1000): - self.model = model_name if model_name else "default-gemini-model" # Ensure model has a default - try: - self.max_tokens = int(max_tokens_str) if max_tokens_str is not None else default_max_tokens - except ValueError: - logging.error(f"Invalid value for max_tokens: {max_tokens_str}. Using default {default_max_tokens}.") - self.max_tokens = default_max_tokens - logging.info(f"Configured to use model: {self.model} with max_tokens: {self.max_tokens}") - - def get_system_prompt_description(self) -> str: - system_prompt_path = os.getenv("SYSTEM_PROMPT_PATH") - if system_prompt_path and os.path.isfile(system_prompt_path): - return f"System Prompt File: {os.path.basename(system_prompt_path)}" - elif system_prompt_path: # Path is set but file not found - return f"System Prompt File: {os.path.basename(system_prompt_path)} (Not found at path: {system_prompt_path})" - else: # Path not set - return "System Prompt File: Not configured (SYSTEM_PROMPT_PATH not set)." - - def get_llm_description(self) -> str: - return f"LLM: {self.model}, Max Tokens: {self.max_tokens}" - - def get_chat_response(self, messages): - try: - response = self.client.chat.completions.create( - model=self.model, - messages=messages, - tools=self.functions if hasattr(self, 'functions') and self.functions else None, - tool_choice="auto" if hasattr(self, 'functions') and self.functions else None, - max_tokens=self.max_tokens - ) - return response - except Exception as e: - logging.error(f"Gemini API call failed: {e}") - # Return a more structured error or re-raise a custom exception - # For now, re-raising to be handled by the caller - raise - - async def handle_message(self, user_id, user_message): - if user_id not in self.conversation_history: - self.conversation_history[user_id] = [] - if hasattr(self, 'system_prompt') and self.system_prompt: - self.conversation_history[user_id].append({"role": "system", "content": self.system_prompt}) - - self.conversation_history[user_id].append({"role": "user", "content": user_message}) - messages = self.conversation_history[user_id] - - response = self.get_chat_response(messages) - - # Ensure response.choices[0].message exists before appending - if response.choices and response.choices[0].message: - messages.append(response.choices[0].message) # Append the assistant's response message - else: - logging.error("No valid response choice message from LLM.") - return "Error: Could not get a valid response from the LLM." - - tool_calls_from_response = [] - if response.choices[0].message.tool_calls: - tool_calls_from_response.extend(response.choices[0].message.tool_calls) - - tool_use_count = 0 - MAX_TOOL_ITERATIONS = 5 # Define a max to prevent infinite loops more explicitly - - while tool_calls_from_response and tool_use_count < MAX_TOOL_ITERATIONS: - tool_results_for_model = [] # Results to be sent back to the model - - for tool_call in tool_calls_from_response: - tool_call_id = tool_call.id - function_to_call = tool_call.function - - logging.info(f"Attempting to call tool: {function_to_call.name} with args: {function_to_call.arguments}") - try: - tool_response_content = self.call_tool(function_to_call.name, function_to_call.arguments) - # Ensure tool_response_content is a string for the API - if not isinstance(tool_response_content, str): - tool_response_content = json.dumps(tool_response_content) - except Exception as e: - logging.error(f"Error calling tool {function_to_call.name}: {e}") - tool_response_content = f"Error executing tool {function_to_call.name}: {str(e)}" - - tool_results_for_model.append({ - "role": "tool", - "tool_call_id": tool_call_id, - "name": function_to_call.name, - "content": tool_response_content - }) - - messages.extend(tool_results_for_model) # Add tool responses to message history - - # Get new response from model based on tool execution results - response = self.get_chat_response(messages) - if not (response.choices and response.choices[0].message): - logging.error("No valid response choice message from LLM after tool call.") - return "Error: Could not get a valid response from the LLM after tool call." - - messages.append(response.choices[0].message) # Append new assistant message - - # Check for new tool calls - tool_calls_from_response = [] # Reset for this iteration - if response.choices[0].message.tool_calls: - tool_calls_from_response.extend(response.choices[0].message.tool_calls) - - tool_use_count += 1 - if tool_use_count >= MAX_TOOL_ITERATIONS and tool_calls_from_response: - logging.warning(f"Max tool iterations ({MAX_TOOL_ITERATIONS}) reached. Returning last assistant message.") - # May need to return a message indicating this to user - - # Conversation history management - if len(self.conversation_history[user_id]) > 2000: # Assuming this limit is for messages, not tokens - self.conversation_history[user_id] = self.conversation_history[user_id][-2000:] - - # Return the latest assistant content - final_assistant_message = messages[-1] - return final_assistant_message.content if final_assistant_message.role == "assistant" and final_assistant_message.content else "No content in final message." - - - async def start(self): - logging.info("Gemini Bot started") - # super().start() if Base class start() has common logic - - async def clear(self, user_id): - super().clear_conversation(user_id) # Calls base class method - - # status() method is inherited from BaseTelegramInferenceBot - - async def abort_processing(self, user_id): - if user_id in self.processing_status: - self.processing_status[user_id]["processing"] = False - # It's good practice to also clear the conversation for an aborted state - await self.clear(user_id) - return "Processing aborted and conversation cleared." - else: - # If no specific status, clearing conversation is a safe default - await self.clear(user_id) - return "No active processing found to abort. Conversation cleared." - async def switch_model(self): - current_small_model = os.environ.get("GEMINI_SMALL_MODEL") - current_large_model = os.environ.get("GEMINI_LARGE_MODEL") + current_small_model = os.environ.get("GEMINI_SMALL_MODEL", "gemini-pro") + current_large_model = os.environ.get("GEMINI_LARGE_MODEL", "gemini-1.5-pro-latest") - # Default to small model if current model is not recognized or if it's the large one if self.model == current_large_model or self.model != current_small_model : target_model = current_small_model target_max_tokens = os.environ.get("GEMINI_SMALL_MODEL_MAX_TOKENS") - else: # Current is small, switch to large + else: target_model = current_large_model target_max_tokens = os.environ.get("GEMINI_LARGE_MODEL_MAX_TOKENS") @@ -168,20 +29,14 @@ class GeminiTelegramInferenceBot(BaseTelegramInferenceBot): logging.info(f"Switched to model: {self.model}") return f"Switched to model: {self.model}" -# The main() function and if __name__ == '__main__': block are for standalone execution. -# If this bot is imported as a module, these might not be necessary or might be handled differently. -# For now, keeping them as they were. def main(): if not os.environ.get("GEMINI_API_KEY"): logging.error("FATAL: GEMINI_API_KEY environment variable not set.") return - # Configure logging here if it's the main entry point logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') bot = GeminiTelegramInferenceBot() - # The instantiation of TelegramHelper and running it implies this file can be an entry point. - # If it's purely a module, this main() would be removed. telegram_helper = TelegramHelper(bot) telegram_helper.run() diff --git a/openai_compatible_inference_bot.py b/openai_compatible_inference_bot.py new file mode 100644 index 0000000..eae60fe --- /dev/null +++ b/openai_compatible_inference_bot.py @@ -0,0 +1,133 @@ +import json +import os +import logging +from abc import abstractmethod +from base_telegram_inference_bot import BaseTelegramInferenceBot +from openai import OpenAI + +class OpenAICompatibleInferenceBot(BaseTelegramInferenceBot): + def __init__(self): + super().__init__() + # Client and model configuration will be handled by subclasses + self.client = None + self.model = None + self.max_tokens = None + + def _configure_model_and_tokens(self, model_name, max_tokens_str, default_max_tokens=1000): + self.model = model_name if model_name else "default-model" + try: + self.max_tokens = int(max_tokens_str) if max_tokens_str is not None else default_max_tokens + except ValueError: + logging.error(f"Invalid value for max_tokens: {max_tokens_str}. Using default {default_max_tokens}.") + self.max_tokens = default_max_tokens + logging.info(f"Configured to use model: {self.model} with max_tokens: {self.max_tokens}") + + def get_llm_description(self) -> str: + return f"LLM: {self.model}, Max Tokens: {self.max_tokens}" + + def get_chat_response(self, messages): + if not self.client: + raise ValueError("OpenAI client not initialized. Subclasses must initialize it.") + try: + response = self.client.chat.completions.create( + model=self.model, + messages=messages, + tools=self.functions if hasattr(self, 'functions') and self.functions else None, + tool_choice="auto" if hasattr(self, 'functions') and self.functions else None, + max_tokens=self.max_tokens + ) + return response + except Exception as e: + logging.error(f"API call failed: {e}") + raise + + async def handle_message(self, user_id, user_message): + if user_id not in self.conversation_history: + self.conversation_history[user_id] = [] + if hasattr(self, 'system_prompt') and self.system_prompt: + self.conversation_history[user_id].append({"role": "system", "content": self.system_prompt}) + + self.conversation_history[user_id].append({"role": "user", "content": user_message}) + messages = self.conversation_history[user_id] + + response = self.get_chat_response(messages) + + if not (response.choices and response.choices[0].message): + logging.error("No valid response choice message from LLM.") + return "Error: Could not get a valid response from the LLM." + + messages.append(response.choices[0].message) # Append the assistant's response message + + tool_calls_from_response = [] + if response.choices[0].message.tool_calls: + tool_calls_from_response.extend(response.choices[0].message.tool_calls) + + tool_use_count = 0 + MAX_TOOL_ITERATIONS = 5 + + while tool_calls_from_response and tool_use_count < MAX_TOOL_ITERATIONS: + tool_results_for_model = [] + + for tool_call in tool_calls_from_response: + tool_call_id = tool_call.id + function_to_call = tool_call.function + + logging.info(f"Attempting to call tool: {function_to_call.name} with args: {function_to_call.arguments}") + try: + tool_response_content = self.call_tool(function_to_call.name, function_to_call.arguments) + if not isinstance(tool_response_content, str): + tool_response_content = json.dumps(tool_response_content) + except Exception as e: + logging.error(f"Error calling tool {function_to_call.name}: {e}") + tool_response_content = f"Error executing tool {function_to_call.name}: {str(e)}" + + tool_results_for_model.append({ + "role": "tool", + "tool_call_id": tool_call_id, + "name": function_to_call.name, + "content": tool_response_content + }) + + messages.extend(tool_results_for_model) + + response = self.get_chat_response(messages) + if not (response.choices and response.choices[0].message): + logging.error("No valid response choice message from LLM after tool call.") + return "Error: Could not get a valid response from the LLM after tool call." + + messages.append(response.choices[0].message) + + tool_calls_from_response = [] + if response.choices[0].message.tool_calls: + tool_calls_from_response.extend(response.choices[0].message.tool_calls) + + tool_use_count += 1 + if tool_use_count >= MAX_TOOL_ITERATIONS and tool_calls_from_response: + logging.warning(f"Max tool iterations ({MAX_TOOL_ITERATIONS}) reached. Returning last assistant message.") + + # Conversation history management + # This limit should be reviewed and potentially made configurable + if len(self.conversation_history[user_id]) > 20: # Example limit, adjust as needed + self.conversation_history[user_id] = self.conversation_history[user_id][-20:] + + final_assistant_message = messages[-1] + return final_assistant_message.content if final_assistant_message.role == "assistant" and final_assistant_message.content else "No content in final message." + + async def start(self): + logging.info(f"{self.__class__.__name__} started.") + + async def clear(self, user_id): + super().clear_conversation_history(user_id) + + async def abort_processing(self, user_id): + if user_id in self.processing_status: + self.processing_status[user_id]["processing"] = False + await self.clear(user_id) + return "Processing aborted and conversation cleared." + else: + await self.clear(user_id) + return "No active processing found to abort. Conversation cleared." + + @abstractmethod + async def switch_model(self): + pass diff --git a/telegram_helper.py b/telegram_helper.py index 53d55b7..e3c83cc 100644 --- a/telegram_helper.py +++ b/telegram_helper.py @@ -3,16 +3,21 @@ import logging import sys import asyncio import time -import git from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup from telegram.ext import Application, CommandHandler, MessageHandler, filters, ContextTypes, CallbackQueryHandler from browse_command import browse_command, button_callback class TelegramHelper: + # --- Constants for configurable paths and magic strings --- + REBOOT_CLAUDE_FILE = '.reboot_claude' + REBOOT_FILE = '.doreboot' + CLAUDE_REBOOT_TARGET = 'claude' + HTML_QUOTE_BLOCK_START = '
Thinking...' + HTML_QUOTE_BLOCK_END = '
' + def __init__(self, bot): self.bot = bot self.telegram_bot_token = os.getenv('TELEGRAM_BOT_TOKEN') - self.repo = git.Repo(".") self.start_time = time.time() async def start(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: @@ -23,11 +28,11 @@ class TelegramHelper: async def clear(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: user_id = update.effective_user.id - await self.bot.clear(user_id) + await self.bot.clear_conversation_history(user_id) await update.message.reply_text("Conversation history cleared. Let's start fresh!") async def status(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: - status_message = await self.bot.status() + status_message = await self.bot.get_bot_status() await update.message.reply_text(status_message) async def switch(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: @@ -56,23 +61,22 @@ class TelegramHelper: logging.info(f"Message from user {user_id}: {user_message}") - status_message = await update.message.reply_text("Processing your request...", reply_markup=InlineKeyboardMarkup([[InlineKeyboardButton("Abort", callback_data='abort')]])) - self.bot.processing_status[user_id] = {"processing": True, "message_id": status_message.message_id} + status_message = await update.message.reply_text("Processing your request...", reply_markup=InlineKeyboardMarkup([[InlineKeyboardButton("Abort", callback_data='abort')]]))\ + await self.bot.set_processing_status(user_id, status_message.message_id) response = await self.bot.handle_message(user_id, user_message) await context.bot.delete_message(chat_id=update.effective_chat.id, message_id=status_message.message_id) - del self.bot.processing_status[user_id] - response = response.replace("", "
Thinking...").replace("", "
") - # Return response as html message + await self.bot.clear_processing_status(user_id) + + response = response.replace("", self.HTML_QUOTE_BLOCK_START).replace("", self.HTML_QUOTE_BLOCK_END) + if len(response) > 4096: - # If the response is too long, split it into chunks chunks = [response[i:i + 4096] for i in range(0, len(response), 4096)] for chunk in chunks: await update.message.reply_text(chunk) - # Add a small delay to avoid flooding await asyncio.sleep(0.1) - else: + else: await update.message.reply_text(response) except Exception as e: @@ -82,27 +86,27 @@ class TelegramHelper: async def abort_processing(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: query = update.callback_query await query.answer() - + user_id = query.from_user.id result = await self.bot.abort_processing(user_id) await query.edit_message_text(text=result) async def reboot(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: - user_message = update.message.text.split() # Split the message to check for 'claude' - if len(user_message) > 1 and user_message[1].lower() == 'claude': - open('./.reboot_claude', 'w').close() # Create an empty file + user_message = update.message.text.split() + if len(user_message) > 1 and user_message[1].lower() == self.CLAUDE_REBOOT_TARGET: + open(self.REBOOT_CLAUDE_FILE, 'w').close() if update: await update.message.reply_text("Rebooting the bot...") logging.info("Received reboot command. Exiting process...") - reboot_file_path = "./.doreboot" + reboot_file_path = self.REBOOT_FILE if not os.path.exists(reboot_file_path): with open(reboot_file_path, 'w') as f: f.write(str(update.effective_chat.id) if update else "") sys.exit(0) async def check_doreboot_file(self, application: Application): - reboot_file_path = "./.doreboot" + reboot_file_path = self.REBOOT_FILE if os.path.exists(reboot_file_path): with open(reboot_file_path, 'r') as f: chat_id = f.read().strip() @@ -122,16 +126,12 @@ class TelegramHelper: application.add_handler(CommandHandler("status", self.status)) application.add_handler(CommandHandler("reboot", self.reboot)) application.add_handler(CommandHandler("browse", self.browse)) - application.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, self.handle_message)) + application.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, self.handle_message))\ application.add_handler(CallbackQueryHandler(self.abort_processing, pattern='^abort$')) application.add_handler(CallbackQueryHandler(button_callback, pattern='^(browse|file):')) - + logging.info("Bot is running...") - # Check for .doreboot file and send message if it exists asyncio.get_event_loop().create_task(self.check_doreboot_file(application)) - - # Commenting out the commit checking task - # asyncio.get_event_loop().create_task(self.check_for_new_commits()) - - application.run_polling() \ No newline at end of file + + application.run_polling()