added gemini inference bot
This commit is contained in:
@@ -1,30 +1,51 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
from base_telegram_inference_bot import BaseTelegramInferenceBot
|
from base_telegram_inference_bot import BaseTelegramInferenceBot # Assuming this base class exists
|
||||||
from telegram_helper import TelegramHelper
|
from telegram_helper import TelegramHelper # Assuming this helper class exists
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
|
# Ensure basic logging is configured if not done elsewhere
|
||||||
|
# logging.basicConfig(level=logging.INFO) # Example: You might have a more sophisticated setup
|
||||||
|
|
||||||
class ChatGPTTelegramInferenceBot(BaseTelegramInferenceBot):
|
class ChatGPTTelegramInferenceBot(BaseTelegramInferenceBot):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"), base_url="http://localhost:1234/v1")
|
self.client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
|
||||||
self.model = "qwen3-1.7b"
|
|
||||||
self.max_tokens = 32768
|
self._configure_model_and_tokens(
|
||||||
|
os.environ.get("OPENAI_SMALL_MODEL"), # Default model
|
||||||
|
os.environ.get("OPENAI_SMALL_MODEL_MAX_TOKENS") # Default tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
def _configure_model_and_tokens(self, model_name, max_tokens_str, default_max_tokens=1000):
|
||||||
|
self.model = model_name
|
||||||
|
try:
|
||||||
|
self.max_tokens = int(max_tokens_str) if max_tokens_str is not None else default_max_tokens
|
||||||
|
except ValueError:
|
||||||
|
logging.error(f"Invalid value for max_tokens: {max_tokens_str}. Using default {default_max_tokens}.")
|
||||||
|
self.max_tokens = default_max_tokens
|
||||||
|
logging.info(f"Configured to use model: {self.model} with max_tokens: {self.max_tokens}")
|
||||||
|
|
||||||
def get_chat_response(self, messages):
|
def get_chat_response(self, messages):
|
||||||
response = self.client.chat.completions.create(
|
try:
|
||||||
model=self.model,
|
response = self.client.chat.completions.create(
|
||||||
messages=[{"role": "system", "content": self.system_prompt}] + messages,
|
model=self.model,
|
||||||
tools=self.functions,
|
messages=messages, # The system prompt is expected to be part of messages here
|
||||||
tool_choice = "auto",
|
tools=self.functions if hasattr(self, 'functions') and self.functions else None,
|
||||||
max_tokens=self.max_tokens
|
tool_choice="auto" if hasattr(self, 'functions') and self.functions else None,
|
||||||
)
|
max_tokens=self.max_tokens
|
||||||
return response
|
)
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"OpenAI API call failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
async def handle_message(self, user_id, user_message):
|
async def handle_message(self, user_id, user_message):
|
||||||
if user_id not in self.conversation_history:
|
if user_id not in self.conversation_history:
|
||||||
self.conversation_history[user_id] = []
|
self.conversation_history[user_id] = []
|
||||||
|
if hasattr(self, 'system_prompt') and self.system_prompt:
|
||||||
|
self.conversation_history[user_id].append({"role": "system", "content": self.system_prompt})
|
||||||
|
|
||||||
self.conversation_history[user_id].append({"role": "user", "content": user_message})
|
self.conversation_history[user_id].append({"role": "user", "content": user_message})
|
||||||
messages = self.conversation_history[user_id]
|
messages = self.conversation_history[user_id]
|
||||||
@@ -44,10 +65,12 @@ class ChatGPTTelegramInferenceBot(BaseTelegramInferenceBot):
|
|||||||
tool_use_results = []
|
tool_use_results = []
|
||||||
|
|
||||||
while len(tool_calls) > 0:
|
while len(tool_calls) > 0:
|
||||||
tool_call = tool_calls.pop(0).function
|
tool_call_message = tool_calls.pop(0)
|
||||||
|
tool_call_id = tool_call_message.id
|
||||||
|
tool_call = tool_call_message.function
|
||||||
tool_response = self.call_tool(tool_call.name, tool_call.arguments)
|
tool_response = self.call_tool(tool_call.name, tool_call.arguments)
|
||||||
try:
|
try:
|
||||||
tool_use_results.append({"role": "tool", "name": tool_call.name, "content": tool_response})
|
tool_use_results.append({"role": "tool", "tool_call_id": tool_call_id, "name":tool_call.name, "content": str(tool_response) })
|
||||||
except (TypeError, ValueError) as e:
|
except (TypeError, ValueError) as e:
|
||||||
logging.error(f"Failed to serialize tool response: {e}")
|
logging.error(f"Failed to serialize tool response: {e}")
|
||||||
tool_use_results.append({"role": "function", "name": tool_call.name, "content": "Serialization error"})
|
tool_use_results.append({"role": "function", "name": tool_call.name, "content": "Serialization error"})
|
||||||
@@ -57,8 +80,8 @@ class ChatGPTTelegramInferenceBot(BaseTelegramInferenceBot):
|
|||||||
response = self.get_chat_response(messages)
|
response = self.get_chat_response(messages)
|
||||||
|
|
||||||
for message_part in response.choices:
|
for message_part in response.choices:
|
||||||
if message_part.finish_reason == "function_call":
|
if message_part.finish_reason == "tool_calls":
|
||||||
tool_calls.append(message_part.message.function_call)
|
tool_calls.extend(message_part.message.tool_calls)
|
||||||
|
|
||||||
messages.append(response.choices[0].message)
|
messages.append(response.choices[0].message)
|
||||||
|
|
||||||
@@ -71,36 +94,50 @@ class ChatGPTTelegramInferenceBot(BaseTelegramInferenceBot):
|
|||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
logging.info("Bot started")
|
logging.info("Bot started")
|
||||||
|
# Potentially call super().start() if it exists and does something
|
||||||
|
|
||||||
async def clear(self, user_id):
|
async def clear(self, user_id):
|
||||||
super().clear_conversation(user_id)
|
super().clear_conversation(user_id)
|
||||||
logging.info(f"Cleared conversation history for user {user_id}")
|
|
||||||
|
|
||||||
async def status(self):
|
async def status(self):
|
||||||
return f"Currently using: {self.model}"
|
return f"Currently using: {self.model}, Max Tokens: {self.max_tokens}"
|
||||||
|
|
||||||
async def abort_processing(self, user_id):
|
async def abort_processing(self, user_id):
|
||||||
if user_id in self.processing_status:
|
# This depends on how processing_status is managed, likely in BaseTelegramInferenceBot
|
||||||
self.processing_status[user_id]["processing"] = False
|
if hasattr(self, 'processing_status') and user_id in self.processing_status:
|
||||||
await self.clear(user_id)
|
self.processing_status[user_id]["processing"] = False # Example
|
||||||
return "Processing aborted."
|
await self.clear(user_id) # Clearing conversation on abort might be desired
|
||||||
|
return "Processing aborted and conversation cleared."
|
||||||
else:
|
else:
|
||||||
return "No active processing to abort."
|
# If not tracking processing_status here, just clear for safety
|
||||||
|
await self.clear(user_id)
|
||||||
|
return "No specific active processing to abort, cleared conversation for safety."
|
||||||
|
|
||||||
async def switch_model(self):
|
async def switch_model(self):
|
||||||
if self.model == "qwen3-4b":
|
current_small_model = os.environ.get("OPENAI_SMALL_MODEL")
|
||||||
self.model = "qwen3-30b-a3b"
|
current_large_model = os.environ.get("OPENAI_LARGE_MODEL")
|
||||||
# self.max_tokens = 4096
|
|
||||||
|
if self.model == current_small_model:
|
||||||
|
target_model = current_large_model
|
||||||
|
target_max_tokens = os.environ.get("OPENAI_LARGE_MODEL_MAX_TOKENS")
|
||||||
else:
|
else:
|
||||||
self.model = "qwen3-4b"
|
target_model = current_small_model
|
||||||
# self.max_tokens = 16384
|
target_max_tokens = os.environ.get("OPENAI_SMALL_MODEL_MAX_TOKENS")
|
||||||
logging.info(f"Switched to model: {self.model}")
|
|
||||||
|
self._configure_model_and_tokens(target_model, target_max_tokens)
|
||||||
return f"Switched to model: {self.model}"
|
return f"Switched to model: {self.model}"
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
# Ensure OPENAI_API_KEY and other environment variables are set
|
||||||
|
if not os.environ.get("OPENAI_API_KEY"):
|
||||||
|
logging.error("FATAL: OPENAI_API_KEY environment variable not set.")
|
||||||
|
return
|
||||||
|
|
||||||
bot = ChatGPTTelegramInferenceBot()
|
bot = ChatGPTTelegramInferenceBot()
|
||||||
telegram_helper = TelegramHelper(bot)
|
telegram_helper = TelegramHelper(bot)
|
||||||
telegram_helper.run()
|
telegram_helper.run()
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||||
main()
|
main()
|
||||||
@@ -0,0 +1,143 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
from base_telegram_inference_bot import BaseTelegramInferenceBot # Assuming this base class exists
|
||||||
|
from telegram_helper import TelegramHelper # Assuming this helper class exists
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
# Ensure basic logging is configured if not done elsewhere
|
||||||
|
# logging.basicConfig(level=logging.INFO) # Example: You might have a more sophisticated setup
|
||||||
|
|
||||||
|
class GeminiTelegramInferenceBot(BaseTelegramInferenceBot):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.client = OpenAI(api_key=os.environ.get("GEMINI_API_KEY"), base_url=os.environ.get("GEMINI_API_BASE_URL"))
|
||||||
|
|
||||||
|
self._configure_model_and_tokens(
|
||||||
|
os.environ.get("GEMINI_SMALL_MODEL"), # Default model
|
||||||
|
os.environ.get("GEMINI_SMALL_MODEL_MAX_TOKENS") # Default tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
def _configure_model_and_tokens(self, model_name, max_tokens_str, default_max_tokens=1000):
|
||||||
|
self.model = model_name
|
||||||
|
try:
|
||||||
|
self.max_tokens = int(max_tokens_str) if max_tokens_str is not None else default_max_tokens
|
||||||
|
except ValueError:
|
||||||
|
logging.error(f"Invalid value for max_tokens: {max_tokens_str}. Using default {default_max_tokens}.")
|
||||||
|
self.max_tokens = default_max_tokens
|
||||||
|
logging.info(f"Configured to use model: {self.model} with max_tokens: {self.max_tokens}")
|
||||||
|
|
||||||
|
def get_chat_response(self, messages):
|
||||||
|
try:
|
||||||
|
response = self.client.chat.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
messages=messages, # The system prompt is expected to be part of messages here
|
||||||
|
tools=self.functions if hasattr(self, 'functions') and self.functions else None,
|
||||||
|
tool_choice="auto" if hasattr(self, 'functions') and self.functions else None,
|
||||||
|
max_tokens=self.max_tokens
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Gemini API call failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def handle_message(self, user_id, user_message):
|
||||||
|
if user_id not in self.conversation_history:
|
||||||
|
self.conversation_history[user_id] = []
|
||||||
|
if hasattr(self, 'system_prompt') and self.system_prompt:
|
||||||
|
self.conversation_history[user_id].append({"role": "system", "content": self.system_prompt})
|
||||||
|
|
||||||
|
self.conversation_history[user_id].append({"role": "user", "content": user_message})
|
||||||
|
messages = self.conversation_history[user_id]
|
||||||
|
|
||||||
|
response = self.get_chat_response(messages)
|
||||||
|
|
||||||
|
tool_calls = []
|
||||||
|
|
||||||
|
for message_part in response.choices:
|
||||||
|
if message_part.finish_reason == "tool_calls":
|
||||||
|
tool_calls.extend(message_part.message.tool_calls)
|
||||||
|
|
||||||
|
messages.append(response.choices[0].message)
|
||||||
|
|
||||||
|
tool_use_count = 0
|
||||||
|
while len(tool_calls) > 0 and tool_use_count < 500:
|
||||||
|
tool_use_results = []
|
||||||
|
|
||||||
|
while len(tool_calls) > 0:
|
||||||
|
tool_call_message = tool_calls.pop(0)
|
||||||
|
tool_call_id = tool_call_message.id
|
||||||
|
tool_call = tool_call_message.function
|
||||||
|
tool_response = self.call_tool(tool_call.name, tool_call.arguments)
|
||||||
|
try:
|
||||||
|
tool_use_results.append({"role": "tool", "tool_call_id": tool_call_id, "name":tool_call.name, "content": str(tool_response) })
|
||||||
|
except (TypeError, ValueError) as e:
|
||||||
|
logging.error(f"Failed to serialize tool response: {e}")
|
||||||
|
tool_use_results.append({"role": "function", "name": tool_call.name, "content": "Serialization error"})
|
||||||
|
|
||||||
|
messages.extend(tool_use_results)
|
||||||
|
|
||||||
|
response = self.get_chat_response(messages)
|
||||||
|
|
||||||
|
for message_part in response.choices:
|
||||||
|
if message_part.finish_reason == "tool_calls":
|
||||||
|
tool_calls.extend(message_part.message.tool_calls)
|
||||||
|
|
||||||
|
messages.append(response.choices[0].message)
|
||||||
|
|
||||||
|
tool_use_count += 1
|
||||||
|
|
||||||
|
if len(self.conversation_history[user_id]) > 2000:
|
||||||
|
self.conversation_history[user_id] = self.conversation_history[user_id][-2000:]
|
||||||
|
|
||||||
|
return messages[-1].content
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
logging.info("Bot started")
|
||||||
|
# Potentially call super().start() if it exists and does something
|
||||||
|
|
||||||
|
async def clear(self, user_id):
|
||||||
|
super().clear_conversation(user_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def status(self):
|
||||||
|
return f"Currently using: {self.model}, Max Tokens: {self.max_tokens}"
|
||||||
|
|
||||||
|
async def abort_processing(self, user_id):
|
||||||
|
# This depends on how processing_status is managed, likely in BaseTelegramInferenceBot
|
||||||
|
if hasattr(self, 'processing_status') and user_id in self.processing_status:
|
||||||
|
self.processing_status[user_id]["processing"] = False # Example
|
||||||
|
await self.clear(user_id) # Clearing conversation on abort might be desired
|
||||||
|
return "Processing aborted and conversation cleared."
|
||||||
|
else:
|
||||||
|
# If not tracking processing_status here, just clear for safety
|
||||||
|
await self.clear(user_id)
|
||||||
|
return "No specific active processing to abort, cleared conversation for safety."
|
||||||
|
|
||||||
|
async def switch_model(self):
|
||||||
|
current_small_model = os.environ.get("GEMINI_SMALL_MODEL")
|
||||||
|
current_large_model = os.environ.get("GEMINI_LARGE_MODEL")
|
||||||
|
|
||||||
|
if self.model == current_small_model:
|
||||||
|
target_model = current_large_model
|
||||||
|
target_max_tokens = os.environ.get("GEMINI_LARGE_MODEL_MAX_TOKENS")
|
||||||
|
else:
|
||||||
|
target_model = current_small_model
|
||||||
|
target_max_tokens = os.environ.get("GEMINI_SMALL_MODEL_MAX_TOKENS")
|
||||||
|
|
||||||
|
self._configure_model_and_tokens(target_model, target_max_tokens)
|
||||||
|
return f"Switched to model: {self.model}"
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Ensure GEMINI_API_KEY and other environment variables are set
|
||||||
|
if not os.environ.get("GEMINI_API_KEY"):
|
||||||
|
logging.error("FATAL: GEMINI_API_KEY environment variable not set.")
|
||||||
|
return
|
||||||
|
|
||||||
|
bot = GeminiTelegramInferenceBot()
|
||||||
|
telegram_helper = TelegramHelper(bot)
|
||||||
|
telegram_helper.run()
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||||
|
main()
|
||||||
+4
-2
@@ -1,4 +1,4 @@
|
|||||||
requests==2.26.0
|
requests
|
||||||
python-telegram-bot==21.4
|
python-telegram-bot==21.4
|
||||||
openai==1.41.0
|
openai==1.41.0
|
||||||
python-dotenv==1.0.1
|
python-dotenv==1.0.1
|
||||||
@@ -6,4 +6,6 @@ discord.py==2.4.0
|
|||||||
anthropic==0.34.0
|
anthropic==0.34.0
|
||||||
GitPython==3.1.43
|
GitPython==3.1.43
|
||||||
pytest
|
pytest
|
||||||
pytest-cov
|
pytest-cov
|
||||||
|
google-genai
|
||||||
|
httpx==0.27.2
|
||||||
+9
-1
@@ -65,7 +65,15 @@ class TelegramHelper:
|
|||||||
del self.bot.processing_status[user_id]
|
del self.bot.processing_status[user_id]
|
||||||
response = response.replace("<think>", "<blockquote expandable><b>Thinking...</b>").replace("</think>", "</blockquote>")
|
response = response.replace("<think>", "<blockquote expandable><b>Thinking...</b>").replace("</think>", "</blockquote>")
|
||||||
# Return response as html message
|
# Return response as html message
|
||||||
await update.message.reply_html(response)
|
if len(response) > 4096:
|
||||||
|
# If the response is too long, split it into chunks
|
||||||
|
chunks = [response[i:i + 4096] for i in range(0, len(response), 4096)]
|
||||||
|
for chunk in chunks:
|
||||||
|
await update.message.reply_text(chunk)
|
||||||
|
# Add a small delay to avoid flooding
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
else:
|
||||||
|
await update.message.reply_text(response)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"An error occurred: {str(e)}")
|
logging.error(f"An error occurred: {str(e)}")
|
||||||
|
|||||||
@@ -14,27 +14,30 @@ class StandaloneLLMTool(BaseTool):
|
|||||||
def get_functions(self):
|
def get_functions(self):
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
"name": "call_external_llm",
|
"type": "function",
|
||||||
"description": "Call an external language model",
|
"function": {
|
||||||
"parameters": {
|
"name": "call_external_llm",
|
||||||
"type": "object",
|
"description": "Call an external language model",
|
||||||
"properties": {
|
"parameters": {
|
||||||
"prompt": {
|
"type": "object",
|
||||||
"type": "string",
|
"properties": {
|
||||||
"description": "The prompt you are providing"
|
"prompt": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The prompt you are providing"
|
||||||
|
},
|
||||||
|
"model": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The model to use for generating the detailed instructions. Use mini for most coding tasks, preview when needing sophisticated reasoning",
|
||||||
|
"enum": ["o1-mini", "o1-preview"],
|
||||||
|
"default": "o1-mini"
|
||||||
|
},
|
||||||
|
"max_tokens": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "The maximum number of tokens to use for generating the detailed instructions. Default is 16384.",
|
||||||
|
}
|
||||||
},
|
},
|
||||||
"model": {
|
"required": ["prompt"]
|
||||||
"type": "string",
|
}
|
||||||
"description": "The model to use for generating the detailed instructions. Use mini for most coding tasks, preview when needing sophisticated reasoning",
|
|
||||||
"enum": ["o1-mini", "o1-preview"],
|
|
||||||
"default": "o1-mini"
|
|
||||||
},
|
|
||||||
"max_tokens": {
|
|
||||||
"type": "integer",
|
|
||||||
"description": "The maximum number of tokens to use for generating the detailed instructions. Default is 16384.",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["prompt"]
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|||||||
Reference in New Issue
Block a user