This commit is contained in:
2025-08-13 14:32:00 -05:00
+213 -36
View File
@@ -3,6 +3,7 @@ import json
import os import os
import logging import logging
import inspect import inspect
import re
from abc import abstractmethod from abc import abstractmethod
from openai import OpenAI from openai import OpenAI
from tools.base_tool import BaseTool from tools.base_tool import BaseTool
@@ -42,8 +43,8 @@ class OpenAICompatibleInferenceBot(InferenceBot):
log_msg = f"Initialized OpenAI compatible client. Target URL: {base_url if base_url else 'OpenAI default'}." log_msg = f"Initialized OpenAI compatible client. Target URL: {base_url if base_url else 'OpenAI default'}."
logging.info(log_msg) logging.info(log_msg)
# Load inference token limits # Load inference token limits (defaults: small=16k, large=32k)
self.small_model_max_inference_tokens = int(os.getenv("_SMALL_MODEL_MAX_INFERENCE_TOKENS", "32768")) self.small_model_max_inference_tokens = int(os.getenv("_SMALL_MODEL_MAX_INFERENCE_TOKENS", "16384"))
self.large_model_max_inference_tokens = int(os.getenv("_LARGE_MODEL_MAX_INFERENCE_TOKENS", "32768")) self.large_model_max_inference_tokens = int(os.getenv("_LARGE_MODEL_MAX_INFERENCE_TOKENS", "32768"))
# Configure the actual model name and max_tokens for API calls # Configure the actual model name and max_tokens for API calls
@@ -86,26 +87,205 @@ class OpenAICompatibleInferenceBot(InferenceBot):
client_type = type(self.client).__name__ client_type = type(self.client).__name__
return f"Client: {client_type}, LLM: {self.model}, Max Tokens: {self.max_tokens if self.max_tokens is not None else 'API default'}" return f"Client: {client_type}, LLM: {self.model}, Max Tokens: {self.max_tokens if self.max_tokens is not None else 'API default'}"
def _count_tokens(self, messages, model): def _encoding_for_model(self, model: str | None):
"""Returns the number of tokens in a list of messages."""
try: try:
encoding = tiktoken.encoding_for_model(model) return tiktoken.encoding_for_model(model) if model else tiktoken.get_encoding("cl100k_base")
except KeyError: except KeyError:
encoding = tiktoken.get_encoding("cl100k_base") # Fallback for unknown models
logging.warning(f"Warning: model {model} not found. Using cl100k_base encoding.") logging.warning(f"Warning: model {model} not found. Using cl100k_base encoding.")
return tiktoken.get_encoding("cl100k_base")
def _normalize_messages(self, messages):
"""Return a list of plain dict chat messages acceptable by the API.
- Converts OpenAI SDK message objects into dicts
- Preserves tool_calls structure where present
"""
normalized = []
for m in messages:
if isinstance(m, dict):
# Ensure only known keys are present; copy shallowly
entry = {k: v for k, v in m.items() if k in {"role", "content", "name", "tool_call_id", "tool_calls"}}
normalized.append(entry)
else:
# Likely an OpenAI message object
role = getattr(m, "role", None)
content = getattr(m, "content", None)
name = getattr(m, "name", None)
tool_calls = []
tc_list = getattr(m, "tool_calls", None)
if tc_list:
for tc in tc_list:
try:
tool_calls.append({
"id": getattr(tc, "id", None),
"type": getattr(tc, "type", "function"),
"function": {
"name": getattr(getattr(tc, "function", None), "name", None),
"arguments": getattr(getattr(tc, "function", None), "arguments", "{}"),
}
})
except Exception:
# Best-effort fallback
tool_calls.append({"id": None, "type": "function", "function": {"name": "unknown", "arguments": "{}"}})
entry = {"role": role, "content": content}
if name:
entry["name"] = name
if tool_calls:
entry["tool_calls"] = tool_calls
normalized.append(entry)
return normalized
def _estimate_tokens(self, messages):
"""Estimate tokens for messages with tiktoken, including tool_calls arguments.
Based on OpenAI's chat token counting rules approximation.
"""
enc = self._encoding_for_model(self.model)
num_tokens = 0 num_tokens = 0
for message in messages: for m in messages:
num_tokens += 4 num_tokens += 4 # per-message overhead
if hasattr(message, "items"): if not isinstance(m, dict):
for key, value in message.items(): continue
if isinstance(value, str): # role/content
num_tokens += len(encoding.encode(value)) for key in ("role", "name", "content"):
if key == "name": v = m.get(key)
num_tokens += 1 if isinstance(v, str):
num_tokens += 2 num_tokens += len(enc.encode(v))
# tool calls request portion (arguments)
tcs = m.get("tool_calls")
if tcs and isinstance(tcs, list):
# approximate cost of the tool_calls JSON the model sees
for tc in tcs:
fn = tc.get("function", {}) if isinstance(tc, dict) else {}
fname = fn.get("name")
fargs = fn.get("arguments")
if isinstance(fname, str):
num_tokens += len(enc.encode(fname))
if isinstance(fargs, str):
num_tokens += len(enc.encode(fargs))
num_tokens += 2 # assistant priming
return num_tokens return num_tokens
def _get_inference_limit(self):
current_model_is_small = self.model == self.model_config["small_model_name"]
current_model_is_large = self.model == self.model_config["large_model_name"]
if current_model_is_small:
return self.small_model_max_inference_tokens
if current_model_is_large:
return self.large_model_max_inference_tokens
logging.warning(f"Could not determine inference token limit for model: {self.model}. Proceeding without check.")
return None
def _summarize_tool_args(self, args_str: str, max_chars: int = 512) -> str:
"""Summarize tool-call request arguments without altering tool responses.
- If JSON, keep keys and short previews of string values.
- If plain string, truncate with an indicator.
"""
try:
parsed = json.loads(args_str)
if isinstance(parsed, dict):
summary = {}
for k, v in parsed.items():
if isinstance(v, str):
if len(v) > 160:
summary[k] = v[:120] + f"... [len={len(v)}]"
else:
summary[k] = v
elif isinstance(v, (list, dict)):
# structural summary only
summary[k] = f"<{type(v).__name__} size={len(v)}>"
else:
summary[k] = v
s = json.dumps(summary, ensure_ascii=False)
if len(s) > max_chars:
s = s[: max_chars - 20] + "... [summarized]"
return s
except Exception:
pass
# Fallback: truncate raw string
return (args_str[: max_chars - 20] + "... [summarized]") if len(args_str) > max_chars else args_str
def _summarize_tool_call_requests_in_messages(self, messages):
changed = False
for m in messages:
if isinstance(m, dict) and m.get("tool_calls"):
new_tool_calls = []
for tc in m["tool_calls"]:
if not isinstance(tc, dict):
new_tool_calls.append(tc)
continue
fn = tc.get("function", {})
args = fn.get("arguments")
if isinstance(args, str) and args and len(args) > 700:
# summarize long request arguments only
fn = dict(fn)
fn["arguments"] = self._summarize_tool_args(args)
tc = dict(tc)
tc["function"] = fn
changed = True
new_tool_calls.append(tc)
if changed:
m["tool_calls"] = new_tool_calls
return changed
def _elide_redundant_code_blocks(self, messages):
"""As a last resort, remove large code blocks from older assistant messages.
Keep the latest assistant message intact.
"""
changed = False
# Identify indices of assistant messages
assistant_indices = [i for i, m in enumerate(messages) if isinstance(m, dict) and m.get("role") == "assistant" and m.get("content")]
if len(assistant_indices) <= 1:
return changed
# Protect the last assistant message
for i in assistant_indices[:-1]:
m = messages[i]
content = m.get("content")
if not isinstance(content, str):
continue
if "```" in content or "\n " in content:
# Replace code blocks fenced by ``` with succinct markers
orig = content
content = re.sub(r"```[\s\S]*?```", "[code block omitted]", content)
# Also collapse long indented blocks
content = re.sub(r"(?:\n\s{4,}.+)+", "\n[long block omitted]", content)
if content != orig:
m["content"] = content
changed = True
return changed
def _enforce_budget(self, messages):
"""Normalize and enforce token budget by summarizing only tool-call requests first,
then eliding redundant code blocks if still too large. Returns normalized messages.
"""
normalized = self._normalize_messages(messages)
limit = self._get_inference_limit()
if not limit:
return normalized
# Reserve space for completion tokens
reserve = self.max_tokens if isinstance(self.max_tokens, int) else 1024
budget = max(1024, limit - reserve)
tokens = self._estimate_tokens(normalized)
if tokens <= budget:
return normalized
# Step 1: summarize only tool-call request arguments
if self._summarize_tool_call_requests_in_messages(normalized):
tokens = self._estimate_tokens(normalized)
logging.info(f"Applied tool-call request summarization. tokens={tokens}/{budget}")
if tokens <= budget:
return normalized
# Step 2: elide redundant code blocks from older assistant messages
if self._elide_redundant_code_blocks(normalized):
tokens = self._estimate_tokens(normalized)
logging.info(f"Elided redundant code blocks. tokens={tokens}/{budget}")
if tokens <= budget:
return normalized
# If still over, log and proceed; the API may still reject; caller may choose to abort
logging.warning(f"Projected tokens still exceed budget after optimizations: {tokens}/{budget}")
return normalized
def get_chat_response(self, messages): def get_chat_response(self, messages):
if not self.client: if not self.client:
logging.error("OpenAI client not initialized before get_chat_response.") logging.error("OpenAI client not initialized before get_chat_response.")
@@ -128,9 +308,12 @@ class OpenAICompatibleInferenceBot(InferenceBot):
func_copy = {k: v for k, v in func.items() if k != "_tags"} func_copy = {k: v for k, v in func.items() if k != "_tags"}
cleaned_tools.append(func_copy) cleaned_tools.append(func_copy)
# Enforce token budget prior to API call
messages_for_api = self._enforce_budget(messages)
response = self.client.chat.completions.create( response = self.client.chat.completions.create(
model=self.model, model=self.model,
messages=messages, messages=messages_for_api,
tools=cleaned_tools, tools=cleaned_tools,
tool_choice="auto" if cleaned_tools else None, tool_choice="auto" if cleaned_tools else None,
max_tokens=self.max_tokens, max_tokens=self.max_tokens,
@@ -158,27 +341,20 @@ class OpenAICompatibleInferenceBot(InferenceBot):
self.conversation_history[user_id].append({"role": "user", "content": user_message}) self.conversation_history[user_id].append({"role": "user", "content": user_message})
messages = list(self.conversation_history[user_id]) messages = list(self.conversation_history[user_id])
# Pre-inference token limit check # Pre-inference token limit check with budgeted optimizations
current_model_is_small = self.model == self.model_config["small_model_name"] limit = self._get_inference_limit()
current_model_is_large = self.model == self.model_config["large_model_name"] if limit is not None:
# Estimate on normalized messages after applying request-only summarization if needed
inference_token_limit = None provisional = self._enforce_budget(messages)
if current_model_is_small: token_count = self._estimate_tokens(provisional)
inference_token_limit = self.small_model_max_inference_tokens reserve = self.max_tokens if isinstance(self.max_tokens, int) else 1024
elif current_model_is_large: budget = max(1024, limit - reserve)
inference_token_limit = self.large_model_max_inference_tokens if token_count > budget:
else: logging.warning(f"Request for user {user_id} exceeds inference token budget even after optimizations ({token_count}/{budget}).")
logging.warning(f"Could not determine inference token limit for model: {self.model}. Proceeding without check.")
if inference_token_limit is not None:
token_count = self._count_tokens(messages, self.model)
if token_count > inference_token_limit:
logging.warning(f"Request for user {user_id} exceeds inference token limit ({token_count}/{inference_token_limit}).")
# Do not persist this message in history as it was not processed by LLM # Do not persist this message in history as it was not processed by LLM
# Remove the last user message from history before returning, to prevent accumulation
if self.conversation_history[user_id] and self.conversation_history[user_id][-1]["role"] == "user" and self.conversation_history[user_id][-1]["content"] == user_message: if self.conversation_history[user_id] and self.conversation_history[user_id][-1]["role"] == "user" and self.conversation_history[user_id][-1]["content"] == user_message:
self.conversation_history[user_id].pop() self.conversation_history[user_id].pop()
return "Request exceeds inference token limit. Please use the /clear command, or implement RAG in your application." return "Request exceeds inference token limit after optimization. Please shorten your request, use /clear, or implement RAG in your application."
response = self.get_chat_response(messages) response = self.get_chat_response(messages)
@@ -204,7 +380,7 @@ class OpenAICompatibleInferenceBot(InferenceBot):
function_name = function_to_call.name function_name = function_to_call.name
function_args_str = function_to_call.arguments function_args_str = function_to_call.arguments
logging.info(f"Attempting to call tool: {function_name} with args: {function_args_str}") logging.info(f"Attempting to call tool: {function_name} with args: [request summarized if large]")
if function_name not in [f["function"]["name"] for f in self.functions]: if function_name not in [f["function"]["name"] for f in self.functions]:
logging.warning(f"Tool function {function_name} not found in available functions.") logging.warning(f"Tool function {function_name} not found in available functions.")
tool_results_for_model.append({ tool_results_for_model.append({
@@ -232,6 +408,7 @@ class OpenAICompatibleInferenceBot(InferenceBot):
messages.extend(tool_results_for_model) messages.extend(tool_results_for_model)
# Enforce budget before next LLM call (summarize request portion only; preserve tool responses)
response = self.get_chat_response(messages) response = self.get_chat_response(messages)
if not (response.choices and response.choices[0].message): if not (response.choices and response.choices[0].message):
logging.error("No valid response choice message from LLM after tool call.") logging.error("No valid response choice message from LLM after tool call.")
@@ -251,7 +428,7 @@ class OpenAICompatibleInferenceBot(InferenceBot):
self.conversation_history[user_id] = messages self.conversation_history[user_id] = messages
final_assistant_message = messages[-1] final_assistant_message = messages[-1]
return final_assistant_message.content if final_assistant_message.role == "assistant" and final_assistant_message.content is not None else "Assistant did not provide a textual response." return final_assistant_message.content if getattr(final_assistant_message, "role", None) == "assistant" and getattr(final_assistant_message, "content", None) is not None else (final_assistant_message.get("content") if isinstance(final_assistant_message, dict) else "Assistant did not provide a textual response.")
async def start(self): async def start(self):
logging.info(f"{self.__class__.__name__} (Model: {self.model}) started.") logging.info(f"{self.__class__.__name__} (Model: {self.model}) started.")