Attention is the core operation inside every large language model. It runs thousands of times per forward pass, once per layer, per head, per token batch. For years it was the reason context windows were capped at 2048 to 4096 tokens: the memory cost of the N×N score matrix made longer sequences physically impossible on available hardware. FlashAttention changed that. GPT-4's 128k context window, Claude's 200k window, none of that exists without solving this problem.
I spent a week learning FlashAttention properly - not just running it, but understanding why it works at the hardware level. Then I implemented it in Triton from scratch. This is what I wish someone had explained to me before I read the paper.
The paper title is: FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. Every word in that title is load-bearing. By the end of this post you will know exactly what each one means.
Before touching attention at all, you need to understand one fact about modern GPUs. They have two types of memory that are not remotely equivalent:

HBM (High Bandwidth Memory) is what people mean when they say "GPU memory." An A100 has 80GB of it. Bandwidth: ~2 TB/s. SRAM is the on-chip cache - tiny (~20MB on A100), but bandwidth of ~19 TB/s. Nearly 10x faster than HBM.
Every time your kernel reads or writes data, it is either hitting fast SRAM or going out to slow HBM. The time spent moving data between the two is called IO - and for attention, this is the actual bottleneck, not the arithmetic.
An A100 can do 312 teraFLOPS. But if your kernel spends most of its time waiting for data from HBM, the compute units sit idle. Two regimes exist:
Compute-bound - arithmetic is the bottleneck, ALUs are saturated. Adding more FLOPS helps. This is where you want to be.
Memory-bound - IO is the bottleneck, ALUs are idle waiting for data. Adding more FLOPS does absolutely nothing. You need to reduce HBM reads/writes instead.

Attention is memory-bound. The matrix multiplications are fast. The bottleneck is reading Q, K, V from HBM and writing the N x N score matrix back. This is the entire problem FlashAttention solves.
The attention formula is deceptively simple:
Attention(Q, K, V) = softmax(QKᵀ / √d) * VA quick note on the terms. Q, K, and V are matrices of shape (N, d) where N is the sequence length and d is the head dimension. In practice, transformers split the full embedding dimension across multiple attention heads. If your model has 512-dimensional embeddings and 8 heads, each head operates on d=64. The sqrt(d) scaling exists because dot products grow in magnitude with d, and without the correction softmax would saturate into a near one-hot distribution, killing the ability to attend to multiple tokens at once.
Here is the problem. To compute softmax(QKᵀ), you need the full N x N score matrix to exist somewhere. With N=2048 tokens and float16:
# memory just for the score matrix
2048 × 2048 × 2 bytes = 8MB
# and you need to read/write it multiple times
Read Q: 256KB
Read K: 256KB
Read V: 256KB
Write scores S: 8MB ← the problem
Read S for softmax: 8MB
Write softmax P: 8MB
Read P for V mult: 8MB
─────────────────────
Total HBM traffic: ~33MBAt N=8192 that score matrix alone is 128MB. And it quadruples every time sequence length doubles. This is the O(N2) memory problem. Not a theoretical concern - it is why context length was stuck at 2048-4096 tokens for so long.
Here is the analogy that made it click for me.
Imagine you are a student grading 1000 exam papers. You need to rank every student against every other student. The naive way: read all 1000 papers, write every score on a giant spreadsheet, then rank. The spreadsheet is huge.
The smart way: read 10 papers at a time. Keep a running note of the highest score you have seen and a running total. Update as you go. Never write the full spreadsheet. End result is identical.
FlashAttention is the smart way. Instead of materializing the full N x N matrix, you process small tiles that fit in SRAM, accumulate results incrementally, and write the final output to HBM exactly once.
The outer loop goes over blocks of Q (rows). The inner loop scans all blocks of K and V (columns). Each small tile fits in SRAM. The full N x N matrix never exists anywhere.
The loop order is not arbitrary. You need to fully finish computing every output token before moving on. So the structure is: pick a chunk of queries, scan through every single K and V block completely, produce the outputs for those queries, then move to the next query chunk. If you flipped it, outer loop over K/V and inner loop over Q, you would never finish any single output in one pass. You would need to store intermediate results for all N queries simultaneously, which defeats the entire purpose.
Tiling sounds simple. The catch: softmax breaks it.
Normal softmax for a row of scores needs the full denominator before computing any output:
softmax(x_i) = exp(x_i) / sum(exp(x_j) for all j)That denominator requires seeing all scores simultaneously. But we are processing in chunks - we cannot see everything at once. This is the problem online softmax solves.
Instead of computing softmax all at once, you maintain three values that update each chunk:
m - the running maximum score seen so farl - the running sum of exp scoresO - the running weighted sum of V vectors (the actual output being built)
Every exp score is computed as exp(score - m), not exp(score). This is numerical stability - nothing else. exp() overflows to infinity very quickly:exp(100) is already 2.7x1043. Subtracting the max before exponentiation keeps all values between 0 and 1, preventing overflow.
This does not change the result at all. Since softmax divides numerator by denominator, the subtracted max cancels out perfectly:
exp(score - m) / sum(exp(scores - m))
# = exp(score) / sum(exp(scores)) <- identicalBack to the grading analogy. Imagine you gave grades relative to the highest score you had seen so far, then found a new exam with an even higher score. Every previous grade was normalized against the wrong maximum. You would need to go back and rescale all of them. The correction factor does exactly that, in a single multiplication.
Concretely: if chunk 1 gives max 4 and chunk 2 gives a score of 6, the new max is 6. All previous exp scores used max=4, so they are too large by a factor of exp(4 minus 6) = exp(-2) = 0.135. Multiply everything by that. Since the new max is always >= the old max, this correction is always <= 1. Always scaling down, never up.
The full update each inner loop iteration:
correction = exp(m_old - m_new) # always <= 1 since m_new >= m_old
m = m_new
l = correction * l + rowsum(P)
O = correction * O + P @ V_block
# all three statistics get the same correction simultaneously
# this is O(1) work regardless of how many chunks came beforeThis feels wasteful but it is extremely cheap. It is one multiplication per step, all in SRAM. The alternative - recomputing everything from scratch each time the max changes - would be O(N) work per update. The correction factor reduces that to O(1).
Putting it all together - this is the entire FlashAttention forward pass:
# outer loop - process Q in blocks
for i in range(0, N, BLOCK_M):
Q_block = Q[i:i+BLOCK_M] # load from HBM -> SRAM
O_block = zeros(BLOCK_M, d)
l_block = zeros(BLOCK_M)
m_block = full(BLOCK_M, -inf)
# inner loop - scan all K/V blocks
for j in range(0, N, BLOCK_N):
K_block = K[j:j+BLOCK_N] # load tile
V_block = V[j:j+BLOCK_N] # load tile
S = Q_block @ K_block.T / sqrt(d) # score tile
m_new = max(m_block, rowmax(S)) # update max
P = exp(S - m_new) # stable exp
correction = exp(m_block - m_new) # rescale factor
l_block = correction * l_block + rowsum(P)
O_block = correction * O_block + P @ V_block
m_block = m_new
# normalize and write back - only HBM write in the whole kernel
O[i:i+BLOCK_M] = O_block / l_blockAt no point does an N x N matrix exist. S is only ever (BLOCK_M x BLOCK_N) - fixed size regardless of N. That is why memory is O(N).
Language models use causal (autoregressive) attention - token i can only attend to tokens before it, not future ones. In standard attention you implement this by setting future scores to -inf before softmax. In FlashAttention you can do better.
For any tile where all column indices j > row indices i - meaning the entire K chunk is from the future relative to the Q chunk - you skip that tile completely. Do not load it, do not compute it. That is roughly half the inner loop iterations eliminated at long sequences. Real FLOPs saved, not just zeroed out.
For tiles on the diagonal where some j ≤ i and some j > i, you apply a per-element mask. Future positions get set to -inf before softmax.
The reason -inf specifically: exp(-inf) = 0 exactly, so those tokens contribute precisely zero to the softmax denominator and precisely zero to the weighted V sum. Setting future scores to zero would not work. softmax(0) is not zero, it is a small positive number (1 divided by the number of tokens), so masked positions would still bleed into the output. Negative infinity is the only value that survives the exp and comes out as exactly zero.
I ran this on a Tesla T4 (Google Colab free tier), comparing against naive PyTorch attention and PyTorch's built-in SDPA. All float16, batch=1, head_dim=64.
| Seq Len | Naive (MB) | Flash (MB) | Ratio |
|---|---|---|---|
| 512 | 1.06 | 0.06 | 18x |
| 1024 | 4.12 | 0.12 | 34x |
| 2048 | 16.25 | 0.25 | 65x |
| 4096 | 64.50 | 0.50 | 129x |
| 8192 | 257.00 | 1.00 | 257x |
The ratio doubles every time sequence length doubles. Naive memory quadruples (O(N2)), flash memory doubles (O(N)). At 8192 tokens, naive uses 257x more memory. This matches theory exactly.
| Seq Len | Naive (ms) | Flash (ms) | Causal (ms) | SDPA (ms) |
|---|---|---|---|---|
| 512 | 0.059 | 0.239 | 0.167 | 0.064 |
| 1024 | 0.080 | 0.441 | 0.264 | 0.123 |
| 2048 | 0.290 | 1.410 | 1.171 | 0.144 |
| 4096 | 1.062 | 3.929 | 2.082 | 0.363 |
| 8192 | 4.268 | 13.604 | 9.159 | 1.582 |
The kernel is slower than SDPA. That is expected and worth explaining honestly. PyTorch's SDPA calls a heavily optimized C++ FlashAttention implementation tuned by engineers over years, with production-grade pipelining and architecture-specific tricks. This is a from-scratch Triton kernel written to understand the algorithm, not to win benchmarks.
Two things are worth noting. First, causal masking gives a consistent 1.5-2x speedup over non-causal by genuinely skipping upper triangle tiles. Second, software pipelining (num_stages=3) gave a 2x speedup at seq=2048 by overlapping the next K/V tile fetch with the current tile's matrix multiply.
Before looking at the kernel, one thing worth understanding is how to pick BLOCK_M and BLOCK_N. These control how large each tile is. Too small and you have more inner loop iterations, more HBM loads per sequence, and more kernel launch overhead per Q block. Too large and the tile does not fit in SRAM, at which point the compiler spills to HBM and you lose the entire benefit. The right answer is hardware-specific: an A100 has more SRAM per streaming multiprocessor than a T4, so it can handle larger blocks. This is why the kernel uses @triton.autotune with a list of configs. Triton runs each config on a small input and picks the fastest one for the current hardware automatically. You never hardcode the block size.
The Triton kernel is literally the pseudocode above with three changes in notation:tl.load instead of array indexing,tl.dot instead of @, and pointer arithmetic instead of slice notation. The logic is identical.
@triton.jit
def flash_attention_kernel(
Q_ptr, K_ptr, V_ptr, O_ptr,
stride_qm, stride_qd, # how to step through Q in memory
N, HEAD_DIM,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
causal: tl.constexpr,
):
# which Q block does this kernel instance own?
block_m = tl.program_id(0)
off_m = block_m * BLOCK_M + tl.arange(0, BLOCK_M)
# load Q tile into SRAM
Q = tl.load(Q_ptr + off_m[:, None] * stride_qm + ...)
# initialize running stats - all in SRAM
O_acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
l_acc = tl.zeros([BLOCK_M], dtype=tl.float32)
m_acc = tl.full([BLOCK_M], float('-inf'), dtype=tl.float32)
# inner loop over K/V blocks
loop_end = block_m + 1 if causal else tl.cdiv(N, BLOCK_N)
for block_n in range(0, loop_end):
K = tl.load(K_ptr + ...)
V = tl.load(V_ptr + ...)
S = tl.dot(Q, tl.trans(K)) / tl.sqrt(HEAD_DIM)
# causal mask on diagonal tile
if causal:
mask = off_m[:, None] >= off_n[None, :]
S = tl.where(mask, S, float('-inf'))
m_new = tl.maximum(m_acc, tl.max(S, axis=1))
P = tl.exp(S - m_new[:, None])
correction = tl.exp(m_acc - m_new)
l_acc = correction * l_acc + tl.sum(P, axis=1)
O_acc = correction[:, None] * O_acc + tl.dot(P, V)
m_acc = m_new
# normalize and write back - one HBM write
O_acc = O_acc / l_acc[:, None]
tl.store(O_ptr + ..., O_acc)One detail worth understanding: tl.constexpr means the value is fixed at compile time rather than determined at runtime. When you pass causal=True, Triton compiles a completely separate kernel with the masking logic baked in. When you pass causal=False, it compiles a second kernel with that branch entirely removed. Think of it like a C preprocessor #ifdef, not a Python ifstatement. There is zero runtime overhead from the branch because it does not exist in the compiled kernel. The same applies to BLOCK_M and BLOCK_N, which is why they are also constexpr. Block sizes are fixed at compile time so the compiler can generate optimal memory access patterns for those exact dimensions.
One thing the pseudocode simplifies: everything shown is for a single attention head. Real transformers have H heads running in parallel, typically 8, 16, or 32. In the Triton kernel, this is handled by adding a head index dimension to the launch grid. Each kernel instance knows which head it owns and loads the corresponding slice of Q, K, and V. The heads never communicate with each other and run fully in parallel across streaming multiprocessors. Going from 1 head to 32 heads does not change the algorithm at all. It just means 32 instances of the same kernel running simultaneously.
Reading the paper felt approachable. Writing the kernel revealed the gaps.
The correction factor application order matters, and getting it wrong produces the worst class of bug: outputs that look plausible but are subtly wrong. My first implementation updated m, then l, then O, but computed the correction before updating m. The outputs passed shape checks. The values were off by small amounts that grew with sequence length. I found it by diffing against naive attention at seq=128 with a fixed random seed and printing the per-element absolute error. It was uniform across the sequence, which pointed to a systematic rescaling error rather than an indexing issue. The fix was four lines: compute correction first, then update all three statistics simultaneously. Order matters.
The masking at tile boundaries is fiddly in a specific way. When N is not divisible by the block size, the last tile loads out-of-bounds memory addresses. Triton will not crash. It loads garbage values. Those garbage values participate in the tl.max call that computes m_new. If any garbage value is larger than any real score, which is common since uninitialized memory is unpredictable, it becomes the running maximum and corrupts all subsequent normalizations. The fix: apply the boundary mask to S before the max operation, setting out-of-bounds positions to -inf so they cannot win the max. Masking after the max does not help. This took longer to debug than the correction factor bug.
The TFLOPS gap between my kernel and SDPA is real but not mysterious. Production FlashAttention uses persistent kernels (no kernel launch overhead per step), register-level optimizations, and architecture-specific tuning for tensor cores. A first implementation does not have any of that. Understanding why the gap exists matters more than closing it in a weekend.
Full implementation, 192 correctness tests (96 non-causal + 96 causal), benchmark suite, and plots at:
github.com/Vinesh2929/FlashAttention-kernel-tritonThe correctness tests compare against naive PyTorch attention (non-causal) and F.scaled_dot_product_attention(is_causal=True) (causal). 72/72 passing across all sequence lengths, head dims, batch sizes, and dtypes.