import os import logging import torch import gc # Import the garbage collection module import chromadb from chromadb.utils import embedding_functions from dotenv import load_dotenv # Import the GitHubTool from its location # Assuming it's in a 'tools' directory as per your other script from tools.github_tool import GitHubTool # --- Configuration --- # You can adjust these settings # If you have downloaded a model, provide the local path here. # Otherwise, the model will be downloaded from Hugging Face. # Example: EMBEDDING_MODEL_PATH = "/path/to/your/models/all-MiniLM-L6-v2" EMBEDDING_MODEL_PATH = os.environ.get("EMBEDDING_MODEL_PATH") # Path to store the local vector database CHROMA_DB_PATH = os.environ.get("CHROMA_DB_PATH") # Name of the collection within the database CHROMA_COLLECTION_NAME = "github_repo" # Files with these extensions will be indexed. Add any other text-based files you need. # Excludes common binary/unwanted files. INCLUDED_EXTENSIONS = ['.py', '.js', '.ts', '.md', '.txt', '.html', '.css', '.go', '.rs', '.java', '.c', '.h', '.cpp', '.sh', '.yaml', '.json'] # *** NEW: Intelligent Chunking Configuration *** CHUNK_SIZE = 1000 # The target size for each text chunk in characters CHUNK_OVERLAP = 200 # The number of characters to overlap between chunks CHUNK_PROCESSING_BATCH_SIZE = 100 # The number of chunks to process in a single batch logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') def get_all_repo_files(github_tool, repo_path=""): """ Recursively fetches all file paths from the GitHub repository. """ all_files = [] try: items = github_tool.execute("list_files", path=repo_path) if isinstance(items, str) and items.startswith("Error"): logging.error(f"Could not list files at path '{repo_path}': {items}") return [] for item in items: # Only index files with allowed extensions if item['type'] == 'file' and any(item['name'].endswith(ext) for ext in INCLUDED_EXTENSIONS): all_files.append(item['path']) elif item['type'] == 'dir': # It's a directory, so recurse into it all_files.extend(get_all_repo_files(github_tool, repo_path=item['path'])) return all_files except Exception as e: logging.error(f"An unexpected error occurred while listing files at '{repo_path}': {e}") return all_files def split_text(text: str) -> list[str]: """ Splits text into chunks of a specified size with overlap. """ chunks = [] if text is None or not text.strip(): return [] start = 0 while start < len(text): end = start + CHUNK_SIZE chunks.append(text[start:end]) start += CHUNK_SIZE - CHUNK_OVERLAP return chunks def main(): """ Main function to initialize the database and index the GitHub repository. """ load_dotenv() logging.info("Starting repository indexing process...") # 1. Initialize GitHub Tool try: github_repo = os.getenv("GITHUB_REPOSITORY") github_token = os.getenv("GITHUB_TOKEN") if not github_repo or not github_token: raise ValueError("GITHUB_REPOSITORY and GITHUB_TOKEN environment variables are required.") github_tool = GitHubTool(repo=github_repo, token=github_token) logging.info(f"Successfully initialized GitHubTool for repo: {github_repo}") except Exception as e: logging.fatal(f"Failed to initialize GitHubTool: {e}") return # 2. Initialize ChromaDB and Embedding Model client = chromadb.PersistentClient(path=CHROMA_DB_PATH) device = "cuda" if torch.cuda.is_available() else "cpu" logging.info(f"Using device: {device} for embedding model inference.") model_location = EMBEDDING_MODEL_PATH logging.info(f"Loading embedding model from: {model_location}") sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction( model_name=model_location, device=device ) logging.info(f"Loading or creating Chroma collection: '{CHROMA_COLLECTION_NAME}'") collection = client.get_or_create_collection( name=CHROMA_COLLECTION_NAME, embedding_function=sentence_transformer_ef, metadata={"hnsw:space": "cosine"} ) # 3. Fetch all file paths from the repository logging.info("Fetching all file paths from the repository...") file_paths = get_all_repo_files(github_tool) if not file_paths: logging.warning("No files found to index. Exiting.") return logging.info(f"Found {len(file_paths)} files to potentially index.") # 4. Process files and upsert to ChromaDB in chunk-based batches batch_documents = [] batch_metadatas = [] batch_ids = [] for i, file_path in enumerate(file_paths): logging.info(f"Processing file {i+1}/{len(file_paths)}: {file_path}") try: content = github_tool.execute("read_file", path=file_path) if not isinstance(content, str) or content.startswith("Error"): logging.warning(f"Could not read or empty content for {file_path}. Skipping.") continue # *** USE THE NEW, ROBUST CHUNKING METHOD *** chunks = split_text(content) for chunk_index, chunk in enumerate(chunks): # Add the processed chunk to the current batch unique_id = f"{file_path}_{chunk_index}" batch_documents.append(chunk) batch_metadatas.append({"source": file_path, "chunk_index": chunk_index}) batch_ids.append(unique_id) # If the batch reaches the desired size, upsert it to the database if len(batch_documents) >= CHUNK_PROCESSING_BATCH_SIZE: logging.info(f"Upserting batch of {len(batch_documents)} chunks...") collection.upsert( ids=batch_ids, documents=batch_documents, metadatas=batch_metadatas ) # Clear the batch lists to free up memory batch_documents, batch_metadatas, batch_ids = [], [], [] # Force garbage collection and empty CUDA cache logging.info("Cleaning up memory...") gc.collect() if device == 'cuda': torch.cuda.empty_cache() logging.info("Batch upserted and memory cleared.") except Exception as e: logging.error(f"Error processing file {file_path}: {e}") # 5. Upsert any remaining documents after the loop finishes if batch_documents: logging.info(f"Upserting final batch of {len(batch_documents)} chunks...") collection.upsert( ids=batch_ids, documents=batch_documents, metadatas=batch_metadatas ) # Final cleanup gc.collect() if device == 'cuda': torch.cuda.empty_cache() logging.info("Final batch upserted.") logging.info("--- Indexing Complete ---") logging.info(f"Total documents in collection: {collection.count()}") if __name__ == '__main__': main()