added gemini inference bot

This commit is contained in:
2025-06-02 13:23:02 -05:00
parent 179718595b
commit a24f56531e
5 changed files with 247 additions and 54 deletions
+67 -30
View File
@@ -1,30 +1,51 @@
import json import json
import os import os
import logging import logging
from base_telegram_inference_bot import BaseTelegramInferenceBot from base_telegram_inference_bot import BaseTelegramInferenceBot # Assuming this base class exists
from telegram_helper import TelegramHelper from telegram_helper import TelegramHelper # Assuming this helper class exists
from openai import OpenAI from openai import OpenAI
# Ensure basic logging is configured if not done elsewhere
# logging.basicConfig(level=logging.INFO) # Example: You might have a more sophisticated setup
class ChatGPTTelegramInferenceBot(BaseTelegramInferenceBot): class ChatGPTTelegramInferenceBot(BaseTelegramInferenceBot):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"), base_url="http://localhost:1234/v1") self.client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
self.model = "qwen3-1.7b"
self.max_tokens = 32768 self._configure_model_and_tokens(
os.environ.get("OPENAI_SMALL_MODEL"), # Default model
os.environ.get("OPENAI_SMALL_MODEL_MAX_TOKENS") # Default tokens
)
def _configure_model_and_tokens(self, model_name, max_tokens_str, default_max_tokens=1000):
self.model = model_name
try:
self.max_tokens = int(max_tokens_str) if max_tokens_str is not None else default_max_tokens
except ValueError:
logging.error(f"Invalid value for max_tokens: {max_tokens_str}. Using default {default_max_tokens}.")
self.max_tokens = default_max_tokens
logging.info(f"Configured to use model: {self.model} with max_tokens: {self.max_tokens}")
def get_chat_response(self, messages): def get_chat_response(self, messages):
response = self.client.chat.completions.create( try:
model=self.model, response = self.client.chat.completions.create(
messages=[{"role": "system", "content": self.system_prompt}] + messages, model=self.model,
tools=self.functions, messages=messages, # The system prompt is expected to be part of messages here
tool_choice = "auto", tools=self.functions if hasattr(self, 'functions') and self.functions else None,
max_tokens=self.max_tokens tool_choice="auto" if hasattr(self, 'functions') and self.functions else None,
) max_tokens=self.max_tokens
return response )
return response
except Exception as e:
logging.error(f"OpenAI API call failed: {e}")
raise
async def handle_message(self, user_id, user_message): async def handle_message(self, user_id, user_message):
if user_id not in self.conversation_history: if user_id not in self.conversation_history:
self.conversation_history[user_id] = [] self.conversation_history[user_id] = []
if hasattr(self, 'system_prompt') and self.system_prompt:
self.conversation_history[user_id].append({"role": "system", "content": self.system_prompt})
self.conversation_history[user_id].append({"role": "user", "content": user_message}) self.conversation_history[user_id].append({"role": "user", "content": user_message})
messages = self.conversation_history[user_id] messages = self.conversation_history[user_id]
@@ -44,10 +65,12 @@ class ChatGPTTelegramInferenceBot(BaseTelegramInferenceBot):
tool_use_results = [] tool_use_results = []
while len(tool_calls) > 0: while len(tool_calls) > 0:
tool_call = tool_calls.pop(0).function tool_call_message = tool_calls.pop(0)
tool_call_id = tool_call_message.id
tool_call = tool_call_message.function
tool_response = self.call_tool(tool_call.name, tool_call.arguments) tool_response = self.call_tool(tool_call.name, tool_call.arguments)
try: try:
tool_use_results.append({"role": "tool", "name": tool_call.name, "content": tool_response}) tool_use_results.append({"role": "tool", "tool_call_id": tool_call_id, "name":tool_call.name, "content": str(tool_response) })
except (TypeError, ValueError) as e: except (TypeError, ValueError) as e:
logging.error(f"Failed to serialize tool response: {e}") logging.error(f"Failed to serialize tool response: {e}")
tool_use_results.append({"role": "function", "name": tool_call.name, "content": "Serialization error"}) tool_use_results.append({"role": "function", "name": tool_call.name, "content": "Serialization error"})
@@ -57,8 +80,8 @@ class ChatGPTTelegramInferenceBot(BaseTelegramInferenceBot):
response = self.get_chat_response(messages) response = self.get_chat_response(messages)
for message_part in response.choices: for message_part in response.choices:
if message_part.finish_reason == "function_call": if message_part.finish_reason == "tool_calls":
tool_calls.append(message_part.message.function_call) tool_calls.extend(message_part.message.tool_calls)
messages.append(response.choices[0].message) messages.append(response.choices[0].message)
@@ -71,36 +94,50 @@ class ChatGPTTelegramInferenceBot(BaseTelegramInferenceBot):
async def start(self): async def start(self):
logging.info("Bot started") logging.info("Bot started")
# Potentially call super().start() if it exists and does something
async def clear(self, user_id): async def clear(self, user_id):
super().clear_conversation(user_id) super().clear_conversation(user_id)
logging.info(f"Cleared conversation history for user {user_id}")
async def status(self): async def status(self):
return f"Currently using: {self.model}" return f"Currently using: {self.model}, Max Tokens: {self.max_tokens}"
async def abort_processing(self, user_id): async def abort_processing(self, user_id):
if user_id in self.processing_status: # This depends on how processing_status is managed, likely in BaseTelegramInferenceBot
self.processing_status[user_id]["processing"] = False if hasattr(self, 'processing_status') and user_id in self.processing_status:
await self.clear(user_id) self.processing_status[user_id]["processing"] = False # Example
return "Processing aborted." await self.clear(user_id) # Clearing conversation on abort might be desired
return "Processing aborted and conversation cleared."
else: else:
return "No active processing to abort." # If not tracking processing_status here, just clear for safety
await self.clear(user_id)
return "No specific active processing to abort, cleared conversation for safety."
async def switch_model(self): async def switch_model(self):
if self.model == "qwen3-4b": current_small_model = os.environ.get("OPENAI_SMALL_MODEL")
self.model = "qwen3-30b-a3b" current_large_model = os.environ.get("OPENAI_LARGE_MODEL")
# self.max_tokens = 4096
if self.model == current_small_model:
target_model = current_large_model
target_max_tokens = os.environ.get("OPENAI_LARGE_MODEL_MAX_TOKENS")
else: else:
self.model = "qwen3-4b" target_model = current_small_model
# self.max_tokens = 16384 target_max_tokens = os.environ.get("OPENAI_SMALL_MODEL_MAX_TOKENS")
logging.info(f"Switched to model: {self.model}")
self._configure_model_and_tokens(target_model, target_max_tokens)
return f"Switched to model: {self.model}" return f"Switched to model: {self.model}"
def main(): def main():
# Ensure OPENAI_API_KEY and other environment variables are set
if not os.environ.get("OPENAI_API_KEY"):
logging.error("FATAL: OPENAI_API_KEY environment variable not set.")
return
bot = ChatGPTTelegramInferenceBot() bot = ChatGPTTelegramInferenceBot()
telegram_helper = TelegramHelper(bot) telegram_helper = TelegramHelper(bot)
telegram_helper.run() telegram_helper.run()
if __name__ == '__main__': if __name__ == '__main__':
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
main() main()
+143
View File
@@ -0,0 +1,143 @@
import json
import os
import logging
from base_telegram_inference_bot import BaseTelegramInferenceBot # Assuming this base class exists
from telegram_helper import TelegramHelper # Assuming this helper class exists
from openai import OpenAI
# Ensure basic logging is configured if not done elsewhere
# logging.basicConfig(level=logging.INFO) # Example: You might have a more sophisticated setup
class GeminiTelegramInferenceBot(BaseTelegramInferenceBot):
def __init__(self):
super().__init__()
self.client = OpenAI(api_key=os.environ.get("GEMINI_API_KEY"), base_url=os.environ.get("GEMINI_API_BASE_URL"))
self._configure_model_and_tokens(
os.environ.get("GEMINI_SMALL_MODEL"), # Default model
os.environ.get("GEMINI_SMALL_MODEL_MAX_TOKENS") # Default tokens
)
def _configure_model_and_tokens(self, model_name, max_tokens_str, default_max_tokens=1000):
self.model = model_name
try:
self.max_tokens = int(max_tokens_str) if max_tokens_str is not None else default_max_tokens
except ValueError:
logging.error(f"Invalid value for max_tokens: {max_tokens_str}. Using default {default_max_tokens}.")
self.max_tokens = default_max_tokens
logging.info(f"Configured to use model: {self.model} with max_tokens: {self.max_tokens}")
def get_chat_response(self, messages):
try:
response = self.client.chat.completions.create(
model=self.model,
messages=messages, # The system prompt is expected to be part of messages here
tools=self.functions if hasattr(self, 'functions') and self.functions else None,
tool_choice="auto" if hasattr(self, 'functions') and self.functions else None,
max_tokens=self.max_tokens
)
return response
except Exception as e:
logging.error(f"Gemini API call failed: {e}")
raise
async def handle_message(self, user_id, user_message):
if user_id not in self.conversation_history:
self.conversation_history[user_id] = []
if hasattr(self, 'system_prompt') and self.system_prompt:
self.conversation_history[user_id].append({"role": "system", "content": self.system_prompt})
self.conversation_history[user_id].append({"role": "user", "content": user_message})
messages = self.conversation_history[user_id]
response = self.get_chat_response(messages)
tool_calls = []
for message_part in response.choices:
if message_part.finish_reason == "tool_calls":
tool_calls.extend(message_part.message.tool_calls)
messages.append(response.choices[0].message)
tool_use_count = 0
while len(tool_calls) > 0 and tool_use_count < 500:
tool_use_results = []
while len(tool_calls) > 0:
tool_call_message = tool_calls.pop(0)
tool_call_id = tool_call_message.id
tool_call = tool_call_message.function
tool_response = self.call_tool(tool_call.name, tool_call.arguments)
try:
tool_use_results.append({"role": "tool", "tool_call_id": tool_call_id, "name":tool_call.name, "content": str(tool_response) })
except (TypeError, ValueError) as e:
logging.error(f"Failed to serialize tool response: {e}")
tool_use_results.append({"role": "function", "name": tool_call.name, "content": "Serialization error"})
messages.extend(tool_use_results)
response = self.get_chat_response(messages)
for message_part in response.choices:
if message_part.finish_reason == "tool_calls":
tool_calls.extend(message_part.message.tool_calls)
messages.append(response.choices[0].message)
tool_use_count += 1
if len(self.conversation_history[user_id]) > 2000:
self.conversation_history[user_id] = self.conversation_history[user_id][-2000:]
return messages[-1].content
async def start(self):
logging.info("Bot started")
# Potentially call super().start() if it exists and does something
async def clear(self, user_id):
super().clear_conversation(user_id)
async def status(self):
return f"Currently using: {self.model}, Max Tokens: {self.max_tokens}"
async def abort_processing(self, user_id):
# This depends on how processing_status is managed, likely in BaseTelegramInferenceBot
if hasattr(self, 'processing_status') and user_id in self.processing_status:
self.processing_status[user_id]["processing"] = False # Example
await self.clear(user_id) # Clearing conversation on abort might be desired
return "Processing aborted and conversation cleared."
else:
# If not tracking processing_status here, just clear for safety
await self.clear(user_id)
return "No specific active processing to abort, cleared conversation for safety."
async def switch_model(self):
current_small_model = os.environ.get("GEMINI_SMALL_MODEL")
current_large_model = os.environ.get("GEMINI_LARGE_MODEL")
if self.model == current_small_model:
target_model = current_large_model
target_max_tokens = os.environ.get("GEMINI_LARGE_MODEL_MAX_TOKENS")
else:
target_model = current_small_model
target_max_tokens = os.environ.get("GEMINI_SMALL_MODEL_MAX_TOKENS")
self._configure_model_and_tokens(target_model, target_max_tokens)
return f"Switched to model: {self.model}"
def main():
# Ensure GEMINI_API_KEY and other environment variables are set
if not os.environ.get("GEMINI_API_KEY"):
logging.error("FATAL: GEMINI_API_KEY environment variable not set.")
return
bot = GeminiTelegramInferenceBot()
telegram_helper = TelegramHelper(bot)
telegram_helper.run()
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
main()
+3 -1
View File
@@ -1,4 +1,4 @@
requests==2.26.0 requests
python-telegram-bot==21.4 python-telegram-bot==21.4
openai==1.41.0 openai==1.41.0
python-dotenv==1.0.1 python-dotenv==1.0.1
@@ -7,3 +7,5 @@ anthropic==0.34.0
GitPython==3.1.43 GitPython==3.1.43
pytest pytest
pytest-cov pytest-cov
google-genai
httpx==0.27.2
+9 -1
View File
@@ -65,7 +65,15 @@ class TelegramHelper:
del self.bot.processing_status[user_id] del self.bot.processing_status[user_id]
response = response.replace("<think>", "<blockquote expandable><b>Thinking...</b>").replace("</think>", "</blockquote>") response = response.replace("<think>", "<blockquote expandable><b>Thinking...</b>").replace("</think>", "</blockquote>")
# Return response as html message # Return response as html message
await update.message.reply_html(response) if len(response) > 4096:
# If the response is too long, split it into chunks
chunks = [response[i:i + 4096] for i in range(0, len(response), 4096)]
for chunk in chunks:
await update.message.reply_text(chunk)
# Add a small delay to avoid flooding
await asyncio.sleep(0.1)
else:
await update.message.reply_text(response)
except Exception as e: except Exception as e:
logging.error(f"An error occurred: {str(e)}") logging.error(f"An error occurred: {str(e)}")
+23 -20
View File
@@ -14,27 +14,30 @@ class StandaloneLLMTool(BaseTool):
def get_functions(self): def get_functions(self):
return [ return [
{ {
"name": "call_external_llm", "type": "function",
"description": "Call an external language model", "function": {
"parameters": { "name": "call_external_llm",
"type": "object", "description": "Call an external language model",
"properties": { "parameters": {
"prompt": { "type": "object",
"type": "string", "properties": {
"description": "The prompt you are providing" "prompt": {
"type": "string",
"description": "The prompt you are providing"
},
"model": {
"type": "string",
"description": "The model to use for generating the detailed instructions. Use mini for most coding tasks, preview when needing sophisticated reasoning",
"enum": ["o1-mini", "o1-preview"],
"default": "o1-mini"
},
"max_tokens": {
"type": "integer",
"description": "The maximum number of tokens to use for generating the detailed instructions. Default is 16384.",
}
}, },
"model": { "required": ["prompt"]
"type": "string", }
"description": "The model to use for generating the detailed instructions. Use mini for most coding tasks, preview when needing sophisticated reasoning",
"enum": ["o1-mini", "o1-preview"],
"default": "o1-mini"
},
"max_tokens": {
"type": "integer",
"description": "The maximum number of tokens to use for generating the detailed instructions. Default is 16384.",
}
},
"required": ["prompt"]
} }
} }
] ]