219 lines
9.9 KiB
Python
219 lines
9.9 KiB
Python
import logging
|
|
import chromadb
|
|
from chromadb.utils import embedding_functions
|
|
from inference_bot import InferenceBot # Correctly inherit from the ABC
|
|
import argparse
|
|
import os
|
|
import importlib
|
|
import torch
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
|
|
|
|
# --- RAG Configuration ---
|
|
# Must match the settings in create_index.py
|
|
EMBEDDING_MODEL_NAME = os.environ.get("EMBEDDING_MODEL_PATH")
|
|
CHROMA_DB_PATH = os.environ.get("CHROMA_DB_PATH")
|
|
CHROMA_COLLECTION_NAME = "github_repo"
|
|
|
|
# Using a powerful open-source reranker model
|
|
RERANKER_MODEL_NAME = os.environ.get("RERANKER_MODEL_PATH")
|
|
|
|
# Number of initial results to fetch from the database before reranking
|
|
N_RESULTS_TO_RETRIEVE = 25
|
|
# Number of final results to keep after reranking
|
|
N_RESULTS_TO_KEEP_AFTER_RERANK = 5
|
|
# The minimum relevance score a result must have to be included in the final context
|
|
RERANKER_SCORE_THRESHOLD = 0.5
|
|
|
|
class RAGInferenceBot(InferenceBot):
|
|
def __init__(self):
|
|
"""
|
|
Initializes the RAG components, including a custom implementation
|
|
for the Qwen3-Reranker based on its model card.
|
|
"""
|
|
logging.info("Initializing RAG components...")
|
|
self._processing_status = {}
|
|
try:
|
|
# --- Embedding and Vector DB Initialization ---
|
|
self.chroma_client = chromadb.PersistentClient(path=CHROMA_DB_PATH, settings=chromadb.Settings(anonymized_telemetry=False))
|
|
self.embedding_function = embedding_functions.SentenceTransformerEmbeddingFunction(
|
|
model_name=EMBEDDING_MODEL_NAME, device="cuda"
|
|
)
|
|
self.collection = self.chroma_client.get_collection(
|
|
name=CHROMA_COLLECTION_NAME,
|
|
embedding_function=self.embedding_function
|
|
)
|
|
logging.info("Successfully connected to ChromaDB collection for RAG.")
|
|
|
|
# --- Custom Reranker Initialization (as per model card) ---
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
logging.info(f"Using device: {self.device} for Reranker model.")
|
|
|
|
self.rerank_tokenizer = AutoTokenizer.from_pretrained(RERANKER_MODEL_NAME, padding_side='left')
|
|
self.rerank_model = AutoModelForCausalLM.from_pretrained(RERANKER_MODEL_NAME, torch_dtype=torch.float16).to(self.device).eval()
|
|
|
|
# Manually set the padding token if it's missing
|
|
if self.rerank_tokenizer.pad_token is None:
|
|
logging.warning("Reranker tokenizer has no pad_token. Setting it to the eos_token.")
|
|
self.rerank_tokenizer.pad_token = self.rerank_tokenizer.eos_token
|
|
|
|
# Get token IDs for score calculation
|
|
self.token_false_id = self.rerank_tokenizer.convert_tokens_to_ids("no")
|
|
self.token_true_id = self.rerank_tokenizer.convert_tokens_to_ids("yes")
|
|
self.max_length = 8192
|
|
|
|
# Define and pre-encode the special prefixes and suffixes from the model card
|
|
prefix_text = "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n"
|
|
suffix_text = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
|
|
self.prefix_tokens = self.rerank_tokenizer.encode(prefix_text, add_special_tokens=False)
|
|
self.suffix_tokens = self.rerank_tokenizer.encode(suffix_text, add_special_tokens=False)
|
|
|
|
logging.info(f"Successfully initialized custom Reranker: {RERANKER_MODEL_NAME}")
|
|
|
|
except Exception as e:
|
|
logging.fatal(f"Failed to initialize RAG components: {e}", exc_info=True)
|
|
raise
|
|
|
|
# --- Implementation of all abstract methods from InferenceBot ---
|
|
|
|
@property
|
|
def processing_status(self):
|
|
return self._processing_status
|
|
|
|
def clear_conversation_history(self, user_id):
|
|
pass
|
|
|
|
async def switch_model(self):
|
|
return "This bot only performs RAG lookups and has no swappable models."
|
|
|
|
def set_processing_status(self, user_id, message_id):
|
|
self._processing_status[user_id] = {"processing": True, "message_id": message_id}
|
|
|
|
def clear_processing_status(self, user_id):
|
|
if user_id in self._processing_status:
|
|
del self._processing_status[user_id]
|
|
|
|
async def abort_processing(self, user_id):
|
|
if user_id in self.processing_status:
|
|
self.clear_processing_status(user_id)
|
|
return "Processing aborted."
|
|
else:
|
|
return "No active processing found to abort."
|
|
|
|
def get_bot_status(self):
|
|
return f"RAG Bot is active.\nEmbedding Model: {os.path.basename(EMBEDDING_MODEL_NAME)}\nReranker Model: {os.path.basename(RERANKER_MODEL_NAME)}"
|
|
|
|
async def start(self):
|
|
logging.info(f"{self.__class__.__name__} started.")
|
|
|
|
# --- Core RAG Logic ---
|
|
|
|
def _format_rerank_instruction(self, query, doc):
|
|
# Using the default instruction from the model card
|
|
instruction = 'Given a web search query, retrieve relevant passages that answer the query'
|
|
return f"<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}"
|
|
|
|
@torch.no_grad()
|
|
def _compute_rerank_scores(self, pairs: list[list[str, str]]):
|
|
"""
|
|
Custom score computation logic that follows the model card's example.
|
|
"""
|
|
# Format all pairs with the required instruction format
|
|
formatted_pairs = [self._format_rerank_instruction(query, doc) for query, doc in pairs]
|
|
|
|
# Tokenize the formatted pairs
|
|
inputs = self.rerank_tokenizer(
|
|
formatted_pairs, padding=False, truncation='longest_first',
|
|
return_attention_mask=False, max_length=self.max_length - len(self.prefix_tokens) - len(self.suffix_tokens)
|
|
)
|
|
|
|
# Add special prefix and suffix tokens to each item in the batch
|
|
for i in range(len(inputs['input_ids'])):
|
|
inputs['input_ids'][i] = self.prefix_tokens + inputs['input_ids'][i] + self.suffix_tokens
|
|
|
|
# Pad the batch to the same length
|
|
inputs = self.rerank_tokenizer.pad(inputs, padding=True, return_tensors="pt", max_length=self.max_length)
|
|
inputs = {key: inputs[key].to(self.device) for key in inputs}
|
|
|
|
# Get model outputs (logits)
|
|
batch_scores = self.rerank_model(**inputs).logits[:, -1, :]
|
|
|
|
# Calculate scores based on the probability of "yes" vs "no"
|
|
true_vector = batch_scores[:, self.token_true_id]
|
|
false_vector = batch_scores[:, self.token_false_id]
|
|
|
|
scores = torch.stack([false_vector, true_vector], dim=1)
|
|
scores = torch.nn.functional.log_softmax(scores, dim=1)
|
|
|
|
final_scores = scores[:, 1].exp().tolist()
|
|
return final_scores
|
|
|
|
def _retrieve_and_rerank_context(self, query: str):
|
|
logging.info(f"RAG: Retrieving context for query: '{query}'")
|
|
if not query: return ""
|
|
try:
|
|
results = self.collection.query(query_texts=[query], n_results=N_RESULTS_TO_RETRIEVE)
|
|
initial_docs = results.get('documents', [[]])[0]
|
|
if not initial_docs:
|
|
logging.info("RAG: No initial documents found in vector search.")
|
|
return "No relevant context found in the knowledge base."
|
|
|
|
logging.info(f"RAG: Retrieved {len(initial_docs)} initial documents from ChromaDB.")
|
|
|
|
# Create pairs of [query, document] for reranking
|
|
rerank_pairs = [[query, doc] for doc in initial_docs]
|
|
|
|
# Use our custom scoring function
|
|
scores = self._compute_rerank_scores(rerank_pairs)
|
|
|
|
scored_docs = sorted(zip(scores, initial_docs, results['metadatas'][0]), key=lambda x: x[0], reverse=True)
|
|
|
|
# Take the top N results
|
|
top_k_docs = scored_docs[:N_RESULTS_TO_KEEP_AFTER_RERANK]
|
|
|
|
# *** NEW: Filter the top N results by the score threshold ***
|
|
final_docs = [doc for doc in top_k_docs if doc[0] >= RERANKER_SCORE_THRESHOLD]
|
|
|
|
logging.info(f"RAG: Reranked and filtered down to {len(final_docs)} documents with score >= {RERANKER_SCORE_THRESHOLD}.")
|
|
|
|
# If no documents meet the threshold, inform the user.
|
|
if not final_docs:
|
|
return "No highly relevant context found after filtering."
|
|
|
|
context_lines = []
|
|
for i, (score, doc, metadata) in enumerate(final_docs):
|
|
source = metadata.get('source', 'Unknown file')
|
|
chunk_index = metadata.get('chunk_index', 'N/A')
|
|
context_lines.append(f"--- Context {i+1} (Relevance: {score:.2f}) ---\nSource: {source}\n\n{doc}\n")
|
|
|
|
return "\n".join(context_lines)
|
|
except Exception as e:
|
|
logging.error(f"RAG: Error during context retrieval/reranking: {e}", exc_info=True)
|
|
return "An error occurred while searching the knowledge base."
|
|
|
|
async def handle_message(self, user_id, user_message):
|
|
context_response = self._retrieve_and_rerank_context(user_message)
|
|
return context_response
|
|
|
|
|
|
def main_rag():
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
|
|
parser = argparse.ArgumentParser(description='RAG-Only Inference Bot')
|
|
parser.add_argument('--messenger', type=str, help='Messenger type (i.e. telegram)', required=True)
|
|
args = parser.parse_args()
|
|
|
|
try:
|
|
bot = RAGInferenceBot()
|
|
|
|
full_code_file = importlib.import_module(f'{args.messenger.lower()}_helper')
|
|
helper_class = getattr(full_code_file, f"{args.messenger.capitalize()}Helper")
|
|
helper = helper_class(bot)
|
|
helper.run()
|
|
|
|
except Exception as e:
|
|
logging.fatal(f"An unexpected error occurred during bot initialization: {e}", exc_info=True)
|
|
|
|
if __name__ == '__main__':
|
|
main_rag()
|