Fixed github_tool.py, added RAG bot
This commit is contained in:
@@ -0,0 +1,42 @@
|
||||
# .github/workflows/reindex_on_merge.yml
|
||||
|
||||
name: Re-index Repository on Merge (Self-Hosted)
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [closed]
|
||||
branches:
|
||||
- main
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
reindex:
|
||||
# This condition ensures the job only runs if the pull request was actually merged.
|
||||
if: github.event.pull_request.merged == true || github.event_name == 'workflow_dispatch'
|
||||
|
||||
# *** KEY CHANGE ***
|
||||
# This tells GitHub to run this job on one of your self-hosted runners.
|
||||
# You can also add labels to target specific servers, e.g., [self-hosted, linux, x64, my-app]
|
||||
runs-on: inference-server
|
||||
|
||||
steps:
|
||||
# Step 1: Check out the repository's code
|
||||
# This downloads the latest version of your 'main' branch into the runner's working directory.
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
# Step 2: Run the indexing script
|
||||
# This executes your 'create_index.py' script using the Python environment on your server.
|
||||
# It assumes Python and all dependencies from requirements.txt are already installed on the server.
|
||||
# The GITHUB_TOKEN is still passed securely to the script.
|
||||
- name: Run indexing script
|
||||
run: python create_index.py
|
||||
env:
|
||||
GITHUB_REPOSITORY: ${{ github.repository }}
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
# Optional: Specify the working directory if your bot lives in a subfolder
|
||||
# working-directory: ./path/to/your/bot
|
||||
|
||||
# The "Upload database artifact" step is no longer needed, as the database
|
||||
# is now being written directly to a persistent location on your server.
|
||||
|
||||
+185
@@ -0,0 +1,185 @@
|
||||
import os
|
||||
import logging
|
||||
import torch
|
||||
import gc # Import the garbage collection module
|
||||
import chromadb
|
||||
from chromadb.utils import embedding_functions
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Import the GitHubTool from its location
|
||||
# Assuming it's in a 'tools' directory as per your other script
|
||||
from tools.github_tool import GitHubTool
|
||||
|
||||
# --- Configuration ---
|
||||
# You can adjust these settings
|
||||
|
||||
# If you have downloaded a model, provide the local path here.
|
||||
# Otherwise, the model will be downloaded from Hugging Face.
|
||||
# Example: EMBEDDING_MODEL_PATH = "/path/to/your/models/all-MiniLM-L6-v2"
|
||||
EMBEDDING_MODEL_PATH = """C:\Models\embeddings\Qwen3-Embedding-0.6B"""
|
||||
|
||||
# Path to store the local vector database
|
||||
CHROMA_DB_PATH = "C:\Models\embeddings\embedding_result\chroma_db"
|
||||
# Name of the collection within the database
|
||||
CHROMA_COLLECTION_NAME = "github_repo"
|
||||
# Files with these extensions will be indexed. Add any other text-based files you need.
|
||||
# Excludes common binary/unwanted files.
|
||||
INCLUDED_EXTENSIONS = ['.py', '.js', '.ts', '.md', '.txt', '.html', '.css', '.go', '.rs', '.java', '.c', '.h', '.cpp', '.sh', '.yaml', '.json']
|
||||
|
||||
# *** NEW: Intelligent Chunking Configuration ***
|
||||
CHUNK_SIZE = 1000 # The target size for each text chunk in characters
|
||||
CHUNK_OVERLAP = 200 # The number of characters to overlap between chunks
|
||||
CHUNK_PROCESSING_BATCH_SIZE = 100 # The number of chunks to process in a single batch
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
|
||||
def get_all_repo_files(github_tool, repo_path=""):
|
||||
"""
|
||||
Recursively fetches all file paths from the GitHub repository.
|
||||
"""
|
||||
all_files = []
|
||||
try:
|
||||
items = github_tool.execute("list_files", path=repo_path)
|
||||
if isinstance(items, str) and items.startswith("Error"):
|
||||
logging.error(f"Could not list files at path '{repo_path}': {items}")
|
||||
return []
|
||||
|
||||
for item in items:
|
||||
# Only index files with allowed extensions
|
||||
if item['type'] == 'file' and any(item['name'].endswith(ext) for ext in INCLUDED_EXTENSIONS):
|
||||
all_files.append(item['path'])
|
||||
elif item['type'] == 'dir':
|
||||
# It's a directory, so recurse into it
|
||||
all_files.extend(get_all_repo_files(github_tool, repo_path=item['path']))
|
||||
|
||||
return all_files
|
||||
except Exception as e:
|
||||
logging.error(f"An unexpected error occurred while listing files at '{repo_path}': {e}")
|
||||
return all_files
|
||||
|
||||
def split_text(text: str) -> list[str]:
|
||||
"""
|
||||
Splits text into chunks of a specified size with overlap.
|
||||
"""
|
||||
chunks = []
|
||||
if text is None or not text.strip():
|
||||
return []
|
||||
|
||||
start = 0
|
||||
while start < len(text):
|
||||
end = start + CHUNK_SIZE
|
||||
chunks.append(text[start:end])
|
||||
start += CHUNK_SIZE - CHUNK_OVERLAP
|
||||
return chunks
|
||||
|
||||
def main():
|
||||
"""
|
||||
Main function to initialize the database and index the GitHub repository.
|
||||
"""
|
||||
load_dotenv()
|
||||
logging.info("Starting repository indexing process...")
|
||||
|
||||
# 1. Initialize GitHub Tool
|
||||
try:
|
||||
github_repo = os.getenv("GITHUB_REPOSITORY")
|
||||
github_token = os.getenv("GITHUB_TOKEN")
|
||||
if not github_repo or not github_token:
|
||||
raise ValueError("GITHUB_REPOSITORY and GITHUB_TOKEN environment variables are required.")
|
||||
|
||||
github_tool = GitHubTool(repo=github_repo, token=github_token)
|
||||
logging.info(f"Successfully initialized GitHubTool for repo: {github_repo}")
|
||||
except Exception as e:
|
||||
logging.fatal(f"Failed to initialize GitHubTool: {e}")
|
||||
return
|
||||
|
||||
# 2. Initialize ChromaDB and Embedding Model
|
||||
client = chromadb.PersistentClient(path=CHROMA_DB_PATH)
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
logging.info(f"Using device: {device} for embedding model inference.")
|
||||
model_location = EMBEDDING_MODEL_PATH
|
||||
logging.info(f"Loading embedding model from: {model_location}")
|
||||
|
||||
sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(
|
||||
model_name=model_location,
|
||||
device=device
|
||||
)
|
||||
|
||||
logging.info(f"Loading or creating Chroma collection: '{CHROMA_COLLECTION_NAME}'")
|
||||
collection = client.get_or_create_collection(
|
||||
name=CHROMA_COLLECTION_NAME,
|
||||
embedding_function=sentence_transformer_ef,
|
||||
metadata={"hnsw:space": "cosine"}
|
||||
)
|
||||
|
||||
# 3. Fetch all file paths from the repository
|
||||
logging.info("Fetching all file paths from the repository...")
|
||||
file_paths = get_all_repo_files(github_tool)
|
||||
if not file_paths:
|
||||
logging.warning("No files found to index. Exiting.")
|
||||
return
|
||||
logging.info(f"Found {len(file_paths)} files to potentially index.")
|
||||
|
||||
# 4. Process files and upsert to ChromaDB in chunk-based batches
|
||||
|
||||
batch_documents = []
|
||||
batch_metadatas = []
|
||||
batch_ids = []
|
||||
|
||||
for i, file_path in enumerate(file_paths):
|
||||
logging.info(f"Processing file {i+1}/{len(file_paths)}: {file_path}")
|
||||
try:
|
||||
content = github_tool.execute("read_file", path=file_path)
|
||||
if not isinstance(content, str) or content.startswith("Error"):
|
||||
logging.warning(f"Could not read or empty content for {file_path}. Skipping.")
|
||||
continue
|
||||
|
||||
# *** USE THE NEW, ROBUST CHUNKING METHOD ***
|
||||
chunks = split_text(content)
|
||||
|
||||
for chunk_index, chunk in enumerate(chunks):
|
||||
# Add the processed chunk to the current batch
|
||||
unique_id = f"{file_path}_{chunk_index}"
|
||||
batch_documents.append(chunk)
|
||||
batch_metadatas.append({"source": file_path, "chunk_index": chunk_index})
|
||||
batch_ids.append(unique_id)
|
||||
|
||||
# If the batch reaches the desired size, upsert it to the database
|
||||
if len(batch_documents) >= CHUNK_PROCESSING_BATCH_SIZE:
|
||||
logging.info(f"Upserting batch of {len(batch_documents)} chunks...")
|
||||
collection.upsert(
|
||||
ids=batch_ids,
|
||||
documents=batch_documents,
|
||||
metadatas=batch_metadatas
|
||||
)
|
||||
# Clear the batch lists to free up memory
|
||||
batch_documents, batch_metadatas, batch_ids = [], [], []
|
||||
|
||||
# Force garbage collection and empty CUDA cache
|
||||
logging.info("Cleaning up memory...")
|
||||
gc.collect()
|
||||
if device == 'cuda':
|
||||
torch.cuda.empty_cache()
|
||||
logging.info("Batch upserted and memory cleared.")
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error processing file {file_path}: {e}")
|
||||
|
||||
# 5. Upsert any remaining documents after the loop finishes
|
||||
if batch_documents:
|
||||
logging.info(f"Upserting final batch of {len(batch_documents)} chunks...")
|
||||
collection.upsert(
|
||||
ids=batch_ids,
|
||||
documents=batch_documents,
|
||||
metadatas=batch_metadatas
|
||||
)
|
||||
# Final cleanup
|
||||
gc.collect()
|
||||
if device == 'cuda':
|
||||
torch.cuda.empty_cache()
|
||||
logging.info("Final batch upserted.")
|
||||
|
||||
logging.info("--- Indexing Complete ---")
|
||||
logging.info(f"Total documents in collection: {collection.count()}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,219 @@
|
||||
import logging
|
||||
import chromadb
|
||||
from chromadb.utils import embedding_functions
|
||||
from inference_bot import InferenceBot # Correctly inherit from the ABC
|
||||
from FlagEmbedding import FlagReranker
|
||||
import argparse
|
||||
import os
|
||||
import importlib
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
|
||||
# --- RAG Configuration ---
|
||||
# Must match the settings in create_index.py
|
||||
EMBEDDING_MODEL_NAME = """C:\Models\embeddings\Qwen3-Embedding-0.6B"""
|
||||
CHROMA_DB_PATH = "C:\Models\embeddings\embedding_result\chroma_db"
|
||||
CHROMA_COLLECTION_NAME = "github_repo"
|
||||
|
||||
# Using a powerful open-source reranker model
|
||||
RERANKER_MODEL_NAME = """C:\Models\embeddings\Qwen3-Reranker-0.6B"""
|
||||
|
||||
# Number of initial results to fetch from the database before reranking
|
||||
N_RESULTS_TO_RETRIEVE = 25
|
||||
# Number of final results to keep after reranking
|
||||
N_RESULTS_TO_KEEP_AFTER_RERANK = 5
|
||||
# The minimum relevance score a result must have to be included in the final context
|
||||
RERANKER_SCORE_THRESHOLD = 0.5
|
||||
|
||||
class RAGInferenceBot(InferenceBot):
|
||||
def __init__(self):
|
||||
"""
|
||||
Initializes the RAG components, including a custom implementation
|
||||
for the Qwen3-Reranker based on its model card.
|
||||
"""
|
||||
logging.info("Initializing RAG components...")
|
||||
self._processing_status = {}
|
||||
try:
|
||||
# --- Embedding and Vector DB Initialization ---
|
||||
self.chroma_client = chromadb.PersistentClient(path=CHROMA_DB_PATH)
|
||||
self.embedding_function = embedding_functions.SentenceTransformerEmbeddingFunction(
|
||||
model_name=EMBEDDING_MODEL_NAME
|
||||
)
|
||||
self.collection = self.chroma_client.get_collection(
|
||||
name=CHROMA_COLLECTION_NAME,
|
||||
embedding_function=self.embedding_function
|
||||
)
|
||||
logging.info("Successfully connected to ChromaDB collection for RAG.")
|
||||
|
||||
# --- Custom Reranker Initialization (as per model card) ---
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
logging.info(f"Using device: {self.device} for Reranker model.")
|
||||
|
||||
self.rerank_tokenizer = AutoTokenizer.from_pretrained(RERANKER_MODEL_NAME, padding_side='left')
|
||||
self.rerank_model = AutoModelForCausalLM.from_pretrained(RERANKER_MODEL_NAME, torch_dtype=torch.float16).to(self.device).eval()
|
||||
|
||||
# Manually set the padding token if it's missing
|
||||
if self.rerank_tokenizer.pad_token is None:
|
||||
logging.warning("Reranker tokenizer has no pad_token. Setting it to the eos_token.")
|
||||
self.rerank_tokenizer.pad_token = self.rerank_tokenizer.eos_token
|
||||
|
||||
# Get token IDs for score calculation
|
||||
self.token_false_id = self.rerank_tokenizer.convert_tokens_to_ids("no")
|
||||
self.token_true_id = self.rerank_tokenizer.convert_tokens_to_ids("yes")
|
||||
self.max_length = 8192
|
||||
|
||||
# Define and pre-encode the special prefixes and suffixes from the model card
|
||||
prefix_text = "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n"
|
||||
suffix_text = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
|
||||
self.prefix_tokens = self.rerank_tokenizer.encode(prefix_text, add_special_tokens=False)
|
||||
self.suffix_tokens = self.rerank_tokenizer.encode(suffix_text, add_special_tokens=False)
|
||||
|
||||
logging.info(f"Successfully initialized custom Reranker: {RERANKER_MODEL_NAME}")
|
||||
|
||||
except Exception as e:
|
||||
logging.fatal(f"Failed to initialize RAG components: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
# --- Implementation of all abstract methods from InferenceBot ---
|
||||
|
||||
@property
|
||||
def processing_status(self):
|
||||
return self._processing_status
|
||||
|
||||
def clear_conversation_history(self, user_id):
|
||||
pass
|
||||
|
||||
async def switch_model(self):
|
||||
return "This bot only performs RAG lookups and has no swappable models."
|
||||
|
||||
def set_processing_status(self, user_id, message_id):
|
||||
self._processing_status[user_id] = {"processing": True, "message_id": message_id}
|
||||
|
||||
def clear_processing_status(self, user_id):
|
||||
if user_id in self._processing_status:
|
||||
del self._processing_status[user_id]
|
||||
|
||||
async def abort_processing(self, user_id):
|
||||
if user_id in self.processing_status:
|
||||
self.clear_processing_status(user_id)
|
||||
return "Processing aborted."
|
||||
else:
|
||||
return "No active processing found to abort."
|
||||
|
||||
def get_bot_status(self):
|
||||
return f"RAG Bot is active.\nEmbedding Model: {os.path.basename(EMBEDDING_MODEL_NAME)}\nReranker Model: {os.path.basename(RERANKER_MODEL_NAME)}"
|
||||
|
||||
async def start(self):
|
||||
logging.info(f"{self.__class__.__name__} started.")
|
||||
|
||||
# --- Core RAG Logic ---
|
||||
|
||||
def _format_rerank_instruction(self, query, doc):
|
||||
# Using the default instruction from the model card
|
||||
instruction = 'Given a web search query, retrieve relevant passages that answer the query'
|
||||
return f"<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}"
|
||||
|
||||
@torch.no_grad()
|
||||
def _compute_rerank_scores(self, pairs: list[list[str, str]]):
|
||||
"""
|
||||
Custom score computation logic that follows the model card's example.
|
||||
"""
|
||||
# Format all pairs with the required instruction format
|
||||
formatted_pairs = [self._format_rerank_instruction(query, doc) for query, doc in pairs]
|
||||
|
||||
# Tokenize the formatted pairs
|
||||
inputs = self.rerank_tokenizer(
|
||||
formatted_pairs, padding=False, truncation='longest_first',
|
||||
return_attention_mask=False, max_length=self.max_length - len(self.prefix_tokens) - len(self.suffix_tokens)
|
||||
)
|
||||
|
||||
# Add special prefix and suffix tokens to each item in the batch
|
||||
for i in range(len(inputs['input_ids'])):
|
||||
inputs['input_ids'][i] = self.prefix_tokens + inputs['input_ids'][i] + self.suffix_tokens
|
||||
|
||||
# Pad the batch to the same length
|
||||
inputs = self.rerank_tokenizer.pad(inputs, padding=True, return_tensors="pt", max_length=self.max_length)
|
||||
inputs = {key: inputs[key].to(self.device) for key in inputs}
|
||||
|
||||
# Get model outputs (logits)
|
||||
batch_scores = self.rerank_model(**inputs).logits[:, -1, :]
|
||||
|
||||
# Calculate scores based on the probability of "yes" vs "no"
|
||||
true_vector = batch_scores[:, self.token_true_id]
|
||||
false_vector = batch_scores[:, self.token_false_id]
|
||||
|
||||
scores = torch.stack([false_vector, true_vector], dim=1)
|
||||
scores = torch.nn.functional.log_softmax(scores, dim=1)
|
||||
|
||||
final_scores = scores[:, 1].exp().tolist()
|
||||
return final_scores
|
||||
|
||||
def _retrieve_and_rerank_context(self, query: str):
|
||||
logging.info(f"RAG: Retrieving context for query: '{query}'")
|
||||
if not query: return ""
|
||||
try:
|
||||
results = self.collection.query(query_texts=[query], n_results=N_RESULTS_TO_RETRIEVE)
|
||||
initial_docs = results.get('documents', [[]])[0]
|
||||
if not initial_docs:
|
||||
logging.info("RAG: No initial documents found in vector search.")
|
||||
return "No relevant context found in the knowledge base."
|
||||
|
||||
logging.info(f"RAG: Retrieved {len(initial_docs)} initial documents from ChromaDB.")
|
||||
|
||||
# Create pairs of [query, document] for reranking
|
||||
rerank_pairs = [[query, doc] for doc in initial_docs]
|
||||
|
||||
# Use our custom scoring function
|
||||
scores = self._compute_rerank_scores(rerank_pairs)
|
||||
|
||||
scored_docs = sorted(zip(scores, initial_docs, results['metadatas'][0]), key=lambda x: x[0], reverse=True)
|
||||
|
||||
# Take the top N results
|
||||
top_k_docs = scored_docs[:N_RESULTS_TO_KEEP_AFTER_RERANK]
|
||||
|
||||
# *** NEW: Filter the top N results by the score threshold ***
|
||||
final_docs = [doc for doc in top_k_docs if doc[0] >= RERANKER_SCORE_THRESHOLD]
|
||||
|
||||
logging.info(f"RAG: Reranked and filtered down to {len(final_docs)} documents with score >= {RERANKER_SCORE_THRESHOLD}.")
|
||||
|
||||
# If no documents meet the threshold, inform the user.
|
||||
if not final_docs:
|
||||
return "No highly relevant context found after filtering."
|
||||
|
||||
context_lines = []
|
||||
for i, (score, doc, metadata) in enumerate(final_docs):
|
||||
source = metadata.get('source', 'Unknown file')
|
||||
chunk_index = metadata.get('chunk_index', 'N/A')
|
||||
context_lines.append(f"--- Context {i+1} (Relevance: {score:.2f}) ---\nSource: {source}\n\n{doc}\n")
|
||||
|
||||
return "\n".join(context_lines)
|
||||
except Exception as e:
|
||||
logging.error(f"RAG: Error during context retrieval/reranking: {e}", exc_info=True)
|
||||
return "An error occurred while searching the knowledge base."
|
||||
|
||||
async def handle_message(self, user_id, user_message):
|
||||
context_response = self._retrieve_and_rerank_context(user_message)
|
||||
return context_response
|
||||
|
||||
|
||||
def main_rag():
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
|
||||
parser = argparse.ArgumentParser(description='RAG-Only Inference Bot')
|
||||
parser.add_argument('--messenger', type=str, help='Messenger type (i.e. telegram)', required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
bot = RAGInferenceBot()
|
||||
|
||||
full_code_file = importlib.import_module(f'{args.messenger.lower()}_helper')
|
||||
helper_class = getattr(full_code_file, f"{args.messenger.capitalize()}Helper")
|
||||
helper = helper_class(bot)
|
||||
helper.run()
|
||||
|
||||
except Exception as e:
|
||||
logging.fatal(f"An unexpected error occurred during bot initialization: {e}", exc_info=True)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main_rag()
|
||||
@@ -10,3 +10,9 @@ pytest-cov
|
||||
google-genai
|
||||
httpx==0.27.2
|
||||
tiktoken
|
||||
chromadb
|
||||
sentence-transformers
|
||||
transformers>=4.38
|
||||
torch --index-url https://download.pytorch.org/whl/cu121
|
||||
torchvision --index-url https://download.pytorch.org/whl/cu121
|
||||
torchaudio --index-url https://download.pytorch.org/whl/cu121
|
||||
+1275
-4
File diff suppressed because one or more lines are too long
@@ -3,6 +3,7 @@ import os
|
||||
import json
|
||||
import logging
|
||||
from openai import OpenAI
|
||||
import re
|
||||
import urllib.request
|
||||
import urllib.error
|
||||
|
||||
@@ -82,10 +83,12 @@ class StandaloneLLMTool(BaseTool):
|
||||
headers={'Content-Type': 'text/plain; charset=utf-8', 'User-Agent': 'DualAICopilot/0.1'},
|
||||
method='POST'
|
||||
)
|
||||
with urllib.request.urlopen(req, timeout=500) as response:
|
||||
with urllib.request.urlopen(req, timeout=3600) as response:
|
||||
if response.status == 200:
|
||||
response_data = response.read().decode('utf-8')
|
||||
logging.info(f"Received response from external copilot: {response_data[:100]}...")
|
||||
# Remove content within <think> tags
|
||||
response_data = re.sub(r"<think>.*?</think>", "", response_data, flags=re.DOTALL)
|
||||
return response_data
|
||||
else:
|
||||
error_message = f"External copilot at {self.copilot_url} returned an error: {response.status} {response.reason}"
|
||||
|
||||
Reference in New Issue
Block a user