Module 1: LLM
Chapter 5: Performance Optimization

Chapter 5: Performance Optimization

Introduction to LLM Performance Optimization

Performance optimization for Large Language Models involves maximizing throughput, minimizing latency, and reducing computational costs while maintaining output quality. This chapter covers advanced techniques and strategies for achieving optimal performance in production environments.

Performance Metrics and Benchmarking

Key Performance Indicators

import time
import psutil
import torch
from collections import defaultdict
 
class PerformanceMonitor:
    def __init__(self):
        self.metrics = defaultdict(list)
        self.start_time = None
 
    def start_measurement(self):
        torch.cuda.synchronize()  # Ensure GPU operations complete
        self.start_time = time.perf_counter()
 
    def end_measurement(self, metric_name):
        torch.cuda.synchronize()
        duration = time.perf_counter() - self.start_time
        self.metrics[metric_name].append(duration)
        return duration
 
    def measure_memory(self):
        if torch.cuda.is_available():
            gpu_memory = torch.cuda.memory_allocated() / 1024**3  # GB
            gpu_cached = torch.cuda.memory_reserved() / 1024**3   # GB
        else:
            gpu_memory = gpu_cached = 0
 
        cpu_memory = psutil.Process().memory_info().rss / 1024**3  # GB
 
        return {
            "gpu_allocated": gpu_memory,
            "gpu_cached": gpu_cached,
            "cpu_memory": cpu_memory
        }
 
    def calculate_throughput(self, num_tokens, duration):
        return num_tokens / duration  # tokens per second
 
    def get_summary(self):
        summary = {}
        for metric, values in self.metrics.items():
            summary[metric] = {
                "mean": sum(values) / len(values),
                "min": min(values),
                "max": max(values),
                "count": len(values)
            }
        return summary
 
# Usage example
monitor = PerformanceMonitor()
 
def benchmark_generation(model, tokenizer, prompts, max_length=100):
    results = []
 
    for prompt in prompts:
        monitor.start_measurement()
 
        inputs = tokenizer(prompt, return_tensors="pt")
        input_length = inputs["input_ids"].shape[1]
 
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_length=max_length,
                do_sample=True,
                temperature=0.7
            )
 
        generation_time = monitor.end_measurement("generation")
        output_length = outputs.shape[1]
        new_tokens = output_length - input_length
 
        throughput = monitor.calculate_throughput(new_tokens, generation_time)
        memory_usage = monitor.measure_memory()
 
        results.append({
            "prompt": prompt,
            "input_tokens": input_length,
            "output_tokens": new_tokens,
            "generation_time": generation_time,
            "throughput": throughput,
            "memory_usage": memory_usage
        })
 
    return results

Benchmarking Framework

class LLMBenchmark:
    def __init__(self, model_configs):
        self.model_configs = model_configs
        self.test_prompts = [
            "Explain artificial intelligence in simple terms.",
            "Write a short story about a robot.",
            "What are the benefits of renewable energy?",
            "Describe the process of photosynthesis.",
            "How does machine learning work?"
        ]
 
    def run_benchmark(self, config_name):
        config = self.model_configs[config_name]
        model = self.load_model(config)
        tokenizer = self.load_tokenizer(config)
 
        # Warmup
        self.warmup(model, tokenizer)
 
        # Benchmark different scenarios
        results = {
            "single_generation": self.benchmark_single(model, tokenizer),
            "batch_generation": self.benchmark_batch(model, tokenizer),
            "streaming_generation": self.benchmark_streaming(model, tokenizer)
        }
 
        return results
 
    def benchmark_single(self, model, tokenizer):
        return benchmark_generation(model, tokenizer, self.test_prompts)
 
    def benchmark_batch(self, model, tokenizer, batch_size=4):
        batched_prompts = [
            self.test_prompts[i:i+batch_size]
            for i in range(0, len(self.test_prompts), batch_size)
        ]
 
        results = []
        for batch in batched_prompts:
            monitor.start_measurement()
 
            inputs = tokenizer(
                batch,
                return_tensors="pt",
                padding=True,
                truncation=True
            )
 
            with torch.no_grad():
                outputs = model.generate(**inputs, max_length=100)
 
            batch_time = monitor.end_measurement("batch_generation")
            total_tokens = sum(len(tokenizer.decode(output)) for output in outputs)
 
            results.append({
                "batch_size": len(batch),
                "total_tokens": total_tokens,
                "batch_time": batch_time,
                "throughput": total_tokens / batch_time
            })
 
        return results

Inference Optimization Techniques

Speculative Decoding

Accelerate generation by using a smaller "draft" model to propose tokens and a larger model to verify them.

class SpeculativeDecoding:
    def __init__(self, draft_model, target_model, tokenizer):
        self.draft_model = draft_model
        self.target_model = target_model
        self.tokenizer = tokenizer
 
    def speculative_generate(self, prompt, max_length=100, gamma=5):
        """
        Generate text using speculative decoding.
 
        Args:
            prompt: Input prompt
            max_length: Maximum sequence length
            gamma: Number of tokens to propose per iteration
        """
        inputs = self.tokenizer(prompt, return_tensors="pt")
        sequence = inputs["input_ids"]
 
        while sequence.shape[1] < max_length:
            # Draft phase: generate gamma tokens with draft model
            draft_outputs = self.draft_model.generate(
                sequence,
                max_new_tokens=gamma,
                do_sample=True,
                temperature=0.7,
                pad_token_id=self.tokenizer.eos_token_id
            )
 
            proposed_tokens = draft_outputs[:, sequence.shape[1]:]
 
            # Target phase: verify proposed tokens
            extended_sequence = torch.cat([sequence, proposed_tokens], dim=1)
 
            with torch.no_grad():
                target_logits = self.target_model(extended_sequence).logits
                draft_logits = self.draft_model(extended_sequence).logits
 
            # Accept/reject tokens based on probability ratios
            accepted_tokens = self.accept_reject_tokens(
                target_logits,
                draft_logits,
                proposed_tokens,
                sequence.shape[1]
            )
 
            sequence = torch.cat([sequence, accepted_tokens], dim=1)
 
            # If we rejected some tokens, break to avoid infinite loop
            if accepted_tokens.shape[1] < proposed_tokens.shape[1]:
                break
 
        return sequence
 
    def accept_reject_tokens(self, target_logits, draft_logits, proposed_tokens, start_idx):
        accepted = []
 
        for i, token in enumerate(proposed_tokens[0]):
            target_prob = torch.softmax(target_logits[0, start_idx + i], dim=0)[token]
            draft_prob = torch.softmax(draft_logits[0, start_idx + i], dim=0)[token]
 
            # Acceptance probability
            acceptance_prob = min(1.0, target_prob / draft_prob)
 
            if torch.rand(1) < acceptance_prob:
                accepted.append(token)
            else:
                # Rejected token - sample new token from corrected distribution
                corrected_logits = target_logits[0, start_idx + i] - draft_logits[0, start_idx + i]
                corrected_probs = torch.softmax(corrected_logits, dim=0)
                new_token = torch.multinomial(corrected_probs, 1)[0]
                accepted.append(new_token)
                break
 
        return torch.tensor(accepted).unsqueeze(0)

Key-Value Cache Optimization

Optimize the attention mechanism's key-value cache for faster inference.

import torch.nn as nn
 
class OptimizedKVCache:
    def __init__(self, max_seq_len, num_layers, num_heads, head_dim, dtype=torch.float16):
        self.max_seq_len = max_seq_len
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.dtype = dtype
 
        # Pre-allocate cache tensors
        self.key_cache = torch.zeros(
            num_layers, max_seq_len, num_heads, head_dim, dtype=dtype
        )
        self.value_cache = torch.zeros(
            num_layers, max_seq_len, num_heads, head_dim, dtype=dtype
        )
 
        self.cache_lengths = torch.zeros(num_layers, dtype=torch.long)
 
    def update_cache(self, layer_idx, keys, values, position):
        seq_len = keys.shape[1]
 
        # Update cache
        self.key_cache[layer_idx, position:position+seq_len] = keys[0]
        self.value_cache[layer_idx, position:position+seq_len] = values[0]
        self.cache_lengths[layer_idx] = position + seq_len
 
    def get_cache(self, layer_idx):
        cache_len = self.cache_lengths[layer_idx]
        return (
            self.key_cache[layer_idx, :cache_len].unsqueeze(0),
            self.value_cache[layer_idx, :cache_len].unsqueeze(0)
        )
 
    def clear_cache(self):
        self.cache_lengths.zero_()
 
class OptimizedAttention(nn.Module):
    def __init__(self, config, layer_idx):
        super().__init__()
        self.layer_idx = layer_idx
        self.num_heads = config.num_attention_heads
        self.head_dim = config.hidden_size // config.num_attention_heads
 
        self.q_proj = nn.Linear(config.hidden_size, config.hidden_size)
        self.k_proj = nn.Linear(config.hidden_size, config.hidden_size)
        self.v_proj = nn.Linear(config.hidden_size, config.hidden_size)
        self.out_proj = nn.Linear(config.hidden_size, config.hidden_size)
 
    def forward(self, hidden_states, kv_cache=None, position=0, use_cache=False):
        batch_size, seq_len, hidden_size = hidden_states.shape
 
        # Compute queries, keys, values
        queries = self.q_proj(hidden_states)
        keys = self.k_proj(hidden_states)
        values = self.v_proj(hidden_states)
 
        # Reshape for multi-head attention
        queries = queries.view(batch_size, seq_len, self.num_heads, self.head_dim)
        keys = keys.view(batch_size, seq_len, self.num_heads, self.head_dim)
        values = values.view(batch_size, seq_len, self.num_heads, self.head_dim)
 
        if use_cache and kv_cache is not None:
            # Use cached keys and values for efficiency
            if position > 0:
                cached_keys, cached_values = kv_cache.get_cache(self.layer_idx)
                keys = torch.cat([cached_keys, keys], dim=1)
                values = torch.cat([cached_values, values], dim=1)
 
            # Update cache
            kv_cache.update_cache(self.layer_idx, keys, values, position)
 
        # Compute attention
        attention_output = self.scaled_dot_product_attention(queries, keys, values)
 
        # Reshape and project output
        attention_output = attention_output.contiguous().view(
            batch_size, seq_len, hidden_size
        )
        output = self.out_proj(attention_output)
 
        return output
 
    def scaled_dot_product_attention(self, queries, keys, values):
        # Use Flash Attention if available
        if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
            return torch.nn.functional.scaled_dot_product_attention(
                queries.transpose(1, 2),
                keys.transpose(1, 2),
                values.transpose(1, 2),
                is_causal=True
            ).transpose(1, 2)
        else:
            # Fallback to manual implementation
            scale = 1.0 / (self.head_dim ** 0.5)
            attention_scores = torch.matmul(queries, keys.transpose(-2, -1)) * scale
 
            # Apply causal mask
            seq_len = queries.shape[1]
            causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
            attention_scores.masked_fill_(causal_mask, float('-inf'))
 
            attention_weights = torch.softmax(attention_scores, dim=-1)
            return torch.matmul(attention_weights, values)

Continuous Batching

Implement continuous batching for improved throughput.

import asyncio
from dataclasses import dataclass
from typing import List, Optional
import heapq
 
@dataclass
class GenerationRequest:
    request_id: str
    prompt: str
    max_tokens: int
    temperature: float = 0.7
    created_at: float = 0.0
    tokens_generated: int = 0
    is_finished: bool = False
    output_tokens: List[int] = None
 
    def __post_init__(self):
        if self.output_tokens is None:
            self.output_tokens = []
 
class ContinuousBatcher:
    def __init__(self, model, tokenizer, max_batch_size=32, max_wait_ms=10):
        self.model = model
        self.tokenizer = tokenizer
        self.max_batch_size = max_batch_size
        self.max_wait_ms = max_wait_ms
 
        # Active requests being processed
        self.active_requests: List[GenerationRequest] = []
        # Waiting requests queue
        self.waiting_queue = []
        # Completed requests
        self.completed_requests = {}
 
        self.kv_cache = OptimizedKVCache(
            max_seq_len=2048,
            num_layers=self.model.config.num_hidden_layers,
            num_heads=self.model.config.num_attention_heads,
            head_dim=self.model.config.hidden_size // self.model.config.num_attention_heads
        )
 
    async def add_request(self, request: GenerationRequest):
        heapq.heappush(self.waiting_queue, (request.created_at, request))
 
        # Start processing if not already running
        if not hasattr(self, '_processing_task') or self._processing_task.done():
            self._processing_task = asyncio.create_task(self._process_loop())
 
    async def _process_loop(self):
        while self.active_requests or self.waiting_queue:
            # Add waiting requests to active batch
            while (len(self.active_requests) < self.max_batch_size and
                   self.waiting_queue):
                _, request = heapq.heappop(self.waiting_queue)
                self.active_requests.append(request)
 
            if not self.active_requests:
                await asyncio.sleep(0.001)  # 1ms
                continue
 
            # Process one iteration for all active requests
            await self._process_batch_iteration()
 
            # Remove completed requests
            self.active_requests = [
                req for req in self.active_requests
                if not req.is_finished
            ]
 
            # Small delay to prevent excessive CPU usage
            await asyncio.sleep(0.001)
 
    async def _process_batch_iteration(self):
        if not self.active_requests:
            return
 
        # Prepare batch inputs
        batch_inputs = []
        batch_positions = []
 
        for request in self.active_requests:
            if request.tokens_generated == 0:
                # First iteration - use full prompt
                inputs = self.tokenizer(request.prompt, return_tensors="pt")
                input_ids = inputs["input_ids"][0]
                position = 0
            else:
                # Subsequent iterations - use last generated token
                input_ids = torch.tensor([request.output_tokens[-1]])
                position = len(self.tokenizer(request.prompt)["input_ids"]) + request.tokens_generated - 1
 
            batch_inputs.append(input_ids)
            batch_positions.append(position)
 
        # Pad inputs to same length
        max_len = max(len(inputs) for inputs in batch_inputs)
        padded_inputs = torch.zeros(len(batch_inputs), max_len, dtype=torch.long)
 
        for i, inputs in enumerate(batch_inputs):
            padded_inputs[i, :len(inputs)] = inputs
 
        # Forward pass
        with torch.no_grad():
            outputs = self.model(
                input_ids=padded_inputs,
                use_cache=True,
                kv_cache=self.kv_cache
            )
 
        # Sample next tokens for each request
        for i, request in enumerate(self.active_requests):
            logits = outputs.logits[i, -1, :]  # Last token logits
 
            if request.temperature > 0:
                probs = torch.softmax(logits / request.temperature, dim=-1)
                next_token = torch.multinomial(probs, 1)[0]
            else:
                next_token = torch.argmax(logits)
 
            request.output_tokens.append(next_token.item())
            request.tokens_generated += 1
 
            # Check if generation is complete
            if (next_token == self.tokenizer.eos_token_id or
                request.tokens_generated >= request.max_tokens):
                request.is_finished = True
                self.completed_requests[request.request_id] = request
 
    def get_result(self, request_id: str) -> Optional[GenerationRequest]:
        return self.completed_requests.get(request_id)

Hardware Acceleration

Mixed Precision Training and Inference

from torch.cuda.amp import autocast, GradScaler
 
class MixedPrecisionInference:
    def __init__(self, model, use_fp16=True):
        self.model = model
        self.use_fp16 = use_fp16
 
        if use_fp16:
            self.model = self.model.half()
 
    @torch.no_grad()
    def generate(self, inputs, **generation_kwargs):
        if self.use_fp16:
            with autocast():
                outputs = self.model.generate(**inputs, **generation_kwargs)
        else:
            outputs = self.model.generate(**inputs, **generation_kwargs)
 
        return outputs
 
# BFloat16 for better numerical stability
def convert_to_bfloat16(model):
    def convert_module(module):
        for name, child in module.named_children():
            if isinstance(child, (torch.nn.Linear, torch.nn.Embedding)):
                # Convert weights to bfloat16
                child.weight.data = child.weight.data.to(torch.bfloat16)
                if hasattr(child, 'bias') and child.bias is not None:
                    child.bias.data = child.bias.data.to(torch.bfloat16)
            else:
                convert_module(child)
 
    convert_module(model)
    return model

Custom CUDA Kernels

# Example: Custom fused attention kernel
import triton
import triton.language as tl
 
@triton.jit
def fused_attention_kernel(
    Q, K, V, Out,
    stride_qz, stride_qh, stride_qm, stride_qk,
    stride_kz, stride_kh, stride_kn, stride_kk,
    stride_vz, stride_vh, stride_vk, stride_vn,
    stride_oz, stride_oh, stride_om, stride_on,
    Z, H, N_CTX,
    BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
    BLOCK_N: tl.constexpr,
):
    start_m = tl.program_id(0)
    off_hz = tl.program_id(1)
 
    # Initialize offsets
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = tl.arange(0, BLOCK_N)
    offs_k = tl.arange(0, BLOCK_DMODEL)
 
    # Load Q block
    q_ptrs = Q + off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk
    q = tl.load(q_ptrs)
 
    # Initialize accumulator
    acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
 
    # Loop over K, V blocks
    for start_n in range(0, N_CTX, BLOCK_N):
        # Load K, V blocks
        k_ptrs = K + off_hz * stride_kh + (start_n + offs_n)[:, None] * stride_kn + offs_k[None, :] * stride_kk
        v_ptrs = V + off_hz * stride_vh + offs_k[:, None] * stride_vk + (start_n + offs_n)[None, :] * stride_vn
 
        k = tl.load(k_ptrs)
        v = tl.load(v_ptrs)
 
        # Compute attention scores
        qk = tl.dot(q, k)
        qk = qk * (1.0 / tl.sqrt(BLOCK_DMODEL.to(tl.float32)))
 
        # Apply causal mask
        mask = offs_m[:, None] >= (start_n + offs_n)[None, :]
        qk = tl.where(mask, qk, float('-inf'))
 
        # Softmax
        m_i = tl.max(qk, 1)
        qk = qk - m_i[:, None]
        p = tl.exp(qk)
        l_i = tl.sum(p, 1)
 
        # Update accumulator
        acc = acc + tl.dot(p.to(tl.float16), v)
 
    # Store output
    out_ptrs = Out + off_hz * stride_oh + offs_m[:, None] * stride_om + offs_k[None, :] * stride_on
    tl.store(out_ptrs, acc)
 
class TritonFusedAttention(nn.Module):
    def forward(self, q, k, v):
        BLOCK_M, BLOCK_N = 64, 64
        BLOCK_DMODEL = q.shape[-1]
 
        # Allocate output
        output = torch.empty_like(q)
 
        # Launch kernel
        grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1])
        fused_attention_kernel[grid](
            q, k, v, output,
            q.stride(0), q.stride(1), q.stride(2), q.stride(3),
            k.stride(0), k.stride(1), k.stride(2), k.stride(3),
            v.stride(0), v.stride(1), v.stride(2), v.stride(3),
            output.stride(0), output.stride(1), output.stride(2), output.stride(3),
            q.shape[0], q.shape[1], q.shape[2],
            BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL, BLOCK_N=BLOCK_N
        )
 
        return output

Memory Management

Gradient Checkpointing

from torch.utils.checkpoint import checkpoint
 
class MemoryEfficientTransformerBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.attention = MultiHeadAttention(config)
        self.feed_forward = FeedForward(config)
        self.norm1 = nn.LayerNorm(config.hidden_size)
        self.norm2 = nn.LayerNorm(config.hidden_size)
 
    def forward(self, x, use_checkpointing=True):
        if use_checkpointing and self.training:
            # Use gradient checkpointing to save memory
            def attention_block(x):
                return x + self.attention(self.norm1(x))
 
            def ff_block(x):
                return x + self.feed_forward(self.norm2(x))
 
            x = checkpoint(attention_block, x)
            x = checkpoint(ff_block, x)
        else:
            x = x + self.attention(self.norm1(x))
            x = x + self.feed_forward(self.norm2(x))
 
        return x
 
# Dynamic memory management
class MemoryManager:
    def __init__(self):
        self.peak_memory = 0
        self.memory_history = []
 
    def monitor_memory(self):
        if torch.cuda.is_available():
            current_memory = torch.cuda.memory_allocated()
            self.peak_memory = max(self.peak_memory, current_memory)
            self.memory_history.append(current_memory)
 
        return current_memory
 
    def clear_cache(self):
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
 
    def optimize_memory_usage(self, model):
        # Enable gradient checkpointing
        if hasattr(model, 'gradient_checkpointing_enable'):
            model.gradient_checkpointing_enable()
 
        # Clear unnecessary cached data
        self.clear_cache()
 
        return model

Model Sharding and Parallelism

import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
 
class ModelSharding:
    def __init__(self, model, num_shards):
        self.model = model
        self.num_shards = num_shards
        self.shards = []
 
    def shard_by_layer(self):
        """Shard model by distributing layers across devices."""
        layers = list(self.model.modules())
        layers_per_shard = len(layers) // self.num_shards
 
        for i in range(self.num_shards):
            start_idx = i * layers_per_shard
            end_idx = (i + 1) * layers_per_shard if i < self.num_shards - 1 else len(layers)
 
            shard = nn.Sequential(*layers[start_idx:end_idx])
            device = f'cuda:{i}'
            shard.to(device)
 
            self.shards.append((shard, device))
 
    def forward_sharded(self, x):
        """Forward pass through sharded model."""
        current_device = x.device
 
        for i, (shard, device) in enumerate(self.shards):
            # Move input to shard device
            x = x.to(device)
 
            # Forward through shard
            x = shard(x)
 
            # Move to next device (or back to original for output)
            if i < len(self.shards) - 1:
                next_device = self.shards[i + 1][1]
                x = x.to(next_device)
 
        return x.to(current_device)
 
# Pipeline parallelism implementation
class PipelineParallel(nn.Module):
    def __init__(self, model, num_stages, device_ids):
        super().__init__()
        self.num_stages = num_stages
        self.device_ids = device_ids
 
        # Split model into stages
        self.stages = self._split_model(model)
 
    def _split_model(self, model):
        stages = []
        layers = list(model.children())
        layers_per_stage = len(layers) // self.num_stages
 
        for i in range(self.num_stages):
            start = i * layers_per_stage
            end = (i + 1) * layers_per_stage if i < self.num_stages - 1 else len(layers)
 
            stage = nn.Sequential(*layers[start:end])
            stage.to(self.device_ids[i])
            stages.append(stage)
 
        return nn.ModuleList(stages)
 
    def forward(self, x, num_microbatches=4):
        # Split input into microbatches for pipeline parallelism
        microbatch_size = x.size(0) // num_microbatches
        microbatches = [
            x[i*microbatch_size:(i+1)*microbatch_size]
            for i in range(num_microbatches)
        ]
 
        # Pipeline execution
        outputs = []
        for microbatch in microbatches:
            current_input = microbatch
 
            for stage_idx, stage in enumerate(self.stages):
                current_input = current_input.to(self.device_ids[stage_idx])
                current_input = stage(current_input)
 
            outputs.append(current_input)
 
        return torch.cat(outputs, dim=0)

Advanced Optimization Strategies

Dynamic Sequence Length

class DynamicBatching:
    def __init__(self, tokenizer, max_tokens_per_batch=8192):
        self.tokenizer = tokenizer
        self.max_tokens_per_batch = max_tokens_per_batch
 
    def create_dynamic_batch(self, requests):
        """Create batches optimized for total token count rather than sequence count."""
        # Sort by sequence length for better packing
        sorted_requests = sorted(requests, key=lambda x: len(self.tokenizer(x.prompt)["input_ids"]))
 
        batches = []
        current_batch = []
        current_tokens = 0
 
        for request in sorted_requests:
            request_tokens = len(self.tokenizer(request.prompt)["input_ids"])
 
            # Check if adding this request would exceed token limit
            if current_tokens + request_tokens > self.max_tokens_per_batch and current_batch:
                batches.append(current_batch)
                current_batch = [request]
                current_tokens = request_tokens
            else:
                current_batch.append(request)
                current_tokens += request_tokens
 
        if current_batch:
            batches.append(current_batch)
 
        return batches
 
    def pad_batch_efficiently(self, batch_inputs):
        """Pad batch to minimum required length rather than max length."""
        max_length = max(len(inputs) for inputs in batch_inputs)
 
        padded_batch = torch.zeros(len(batch_inputs), max_length, dtype=torch.long)
        attention_mask = torch.zeros(len(batch_inputs), max_length, dtype=torch.bool)
 
        for i, inputs in enumerate(batch_inputs):
            length = len(inputs)
            padded_batch[i, :length] = inputs
            attention_mask[i, :length] = True
 
        return padded_batch, attention_mask

Adaptive Computation

class AdaptiveComputationTransformer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.layers = nn.ModuleList([
            TransformerBlock(config) for _ in range(config.num_layers)
        ])
 
        # Adaptive computation components
        self.halt_probability = nn.ModuleList([
            nn.Linear(config.hidden_size, 1) for _ in range(config.num_layers)
        ])
        self.halt_threshold = 0.95
 
    def forward(self, x, max_computation_steps=None):
        if max_computation_steps is None:
            max_computation_steps = len(self.layers)
 
        batch_size, seq_len, hidden_size = x.shape
 
        # Initialize adaptive computation variables
        cumulative_halt_prob = torch.zeros(batch_size, seq_len, device=x.device)
        outputs = x
 
        for layer_idx, layer in enumerate(self.layers[:max_computation_steps]):
            # Apply transformer layer
            layer_output = layer(outputs)
 
            # Compute halt probability
            halt_logits = self.halt_probability[layer_idx](layer_output)
            halt_prob = torch.sigmoid(halt_logits.squeeze(-1))
 
            # Update cumulative halt probability
            cumulative_halt_prob += halt_prob
 
            # Create mask for tokens that should continue processing
            continue_mask = (cumulative_halt_prob < self.halt_threshold).float()
 
            # Update outputs only for continuing tokens
            outputs = (continue_mask.unsqueeze(-1) * layer_output +
                      (1 - continue_mask.unsqueeze(-1)) * outputs)
 
            # Early stopping if all tokens have halted
            if (cumulative_halt_prob >= self.halt_threshold).all():
                break
 
        return outputs, cumulative_halt_prob

Profiling and Debugging

Performance Profiling Tools

import torch.profiler
from torch.profiler import profile, ProfilerActivity
 
class PerformanceProfiler:
    def __init__(self, output_dir="./profiler_logs"):
        self.output_dir = output_dir
 
    def profile_model_execution(self, model, inputs, num_iterations=10):
        """Profile model execution with detailed metrics."""
 
        # Warmup
        for _ in range(3):
            with torch.no_grad():
                _ = model(**inputs)
 
        # Profile with PyTorch Profiler
        with profile(
            activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
            record_shapes=True,
            profile_memory=True,
            with_stack=True
        ) as prof:
            for _ in range(num_iterations):
                with torch.no_grad():
                    output = model(**inputs)
 
        # Export profiling results
        prof.export_chrome_trace(f"{self.output_dir}/trace.json")
 
        # Print summary
        print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
 
        return prof
 
    def memory_profile(self, model, inputs):
        """Profile memory usage during model execution."""
 
        def trace_handler(p):
            output = p.key_averages().table(sort_by="self_cuda_memory_usage", row_limit=10)
            print(output)
            p.export_chrome_trace(f"{self.output_dir}/memory_trace.json")
 
        with profile(
            activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
            profile_memory=True,
            record_shapes=True,
            on_trace_ready=trace_handler
        ) as prof:
            with torch.no_grad():
                _ = model(**inputs)
 
        return prof
 
# Usage example
profiler = PerformanceProfiler()
 
def benchmark_with_profiling(model, tokenizer, test_prompts):
    # Prepare inputs
    inputs = tokenizer(test_prompts, return_tensors="pt", padding=True, truncation=True)
 
    # Profile execution
    execution_prof = profiler.profile_model_execution(model, inputs)
    memory_prof = profiler.memory_profile(model, inputs)
 
    return execution_prof, memory_prof

Bottleneck Analysis

class BottleneckAnalyzer:
    def __init__(self):
        self.layer_times = {}
        self.memory_usage = {}
 
    def analyze_model_layers(self, model, inputs):
        """Analyze each layer's contribution to total inference time."""
 
        def forward_hook(name):
            def hook(module, input, output):
                torch.cuda.synchronize()
                end_time = time.perf_counter()
                self.layer_times[name] = end_time - self.start_time
                self.start_time = end_time
            return hook
 
        # Register hooks for all layers
        hooks = []
        for name, module in model.named_modules():
            if isinstance(module, (nn.Linear, nn.LayerNorm, nn.MultiheadAttention)):
                hooks.append(module.register_forward_hook(forward_hook(name)))
 
        # Run inference
        torch.cuda.synchronize()
        self.start_time = time.perf_counter()
 
        with torch.no_grad():
            _ = model(**inputs)
 
        # Remove hooks
        for hook in hooks:
            hook.remove()
 
        # Analyze results
        total_time = sum(self.layer_times.values())
 
        print("Layer Performance Analysis:")
        print("-" * 60)
        for layer_name, layer_time in sorted(
            self.layer_times.items(),
            key=lambda x: x[1],
            reverse=True
        )[:10]:
            percentage = (layer_time / total_time) * 100
            print(f"{layer_name:<40} {layer_time:.4f}s ({percentage:.1f}%)")
 
        return self.layer_times
 
    def suggest_optimizations(self, layer_times):
        """Suggest optimizations based on profiling results."""
        suggestions = []
 
        total_time = sum(layer_times.values())
 
        # Check for expensive attention layers
        attention_time = sum(
            time for name, time in layer_times.items()
            if 'attention' in name.lower()
        )
        if attention_time / total_time > 0.6:
            suggestions.append("Consider using Flash Attention or other optimized attention implementations")
 
        # Check for expensive linear layers
        linear_time = sum(
            time for name, time in layer_times.items()
            if 'linear' in name.lower() or 'dense' in name.lower()
        )
        if linear_time / total_time > 0.4:
            suggestions.append("Consider quantization or pruning for linear layers")
 
        # Check for layer norm overhead
        norm_time = sum(
            time for name, time in layer_times.items()
            if 'norm' in name.lower()
        )
        if norm_time / total_time > 0.1:
            suggestions.append("Consider fusing layer normalization with adjacent operations")
 
        return suggestions

Production Optimization Checklist

class OptimizationChecklist:
    def __init__(self):
        self.checks = {
            "model_optimization": [
                "Quantization applied (INT8/FP16)",
                "Pruning applied where appropriate",
                "Knowledge distillation considered",
                "Model architecture optimized for inference"
            ],
            "memory_optimization": [
                "Gradient checkpointing enabled",
                "KV-cache optimization implemented",
                "Memory pooling configured",
                "Batch size optimized for hardware"
            ],
            "compute_optimization": [
                "Mixed precision enabled",
                "Flash Attention or equivalent used",
                "Fused operations implemented",
                "CUDA kernels optimized"
            ],
            "serving_optimization": [
                "Continuous batching implemented",
                "Response caching configured",
                "Load balancing optimized",
                "Auto-scaling configured"
            ],
            "monitoring": [
                "Latency monitoring in place",
                "Throughput tracking enabled",
                "Resource utilization monitored",
                "Error rate tracking configured"
            ]
        }
 
    def run_checklist(self, model, serving_config):
        results = {}
 
        for category, items in self.checks.items():
            results[category] = {}
 
            for item in items:
                # This would be replaced with actual checks
                status = self.check_optimization(item, model, serving_config)
                results[category][item] = status
 
        return results
 
    def check_optimization(self, item, model, config):
        # Placeholder for actual optimization checks
        # In practice, this would inspect the model and config
        return "✓" if hash(item) % 2 == 0 else "✗"
 
    def generate_report(self, results):
        print("Optimization Status Report")
        print("=" * 50)
 
        for category, checks in results.items():
            print(f"\n{category.upper()}:")
            for item, status in checks.items():
                print(f"  {status} {item}")

Key Takeaways

  • Performance optimization requires systematic measurement and analysis
  • Speculative decoding and continuous batching significantly improve throughput
  • Memory optimization through caching and checkpointing enables larger models
  • Hardware acceleration with mixed precision and custom kernels reduces latency
  • Profiling tools are essential for identifying bottlenecks
  • Production optimization requires balancing multiple factors: latency, throughput, cost, and quality

This comprehensive approach to performance optimization enables deploying LLMs that meet production requirements for speed, efficiency, and cost-effectiveness.


Navigation