diff --git a/openai_compatible_inference_bot.py b/openai_compatible_inference_bot.py new file mode 100644 index 0000000..eae60fe --- /dev/null +++ b/openai_compatible_inference_bot.py @@ -0,0 +1,133 @@ +import json +import os +import logging +from abc import abstractmethod +from base_telegram_inference_bot import BaseTelegramInferenceBot +from openai import OpenAI + +class OpenAICompatibleInferenceBot(BaseTelegramInferenceBot): + def __init__(self): + super().__init__() + # Client and model configuration will be handled by subclasses + self.client = None + self.model = None + self.max_tokens = None + + def _configure_model_and_tokens(self, model_name, max_tokens_str, default_max_tokens=1000): + self.model = model_name if model_name else "default-model" + try: + self.max_tokens = int(max_tokens_str) if max_tokens_str is not None else default_max_tokens + except ValueError: + logging.error(f"Invalid value for max_tokens: {max_tokens_str}. Using default {default_max_tokens}.") + self.max_tokens = default_max_tokens + logging.info(f"Configured to use model: {self.model} with max_tokens: {self.max_tokens}") + + def get_llm_description(self) -> str: + return f"LLM: {self.model}, Max Tokens: {self.max_tokens}" + + def get_chat_response(self, messages): + if not self.client: + raise ValueError("OpenAI client not initialized. Subclasses must initialize it.") + try: + response = self.client.chat.completions.create( + model=self.model, + messages=messages, + tools=self.functions if hasattr(self, 'functions') and self.functions else None, + tool_choice="auto" if hasattr(self, 'functions') and self.functions else None, + max_tokens=self.max_tokens + ) + return response + except Exception as e: + logging.error(f"API call failed: {e}") + raise + + async def handle_message(self, user_id, user_message): + if user_id not in self.conversation_history: + self.conversation_history[user_id] = [] + if hasattr(self, 'system_prompt') and self.system_prompt: + self.conversation_history[user_id].append({"role": "system", "content": self.system_prompt}) + + self.conversation_history[user_id].append({"role": "user", "content": user_message}) + messages = self.conversation_history[user_id] + + response = self.get_chat_response(messages) + + if not (response.choices and response.choices[0].message): + logging.error("No valid response choice message from LLM.") + return "Error: Could not get a valid response from the LLM." + + messages.append(response.choices[0].message) # Append the assistant's response message + + tool_calls_from_response = [] + if response.choices[0].message.tool_calls: + tool_calls_from_response.extend(response.choices[0].message.tool_calls) + + tool_use_count = 0 + MAX_TOOL_ITERATIONS = 5 + + while tool_calls_from_response and tool_use_count < MAX_TOOL_ITERATIONS: + tool_results_for_model = [] + + for tool_call in tool_calls_from_response: + tool_call_id = tool_call.id + function_to_call = tool_call.function + + logging.info(f"Attempting to call tool: {function_to_call.name} with args: {function_to_call.arguments}") + try: + tool_response_content = self.call_tool(function_to_call.name, function_to_call.arguments) + if not isinstance(tool_response_content, str): + tool_response_content = json.dumps(tool_response_content) + except Exception as e: + logging.error(f"Error calling tool {function_to_call.name}: {e}") + tool_response_content = f"Error executing tool {function_to_call.name}: {str(e)}" + + tool_results_for_model.append({ + "role": "tool", + "tool_call_id": tool_call_id, + "name": function_to_call.name, + "content": tool_response_content + }) + + messages.extend(tool_results_for_model) + + response = self.get_chat_response(messages) + if not (response.choices and response.choices[0].message): + logging.error("No valid response choice message from LLM after tool call.") + return "Error: Could not get a valid response from the LLM after tool call." + + messages.append(response.choices[0].message) + + tool_calls_from_response = [] + if response.choices[0].message.tool_calls: + tool_calls_from_response.extend(response.choices[0].message.tool_calls) + + tool_use_count += 1 + if tool_use_count >= MAX_TOOL_ITERATIONS and tool_calls_from_response: + logging.warning(f"Max tool iterations ({MAX_TOOL_ITERATIONS}) reached. Returning last assistant message.") + + # Conversation history management + # This limit should be reviewed and potentially made configurable + if len(self.conversation_history[user_id]) > 20: # Example limit, adjust as needed + self.conversation_history[user_id] = self.conversation_history[user_id][-20:] + + final_assistant_message = messages[-1] + return final_assistant_message.content if final_assistant_message.role == "assistant" and final_assistant_message.content else "No content in final message." + + async def start(self): + logging.info(f"{self.__class__.__name__} started.") + + async def clear(self, user_id): + super().clear_conversation_history(user_id) + + async def abort_processing(self, user_id): + if user_id in self.processing_status: + self.processing_status[user_id]["processing"] = False + await self.clear(user_id) + return "Processing aborted and conversation cleared." + else: + await self.clear(user_id) + return "No active processing found to abort. Conversation cleared." + + @abstractmethod + async def switch_model(self): + pass