2025-06-06 17:41:30 -05:00
import logging
import chromadb
from chromadb . utils import embedding_functions
from inference_bot import InferenceBot # Correctly inherit from the ABC
import argparse
import os
import importlib
import torch
from transformers import AutoTokenizer , AutoModelForCausalLM
# --- RAG Configuration ---
# Must match the settings in create_index.py
2025-08-07 15:38:01 -05:00
EMBEDDING_MODEL_NAME = os . environ . get ( " EMBEDDING_MODEL_PATH " )
CHROMA_DB_PATH = os . environ . get ( " CHROMA_DB_PATH " )
2025-06-06 17:41:30 -05:00
CHROMA_COLLECTION_NAME = " github_repo "
# Using a powerful open-source reranker model
2025-08-07 15:38:01 -05:00
RERANKER_MODEL_NAME = os . environ . get ( " RERANKER_MODEL_PATH " )
2025-06-06 17:41:30 -05:00
# Number of initial results to fetch from the database before reranking
N_RESULTS_TO_RETRIEVE = 25
# Number of final results to keep after reranking
N_RESULTS_TO_KEEP_AFTER_RERANK = 5
# The minimum relevance score a result must have to be included in the final context
RERANKER_SCORE_THRESHOLD = 0.5
class RAGInferenceBot ( InferenceBot ) :
def __init__ ( self ) :
"""
Initializes the RAG components, including a custom implementation
for the Qwen3-Reranker based on its model card.
"""
logging . info ( " Initializing RAG components... " )
self . _processing_status = { }
try :
# --- Embedding and Vector DB Initialization ---
2025-08-07 15:38:01 -05:00
self . chroma_client = chromadb . PersistentClient ( path = CHROMA_DB_PATH , settings = chromadb . Settings ( anonymized_telemetry = False ) )
2025-06-06 17:41:30 -05:00
self . embedding_function = embedding_functions . SentenceTransformerEmbeddingFunction (
2025-08-07 15:38:01 -05:00
model_name = EMBEDDING_MODEL_NAME , device = " cuda "
2025-06-06 17:41:30 -05:00
)
self . collection = self . chroma_client . get_collection (
name = CHROMA_COLLECTION_NAME ,
embedding_function = self . embedding_function
)
logging . info ( " Successfully connected to ChromaDB collection for RAG. " )
# --- Custom Reranker Initialization (as per model card) ---
self . device = " cuda " if torch . cuda . is_available ( ) else " cpu "
logging . info ( f " Using device: { self . device } for Reranker model. " )
self . rerank_tokenizer = AutoTokenizer . from_pretrained ( RERANKER_MODEL_NAME , padding_side = ' left ' )
self . rerank_model = AutoModelForCausalLM . from_pretrained ( RERANKER_MODEL_NAME , torch_dtype = torch . float16 ) . to ( self . device ) . eval ( )
# Manually set the padding token if it's missing
if self . rerank_tokenizer . pad_token is None :
logging . warning ( " Reranker tokenizer has no pad_token. Setting it to the eos_token. " )
self . rerank_tokenizer . pad_token = self . rerank_tokenizer . eos_token
# Get token IDs for score calculation
self . token_false_id = self . rerank_tokenizer . convert_tokens_to_ids ( " no " )
self . token_true_id = self . rerank_tokenizer . convert_tokens_to_ids ( " yes " )
self . max_length = 8192
# Define and pre-encode the special prefixes and suffixes from the model card
prefix_text = " <|im_start|>system \n Judge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \" yes \" or \" no \" .<|im_end|> \n <|im_start|>user \n "
suffix_text = " <|im_end|> \n <|im_start|>assistant \n <think> \n \n </think> \n \n "
self . prefix_tokens = self . rerank_tokenizer . encode ( prefix_text , add_special_tokens = False )
self . suffix_tokens = self . rerank_tokenizer . encode ( suffix_text , add_special_tokens = False )
logging . info ( f " Successfully initialized custom Reranker: { RERANKER_MODEL_NAME } " )
except Exception as e :
logging . fatal ( f " Failed to initialize RAG components: { e } " , exc_info = True )
raise
# --- Implementation of all abstract methods from InferenceBot ---
@property
def processing_status ( self ) :
return self . _processing_status
def clear_conversation_history ( self , user_id ) :
pass
async def switch_model ( self ) :
return " This bot only performs RAG lookups and has no swappable models. "
def set_processing_status ( self , user_id , message_id ) :
self . _processing_status [ user_id ] = { " processing " : True , " message_id " : message_id }
def clear_processing_status ( self , user_id ) :
if user_id in self . _processing_status :
del self . _processing_status [ user_id ]
async def abort_processing ( self , user_id ) :
if user_id in self . processing_status :
self . clear_processing_status ( user_id )
return " Processing aborted. "
else :
return " No active processing found to abort. "
def get_bot_status ( self ) :
return f " RAG Bot is active. \n Embedding Model: { os . path . basename ( EMBEDDING_MODEL_NAME ) } \n Reranker Model: { os . path . basename ( RERANKER_MODEL_NAME ) } "
async def start ( self ) :
logging . info ( f " { self . __class__ . __name__ } started. " )
# --- Core RAG Logic ---
def _format_rerank_instruction ( self , query , doc ) :
# Using the default instruction from the model card
instruction = ' Given a web search query, retrieve relevant passages that answer the query '
return f " <Instruct>: { instruction } \n <Query>: { query } \n <Document>: { doc } "
@torch.no_grad ( )
def _compute_rerank_scores ( self , pairs : list [ list [ str , str ] ] ) :
"""
Custom score computation logic that follows the model card ' s example.
"""
# Format all pairs with the required instruction format
formatted_pairs = [ self . _format_rerank_instruction ( query , doc ) for query , doc in pairs ]
# Tokenize the formatted pairs
inputs = self . rerank_tokenizer (
formatted_pairs , padding = False , truncation = ' longest_first ' ,
return_attention_mask = False , max_length = self . max_length - len ( self . prefix_tokens ) - len ( self . suffix_tokens )
)
# Add special prefix and suffix tokens to each item in the batch
for i in range ( len ( inputs [ ' input_ids ' ] ) ) :
inputs [ ' input_ids ' ] [ i ] = self . prefix_tokens + inputs [ ' input_ids ' ] [ i ] + self . suffix_tokens
# Pad the batch to the same length
inputs = self . rerank_tokenizer . pad ( inputs , padding = True , return_tensors = " pt " , max_length = self . max_length )
inputs = { key : inputs [ key ] . to ( self . device ) for key in inputs }
# Get model outputs (logits)
batch_scores = self . rerank_model ( * * inputs ) . logits [ : , - 1 , : ]
# Calculate scores based on the probability of "yes" vs "no"
true_vector = batch_scores [ : , self . token_true_id ]
false_vector = batch_scores [ : , self . token_false_id ]
scores = torch . stack ( [ false_vector , true_vector ] , dim = 1 )
scores = torch . nn . functional . log_softmax ( scores , dim = 1 )
final_scores = scores [ : , 1 ] . exp ( ) . tolist ( )
return final_scores
def _retrieve_and_rerank_context ( self , query : str ) :
logging . info ( f " RAG: Retrieving context for query: ' { query } ' " )
if not query : return " "
try :
results = self . collection . query ( query_texts = [ query ] , n_results = N_RESULTS_TO_RETRIEVE )
initial_docs = results . get ( ' documents ' , [ [ ] ] ) [ 0 ]
if not initial_docs :
logging . info ( " RAG: No initial documents found in vector search. " )
return " No relevant context found in the knowledge base. "
logging . info ( f " RAG: Retrieved { len ( initial_docs ) } initial documents from ChromaDB. " )
# Create pairs of [query, document] for reranking
rerank_pairs = [ [ query , doc ] for doc in initial_docs ]
# Use our custom scoring function
scores = self . _compute_rerank_scores ( rerank_pairs )
scored_docs = sorted ( zip ( scores , initial_docs , results [ ' metadatas ' ] [ 0 ] ) , key = lambda x : x [ 0 ] , reverse = True )
# Take the top N results
top_k_docs = scored_docs [ : N_RESULTS_TO_KEEP_AFTER_RERANK ]
# *** NEW: Filter the top N results by the score threshold ***
final_docs = [ doc for doc in top_k_docs if doc [ 0 ] > = RERANKER_SCORE_THRESHOLD ]
logging . info ( f " RAG: Reranked and filtered down to { len ( final_docs ) } documents with score >= { RERANKER_SCORE_THRESHOLD } . " )
# If no documents meet the threshold, inform the user.
if not final_docs :
return " No highly relevant context found after filtering. "
context_lines = [ ]
for i , ( score , doc , metadata ) in enumerate ( final_docs ) :
source = metadata . get ( ' source ' , ' Unknown file ' )
chunk_index = metadata . get ( ' chunk_index ' , ' N/A ' )
context_lines . append ( f " --- Context { i + 1 } (Relevance: { score : .2f } ) --- \n Source: { source } \n \n { doc } \n " )
return " \n " . join ( context_lines )
except Exception as e :
logging . error ( f " RAG: Error during context retrieval/reranking: { e } " , exc_info = True )
return " An error occurred while searching the knowledge base. "
async def handle_message ( self , user_id , user_message ) :
context_response = self . _retrieve_and_rerank_context ( user_message )
return context_response
def main_rag ( ) :
logging . basicConfig ( level = logging . INFO , format = ' %(asctime)s - %(name)s - %(levelname)s - %(message)s ' )
parser = argparse . ArgumentParser ( description = ' RAG-Only Inference Bot ' )
parser . add_argument ( ' --messenger ' , type = str , help = ' Messenger type (i.e. telegram) ' , required = True )
args = parser . parse_args ( )
try :
bot = RAGInferenceBot ( )
full_code_file = importlib . import_module ( f ' { args . messenger . lower ( ) } _helper ' )
helper_class = getattr ( full_code_file , f " { args . messenger . capitalize ( ) } Helper " )
helper = helper_class ( bot )
helper . run ( )
except Exception as e :
logging . fatal ( f " An unexpected error occurred during bot initialization: { e } " , exc_info = True )
if __name__ == ' __main__ ' :
main_rag ( )