Back to Blog
AI ResearchTransformersArchitecturePerformanceLLM

Transformer Attention Mechanisms: From Self-Attention to Flash Attention 3

A deep dive into transformer attention — the math, the memory bottleneck, and how Flash Attention 3 achieves 1.5–2x speedups through hardware-aware algorithm design.

Rohit Raj··4 min read

Introduction

Attention is the heart of every modern language model. Understanding it deeply — from the original formulation to the hardware-aware optimizations that make 100B-parameter models practical — is essential for any serious AI engineer.

The Self-Attention Mechanism

The original scaled dot-product attention from "Attention Is All You Need":

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

Where:

  • QRn×dkQ \in \mathbb{R}^{n \times d_k} — queries (what am I looking for?)
  • KRn×dkK \in \mathbb{R}^{n \times d_k} — keys (what do I contain?)
  • VRn×dvV \in \mathbb{R}^{n \times d_v} — values (what do I return?)
  • nn = sequence length, dkd_k = key dimension

In code:

python
import torch
import torch.nn.functional as F
import math
 
def scaled_dot_product_attention(
    Q: torch.Tensor,
    K: torch.Tensor,
    V: torch.Tensor,
    mask: torch.Tensor | None = None
) -> torch.Tensor:
    d_k = Q.shape[-1]
    
    # Compute attention scores
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
    
    # Apply causal mask (decoder self-attention)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
    
    # Softmax + weighted sum
    weights = F.softmax(scores, dim=-1)
    return torch.matmul(weights, V)

The Memory Bottleneck

Standard attention is O(n2)O(n^2) in both time and memory. For a sequence of 4096 tokens with FP16:

Memory=n2×heads×2 bytes=40962×32×21 GB\text{Memory} = n^2 \times \text{heads} \times 2 \text{ bytes} = 4096^2 \times 32 \times 2 \approx 1 \text{ GB}

This is just for the attention matrix — before the FFN, embeddings, or activations. At 100K sequence length, this explodes to 800 GB.

Multi-Head Attention

Rather than a single attention, use hh parallel heads, each projecting to a lower dimension:

MultiHead(Q,K,V)=Concat(head1,,headh)WO\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h) W^O

Where headi=Attention(QWiQ,KWiK,VWiV)\text{head}_i = \text{Attention}(Q W_i^Q, K W_i^K, V W_i^V)

python
class MultiHeadAttention(torch.nn.Module):
    def __init__(self, d_model: int, num_heads: int):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_k = d_model // num_heads
        self.heads = num_heads
 
        self.W_q = torch.nn.Linear(d_model, d_model)
        self.W_k = torch.nn.Linear(d_model, d_model)
        self.W_v = torch.nn.Linear(d_model, d_model)
        self.W_o = torch.nn.Linear(d_model, d_model)
 
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, N, _ = x.shape
 
        Q = self.W_q(x).view(B, N, self.heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(B, N, self.heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(B, N, self.heads, self.d_k).transpose(1, 2)
 
        attn = scaled_dot_product_attention(Q, K, V)
        attn = attn.transpose(1, 2).contiguous().view(B, N, -1)
        return self.W_o(attn)

Flash Attention: Hardware-Aware Algorithm Design

Flash Attention (Dao et al., 2022) achieves the same mathematical result as standard attention but with O(n)O(n) memory by exploiting GPU memory hierarchy:

Key insight: Avoid materializing the full n×nn \times n attention matrix in HBM (slow). Instead, process attention in tiles that fit in SRAM (fast):

Standard: HBM → read Q,K,V | compute S=QKᵀ | write S | read S | softmax | write P | P@V → HBM
           (Many slow HBM reads/writes)

FlashAttn: HBM → tile of Q,K,V → SRAM → compute full attention tile → accumulate output → HBM
           (Far fewer HBM accesses)

Flash Attention 3 (2024) adds:

  • Warp specialization: Separate CUDA warps for matrix multiply vs. softmax
  • Pingpong scheduling: Overlaps compute and memory ops
  • FP8 support: 1.5–2x speedup on H100 GPUs
python
# In practice — just use the library
from flash_attn import flash_attn_func
 
# Drop-in replacement for standard attention
output = flash_attn_func(
    q, k, v,
    causal=True,           # Causal mask for autoregressive generation
    softmax_scale=1.0 / math.sqrt(d_k),
)

GQA and MQA: Reducing KV Cache

Modern LLMs use Grouped Query Attention (GQA) or Multi-Query Attention (MQA) to shrink the KV cache:

MethodQuery headsKV headsKV cache sizeQuality
MHA3232100%Best
GQA32825%Near-best
MQA3213.1%Slightly lower

Llama 3 and Mistral use GQA — the best tradeoff in practice.

Key Takeaways

  1. Self-attention is O(n2)O(n^2) — this is the fundamental constraint on context length
  2. Flash Attention makes long contexts practical — use it universally for training and inference
  3. GQA is the production standard for KV cache efficiency
  4. Understanding the memory math is critical for capacity planning

References

  • Vaswani et al., "Attention Is All You Need" (2017)
  • Dao et al., "FlashAttention-2: Faster Attention with Better Parallelism" (2023)
  • Ainslie et al., "GQA: Training Generalized Multi-Query Transformer Models" (2023)

Written by

Rohit Raj

Senior AI Engineer @ American Express

More posts →