Fix RAG inference
This commit is contained in:
@@ -2,7 +2,6 @@ import logging
|
||||
import chromadb
|
||||
from chromadb.utils import embedding_functions
|
||||
from inference_bot import InferenceBot # Correctly inherit from the ABC
|
||||
from FlagEmbedding import FlagReranker
|
||||
import argparse
|
||||
import os
|
||||
import importlib
|
||||
@@ -12,12 +11,12 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
# --- RAG Configuration ---
|
||||
# Must match the settings in create_index.py
|
||||
EMBEDDING_MODEL_NAME = """C:\Models\embeddings\Qwen3-Embedding-0.6B"""
|
||||
CHROMA_DB_PATH = "C:\Models\embeddings\embedding_result\chroma_db"
|
||||
EMBEDDING_MODEL_NAME = os.environ.get("EMBEDDING_MODEL_PATH")
|
||||
CHROMA_DB_PATH = os.environ.get("CHROMA_DB_PATH")
|
||||
CHROMA_COLLECTION_NAME = "github_repo"
|
||||
|
||||
# Using a powerful open-source reranker model
|
||||
RERANKER_MODEL_NAME = """C:\Models\embeddings\Qwen3-Reranker-0.6B"""
|
||||
RERANKER_MODEL_NAME = os.environ.get("RERANKER_MODEL_PATH")
|
||||
|
||||
# Number of initial results to fetch from the database before reranking
|
||||
N_RESULTS_TO_RETRIEVE = 25
|
||||
@@ -36,9 +35,9 @@ class RAGInferenceBot(InferenceBot):
|
||||
self._processing_status = {}
|
||||
try:
|
||||
# --- Embedding and Vector DB Initialization ---
|
||||
self.chroma_client = chromadb.PersistentClient(path=CHROMA_DB_PATH)
|
||||
self.chroma_client = chromadb.PersistentClient(path=CHROMA_DB_PATH, settings=chromadb.Settings(anonymized_telemetry=False))
|
||||
self.embedding_function = embedding_functions.SentenceTransformerEmbeddingFunction(
|
||||
model_name=EMBEDDING_MODEL_NAME
|
||||
model_name=EMBEDDING_MODEL_NAME, device="cuda"
|
||||
)
|
||||
self.collection = self.chroma_client.get_collection(
|
||||
name=CHROMA_COLLECTION_NAME,
|
||||
|
||||
Reference in New Issue
Block a user