feat: Implement pre-inference token limit check in openai_compatible_inference_bot.py
This commit is contained in:
@@ -9,6 +9,7 @@ from tools.base_tool import BaseTool
|
|||||||
from telegram_helper import TelegramHelper
|
from telegram_helper import TelegramHelper
|
||||||
import argparse
|
import argparse
|
||||||
from inference_bot import InferenceBot
|
from inference_bot import InferenceBot
|
||||||
|
import tiktoken # Added this import
|
||||||
|
|
||||||
class OpenAICompatibleInferenceBot(InferenceBot):
|
class OpenAICompatibleInferenceBot(InferenceBot):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -21,7 +22,7 @@ class OpenAICompatibleInferenceBot(InferenceBot):
|
|||||||
large_model_max_tokens: str | None = None,
|
large_model_max_tokens: str | None = None,
|
||||||
allowed_function_tags: list[str] | None = None,
|
allowed_function_tags: list[str] | None = None,
|
||||||
system_prompt_path: str | None = None,
|
system_prompt_path: str | None = None,
|
||||||
use_large_model: bool = False # New argument
|
use_large_model: bool = False
|
||||||
):
|
):
|
||||||
self.model_config = {
|
self.model_config = {
|
||||||
"small_model_name": small_model_name,
|
"small_model_name": small_model_name,
|
||||||
@@ -32,8 +33,7 @@ class OpenAICompatibleInferenceBot(InferenceBot):
|
|||||||
self.allowed_function_tags = allowed_function_tags if allowed_function_tags else None
|
self.allowed_function_tags = allowed_function_tags if allowed_function_tags else None
|
||||||
self.conversation_history = {}
|
self.conversation_history = {}
|
||||||
self._processing_status = {}
|
self._processing_status = {}
|
||||||
self.system_prompt_path = system_prompt_path # Store the prompt path for status
|
self.system_prompt_path = system_prompt_path
|
||||||
# MODIFIED to pass arguments
|
|
||||||
self.system_prompt = self.load_system_prompt(
|
self.system_prompt = self.load_system_prompt(
|
||||||
file_path=system_prompt_path
|
file_path=system_prompt_path
|
||||||
)
|
)
|
||||||
@@ -42,6 +42,10 @@ class OpenAICompatibleInferenceBot(InferenceBot):
|
|||||||
log_msg = f"Initialized OpenAI compatible client. Target URL: {base_url if base_url else 'OpenAI default'}."
|
log_msg = f"Initialized OpenAI compatible client. Target URL: {base_url if base_url else 'OpenAI default'}."
|
||||||
logging.info(log_msg)
|
logging.info(log_msg)
|
||||||
|
|
||||||
|
# Load inference token limits
|
||||||
|
self.small_model_max_inference_tokens = int(os.getenv("_SMALL_MODEL_MAX_INFERENCE_TOKENS", "32768"))
|
||||||
|
self.large_model_max_inference_tokens = int(os.getenv("_LARGE_MODEL_MAX_INFERENCE_TOKENS", "32768"))
|
||||||
|
|
||||||
# Configure the actual model name and max_tokens for API calls
|
# Configure the actual model name and max_tokens for API calls
|
||||||
if use_large_model:
|
if use_large_model:
|
||||||
self._configure_model_and_tokens(
|
self._configure_model_and_tokens(
|
||||||
@@ -53,12 +57,9 @@ class OpenAICompatibleInferenceBot(InferenceBot):
|
|||||||
self.model_config["small_model_name"],
|
self.model_config["small_model_name"],
|
||||||
self.model_config["small_model_max_tokens"]
|
self.model_config["small_model_max_tokens"]
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def processing_status(self):
|
def processing_status(self):
|
||||||
"""
|
|
||||||
An attribute to store the processing status for users.
|
|
||||||
Example usage in subclass: self.processing_status.get(user_id)
|
|
||||||
"""
|
|
||||||
return self._processing_status
|
return self._processing_status
|
||||||
|
|
||||||
def clear_conversation_history(self, user_id):
|
def clear_conversation_history(self, user_id):
|
||||||
@@ -71,14 +72,13 @@ class OpenAICompatibleInferenceBot(InferenceBot):
|
|||||||
def _configure_model_and_tokens(self, model_name: str | None, max_tokens_str: str | None):
|
def _configure_model_and_tokens(self, model_name: str | None, max_tokens_str: str | None):
|
||||||
self.model = model_name
|
self.model = model_name
|
||||||
try:
|
try:
|
||||||
# If max_tokens_str is explicitly "None" or empty, treat as None for API default
|
|
||||||
if max_tokens_str and max_tokens_str.lower() not in ["none", "", "null"]:
|
if max_tokens_str and max_tokens_str.lower() not in ["none", "", "null"]:
|
||||||
self.max_tokens = int(max_tokens_str)
|
self.max_tokens = int(max_tokens_str)
|
||||||
else:
|
else:
|
||||||
self.max_tokens = None # Use API default by not sending the parameter or sending null
|
self.max_tokens = None
|
||||||
except ValueError:
|
except ValueError:
|
||||||
logging.warning(f"Invalid value for max_tokens: {max_tokens_str}. Using API default (None)")
|
logging.warning(f"Invalid value for max_tokens: {max_tokens_str}. Using API default (None)")
|
||||||
self.max_tokens = None # Use API default
|
self.max_tokens = None
|
||||||
|
|
||||||
logging.info(f"Configured to use model: {self.model} with max_tokens: {self.max_tokens if self.max_tokens is not None else 'API default'}")
|
logging.info(f"Configured to use model: {self.model} with max_tokens: {self.max_tokens if self.max_tokens is not None else 'API default'}")
|
||||||
|
|
||||||
@@ -86,26 +86,39 @@ class OpenAICompatibleInferenceBot(InferenceBot):
|
|||||||
client_type = type(self.client).__name__
|
client_type = type(self.client).__name__
|
||||||
return f"Client: {client_type}, LLM: {self.model}, Max Tokens: {self.max_tokens if self.max_tokens is not None else 'API default'}"
|
return f"Client: {client_type}, LLM: {self.model}, Max Tokens: {self.max_tokens if self.max_tokens is not None else 'API default'}"
|
||||||
|
|
||||||
|
def _count_tokens(self, messages, model):
|
||||||
|
"""Returns the number of tokens in a list of messages."""
|
||||||
|
try:
|
||||||
|
encoding = tiktoken.encoding_for_model(model)
|
||||||
|
except KeyError:
|
||||||
|
encoding = tiktoken.get_encoding("cl100k_base") # Fallback for unknown models
|
||||||
|
logging.warning(f"Warning: model {model} not found. Using cl100k_base encoding.")
|
||||||
|
|
||||||
|
num_tokens = 0
|
||||||
|
for message in messages:
|
||||||
|
num_tokens += 4
|
||||||
|
for key, value in message.items():
|
||||||
|
if isinstance(value, str):
|
||||||
|
num_tokens += len(encoding.encode(value))
|
||||||
|
if key == "name":
|
||||||
|
num_tokens += 1
|
||||||
|
num_tokens += 2
|
||||||
|
return num_tokens
|
||||||
|
|
||||||
def get_chat_response(self, messages):
|
def get_chat_response(self, messages):
|
||||||
if not self.client:
|
if not self.client:
|
||||||
# This should ideally not be hit if __init__ is successful
|
|
||||||
logging.error("OpenAI client not initialized before get_chat_response.")
|
logging.error("OpenAI client not initialized before get_chat_response.")
|
||||||
raise ValueError("OpenAI client not initialized.")
|
raise ValueError("OpenAI client not initialized.")
|
||||||
try:
|
try:
|
||||||
# Pass self.max_tokens directly. If None, OpenAI library omits it or handles it.
|
|
||||||
# Initialize tools filtering based on allowed tags
|
|
||||||
cleaned_tools = None
|
cleaned_tools = None
|
||||||
if hasattr(self, 'functions') and self.functions:
|
if hasattr(self, 'functions') and self.functions:
|
||||||
# Create a copy of functions without "_tags" field
|
|
||||||
cleaned_tools = []
|
cleaned_tools = []
|
||||||
for func in self.functions:
|
for func in self.functions:
|
||||||
include_function = False
|
include_function = False
|
||||||
|
|
||||||
if not hasattr(self, 'allowed_function_tags') or self.allowed_function_tags is None:
|
if not hasattr(self, 'allowed_function_tags') or self.allowed_function_tags is None:
|
||||||
# Include all functions if no tag filtering is specified
|
|
||||||
include_function = True
|
include_function = True
|
||||||
else:
|
else:
|
||||||
# Only include if function has matching tags
|
|
||||||
tags = func.get("_tags", [])
|
tags = func.get("_tags", [])
|
||||||
if any(tag in self.allowed_function_tags for tag in tags):
|
if any(tag in self.allowed_function_tags for tag in tags):
|
||||||
include_function = True
|
include_function = True
|
||||||
@@ -137,17 +150,38 @@ class OpenAICompatibleInferenceBot(InferenceBot):
|
|||||||
async def handle_message(self, user_id, user_message):
|
async def handle_message(self, user_id, user_message):
|
||||||
if user_id not in self.conversation_history or not self.conversation_history[user_id]:
|
if user_id not in self.conversation_history or not self.conversation_history[user_id]:
|
||||||
self.conversation_history[user_id] = []
|
self.conversation_history[user_id] = []
|
||||||
if self.system_prompt: # Use the loaded system_prompt
|
if self.system_prompt:
|
||||||
self.conversation_history[user_id].append({"role": "system", "content": self.system_prompt})
|
self.conversation_history[user_id].append({"role": "system", "content": self.system_prompt})
|
||||||
|
|
||||||
self.conversation_history[user_id].append({"role": "user", "content": user_message})
|
self.conversation_history[user_id].append({"role": "user", "content": user_message})
|
||||||
messages = list(self.conversation_history[user_id]) # Work with a copy for this turn
|
messages = list(self.conversation_history[user_id])
|
||||||
|
|
||||||
|
# Pre-inference token limit check
|
||||||
|
current_model_is_small = self.model == self.model_config["small_model_name"]
|
||||||
|
current_model_is_large = self.model == self.model_config["large_model_name"]
|
||||||
|
|
||||||
|
inference_token_limit = None
|
||||||
|
if current_model_is_small:
|
||||||
|
inference_token_limit = self.small_model_max_inference_tokens
|
||||||
|
elif current_model_is_large:
|
||||||
|
inference_token_limit = self.large_model_max_inference_tokens
|
||||||
|
else:
|
||||||
|
logging.warning(f"Could not determine inference token limit for model: {self.model}. Proceeding without check.")
|
||||||
|
|
||||||
|
if inference_token_limit is not None:
|
||||||
|
token_count = self._count_tokens(messages, self.model)
|
||||||
|
if token_count > inference_token_limit:
|
||||||
|
logging.warning(f"Request for user {user_id} exceeds inference token limit ({token_count}/{inference_token_limit}).")
|
||||||
|
# Do not persist this message in history as it was not processed by LLM
|
||||||
|
# Remove the last user message from history before returning, to prevent accumulation
|
||||||
|
if self.conversation_history[user_id] and self.conversation_history[user_id][-1]["role"] == "user" and self.conversation_history[user_id][-1]["content"] == user_message:
|
||||||
|
self.conversation_history[user_id].pop()
|
||||||
|
return "Request exceeds inference token limit. Please use the /clear command, or implement RAG in your application."
|
||||||
|
|
||||||
response = self.get_chat_response(messages)
|
response = self.get_chat_response(messages)
|
||||||
|
|
||||||
if not (response.choices and response.choices[0].message):
|
if not (response.choices and response.choices[0].message):
|
||||||
logging.error("No valid response choice message from LLM.")
|
logging.error("No valid response choice message from LLM.")
|
||||||
# Persist the user message in history even if LLM fails this turn
|
|
||||||
self.conversation_history[user_id] = messages
|
self.conversation_history[user_id] = messages
|
||||||
return "Error: Could not get a valid response from the LLM."
|
return "Error: Could not get a valid response from the LLM."
|
||||||
|
|
||||||
@@ -180,9 +214,7 @@ class OpenAICompatibleInferenceBot(InferenceBot):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Arguments are already a string from the API, self.call_tool expects dict or string
|
|
||||||
tool_response_content = self.call_tool(function_name, function_args_str)
|
tool_response_content = self.call_tool(function_name, function_args_str)
|
||||||
# Ensure content is string for OpenAI tool role
|
|
||||||
if not isinstance(tool_response_content, str):
|
if not isinstance(tool_response_content, str):
|
||||||
tool_response_content = json.dumps(tool_response_content)
|
tool_response_content = json.dumps(tool_response_content)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -201,7 +233,7 @@ class OpenAICompatibleInferenceBot(InferenceBot):
|
|||||||
response = self.get_chat_response(messages)
|
response = self.get_chat_response(messages)
|
||||||
if not (response.choices and response.choices[0].message):
|
if not (response.choices and response.choices[0].message):
|
||||||
logging.error("No valid response choice message from LLM after tool call.")
|
logging.error("No valid response choice message from LLM after tool call.")
|
||||||
self.conversation_history[user_id] = messages # Persist state before error
|
self.conversation_history[user_id] = messages
|
||||||
return "Error: Could not get a valid response from the LLM after tool call."
|
return "Error: Could not get a valid response from the LLM after tool call."
|
||||||
|
|
||||||
assistant_message = response.choices[0].message
|
assistant_message = response.choices[0].message
|
||||||
@@ -212,7 +244,6 @@ class OpenAICompatibleInferenceBot(InferenceBot):
|
|||||||
tool_use_count += 1
|
tool_use_count += 1
|
||||||
if tool_use_count >= MAX_TOOL_ITERATIONS and tool_calls_from_response:
|
if tool_use_count >= MAX_TOOL_ITERATIONS and tool_calls_from_response:
|
||||||
logging.warning(f"Max tool iterations ({MAX_TOOL_ITERATIONS}) reached. Returning last assistant message.")
|
logging.warning(f"Max tool iterations ({MAX_TOOL_ITERATIONS}) reached. Returning last assistant message.")
|
||||||
# Ensure final content is returned even if max iterations hit with pending tool calls
|
|
||||||
break
|
break
|
||||||
|
|
||||||
self.conversation_history[user_id] = messages
|
self.conversation_history[user_id] = messages
|
||||||
@@ -224,9 +255,8 @@ class OpenAICompatibleInferenceBot(InferenceBot):
|
|||||||
logging.info(f"{self.__class__.__name__} (Model: {self.model}) started.")
|
logging.info(f"{self.__class__.__name__} (Model: {self.model}) started.")
|
||||||
|
|
||||||
async def abort_processing(self, user_id):
|
async def abort_processing(self, user_id):
|
||||||
# This is a soft abort for OpenAI compatible bots as API calls are synchronous within handle_message
|
|
||||||
if user_id in self.processing_status:
|
if user_id in self.processing_status:
|
||||||
self.clear_processing_status(user_id) # Use base class method
|
self.clear_processing_status(user_id)
|
||||||
logging.info(f"Processing aborted for user {user_id}.")
|
logging.info(f"Processing aborted for user {user_id}.")
|
||||||
return "Processing aborted. You can send a new message or /clear the conversation."
|
return "Processing aborted. You can send a new message or /clear the conversation."
|
||||||
else:
|
else:
|
||||||
@@ -278,7 +308,6 @@ class OpenAICompatibleInferenceBot(InferenceBot):
|
|||||||
logging.warning(f"Could not read system prompt file {prompt_path_to_try}: {e}. Using default.")
|
logging.warning(f"Could not read system prompt file {prompt_path_to_try}: {e}. Using default.")
|
||||||
return default_prompt
|
return default_prompt
|
||||||
else:
|
else:
|
||||||
# This condition now also covers if 'file_path' argument was given but invalid
|
|
||||||
logging.warning(f"System prompt file {prompt_path_to_try} not found. Using default system prompt.")
|
logging.warning(f"System prompt file {prompt_path_to_try} not found. Using default system prompt.")
|
||||||
return default_prompt
|
return default_prompt
|
||||||
else:
|
else:
|
||||||
@@ -357,7 +386,7 @@ def main():
|
|||||||
parser.add_argument('--messenger', type=str, help='Messenger type (i.e. telegram)', required=True)
|
parser.add_argument('--messenger', type=str, help='Messenger type (i.e. telegram)', required=True)
|
||||||
parser.add_argument('--persona', type=str, help='Path to system prompt file', required=False)
|
parser.add_argument('--persona', type=str, help='Path to system prompt file', required=False)
|
||||||
parser.add_argument('--tools', nargs='+', help='List of allowed function tags', required=False)
|
parser.add_argument('--tools', nargs='+', help='List of allowed function tags', required=False)
|
||||||
parser.add_argument('--use-large-model', action='store_true', help='Use the large model instead of the small model') # New argument
|
parser.add_argument('--use-large-model', action='store_true', help='Use the large model instead of the small model')
|
||||||
# Add these to launch.json arguments if you want to limit the toolset available: "--tools", "read", "communicate"
|
# Add these to launch.json arguments if you want to limit the toolset available: "--tools", "read", "communicate"
|
||||||
# Parse command line arguments
|
# Parse command line arguments
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@@ -369,7 +398,7 @@ def main():
|
|||||||
allowed_function_tags=args.tools if args.tools else None
|
allowed_function_tags=args.tools if args.tools else None
|
||||||
config_prepend = args.config if args.config else None
|
config_prepend = args.config if args.config else None
|
||||||
messenger = args.messenger if args.messenger else None
|
messenger = args.messenger if args.messenger else None
|
||||||
use_large_model = args.use_large_model # Get the value of the new argument
|
use_large_model = args.use_large_model
|
||||||
|
|
||||||
# Initialize model and max tokens based on the config prepend
|
# Initialize model and max tokens based on the config prepend
|
||||||
if config_prepend:
|
if config_prepend:
|
||||||
@@ -389,7 +418,7 @@ def main():
|
|||||||
large_model_max_tokens=large_model_max_tokens,
|
large_model_max_tokens=large_model_max_tokens,
|
||||||
system_prompt_path=system_prompt_path,
|
system_prompt_path=system_prompt_path,
|
||||||
allowed_function_tags=allowed_function_tags,
|
allowed_function_tags=allowed_function_tags,
|
||||||
use_large_model=use_large_model # Pass the new argument
|
use_large_model=use_large_model
|
||||||
)
|
)
|
||||||
full_code_file = importlib.import_module(f'{messenger.lower()}_helper')
|
full_code_file = importlib.import_module(f'{messenger.lower()}_helper')
|
||||||
messenger_helper_class_name = f"{messenger.capitalize()}Helper"
|
messenger_helper_class_name = f"{messenger.capitalize()}Helper"
|
||||||
|
|||||||
Reference in New Issue
Block a user