Files
cyclop/create_index.py
2025-08-07 15:38:01 -05:00

186 lines
7.3 KiB
Python

import os
import logging
import torch
import gc # Import the garbage collection module
import chromadb
from chromadb.utils import embedding_functions
from dotenv import load_dotenv
# Import the GitHubTool from its location
# Assuming it's in a 'tools' directory as per your other script
from tools.github_tool import GitHubTool
# --- Configuration ---
# You can adjust these settings
# If you have downloaded a model, provide the local path here.
# Otherwise, the model will be downloaded from Hugging Face.
# Example: EMBEDDING_MODEL_PATH = "/path/to/your/models/all-MiniLM-L6-v2"
EMBEDDING_MODEL_PATH = os.environ.get("EMBEDDING_MODEL_PATH")
# Path to store the local vector database
CHROMA_DB_PATH = os.environ.get("CHROMA_DB_PATH")
# Name of the collection within the database
CHROMA_COLLECTION_NAME = "github_repo"
# Files with these extensions will be indexed. Add any other text-based files you need.
# Excludes common binary/unwanted files.
INCLUDED_EXTENSIONS = ['.py', '.js', '.ts', '.md', '.txt', '.html', '.css', '.go', '.rs', '.java', '.c', '.h', '.cpp', '.sh', '.yaml', '.json']
# *** NEW: Intelligent Chunking Configuration ***
CHUNK_SIZE = 1000 # The target size for each text chunk in characters
CHUNK_OVERLAP = 200 # The number of characters to overlap between chunks
CHUNK_PROCESSING_BATCH_SIZE = 100 # The number of chunks to process in a single batch
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def get_all_repo_files(github_tool, repo_path=""):
"""
Recursively fetches all file paths from the GitHub repository.
"""
all_files = []
try:
items = github_tool.execute("list_files", path=repo_path)
if isinstance(items, str) and items.startswith("Error"):
logging.error(f"Could not list files at path '{repo_path}': {items}")
return []
for item in items:
# Only index files with allowed extensions
if item['type'] == 'file' and any(item['name'].endswith(ext) for ext in INCLUDED_EXTENSIONS):
all_files.append(item['path'])
elif item['type'] == 'dir':
# It's a directory, so recurse into it
all_files.extend(get_all_repo_files(github_tool, repo_path=item['path']))
return all_files
except Exception as e:
logging.error(f"An unexpected error occurred while listing files at '{repo_path}': {e}")
return all_files
def split_text(text: str) -> list[str]:
"""
Splits text into chunks of a specified size with overlap.
"""
chunks = []
if text is None or not text.strip():
return []
start = 0
while start < len(text):
end = start + CHUNK_SIZE
chunks.append(text[start:end])
start += CHUNK_SIZE - CHUNK_OVERLAP
return chunks
def main():
"""
Main function to initialize the database and index the GitHub repository.
"""
load_dotenv()
logging.info("Starting repository indexing process...")
# 1. Initialize GitHub Tool
try:
github_repo = os.getenv("GITHUB_REPOSITORY")
github_token = os.getenv("GITHUB_TOKEN")
if not github_repo or not github_token:
raise ValueError("GITHUB_REPOSITORY and GITHUB_TOKEN environment variables are required.")
github_tool = GitHubTool(repo=github_repo, token=github_token)
logging.info(f"Successfully initialized GitHubTool for repo: {github_repo}")
except Exception as e:
logging.fatal(f"Failed to initialize GitHubTool: {e}")
return
# 2. Initialize ChromaDB and Embedding Model
client = chromadb.PersistentClient(path=CHROMA_DB_PATH)
device = "cuda" if torch.cuda.is_available() else "cpu"
logging.info(f"Using device: {device} for embedding model inference.")
model_location = EMBEDDING_MODEL_PATH
logging.info(f"Loading embedding model from: {model_location}")
sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(
model_name=model_location,
device=device
)
logging.info(f"Loading or creating Chroma collection: '{CHROMA_COLLECTION_NAME}'")
collection = client.get_or_create_collection(
name=CHROMA_COLLECTION_NAME,
embedding_function=sentence_transformer_ef,
metadata={"hnsw:space": "cosine"}
)
# 3. Fetch all file paths from the repository
logging.info("Fetching all file paths from the repository...")
file_paths = get_all_repo_files(github_tool)
if not file_paths:
logging.warning("No files found to index. Exiting.")
return
logging.info(f"Found {len(file_paths)} files to potentially index.")
# 4. Process files and upsert to ChromaDB in chunk-based batches
batch_documents = []
batch_metadatas = []
batch_ids = []
for i, file_path in enumerate(file_paths):
logging.info(f"Processing file {i+1}/{len(file_paths)}: {file_path}")
try:
content = github_tool.execute("read_file", path=file_path)
if not isinstance(content, str) or content.startswith("Error"):
logging.warning(f"Could not read or empty content for {file_path}. Skipping.")
continue
# *** USE THE NEW, ROBUST CHUNKING METHOD ***
chunks = split_text(content)
for chunk_index, chunk in enumerate(chunks):
# Add the processed chunk to the current batch
unique_id = f"{file_path}_{chunk_index}"
batch_documents.append(chunk)
batch_metadatas.append({"source": file_path, "chunk_index": chunk_index})
batch_ids.append(unique_id)
# If the batch reaches the desired size, upsert it to the database
if len(batch_documents) >= CHUNK_PROCESSING_BATCH_SIZE:
logging.info(f"Upserting batch of {len(batch_documents)} chunks...")
collection.upsert(
ids=batch_ids,
documents=batch_documents,
metadatas=batch_metadatas
)
# Clear the batch lists to free up memory
batch_documents, batch_metadatas, batch_ids = [], [], []
# Force garbage collection and empty CUDA cache
logging.info("Cleaning up memory...")
gc.collect()
if device == 'cuda':
torch.cuda.empty_cache()
logging.info("Batch upserted and memory cleared.")
except Exception as e:
logging.error(f"Error processing file {file_path}: {e}")
# 5. Upsert any remaining documents after the loop finishes
if batch_documents:
logging.info(f"Upserting final batch of {len(batch_documents)} chunks...")
collection.upsert(
ids=batch_ids,
documents=batch_documents,
metadatas=batch_metadatas
)
# Final cleanup
gc.collect()
if device == 'cuda':
torch.cuda.empty_cache()
logging.info("Final batch upserted.")
logging.info("--- Indexing Complete ---")
logging.info(f"Total documents in collection: {collection.count()}")
if __name__ == '__main__':
main()