Merge pull request #189 from bucolucas/refactor/bot-core-logic

Refactor Core Bot Logic: telegram_helper.py and *_inference_bot.py
This commit is contained in:
2025-06-02 15:04:14 -05:00
committed by GitHub
6 changed files with 261 additions and 387 deletions
+30 -44
View File
@@ -1,45 +1,29 @@
import os import os
import json import json
import logging import logging
from anthropic import Anthropic from anthropic import Anthropic, APIError, RateLimitError
from base_telegram_inference_bot import BaseTelegramInferenceBot from base_telegram_inference_bot import BaseTelegramInferenceBot
from telegram_helper import TelegramHelper from telegram_helper import TelegramHelper
# logging.basicConfig(level=logging.INFO) # Usually configured in main execution script
class AnthropicTelegramInferenceBot(BaseTelegramInferenceBot): class AnthropicTelegramInferenceBot(BaseTelegramInferenceBot):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.anthropic_client = Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY")) self.anthropic_client = Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY"))
# Note: default_headers for max_tokens with older models might be needed.
# For Claude 3.5 Sonnet, max_tokens is a top-level param in messages.create
# Configure model and tokens. Using Sonnet 3.5 as default.
# ANTHROPIC_MODEL and ANTHROPIC_MAX_TOKENS would be new ENVs.
self._configure_model_and_tokens( self._configure_model_and_tokens(
os.environ.get("ANTHROPIC_MODEL", "claude-3-5-sonnet-20240620"), os.environ.get("ANTHROPIC_MODEL", "claude-3-5-sonnet-20240620"),
os.environ.get("ANTHROPIC_MAX_TOKENS", "4096") # Default max tokens for Sonnet 3.5 os.environ.get("ANTHROPIC_MAX_TOKENS", "4096")
) )
def _configure_model_and_tokens(self, model_name, max_tokens_str, default_max_tokens=4096): def _configure_model_and_tokens(self, model_name, max_tokens_str, default_max_tokens=4096):
self.model = model_name if model_name else "claude-3-5-sonnet-20240620" self.model = model_name if model_name else "claude-3-5-sonnet-20240620"
try: try:
# Anthropic's max_tokens is an integer.
self.max_tokens = int(max_tokens_str) if max_tokens_str is not None else default_max_tokens self.max_tokens = int(max_tokens_str) if max_tokens_str is not None else default_max_tokens
except ValueError: except ValueError:
logging.error(f"Invalid value for Anthropic max_tokens: {max_tokens_str}. Using default {default_max_tokens}.") logging.error(f"Invalid value for Anthropic max_tokens: {max_tokens_str}. Using default {default_max_tokens}.")
self.max_tokens = default_max_tokens self.max_tokens = default_max_tokens
logging.info(f"Configured to use Anthropic model: {self.model} with max_tokens: {self.max_tokens}") logging.info(f"Configured to use Anthropic model: {self.model} with max_tokens: {self.max_tokens}")
def get_system_prompt_description(self) -> str:
system_prompt_path = os.getenv("SYSTEM_PROMPT_PATH")
if system_prompt_path and os.path.isfile(system_prompt_path):
return f"System Prompt File: {os.path.basename(system_prompt_path)}"
elif system_prompt_path:
return f"System Prompt File: {os.path.basename(system_prompt_path)} (Not found at path: {system_prompt_path})"
else:
return "System Prompt File: Not configured (SYSTEM_PROMPT_PATH not set)."
def get_llm_description(self) -> str: def get_llm_description(self) -> str:
return f"LLM: {self.model}, Max Tokens: {self.max_tokens}" return f"LLM: {self.model}, Max Tokens: {self.max_tokens}"
@@ -66,9 +50,27 @@ class AnthropicTelegramInferenceBot(BaseTelegramInferenceBot):
tool_choice={"type": "auto"} if anthropic_tools else None tool_choice={"type": "auto"} if anthropic_tools else None
) )
return response return response
except Exception as e: except (APIError, RateLimitError) as e:
logging.error(f"Anthropic API call failed: {e}") logging.error(f"Anthropic API error: {e}")
raise raise
except Exception as e:
logging.error(f"An unexpected error occurred during Anthropic API call: {e}")
raise
def _format_tool_response_for_anthropic(self, tool_response_data):
if isinstance(tool_response_data, str):
return [{"type": "text", "text": tool_response_data}]
elif isinstance(tool_response_data, (dict, list)):
try:
is_valid_block_list = isinstance(tool_response_data, list) and all(isinstance(item, dict) and "type" in item for item in tool_response_data)
if is_valid_block_list:
return tool_response_data
else:
return [{"type": "text", "text": json.dumps(tool_response_data)}]
except (TypeError, json.JSONDecodeError):
return [{"type": "text", "text": str(tool_response_data)}]
else:
return [{"type": "text", "text": str(tool_response_data)}]
async def handle_message(self, user_id, user_message): async def handle_message(self, user_id, user_message):
if user_id not in self.conversation_history: if user_id not in self.conversation_history:
@@ -86,7 +88,7 @@ class AnthropicTelegramInferenceBot(BaseTelegramInferenceBot):
if not response or not response.content: if not response or not response.content:
logging.error("No valid response content from Anthropic LLM.") logging.error("No valid response content from Anthropic LLM.")
self.conversation_history[user_id] = current_turn_messages # Persist what we have self.conversation_history[user_id] = current_turn_messages
return "Error: Could not get a valid response from the LLM." return "Error: Could not get a valid response from the LLM."
assistant_current_turn_content_blocks = response.content assistant_current_turn_content_blocks = response.content
@@ -114,22 +116,7 @@ class AnthropicTelegramInferenceBot(BaseTelegramInferenceBot):
logging.info(f"Attempting to call Anthropic tool: {tool_name} with input: {tool_input}") logging.info(f"Attempting to call Anthropic tool: {tool_name} with input: {tool_input}")
try: try:
tool_response_data = self.call_tool(tool_name, tool_input) tool_response_data = self.call_tool(tool_name, tool_input)
tool_result_content_block = self._format_tool_response_for_anthropic(tool_response_data)
if isinstance(tool_response_data, str):
tool_result_content_block = [{"type": "text", "text": tool_response_data}]
elif isinstance(tool_response_data, dict) or isinstance(tool_response_data, list):
try:
# If tool_response_data is already a list of Anthropic content blocks, use as is.
# Otherwise, dump to JSON string and wrap in a text block.
is_valid_block_list = isinstance(tool_response_data, list) and all(isinstance(item, dict) and "type" in item for item in tool_response_data)
if is_valid_block_list:
tool_result_content_block = tool_response_data
else:
tool_result_content_block = [{"type": "text", "text": json.dumps(tool_response_data)}]
except (TypeError, json.JSONDecodeError): # Not easily serializable or not a valid block list
tool_result_content_block = [{"type": "text", "text": str(tool_response_data)}]
else: # bool, int, float, None, etc.
tool_result_content_block = [{"type": "text", "text": str(tool_response_data)}]
tool_results_for_model.append({ tool_results_for_model.append({
"type": "tool_result", "type": "tool_result",
@@ -157,11 +144,10 @@ class AnthropicTelegramInferenceBot(BaseTelegramInferenceBot):
if len(self.conversation_history[user_id]) > 20: if len(self.conversation_history[user_id]) > 20:
self.conversation_history[user_id] = self.conversation_history[user_id][-20:] self.conversation_history[user_id] = self.conversation_history[user_id][-20:]
if assistant_response_content: # Text from the last successful assistant turn (or before max iterations) if assistant_response_content:
return assistant_response_content return assistant_response_content
else: # Fallback if no text content was generated by assistant (e.g. initial error, or only tool use) else:
if current_turn_messages: if current_turn_messages:
# Try to get the *very last* text block from the *very last* assistant message in history.
last_message_in_turn = current_turn_messages[-1] last_message_in_turn = current_turn_messages[-1]
if last_message_in_turn.get("role") == "assistant" and isinstance(last_message_in_turn.get("content"), list): if last_message_in_turn.get("role") == "assistant" and isinstance(last_message_in_turn.get("content"), list):
for block in reversed(last_message_in_turn["content"]): for block in reversed(last_message_in_turn["content"]):
@@ -173,17 +159,17 @@ class AnthropicTelegramInferenceBot(BaseTelegramInferenceBot):
async def start(self): async def start(self):
logging.info("Anthropic Bot started") logging.info("Anthropic Bot started")
async def clear(self, user_id): async def clear_conversation_history(self, user_id):
super().clear_conversation(user_id) super().clear_conversation_history(user_id)
logging.info(f"Cleared conversation history for Anthropic bot, user {user_id}") logging.info(f"Cleared conversation history for Anthropic bot, user {user_id}")
async def abort_processing(self, user_id): async def abort_processing(self, user_id):
if user_id in self.processing_status: if user_id in self.processing_status:
self.processing_status[user_id]["processing"] = False self.processing_status[user_id]["processing"] = False
await self.clear(user_id) await self.clear_conversation_history(user_id)
return "Processing aborted and conversation cleared." return "Processing aborted and conversation cleared."
else: else:
await self.clear(user_id) await self.clear_conversation_history(user_id)
return "No active processing found to abort. Conversation cleared." return "No active processing found to abort. Conversation cleared."
async def switch_model(self): async def switch_model(self):
+55 -25
View File
@@ -2,6 +2,7 @@ import importlib
import os import os
import json import json
import inspect import inspect
import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from tools.base_tool import BaseTool from tools.base_tool import BaseTool
@@ -11,32 +12,44 @@ class BaseTelegramInferenceBot(ABC):
self.processing_status = {} self.processing_status = {}
self.system_prompt = self.load_system_prompt() self.system_prompt = self.load_system_prompt()
self.tools, self.functions = self.load_functions() self.tools, self.functions = self.load_functions()
print(f'System Prompt: {os.environ.get("SYSTEM_PROMPT_PATH")}') logging.info(f'System Prompt: {os.environ.get("SYSTEM_PROMPT_PATH")}')
print(f'Github Repository: {os.environ.get("GITHUB_REPOSITORY")}') logging.info(f'Github Repository: {os.environ.get("GITHUB_REPOSITORY")}')
@staticmethod def load_system_prompt(self):
def load_system_prompt():
system_prompt_path = os.getenv("SYSTEM_PROMPT_PATH") system_prompt_path = os.getenv("SYSTEM_PROMPT_PATH")
if system_prompt_path and os.path.isfile(system_prompt_path): if system_prompt_path and os.path.isfile(system_prompt_path):
with open(system_prompt_path, "r", encoding="utf-8") as file: try:
return file.read().strip() with open(system_prompt_path, "r", encoding="utf-8") as file:
return file.read().strip()
except IOError as e:
logging.warning(f"Could not read system prompt file {system_prompt_path}: {e}")
return "You are a helpful AI assistant."
else: else:
raise FileNotFoundError("SYSTEM_PROMPT_PATH is not set or file does not exist.") logging.warning("SYSTEM_PROMPT_PATH is not set or file does not exist. Using default system prompt.")
return "You are a helpful AI assistant."
@staticmethod def load_functions(self):
def load_functions():
tools = [] tools = []
functions = []
tools_dir = os.path.join(os.path.dirname(__file__), 'tools') tools_dir = os.path.join(os.path.dirname(__file__), 'tools')
if not os.path.exists(tools_dir):
logging.warning(f"Tools directory not found: {tools_dir}")
return [], []
for filename in os.listdir(tools_dir): for filename in os.listdir(tools_dir):
if filename.endswith('.py') and filename != '__init__.py' and filename != 'base_tool.py': if filename.endswith('.py') and filename != '__init__.py' and filename != 'base_tool.py':
module_name = f'tools.{filename[:-3]}' module_name = f'tools.{filename[:-3]}'
module = importlib.import_module(module_name) try:
for name, obj in inspect.getmembers(module): module = importlib.import_module(module_name)
if inspect.isclass(obj) and issubclass(obj, BaseTool) and obj != BaseTool: for name, obj in inspect.getmembers(module):
tools.append(obj()) if inspect.isclass(obj) and issubclass(obj, BaseTool) and obj != BaseTool:
try:
tools.append(obj())
except Exception as e:
logging.error(f"Error instantiating tool {name} from {filename}: {e}")
except Exception as e:
logging.error(f"Error importing module {module_name}: {e}")
# Collect all function definitions
functions = []
for tool in tools: for tool in tools:
functions.extend(tool.get_functions()) functions.extend(tool.get_functions())
return tools, functions return tools, functions
@@ -49,37 +62,53 @@ class BaseTelegramInferenceBot(ABC):
async def handle_message(self, user_id, user_message): async def handle_message(self, user_id, user_message):
pass pass
def clear_conversation(self, user_id): def clear_conversation_history(self, user_id):
if user_id in self.conversation_history: if user_id in self.conversation_history:
del self.conversation_history[user_id] del self.conversation_history[user_id]
# Assuming tool.clear() is for global state or doesn't need user_id
for tool in self.tools: for tool in self.tools:
tool.clear() tool.clear()
def set_processing_status(self, user_id: int, message_id: int):
self.processing_status[user_id] = {"processing": True, "message_id": message_id}
def clear_processing_status(self, user_id: int):
if user_id in self.processing_status:
del self.processing_status[user_id]
def call_tool(self, function_call_name, function_call_arguments): def call_tool(self, function_call_name, function_call_arguments):
function_name = function_call_name function_name = function_call_name
function_args = json.loads(function_call_arguments if function_call_arguments is not None else "{}") try:
function_args = json.loads(function_call_arguments if function_call_arguments is not None else "{}")
except json.JSONDecodeError as e:
logging.error(f"Error decoding function call arguments for {function_call_name}: {e}. Arguments: {function_call_arguments}")
return f"Error: Malformed arguments for tool call: {e}"
for tool in self.tools: for tool in self.tools:
for function in tool.get_functions(): for function in tool.get_functions():
if function["function"]["name"] == function_name: if function["function"]["name"] == function_name:
return tool.execute(function_name, **function_args) try:
return tool.execute(function_name, **function_args)
except Exception as e:
logging.error(f"Error executing tool {function_name} with args {function_args}: {e}")
return f"Error executing tool {function_name}: {e}"
logging.warning(f"Tool function {function_name} not found.")
return f"Error: Tool function {function_name} not found."
@abstractmethod
def get_system_prompt_description(self) -> str: def get_system_prompt_description(self) -> str:
"""Returns a description of the system prompt being used.""" """Returns a description of the system prompt being used."""
pass return f"System Prompt: {'Custom' if os.getenv('SYSTEM_PROMPT_PATH') else 'Default'}"
@abstractmethod @abstractmethod
def get_llm_description(self) -> str: def get_llm_description(self) -> str:
"""Returns a description of the LLM being used.""" """Returns a description of the LLM being used."""
pass pass
async def status(self) -> str: # Changed from abstract to concrete async def get_bot_status(self) -> str:
"""Provides a status message including prompt and LLM information.""" """Provides a status message including prompt and LLM information."""
prompt_desc = self.get_system_prompt_description() prompt_desc = self.get_system_prompt_description()
llm_desc = self.get_llm_description() llm_desc = self.get_llm_description()
# Consider potential async calls if get_... methods were async
# For now, assuming they are synchronous as per design
return f"{prompt_desc}\n{llm_desc}" return f"{prompt_desc}\n{llm_desc}"
@abstractmethod @abstractmethod
@@ -87,9 +116,10 @@ class BaseTelegramInferenceBot(ABC):
pass pass
@abstractmethod @abstractmethod
async def clear(self, user_id): async def abort_processing(self, user_id):
pass pass
@abstractmethod @abstractmethod
async def abort_processing(self, user_id): async def switch_model(self):
"""Switches the underlying model if supported by the bot."""
pass pass
+7 -137
View File
@@ -1,156 +1,27 @@
import json
import os import os
import logging import logging
from base_telegram_inference_bot import BaseTelegramInferenceBot
from telegram_helper import TelegramHelper
from openai import OpenAI from openai import OpenAI
from openai_compatible_inference_bot import OpenAICompatibleInferenceBot
from telegram_helper import TelegramHelper
# logging.basicConfig(level=logging.INFO) # Usually configured in main execution script class ChatGPTTelegramInferenceBot(OpenAICompatibleInferenceBot):
class ChatGPTTelegramInferenceBot(BaseTelegramInferenceBot):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) self.client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
self._configure_model_and_tokens( self._configure_model_and_tokens(
os.environ.get("OPENAI_SMALL_MODEL", "gpt-3.5-turbo"), # Default to a common small model os.environ.get("OPENAI_SMALL_MODEL", "gpt-3.5-turbo"),
os.environ.get("OPENAI_SMALL_MODEL_MAX_TOKENS") os.environ.get("OPENAI_SMALL_MODEL_MAX_TOKENS")
) )
def _configure_model_and_tokens(self, model_name, max_tokens_str, default_max_tokens=1000):
self.model = model_name if model_name else "gpt-3.5-turbo" # Ensure model has a default
try:
self.max_tokens = int(max_tokens_str) if max_tokens_str is not None else default_max_tokens
except ValueError:
logging.error(f"Invalid value for max_tokens: {max_tokens_str}. Using default {default_max_tokens}.")
self.max_tokens = default_max_tokens
logging.info(f"Configured to use model: {self.model} with max_tokens: {self.max_tokens}")
def get_system_prompt_description(self) -> str:
system_prompt_path = os.getenv("SYSTEM_PROMPT_PATH")
if system_prompt_path and os.path.isfile(system_prompt_path):
return f"System Prompt File: {os.path.basename(system_prompt_path)}"
elif system_prompt_path: # Path is set but file not found
return f"System Prompt File: {os.path.basename(system_prompt_path)} (Not found at path: {system_prompt_path})"
else: # Path not set
return "System Prompt File: Not configured (SYSTEM_PROMPT_PATH not set)."
def get_llm_description(self) -> str:
return f"LLM: {self.model}, Max Tokens: {self.max_tokens}"
def get_chat_response(self, messages):
try:
response = self.client.chat.completions.create(
model=self.model,
messages=messages,
tools=self.functions if hasattr(self, 'functions') and self.functions else None,
tool_choice="auto" if hasattr(self, 'functions') and self.functions else None,
max_tokens=self.max_tokens
)
return response
except Exception as e:
logging.error(f"OpenAI API call failed: {e}")
raise
async def handle_message(self, user_id, user_message):
if user_id not in self.conversation_history:
self.conversation_history[user_id] = []
if hasattr(self, 'system_prompt') and self.system_prompt:
self.conversation_history[user_id].append({"role": "system", "content": self.system_prompt})
self.conversation_history[user_id].append({"role": "user", "content": user_message})
messages = self.conversation_history[user_id]
response = self.get_chat_response(messages)
if not (response.choices and response.choices[0].message):
logging.error("No valid response choice message from LLM.")
return "Error: Could not get a valid response from the LLM."
messages.append(response.choices[0].message) # Append the assistant's response message
tool_calls_from_response = []
if response.choices[0].message.tool_calls:
tool_calls_from_response.extend(response.choices[0].message.tool_calls)
tool_use_count = 0
MAX_TOOL_ITERATIONS = 5
while tool_calls_from_response and tool_use_count < MAX_TOOL_ITERATIONS:
tool_results_for_model = []
for tool_call in tool_calls_from_response:
tool_call_id = tool_call.id
function_to_call = tool_call.function
logging.info(f"Attempting to call tool: {function_to_call.name} with args: {function_to_call.arguments}")
try:
tool_response_content = self.call_tool(function_to_call.name, function_to_call.arguments)
if not isinstance(tool_response_content, str):
tool_response_content = json.dumps(tool_response_content)
except Exception as e:
logging.error(f"Error calling tool {function_to_call.name}: {e}")
tool_response_content = f"Error executing tool {function_to_call.name}: {str(e)}"
tool_results_for_model.append({
"role": "tool",
"tool_call_id": tool_call_id,
"name": function_to_call.name,
"content": tool_response_content
})
messages.extend(tool_results_for_model)
response = self.get_chat_response(messages)
if not (response.choices and response.choices[0].message):
logging.error("No valid response choice message from LLM after tool call.")
return "Error: Could not get a valid response from the LLM after tool call."
messages.append(response.choices[0].message)
tool_calls_from_response = []
if response.choices[0].message.tool_calls:
tool_calls_from_response.extend(response.choices[0].message.tool_calls)
tool_use_count += 1
if tool_use_count >= MAX_TOOL_ITERATIONS and tool_calls_from_response:
logging.warning(f"Max tool iterations ({MAX_TOOL_ITERATIONS}) reached. Returning last assistant message.")
if len(self.conversation_history[user_id]) > 20: # This limit seems small, consider increasing
self.conversation_history[user_id] = self.conversation_history[user_id][-20:]
final_assistant_message = messages[-1]
return final_assistant_message.content if final_assistant_message.role == "assistant" and final_assistant_message.content else "No content in final message."
async def start(self):
logging.info("ChatGPT Bot started")
# super().start() if Base class start() has common logic
async def clear(self, user_id):
super().clear_conversation(user_id)
# status() method is inherited from BaseTelegramInferenceBot
async def abort_processing(self, user_id):
if user_id in self.processing_status: # Relies on processing_status from Base
self.processing_status[user_id]["processing"] = False
await self.clear(user_id)
return "Processing aborted and conversation cleared."
else:
await self.clear(user_id)
return "No active processing found to abort. Conversation cleared."
async def switch_model(self): async def switch_model(self):
# Ensure environment variables for model names are set for this to work meaningfully
current_small_model = os.environ.get("OPENAI_SMALL_MODEL", "gpt-3.5-turbo") current_small_model = os.environ.get("OPENAI_SMALL_MODEL", "gpt-3.5-turbo")
current_large_model = os.environ.get("OPENAI_LARGE_MODEL", "gpt-4") # Example large model current_large_model = os.environ.get("OPENAI_LARGE_MODEL", "gpt-4")
# Default to small model if current model is not recognized or if it's the large one if self.model == current_large_model or self.model != current_small_model:
if self.model == current_large_model or self.model != current_small_model :
target_model = current_small_model target_model = current_small_model
target_max_tokens = os.environ.get("OPENAI_SMALL_MODEL_MAX_TOKENS") target_max_tokens = os.environ.get("OPENAI_SMALL_MODEL_MAX_TOKENS")
else: # Current is small (or default), switch to large else:
target_model = current_large_model target_model = current_large_model
target_max_tokens = os.environ.get("OPENAI_LARGE_MODEL_MAX_TOKENS") target_max_tokens = os.environ.get("OPENAI_LARGE_MODEL_MAX_TOKENS")
@@ -163,7 +34,6 @@ def main():
logging.error("FATAL: OPENAI_API_KEY environment variable not set.") logging.error("FATAL: OPENAI_API_KEY environment variable not set.")
return return
# Configure logging here if it's the main entry point
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
bot = ChatGPTTelegramInferenceBot() bot = ChatGPTTelegramInferenceBot()
+7 -152
View File
@@ -1,166 +1,27 @@
import json
import os import os
import logging import logging
from base_telegram_inference_bot import BaseTelegramInferenceBot
from telegram_helper import TelegramHelper # This import might be unused if main() is removed or TelegramHelper is not directly instantiated here.
from openai import OpenAI from openai import OpenAI
from openai_compatible_inference_bot import OpenAICompatibleInferenceBot
from telegram_helper import TelegramHelper
# logging.basicConfig(level=logging.INFO) # Usually configured in main execution script class GeminiTelegramInferenceBot(OpenAICompatibleInferenceBot):
class GeminiTelegramInferenceBot(BaseTelegramInferenceBot):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.client = OpenAI(api_key=os.environ.get("GEMINI_API_KEY"), base_url=os.environ.get("GEMINI_API_BASE_URL")) self.client = OpenAI(api_key=os.environ.get("GEMINI_API_KEY"), base_url=os.environ.get("GEMINI_API_BASE_URL"))
self._configure_model_and_tokens( self._configure_model_and_tokens(
os.environ.get("GEMINI_SMALL_MODEL"), os.environ.get("GEMINI_SMALL_MODEL", "gemini-pro"),
os.environ.get("GEMINI_SMALL_MODEL_MAX_TOKENS") os.environ.get("GEMINI_SMALL_MODEL_MAX_TOKENS")
) )
def _configure_model_and_tokens(self, model_name, max_tokens_str, default_max_tokens=1000):
self.model = model_name if model_name else "default-gemini-model" # Ensure model has a default
try:
self.max_tokens = int(max_tokens_str) if max_tokens_str is not None else default_max_tokens
except ValueError:
logging.error(f"Invalid value for max_tokens: {max_tokens_str}. Using default {default_max_tokens}.")
self.max_tokens = default_max_tokens
logging.info(f"Configured to use model: {self.model} with max_tokens: {self.max_tokens}")
def get_system_prompt_description(self) -> str:
system_prompt_path = os.getenv("SYSTEM_PROMPT_PATH")
if system_prompt_path and os.path.isfile(system_prompt_path):
return f"System Prompt File: {os.path.basename(system_prompt_path)}"
elif system_prompt_path: # Path is set but file not found
return f"System Prompt File: {os.path.basename(system_prompt_path)} (Not found at path: {system_prompt_path})"
else: # Path not set
return "System Prompt File: Not configured (SYSTEM_PROMPT_PATH not set)."
def get_llm_description(self) -> str:
return f"LLM: {self.model}, Max Tokens: {self.max_tokens}"
def get_chat_response(self, messages):
try:
response = self.client.chat.completions.create(
model=self.model,
messages=messages,
tools=self.functions if hasattr(self, 'functions') and self.functions else None,
tool_choice="auto" if hasattr(self, 'functions') and self.functions else None,
max_tokens=self.max_tokens
)
return response
except Exception as e:
logging.error(f"Gemini API call failed: {e}")
# Return a more structured error or re-raise a custom exception
# For now, re-raising to be handled by the caller
raise
async def handle_message(self, user_id, user_message):
if user_id not in self.conversation_history:
self.conversation_history[user_id] = []
if hasattr(self, 'system_prompt') and self.system_prompt:
self.conversation_history[user_id].append({"role": "system", "content": self.system_prompt})
self.conversation_history[user_id].append({"role": "user", "content": user_message})
messages = self.conversation_history[user_id]
response = self.get_chat_response(messages)
# Ensure response.choices[0].message exists before appending
if response.choices and response.choices[0].message:
messages.append(response.choices[0].message) # Append the assistant's response message
else:
logging.error("No valid response choice message from LLM.")
return "Error: Could not get a valid response from the LLM."
tool_calls_from_response = []
if response.choices[0].message.tool_calls:
tool_calls_from_response.extend(response.choices[0].message.tool_calls)
tool_use_count = 0
MAX_TOOL_ITERATIONS = 5 # Define a max to prevent infinite loops more explicitly
while tool_calls_from_response and tool_use_count < MAX_TOOL_ITERATIONS:
tool_results_for_model = [] # Results to be sent back to the model
for tool_call in tool_calls_from_response:
tool_call_id = tool_call.id
function_to_call = tool_call.function
logging.info(f"Attempting to call tool: {function_to_call.name} with args: {function_to_call.arguments}")
try:
tool_response_content = self.call_tool(function_to_call.name, function_to_call.arguments)
# Ensure tool_response_content is a string for the API
if not isinstance(tool_response_content, str):
tool_response_content = json.dumps(tool_response_content)
except Exception as e:
logging.error(f"Error calling tool {function_to_call.name}: {e}")
tool_response_content = f"Error executing tool {function_to_call.name}: {str(e)}"
tool_results_for_model.append({
"role": "tool",
"tool_call_id": tool_call_id,
"name": function_to_call.name,
"content": tool_response_content
})
messages.extend(tool_results_for_model) # Add tool responses to message history
# Get new response from model based on tool execution results
response = self.get_chat_response(messages)
if not (response.choices and response.choices[0].message):
logging.error("No valid response choice message from LLM after tool call.")
return "Error: Could not get a valid response from the LLM after tool call."
messages.append(response.choices[0].message) # Append new assistant message
# Check for new tool calls
tool_calls_from_response = [] # Reset for this iteration
if response.choices[0].message.tool_calls:
tool_calls_from_response.extend(response.choices[0].message.tool_calls)
tool_use_count += 1
if tool_use_count >= MAX_TOOL_ITERATIONS and tool_calls_from_response:
logging.warning(f"Max tool iterations ({MAX_TOOL_ITERATIONS}) reached. Returning last assistant message.")
# May need to return a message indicating this to user
# Conversation history management
if len(self.conversation_history[user_id]) > 2000: # Assuming this limit is for messages, not tokens
self.conversation_history[user_id] = self.conversation_history[user_id][-2000:]
# Return the latest assistant content
final_assistant_message = messages[-1]
return final_assistant_message.content if final_assistant_message.role == "assistant" and final_assistant_message.content else "No content in final message."
async def start(self):
logging.info("Gemini Bot started")
# super().start() if Base class start() has common logic
async def clear(self, user_id):
super().clear_conversation(user_id) # Calls base class method
# status() method is inherited from BaseTelegramInferenceBot
async def abort_processing(self, user_id):
if user_id in self.processing_status:
self.processing_status[user_id]["processing"] = False
# It's good practice to also clear the conversation for an aborted state
await self.clear(user_id)
return "Processing aborted and conversation cleared."
else:
# If no specific status, clearing conversation is a safe default
await self.clear(user_id)
return "No active processing found to abort. Conversation cleared."
async def switch_model(self): async def switch_model(self):
current_small_model = os.environ.get("GEMINI_SMALL_MODEL") current_small_model = os.environ.get("GEMINI_SMALL_MODEL", "gemini-pro")
current_large_model = os.environ.get("GEMINI_LARGE_MODEL") current_large_model = os.environ.get("GEMINI_LARGE_MODEL", "gemini-1.5-pro-latest")
# Default to small model if current model is not recognized or if it's the large one
if self.model == current_large_model or self.model != current_small_model : if self.model == current_large_model or self.model != current_small_model :
target_model = current_small_model target_model = current_small_model
target_max_tokens = os.environ.get("GEMINI_SMALL_MODEL_MAX_TOKENS") target_max_tokens = os.environ.get("GEMINI_SMALL_MODEL_MAX_TOKENS")
else: # Current is small, switch to large else:
target_model = current_large_model target_model = current_large_model
target_max_tokens = os.environ.get("GEMINI_LARGE_MODEL_MAX_TOKENS") target_max_tokens = os.environ.get("GEMINI_LARGE_MODEL_MAX_TOKENS")
@@ -168,20 +29,14 @@ class GeminiTelegramInferenceBot(BaseTelegramInferenceBot):
logging.info(f"Switched to model: {self.model}") logging.info(f"Switched to model: {self.model}")
return f"Switched to model: {self.model}" return f"Switched to model: {self.model}"
# The main() function and if __name__ == '__main__': block are for standalone execution.
# If this bot is imported as a module, these might not be necessary or might be handled differently.
# For now, keeping them as they were.
def main(): def main():
if not os.environ.get("GEMINI_API_KEY"): if not os.environ.get("GEMINI_API_KEY"):
logging.error("FATAL: GEMINI_API_KEY environment variable not set.") logging.error("FATAL: GEMINI_API_KEY environment variable not set.")
return return
# Configure logging here if it's the main entry point
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
bot = GeminiTelegramInferenceBot() bot = GeminiTelegramInferenceBot()
# The instantiation of TelegramHelper and running it implies this file can be an entry point.
# If it's purely a module, this main() would be removed.
telegram_helper = TelegramHelper(bot) telegram_helper = TelegramHelper(bot)
telegram_helper.run() telegram_helper.run()
+133
View File
@@ -0,0 +1,133 @@
import json
import os
import logging
from abc import abstractmethod
from base_telegram_inference_bot import BaseTelegramInferenceBot
from openai import OpenAI
class OpenAICompatibleInferenceBot(BaseTelegramInferenceBot):
def __init__(self):
super().__init__()
# Client and model configuration will be handled by subclasses
self.client = None
self.model = None
self.max_tokens = None
def _configure_model_and_tokens(self, model_name, max_tokens_str, default_max_tokens=1000):
self.model = model_name if model_name else "default-model"
try:
self.max_tokens = int(max_tokens_str) if max_tokens_str is not None else default_max_tokens
except ValueError:
logging.error(f"Invalid value for max_tokens: {max_tokens_str}. Using default {default_max_tokens}.")
self.max_tokens = default_max_tokens
logging.info(f"Configured to use model: {self.model} with max_tokens: {self.max_tokens}")
def get_llm_description(self) -> str:
return f"LLM: {self.model}, Max Tokens: {self.max_tokens}"
def get_chat_response(self, messages):
if not self.client:
raise ValueError("OpenAI client not initialized. Subclasses must initialize it.")
try:
response = self.client.chat.completions.create(
model=self.model,
messages=messages,
tools=self.functions if hasattr(self, 'functions') and self.functions else None,
tool_choice="auto" if hasattr(self, 'functions') and self.functions else None,
max_tokens=self.max_tokens
)
return response
except Exception as e:
logging.error(f"API call failed: {e}")
raise
async def handle_message(self, user_id, user_message):
if user_id not in self.conversation_history:
self.conversation_history[user_id] = []
if hasattr(self, 'system_prompt') and self.system_prompt:
self.conversation_history[user_id].append({"role": "system", "content": self.system_prompt})
self.conversation_history[user_id].append({"role": "user", "content": user_message})
messages = self.conversation_history[user_id]
response = self.get_chat_response(messages)
if not (response.choices and response.choices[0].message):
logging.error("No valid response choice message from LLM.")
return "Error: Could not get a valid response from the LLM."
messages.append(response.choices[0].message) # Append the assistant's response message
tool_calls_from_response = []
if response.choices[0].message.tool_calls:
tool_calls_from_response.extend(response.choices[0].message.tool_calls)
tool_use_count = 0
MAX_TOOL_ITERATIONS = 5
while tool_calls_from_response and tool_use_count < MAX_TOOL_ITERATIONS:
tool_results_for_model = []
for tool_call in tool_calls_from_response:
tool_call_id = tool_call.id
function_to_call = tool_call.function
logging.info(f"Attempting to call tool: {function_to_call.name} with args: {function_to_call.arguments}")
try:
tool_response_content = self.call_tool(function_to_call.name, function_to_call.arguments)
if not isinstance(tool_response_content, str):
tool_response_content = json.dumps(tool_response_content)
except Exception as e:
logging.error(f"Error calling tool {function_to_call.name}: {e}")
tool_response_content = f"Error executing tool {function_to_call.name}: {str(e)}"
tool_results_for_model.append({
"role": "tool",
"tool_call_id": tool_call_id,
"name": function_to_call.name,
"content": tool_response_content
})
messages.extend(tool_results_for_model)
response = self.get_chat_response(messages)
if not (response.choices and response.choices[0].message):
logging.error("No valid response choice message from LLM after tool call.")
return "Error: Could not get a valid response from the LLM after tool call."
messages.append(response.choices[0].message)
tool_calls_from_response = []
if response.choices[0].message.tool_calls:
tool_calls_from_response.extend(response.choices[0].message.tool_calls)
tool_use_count += 1
if tool_use_count >= MAX_TOOL_ITERATIONS and tool_calls_from_response:
logging.warning(f"Max tool iterations ({MAX_TOOL_ITERATIONS}) reached. Returning last assistant message.")
# Conversation history management
# This limit should be reviewed and potentially made configurable
if len(self.conversation_history[user_id]) > 20: # Example limit, adjust as needed
self.conversation_history[user_id] = self.conversation_history[user_id][-20:]
final_assistant_message = messages[-1]
return final_assistant_message.content if final_assistant_message.role == "assistant" and final_assistant_message.content else "No content in final message."
async def start(self):
logging.info(f"{self.__class__.__name__} started.")
async def clear(self, user_id):
super().clear_conversation_history(user_id)
async def abort_processing(self, user_id):
if user_id in self.processing_status:
self.processing_status[user_id]["processing"] = False
await self.clear(user_id)
return "Processing aborted and conversation cleared."
else:
await self.clear(user_id)
return "No active processing found to abort. Conversation cleared."
@abstractmethod
async def switch_model(self):
pass
+21 -21
View File
@@ -3,16 +3,21 @@ import logging
import sys import sys
import asyncio import asyncio
import time import time
import git
from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup
from telegram.ext import Application, CommandHandler, MessageHandler, filters, ContextTypes, CallbackQueryHandler from telegram.ext import Application, CommandHandler, MessageHandler, filters, ContextTypes, CallbackQueryHandler
from browse_command import browse_command, button_callback from browse_command import browse_command, button_callback
class TelegramHelper: class TelegramHelper:
# --- Constants for configurable paths and magic strings ---
REBOOT_CLAUDE_FILE = '.reboot_claude'
REBOOT_FILE = '.doreboot'
CLAUDE_REBOOT_TARGET = 'claude'
HTML_QUOTE_BLOCK_START = '<blockquote expandable><b>Thinking...</b>'
HTML_QUOTE_BLOCK_END = '</blockquote>'
def __init__(self, bot): def __init__(self, bot):
self.bot = bot self.bot = bot
self.telegram_bot_token = os.getenv('TELEGRAM_BOT_TOKEN') self.telegram_bot_token = os.getenv('TELEGRAM_BOT_TOKEN')
self.repo = git.Repo(".")
self.start_time = time.time() self.start_time = time.time()
async def start(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: async def start(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
@@ -23,11 +28,11 @@ class TelegramHelper:
async def clear(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: async def clear(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
user_id = update.effective_user.id user_id = update.effective_user.id
await self.bot.clear(user_id) await self.bot.clear_conversation_history(user_id)
await update.message.reply_text("Conversation history cleared. Let's start fresh!") await update.message.reply_text("Conversation history cleared. Let's start fresh!")
async def status(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: async def status(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
status_message = await self.bot.status() status_message = await self.bot.get_bot_status()
await update.message.reply_text(status_message) await update.message.reply_text(status_message)
async def switch(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: async def switch(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
@@ -56,21 +61,20 @@ class TelegramHelper:
logging.info(f"Message from user {user_id}: {user_message}") logging.info(f"Message from user {user_id}: {user_message}")
status_message = await update.message.reply_text("Processing your request...", reply_markup=InlineKeyboardMarkup([[InlineKeyboardButton("Abort", callback_data='abort')]])) status_message = await update.message.reply_text("Processing your request...", reply_markup=InlineKeyboardMarkup([[InlineKeyboardButton("Abort", callback_data='abort')]]))\
self.bot.processing_status[user_id] = {"processing": True, "message_id": status_message.message_id} await self.bot.set_processing_status(user_id, status_message.message_id)
response = await self.bot.handle_message(user_id, user_message) response = await self.bot.handle_message(user_id, user_message)
await context.bot.delete_message(chat_id=update.effective_chat.id, message_id=status_message.message_id) await context.bot.delete_message(chat_id=update.effective_chat.id, message_id=status_message.message_id)
del self.bot.processing_status[user_id] await self.bot.clear_processing_status(user_id)
response = response.replace("<think>", "<blockquote expandable><b>Thinking...</b>").replace("</think>", "</blockquote>")
# Return response as html message response = response.replace("<think>", self.HTML_QUOTE_BLOCK_START).replace("</think>", self.HTML_QUOTE_BLOCK_END)
if len(response) > 4096: if len(response) > 4096:
# If the response is too long, split it into chunks
chunks = [response[i:i + 4096] for i in range(0, len(response), 4096)] chunks = [response[i:i + 4096] for i in range(0, len(response), 4096)]
for chunk in chunks: for chunk in chunks:
await update.message.reply_text(chunk) await update.message.reply_text(chunk)
# Add a small delay to avoid flooding
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
else: else:
await update.message.reply_text(response) await update.message.reply_text(response)
@@ -88,21 +92,21 @@ class TelegramHelper:
await query.edit_message_text(text=result) await query.edit_message_text(text=result)
async def reboot(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: async def reboot(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
user_message = update.message.text.split() # Split the message to check for 'claude' user_message = update.message.text.split()
if len(user_message) > 1 and user_message[1].lower() == 'claude': if len(user_message) > 1 and user_message[1].lower() == self.CLAUDE_REBOOT_TARGET:
open('./.reboot_claude', 'w').close() # Create an empty file open(self.REBOOT_CLAUDE_FILE, 'w').close()
if update: if update:
await update.message.reply_text("Rebooting the bot...") await update.message.reply_text("Rebooting the bot...")
logging.info("Received reboot command. Exiting process...") logging.info("Received reboot command. Exiting process...")
reboot_file_path = "./.doreboot" reboot_file_path = self.REBOOT_FILE
if not os.path.exists(reboot_file_path): if not os.path.exists(reboot_file_path):
with open(reboot_file_path, 'w') as f: with open(reboot_file_path, 'w') as f:
f.write(str(update.effective_chat.id) if update else "") f.write(str(update.effective_chat.id) if update else "")
sys.exit(0) sys.exit(0)
async def check_doreboot_file(self, application: Application): async def check_doreboot_file(self, application: Application):
reboot_file_path = "./.doreboot" reboot_file_path = self.REBOOT_FILE
if os.path.exists(reboot_file_path): if os.path.exists(reboot_file_path):
with open(reboot_file_path, 'r') as f: with open(reboot_file_path, 'r') as f:
chat_id = f.read().strip() chat_id = f.read().strip()
@@ -122,16 +126,12 @@ class TelegramHelper:
application.add_handler(CommandHandler("status", self.status)) application.add_handler(CommandHandler("status", self.status))
application.add_handler(CommandHandler("reboot", self.reboot)) application.add_handler(CommandHandler("reboot", self.reboot))
application.add_handler(CommandHandler("browse", self.browse)) application.add_handler(CommandHandler("browse", self.browse))
application.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, self.handle_message)) application.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, self.handle_message))\
application.add_handler(CallbackQueryHandler(self.abort_processing, pattern='^abort$')) application.add_handler(CallbackQueryHandler(self.abort_processing, pattern='^abort$'))
application.add_handler(CallbackQueryHandler(button_callback, pattern='^(browse|file):')) application.add_handler(CallbackQueryHandler(button_callback, pattern='^(browse|file):'))
logging.info("Bot is running...") logging.info("Bot is running...")
# Check for .doreboot file and send message if it exists
asyncio.get_event_loop().create_task(self.check_doreboot_file(application)) asyncio.get_event_loop().create_task(self.check_doreboot_file(application))
# Commenting out the commit checking task
# asyncio.get_event_loop().create_task(self.check_for_new_commits())
application.run_polling() application.run_polling()