Lecture 13: AI Workloads & Dataflow
Part IV: Applications — SCE Futures
Learning Objectives¶
By the end of this lecture, you will be able to:
- Explain the self-attention mechanism and why it revolutionized sequence modeling
- Trace the data flow through a transformer block and identify all matrix operations
- Calculate the exact dimensions of Q, K, V matrices for any model configuration
- Quantify the computational complexity (FLOPs) for attention and MLP layers
- Analyze memory bandwidth requirements and identify bottlenecks
- Understand why transformers are both computationally intensive and memory-bound
- Connect these requirements to the need for novel hardware architectures
1. Why Transformers Matter for Hardware¶
The transformer architecture, introduced in "Attention Is All You Need" (Vaswani et al., 2017), has become the foundation of modern AI. Understanding its computational structure is essential for hardware designers because:
- Transformers dominate AI compute: GPT-4, Claude, Llama, Gemini, and virtually every frontier model uses transformers
- Predictable structure: Unlike earlier RNNs, transformers are built from repeated, regular matrix operations
- Extreme scale: Modern models have billions to trillions of parameters, demanding specialized hardware
- Memory-bound inference: The bottleneck is often memory bandwidth, not raw compute
The Scale of Modern LLMs¶
| Model | Parameters | Training Compute (FLOPs) | Est. Training Cost |
|---|---|---|---|
| GPT-2 (2019) | 1.5B | ~10²² | ~$50K |
| GPT-3 (2020) | 175B | ~3×10²³ | ~$5M |
| GPT-4 (2023) | ~1.8T (est.) | ~10²⁵ | ~$100M |
| Llama 3 405B (2024) | 405B | ~4×10²⁵ | ~$100M+ |
Note: Llama 3 405B has ~4.4× fewer parameters than GPT-4 but used ~4× more training compute and cost as much or more. This reflects the "Chinchilla scaling" approach: instead of making the model bigger, Meta trained a smaller model on far more data (15T+ tokens). Training cost is driven by compute (FLOPs), not parameter count — more passes over more data means more GPU-hours regardless of model size.
The exponential growth in model size has outpaced Moore's Law, creating an urgent need for more efficient compute architectures.
# Setup: Import libraries for visualizations
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import matplotlib.patheffects as path_effects
import numpy as np
COLORS = {
'primary': '#2196F3', # Blue
'secondary': '#FF9800', # Orange
'success': '#4CAF50', # Green
'danger': '#f44336', # Red
'dark': '#1a1a2e', # Dark navy
'light': '#f5f5f5', # Light gray
'purple': '#9C27B0', # Purple
'cyan': '#00BCD4', # Cyan
'attention': '#E91E63', # Pink for attention
'mlp': '#3F51B5', # Indigo for MLP
'embedding': '#009688', # Teal for embeddings
}
plt.rcParams['figure.facecolor'] = 'white'
plt.rcParams['axes.facecolor'] = 'white'
plt.rcParams['font.size'] = 11
print("Setup complete.")
Setup complete.
# Visualize: Growth of LLM compute requirements over time
fig, ax = plt.subplots(figsize=(14, 6))
# Data: model releases with training compute estimates (in FLOPs)
models = {
'BERT\n(2018)': 3e19,
'GPT-2\n(2019)': 1e22,
'GPT-3\n(2020)': 3e23,
'PaLM\n(2022)': 3e24,
'GPT-4\n(2023)': 1e25,
'Llama 3\n405B (2024)': 4e25,
}
years = [2018, 2019, 2020, 2022, 2023, 2024]
flops = list(models.values())
names = list(models.keys())
# Plot with log scale
bars = ax.bar(range(len(models)), flops, color=COLORS['primary'], alpha=0.8, edgecolor='black')
ax.set_yscale('log')
ax.set_xticks(range(len(models)))
ax.set_xticklabels(names, fontsize=10)
ax.set_ylabel('Training Compute (FLOPs)', fontsize=12)
ax.set_title('Exponential Growth in LLM Training Compute', fontsize=14, fontweight='bold')
ax.grid(True, alpha=0.3, axis='y')
# Add Moore's Law reference line (2x every 2 years from 2018)
moore_x = np.linspace(0, 5, 100)
moore_y = 3e19 * (2 ** (moore_x * 1)) # 2x per year
ax.plot(moore_x, moore_y, '--', color=COLORS['danger'], linewidth=2, label="Moore's Law (2x/2yr)")
# Actual scaling line
actual_x = np.array([0, 5])
actual_y = np.array([3e19, 4e25])
ax.plot(actual_x, actual_y, '--', color=COLORS['success'], linewidth=2, label='Actual LLM scaling (~10x/yr)')
ax.legend(loc='upper left', fontsize=10)
ax.set_ylim(1e19, 1e26)
plt.tight_layout()
plt.show()
print("LLM compute demand is growing ~10x per year, far exceeding Moore's Law.")
print("This gap is driving the need for novel architectures like superconducting accelerators.")
LLM compute demand is growing ~10x per year, far exceeding Moore's Law. This gap is driving the need for novel architectures like superconducting accelerators.
2. The Attention Mechanism¶
The core innovation of transformers is self-attention, which allows every token in a sequence to "attend to" every other token. This replaces the sequential processing of RNNs with parallel computation.
The Attention Formula¶
$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V$$
Where:
- $Q$ (Query): What am I looking for?
- $K$ (Key): What do I contain?
- $V$ (Value): What information do I provide?
- $d_k$: Dimension of keys (for numerical stability)
Intuition: The Library Analogy¶
Think of attention like searching a library:
- Query: Your search question ("books about physics")
- Key: The catalog entry for each book (title, subject tags)
- Value: The actual content of each book
- Attention scores: How relevant each book is to your query
- Output: A weighted combination of book contents based on relevance
Step-by-Step Computation¶
For a sequence of $n$ tokens with embedding dimension $d_{model}$:
Linear projections: Create Q, K, V from input embeddings
- $Q = XW_Q$ where $W_Q \in \mathbb{R}^{d_{model} \times d_k}$
- $K = XW_K$ where $W_K \in \mathbb{R}^{d_{model} \times d_k}$
- $V = XW_V$ where $W_V \in \mathbb{R}^{d_{model} \times d_v}$
Attention scores: Compute similarity between queries and keys
- $\text{scores} = QK^T / \sqrt{d_k}$ gives an $n \times n$ matrix
Softmax normalization: Convert scores to probabilities (row-wise)
- Each row sums to 1, representing attention weights
Weighted sum: Combine values according to attention weights
- Output = attention_weights × V
# Visualize: The attention computation flow
fig, axes = plt.subplots(1, 4, figsize=(16, 5))
# Example dimensions
n_tokens = 6
d_model = 8
d_k = 4
def draw_matrix(ax, rows, cols, label, color, sublabel=''):
"""Draw a matrix representation."""
rect = patches.FancyBboxPatch((0, 0), cols, rows,
boxstyle="round,pad=0.02",
facecolor=color, alpha=0.7,
edgecolor='black', linewidth=2)
ax.add_patch(rect)
ax.set_xlim(-0.5, cols + 0.5)
ax.set_ylim(-0.5, rows + 1)
ax.set_aspect('equal')
ax.axis('off')
ax.text(cols/2, rows + 0.3, label, ha='center', va='bottom',
fontsize=14, fontweight='bold')
ax.text(cols/2, -0.3, sublabel, ha='center', va='top', fontsize=10)
# Step 1: Q, K, V matrices
draw_matrix(axes[0], n_tokens, d_k, 'Q, K, V', COLORS['primary'], f'{n_tokens}×{d_k}')
axes[0].set_title('1. Project to Q, K, V', fontsize=11)
# Step 2: QK^T attention scores
draw_matrix(axes[1], n_tokens, n_tokens, 'QK$^T$', COLORS['attention'], f'{n_tokens}×{n_tokens}')
axes[1].set_title('2. Compute Attention Scores', fontsize=11)
# Step 3: Softmax (same shape, different values)
draw_matrix(axes[2], n_tokens, n_tokens, 'softmax(QK$^T$/√d$_k$)', COLORS['secondary'], f'{n_tokens}×{n_tokens}')
axes[2].set_title('3. Apply Softmax', fontsize=11)
# Step 4: Output
draw_matrix(axes[3], n_tokens, d_k, 'Output', COLORS['success'], f'{n_tokens}×{d_k}')
axes[3].set_title('4. Weighted Sum with V', fontsize=11)
# Add arrows between steps
for i in range(3):
fig.text(0.25 + i*0.23, 0.5, '→', fontsize=30, ha='center', va='center')
plt.suptitle('Self-Attention: Data Flow Through Matrix Operations', fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()
print(f"Key insight: The attention matrix is n×n = {n_tokens}×{n_tokens} = {n_tokens**2} elements.")
print(f"This n² scaling is why long sequences are expensive!")
Key insight: The attention matrix is n×n = 6×6 = 36 elements. This n² scaling is why long sequences are expensive!
# Demonstrate: Actual attention computation with numbers
np.random.seed(42)
# Small example for clarity
seq_len = 4
d_k = 3
# Random Q, K, V matrices
Q = np.random.randn(seq_len, d_k).round(2)
K = np.random.randn(seq_len, d_k).round(2)
V = np.random.randn(seq_len, d_k).round(2)
print("="*60)
print("ATTENTION COMPUTATION EXAMPLE")
print("="*60)
print(f"\nSequence length: {seq_len}, Key dimension: {d_k}")
print(f"\nQ (Query) matrix [{seq_len}×{d_k}]:")
print(Q)
print(f"\nK (Key) matrix [{seq_len}×{d_k}]:")
print(K)
print(f"\nV (Value) matrix [{seq_len}×{d_k}]:")
print(V)
# Step 1: QK^T
scores = Q @ K.T
print(f"\n--- Step 1: QK^T [{seq_len}×{seq_len}] ---")
print(scores.round(2))
# Step 2: Scale
scaled_scores = scores / np.sqrt(d_k)
print(f"\n--- Step 2: Scale by √d_k = √{d_k} = {np.sqrt(d_k):.2f} ---")
print(scaled_scores.round(2))
# Step 3: Softmax
def softmax(x):
exp_x = np.exp(x - np.max(x, axis=-1, keepdims=True))
return exp_x / exp_x.sum(axis=-1, keepdims=True)
attention_weights = softmax(scaled_scores)
print(f"\n--- Step 3: Softmax (each row sums to 1) ---")
print(attention_weights.round(3))
print(f"Row sums: {attention_weights.sum(axis=1).round(3)}")
# Step 4: Output
output = attention_weights @ V
print(f"\n--- Step 4: Attention × V = Output [{seq_len}×{d_k}] ---")
print(output.round(2))
print("\n" + "="*60)
print("Each output row is a weighted combination of all V rows,")
print("where weights come from the attention scores.")
============================================================ ATTENTION COMPUTATION EXAMPLE ============================================================ Sequence length: 4, Key dimension: 3 Q (Query) matrix [4×3]: [[ 0.5 -0.14 0.65] [ 1.52 -0.23 -0.23] [ 1.58 0.77 -0.47] [ 0.54 -0.46 -0.47]] K (Key) matrix [4×3]: [[ 0.24 -1.91 -1.72] [-0.56 -1.01 0.31] [-0.91 -1.41 1.47] [-0.23 0.07 -1.42]] V (Value) matrix [4×3]: [[-0.54 0.11 -1.15] [ 0.38 -0.6 -0.29] [-0.6 1.85 -0.01] [-1.06 0.82 -1.22]] --- Step 1: QK^T [4×4] --- [[-0.73 0.06 0.7 -1.05] [ 1.2 -0.69 -1.4 -0.04] [-0.28 -1.81 -3.21 0.36] [ 1.82 0.02 -0.53 0.51]] --- Step 2: Scale by √d_k = √3 = 1.73 --- [[-0.42 0.04 0.4 -0.6 ] [ 0.69 -0.4 -0.81 -0.02] [-0.16 -1.04 -1.86 0.21] [ 1.05 0.01 -0.31 0.3 ]] --- Step 3: Softmax (each row sums to 1) --- [[0.176 0.278 0.401 0.146] [0.488 0.164 0.109 0.239] [0.328 0.136 0.06 0.475] [0.48 0.17 0.124 0.226]] Row sums: [1. 1. 1. 1.] --- Step 4: Attention × V = Output [4×3] --- [[-0.38 0.71 -0.46] [-0.52 0.35 -0.9 ] [-0.67 0.46 -1. ] [-0.51 0.37 -0.88]] ============================================================ Each output row is a weighted combination of all V rows, where weights come from the attention scores.
Multi-Head Attention¶
Real transformers use multi-head attention, running multiple attention operations in parallel with different learned projections:
$$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W^O$$
Where each head has its own projection matrices: $$\text{head}_i = \text{Attention}(XW_i^Q, XW_i^K, XW_i^V)$$
Why multiple heads?
- Each head can learn to attend to different types of relationships
- Head 1 might learn syntactic relationships (subject-verb)
- Head 2 might learn semantic relationships (synonyms)
- Head 3 might learn positional patterns
Typical configurations:
| Model | d_model | n_heads | d_head = d_model/n_heads |
|---|---|---|---|
| BERT-base | 768 | 12 | 64 |
| GPT-3 | 12,288 | 96 | 128 |
| Llama 2 70B | 8,192 | 64 | 128 |
| Llama 3 405B | 16,384 | 128 | 128 |
# Visualize: Multi-head attention architecture
fig, ax = plt.subplots(figsize=(14, 8))
# Configuration
n_heads = 4
d_model = 512
d_head = d_model // n_heads
# Draw input
input_rect = patches.FancyBboxPatch((1, 7), 3, 0.8,
boxstyle="round,pad=0.02",
facecolor=COLORS['embedding'], alpha=0.8,
edgecolor='black', linewidth=2)
ax.add_patch(input_rect)
ax.text(2.5, 7.4, f'Input X\n(n × {d_model})', ha='center', va='center', fontsize=10, fontweight='bold')
# Draw projection layer
proj_rect = patches.FancyBboxPatch((0.5, 5.5), 4, 0.8,
boxstyle="round,pad=0.02",
facecolor=COLORS['light'], alpha=0.8,
edgecolor='black', linewidth=2)
ax.add_patch(proj_rect)
ax.text(2.5, 5.9, 'Linear Projections → Q, K, V', ha='center', va='center', fontsize=10)
# Arrow from input to projection
ax.annotate('', xy=(2.5, 6.3), xytext=(2.5, 7),
arrowprops=dict(arrowstyle='->', color='black', lw=2))
# Draw attention heads
head_colors = [COLORS['primary'], COLORS['secondary'], COLORS['success'], COLORS['purple']]
for i in range(n_heads):
x = 0.5 + i * 1.1
head_rect = patches.FancyBboxPatch((x, 3.5), 0.9, 1.5,
boxstyle="round,pad=0.02",
facecolor=head_colors[i], alpha=0.7,
edgecolor='black', linewidth=2)
ax.add_patch(head_rect)
ax.text(x + 0.45, 4.25, f'Head {i+1}\n({d_head})', ha='center', va='center',
fontsize=9, fontweight='bold', color='white')
# Arrows from projection to heads
ax.annotate('', xy=(x + 0.45, 5), xytext=(2.5, 5.5),
arrowprops=dict(arrowstyle='->', color='gray', lw=1))
# Draw concatenation
concat_rect = patches.FancyBboxPatch((0.5, 2), 4, 0.8,
boxstyle="round,pad=0.02",
facecolor=COLORS['light'], alpha=0.8,
edgecolor='black', linewidth=2)
ax.add_patch(concat_rect)
ax.text(2.5, 2.4, f'Concatenate ({n_heads} × {d_head} = {d_model})', ha='center', va='center', fontsize=10)
# Arrows from heads to concat
for i in range(n_heads):
x = 0.5 + i * 1.1 + 0.45
ax.annotate('', xy=(x, 2.8), xytext=(x, 3.5),
arrowprops=dict(arrowstyle='->', color='gray', lw=1))
# Draw output projection
output_proj_rect = patches.FancyBboxPatch((1, 0.8), 3, 0.8,
boxstyle="round,pad=0.02",
facecolor=COLORS['cyan'], alpha=0.8,
edgecolor='black', linewidth=2)
ax.add_patch(output_proj_rect)
ax.text(2.5, 1.2, f'Output Projection W_O\n(n × {d_model})', ha='center', va='center', fontsize=10, fontweight='bold')
# Arrow from concat to output
ax.annotate('', xy=(2.5, 1.6), xytext=(2.5, 2),
arrowprops=dict(arrowstyle='->', color='black', lw=2))
ax.set_xlim(-0.5, 5.5)
ax.set_ylim(0, 8.5)
ax.axis('off')
ax.set_title(f'Multi-Head Attention with {n_heads} Heads', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()
print(f"Each head operates on d_head = d_model/n_heads = {d_model}/{n_heads} = {d_head} dimensions")
print(f"Total compute is the same as single-head with d_model = {d_model}")
print(f"But we get {n_heads} different 'views' of the relationships!")
Each head operates on d_head = d_model/n_heads = 512/4 = 128 dimensions Total compute is the same as single-head with d_model = 512 But we get 4 different 'views' of the relationships!
3. Transformer Block Architecture¶
A complete transformer is built from stacked transformer blocks. Each block contains two main sub-layers:
- Multi-Head Self-Attention (MHSA)
- Feed-Forward Network (FFN/MLP)
Both sub-layers use:
- Residual connections: Output = SubLayer(x) + x
- Layer normalization: Stabilizes training
The Feed-Forward Network¶
The FFN is deceptively simple but accounts for ~2/3 of parameters in most models:
$$\text{FFN}(x) = \text{GELU}(xW_1 + b_1)W_2 + b_2$$
Where:
- $W_1 \in \mathbb{R}^{d_{model} \times d_{ff}}$ (expansion)
- $W_2 \in \mathbb{R}^{d_{ff} \times d_{model}}$ (projection)
- $d_{ff}$ is typically 4× $d_{model}$ (the "expansion ratio")
Modern Variations: SwiGLU¶
Many modern models (Llama, PaLM) use SwiGLU instead of standard FFN:
$$\text{SwiGLU}(x) = (\text{Swish}(xW_{gate}) \odot xW_{up})W_{down}$$
This has 3 weight matrices instead of 2, with $d_{ff} = \frac{8}{3} d_{model}$ to match parameter count.
# Visualize: Complete transformer block architecture
fig, ax = plt.subplots(figsize=(10, 12))
def draw_block(ax, x, y, w, h, label, color, sublabel=''):
rect = patches.FancyBboxPatch((x, y), w, h,
boxstyle="round,pad=0.02",
facecolor=color, alpha=0.8,
edgecolor='black', linewidth=2)
ax.add_patch(rect)
ax.text(x + w/2, y + h/2, label, ha='center', va='center',
fontsize=10, fontweight='bold')
if sublabel:
ax.text(x + w/2, y + h/2 - 0.3, sublabel, ha='center', va='center', fontsize=8)
# Spacing
cx = 2.5 # center x
w = 3 # block width
# Input
draw_block(ax, cx - w/2, 0, w, 0.6, 'Input Embeddings', COLORS['embedding'])
# Layer norm 1
draw_block(ax, cx - w/2, 1, w, 0.5, 'Layer Norm', COLORS['light'])
# Multi-head attention
draw_block(ax, cx - w/2, 1.8, w, 1.2, 'Multi-Head\nSelf-Attention', COLORS['attention'])
# Add symbol
circle1 = plt.Circle((cx, 3.3), 0.2, color='white', ec='black', linewidth=2)
ax.add_patch(circle1)
ax.text(cx, 3.3, '+', ha='center', va='center', fontsize=14, fontweight='bold')
# Residual arrow 1
ax.annotate('', xy=(cx + 1.7, 3.3), xytext=(cx + 1.7, 0.3),
arrowprops=dict(arrowstyle='->', color=COLORS['primary'], lw=2))
ax.text(cx + 2, 1.8, 'Residual', ha='left', va='center', fontsize=9, rotation=90)
# Layer norm 2
draw_block(ax, cx - w/2, 3.8, w, 0.5, 'Layer Norm', COLORS['light'])
# FFN
draw_block(ax, cx - w/2, 4.6, w, 1.2, 'Feed-Forward\nNetwork (MLP)', COLORS['mlp'])
# Add symbol 2
circle2 = plt.Circle((cx, 6.1), 0.2, color='white', ec='black', linewidth=2)
ax.add_patch(circle2)
ax.text(cx, 6.1, '+', ha='center', va='center', fontsize=14, fontweight='bold')
# Residual arrow 2
ax.annotate('', xy=(cx + 1.7, 6.1), xytext=(cx + 1.7, 3.5),
arrowprops=dict(arrowstyle='->', color=COLORS['primary'], lw=2))
ax.text(cx + 2, 4.8, 'Residual', ha='left', va='center', fontsize=9, rotation=90)
# Output
draw_block(ax, cx - w/2, 6.6, w, 0.6, 'Output', COLORS['success'])
# Vertical arrows
arrow_positions = [(0.6, 1), (1.5, 1.8), (3, 3.8), (4.3, 4.6), (5.8, 6.6)]
for y1, y2 in arrow_positions:
ax.annotate('', xy=(cx, y2), xytext=(cx, y1),
arrowprops=dict(arrowstyle='->', color='black', lw=1.5))
# Annotations
ax.text(-0.5, 2.4, 'Attention\nSub-layer', ha='center', va='center', fontsize=10,
bbox=dict(boxstyle='round', facecolor='white', edgecolor='gray'))
ax.text(-0.5, 5.2, 'FFN\nSub-layer', ha='center', va='center', fontsize=10,
bbox=dict(boxstyle='round', facecolor='white', edgecolor='gray'))
# "× N layers" annotation
ax.text(5.5, 3.5, '× N\nlayers', ha='center', va='center', fontsize=14, fontweight='bold',
bbox=dict(boxstyle='round', facecolor=COLORS['secondary'], alpha=0.5))
ax.set_xlim(-1.5, 6.5)
ax.set_ylim(-0.5, 7.5)
ax.axis('off')
ax.set_title('Transformer Block (Decoder-only, Pre-Norm Style)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()
print("Modern LLMs (GPT, Llama, Claude) use 'pre-norm' style: LayerNorm before each sub-layer")
print("This improves training stability compared to original 'post-norm' design")
Modern LLMs (GPT, Llama, Claude) use 'pre-norm' style: LayerNorm before each sub-layer This improves training stability compared to original 'post-norm' design
4. Matrix Dimensions and Shapes¶
Understanding the exact tensor shapes is crucial for hardware design. Let's trace through a concrete example.
Configuration: Llama 2 7B¶
| Parameter | Value |
|---|---|
| d_model (hidden size) | 4,096 |
| n_layers | 32 |
| n_heads | 32 |
| d_head | 128 |
| d_ff (FFN intermediate) | 11,008 |
| vocab_size | 32,000 |
| context_length | 4,096 |
Weight Matrix Shapes (Per Layer)¶
Attention weights:
- $W_Q$: (4096, 4096) = 16.8M parameters
- $W_K$: (4096, 4096) = 16.8M parameters
- $W_V$: (4096, 4096) = 16.8M parameters
- $W_O$: (4096, 4096) = 16.8M parameters
- Total attention: 67.1M parameters per layer
FFN weights (SwiGLU):
- $W_{gate}$: (4096, 11008) = 45.1M parameters
- $W_{up}$: (4096, 11008) = 45.1M parameters
- $W_{down}$: (11008, 4096) = 45.1M parameters
- Total FFN: 135.3M parameters per layer
Total per layer: ~202M parameters 32 layers: ~6.5B parameters (+ embeddings ≈ 7B total)
# Calculate: Exact parameter counts for common model sizes
def count_parameters(d_model, n_layers, n_heads, d_ff, vocab_size, use_swiglu=True):
"""Count parameters in a transformer model."""
params = {}
# Embedding layers
params['embedding'] = vocab_size * d_model
# Per-layer attention
params['attn_qkv'] = 3 * d_model * d_model # W_Q, W_K, W_V
params['attn_out'] = d_model * d_model # W_O
params['attn_total'] = params['attn_qkv'] + params['attn_out']
# Per-layer FFN
if use_swiglu:
params['ffn_total'] = 3 * d_model * d_ff # gate, up, down
else:
params['ffn_total'] = 2 * d_model * d_ff # up, down
# Layer norms (small)
params['layer_norms'] = 2 * d_model # 2 per layer
# Total per layer
params['per_layer'] = params['attn_total'] + params['ffn_total'] + params['layer_norms']
# All layers
params['all_layers'] = params['per_layer'] * n_layers
# Final layer norm + output projection (often tied to embedding)
params['final_norm'] = d_model
params['output_proj'] = vocab_size * d_model # or tied
# Total
params['total'] = params['embedding'] + params['all_layers'] + params['final_norm']
params['total_with_output'] = params['total'] + params['output_proj']
return params
# Common model configurations
models = {
'Llama 2 7B': {'d_model': 4096, 'n_layers': 32, 'n_heads': 32, 'd_ff': 11008, 'vocab_size': 32000},
'Llama 2 13B': {'d_model': 5120, 'n_layers': 40, 'n_heads': 40, 'd_ff': 13824, 'vocab_size': 32000},
'Llama 2 70B': {'d_model': 8192, 'n_layers': 80, 'n_heads': 64, 'd_ff': 28672, 'vocab_size': 32000},
'Llama 3 405B': {'d_model': 16384, 'n_layers': 126, 'n_heads': 128, 'd_ff': 53248, 'vocab_size': 128000},
}
print("="*80)
print("PARAMETER BREAKDOWN BY MODEL")
print("="*80)
for name, config in models.items():
params = count_parameters(**config)
total_b = params['total_with_output'] / 1e9
attn_pct = (params['attn_total'] * config['n_layers']) / params['total_with_output'] * 100
ffn_pct = (params['ffn_total'] * config['n_layers']) / params['total_with_output'] * 100
print(f"\n{name}:")
print(f" d_model: {config['d_model']:,} | layers: {config['n_layers']} | d_ff: {config['d_ff']:,}")
print(f" Total parameters: {total_b:.1f}B")
print(f" Attention: {attn_pct:.1f}% | FFN: {ffn_pct:.1f}% | Embeddings: {100-attn_pct-ffn_pct:.1f}%")
print(f" Memory (FP16): {total_b * 2:.1f} GB | Memory (INT8): {total_b:.1f} GB")
================================================================================ PARAMETER BREAKDOWN BY MODEL ================================================================================ Llama 2 7B: d_model: 4,096 | layers: 32 | d_ff: 11,008 Total parameters: 6.7B Attention: 31.9% | FFN: 64.2% | Embeddings: 3.9% Memory (FP16): 13.5 GB | Memory (INT8): 6.7 GB Llama 2 13B: d_model: 5,120 | layers: 40 | d_ff: 13,824 Total parameters: 13.0B Attention: 32.2% | FFN: 65.3% | Embeddings: 2.5% Memory (FP16): 26.0 GB | Memory (INT8): 13.0 GB Llama 2 70B: d_model: 8,192 | layers: 80 | d_ff: 28,672 Total parameters: 78.4B Attention: 27.4% | FFN: 71.9% | Embeddings: 0.7% Memory (FP16): 156.7 GB | Memory (INT8): 78.4 GB Llama 3 405B: d_model: 16,384 | layers: 126 | d_ff: 53,248 Total parameters: 469.3B Attention: 28.8% | FFN: 70.3% | Embeddings: 0.9% Memory (FP16): 938.5 GB | Memory (INT8): 469.3 GB
# Visualize: Parameter distribution in a transformer
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# Llama 2 7B breakdown
config = models['Llama 2 7B']
params = count_parameters(**config)
# Pie chart of parameter distribution
attn_params = params['attn_total'] * config['n_layers']
ffn_params = params['ffn_total'] * config['n_layers']
other_params = params['total_with_output'] - attn_params - ffn_params
sizes = [attn_params, ffn_params, other_params]
labels = ['Attention\n(QKV + Out)', 'FFN\n(SwiGLU)', 'Embeddings\n& Norms']
colors_pie = [COLORS['attention'], COLORS['mlp'], COLORS['embedding']]
explode = (0.02, 0.02, 0.02)
axes[0].pie(sizes, labels=labels, colors=colors_pie, explode=explode,
autopct='%1.1f%%', startangle=90, textprops={'fontsize': 11})
axes[0].set_title('Llama 2 7B: Parameter Distribution', fontsize=12, fontweight='bold')
# Bar chart comparing models
model_names = list(models.keys())
attn_pcts = []
ffn_pcts = []
other_pcts = []
for config in models.values():
p = count_parameters(**config)
total = p['total_with_output']
attn_pcts.append((p['attn_total'] * config['n_layers']) / total * 100)
ffn_pcts.append((p['ffn_total'] * config['n_layers']) / total * 100)
other_pcts.append(100 - attn_pcts[-1] - ffn_pcts[-1])
x = np.arange(len(model_names))
width = 0.6
axes[1].bar(x, attn_pcts, width, label='Attention', color=COLORS['attention'])
axes[1].bar(x, ffn_pcts, width, bottom=attn_pcts, label='FFN', color=COLORS['mlp'])
axes[1].bar(x, other_pcts, width, bottom=np.array(attn_pcts)+np.array(ffn_pcts),
label='Embeddings', color=COLORS['embedding'])
axes[1].set_ylabel('Percentage of Parameters', fontsize=11)
axes[1].set_xticks(x)
axes[1].set_xticklabels([n.replace(' ', '\n') for n in model_names], fontsize=9)
axes[1].legend(loc='upper right', fontsize=10)
axes[1].set_title('Parameter Distribution Across Model Sizes', fontsize=12, fontweight='bold')
axes[1].set_ylim(0, 100)
axes[1].grid(True, alpha=0.3, axis='y')
plt.tight_layout()
plt.show()
print("Key insight: FFN layers contain ~2/3 of parameters!")
print("Attention gets the fame, but the MLP does the heavy lifting.")
Key insight: FFN layers contain ~2/3 of parameters! Attention gets the fame, but the MLP does the heavy lifting.
5. Computational Complexity Analysis¶
Understanding FLOPs (floating-point operations) is essential for hardware sizing.
Matrix Multiplication Complexity¶
For a matrix multiply $C = AB$ where $A \in \mathbb{R}^{m \times k}$ and $B \in \mathbb{R}^{k \times n}$:
- Each output element requires $k$ multiplications and $k-1$ additions
- Total: $2mnk$ FLOPs (counting multiply-add as 2 ops)
Attention FLOPs (per layer, per token for inference)¶
For sequence length $n$, model dimension $d$:
| Operation | Shape | FLOPs |
|---|---|---|
| Q projection | (1, d) × (d, d) | $2d^2$ |
| K projection | (1, d) × (d, d) | $2d^2$ |
| V projection | (1, d) × (d, d) | $2d^2$ |
| QK^T scores | (1, d) × (n, d)^T | $2nd$ |
| Attention × V | (1, n) × (n, d) | $2nd$ |
| Output projection | (1, d) × (d, d) | $2d^2$ |
| Total Attention | $8d^2 + 4nd$ |
FFN FLOPs (per layer, per token)¶
For FFN dimension $d_{ff} = 4d$ (or $\frac{8}{3}d$ for SwiGLU):
| Operation | Shape | FLOPs |
|---|---|---|
| Up projection | (1, d) × (d, d_ff) | $2d \cdot d_{ff}$ |
| Gate projection (SwiGLU) | (1, d) × (d, d_ff) | $2d \cdot d_{ff}$ |
| Down projection | (1, d_ff) × (d_ff, d) | $2d \cdot d_{ff}$ |
| Total FFN (SwiGLU) | $6d \cdot d_{ff}$ |
Total FLOPs per Token¶
$$\text{FLOPs/token} = L \times (8d^2 + 4nd + 6d \cdot d_{ff})$$
Where $L$ is the number of layers.
# Calculate: FLOPs for different model sizes and sequence lengths
def compute_flops_per_token(d_model, n_layers, d_ff, seq_len, use_swiglu=True):
"""Compute FLOPs per output token during inference."""
# Attention FLOPs
qkv_proj = 3 * 2 * d_model * d_model # Q, K, V projections
qk_scores = 2 * seq_len * d_model # QK^T
attn_v = 2 * seq_len * d_model # attention × V
out_proj = 2 * d_model * d_model # output projection
attn_flops = qkv_proj + qk_scores + attn_v + out_proj
# FFN FLOPs
if use_swiglu:
ffn_flops = 3 * 2 * d_model * d_ff # gate, up, down
else:
ffn_flops = 2 * 2 * d_model * d_ff # up, down
# Total per layer
layer_flops = attn_flops + ffn_flops
# All layers
total_flops = layer_flops * n_layers
return {
'attention': attn_flops * n_layers,
'ffn': ffn_flops * n_layers,
'total': total_flops
}
print("="*80)
print("FLOPs PER OUTPUT TOKEN (Inference, varies with sequence length)")
print("="*80)
seq_lengths = [512, 2048, 8192, 32768]
for name, config in models.items():
print(f"\n{name}:")
print("-" * 60)
for seq_len in seq_lengths:
flops = compute_flops_per_token(
config['d_model'], config['n_layers'], config['d_ff'], seq_len
)
total_tflops = flops['total'] / 1e12
attn_pct = flops['attention'] / flops['total'] * 100
print(f" seq_len={seq_len:>5}: {total_tflops:>7.2f} TFLOPs/token "
f"(Attention: {attn_pct:>4.1f}%, FFN: {100-attn_pct:>4.1f}%)")
================================================================================ FLOPs PER OUTPUT TOKEN (Inference, varies with sequence length) ================================================================================ Llama 2 7B: ------------------------------------------------------------ seq_len= 512: 0.01 TFLOPs/token (Attention: 34.5%, FFN: 65.5%) seq_len= 2048: 0.01 TFLOPs/token (Attention: 38.3%, FFN: 61.7%) seq_len= 8192: 0.02 TFLOPs/token (Attention: 49.8%, FFN: 50.2%) seq_len=32768: 0.03 TFLOPs/token (Attention: 71.3%, FFN: 28.7%) Llama 2 13B: ------------------------------------------------------------ seq_len= 512: 0.03 TFLOPs/token (Attention: 34.1%, FFN: 65.9%) seq_len= 2048: 0.03 TFLOPs/token (Attention: 37.2%, FFN: 62.8%) seq_len= 8192: 0.03 TFLOPs/token (Attention: 47.1%, FFN: 52.9%) seq_len=32768: 0.05 TFLOPs/token (Attention: 67.5%, FFN: 32.5%) Llama 2 70B: ------------------------------------------------------------ seq_len= 512: 0.16 TFLOPs/token (Attention: 28.2%, FFN: 71.8%) seq_len= 2048: 0.16 TFLOPs/token (Attention: 30.0%, FFN: 70.0%) seq_len= 8192: 0.18 TFLOPs/token (Attention: 36.4%, FFN: 63.6%) seq_len=32768: 0.24 TFLOPs/token (Attention: 53.3%, FFN: 46.7%) Llama 3 405B: ------------------------------------------------------------ seq_len= 512: 0.93 TFLOPs/token (Attention: 29.4%, FFN: 70.6%) seq_len= 2048: 0.95 TFLOPs/token (Attention: 30.4%, FFN: 69.6%) seq_len= 8192: 1.00 TFLOPs/token (Attention: 33.9%, FFN: 66.1%) seq_len=32768: 1.20 TFLOPs/token (Attention: 45.1%, FFN: 54.9%)
# Visualize: How FLOPs scale with sequence length
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# Llama 2 7B example
config = models['Llama 2 7B']
seq_lengths = np.logspace(2, 5, 50) # 100 to 100,000
attn_flops = []
ffn_flops = []
for sl in seq_lengths:
f = compute_flops_per_token(config['d_model'], config['n_layers'], config['d_ff'], int(sl))
attn_flops.append(f['attention'] / 1e12)
ffn_flops.append(f['ffn'] / 1e12)
# Left plot: Stacked area
axes[0].fill_between(seq_lengths, 0, ffn_flops, alpha=0.7, color=COLORS['mlp'], label='FFN')
axes[0].fill_between(seq_lengths, ffn_flops, np.array(ffn_flops)+np.array(attn_flops),
alpha=0.7, color=COLORS['attention'], label='Attention')
axes[0].set_xscale('log')
axes[0].set_xlabel('Sequence Length', fontsize=12)
axes[0].set_ylabel('TFLOPs per Token', fontsize=12)
axes[0].set_title('Llama 2 7B: FLOPs vs Sequence Length', fontsize=12, fontweight='bold')
axes[0].legend(loc='upper left', fontsize=10)
axes[0].grid(True, alpha=0.3)
axes[0].set_xlim(100, 100000)
# Right plot: Attention percentage
attn_pct = np.array(attn_flops) / (np.array(attn_flops) + np.array(ffn_flops)) * 100
axes[1].plot(seq_lengths, attn_pct, color=COLORS['attention'], linewidth=2.5)
axes[1].axhline(y=50, color='gray', linestyle='--', alpha=0.5)
axes[1].fill_between(seq_lengths, attn_pct, 50, where=(attn_pct > 50),
color=COLORS['attention'], alpha=0.3, label='Attention-dominated')
axes[1].fill_between(seq_lengths, attn_pct, 50, where=(attn_pct <= 50),
color=COLORS['mlp'], alpha=0.3, label='FFN-dominated')
axes[1].set_xscale('log')
axes[1].set_xlabel('Sequence Length', fontsize=12)
axes[1].set_ylabel('Attention % of Total FLOPs', fontsize=12)
axes[1].set_title('Crossover Point: When Attention Dominates', fontsize=12, fontweight='bold')
axes[1].legend(loc='lower right', fontsize=10)
axes[1].grid(True, alpha=0.3)
axes[1].set_xlim(100, 100000)
axes[1].set_ylim(0, 100)
# Find crossover point
crossover_idx = np.argmin(np.abs(np.array(attn_pct) - 50))
crossover_seq = seq_lengths[crossover_idx]
axes[1].axvline(x=crossover_seq, color=COLORS['danger'], linestyle=':', linewidth=2)
axes[1].annotate(f'Crossover\n~{int(crossover_seq):,}',
xy=(crossover_seq, 50), xytext=(crossover_seq*2, 60),
fontsize=10, arrowprops=dict(arrowstyle='->', color='black'))
plt.tight_layout()
plt.show()
print(f"\nFor Llama 2 7B, attention dominates when sequence length > ~{int(crossover_seq):,}")
print("Modern context windows (128K+) put us firmly in the attention-dominated regime!")
For Llama 2 7B, attention dominates when sequence length > ~7,906 Modern context windows (128K+) put us firmly in the attention-dominated regime!
6. Memory Requirements and the Memory Wall¶
The memory wall is the fundamental bottleneck for LLM inference: we can compute faster than we can feed data to the compute units.
Arithmetic Intensity¶
Arithmetic intensity = FLOPs / Bytes transferred
- High intensity (matrix-matrix multiply): Data reuse is high, compute-bound
- Low intensity (vector operations): Little reuse, memory-bound
LLM Inference: The Memory-Bound Problem¶
During inference with batch size = 1:
- Each matrix multiply is (1, d) × (d, d) → vector-matrix multiply
- We must load the entire weight matrix (d² elements)
- But we only do 2d² FLOPs
- Arithmetic intensity = 2 FLOPs/byte (with FP16 weights)
Compare to hardware capabilities:
| Hardware | Compute (TFLOPS) | Memory BW (TB/s) | Compute/BW Ratio |
|---|---|---|---|
| NVIDIA A100 | 312 (FP16) | 2.0 | 156 |
| NVIDIA H100 | 990 (FP16) | 3.35 | 296 |
| Google TPU v4 | 275 (BF16) | 1.2 | 229 |
The hardware wants 150-300 FLOPs per byte, but inference provides only 2!
Result: LLM inference achieves only ~1% of peak hardware utilization with batch=1.
# Visualize: The roofline model for LLM inference
fig, ax = plt.subplots(figsize=(12, 7))
# Hardware specs (H100 as example)
peak_compute = 990 # TFLOPS (FP16)
mem_bw = 3.35 # TB/s
# Roofline: achievable = min(peak_compute, intensity * mem_bw)
# Units: TFLOPS = FLOPs/Byte * TB/s (the "Tera" cancels)
intensity = np.logspace(-1, 3, 1000)
achievable = np.minimum(peak_compute, intensity * mem_bw)
ax.loglog(intensity, achievable, 'b-', linewidth=3, label='Roofline (H100)')
# Ridge point: where memory-bound meets compute-bound
ridge_point = peak_compute / mem_bw # ~296 FLOPs/Byte
# Fill regions
ax.fill_between(intensity, achievable, 0.1, where=(intensity < ridge_point),
alpha=0.2, color='red', label='Memory-bound region')
ax.fill_between(intensity, achievable, 0.1, where=(intensity >= ridge_point),
alpha=0.2, color='green', label='Compute-bound region')
# Mark key points
workloads = {
'LLM Inference\n(batch=1)': (2, 2 * mem_bw),
'LLM Inference\n(batch=32)': (64, 64 * mem_bw),
'LLM Training': (200, min(200 * mem_bw, peak_compute * 0.9)),
'Dense GEMM': (500, peak_compute * 0.95),
}
for name, (ai, perf) in workloads.items():
perf = min(perf, peak_compute) # Cap at roofline
ax.scatter([ai], [perf], s=150, zorder=5, edgecolors='black', linewidth=2)
ax.annotate(name, xy=(ai, perf), xytext=(ai*1.5, perf*1.3),
fontsize=10, ha='left',
arrowprops=dict(arrowstyle='->', color='black', lw=1))
# Ridge point annotation
ax.axvline(x=ridge_point, color='purple', linestyle='--', alpha=0.7)
ax.annotate(f'Ridge point\n({ridge_point:.0f} FLOP/B)', xy=(ridge_point, peak_compute/2),
fontsize=10, ha='center', color='purple')
ax.set_xlabel('Arithmetic Intensity (FLOPs/Byte)', fontsize=12)
ax.set_ylabel('Achievable Performance (TFLOPS)', fontsize=12)
ax.set_title('Roofline Model: Why LLM Inference is Memory-Bound', fontsize=14, fontweight='bold')
ax.legend(loc='lower right', fontsize=10)
ax.grid(True, alpha=0.3, which='both')
ax.set_xlim(0.5, 1000)
ax.set_ylim(0.1, 2000)
plt.tight_layout()
plt.show()
# Correct calculations
batch1_perf = 2 * mem_bw # ~6.7 TFLOPS
print(f"\nWith batch=1: Arithmetic intensity ≈ 2 FLOP/Byte")
print(f"H100 achieves: {batch1_perf:.1f} TFLOPS = {batch1_perf / peak_compute * 100:.1f}% of peak")
print(f"\nTo reach peak compute, we need intensity > {ridge_point:.0f} FLOP/Byte")
print(f"This requires batch size ≈ {ridge_point/2:.0f} or more!")
With batch=1: Arithmetic intensity ≈ 2 FLOP/Byte H100 achieves: 6.7 TFLOPS = 0.7% of peak To reach peak compute, we need intensity > 296 FLOP/Byte This requires batch size ≈ 148 or more!
# Calculate: Memory requirements for KV cache
def kv_cache_size(n_layers, n_heads, d_head, seq_len, batch_size=1, dtype_bytes=2):
"""Calculate KV cache size in bytes."""
# K and V for each layer, for each head
# Shape: [batch, n_layers, 2 (K+V), n_heads, seq_len, d_head]
size = batch_size * n_layers * 2 * n_heads * seq_len * d_head * dtype_bytes
return size
print("="*80)
print("KV CACHE MEMORY REQUIREMENTS")
print("="*80)
seq_lengths = [2048, 8192, 32768, 131072]
kv_configs = {
'Llama 2 7B': {'n_layers': 32, 'n_heads': 32, 'd_head': 128},
'Llama 2 70B': {'n_layers': 80, 'n_heads': 64, 'd_head': 128},
'Llama 3 405B': {'n_layers': 126, 'n_heads': 128, 'd_head': 128},
}
print("\nKV Cache Size (FP16, batch=1):")
print("-" * 70)
print(f"{'Model':<15} | " + " | ".join([f"{sl:>10,} ctx" for sl in seq_lengths]))
print("-" * 70)
for name, config in kv_configs.items():
sizes = []
for sl in seq_lengths:
size_gb = kv_cache_size(**config, seq_len=sl) / 1e9
sizes.append(f"{size_gb:>10.1f} GB")
print(f"{name:<15} | " + " | ".join(sizes))
print("\n" + "="*80)
print("TOTAL MEMORY = Model Weights + KV Cache + Activations")
print("="*80)
# Example for Llama 2 70B at 32K context
model_size = 70 * 2 # 70B params × 2 bytes (FP16) = 140 GB
kv_size = kv_cache_size(**kv_configs['Llama 2 70B'], seq_len=32768) / 1e9
activation_size = 5 # Rough estimate for intermediate activations
print(f"\nLlama 2 70B at 32K context:")
print(f" Model weights: {model_size:.0f} GB")
print(f" KV cache: {kv_size:.1f} GB")
print(f" Activations: ~{activation_size} GB")
print(f" ─────────────────────")
print(f" Total: ~{model_size + kv_size + activation_size:.0f} GB")
print(f"\n → Requires multiple GPUs (H100 has 80GB each)")
================================================================================ KV CACHE MEMORY REQUIREMENTS ================================================================================ KV Cache Size (FP16, batch=1): ---------------------------------------------------------------------- Model | 2,048 ctx | 8,192 ctx | 32,768 ctx | 131,072 ctx ---------------------------------------------------------------------- Llama 2 7B | 1.1 GB | 4.3 GB | 17.2 GB | 68.7 GB Llama 2 70B | 5.4 GB | 21.5 GB | 85.9 GB | 343.6 GB Llama 3 405B | 16.9 GB | 67.6 GB | 270.6 GB | 1082.3 GB ================================================================================ TOTAL MEMORY = Model Weights + KV Cache + Activations ================================================================================ Llama 2 70B at 32K context: Model weights: 140 GB KV cache: 85.9 GB Activations: ~5 GB ───────────────────── Total: ~231 GB → Requires multiple GPUs (H100 has 80GB each)
The Memory Bandwidth Crisis¶
Let's calculate how fast we can actually generate tokens:
Token generation rate = Memory Bandwidth / Bytes per Token
For Llama 2 70B (FP16):
- Model size: 140 GB
- Each token requires reading all weights: 140 GB
- H100 bandwidth: 3.35 TB/s
- Max tokens/second: 3350 / 140 ≈ 24 tokens/s
This is the hard limit from memory bandwidth alone, regardless of compute capability!
Batch processing helps: With batch size B, we read weights once but generate B tokens:
- Batch=8: 192 tokens/s total (24 per sequence)
- Batch=32: 768 tokens/s total (24 per sequence)
But: larger batches require more KV cache memory, eventually hitting memory capacity limits.
# Visualize: Memory bandwidth limited throughput
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# Hardware specs
hardware = {
'A100 80GB': {'mem_bw': 2.0, 'capacity': 80, 'color': COLORS['primary']},
'H100 80GB': {'mem_bw': 3.35, 'capacity': 80, 'color': COLORS['success']},
'H100 NVL 188GB': {'mem_bw': 7.8, 'capacity': 188, 'color': COLORS['secondary']},
}
# Model sizes (FP16)
model_sizes = {
'7B': 14,
'13B': 26,
'34B': 68,
'70B': 140,
'405B': 810,
}
# Left plot: Tokens/second for different models on H100
h100_bw = 3.35 * 1000 # GB/s
model_names = list(model_sizes.keys())
tokens_per_sec = [h100_bw / size for size in model_sizes.values()]
bars = axes[0].bar(model_names, tokens_per_sec, color=COLORS['primary'], edgecolor='black')
axes[0].set_ylabel('Tokens per Second (batch=1)', fontsize=11)
axes[0].set_xlabel('Model Size', fontsize=11)
axes[0].set_title('H100: Bandwidth-Limited Generation Speed', fontsize=12, fontweight='bold')
axes[0].grid(True, alpha=0.3, axis='y')
# Add value labels
for bar, tps in zip(bars, tokens_per_sec):
height = bar.get_height()
axes[0].text(bar.get_x() + bar.get_width()/2., height,
f'{tps:.0f}', ha='center', va='bottom', fontsize=10, fontweight='bold')
# Human reading speed reference
axes[0].axhline(y=250, color=COLORS['danger'], linestyle='--', linewidth=2)
axes[0].text(0.5, 260, 'Human reading speed (~250 wpm ≈ 4 tok/s)', fontsize=9, color=COLORS['danger'])
# Right plot: How batch size affects throughput (Llama 70B)
model_size_gb = 140 # Llama 70B FP16
kv_per_token_gb = 0.01 # ~10MB per token for 70B (rough estimate)
memory_capacity = 80 # H100 80GB
batch_sizes = np.arange(1, 50)
throughputs = []
fits_in_memory = []
for bs in batch_sizes:
# Check if it fits (simplified - assume 2K context)
kv_cache = bs * 2048 * kv_per_token_gb
fits = (model_size_gb + kv_cache) <= memory_capacity
fits_in_memory.append(fits)
if fits:
throughputs.append(h100_bw / model_size_gb * bs)
else:
throughputs.append(np.nan)
axes[1].plot(batch_sizes, throughputs, 'o-', color=COLORS['primary'], linewidth=2, markersize=4)
axes[1].axvline(x=batch_sizes[np.array(fits_in_memory).sum()-1], color=COLORS['danger'],
linestyle='--', linewidth=2, label='Memory limit')
axes[1].set_xlabel('Batch Size', fontsize=11)
axes[1].set_ylabel('Total Tokens per Second', fontsize=11)
axes[1].set_title('Llama 70B on H100: Batch vs Throughput', fontsize=12, fontweight='bold')
axes[1].grid(True, alpha=0.3)
axes[1].legend(fontsize=10)
# Annotate memory limit
max_batch = batch_sizes[np.array(fits_in_memory).sum()-1]
axes[1].annotate(f'Max batch ≈{max_batch}\n(memory limit)',
xy=(max_batch, throughputs[max_batch-1]),
xytext=(max_batch+5, throughputs[max_batch-1]*0.7),
fontsize=10, arrowprops=dict(arrowstyle='->', color='black'))
plt.tight_layout()
plt.show()
print("\nKey insight: Larger models have lower per-user latency.")
print("Batching improves throughput but is limited by memory capacity.")
print("This is why datacenters deploy many GPUs - to serve many users in parallel!")
Key insight: Larger models have lower per-user latency. Batching improves throughput but is limited by memory capacity. This is why datacenters deploy many GPUs - to serve many users in parallel!
7. Modern Architectural Variations¶
Modern LLMs have evolved significantly from the original transformer. Key innovations:
Positional Encoding: RoPE¶
Rotary Position Embedding (RoPE) encodes position by rotating query/key vectors:
- No learned position embeddings
- Naturally handles extrapolation to longer sequences
- Used by Llama, Mistral, and most modern models
Attention Variants¶
| Variant | Description | Memory | Used By |
|---|---|---|---|
| Multi-Head Attention (MHA) | Original, full KV cache per head | O(L × H × d) | GPT-3 |
| Multi-Query Attention (MQA) | Share K,V across all heads | O(L × d) | PaLM, Falcon |
| Grouped-Query Attention (GQA) | Share K,V within groups | O(L × G × d) | Llama 2 70B+ |
GQA is a middle ground: fewer KV heads than query heads, reducing memory while preserving quality.
Mixture of Experts (MoE)¶
MoE models activate only a subset of parameters per token:
- Total parameters: Huge (e.g., 1.8T for GPT-4 rumored)
- Active parameters: Much smaller (e.g., ~300B)
- Benefit: More capacity without proportional compute increase
- Challenge: Memory still scales with total parameters
Flash Attention¶
Not an architectural change, but a crucial implementation optimization:
- Fuses attention operations to avoid materializing the n×n attention matrix
- Reduces memory from O(n²) to O(n)
- Enables much longer context lengths
# Visualize: Attention variants and their KV cache impact
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# Left: Attention variant comparison
variants = ['MHA\n(Original)', 'GQA\n(Llama 2)', 'MQA\n(PaLM)']
n_heads = 32
gqa_groups = 8
kv_heads = [32, 8, 1] # MHA: all heads, GQA: groups, MQA: single
relative_memory = [1.0, gqa_groups/n_heads, 1/n_heads]
bars = axes[0].bar(variants, relative_memory, color=[COLORS['primary'], COLORS['secondary'], COLORS['success']],
edgecolor='black', linewidth=2)
axes[0].set_ylabel('Relative KV Cache Size', fontsize=11)
axes[0].set_title('KV Cache Memory by Attention Variant', fontsize=12, fontweight='bold')
axes[0].set_ylim(0, 1.2)
axes[0].grid(True, alpha=0.3, axis='y')
for bar, mem, kvh in zip(bars, relative_memory, kv_heads):
axes[0].text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.03,
f'{mem*100:.0f}%\n({kvh} KV heads)', ha='center', va='bottom', fontsize=10)
# Right: Context length vs memory with Flash Attention comparison
context_lengths = np.logspace(2, 6, 100) # 100 to 1M
d_model = 4096
# Standard attention: O(n²) for attention matrix
standard_memory = context_lengths ** 2 * 4 / 1e9 # float32, GB
# Flash attention: O(n)
flash_memory = context_lengths * d_model * 4 / 1e9 # float32, GB
axes[1].loglog(context_lengths, standard_memory, '-', color=COLORS['danger'],
linewidth=2.5, label='Standard Attention O(n²)')
axes[1].loglog(context_lengths, flash_memory, '-', color=COLORS['success'],
linewidth=2.5, label='Flash Attention O(n)')
# Memory limits
axes[1].axhline(y=80, color='gray', linestyle='--', alpha=0.7)
axes[1].text(200, 100, 'H100 80GB limit', fontsize=9, color='gray')
axes[1].set_xlabel('Context Length (tokens)', fontsize=11)
axes[1].set_ylabel('Memory for Attention (GB)', fontsize=11)
axes[1].set_title('Why Flash Attention Enables Long Context', fontsize=12, fontweight='bold')
axes[1].legend(loc='upper left', fontsize=10)
axes[1].grid(True, alpha=0.3, which='both')
axes[1].set_xlim(100, 1e6)
axes[1].set_ylim(1e-4, 1e6)
# Annotate crossover
axes[1].annotate('Without Flash Attention,\n1M context = 4 PB!',
xy=(1e6, standard_memory[-1]), xytext=(1e5, 1e5),
fontsize=10, arrowprops=dict(arrowstyle='->', color='black'))
plt.tight_layout()
plt.show()
print("GQA reduces KV cache by 4x (Llama 2 70B) or more, enabling longer contexts.")
print("Flash Attention avoids materializing the n×n attention matrix entirely.")
GQA reduces KV cache by 4x (Llama 2 70B) or more, enabling longer contexts. Flash Attention avoids materializing the n×n attention matrix entirely.
# Visualize: MoE architecture concept
fig, ax = plt.subplots(figsize=(14, 7))
# Draw input
input_rect = patches.FancyBboxPatch((5.5, 6.5), 2, 0.6,
boxstyle="round,pad=0.02",
facecolor=COLORS['embedding'], alpha=0.8,
edgecolor='black', linewidth=2)
ax.add_patch(input_rect)
ax.text(6.5, 6.8, 'Input Token', ha='center', va='center', fontsize=10, fontweight='bold')
# Router
router_rect = patches.FancyBboxPatch((5.5, 5), 2, 0.8,
boxstyle="round,pad=0.02",
facecolor=COLORS['attention'], alpha=0.8,
edgecolor='black', linewidth=2)
ax.add_patch(router_rect)
ax.text(6.5, 5.4, 'Router\n(learned)', ha='center', va='center', fontsize=10, fontweight='bold')
# Arrow from input to router
ax.annotate('', xy=(6.5, 5.8), xytext=(6.5, 6.5),
arrowprops=dict(arrowstyle='->', color='black', lw=2))
# Expert FFNs
n_experts = 8
active_experts = [2, 5] # Which experts are activated
for i in range(n_experts):
x = 1 + i * 1.5
is_active = i in active_experts
color = COLORS['success'] if is_active else COLORS['light']
alpha = 0.9 if is_active else 0.4
expert_rect = patches.FancyBboxPatch((x, 2.5), 1.2, 1.5,
boxstyle="round,pad=0.02",
facecolor=color, alpha=alpha,
edgecolor='black', linewidth=2)
ax.add_patch(expert_rect)
ax.text(x + 0.6, 3.25, f'Expert\n{i+1}', ha='center', va='center', fontsize=9,
fontweight='bold' if is_active else 'normal')
# Arrows from router
if is_active:
ax.annotate('', xy=(x + 0.6, 4), xytext=(6.5, 5),
arrowprops=dict(arrowstyle='->', color=COLORS['success'], lw=2))
else:
ax.plot([6.5, x + 0.6], [5, 4], ':', color='gray', alpha=0.3, lw=1)
# Weighted sum
sum_rect = patches.FancyBboxPatch((5.5, 0.8), 2, 0.8,
boxstyle="round,pad=0.02",
facecolor=COLORS['secondary'], alpha=0.8,
edgecolor='black', linewidth=2)
ax.add_patch(sum_rect)
ax.text(6.5, 1.2, 'Weighted\nSum', ha='center', va='center', fontsize=10, fontweight='bold')
# Arrows from active experts to sum
for i in active_experts:
x = 1 + i * 1.5 + 0.6
ax.annotate('', xy=(6.5, 1.6), xytext=(x, 2.5),
arrowprops=dict(arrowstyle='->', color=COLORS['success'], lw=2))
# Annotations
ax.text(0.5, 3.25, 'Top-K\nselection\n(K=2)', ha='center', va='center', fontsize=10,
bbox=dict(boxstyle='round', facecolor='white', edgecolor=COLORS['success']))
ax.text(13, 3.25, f'8 experts total\n2 active per token\n\nCompute: 2/8 = 25%\nParams: 8× of dense',
ha='left', va='center', fontsize=10,
bbox=dict(boxstyle='round', facecolor='white', edgecolor='gray'))
ax.set_xlim(-0.5, 14)
ax.set_ylim(0, 7.5)
ax.axis('off')
ax.set_title('Mixture of Experts (MoE): Sparse Activation', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()
print("MoE models like Mixtral 8x7B have 8 experts but only use 2 per token.")
print("Result: 47B total params, but only ~13B active compute per token.")
print("Challenge: Still need to store all 47B params in memory!")
MoE models like Mixtral 8x7B have 8 experts but only use 2 per token. Result: 47B total params, but only ~13B active compute per token. Challenge: Still need to store all 47B params in memory!
8. Summary¶
Key Concepts¶
Transformer Architecture:
- Self-attention allows all-to-all token interactions: $\text{Attention} = \text{softmax}(QK^T/\sqrt{d_k})V$
- Multi-head attention provides multiple "views" of relationships
- FFN/MLP layers contain ~2/3 of parameters
- Modern variants: RoPE, GQA, SwiGLU, Flash Attention
Computational Profile:
- FLOPs scale as O(Ld²) for FFN, O(Lnd) for attention
- At long context (>10K), attention dominates compute
- Most operations are matrix multiplications → ideal for systolic arrays
Memory Bottleneck:
- Inference at batch=1 has arithmetic intensity ≈ 2 FLOP/byte
- Hardware wants 150-300 FLOP/byte → achieves only ~1% utilization
- Memory bandwidth, not compute, limits token generation speed
Key Numbers to Remember¶
| Metric | Value |
|---|---|
| Llama 2 70B parameters | 70 billion |
| FP16 model memory | 140 GB |
| H100 memory bandwidth | 3.35 TB/s |
| Max tokens/sec (70B, batch=1) | ~24 |
| Attention intensity (batch=1) | 2 FLOP/byte |
| H100 ridge point | ~300 FLOP/byte |