Transformer Attention Mechanisms: From Self-Attention to Flash Attention 3
A deep dive into transformer attention — the math, the memory bottleneck, and how Flash Attention 3 achieves 1.5–2x speedups through hardware-aware algorithm design.
Introduction
Attention is the heart of every modern language model. Understanding it deeply — from the original formulation to the hardware-aware optimizations that make 100B-parameter models practical — is essential for any serious AI engineer.
The Self-Attention Mechanism
The original scaled dot-product attention from "Attention Is All You Need":
Where:
- — queries (what am I looking for?)
- — keys (what do I contain?)
- — values (what do I return?)
- = sequence length, = key dimension
In code:
import torch
import torch.nn.functional as F
import math
def scaled_dot_product_attention(
Q: torch.Tensor,
K: torch.Tensor,
V: torch.Tensor,
mask: torch.Tensor | None = None
) -> torch.Tensor:
d_k = Q.shape[-1]
# Compute attention scores
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
# Apply causal mask (decoder self-attention)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# Softmax + weighted sum
weights = F.softmax(scores, dim=-1)
return torch.matmul(weights, V)The Memory Bottleneck
Standard attention is in both time and memory. For a sequence of 4096 tokens with FP16:
This is just for the attention matrix — before the FFN, embeddings, or activations. At 100K sequence length, this explodes to 800 GB.
Multi-Head Attention
Rather than a single attention, use parallel heads, each projecting to a lower dimension:
Where
class MultiHeadAttention(torch.nn.Module):
def __init__(self, d_model: int, num_heads: int):
super().__init__()
assert d_model % num_heads == 0
self.d_k = d_model // num_heads
self.heads = num_heads
self.W_q = torch.nn.Linear(d_model, d_model)
self.W_k = torch.nn.Linear(d_model, d_model)
self.W_v = torch.nn.Linear(d_model, d_model)
self.W_o = torch.nn.Linear(d_model, d_model)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, _ = x.shape
Q = self.W_q(x).view(B, N, self.heads, self.d_k).transpose(1, 2)
K = self.W_k(x).view(B, N, self.heads, self.d_k).transpose(1, 2)
V = self.W_v(x).view(B, N, self.heads, self.d_k).transpose(1, 2)
attn = scaled_dot_product_attention(Q, K, V)
attn = attn.transpose(1, 2).contiguous().view(B, N, -1)
return self.W_o(attn)Flash Attention: Hardware-Aware Algorithm Design
Flash Attention (Dao et al., 2022) achieves the same mathematical result as standard attention but with memory by exploiting GPU memory hierarchy:
Key insight: Avoid materializing the full attention matrix in HBM (slow). Instead, process attention in tiles that fit in SRAM (fast):
Standard: HBM → read Q,K,V | compute S=QKᵀ | write S | read S | softmax | write P | P@V → HBM
(Many slow HBM reads/writes)
FlashAttn: HBM → tile of Q,K,V → SRAM → compute full attention tile → accumulate output → HBM
(Far fewer HBM accesses)
Flash Attention 3 (2024) adds:
- Warp specialization: Separate CUDA warps for matrix multiply vs. softmax
- Pingpong scheduling: Overlaps compute and memory ops
- FP8 support: 1.5–2x speedup on H100 GPUs
# In practice — just use the library
from flash_attn import flash_attn_func
# Drop-in replacement for standard attention
output = flash_attn_func(
q, k, v,
causal=True, # Causal mask for autoregressive generation
softmax_scale=1.0 / math.sqrt(d_k),
)GQA and MQA: Reducing KV Cache
Modern LLMs use Grouped Query Attention (GQA) or Multi-Query Attention (MQA) to shrink the KV cache:
| Method | Query heads | KV heads | KV cache size | Quality |
|---|---|---|---|---|
| MHA | 32 | 32 | 100% | Best |
| GQA | 32 | 8 | 25% | Near-best |
| MQA | 32 | 1 | 3.1% | Slightly lower |
Llama 3 and Mistral use GQA — the best tradeoff in practice.
Key Takeaways
- Self-attention is — this is the fundamental constraint on context length
- Flash Attention makes long contexts practical — use it universally for training and inference
- GQA is the production standard for KV cache efficiency
- Understanding the memory math is critical for capacity planning
References
- Vaswani et al., "Attention Is All You Need" (2017)
- Dao et al., "FlashAttention-2: Faster Attention with Better Parallelism" (2023)
- Ainslie et al., "GQA: Training Generalized Multi-Query Transformer Models" (2023)
Written by
Rohit Raj
Senior AI Engineer @ American Express