186 lines
7.3 KiB
Python
186 lines
7.3 KiB
Python
import os
|
|
import logging
|
|
import torch
|
|
import gc # Import the garbage collection module
|
|
import chromadb
|
|
from chromadb.utils import embedding_functions
|
|
from dotenv import load_dotenv
|
|
|
|
# Import the GitHubTool from its location
|
|
# Assuming it's in a 'tools' directory as per your other script
|
|
from tools.github_tool import GitHubTool
|
|
|
|
# --- Configuration ---
|
|
# You can adjust these settings
|
|
|
|
# If you have downloaded a model, provide the local path here.
|
|
# Otherwise, the model will be downloaded from Hugging Face.
|
|
# Example: EMBEDDING_MODEL_PATH = "/path/to/your/models/all-MiniLM-L6-v2"
|
|
EMBEDDING_MODEL_PATH = os.environ.get("EMBEDDING_MODEL_PATH")
|
|
|
|
# Path to store the local vector database
|
|
CHROMA_DB_PATH = os.environ.get("CHROMA_DB_PATH")
|
|
# Name of the collection within the database
|
|
CHROMA_COLLECTION_NAME = "github_repo"
|
|
# Files with these extensions will be indexed. Add any other text-based files you need.
|
|
# Excludes common binary/unwanted files.
|
|
INCLUDED_EXTENSIONS = ['.py', '.js', '.ts', '.md', '.txt', '.html', '.css', '.go', '.rs', '.java', '.c', '.h', '.cpp', '.sh', '.yaml', '.json']
|
|
|
|
# *** NEW: Intelligent Chunking Configuration ***
|
|
CHUNK_SIZE = 1000 # The target size for each text chunk in characters
|
|
CHUNK_OVERLAP = 200 # The number of characters to overlap between chunks
|
|
CHUNK_PROCESSING_BATCH_SIZE = 100 # The number of chunks to process in a single batch
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
|
|
def get_all_repo_files(github_tool, repo_path=""):
|
|
"""
|
|
Recursively fetches all file paths from the GitHub repository.
|
|
"""
|
|
all_files = []
|
|
try:
|
|
items = github_tool.execute("list_files", path=repo_path)
|
|
if isinstance(items, str) and items.startswith("Error"):
|
|
logging.error(f"Could not list files at path '{repo_path}': {items}")
|
|
return []
|
|
|
|
for item in items:
|
|
# Only index files with allowed extensions
|
|
if item['type'] == 'file' and any(item['name'].endswith(ext) for ext in INCLUDED_EXTENSIONS):
|
|
all_files.append(item['path'])
|
|
elif item['type'] == 'dir':
|
|
# It's a directory, so recurse into it
|
|
all_files.extend(get_all_repo_files(github_tool, repo_path=item['path']))
|
|
|
|
return all_files
|
|
except Exception as e:
|
|
logging.error(f"An unexpected error occurred while listing files at '{repo_path}': {e}")
|
|
return all_files
|
|
|
|
def split_text(text: str) -> list[str]:
|
|
"""
|
|
Splits text into chunks of a specified size with overlap.
|
|
"""
|
|
chunks = []
|
|
if text is None or not text.strip():
|
|
return []
|
|
|
|
start = 0
|
|
while start < len(text):
|
|
end = start + CHUNK_SIZE
|
|
chunks.append(text[start:end])
|
|
start += CHUNK_SIZE - CHUNK_OVERLAP
|
|
return chunks
|
|
|
|
def main():
|
|
"""
|
|
Main function to initialize the database and index the GitHub repository.
|
|
"""
|
|
load_dotenv()
|
|
logging.info("Starting repository indexing process...")
|
|
|
|
# 1. Initialize GitHub Tool
|
|
try:
|
|
github_repo = os.getenv("GITHUB_REPOSITORY")
|
|
github_token = os.getenv("GITHUB_TOKEN")
|
|
if not github_repo or not github_token:
|
|
raise ValueError("GITHUB_REPOSITORY and GITHUB_TOKEN environment variables are required.")
|
|
|
|
github_tool = GitHubTool(repo=github_repo, token=github_token)
|
|
logging.info(f"Successfully initialized GitHubTool for repo: {github_repo}")
|
|
except Exception as e:
|
|
logging.fatal(f"Failed to initialize GitHubTool: {e}")
|
|
return
|
|
|
|
# 2. Initialize ChromaDB and Embedding Model
|
|
client = chromadb.PersistentClient(path=CHROMA_DB_PATH)
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
logging.info(f"Using device: {device} for embedding model inference.")
|
|
model_location = EMBEDDING_MODEL_PATH
|
|
logging.info(f"Loading embedding model from: {model_location}")
|
|
|
|
sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(
|
|
model_name=model_location,
|
|
device=device
|
|
)
|
|
|
|
logging.info(f"Loading or creating Chroma collection: '{CHROMA_COLLECTION_NAME}'")
|
|
collection = client.get_or_create_collection(
|
|
name=CHROMA_COLLECTION_NAME,
|
|
embedding_function=sentence_transformer_ef,
|
|
metadata={"hnsw:space": "cosine"}
|
|
)
|
|
|
|
# 3. Fetch all file paths from the repository
|
|
logging.info("Fetching all file paths from the repository...")
|
|
file_paths = get_all_repo_files(github_tool)
|
|
if not file_paths:
|
|
logging.warning("No files found to index. Exiting.")
|
|
return
|
|
logging.info(f"Found {len(file_paths)} files to potentially index.")
|
|
|
|
# 4. Process files and upsert to ChromaDB in chunk-based batches
|
|
|
|
batch_documents = []
|
|
batch_metadatas = []
|
|
batch_ids = []
|
|
|
|
for i, file_path in enumerate(file_paths):
|
|
logging.info(f"Processing file {i+1}/{len(file_paths)}: {file_path}")
|
|
try:
|
|
content = github_tool.execute("read_file", path=file_path)
|
|
if not isinstance(content, str) or content.startswith("Error"):
|
|
logging.warning(f"Could not read or empty content for {file_path}. Skipping.")
|
|
continue
|
|
|
|
# *** USE THE NEW, ROBUST CHUNKING METHOD ***
|
|
chunks = split_text(content)
|
|
|
|
for chunk_index, chunk in enumerate(chunks):
|
|
# Add the processed chunk to the current batch
|
|
unique_id = f"{file_path}_{chunk_index}"
|
|
batch_documents.append(chunk)
|
|
batch_metadatas.append({"source": file_path, "chunk_index": chunk_index})
|
|
batch_ids.append(unique_id)
|
|
|
|
# If the batch reaches the desired size, upsert it to the database
|
|
if len(batch_documents) >= CHUNK_PROCESSING_BATCH_SIZE:
|
|
logging.info(f"Upserting batch of {len(batch_documents)} chunks...")
|
|
collection.upsert(
|
|
ids=batch_ids,
|
|
documents=batch_documents,
|
|
metadatas=batch_metadatas
|
|
)
|
|
# Clear the batch lists to free up memory
|
|
batch_documents, batch_metadatas, batch_ids = [], [], []
|
|
|
|
# Force garbage collection and empty CUDA cache
|
|
logging.info("Cleaning up memory...")
|
|
gc.collect()
|
|
if device == 'cuda':
|
|
torch.cuda.empty_cache()
|
|
logging.info("Batch upserted and memory cleared.")
|
|
|
|
except Exception as e:
|
|
logging.error(f"Error processing file {file_path}: {e}")
|
|
|
|
# 5. Upsert any remaining documents after the loop finishes
|
|
if batch_documents:
|
|
logging.info(f"Upserting final batch of {len(batch_documents)} chunks...")
|
|
collection.upsert(
|
|
ids=batch_ids,
|
|
documents=batch_documents,
|
|
metadatas=batch_metadatas
|
|
)
|
|
# Final cleanup
|
|
gc.collect()
|
|
if device == 'cuda':
|
|
torch.cuda.empty_cache()
|
|
logging.info("Final batch upserted.")
|
|
|
|
logging.info("--- Indexing Complete ---")
|
|
logging.info(f"Total documents in collection: {collection.count()}")
|
|
|
|
if __name__ == '__main__':
|
|
main()
|