import logging import chromadb from chromadb.utils import embedding_functions from inference_bot import InferenceBot # Correctly inherit from the ABC import argparse import os import importlib import torch from transformers import AutoTokenizer, AutoModelForCausalLM # --- RAG Configuration --- # Must match the settings in create_index.py EMBEDDING_MODEL_NAME = os.environ.get("EMBEDDING_MODEL_PATH") CHROMA_DB_PATH = os.environ.get("CHROMA_DB_PATH") CHROMA_COLLECTION_NAME = "github_repo" # Using a powerful open-source reranker model RERANKER_MODEL_NAME = os.environ.get("RERANKER_MODEL_PATH") # Number of initial results to fetch from the database before reranking N_RESULTS_TO_RETRIEVE = 25 # Number of final results to keep after reranking N_RESULTS_TO_KEEP_AFTER_RERANK = 5 # The minimum relevance score a result must have to be included in the final context RERANKER_SCORE_THRESHOLD = 0.5 class RAGInferenceBot(InferenceBot): def __init__(self): """ Initializes the RAG components, including a custom implementation for the Qwen3-Reranker based on its model card. """ logging.info("Initializing RAG components...") self._processing_status = {} try: # --- Embedding and Vector DB Initialization --- self.chroma_client = chromadb.PersistentClient(path=CHROMA_DB_PATH, settings=chromadb.Settings(anonymized_telemetry=False)) self.embedding_function = embedding_functions.SentenceTransformerEmbeddingFunction( model_name=EMBEDDING_MODEL_NAME, device="cuda" ) self.collection = self.chroma_client.get_collection( name=CHROMA_COLLECTION_NAME, embedding_function=self.embedding_function ) logging.info("Successfully connected to ChromaDB collection for RAG.") # --- Custom Reranker Initialization (as per model card) --- self.device = "cuda" if torch.cuda.is_available() else "cpu" logging.info(f"Using device: {self.device} for Reranker model.") self.rerank_tokenizer = AutoTokenizer.from_pretrained(RERANKER_MODEL_NAME, padding_side='left') self.rerank_model = AutoModelForCausalLM.from_pretrained(RERANKER_MODEL_NAME, torch_dtype=torch.float16).to(self.device).eval() # Manually set the padding token if it's missing if self.rerank_tokenizer.pad_token is None: logging.warning("Reranker tokenizer has no pad_token. Setting it to the eos_token.") self.rerank_tokenizer.pad_token = self.rerank_tokenizer.eos_token # Get token IDs for score calculation self.token_false_id = self.rerank_tokenizer.convert_tokens_to_ids("no") self.token_true_id = self.rerank_tokenizer.convert_tokens_to_ids("yes") self.max_length = 8192 # Define and pre-encode the special prefixes and suffixes from the model card prefix_text = "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n" suffix_text = "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n" self.prefix_tokens = self.rerank_tokenizer.encode(prefix_text, add_special_tokens=False) self.suffix_tokens = self.rerank_tokenizer.encode(suffix_text, add_special_tokens=False) logging.info(f"Successfully initialized custom Reranker: {RERANKER_MODEL_NAME}") except Exception as e: logging.fatal(f"Failed to initialize RAG components: {e}", exc_info=True) raise # --- Implementation of all abstract methods from InferenceBot --- @property def processing_status(self): return self._processing_status def clear_conversation_history(self, user_id): pass async def switch_model(self): return "This bot only performs RAG lookups and has no swappable models." def set_processing_status(self, user_id, message_id): self._processing_status[user_id] = {"processing": True, "message_id": message_id} def clear_processing_status(self, user_id): if user_id in self._processing_status: del self._processing_status[user_id] async def abort_processing(self, user_id): if user_id in self.processing_status: self.clear_processing_status(user_id) return "Processing aborted." else: return "No active processing found to abort." def get_bot_status(self): return f"RAG Bot is active.\nEmbedding Model: {os.path.basename(EMBEDDING_MODEL_NAME)}\nReranker Model: {os.path.basename(RERANKER_MODEL_NAME)}" async def start(self): logging.info(f"{self.__class__.__name__} started.") # --- Core RAG Logic --- def _format_rerank_instruction(self, query, doc): # Using the default instruction from the model card instruction = 'Given a web search query, retrieve relevant passages that answer the query' return f": {instruction}\n: {query}\n: {doc}" @torch.no_grad() def _compute_rerank_scores(self, pairs: list[list[str, str]]): """ Custom score computation logic that follows the model card's example. """ # Format all pairs with the required instruction format formatted_pairs = [self._format_rerank_instruction(query, doc) for query, doc in pairs] # Tokenize the formatted pairs inputs = self.rerank_tokenizer( formatted_pairs, padding=False, truncation='longest_first', return_attention_mask=False, max_length=self.max_length - len(self.prefix_tokens) - len(self.suffix_tokens) ) # Add special prefix and suffix tokens to each item in the batch for i in range(len(inputs['input_ids'])): inputs['input_ids'][i] = self.prefix_tokens + inputs['input_ids'][i] + self.suffix_tokens # Pad the batch to the same length inputs = self.rerank_tokenizer.pad(inputs, padding=True, return_tensors="pt", max_length=self.max_length) inputs = {key: inputs[key].to(self.device) for key in inputs} # Get model outputs (logits) batch_scores = self.rerank_model(**inputs).logits[:, -1, :] # Calculate scores based on the probability of "yes" vs "no" true_vector = batch_scores[:, self.token_true_id] false_vector = batch_scores[:, self.token_false_id] scores = torch.stack([false_vector, true_vector], dim=1) scores = torch.nn.functional.log_softmax(scores, dim=1) final_scores = scores[:, 1].exp().tolist() return final_scores def _retrieve_and_rerank_context(self, query: str): logging.info(f"RAG: Retrieving context for query: '{query}'") if not query: return "" try: results = self.collection.query(query_texts=[query], n_results=N_RESULTS_TO_RETRIEVE) initial_docs = results.get('documents', [[]])[0] if not initial_docs: logging.info("RAG: No initial documents found in vector search.") return "No relevant context found in the knowledge base." logging.info(f"RAG: Retrieved {len(initial_docs)} initial documents from ChromaDB.") # Create pairs of [query, document] for reranking rerank_pairs = [[query, doc] for doc in initial_docs] # Use our custom scoring function scores = self._compute_rerank_scores(rerank_pairs) scored_docs = sorted(zip(scores, initial_docs, results['metadatas'][0]), key=lambda x: x[0], reverse=True) # Take the top N results top_k_docs = scored_docs[:N_RESULTS_TO_KEEP_AFTER_RERANK] # *** NEW: Filter the top N results by the score threshold *** final_docs = [doc for doc in top_k_docs if doc[0] >= RERANKER_SCORE_THRESHOLD] logging.info(f"RAG: Reranked and filtered down to {len(final_docs)} documents with score >= {RERANKER_SCORE_THRESHOLD}.") # If no documents meet the threshold, inform the user. if not final_docs: return "No highly relevant context found after filtering." context_lines = [] for i, (score, doc, metadata) in enumerate(final_docs): source = metadata.get('source', 'Unknown file') chunk_index = metadata.get('chunk_index', 'N/A') context_lines.append(f"--- Context {i+1} (Relevance: {score:.2f}) ---\nSource: {source}\n\n{doc}\n") return "\n".join(context_lines) except Exception as e: logging.error(f"RAG: Error during context retrieval/reranking: {e}", exc_info=True) return "An error occurred while searching the knowledge base." async def handle_message(self, user_id, user_message): context_response = self._retrieve_and_rerank_context(user_message) return context_response def main_rag(): logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') parser = argparse.ArgumentParser(description='RAG-Only Inference Bot') parser.add_argument('--messenger', type=str, help='Messenger type (i.e. telegram)', required=True) args = parser.parse_args() try: bot = RAGInferenceBot() full_code_file = importlib.import_module(f'{args.messenger.lower()}_helper') helper_class = getattr(full_code_file, f"{args.messenger.capitalize()}Helper") helper = helper_class(bot) helper.run() except Exception as e: logging.fatal(f"An unexpected error occurred during bot initialization: {e}", exc_info=True) if __name__ == '__main__': main_rag()