Merge pull request #230 from bucolucas/issue-229-token-limit

feat: Implement pre-inference token limit check
This commit is contained in:
2025-06-06 14:30:36 -05:00
committed by GitHub
3 changed files with 63 additions and 30 deletions
+3
View File
@@ -24,6 +24,9 @@ OPENAI_SMALL_MODEL_MAX_TOKENS=32768
OPENAI_LARGE_MODEL=gpt-4.1
OPENAI_LARGE_MODEL_MAX_TOKENS=32768
_SMALL_MODEL_MAX_INFERENCE_TOKENS=32768
_LARGE_MODEL_MAX_INFERENCE_TOKENS=32768
# Gemini API
GEMINI_API_KEY=your_gemini_api_key_here
GEMINI_API_BASE_URL=https://generativelanguage.googleapis.com/v1beta/openai/
+58 -29
View File
@@ -9,6 +9,7 @@ from tools.base_tool import BaseTool
from telegram_helper import TelegramHelper
import argparse
from inference_bot import InferenceBot
import tiktoken # Added this import
class OpenAICompatibleInferenceBot(InferenceBot):
def __init__(
@@ -21,7 +22,7 @@ class OpenAICompatibleInferenceBot(InferenceBot):
large_model_max_tokens: str | None = None,
allowed_function_tags: list[str] | None = None,
system_prompt_path: str | None = None,
use_large_model: bool = False # New argument
use_large_model: bool = False
):
self.model_config = {
"small_model_name": small_model_name,
@@ -32,8 +33,7 @@ class OpenAICompatibleInferenceBot(InferenceBot):
self.allowed_function_tags = allowed_function_tags if allowed_function_tags else None
self.conversation_history = {}
self._processing_status = {}
self.system_prompt_path = system_prompt_path # Store the prompt path for status
# MODIFIED to pass arguments
self.system_prompt_path = system_prompt_path
self.system_prompt = self.load_system_prompt(
file_path=system_prompt_path
)
@@ -42,6 +42,10 @@ class OpenAICompatibleInferenceBot(InferenceBot):
log_msg = f"Initialized OpenAI compatible client. Target URL: {base_url if base_url else 'OpenAI default'}."
logging.info(log_msg)
# Load inference token limits
self.small_model_max_inference_tokens = int(os.getenv("_SMALL_MODEL_MAX_INFERENCE_TOKENS", "32768"))
self.large_model_max_inference_tokens = int(os.getenv("_LARGE_MODEL_MAX_INFERENCE_TOKENS", "32768"))
# Configure the actual model name and max_tokens for API calls
if use_large_model:
self._configure_model_and_tokens(
@@ -53,12 +57,9 @@ class OpenAICompatibleInferenceBot(InferenceBot):
self.model_config["small_model_name"],
self.model_config["small_model_max_tokens"]
)
@property
def processing_status(self):
"""
An attribute to store the processing status for users.
Example usage in subclass: self.processing_status.get(user_id)
"""
return self._processing_status
def clear_conversation_history(self, user_id):
@@ -71,14 +72,13 @@ class OpenAICompatibleInferenceBot(InferenceBot):
def _configure_model_and_tokens(self, model_name: str | None, max_tokens_str: str | None):
self.model = model_name
try:
# If max_tokens_str is explicitly "None" or empty, treat as None for API default
if max_tokens_str and max_tokens_str.lower() not in ["none", "", "null"]:
self.max_tokens = int(max_tokens_str)
else:
self.max_tokens = None # Use API default by not sending the parameter or sending null
self.max_tokens = None
except ValueError:
logging.warning(f"Invalid value for max_tokens: {max_tokens_str}. Using API default (None)")
self.max_tokens = None # Use API default
self.max_tokens = None
logging.info(f"Configured to use model: {self.model} with max_tokens: {self.max_tokens if self.max_tokens is not None else 'API default'}")
@@ -86,26 +86,39 @@ class OpenAICompatibleInferenceBot(InferenceBot):
client_type = type(self.client).__name__
return f"Client: {client_type}, LLM: {self.model}, Max Tokens: {self.max_tokens if self.max_tokens is not None else 'API default'}"
def _count_tokens(self, messages, model):
"""Returns the number of tokens in a list of messages."""
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
encoding = tiktoken.get_encoding("cl100k_base") # Fallback for unknown models
logging.warning(f"Warning: model {model} not found. Using cl100k_base encoding.")
num_tokens = 0
for message in messages:
num_tokens += 4
for key, value in message.items():
if isinstance(value, str):
num_tokens += len(encoding.encode(value))
if key == "name":
num_tokens += 1
num_tokens += 2
return num_tokens
def get_chat_response(self, messages):
if not self.client:
# This should ideally not be hit if __init__ is successful
logging.error("OpenAI client not initialized before get_chat_response.")
raise ValueError("OpenAI client not initialized.")
try:
# Pass self.max_tokens directly. If None, OpenAI library omits it or handles it.
# Initialize tools filtering based on allowed tags
cleaned_tools = None
if hasattr(self, 'functions') and self.functions:
# Create a copy of functions without "_tags" field
cleaned_tools = []
for func in self.functions:
include_function = False
if not hasattr(self, 'allowed_function_tags') or self.allowed_function_tags is None:
# Include all functions if no tag filtering is specified
include_function = True
else:
# Only include if function has matching tags
tags = func.get("_tags", [])
if any(tag in self.allowed_function_tags for tag in tags):
include_function = True
@@ -137,17 +150,38 @@ class OpenAICompatibleInferenceBot(InferenceBot):
async def handle_message(self, user_id, user_message):
if user_id not in self.conversation_history or not self.conversation_history[user_id]:
self.conversation_history[user_id] = []
if self.system_prompt: # Use the loaded system_prompt
if self.system_prompt:
self.conversation_history[user_id].append({"role": "system", "content": self.system_prompt})
self.conversation_history[user_id].append({"role": "user", "content": user_message})
messages = list(self.conversation_history[user_id]) # Work with a copy for this turn
messages = list(self.conversation_history[user_id])
# Pre-inference token limit check
current_model_is_small = self.model == self.model_config["small_model_name"]
current_model_is_large = self.model == self.model_config["large_model_name"]
inference_token_limit = None
if current_model_is_small:
inference_token_limit = self.small_model_max_inference_tokens
elif current_model_is_large:
inference_token_limit = self.large_model_max_inference_tokens
else:
logging.warning(f"Could not determine inference token limit for model: {self.model}. Proceeding without check.")
if inference_token_limit is not None:
token_count = self._count_tokens(messages, self.model)
if token_count > inference_token_limit:
logging.warning(f"Request for user {user_id} exceeds inference token limit ({token_count}/{inference_token_limit}).")
# Do not persist this message in history as it was not processed by LLM
# Remove the last user message from history before returning, to prevent accumulation
if self.conversation_history[user_id] and self.conversation_history[user_id][-1]["role"] == "user" and self.conversation_history[user_id][-1]["content"] == user_message:
self.conversation_history[user_id].pop()
return "Request exceeds inference token limit. Please use the /clear command, or implement RAG in your application."
response = self.get_chat_response(messages)
if not (response.choices and response.choices[0].message):
logging.error("No valid response choice message from LLM.")
# Persist the user message in history even if LLM fails this turn
self.conversation_history[user_id] = messages
return "Error: Could not get a valid response from the LLM."
@@ -180,9 +214,7 @@ class OpenAICompatibleInferenceBot(InferenceBot):
continue
try:
# Arguments are already a string from the API, self.call_tool expects dict or string
tool_response_content = self.call_tool(function_name, function_args_str)
# Ensure content is string for OpenAI tool role
if not isinstance(tool_response_content, str):
tool_response_content = json.dumps(tool_response_content)
except Exception as e:
@@ -201,7 +233,7 @@ class OpenAICompatibleInferenceBot(InferenceBot):
response = self.get_chat_response(messages)
if not (response.choices and response.choices[0].message):
logging.error("No valid response choice message from LLM after tool call.")
self.conversation_history[user_id] = messages # Persist state before error
self.conversation_history[user_id] = messages
return "Error: Could not get a valid response from the LLM after tool call."
assistant_message = response.choices[0].message
@@ -212,7 +244,6 @@ class OpenAICompatibleInferenceBot(InferenceBot):
tool_use_count += 1
if tool_use_count >= MAX_TOOL_ITERATIONS and tool_calls_from_response:
logging.warning(f"Max tool iterations ({MAX_TOOL_ITERATIONS}) reached. Returning last assistant message.")
# Ensure final content is returned even if max iterations hit with pending tool calls
break
self.conversation_history[user_id] = messages
@@ -224,9 +255,8 @@ class OpenAICompatibleInferenceBot(InferenceBot):
logging.info(f"{self.__class__.__name__} (Model: {self.model}) started.")
async def abort_processing(self, user_id):
# This is a soft abort for OpenAI compatible bots as API calls are synchronous within handle_message
if user_id in self.processing_status:
self.clear_processing_status(user_id) # Use base class method
self.clear_processing_status(user_id)
logging.info(f"Processing aborted for user {user_id}.")
return "Processing aborted. You can send a new message or /clear the conversation."
else:
@@ -278,7 +308,6 @@ class OpenAICompatibleInferenceBot(InferenceBot):
logging.warning(f"Could not read system prompt file {prompt_path_to_try}: {e}. Using default.")
return default_prompt
else:
# This condition now also covers if 'file_path' argument was given but invalid
logging.warning(f"System prompt file {prompt_path_to_try} not found. Using default system prompt.")
return default_prompt
else:
@@ -357,7 +386,7 @@ def main():
parser.add_argument('--messenger', type=str, help='Messenger type (i.e. telegram)', required=True)
parser.add_argument('--persona', type=str, help='Path to system prompt file', required=False)
parser.add_argument('--tools', nargs='+', help='List of allowed function tags', required=False)
parser.add_argument('--use-large-model', action='store_true', help='Use the large model instead of the small model') # New argument
parser.add_argument('--use-large-model', action='store_true', help='Use the large model instead of the small model')
# Add these to launch.json arguments if you want to limit the toolset available: "--tools", "read", "communicate"
# Parse command line arguments
args = parser.parse_args()
@@ -369,7 +398,7 @@ def main():
allowed_function_tags=args.tools if args.tools else None
config_prepend = args.config if args.config else None
messenger = args.messenger if args.messenger else None
use_large_model = args.use_large_model # Get the value of the new argument
use_large_model = args.use_large_model
# Initialize model and max tokens based on the config prepend
if config_prepend:
@@ -389,7 +418,7 @@ def main():
large_model_max_tokens=large_model_max_tokens,
system_prompt_path=system_prompt_path,
allowed_function_tags=allowed_function_tags,
use_large_model=use_large_model # Pass the new argument
use_large_model=use_large_model
)
full_code_file = importlib.import_module(f'{messenger.lower()}_helper')
messenger_helper_class_name = f"{messenger.capitalize()}Helper"
+2 -1
View File
@@ -8,4 +8,5 @@ GitPython==3.1.43
pytest
pytest-cov
google-genai
httpx==0.27.2
httpx==0.27.2
tiktoken