diff --git a/.env.example b/.env.example index 2b977f2..db30dd2 100644 --- a/.env.example +++ b/.env.example @@ -24,6 +24,9 @@ OPENAI_SMALL_MODEL_MAX_TOKENS=32768 OPENAI_LARGE_MODEL=gpt-4.1 OPENAI_LARGE_MODEL_MAX_TOKENS=32768 +_SMALL_MODEL_MAX_INFERENCE_TOKENS=32768 +_LARGE_MODEL_MAX_INFERENCE_TOKENS=32768 + # Gemini API GEMINI_API_KEY=your_gemini_api_key_here GEMINI_API_BASE_URL=https://generativelanguage.googleapis.com/v1beta/openai/ diff --git a/openai_compatible_inference_bot.py b/openai_compatible_inference_bot.py index 9e34141..9a64ee9 100644 --- a/openai_compatible_inference_bot.py +++ b/openai_compatible_inference_bot.py @@ -9,6 +9,7 @@ from tools.base_tool import BaseTool from telegram_helper import TelegramHelper import argparse from inference_bot import InferenceBot +import tiktoken # Added this import class OpenAICompatibleInferenceBot(InferenceBot): def __init__( @@ -21,7 +22,7 @@ class OpenAICompatibleInferenceBot(InferenceBot): large_model_max_tokens: str | None = None, allowed_function_tags: list[str] | None = None, system_prompt_path: str | None = None, - use_large_model: bool = False # New argument + use_large_model: bool = False ): self.model_config = { "small_model_name": small_model_name, @@ -32,8 +33,7 @@ class OpenAICompatibleInferenceBot(InferenceBot): self.allowed_function_tags = allowed_function_tags if allowed_function_tags else None self.conversation_history = {} self._processing_status = {} - self.system_prompt_path = system_prompt_path # Store the prompt path for status - # MODIFIED to pass arguments + self.system_prompt_path = system_prompt_path self.system_prompt = self.load_system_prompt( file_path=system_prompt_path ) @@ -42,6 +42,10 @@ class OpenAICompatibleInferenceBot(InferenceBot): log_msg = f"Initialized OpenAI compatible client. Target URL: {base_url if base_url else 'OpenAI default'}." logging.info(log_msg) + # Load inference token limits + self.small_model_max_inference_tokens = int(os.getenv("_SMALL_MODEL_MAX_INFERENCE_TOKENS", "32768")) + self.large_model_max_inference_tokens = int(os.getenv("_LARGE_MODEL_MAX_INFERENCE_TOKENS", "32768")) + # Configure the actual model name and max_tokens for API calls if use_large_model: self._configure_model_and_tokens( @@ -53,12 +57,9 @@ class OpenAICompatibleInferenceBot(InferenceBot): self.model_config["small_model_name"], self.model_config["small_model_max_tokens"] ) + @property def processing_status(self): - """ - An attribute to store the processing status for users. - Example usage in subclass: self.processing_status.get(user_id) - """ return self._processing_status def clear_conversation_history(self, user_id): @@ -71,14 +72,13 @@ class OpenAICompatibleInferenceBot(InferenceBot): def _configure_model_and_tokens(self, model_name: str | None, max_tokens_str: str | None): self.model = model_name try: - # If max_tokens_str is explicitly "None" or empty, treat as None for API default if max_tokens_str and max_tokens_str.lower() not in ["none", "", "null"]: self.max_tokens = int(max_tokens_str) else: - self.max_tokens = None # Use API default by not sending the parameter or sending null + self.max_tokens = None except ValueError: logging.warning(f"Invalid value for max_tokens: {max_tokens_str}. Using API default (None)") - self.max_tokens = None # Use API default + self.max_tokens = None logging.info(f"Configured to use model: {self.model} with max_tokens: {self.max_tokens if self.max_tokens is not None else 'API default'}") @@ -86,26 +86,39 @@ class OpenAICompatibleInferenceBot(InferenceBot): client_type = type(self.client).__name__ return f"Client: {client_type}, LLM: {self.model}, Max Tokens: {self.max_tokens if self.max_tokens is not None else 'API default'}" + def _count_tokens(self, messages, model): + """Returns the number of tokens in a list of messages.""" + try: + encoding = tiktoken.encoding_for_model(model) + except KeyError: + encoding = tiktoken.get_encoding("cl100k_base") # Fallback for unknown models + logging.warning(f"Warning: model {model} not found. Using cl100k_base encoding.") + + num_tokens = 0 + for message in messages: + num_tokens += 4 + for key, value in message.items(): + if isinstance(value, str): + num_tokens += len(encoding.encode(value)) + if key == "name": + num_tokens += 1 + num_tokens += 2 + return num_tokens + def get_chat_response(self, messages): if not self.client: - # This should ideally not be hit if __init__ is successful logging.error("OpenAI client not initialized before get_chat_response.") raise ValueError("OpenAI client not initialized.") try: - # Pass self.max_tokens directly. If None, OpenAI library omits it or handles it. - # Initialize tools filtering based on allowed tags cleaned_tools = None if hasattr(self, 'functions') and self.functions: - # Create a copy of functions without "_tags" field cleaned_tools = [] for func in self.functions: include_function = False if not hasattr(self, 'allowed_function_tags') or self.allowed_function_tags is None: - # Include all functions if no tag filtering is specified include_function = True else: - # Only include if function has matching tags tags = func.get("_tags", []) if any(tag in self.allowed_function_tags for tag in tags): include_function = True @@ -137,17 +150,38 @@ class OpenAICompatibleInferenceBot(InferenceBot): async def handle_message(self, user_id, user_message): if user_id not in self.conversation_history or not self.conversation_history[user_id]: self.conversation_history[user_id] = [] - if self.system_prompt: # Use the loaded system_prompt + if self.system_prompt: self.conversation_history[user_id].append({"role": "system", "content": self.system_prompt}) self.conversation_history[user_id].append({"role": "user", "content": user_message}) - messages = list(self.conversation_history[user_id]) # Work with a copy for this turn + messages = list(self.conversation_history[user_id]) + + # Pre-inference token limit check + current_model_is_small = self.model == self.model_config["small_model_name"] + current_model_is_large = self.model == self.model_config["large_model_name"] + + inference_token_limit = None + if current_model_is_small: + inference_token_limit = self.small_model_max_inference_tokens + elif current_model_is_large: + inference_token_limit = self.large_model_max_inference_tokens + else: + logging.warning(f"Could not determine inference token limit for model: {self.model}. Proceeding without check.") + + if inference_token_limit is not None: + token_count = self._count_tokens(messages, self.model) + if token_count > inference_token_limit: + logging.warning(f"Request for user {user_id} exceeds inference token limit ({token_count}/{inference_token_limit}).") + # Do not persist this message in history as it was not processed by LLM + # Remove the last user message from history before returning, to prevent accumulation + if self.conversation_history[user_id] and self.conversation_history[user_id][-1]["role"] == "user" and self.conversation_history[user_id][-1]["content"] == user_message: + self.conversation_history[user_id].pop() + return "Request exceeds inference token limit. Please use the /clear command, or implement RAG in your application." response = self.get_chat_response(messages) if not (response.choices and response.choices[0].message): logging.error("No valid response choice message from LLM.") - # Persist the user message in history even if LLM fails this turn self.conversation_history[user_id] = messages return "Error: Could not get a valid response from the LLM." @@ -180,9 +214,7 @@ class OpenAICompatibleInferenceBot(InferenceBot): continue try: - # Arguments are already a string from the API, self.call_tool expects dict or string tool_response_content = self.call_tool(function_name, function_args_str) - # Ensure content is string for OpenAI tool role if not isinstance(tool_response_content, str): tool_response_content = json.dumps(tool_response_content) except Exception as e: @@ -201,7 +233,7 @@ class OpenAICompatibleInferenceBot(InferenceBot): response = self.get_chat_response(messages) if not (response.choices and response.choices[0].message): logging.error("No valid response choice message from LLM after tool call.") - self.conversation_history[user_id] = messages # Persist state before error + self.conversation_history[user_id] = messages return "Error: Could not get a valid response from the LLM after tool call." assistant_message = response.choices[0].message @@ -212,7 +244,6 @@ class OpenAICompatibleInferenceBot(InferenceBot): tool_use_count += 1 if tool_use_count >= MAX_TOOL_ITERATIONS and tool_calls_from_response: logging.warning(f"Max tool iterations ({MAX_TOOL_ITERATIONS}) reached. Returning last assistant message.") - # Ensure final content is returned even if max iterations hit with pending tool calls break self.conversation_history[user_id] = messages @@ -224,9 +255,8 @@ class OpenAICompatibleInferenceBot(InferenceBot): logging.info(f"{self.__class__.__name__} (Model: {self.model}) started.") async def abort_processing(self, user_id): - # This is a soft abort for OpenAI compatible bots as API calls are synchronous within handle_message if user_id in self.processing_status: - self.clear_processing_status(user_id) # Use base class method + self.clear_processing_status(user_id) logging.info(f"Processing aborted for user {user_id}.") return "Processing aborted. You can send a new message or /clear the conversation." else: @@ -278,7 +308,6 @@ class OpenAICompatibleInferenceBot(InferenceBot): logging.warning(f"Could not read system prompt file {prompt_path_to_try}: {e}. Using default.") return default_prompt else: - # This condition now also covers if 'file_path' argument was given but invalid logging.warning(f"System prompt file {prompt_path_to_try} not found. Using default system prompt.") return default_prompt else: @@ -357,7 +386,7 @@ def main(): parser.add_argument('--messenger', type=str, help='Messenger type (i.e. telegram)', required=True) parser.add_argument('--persona', type=str, help='Path to system prompt file', required=False) parser.add_argument('--tools', nargs='+', help='List of allowed function tags', required=False) - parser.add_argument('--use-large-model', action='store_true', help='Use the large model instead of the small model') # New argument + parser.add_argument('--use-large-model', action='store_true', help='Use the large model instead of the small model') # Add these to launch.json arguments if you want to limit the toolset available: "--tools", "read", "communicate" # Parse command line arguments args = parser.parse_args() @@ -369,7 +398,7 @@ def main(): allowed_function_tags=args.tools if args.tools else None config_prepend = args.config if args.config else None messenger = args.messenger if args.messenger else None - use_large_model = args.use_large_model # Get the value of the new argument + use_large_model = args.use_large_model # Initialize model and max tokens based on the config prepend if config_prepend: @@ -389,7 +418,7 @@ def main(): large_model_max_tokens=large_model_max_tokens, system_prompt_path=system_prompt_path, allowed_function_tags=allowed_function_tags, - use_large_model=use_large_model # Pass the new argument + use_large_model=use_large_model ) full_code_file = importlib.import_module(f'{messenger.lower()}_helper') messenger_helper_class_name = f"{messenger.capitalize()}Helper" diff --git a/requirements.txt b/requirements.txt index 8feeb13..d76bbcd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,4 +8,5 @@ GitPython==3.1.43 pytest pytest-cov google-genai -httpx==0.27.2 \ No newline at end of file +httpx==0.27.2 +tiktoken