106 lines
3.7 KiB
Python
106 lines
3.7 KiB
Python
import json
|
|
import os
|
|
import logging
|
|
from base_telegram_inference_bot import BaseTelegramInferenceBot
|
|
from telegram_helper import TelegramHelper
|
|
from openai import OpenAI
|
|
|
|
class ChatGPTTelegramInferenceBot(BaseTelegramInferenceBot):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"), base_url="http://localhost:1234/v1")
|
|
self.model = "qwen3-1.7b"
|
|
self.max_tokens = 32768
|
|
|
|
def get_chat_response(self, messages):
|
|
response = self.client.chat.completions.create(
|
|
model=self.model,
|
|
messages=[{"role": "system", "content": self.system_prompt}] + messages,
|
|
tools=self.functions,
|
|
tool_choice = "auto",
|
|
max_tokens=self.max_tokens
|
|
)
|
|
return response
|
|
|
|
async def handle_message(self, user_id, user_message):
|
|
if user_id not in self.conversation_history:
|
|
self.conversation_history[user_id] = []
|
|
|
|
self.conversation_history[user_id].append({"role": "user", "content": user_message})
|
|
messages = self.conversation_history[user_id]
|
|
|
|
response = self.get_chat_response(messages)
|
|
|
|
tool_calls = []
|
|
|
|
for message_part in response.choices:
|
|
if message_part.finish_reason == "tool_calls":
|
|
tool_calls.extend(message_part.message.tool_calls)
|
|
|
|
messages.append(response.choices[0].message)
|
|
|
|
tool_use_count = 0
|
|
while len(tool_calls) > 0 and tool_use_count < 500:
|
|
tool_use_results = []
|
|
|
|
while len(tool_calls) > 0:
|
|
tool_call = tool_calls.pop(0).function
|
|
tool_response = self.call_tool(tool_call.name, tool_call.arguments)
|
|
try:
|
|
tool_use_results.append({"role": "tool", "name": tool_call.name, "content": tool_response})
|
|
except (TypeError, ValueError) as e:
|
|
logging.error(f"Failed to serialize tool response: {e}")
|
|
tool_use_results.append({"role": "function", "name": tool_call.name, "content": "Serialization error"})
|
|
|
|
messages.extend(tool_use_results)
|
|
|
|
response = self.get_chat_response(messages)
|
|
|
|
for message_part in response.choices:
|
|
if message_part.finish_reason == "function_call":
|
|
tool_calls.append(message_part.message.function_call)
|
|
|
|
messages.append(response.choices[0].message)
|
|
|
|
tool_use_count += 1
|
|
|
|
if len(self.conversation_history[user_id]) > 20:
|
|
self.conversation_history[user_id] = self.conversation_history[user_id][-20:]
|
|
|
|
return messages[-1].content
|
|
|
|
async def start(self):
|
|
logging.info("Bot started")
|
|
|
|
async def clear(self, user_id):
|
|
super().clear_conversation(user_id)
|
|
logging.info(f"Cleared conversation history for user {user_id}")
|
|
|
|
async def status(self):
|
|
return f"Currently using: {self.model}"
|
|
|
|
async def abort_processing(self, user_id):
|
|
if user_id in self.processing_status:
|
|
self.processing_status[user_id]["processing"] = False
|
|
await self.clear(user_id)
|
|
return "Processing aborted."
|
|
else:
|
|
return "No active processing to abort."
|
|
|
|
async def switch_model(self):
|
|
if self.model == "qwen3-4b":
|
|
self.model = "qwen3-30b-a3b"
|
|
# self.max_tokens = 4096
|
|
else:
|
|
self.model = "qwen3-4b"
|
|
# self.max_tokens = 16384
|
|
logging.info(f"Switched to model: {self.model}")
|
|
return f"Switched to model: {self.model}"
|
|
|
|
def main():
|
|
bot = ChatGPTTelegramInferenceBot()
|
|
telegram_helper = TelegramHelper(bot)
|
|
telegram_helper.run()
|
|
|
|
if __name__ == '__main__':
|
|
main() |