Chapter 3: RAG Implementation Guide
This chapter provides a comprehensive, hands-on guide to implementing RAG systems using the patterns and code from your provided materials. We'll build from basic implementations to production-ready systems.
Document Processing Implementation
Document Loaders
Let's start by implementing the document loading system based on your retrieval.md specifications:
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional, Iterator, Union
from pathlib import Path
import hashlib
import mimetypes
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class Document:
"""Document class to represent loaded documents"""
def __init__(
self,
page_content: str,
metadata: Optional[Dict[str, Any]] = None
):
self.page_content = page_content
self.metadata = metadata or {}
def __repr__(self) -> str:
return f"Document(page_content='{self.page_content[:50]}...', metadata={self.metadata})"
class BaseLoader(ABC):
"""Abstract base class for document loaders"""
@abstractmethod
def load(self) -> List[Document]:
"""Load documents and return a list of Document objects."""
pass
def lazy_load(self) -> Iterator[Document]:
"""Lazy load documents one at a time."""
for doc in self.load():
yield doc
class TextLoader(BaseLoader):
"""Load text files with automatic encoding detection"""
def __init__(
self,
file_path: Union[str, Path],
encoding: Optional[str] = None,
autodetect_encoding: bool = True
):
self.file_path = Path(file_path)
self.encoding = encoding
self.autodetect_encoding = autodetect_encoding
def load(self) -> List[Document]:
"""Load text file with encoding detection"""
try:
encoding = self._detect_encoding() if self.autodetect_encoding else self.encoding
with open(self.file_path, 'r', encoding=encoding) as f:
text = f.read()
metadata = self._extract_metadata(encoding)
return [Document(page_content=text, metadata=metadata)]
except Exception as e:
logger.error(f"Error loading {self.file_path}: {e}")
raise
def _detect_encoding(self) -> str:
"""Detect file encoding"""
if self.encoding:
return self.encoding
try:
import chardet
with open(self.file_path, 'rb') as f:
raw_data = f.read()
result = chardet.detect(raw_data)
return result['encoding'] or 'utf-8'
except ImportError:
return 'utf-8'
def _extract_metadata(self, encoding: str) -> Dict[str, Any]:
"""Extract file metadata"""
stat = self.file_path.stat()
return {
'source': str(self.file_path),
'file_name': self.file_path.name,
'file_size': stat.st_size,
'creation_date': stat.st_ctime,
'modification_date': stat.st_mtime,
'encoding': encoding,
'file_type': 'text'
}
class DirectoryLoader(BaseLoader):
"""Load all documents from a directory with flexible file type support"""
def __init__(
self,
path: Union[str, Path],
glob: str = "**/*",
exclude: Optional[List[str]] = None,
loader_cls_mapping: Optional[Dict[str, BaseLoader]] = None,
loader_kwargs: Optional[Dict[str, Dict]] = None,
recursive: bool = True,
show_progress: bool = False,
use_multithreading: bool = False,
max_concurrency: int = 4,
sample_size: int = 0,
randomize_sample: bool = False,
silent_errors: bool = False
):
self.path = Path(path)
self.glob = glob
self.exclude = exclude or []
self.loader_cls_mapping = loader_cls_mapping or {}
self.loader_kwargs = loader_kwargs or {}
self.recursive = recursive
self.show_progress = show_progress
self.use_multithreading = use_multithreading
self.max_concurrency = max_concurrency
self.sample_size = sample_size
self.randomize_sample = randomize_sample
self.silent_errors = silent_errors
def load(self) -> List[Document]:
"""Load all documents from directory"""
file_paths = self._get_file_paths()
if self.sample_size > 0:
file_paths = self._sample_files(file_paths)
if self.use_multithreading:
return self._load_with_threading(file_paths)
else:
return self._load_sequential(file_paths)
def _get_file_paths(self) -> List[Path]:
"""Get all file paths matching the glob pattern"""
if self.recursive:
file_paths = list(self.path.rglob(self.glob))
else:
file_paths = list(self.path.glob(self.glob))
# Filter out excluded files
filtered_paths = []
for file_path in file_paths:
if file_path.is_file():
# Check if file should be excluded
should_exclude = False
for exclude_pattern in self.exclude:
if file_path.match(exclude_pattern):
should_exclude = True
break
if not should_exclude:
filtered_paths.append(file_path)
return filtered_paths
def _sample_files(self, file_paths: List[Path]) -> List[Path]:
"""Sample files if sample_size is specified"""
if self.sample_size >= len(file_paths):
return file_paths
if self.randomize_sample:
import random
return random.sample(file_paths, self.sample_size)
else:
return file_paths[:self.sample_size]
def _load_sequential(self, file_paths: List[Path]) -> List[Document]:
"""Load files sequentially"""
documents = []
if self.show_progress:
try:
from tqdm import tqdm
file_paths = tqdm(file_paths, desc="Loading documents")
except ImportError:
pass
for file_path in file_paths:
try:
loader = self._get_loader_for_file(file_path)
docs = loader.load()
documents.extend(docs)
except Exception as e:
if not self.silent_errors:
logger.error(f"Error loading {file_path}: {e}")
continue
return documents
def _load_with_threading(self, file_paths: List[Path]) -> List[Document]:
"""Load files using multithreading"""
from concurrent.futures import ThreadPoolExecutor, as_completed
documents = []
with ThreadPoolExecutor(max_workers=self.max_concurrency) as executor:
future_to_path = {
executor.submit(self._load_single_file, path): path
for path in file_paths
}
if self.show_progress:
try:
from tqdm import tqdm
future_to_path = tqdm(future_to_path.items(), desc="Loading documents")
except ImportError:
future_to_path = future_to_path.items()
for future, path in future_to_path:
try:
docs = future.result()
documents.extend(docs)
except Exception as e:
if not self.silent_errors:
logger.error(f"Error loading {path}: {e}")
continue
return documents
def _load_single_file(self, file_path: Path) -> List[Document]:
"""Load a single file"""
loader = self._get_loader_for_file(file_path)
return loader.load()
def _get_loader_for_file(self, file_path: Path) -> BaseLoader:
"""Get appropriate loader for file type"""
# Check for explicit file extension mapping
file_ext = file_path.suffix.lower()
if file_ext in self.loader_cls_mapping:
loader_cls = self.loader_cls_mapping[file_ext]
kwargs = self.loader_kwargs.get(file_ext, {})
return loader_cls(file_path, **kwargs)
# Default loaders based on file type
if file_ext == '.pdf':
return PDFLoader(file_path)
elif file_ext == '.csv':
return CSVLoader(file_path)
elif file_ext in ['.txt', '.md', '.py', '.js', '.html', '.xml', '.json']:
return TextLoader(file_path)
else:
# Try to determine if it's a text file
mime_type, _ = mimetypes.guess_type(str(file_path))
if mime_type and mime_type.startswith('text'):
return TextLoader(file_path)
else:
raise ValueError(f"Unsupported file type: {file_ext}")
class PDFLoader(BaseLoader):
"""Load PDF documents with advanced parsing capabilities"""
def __init__(
self,
file_path: Union[str, Path],
password: Optional[str] = None,
extract_images: bool = False,
headers: Optional[Dict] = None
):
self.file_path = Path(file_path)
self.password = password
self.extract_images = extract_images
self.headers = headers
def load(self) -> List[Document]:
"""Load PDF file with PyPDF2 or pdfplumber"""
try:
# Try pdfplumber first for better text extraction
return self._load_with_pdfplumber()
except ImportError:
try:
# Fallback to PyPDF2
return self._load_with_pypdf2()
except ImportError:
raise ImportError("Please install 'pdfplumber' or 'PyPDF2' to use PDFLoader")
def _load_with_pdfplumber(self) -> List[Document]:
"""Load PDF using pdfplumber for better text extraction"""
import pdfplumber
documents = []
with pdfplumber.open(self.file_path, password=self.password) as pdf:
for page_num, page in enumerate(pdf.pages):
text = page.extract_text() or ""
if text.strip(): # Only process pages with text
metadata = {
'source': str(self.file_path),
'file_name': self.file_path.name,
'page': page_num + 1,
'total_pages': len(pdf.pages),
'file_type': 'pdf'
}
# Extract additional page info if available
if hasattr(page, 'bbox'):
metadata['page_width'] = page.bbox[2]
metadata['page_height'] = page.bbox[3]
documents.append(Document(page_content=text, metadata=metadata))
return documents
def _load_with_pypdf2(self) -> List[Document]:
"""Load PDF using PyPDF2 as fallback"""
import PyPDF2
documents = []
with open(self.file_path, 'rb') as file:
reader = PyPDF2.PdfReader(file)
# Handle password protection
if reader.is_encrypted and self.password:
reader.decrypt(self.password)
for page_num, page in enumerate(reader.pages):
text = page.extract_text()
if text.strip():
metadata = {
'source': str(self.file_path),
'file_name': self.file_path.name,
'page': page_num + 1,
'total_pages': len(reader.pages),
'file_type': 'pdf'
}
documents.append(Document(page_content=text, metadata=metadata))
return documents
class CSVLoader(BaseLoader):
"""Load CSV files with flexible column handling"""
def __init__(
self,
file_path: Union[str, Path],
csv_args: Optional[Dict] = None,
content_columns: Optional[List[str]] = None,
metadata_columns: Optional[List[str]] = None,
encoding: str = "utf-8",
autodetect_encoding: bool = True
):
self.file_path = Path(file_path)
self.csv_args = csv_args or {}
self.content_columns = content_columns
self.metadata_columns = metadata_columns or []
self.encoding = encoding
self.autodetect_encoding = autodetect_encoding
def load(self) -> List[Document]:
"""Load CSV file"""
try:
import pandas as pd
except ImportError:
raise ImportError("Please install 'pandas' to use CSVLoader")
encoding = self._detect_encoding() if self.autodetect_encoding else self.encoding
try:
df = pd.read_csv(self.file_path, encoding=encoding, **self.csv_args)
except Exception as e:
logger.error(f"Error reading CSV file {self.file_path}: {e}")
raise
documents = []
for index, row in df.iterrows():
# Determine content
if self.content_columns:
content_parts = []
for col in self.content_columns:
if col in df.columns:
value = row[col]
if pd.notna(value): # Skip NaN values
content_parts.append(str(value))
content = " ".join(content_parts)
else:
# Use all columns as content
content_parts = []
for col, value in row.items():
if pd.notna(value):
content_parts.append(f"{col}: {value}")
content = " | ".join(content_parts)
# Build metadata
metadata = {
'source': str(self.file_path),
'file_name': self.file_path.name,
'row': index,
'file_type': 'csv',
'encoding': encoding
}
# Add specified metadata columns
for col in self.metadata_columns:
if col in df.columns and pd.notna(row[col]):
metadata[col] = row[col]
# Add all columns as metadata if no specific columns specified
if not self.metadata_columns:
for col, value in row.items():
if pd.notna(value) and col not in (self.content_columns or []):
metadata[col] = value
documents.append(Document(page_content=content, metadata=metadata))
return documents
def _detect_encoding(self) -> str:
"""Detect CSV file encoding"""
if not self.autodetect_encoding:
return self.encoding
try:
import chardet
with open(self.file_path, 'rb') as f:
raw_data = f.read()
result = chardet.detect(raw_data)
return result['encoding'] or 'utf-8'
except ImportError:
return self.encoding
class JSONLoader(BaseLoader):
"""Load JSON and JSONL files"""
def __init__(
self,
file_path: Union[str, Path],
jq_schema: Optional[str] = None,
content_key: Optional[str] = None,
metadata_func: Optional[callable] = None,
text_content: bool = True,
json_lines: bool = False
):
self.file_path = Path(file_path)
self.jq_schema = jq_schema
self.content_key = content_key
self.metadata_func = metadata_func
self.text_content = text_content
self.json_lines = json_lines
def load(self) -> List[Document]:
"""Load JSON or JSONL file"""
try:
import json
except ImportError:
raise ImportError("JSON support is required")
documents = []
with open(self.file_path, 'r', encoding='utf-8') as f:
if self.json_lines:
# Handle JSONL format
for line_num, line in enumerate(f, 1):
if line.strip():
try:
data = json.loads(line)
doc = self._process_json_data(data, line_num)
if doc:
documents.append(doc)
except json.JSONDecodeError as e:
logger.error(f"Error parsing line {line_num}: {e}")
continue
else:
# Handle regular JSON
try:
data = json.load(f)
if isinstance(data, list):
for i, item in enumerate(data):
doc = self._process_json_data(item, i)
if doc:
documents.append(doc)
else:
doc = self._process_json_data(data, 0)
if doc:
documents.append(doc)
except json.JSONDecodeError as e:
logger.error(f"Error parsing JSON file: {e}")
raise
return documents
def _process_json_data(self, data: Dict[str, Any], index: int) -> Optional[Document]:
"""Process individual JSON object into Document"""
try:
# Extract content
if self.content_key and self.content_key in data:
content = str(data[self.content_key])
elif self.text_content:
# Convert entire object to text
content = json.dumps(data, indent=2, ensure_ascii=False)
else:
content = str(data)
# Build metadata
metadata = {
'source': str(self.file_path),
'file_name': self.file_path.name,
'seq_num': index,
'file_type': 'json'
}
# Add custom metadata
if self.metadata_func:
custom_metadata = self.metadata_func(data, metadata)
metadata.update(custom_metadata)
else:
# Add all fields except content as metadata
for key, value in data.items():
if key != self.content_key:
metadata[key] = value
return Document(page_content=content, metadata=metadata)
except Exception as e:
logger.error(f"Error processing JSON data at index {index}: {e}")
return None
# Usage Example: Document Loading Pipeline
class DocumentLoadingPipeline:
"""Complete document loading pipeline"""
def __init__(self, show_progress: bool = True):
self.show_progress = show_progress
self.loaded_documents = []
def load_from_directory(
self,
directory: Union[str, Path],
file_types: Optional[List[str]] = None,
recursive: bool = True,
max_concurrency: int = 4
) -> List[Document]:
"""Load documents from directory with multiple file types"""
# Define default loader mappings
loader_mapping = {
'.txt': TextLoader,
'.md': TextLoader,
'.py': TextLoader,
'.js': TextLoader,
'.html': TextLoader,
'.xml': TextLoader,
'.pdf': PDFLoader,
'.csv': CSVLoader,
'.json': JSONLoader,
'.jsonl': lambda path: JSONLoader(path, json_lines=True)
}
# Filter by specified file types
if file_types:
loader_mapping = {
ext: loader for ext, loader in loader_mapping.items()
if ext in file_types
}
# Create directory loader
directory_loader = DirectoryLoader(
path=directory,
loader_cls_mapping=loader_mapping,
recursive=recursive,
show_progress=self.show_progress,
use_multithreading=True,
max_concurrency=max_concurrency,
silent_errors=True
)
# Load documents
documents = directory_loader.load()
self.loaded_documents.extend(documents)
logger.info(f"Loaded {len(documents)} documents from {directory}")
return documents
def load_from_files(self, file_paths: List[Union[str, Path]]) -> List[Document]:
"""Load documents from specific file paths"""
documents = []
for file_path in file_paths:
file_path = Path(file_path)
try:
# Get appropriate loader
if file_path.suffix.lower() == '.pdf':
loader = PDFLoader(file_path)
elif file_path.suffix.lower() == '.csv':
loader = CSVLoader(file_path)
elif file_path.suffix.lower() in ['.json', '.jsonl']:
loader = JSONLoader(file_path, json_lines=file_path.suffix == '.jsonl')
else:
loader = TextLoader(file_path)
docs = loader.load()
documents.extend(docs)
except Exception as e:
logger.error(f"Error loading {file_path}: {e}")
continue
self.loaded_documents.extend(documents)
logger.info(f"Loaded {len(documents)} documents from {len(file_paths)} files")
return documents
def get_statistics(self) -> Dict[str, Any]:
"""Get loading statistics"""
if not self.loaded_documents:
return {"total_documents": 0}
stats = {
"total_documents": len(self.loaded_documents),
"total_characters": sum(len(doc.page_content) for doc in self.loaded_documents),
"avg_document_length": sum(len(doc.page_content) for doc in self.loaded_documents) / len(self.loaded_documents),
"file_types": {},
"sources": set()
}
# Analyze by file type
for doc in self.loaded_documents:
file_type = doc.metadata.get('file_type', 'unknown')
stats["file_types"][file_type] = stats["file_types"].get(file_type, 0) + 1
stats["sources"].add(doc.metadata.get('source', 'unknown'))
stats["unique_sources"] = len(stats["sources"])
stats["sources"] = list(stats["sources"]) # Convert set to list for JSON serialization
return stats
# Example usage
if __name__ == "__main__":
# Initialize pipeline
pipeline = DocumentLoadingPipeline()
# Load from directory
documents = pipeline.load_from_directory(
directory="./documents",
file_types=['.txt', '.pdf', '.md'],
recursive=True
)
# Print statistics
stats = pipeline.get_statistics()
print(f"Loaded {stats['total_documents']} documents")
print(f"File types: {stats['file_types']}")
print(f"Average length: {stats['avg_document_length']:.0f} characters")Text Splitting Implementation
Building on your text splitting specifications, here's the comprehensive implementation:
import re
import tiktoken
from typing import List, Callable, Optional, Any, Dict, Union
from abc import ABC, abstractmethod
from dataclasses import dataclass
@dataclass
class ChunkMetadata:
"""Metadata for text chunks"""
chunk_index: int
total_chunks: int
start_index: int
end_index: int
chunk_size: int
overlap_size: int
source_metadata: Dict[str, Any]
class BaseTextSplitter(ABC):
"""Abstract base class for text splitters"""
def __init__(
self,
chunk_size: int = 4000,
chunk_overlap: int = 200,
length_function: Callable[[str], int] = len,
keep_separator: bool = False,
add_start_index: bool = False,
strip_whitespace: bool = True
):
if chunk_overlap >= chunk_size:
raise ValueError("chunk_overlap must be less than chunk_size")
self._chunk_size = chunk_size
self._chunk_overlap = chunk_overlap
self._length_function = length_function
self._keep_separator = keep_separator
self._add_start_index = add_start_index
self._strip_whitespace = strip_whitespace
@abstractmethod
def split_text(self, text: str) -> List[str]:
"""Split text into chunks"""
pass
def split_documents(self, documents: List[Document]) -> List[Document]:
"""Split a list of documents"""
texts, metadatas = [], []
for doc in documents:
texts.append(doc.page_content)
metadatas.append(doc.metadata)
return self.create_documents(texts, metadatas)
def create_documents(
self, texts: List[str], metadatas: Optional[List[Dict]] = None
) -> List[Document]:
"""Create documents from texts and metadatas"""
_metadatas = metadatas or [{}] * len(texts)
documents = []
for i, text in enumerate(texts):
index = 0
chunks = self.split_text(text)
for j, chunk in enumerate(chunks):
metadata = _metadatas[i].copy()
# Add chunk metadata
chunk_metadata = ChunkMetadata(
chunk_index=j,
total_chunks=len(chunks),
start_index=index if self._add_start_index else None,
end_index=index + len(chunk) if self._add_start_index else None,
chunk_size=len(chunk),
overlap_size=self._chunk_overlap,
source_metadata=metadata
)
metadata.update({
'chunk_index': chunk_metadata.chunk_index,
'total_chunks': chunk_metadata.total_chunks,
'chunk_size': chunk_metadata.chunk_size
})
if self._add_start_index:
metadata.update({
'start_index': chunk_metadata.start_index,
'end_index': chunk_metadata.end_index
})
documents.append(Document(page_content=chunk, metadata=metadata))
index += len(chunk) - self._chunk_overlap
return documents
@property
def chunk_size(self) -> int:
return self._chunk_size
@property
def chunk_overlap(self) -> int:
return self._chunk_overlap
class CharacterTextSplitter(BaseTextSplitter):
"""Simple character-based text splitter"""
def __init__(self, separator: str = "\n\n", **kwargs):
super().__init__(**kwargs)
self._separator = separator
def split_text(self, text: str) -> List[str]:
"""Split text by separator"""
if self._separator:
splits = text.split(self._separator)
else:
splits = [text[i:i+1] for i in range(len(text))]
return self._merge_splits(splits, self._separator)
def _merge_splits(self, splits: List[str], separator: str) -> List[str]:
"""Merge splits into appropriately sized chunks"""
separator_len = self._length_function(separator)
docs = []
current_doc = []
total = 0
for s in splits:
_len = self._length_function(s)
# Check if adding this split would exceed chunk size
if (total + _len + (separator_len * len(current_doc)) > self._chunk_size
and current_doc):
# Join current doc and add to results
doc = separator.join(current_doc)
if self._strip_whitespace:
doc = doc.strip()
if doc:
docs.append(doc)
# Keep overlap
while (total > self._chunk_overlap or
(total + _len + (separator_len * len(current_doc)) > self._chunk_size
and total > 0)):
total -= self._length_function(current_doc[0]) + separator_len
current_doc = current_doc[1:]
current_doc.append(s)
total += _len
# Add final document
if current_doc:
doc = separator.join(current_doc)
if self._strip_whitespace:
doc = doc.strip()
if doc:
docs.append(doc)
return docs
class RecursiveCharacterTextSplitter(BaseTextSplitter):
"""Recursively split text using multiple separators"""
def __init__(
self,
separators: Optional[List[str]] = None,
is_separator_regex: bool = False,
**kwargs
):
super().__init__(**kwargs)
self._separators = separators or ["\n\n", "\n", " ", ""]
self._is_separator_regex = is_separator_regex
def split_text(self, text: str) -> List[str]:
"""Split text recursively using separators"""
return self._split_text_recursive(text, self._separators)
def _split_text_recursive(self, text: str, separators: List[str]) -> List[str]:
"""Recursively split text using separators"""
final_chunks = []
separator = separators[-1]
new_separators = []
for i, _s in enumerate(separators):
_separator = _s if self._is_separator_regex else re.escape(_s)
if _s == "":
separator = _s
break
if re.search(_separator, text):
separator = _s
new_separators = separators[i + 1:]
break
_separator = separator if self._is_separator_regex else re.escape(separator)
splits = self._split_text_with_separator(text, separator, _separator)
# Merge splits appropriately
_good_splits = []
_separator = "" if self._keep_separator else separator
for s in splits:
if self._length_function(s) < self._chunk_size:
_good_splits.append(s)
else:
if _good_splits:
merged_text = self._merge_splits(_good_splits, _separator)
final_chunks.extend(merged_text)
_good_splits = []
if not new_separators:
final_chunks.append(s)
else:
other_info = self._split_text_recursive(s, new_separators)
final_chunks.extend(other_info)
if _good_splits:
merged_text = self._merge_splits(_good_splits, _separator)
final_chunks.extend(merged_text)
return final_chunks
def _split_text_with_separator(
self, text: str, separator: str, _separator: str
) -> List[str]:
"""Split text with separator handling"""
if separator:
if self._is_separator_regex:
splits = re.split(_separator, text)
else:
splits = text.split(separator)
else:
splits = list(text)
# Handle separator retention
if self._keep_separator and separator:
_splits = []
for i in range(len(splits)):
_splits.append(splits[i])
if i < len(splits) - 1:
_splits.append(separator)
splits = _splits
return [s for s in splits if s != ""]
def _merge_splits(self, splits: List[str], separator: str) -> List[str]:
"""Merge splits into appropriately sized chunks"""
separator_len = self._length_function(separator)
docs = []
current_doc = []
total = 0
for s in splits:
_len = self._length_function(s)
# Check if adding this split would exceed chunk size
if (total + _len + (separator_len * len(current_doc)) > self._chunk_size
and current_doc):
# Join current doc and add to results
doc = separator.join(current_doc)
if self._strip_whitespace:
doc = doc.strip()
if doc:
docs.append(doc)
# Keep overlap
while (total > self._chunk_overlap or
(total + _len + (separator_len * len(current_doc)) > self._chunk_size
and total > 0)):
total -= self._length_function(current_doc[0]) + separator_len
current_doc = current_doc[1:]
current_doc.append(s)
total += _len
# Add final document
if current_doc:
doc = separator.join(current_doc)
if self._strip_whitespace:
doc = doc.strip()
if doc:
docs.append(doc)
return docs
class TokenTextSplitter(BaseTextSplitter):
"""Split text based on token count using tiktoken"""
def __init__(
self,
encoding_name: str = "cl100k_base",
model_name: Optional[str] = None,
allowed_special: Union[str, set] = set(),
disallowed_special: Union[str, set] = "all",
**kwargs
):
super().__init__(**kwargs)
# Initialize tokenizer
if model_name is not None:
try:
self._tokenizer = tiktoken.encoding_for_model(model_name)
except KeyError:
logger.warning(f"Model {model_name} not found. Using {encoding_name} encoding.")
self._tokenizer = tiktoken.get_encoding(encoding_name)
else:
self._tokenizer = tiktoken.get_encoding(encoding_name)
self._allowed_special = allowed_special
self._disallowed_special = disallowed_special
def split_text(self, text: str) -> List[str]:
"""Split text based on token count"""
def _encode(_text: str) -> List[int]:
return self._tokenizer.encode(
_text,
allowed_special=self._allowed_special,
disallowed_special=self._disallowed_special,
)
tokenized_text = _encode(text)
chunks = []
start_idx = 0
cur_idx = min(start_idx + self._chunk_size, len(tokenized_text))
chunk_ids = tokenized_text[start_idx:cur_idx]
while start_idx < len(tokenized_text):
chunk_text = self._tokenizer.decode(chunk_ids)
if self._strip_whitespace:
chunk_text = chunk_text.strip()
if chunk_text:
chunks.append(chunk_text)
# Move to next chunk with overlap
start_idx += self._chunk_size - self._chunk_overlap
cur_idx = min(start_idx + self._chunk_size, len(tokenized_text))
chunk_ids = tokenized_text[start_idx:cur_idx]
return chunks
@classmethod
def from_tiktoken_encoder(
cls,
encoding: tiktoken.Encoding,
chunk_size: int = 4000,
chunk_overlap: int = 200,
**kwargs
) -> "TokenTextSplitter":
"""Create from tiktoken encoder directly"""
obj = cls(chunk_size=chunk_size, chunk_overlap=chunk_overlap, **kwargs)
obj._tokenizer = encoding
return obj
class MarkdownHeaderTextSplitter:
"""Split markdown text based on headers"""
def __init__(
self,
headers_to_split_on: List[tuple],
return_each_line: bool = False,
strip_headers: bool = True
):
"""
Args:
headers_to_split_on: List of tuples (header_level, header_name)
e.g., [("#", "Header 1"), ("##", "Header 2")]
return_each_line: Whether to return each line as separate chunk
strip_headers: Whether to strip header markup from content
"""
self.headers_to_split_on = headers_to_split_on
self.return_each_line = return_each_line
self.strip_headers = strip_headers
def split_text(self, text: str) -> List[Document]:
"""Split markdown text based on headers"""
lines = text.split('\n')
chunks = []
current_chunk = []
current_metadata = {}
for line in lines:
header_info = self._is_header(line)
if header_info:
# Save current chunk if it exists
if current_chunk:
content = '\n'.join(current_chunk)
if content.strip():
chunks.append(Document(
page_content=content,
metadata=current_metadata.copy()
))
# Start new chunk
current_chunk = []
header_level, header_text = header_info
current_metadata[header_level] = header_text
# Remove lower-level headers from metadata
headers_to_remove = []
for level, _ in self.headers_to_split_on:
if len(level) > len(header_level):
headers_to_remove.append(level)
for level in headers_to_remove:
current_metadata.pop(level, None)
# Add header to chunk if not stripping
if not self.strip_headers:
current_chunk.append(line)
elif self.return_each_line:
current_chunk.append(header_text)
else:
current_chunk.append(line)
# Add final chunk
if current_chunk:
content = '\n'.join(current_chunk)
if content.strip():
chunks.append(Document(
page_content=content,
metadata=current_metadata.copy()
))
return chunks
def _is_header(self, line: str) -> Optional[tuple]:
"""Check if line is a header and return (level, text)"""
stripped_line = line.strip()
for header_level, header_name in self.headers_to_split_on:
if stripped_line.startswith(header_level + " "):
header_text = stripped_line[len(header_level):].strip()
return header_level, header_text
return None
# Advanced Splitters
class SemanticTextSplitter(BaseTextSplitter):
"""Split text based on semantic boundaries using embeddings"""
def __init__(
self,
embedding_model,
similarity_threshold: float = 0.5,
sentence_split_regex: str = r'(?<=[.!?])\s+',
**kwargs
):
super().__init__(**kwargs)
self.embedding_model = embedding_model
self.similarity_threshold = similarity_threshold
self.sentence_split_regex = sentence_split_regex
def split_text(self, text: str) -> List[str]:
"""Split text based on semantic similarity"""
# Split into sentences
sentences = self._split_into_sentences(text)
if len(sentences) <= 1:
return [text] if text.strip() else []
# Generate embeddings for sentences
try:
embeddings = self.embedding_model.embed_documents(sentences)
except Exception as e:
logger.warning(f"Error generating embeddings: {e}. Falling back to character splitting.")
# Fallback to character-based splitting
char_splitter = RecursiveCharacterTextSplitter(
chunk_size=self._chunk_size,
chunk_overlap=self._chunk_overlap
)
return char_splitter.split_text(text)
# Find semantic boundaries
boundaries = self._find_semantic_boundaries(embeddings)
# Create chunks based on boundaries
chunks = []
for i in range(len(boundaries) - 1):
start_idx = boundaries[i]
end_idx = boundaries[i + 1]
chunk_sentences = sentences[start_idx:end_idx]
chunk_text = ' '.join(chunk_sentences)
# Further split if chunk is too large
if self._length_function(chunk_text) > self._chunk_size:
sub_chunks = self._split_large_chunk(chunk_text)
chunks.extend(sub_chunks)
else:
if chunk_text.strip():
chunks.append(chunk_text)
return chunks
def _split_into_sentences(self, text: str) -> List[str]:
"""Split text into sentences using regex"""
sentences = re.split(self.sentence_split_regex, text)
return [s.strip() for s in sentences if s.strip()]
def _find_semantic_boundaries(self, embeddings: List[List[float]]) -> List[int]:
"""Find semantic boundaries based on embedding similarity"""
import numpy as np
boundaries = [0] # Always start with first sentence
for i in range(1, len(embeddings)):
similarity = self._cosine_similarity(
np.array(embeddings[i-1]),
np.array(embeddings[i])
)
if similarity < self.similarity_threshold:
boundaries.append(i)
boundaries.append(len(embeddings)) # Always end with last sentence
return boundaries
def _cosine_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> float:
"""Calculate cosine similarity between two vectors"""
import numpy as np
dot_product = np.dot(vec1, vec2)
norm_a = np.linalg.norm(vec1)
norm_b = np.linalg.norm(vec2)
if norm_a == 0 or norm_b == 0:
return 0.0
return dot_product / (norm_a * norm_b)
def _split_large_chunk(self, text: str) -> List[str]:
"""Split chunks that are too large using character-based splitting"""
char_splitter = RecursiveCharacterTextSplitter(
chunk_size=self._chunk_size,
chunk_overlap=self._chunk_overlap,
length_function=self._length_function
)
return char_splitter.split_text(text)
# Text Splitting Pipeline
class TextSplittingPipeline:
"""Complete text splitting pipeline with multiple strategies"""
def __init__(self):
self.splitters = {}
self.default_splitter_name = "recursive"
self._setup_default_splitters()
def _setup_default_splitters(self):
"""Setup default text splitters"""
self.splitters = {
"character": CharacterTextSplitter(
separator="\n\n",
chunk_size=1000,
chunk_overlap=200
),
"recursive": RecursiveCharacterTextSplitter(
separators=["\n\n", "\n", " ", ""],
chunk_size=1000,
chunk_overlap=200
),
"token": TokenTextSplitter(
encoding_name="cl100k_base",
chunk_size=1000,
chunk_overlap=200
),
"markdown": MarkdownHeaderTextSplitter(
headers_to_split_on=[
("#", "Header 1"),
("##", "Header 2"),
("###", "Header 3"),
("####", "Header 4")
]
)
}
def add_splitter(self, name: str, splitter: BaseTextSplitter):
"""Add custom splitter"""
self.splitters[name] = splitter
def split_documents(
self,
documents: List[Document],
splitter_name: Optional[str] = None,
chunk_size: Optional[int] = None,
chunk_overlap: Optional[int] = None
) -> List[Document]:
"""Split documents using specified splitter"""
splitter_name = splitter_name or self.default_splitter_name
if splitter_name not in self.splitters:
raise ValueError(f"Unknown splitter: {splitter_name}")
splitter = self.splitters[splitter_name]
# Update chunk size/overlap if provided
if chunk_size is not None or chunk_overlap is not None:
if hasattr(splitter, '_chunk_size') and chunk_size is not None:
splitter._chunk_size = chunk_size
if hasattr(splitter, '_chunk_overlap') and chunk_overlap is not None:
splitter._chunk_overlap = chunk_overlap
# Handle markdown splitter differently (returns Documents)
if splitter_name == "markdown":
all_chunks = []
for doc in documents:
chunks = splitter.split_text(doc.page_content)
# Add original metadata to chunks
for chunk in chunks:
chunk.metadata.update(doc.metadata)
all_chunks.extend(chunks)
return all_chunks
else:
return splitter.split_documents(documents)
def get_optimal_chunk_size(
self,
documents: List[Document],
target_chunk_count: int = 100
) -> int:
"""Estimate optimal chunk size based on document collection"""
total_chars = sum(len(doc.page_content) for doc in documents)
avg_chunk_size = total_chars // target_chunk_count
# Round to nearest 100 and ensure minimum
optimal_size = max(500, round(avg_chunk_size / 100) * 100)
logger.info(f"Estimated optimal chunk size: {optimal_size} characters")
return optimal_size
def analyze_splitting_results(
self,
original_documents: List[Document],
split_documents: List[Document]
) -> Dict[str, Any]:
"""Analyze splitting results"""
original_lengths = [len(doc.page_content) for doc in original_documents]
split_lengths = [len(doc.page_content) for doc in split_documents]
return {
"original_count": len(original_documents),
"split_count": len(split_documents),
"expansion_ratio": len(split_documents) / len(original_documents),
"original_stats": {
"total_chars": sum(original_lengths),
"avg_length": sum(original_lengths) / len(original_lengths),
"min_length": min(original_lengths),
"max_length": max(original_lengths)
},
"split_stats": {
"total_chars": sum(split_lengths),
"avg_length": sum(split_lengths) / len(split_lengths),
"min_length": min(split_lengths),
"max_length": max(split_lengths)
}
}
# Usage Example
if __name__ == "__main__":
# Initialize pipeline
pipeline = TextSplittingPipeline()
# Load some documents (using previous example)
doc_pipeline = DocumentLoadingPipeline()
documents = doc_pipeline.load_from_directory("./documents")
# Split documents
split_docs = pipeline.split_documents(
documents,
splitter_name="recursive",
chunk_size=1000,
chunk_overlap=200
)
# Analyze results
analysis = pipeline.analyze_splitting_results(documents, split_docs)
print(f"Split {analysis['original_count']} documents into {analysis['split_count']} chunks")
print(f"Average chunk size: {analysis['split_stats']['avg_length']:.0f} characters")This implementation provides a comprehensive document processing and text splitting system based on your RAG materials. The system includes:
- Multiple document loaders for different file types (PDF, CSV, JSON, text)
- Advanced text splitting strategies including recursive, token-based, and semantic splitting
- Flexible configuration for chunk sizes, overlaps, and processing options
- Error handling and logging for production use
- Performance optimizations including multithreading and caching
- Analysis tools for understanding processing results
Next, we'll implement the embedding and vector storage components in Chapter 4: Advanced RAG Techniques.