feat: robust 16k/32k context management with request-only tool-call summarization and budget enforcement

- Add normalization of messages before API calls
- Implement token projection and enforce budget for 16k/32k windows
- Summarize only tool-call request arguments (not responses) when over budget
- Optionally elide redundant code blocks in old assistant messages as last-resort trimming
- Default small-model limit to 16k, large to 32k; reserve space for response tokens
- Keep core behavior and tool execution unchanged
This commit is contained in:
2025-08-13 14:25:13 -05:00
parent 7bd1bcf82b
commit b9b07320bc
+213 -36
View File
@@ -3,6 +3,7 @@ import json
import os
import logging
import inspect
import re
from abc import abstractmethod
from openai import OpenAI
from tools.base_tool import BaseTool
@@ -42,8 +43,8 @@ class OpenAICompatibleInferenceBot(InferenceBot):
log_msg = f"Initialized OpenAI compatible client. Target URL: {base_url if base_url else 'OpenAI default'}."
logging.info(log_msg)
# Load inference token limits
self.small_model_max_inference_tokens = int(os.getenv("_SMALL_MODEL_MAX_INFERENCE_TOKENS", "32768"))
# Load inference token limits (defaults: small=16k, large=32k)
self.small_model_max_inference_tokens = int(os.getenv("_SMALL_MODEL_MAX_INFERENCE_TOKENS", "16384"))
self.large_model_max_inference_tokens = int(os.getenv("_LARGE_MODEL_MAX_INFERENCE_TOKENS", "32768"))
# Configure the actual model name and max_tokens for API calls
@@ -86,26 +87,205 @@ class OpenAICompatibleInferenceBot(InferenceBot):
client_type = type(self.client).__name__
return f"Client: {client_type}, LLM: {self.model}, Max Tokens: {self.max_tokens if self.max_tokens is not None else 'API default'}"
def _count_tokens(self, messages, model):
"""Returns the number of tokens in a list of messages."""
def _encoding_for_model(self, model: str | None):
try:
encoding = tiktoken.encoding_for_model(model)
return tiktoken.encoding_for_model(model) if model else tiktoken.get_encoding("cl100k_base")
except KeyError:
encoding = tiktoken.get_encoding("cl100k_base") # Fallback for unknown models
logging.warning(f"Warning: model {model} not found. Using cl100k_base encoding.")
return tiktoken.get_encoding("cl100k_base")
def _normalize_messages(self, messages):
"""Return a list of plain dict chat messages acceptable by the API.
- Converts OpenAI SDK message objects into dicts
- Preserves tool_calls structure where present
"""
normalized = []
for m in messages:
if isinstance(m, dict):
# Ensure only known keys are present; copy shallowly
entry = {k: v for k, v in m.items() if k in {"role", "content", "name", "tool_call_id", "tool_calls"}}
normalized.append(entry)
else:
# Likely an OpenAI message object
role = getattr(m, "role", None)
content = getattr(m, "content", None)
name = getattr(m, "name", None)
tool_calls = []
tc_list = getattr(m, "tool_calls", None)
if tc_list:
for tc in tc_list:
try:
tool_calls.append({
"id": getattr(tc, "id", None),
"type": getattr(tc, "type", "function"),
"function": {
"name": getattr(getattr(tc, "function", None), "name", None),
"arguments": getattr(getattr(tc, "function", None), "arguments", "{}"),
}
})
except Exception:
# Best-effort fallback
tool_calls.append({"id": None, "type": "function", "function": {"name": "unknown", "arguments": "{}"}})
entry = {"role": role, "content": content}
if name:
entry["name"] = name
if tool_calls:
entry["tool_calls"] = tool_calls
normalized.append(entry)
return normalized
def _estimate_tokens(self, messages):
"""Estimate tokens for messages with tiktoken, including tool_calls arguments.
Based on OpenAI's chat token counting rules approximation.
"""
enc = self._encoding_for_model(self.model)
num_tokens = 0
for message in messages:
num_tokens += 4
if hasattr(message, "items"):
for key, value in message.items():
if isinstance(value, str):
num_tokens += len(encoding.encode(value))
if key == "name":
num_tokens += 1
num_tokens += 2
for m in messages:
num_tokens += 4 # per-message overhead
if not isinstance(m, dict):
continue
# role/content
for key in ("role", "name", "content"):
v = m.get(key)
if isinstance(v, str):
num_tokens += len(enc.encode(v))
# tool calls request portion (arguments)
tcs = m.get("tool_calls")
if tcs and isinstance(tcs, list):
# approximate cost of the tool_calls JSON the model sees
for tc in tcs:
fn = tc.get("function", {}) if isinstance(tc, dict) else {}
fname = fn.get("name")
fargs = fn.get("arguments")
if isinstance(fname, str):
num_tokens += len(enc.encode(fname))
if isinstance(fargs, str):
num_tokens += len(enc.encode(fargs))
num_tokens += 2 # assistant priming
return num_tokens
def _get_inference_limit(self):
current_model_is_small = self.model == self.model_config["small_model_name"]
current_model_is_large = self.model == self.model_config["large_model_name"]
if current_model_is_small:
return self.small_model_max_inference_tokens
if current_model_is_large:
return self.large_model_max_inference_tokens
logging.warning(f"Could not determine inference token limit for model: {self.model}. Proceeding without check.")
return None
def _summarize_tool_args(self, args_str: str, max_chars: int = 512) -> str:
"""Summarize tool-call request arguments without altering tool responses.
- If JSON, keep keys and short previews of string values.
- If plain string, truncate with an indicator.
"""
try:
parsed = json.loads(args_str)
if isinstance(parsed, dict):
summary = {}
for k, v in parsed.items():
if isinstance(v, str):
if len(v) > 160:
summary[k] = v[:120] + f"... [len={len(v)}]"
else:
summary[k] = v
elif isinstance(v, (list, dict)):
# structural summary only
summary[k] = f"<{type(v).__name__} size={len(v)}>"
else:
summary[k] = v
s = json.dumps(summary, ensure_ascii=False)
if len(s) > max_chars:
s = s[: max_chars - 20] + "... [summarized]"
return s
except Exception:
pass
# Fallback: truncate raw string
return (args_str[: max_chars - 20] + "... [summarized]") if len(args_str) > max_chars else args_str
def _summarize_tool_call_requests_in_messages(self, messages):
changed = False
for m in messages:
if isinstance(m, dict) and m.get("tool_calls"):
new_tool_calls = []
for tc in m["tool_calls"]:
if not isinstance(tc, dict):
new_tool_calls.append(tc)
continue
fn = tc.get("function", {})
args = fn.get("arguments")
if isinstance(args, str) and args and len(args) > 700:
# summarize long request arguments only
fn = dict(fn)
fn["arguments"] = self._summarize_tool_args(args)
tc = dict(tc)
tc["function"] = fn
changed = True
new_tool_calls.append(tc)
if changed:
m["tool_calls"] = new_tool_calls
return changed
def _elide_redundant_code_blocks(self, messages):
"""As a last resort, remove large code blocks from older assistant messages.
Keep the latest assistant message intact.
"""
changed = False
# Identify indices of assistant messages
assistant_indices = [i for i, m in enumerate(messages) if isinstance(m, dict) and m.get("role") == "assistant" and m.get("content")]
if len(assistant_indices) <= 1:
return changed
# Protect the last assistant message
for i in assistant_indices[:-1]:
m = messages[i]
content = m.get("content")
if not isinstance(content, str):
continue
if "```" in content or "\n " in content:
# Replace code blocks fenced by ``` with succinct markers
orig = content
content = re.sub(r"```[\s\S]*?```", "[code block omitted]", content)
# Also collapse long indented blocks
content = re.sub(r"(?:\n\s{4,}.+)+", "\n[long block omitted]", content)
if content != orig:
m["content"] = content
changed = True
return changed
def _enforce_budget(self, messages):
"""Normalize and enforce token budget by summarizing only tool-call requests first,
then eliding redundant code blocks if still too large. Returns normalized messages.
"""
normalized = self._normalize_messages(messages)
limit = self._get_inference_limit()
if not limit:
return normalized
# Reserve space for completion tokens
reserve = self.max_tokens if isinstance(self.max_tokens, int) else 1024
budget = max(1024, limit - reserve)
tokens = self._estimate_tokens(normalized)
if tokens <= budget:
return normalized
# Step 1: summarize only tool-call request arguments
if self._summarize_tool_call_requests_in_messages(normalized):
tokens = self._estimate_tokens(normalized)
logging.info(f"Applied tool-call request summarization. tokens={tokens}/{budget}")
if tokens <= budget:
return normalized
# Step 2: elide redundant code blocks from older assistant messages
if self._elide_redundant_code_blocks(normalized):
tokens = self._estimate_tokens(normalized)
logging.info(f"Elided redundant code blocks. tokens={tokens}/{budget}")
if tokens <= budget:
return normalized
# If still over, log and proceed; the API may still reject; caller may choose to abort
logging.warning(f"Projected tokens still exceed budget after optimizations: {tokens}/{budget}")
return normalized
def get_chat_response(self, messages):
if not self.client:
logging.error("OpenAI client not initialized before get_chat_response.")
@@ -128,9 +308,12 @@ class OpenAICompatibleInferenceBot(InferenceBot):
func_copy = {k: v for k, v in func.items() if k != "_tags"}
cleaned_tools.append(func_copy)
# Enforce token budget prior to API call
messages_for_api = self._enforce_budget(messages)
response = self.client.chat.completions.create(
model=self.model,
messages=messages,
messages=messages_for_api,
tools=cleaned_tools,
tool_choice="auto" if cleaned_tools else None,
max_tokens=self.max_tokens,
@@ -158,27 +341,20 @@ class OpenAICompatibleInferenceBot(InferenceBot):
self.conversation_history[user_id].append({"role": "user", "content": user_message})
messages = list(self.conversation_history[user_id])
# Pre-inference token limit check
current_model_is_small = self.model == self.model_config["small_model_name"]
current_model_is_large = self.model == self.model_config["large_model_name"]
inference_token_limit = None
if current_model_is_small:
inference_token_limit = self.small_model_max_inference_tokens
elif current_model_is_large:
inference_token_limit = self.large_model_max_inference_tokens
else:
logging.warning(f"Could not determine inference token limit for model: {self.model}. Proceeding without check.")
if inference_token_limit is not None:
token_count = self._count_tokens(messages, self.model)
if token_count > inference_token_limit:
logging.warning(f"Request for user {user_id} exceeds inference token limit ({token_count}/{inference_token_limit}).")
# Pre-inference token limit check with budgeted optimizations
limit = self._get_inference_limit()
if limit is not None:
# Estimate on normalized messages after applying request-only summarization if needed
provisional = self._enforce_budget(messages)
token_count = self._estimate_tokens(provisional)
reserve = self.max_tokens if isinstance(self.max_tokens, int) else 1024
budget = max(1024, limit - reserve)
if token_count > budget:
logging.warning(f"Request for user {user_id} exceeds inference token budget even after optimizations ({token_count}/{budget}).")
# Do not persist this message in history as it was not processed by LLM
# Remove the last user message from history before returning, to prevent accumulation
if self.conversation_history[user_id] and self.conversation_history[user_id][-1]["role"] == "user" and self.conversation_history[user_id][-1]["content"] == user_message:
self.conversation_history[user_id].pop()
return "Request exceeds inference token limit. Please use the /clear command, or implement RAG in your application."
return "Request exceeds inference token limit after optimization. Please shorten your request, use /clear, or implement RAG in your application."
response = self.get_chat_response(messages)
@@ -204,7 +380,7 @@ class OpenAICompatibleInferenceBot(InferenceBot):
function_name = function_to_call.name
function_args_str = function_to_call.arguments
logging.info(f"Attempting to call tool: {function_name} with args: {function_args_str}")
logging.info(f"Attempting to call tool: {function_name} with args: [request summarized if large]")
if function_name not in [f["function"]["name"] for f in self.functions]:
logging.warning(f"Tool function {function_name} not found in available functions.")
tool_results_for_model.append({
@@ -232,6 +408,7 @@ class OpenAICompatibleInferenceBot(InferenceBot):
messages.extend(tool_results_for_model)
# Enforce budget before next LLM call (summarize request portion only; preserve tool responses)
response = self.get_chat_response(messages)
if not (response.choices and response.choices[0].message):
logging.error("No valid response choice message from LLM after tool call.")
@@ -251,7 +428,7 @@ class OpenAICompatibleInferenceBot(InferenceBot):
self.conversation_history[user_id] = messages
final_assistant_message = messages[-1]
return final_assistant_message.content if final_assistant_message.role == "assistant" and final_assistant_message.content is not None else "Assistant did not provide a textual response."
return final_assistant_message.content if getattr(final_assistant_message, "role", None) == "assistant" and getattr(final_assistant_message, "content", None) is not None else (final_assistant_message.get("content") if isinstance(final_assistant_message, dict) else "Assistant did not provide a textual response.")
async def start(self):
logging.info(f"{self.__class__.__name__} (Model: {self.model}) started.")