diff --git a/openai_compatible_inference_bot.py b/openai_compatible_inference_bot.py index 63b14e4..01d3ca7 100644 --- a/openai_compatible_inference_bot.py +++ b/openai_compatible_inference_bot.py @@ -3,6 +3,7 @@ import json import os import logging import inspect +import re from abc import abstractmethod from openai import OpenAI from tools.base_tool import BaseTool @@ -42,8 +43,8 @@ class OpenAICompatibleInferenceBot(InferenceBot): log_msg = f"Initialized OpenAI compatible client. Target URL: {base_url if base_url else 'OpenAI default'}." logging.info(log_msg) - # Load inference token limits - self.small_model_max_inference_tokens = int(os.getenv("_SMALL_MODEL_MAX_INFERENCE_TOKENS", "32768")) + # Load inference token limits (defaults: small=16k, large=32k) + self.small_model_max_inference_tokens = int(os.getenv("_SMALL_MODEL_MAX_INFERENCE_TOKENS", "16384")) self.large_model_max_inference_tokens = int(os.getenv("_LARGE_MODEL_MAX_INFERENCE_TOKENS", "32768")) # Configure the actual model name and max_tokens for API calls @@ -86,26 +87,205 @@ class OpenAICompatibleInferenceBot(InferenceBot): client_type = type(self.client).__name__ return f"Client: {client_type}, LLM: {self.model}, Max Tokens: {self.max_tokens if self.max_tokens is not None else 'API default'}" - def _count_tokens(self, messages, model): - """Returns the number of tokens in a list of messages.""" + def _encoding_for_model(self, model: str | None): try: - encoding = tiktoken.encoding_for_model(model) + return tiktoken.encoding_for_model(model) if model else tiktoken.get_encoding("cl100k_base") except KeyError: - encoding = tiktoken.get_encoding("cl100k_base") # Fallback for unknown models logging.warning(f"Warning: model {model} not found. Using cl100k_base encoding.") + return tiktoken.get_encoding("cl100k_base") + def _normalize_messages(self, messages): + """Return a list of plain dict chat messages acceptable by the API. + - Converts OpenAI SDK message objects into dicts + - Preserves tool_calls structure where present + """ + normalized = [] + for m in messages: + if isinstance(m, dict): + # Ensure only known keys are present; copy shallowly + entry = {k: v for k, v in m.items() if k in {"role", "content", "name", "tool_call_id", "tool_calls"}} + normalized.append(entry) + else: + # Likely an OpenAI message object + role = getattr(m, "role", None) + content = getattr(m, "content", None) + name = getattr(m, "name", None) + tool_calls = [] + tc_list = getattr(m, "tool_calls", None) + if tc_list: + for tc in tc_list: + try: + tool_calls.append({ + "id": getattr(tc, "id", None), + "type": getattr(tc, "type", "function"), + "function": { + "name": getattr(getattr(tc, "function", None), "name", None), + "arguments": getattr(getattr(tc, "function", None), "arguments", "{}"), + } + }) + except Exception: + # Best-effort fallback + tool_calls.append({"id": None, "type": "function", "function": {"name": "unknown", "arguments": "{}"}}) + entry = {"role": role, "content": content} + if name: + entry["name"] = name + if tool_calls: + entry["tool_calls"] = tool_calls + normalized.append(entry) + return normalized + + def _estimate_tokens(self, messages): + """Estimate tokens for messages with tiktoken, including tool_calls arguments. + Based on OpenAI's chat token counting rules approximation. + """ + enc = self._encoding_for_model(self.model) num_tokens = 0 - for message in messages: - num_tokens += 4 - if hasattr(message, "items"): - for key, value in message.items(): - if isinstance(value, str): - num_tokens += len(encoding.encode(value)) - if key == "name": - num_tokens += 1 - num_tokens += 2 + for m in messages: + num_tokens += 4 # per-message overhead + if not isinstance(m, dict): + continue + # role/content + for key in ("role", "name", "content"): + v = m.get(key) + if isinstance(v, str): + num_tokens += len(enc.encode(v)) + # tool calls request portion (arguments) + tcs = m.get("tool_calls") + if tcs and isinstance(tcs, list): + # approximate cost of the tool_calls JSON the model sees + for tc in tcs: + fn = tc.get("function", {}) if isinstance(tc, dict) else {} + fname = fn.get("name") + fargs = fn.get("arguments") + if isinstance(fname, str): + num_tokens += len(enc.encode(fname)) + if isinstance(fargs, str): + num_tokens += len(enc.encode(fargs)) + num_tokens += 2 # assistant priming return num_tokens + def _get_inference_limit(self): + current_model_is_small = self.model == self.model_config["small_model_name"] + current_model_is_large = self.model == self.model_config["large_model_name"] + if current_model_is_small: + return self.small_model_max_inference_tokens + if current_model_is_large: + return self.large_model_max_inference_tokens + logging.warning(f"Could not determine inference token limit for model: {self.model}. Proceeding without check.") + return None + + def _summarize_tool_args(self, args_str: str, max_chars: int = 512) -> str: + """Summarize tool-call request arguments without altering tool responses. + - If JSON, keep keys and short previews of string values. + - If plain string, truncate with an indicator. + """ + try: + parsed = json.loads(args_str) + if isinstance(parsed, dict): + summary = {} + for k, v in parsed.items(): + if isinstance(v, str): + if len(v) > 160: + summary[k] = v[:120] + f"... [len={len(v)}]" + else: + summary[k] = v + elif isinstance(v, (list, dict)): + # structural summary only + summary[k] = f"<{type(v).__name__} size={len(v)}>" + else: + summary[k] = v + s = json.dumps(summary, ensure_ascii=False) + if len(s) > max_chars: + s = s[: max_chars - 20] + "... [summarized]" + return s + except Exception: + pass + # Fallback: truncate raw string + return (args_str[: max_chars - 20] + "... [summarized]") if len(args_str) > max_chars else args_str + + def _summarize_tool_call_requests_in_messages(self, messages): + changed = False + for m in messages: + if isinstance(m, dict) and m.get("tool_calls"): + new_tool_calls = [] + for tc in m["tool_calls"]: + if not isinstance(tc, dict): + new_tool_calls.append(tc) + continue + fn = tc.get("function", {}) + args = fn.get("arguments") + if isinstance(args, str) and args and len(args) > 700: + # summarize long request arguments only + fn = dict(fn) + fn["arguments"] = self._summarize_tool_args(args) + tc = dict(tc) + tc["function"] = fn + changed = True + new_tool_calls.append(tc) + if changed: + m["tool_calls"] = new_tool_calls + return changed + + def _elide_redundant_code_blocks(self, messages): + """As a last resort, remove large code blocks from older assistant messages. + Keep the latest assistant message intact. + """ + changed = False + # Identify indices of assistant messages + assistant_indices = [i for i, m in enumerate(messages) if isinstance(m, dict) and m.get("role") == "assistant" and m.get("content")] + if len(assistant_indices) <= 1: + return changed + # Protect the last assistant message + for i in assistant_indices[:-1]: + m = messages[i] + content = m.get("content") + if not isinstance(content, str): + continue + if "```" in content or "\n " in content: + # Replace code blocks fenced by ``` with succinct markers + orig = content + content = re.sub(r"```[\s\S]*?```", "[code block omitted]", content) + # Also collapse long indented blocks + content = re.sub(r"(?:\n\s{4,}.+)+", "\n[long block omitted]", content) + if content != orig: + m["content"] = content + changed = True + return changed + + def _enforce_budget(self, messages): + """Normalize and enforce token budget by summarizing only tool-call requests first, + then eliding redundant code blocks if still too large. Returns normalized messages. + """ + normalized = self._normalize_messages(messages) + limit = self._get_inference_limit() + if not limit: + return normalized + # Reserve space for completion tokens + reserve = self.max_tokens if isinstance(self.max_tokens, int) else 1024 + budget = max(1024, limit - reserve) + + tokens = self._estimate_tokens(normalized) + if tokens <= budget: + return normalized + + # Step 1: summarize only tool-call request arguments + if self._summarize_tool_call_requests_in_messages(normalized): + tokens = self._estimate_tokens(normalized) + logging.info(f"Applied tool-call request summarization. tokens={tokens}/{budget}") + if tokens <= budget: + return normalized + + # Step 2: elide redundant code blocks from older assistant messages + if self._elide_redundant_code_blocks(normalized): + tokens = self._estimate_tokens(normalized) + logging.info(f"Elided redundant code blocks. tokens={tokens}/{budget}") + if tokens <= budget: + return normalized + + # If still over, log and proceed; the API may still reject; caller may choose to abort + logging.warning(f"Projected tokens still exceed budget after optimizations: {tokens}/{budget}") + return normalized + def get_chat_response(self, messages): if not self.client: logging.error("OpenAI client not initialized before get_chat_response.") @@ -128,9 +308,12 @@ class OpenAICompatibleInferenceBot(InferenceBot): func_copy = {k: v for k, v in func.items() if k != "_tags"} cleaned_tools.append(func_copy) + # Enforce token budget prior to API call + messages_for_api = self._enforce_budget(messages) + response = self.client.chat.completions.create( model=self.model, - messages=messages, + messages=messages_for_api, tools=cleaned_tools, tool_choice="auto" if cleaned_tools else None, max_tokens=self.max_tokens, @@ -158,27 +341,20 @@ class OpenAICompatibleInferenceBot(InferenceBot): self.conversation_history[user_id].append({"role": "user", "content": user_message}) messages = list(self.conversation_history[user_id]) - # Pre-inference token limit check - current_model_is_small = self.model == self.model_config["small_model_name"] - current_model_is_large = self.model == self.model_config["large_model_name"] - - inference_token_limit = None - if current_model_is_small: - inference_token_limit = self.small_model_max_inference_tokens - elif current_model_is_large: - inference_token_limit = self.large_model_max_inference_tokens - else: - logging.warning(f"Could not determine inference token limit for model: {self.model}. Proceeding without check.") - - if inference_token_limit is not None: - token_count = self._count_tokens(messages, self.model) - if token_count > inference_token_limit: - logging.warning(f"Request for user {user_id} exceeds inference token limit ({token_count}/{inference_token_limit}).") + # Pre-inference token limit check with budgeted optimizations + limit = self._get_inference_limit() + if limit is not None: + # Estimate on normalized messages after applying request-only summarization if needed + provisional = self._enforce_budget(messages) + token_count = self._estimate_tokens(provisional) + reserve = self.max_tokens if isinstance(self.max_tokens, int) else 1024 + budget = max(1024, limit - reserve) + if token_count > budget: + logging.warning(f"Request for user {user_id} exceeds inference token budget even after optimizations ({token_count}/{budget}).") # Do not persist this message in history as it was not processed by LLM - # Remove the last user message from history before returning, to prevent accumulation if self.conversation_history[user_id] and self.conversation_history[user_id][-1]["role"] == "user" and self.conversation_history[user_id][-1]["content"] == user_message: self.conversation_history[user_id].pop() - return "Request exceeds inference token limit. Please use the /clear command, or implement RAG in your application." + return "Request exceeds inference token limit after optimization. Please shorten your request, use /clear, or implement RAG in your application." response = self.get_chat_response(messages) @@ -204,7 +380,7 @@ class OpenAICompatibleInferenceBot(InferenceBot): function_name = function_to_call.name function_args_str = function_to_call.arguments - logging.info(f"Attempting to call tool: {function_name} with args: {function_args_str}") + logging.info(f"Attempting to call tool: {function_name} with args: [request summarized if large]") if function_name not in [f["function"]["name"] for f in self.functions]: logging.warning(f"Tool function {function_name} not found in available functions.") tool_results_for_model.append({ @@ -232,6 +408,7 @@ class OpenAICompatibleInferenceBot(InferenceBot): messages.extend(tool_results_for_model) + # Enforce budget before next LLM call (summarize request portion only; preserve tool responses) response = self.get_chat_response(messages) if not (response.choices and response.choices[0].message): logging.error("No valid response choice message from LLM after tool call.") @@ -251,7 +428,7 @@ class OpenAICompatibleInferenceBot(InferenceBot): self.conversation_history[user_id] = messages final_assistant_message = messages[-1] - return final_assistant_message.content if final_assistant_message.role == "assistant" and final_assistant_message.content is not None else "Assistant did not provide a textual response." + return final_assistant_message.content if getattr(final_assistant_message, "role", None) == "assistant" and getattr(final_assistant_message, "content", None) is not None else (final_assistant_message.get("content") if isinstance(final_assistant_message, dict) else "Assistant did not provide a textual response.") async def start(self): logging.info(f"{self.__class__.__name__} (Model: {self.model}) started.")