diff --git a/discord_helper.py b/discord_helper.py new file mode 100644 index 0000000..36d6717 --- /dev/null +++ b/discord_helper.py @@ -0,0 +1,128 @@ +import os +import logging +import asyncio +import time +from typing import TypedDict, Union, TypeAlias, List +import discord +from discord.ext import commands +from browse_command import browse_command +from inference_bot import InferenceBot + +class MessageHandlerLogicResult(TypedDict): + success: bool + response_text: Union[str, None] + error_message: Union[str, None] + +LogicResult: TypeAlias = MessageHandlerLogicResult + +class DiscordHelper(commands.Cog): + QUOTE_BLOCK_START = '>>> '\n**Thinking...**' + QUOTE_BLOCK_END = '' + CHUNK_MESSAGE_SLEEP_DURATION = 0.1 + + def __init__(self, bot: InferenceBot, + bot_config: dict = None, + chunk_message_sleep_duration: float | None = None, + logger=None): + self.bot_logic = bot # Avoid confusion with Discord's bot + self.discord_bot_token = os.getenv('DISCORD_BOT_TOKEN') + self.start_time = time.time() + self.chunk_message_sleep_duration = chunk_message_sleep_duration if chunk_message_sleep_duration is not None else self.CHUNK_MESSAGE_SLEEP_DURATION + self.logger = logger or logging.getLogger(__name__) + + async def _start_logic(self) -> str: + await self.bot_logic.start() + return "Hello! I'm your AI assistant for Discord. How can I help you today?" + + @commands.command() + async def start(self, ctx): + response_message = await self._start_logic() + await ctx.send(response_message) + + async def _clear_logic(self, user_id: int) -> str: + self.bot_logic.clear_conversation_history(user_id) + return "Conversation history cleared. Let's start fresh!" + + @commands.command() + async def clear(self, ctx): + user_id = ctx.author.id + response_message = await self._clear_logic(user_id) + await ctx.send(response_message) + + async def _status_logic(self) -> str: + return await self.bot_logic.get_bot_status() + + @commands.command() + async def status(self, ctx): + response_message = await self._status_logic() + await ctx.send(response_message) + + async def _switch_logic(self) -> str: + if hasattr(self.bot_logic, 'switch_model'): + return await self.bot_logic.switch_model() + else: + return "Model switching is not supported for this bot." + + @commands.command() + async def switch(self, ctx): + response_message = await self._switch_logic() + await ctx.send(response_message) + + async def _handle_message_logic(self, user_id: int, user_message: str) -> LogicResult: + try: + response = await self.bot_logic.handle_message(user_id, user_message) + processed_response = response.replace("", self.QUOTE_BLOCK_START).replace("", self.QUOTE_BLOCK_END) + return LogicResult(success=True, response_text=processed_response, error_message=None) + except Exception as e: + self.logger.error(f"Error in _handle_message_logic for user {user_id}: {str(e)}") + return LogicResult(success=False, response_text=None, error_message=str(e)) + + @commands.Cog.listener() + async def on_message(self, message): + if message.author.bot: + return + ctx = await self.bot.get_context(message) + if ctx.valid: + await self.bot.process_commands(message) + return + user_id = message.author.id + user_message = message.content + try: + logic_result = await self._handle_message_logic(user_id, user_message) + if logic_result["success"]: + response_text = logic_result["response_text"] + if response_text: + if len(response_text) > 2000: + chunks = [response_text[i:i + 2000] for i in range(0, len(response_text), 2000)] + for chunk in chunks: + await message.channel.send(chunk) + await asyncio.sleep(self.chunk_message_sleep_duration) + else: + await message.channel.send(response_text) + else: + self.logger.warning("Successful logic result but no response text.") + await message.channel.send("Something went unexpectedly well, but I have nothing to say.") + else: + await message.channel.send("Sorry, an error occurred while processing your request.") + + except Exception as e: + self.logger.error(f"Outer error in handle_message for user {user_id}: {str(e)}") + try: + await message.channel.send("Sorry, an unexpected error occurred with the bot.") + except Exception as e_reply: + self.logger.error(f"Failed to send error reply: {e_reply}") + + @commands.command() + async def browse(self, ctx): + # You may need to adapt browse_command for Discord or ensure compatibility + await browse_command(ctx, self.bot_logic) + + def run(self): + intents = discord.Intents.default() + intents.messages = True + + bot = commands.Bot(command_prefix="!", intents=intents) + bot.add_cog(self) + self.bot = bot # Save instance for on_message lookup + self.logger.info("Discord bot is running...") + bot.run(self.discord_bot_token)