feat: robust 16k/32k context management with request-only tool-call summarization and budget enforcement
- Add normalization of messages before API calls - Implement token projection and enforce budget for 16k/32k windows - Summarize only tool-call request arguments (not responses) when over budget - Optionally elide redundant code blocks in old assistant messages as last-resort trimming - Default small-model limit to 16k, large to 32k; reserve space for response tokens - Keep core behavior and tool execution unchanged
This commit is contained in:
@@ -3,6 +3,7 @@ import json
|
||||
import os
|
||||
import logging
|
||||
import inspect
|
||||
import re
|
||||
from abc import abstractmethod
|
||||
from openai import OpenAI
|
||||
from tools.base_tool import BaseTool
|
||||
@@ -42,8 +43,8 @@ class OpenAICompatibleInferenceBot(InferenceBot):
|
||||
log_msg = f"Initialized OpenAI compatible client. Target URL: {base_url if base_url else 'OpenAI default'}."
|
||||
logging.info(log_msg)
|
||||
|
||||
# Load inference token limits
|
||||
self.small_model_max_inference_tokens = int(os.getenv("_SMALL_MODEL_MAX_INFERENCE_TOKENS", "32768"))
|
||||
# Load inference token limits (defaults: small=16k, large=32k)
|
||||
self.small_model_max_inference_tokens = int(os.getenv("_SMALL_MODEL_MAX_INFERENCE_TOKENS", "16384"))
|
||||
self.large_model_max_inference_tokens = int(os.getenv("_LARGE_MODEL_MAX_INFERENCE_TOKENS", "32768"))
|
||||
|
||||
# Configure the actual model name and max_tokens for API calls
|
||||
@@ -86,26 +87,205 @@ class OpenAICompatibleInferenceBot(InferenceBot):
|
||||
client_type = type(self.client).__name__
|
||||
return f"Client: {client_type}, LLM: {self.model}, Max Tokens: {self.max_tokens if self.max_tokens is not None else 'API default'}"
|
||||
|
||||
def _count_tokens(self, messages, model):
|
||||
"""Returns the number of tokens in a list of messages."""
|
||||
def _encoding_for_model(self, model: str | None):
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
return tiktoken.encoding_for_model(model) if model else tiktoken.get_encoding("cl100k_base")
|
||||
except KeyError:
|
||||
encoding = tiktoken.get_encoding("cl100k_base") # Fallback for unknown models
|
||||
logging.warning(f"Warning: model {model} not found. Using cl100k_base encoding.")
|
||||
return tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
def _normalize_messages(self, messages):
|
||||
"""Return a list of plain dict chat messages acceptable by the API.
|
||||
- Converts OpenAI SDK message objects into dicts
|
||||
- Preserves tool_calls structure where present
|
||||
"""
|
||||
normalized = []
|
||||
for m in messages:
|
||||
if isinstance(m, dict):
|
||||
# Ensure only known keys are present; copy shallowly
|
||||
entry = {k: v for k, v in m.items() if k in {"role", "content", "name", "tool_call_id", "tool_calls"}}
|
||||
normalized.append(entry)
|
||||
else:
|
||||
# Likely an OpenAI message object
|
||||
role = getattr(m, "role", None)
|
||||
content = getattr(m, "content", None)
|
||||
name = getattr(m, "name", None)
|
||||
tool_calls = []
|
||||
tc_list = getattr(m, "tool_calls", None)
|
||||
if tc_list:
|
||||
for tc in tc_list:
|
||||
try:
|
||||
tool_calls.append({
|
||||
"id": getattr(tc, "id", None),
|
||||
"type": getattr(tc, "type", "function"),
|
||||
"function": {
|
||||
"name": getattr(getattr(tc, "function", None), "name", None),
|
||||
"arguments": getattr(getattr(tc, "function", None), "arguments", "{}"),
|
||||
}
|
||||
})
|
||||
except Exception:
|
||||
# Best-effort fallback
|
||||
tool_calls.append({"id": None, "type": "function", "function": {"name": "unknown", "arguments": "{}"}})
|
||||
entry = {"role": role, "content": content}
|
||||
if name:
|
||||
entry["name"] = name
|
||||
if tool_calls:
|
||||
entry["tool_calls"] = tool_calls
|
||||
normalized.append(entry)
|
||||
return normalized
|
||||
|
||||
def _estimate_tokens(self, messages):
|
||||
"""Estimate tokens for messages with tiktoken, including tool_calls arguments.
|
||||
Based on OpenAI's chat token counting rules approximation.
|
||||
"""
|
||||
enc = self._encoding_for_model(self.model)
|
||||
num_tokens = 0
|
||||
for message in messages:
|
||||
num_tokens += 4
|
||||
if hasattr(message, "items"):
|
||||
for key, value in message.items():
|
||||
if isinstance(value, str):
|
||||
num_tokens += len(encoding.encode(value))
|
||||
if key == "name":
|
||||
num_tokens += 1
|
||||
num_tokens += 2
|
||||
for m in messages:
|
||||
num_tokens += 4 # per-message overhead
|
||||
if not isinstance(m, dict):
|
||||
continue
|
||||
# role/content
|
||||
for key in ("role", "name", "content"):
|
||||
v = m.get(key)
|
||||
if isinstance(v, str):
|
||||
num_tokens += len(enc.encode(v))
|
||||
# tool calls request portion (arguments)
|
||||
tcs = m.get("tool_calls")
|
||||
if tcs and isinstance(tcs, list):
|
||||
# approximate cost of the tool_calls JSON the model sees
|
||||
for tc in tcs:
|
||||
fn = tc.get("function", {}) if isinstance(tc, dict) else {}
|
||||
fname = fn.get("name")
|
||||
fargs = fn.get("arguments")
|
||||
if isinstance(fname, str):
|
||||
num_tokens += len(enc.encode(fname))
|
||||
if isinstance(fargs, str):
|
||||
num_tokens += len(enc.encode(fargs))
|
||||
num_tokens += 2 # assistant priming
|
||||
return num_tokens
|
||||
|
||||
def _get_inference_limit(self):
|
||||
current_model_is_small = self.model == self.model_config["small_model_name"]
|
||||
current_model_is_large = self.model == self.model_config["large_model_name"]
|
||||
if current_model_is_small:
|
||||
return self.small_model_max_inference_tokens
|
||||
if current_model_is_large:
|
||||
return self.large_model_max_inference_tokens
|
||||
logging.warning(f"Could not determine inference token limit for model: {self.model}. Proceeding without check.")
|
||||
return None
|
||||
|
||||
def _summarize_tool_args(self, args_str: str, max_chars: int = 512) -> str:
|
||||
"""Summarize tool-call request arguments without altering tool responses.
|
||||
- If JSON, keep keys and short previews of string values.
|
||||
- If plain string, truncate with an indicator.
|
||||
"""
|
||||
try:
|
||||
parsed = json.loads(args_str)
|
||||
if isinstance(parsed, dict):
|
||||
summary = {}
|
||||
for k, v in parsed.items():
|
||||
if isinstance(v, str):
|
||||
if len(v) > 160:
|
||||
summary[k] = v[:120] + f"... [len={len(v)}]"
|
||||
else:
|
||||
summary[k] = v
|
||||
elif isinstance(v, (list, dict)):
|
||||
# structural summary only
|
||||
summary[k] = f"<{type(v).__name__} size={len(v)}>"
|
||||
else:
|
||||
summary[k] = v
|
||||
s = json.dumps(summary, ensure_ascii=False)
|
||||
if len(s) > max_chars:
|
||||
s = s[: max_chars - 20] + "... [summarized]"
|
||||
return s
|
||||
except Exception:
|
||||
pass
|
||||
# Fallback: truncate raw string
|
||||
return (args_str[: max_chars - 20] + "... [summarized]") if len(args_str) > max_chars else args_str
|
||||
|
||||
def _summarize_tool_call_requests_in_messages(self, messages):
|
||||
changed = False
|
||||
for m in messages:
|
||||
if isinstance(m, dict) and m.get("tool_calls"):
|
||||
new_tool_calls = []
|
||||
for tc in m["tool_calls"]:
|
||||
if not isinstance(tc, dict):
|
||||
new_tool_calls.append(tc)
|
||||
continue
|
||||
fn = tc.get("function", {})
|
||||
args = fn.get("arguments")
|
||||
if isinstance(args, str) and args and len(args) > 700:
|
||||
# summarize long request arguments only
|
||||
fn = dict(fn)
|
||||
fn["arguments"] = self._summarize_tool_args(args)
|
||||
tc = dict(tc)
|
||||
tc["function"] = fn
|
||||
changed = True
|
||||
new_tool_calls.append(tc)
|
||||
if changed:
|
||||
m["tool_calls"] = new_tool_calls
|
||||
return changed
|
||||
|
||||
def _elide_redundant_code_blocks(self, messages):
|
||||
"""As a last resort, remove large code blocks from older assistant messages.
|
||||
Keep the latest assistant message intact.
|
||||
"""
|
||||
changed = False
|
||||
# Identify indices of assistant messages
|
||||
assistant_indices = [i for i, m in enumerate(messages) if isinstance(m, dict) and m.get("role") == "assistant" and m.get("content")]
|
||||
if len(assistant_indices) <= 1:
|
||||
return changed
|
||||
# Protect the last assistant message
|
||||
for i in assistant_indices[:-1]:
|
||||
m = messages[i]
|
||||
content = m.get("content")
|
||||
if not isinstance(content, str):
|
||||
continue
|
||||
if "```" in content or "\n " in content:
|
||||
# Replace code blocks fenced by ``` with succinct markers
|
||||
orig = content
|
||||
content = re.sub(r"```[\s\S]*?```", "[code block omitted]", content)
|
||||
# Also collapse long indented blocks
|
||||
content = re.sub(r"(?:\n\s{4,}.+)+", "\n[long block omitted]", content)
|
||||
if content != orig:
|
||||
m["content"] = content
|
||||
changed = True
|
||||
return changed
|
||||
|
||||
def _enforce_budget(self, messages):
|
||||
"""Normalize and enforce token budget by summarizing only tool-call requests first,
|
||||
then eliding redundant code blocks if still too large. Returns normalized messages.
|
||||
"""
|
||||
normalized = self._normalize_messages(messages)
|
||||
limit = self._get_inference_limit()
|
||||
if not limit:
|
||||
return normalized
|
||||
# Reserve space for completion tokens
|
||||
reserve = self.max_tokens if isinstance(self.max_tokens, int) else 1024
|
||||
budget = max(1024, limit - reserve)
|
||||
|
||||
tokens = self._estimate_tokens(normalized)
|
||||
if tokens <= budget:
|
||||
return normalized
|
||||
|
||||
# Step 1: summarize only tool-call request arguments
|
||||
if self._summarize_tool_call_requests_in_messages(normalized):
|
||||
tokens = self._estimate_tokens(normalized)
|
||||
logging.info(f"Applied tool-call request summarization. tokens={tokens}/{budget}")
|
||||
if tokens <= budget:
|
||||
return normalized
|
||||
|
||||
# Step 2: elide redundant code blocks from older assistant messages
|
||||
if self._elide_redundant_code_blocks(normalized):
|
||||
tokens = self._estimate_tokens(normalized)
|
||||
logging.info(f"Elided redundant code blocks. tokens={tokens}/{budget}")
|
||||
if tokens <= budget:
|
||||
return normalized
|
||||
|
||||
# If still over, log and proceed; the API may still reject; caller may choose to abort
|
||||
logging.warning(f"Projected tokens still exceed budget after optimizations: {tokens}/{budget}")
|
||||
return normalized
|
||||
|
||||
def get_chat_response(self, messages):
|
||||
if not self.client:
|
||||
logging.error("OpenAI client not initialized before get_chat_response.")
|
||||
@@ -128,9 +308,12 @@ class OpenAICompatibleInferenceBot(InferenceBot):
|
||||
func_copy = {k: v for k, v in func.items() if k != "_tags"}
|
||||
cleaned_tools.append(func_copy)
|
||||
|
||||
# Enforce token budget prior to API call
|
||||
messages_for_api = self._enforce_budget(messages)
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
messages=messages_for_api,
|
||||
tools=cleaned_tools,
|
||||
tool_choice="auto" if cleaned_tools else None,
|
||||
max_tokens=self.max_tokens,
|
||||
@@ -158,27 +341,20 @@ class OpenAICompatibleInferenceBot(InferenceBot):
|
||||
self.conversation_history[user_id].append({"role": "user", "content": user_message})
|
||||
messages = list(self.conversation_history[user_id])
|
||||
|
||||
# Pre-inference token limit check
|
||||
current_model_is_small = self.model == self.model_config["small_model_name"]
|
||||
current_model_is_large = self.model == self.model_config["large_model_name"]
|
||||
|
||||
inference_token_limit = None
|
||||
if current_model_is_small:
|
||||
inference_token_limit = self.small_model_max_inference_tokens
|
||||
elif current_model_is_large:
|
||||
inference_token_limit = self.large_model_max_inference_tokens
|
||||
else:
|
||||
logging.warning(f"Could not determine inference token limit for model: {self.model}. Proceeding without check.")
|
||||
|
||||
if inference_token_limit is not None:
|
||||
token_count = self._count_tokens(messages, self.model)
|
||||
if token_count > inference_token_limit:
|
||||
logging.warning(f"Request for user {user_id} exceeds inference token limit ({token_count}/{inference_token_limit}).")
|
||||
# Pre-inference token limit check with budgeted optimizations
|
||||
limit = self._get_inference_limit()
|
||||
if limit is not None:
|
||||
# Estimate on normalized messages after applying request-only summarization if needed
|
||||
provisional = self._enforce_budget(messages)
|
||||
token_count = self._estimate_tokens(provisional)
|
||||
reserve = self.max_tokens if isinstance(self.max_tokens, int) else 1024
|
||||
budget = max(1024, limit - reserve)
|
||||
if token_count > budget:
|
||||
logging.warning(f"Request for user {user_id} exceeds inference token budget even after optimizations ({token_count}/{budget}).")
|
||||
# Do not persist this message in history as it was not processed by LLM
|
||||
# Remove the last user message from history before returning, to prevent accumulation
|
||||
if self.conversation_history[user_id] and self.conversation_history[user_id][-1]["role"] == "user" and self.conversation_history[user_id][-1]["content"] == user_message:
|
||||
self.conversation_history[user_id].pop()
|
||||
return "Request exceeds inference token limit. Please use the /clear command, or implement RAG in your application."
|
||||
return "Request exceeds inference token limit after optimization. Please shorten your request, use /clear, or implement RAG in your application."
|
||||
|
||||
response = self.get_chat_response(messages)
|
||||
|
||||
@@ -204,7 +380,7 @@ class OpenAICompatibleInferenceBot(InferenceBot):
|
||||
function_name = function_to_call.name
|
||||
function_args_str = function_to_call.arguments
|
||||
|
||||
logging.info(f"Attempting to call tool: {function_name} with args: {function_args_str}")
|
||||
logging.info(f"Attempting to call tool: {function_name} with args: [request summarized if large]")
|
||||
if function_name not in [f["function"]["name"] for f in self.functions]:
|
||||
logging.warning(f"Tool function {function_name} not found in available functions.")
|
||||
tool_results_for_model.append({
|
||||
@@ -232,6 +408,7 @@ class OpenAICompatibleInferenceBot(InferenceBot):
|
||||
|
||||
messages.extend(tool_results_for_model)
|
||||
|
||||
# Enforce budget before next LLM call (summarize request portion only; preserve tool responses)
|
||||
response = self.get_chat_response(messages)
|
||||
if not (response.choices and response.choices[0].message):
|
||||
logging.error("No valid response choice message from LLM after tool call.")
|
||||
@@ -251,7 +428,7 @@ class OpenAICompatibleInferenceBot(InferenceBot):
|
||||
self.conversation_history[user_id] = messages
|
||||
|
||||
final_assistant_message = messages[-1]
|
||||
return final_assistant_message.content if final_assistant_message.role == "assistant" and final_assistant_message.content is not None else "Assistant did not provide a textual response."
|
||||
return final_assistant_message.content if getattr(final_assistant_message, "role", None) == "assistant" and getattr(final_assistant_message, "content", None) is not None else (final_assistant_message.get("content") if isinstance(final_assistant_message, dict) else "Assistant did not provide a textual response.")
|
||||
|
||||
async def start(self):
|
||||
logging.info(f"{self.__class__.__name__} (Model: {self.model}) started.")
|
||||
|
||||
Reference in New Issue
Block a user