diff --git a/discord_helper.py b/discord_helper.py index 36d6717..4ecbc7b 100644 --- a/discord_helper.py +++ b/discord_helper.py @@ -5,7 +5,6 @@ import time from typing import TypedDict, Union, TypeAlias, List import discord from discord.ext import commands -from browse_command import browse_command from inference_bot import InferenceBot class MessageHandlerLogicResult(TypedDict): @@ -29,33 +28,37 @@ class DiscordHelper(commands.Cog): self.start_time = time.time() self.chunk_message_sleep_duration = chunk_message_sleep_duration if chunk_message_sleep_duration is not None else self.CHUNK_MESSAGE_SLEEP_DURATION self.logger = logger or logging.getLogger(__name__) + self.bot = None # This will be set in run() async def _start_logic(self) -> str: await self.bot_logic.start() return "Hello! I'm your AI assistant for Discord. How can I help you today?" - @commands.command() - async def start(self, ctx): + @discord.app_commands.command(name="start", description="Starts the bot and initializes the conversation.") + async def start(self, interaction: discord.Interaction): + await interaction.response.defer() # Defer the response as the logic might take time response_message = await self._start_logic() - await ctx.send(response_message) + await interaction.followup.send(response_message) async def _clear_logic(self, user_id: int) -> str: self.bot_logic.clear_conversation_history(user_id) return "Conversation history cleared. Let's start fresh!" - @commands.command() - async def clear(self, ctx): - user_id = ctx.author.id + @discord.app_commands.command(name="clear", description="Clears your conversation history with the bot.") + async def clear(self, interaction: discord.Interaction): + await interaction.response.defer() + user_id = interaction.user.id response_message = await self._clear_logic(user_id) - await ctx.send(response_message) + await interaction.followup.send(response_message) async def _status_logic(self) -> str: return await self.bot_logic.get_bot_status() - @commands.command() - async def status(self, ctx): + @discord.app_commands.command(name="status", description="Checks the current status of the bot.") + async def status(self, interaction: discord.Interaction): + await interaction.response.defer() response_message = await self._status_logic() - await ctx.send(response_message) + await interaction.followup.send(response_message) async def _switch_logic(self) -> str: if hasattr(self.bot_logic, 'switch_model'): @@ -63,10 +66,11 @@ class DiscordHelper(commands.Cog): else: return "Model switching is not supported for this bot." - @commands.command() - async def switch(self, ctx): + @discord.app_commands.command(name="switch", description="Switches the underlying model (if supported).") + async def switch(self, interaction: discord.Interaction): + await interaction.response.defer() response_message = await self._switch_logic() - await ctx.send(response_message) + await interaction.followup.send(response_message) async def _handle_message_logic(self, user_id: int, user_message: str) -> LogicResult: try: @@ -78,51 +82,58 @@ class DiscordHelper(commands.Cog): return LogicResult(success=False, response_text=None, error_message=str(e)) @commands.Cog.listener() - async def on_message(self, message): - if message.author.bot: - return - ctx = await self.bot.get_context(message) - if ctx.valid: - await self.bot.process_commands(message) + async def on_message(self, message: discord.Message): + if message.author.bot or message.is_command(): # Ignore bot messages and commands handled by slash commands return + user_id = message.author.id user_message = message.content - try: - logic_result = await self._handle_message_logic(user_id, user_message) - if logic_result["success"]: - response_text = logic_result["response_text"] - if response_text: - if len(response_text) > 2000: - chunks = [response_text[i:i + 2000] for i in range(0, len(response_text), 2000)] - for chunk in chunks: - await message.channel.send(chunk) - await asyncio.sleep(self.chunk_message_sleep_duration) - else: - await message.channel.send(response_text) - else: - self.logger.warning("Successful logic result but no response text.") - await message.channel.send("Something went unexpectedly well, but I have nothing to say.") - else: - await message.channel.send("Sorry, an error occurred while processing your request.") - except Exception as e: - self.logger.error(f"Outer error in handle_message for user {user_id}: {str(e)}") + # Only process messages that mention the bot, or are in a DM + if self.bot.user in message.mentions or isinstance(message.channel, discord.DMChannel): try: - await message.channel.send("Sorry, an unexpected error occurred with the bot.") - except Exception as e_reply: - self.logger.error(f"Failed to send error reply: {e_reply}") + logic_result = await self._handle_message_logic(user_id, user_message) + if logic_result["success"]: + response_text = logic_result["response_text"] + if response_text: + if len(response_text) > 2000: + chunks = [response_text[i:i + 2000] for i in range(0, len(response_text), 2000)] + for chunk in chunks: + await message.channel.send(chunk) + await asyncio.sleep(self.chunk_message_sleep_duration) + else: + await message.channel.send(response_text) + else: + self.logger.warning("Successful logic result but no response text.") + await message.channel.send("Something went unexpectedly well, but I have nothing to say.") + else: + await message.channel.send("Sorry, an error occurred while processing your request.") - @commands.command() - async def browse(self, ctx): - # You may need to adapt browse_command for Discord or ensure compatibility - await browse_command(ctx, self.bot_logic) + except Exception as e: + self.logger.error(f"Outer error in handle_message for user {user_id}: {str(e)}") + try: + await message.channel.send("Sorry, an unexpected error occurred with the bot.") + except Exception as e_reply: + self.logger.error(f"Failed to send error reply: {e_reply}") def run(self): intents = discord.Intents.default() + intents.message_content = True # Required for accessing message.content intents.messages = True + intents.guilds = True # Required for synchronizing slash commands globally bot = commands.Bot(command_prefix="!", intents=intents) - bot.add_cog(self) - self.bot = bot # Save instance for on_message lookup - self.logger.info("Discord bot is running...") + self.bot = bot # Save instance for on_message lookup and command tree + + @bot.event + async def on_ready(): + self.logger.info(f"Logged in as {bot.user} (ID: {bot.user.id})") + await bot.add_cog(self) + try: + synced = await bot.tree.sync() + self.logger.info(f"Synced {len(synced)} command(s).") + except Exception as e: + self.logger.error(f"Error syncing commands: {e}") + self.logger.info("Discord bot is running...") + bot.run(self.discord_bot_token)