FlashAttention from Scratch

The IO-awareness paper explained through a working Triton kernel.

Attention is the core operation inside every large language model. It runs thousands of times per forward pass, once per layer, per head, per token batch. For years it was the reason context windows were capped at 2048 to 4096 tokens: the memory cost of the N×N score matrix made longer sequences physically impossible on available hardware. FlashAttention changed that. GPT-4's 128k context window, Claude's 200k window, none of that exists without solving this problem.

I spent a week learning FlashAttention properly - not just running it, but understanding why it works at the hardware level. Then I implemented it in Triton from scratch. This is what I wish someone had explained to me before I read the paper.

The paper title is: FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. Every word in that title is load-bearing. By the end of this post you will know exactly what each one means.

Same math. Same outputs. 257x less memory. The entire trick is where the computation happens, not what the computation is.

The hardware reality first

Before touching attention at all, you need to understand one fact about modern GPUs. They have two types of memory that are not remotely equivalent:

fig 1 - memory hierarchy (A100)
GPU memory hierarchy
fig 1b - from the FlashAttention paper (Dao et al., 2022)

HBM (High Bandwidth Memory) is what people mean when they say "GPU memory." An A100 has 80GB of it. Bandwidth: ~2 TB/s. SRAM is the on-chip cache - tiny (~20MB on A100), but bandwidth of ~19 TB/s. Nearly 10x faster than HBM.

Every time your kernel reads or writes data, it is either hitting fast SRAM or going out to slow HBM. The time spent moving data between the two is called IO - and for attention, this is the actual bottleneck, not the arithmetic.

Why more FLOPS does not mean faster

An A100 can do 312 teraFLOPS. But if your kernel spends most of its time waiting for data from HBM, the compute units sit idle. Two regimes exist:

Compute-bound - arithmetic is the bottleneck, ALUs are saturated. Adding more FLOPS helps. This is where you want to be.

Memory-bound - IO is the bottleneck, ALUs are idle waiting for data. Adding more FLOPS does absolutely nothing. You need to reduce HBM reads/writes instead.

Roofline model
fig 1c - roofline model: operations below the ridge point are memory-bound

Attention is memory-bound. The matrix multiplications are fast. The bottleneck is reading Q, K, V from HBM and writing the N x N score matrix back. This is the entire problem FlashAttention solves.


What vanilla attention actually does to memory

The attention formula is deceptively simple:

Attention(Q, K, V) = softmax(QKᵀ / √d) * V
Qthe question each token is asking
Kthe label each token is advertising
Vthe content each token actually carries
QKᵀhow much every token should attend to every other token
/ √dkeep scores from getting so large softmax stops working
softmaxturn raw scores into weights that sum to 1
* Vblend the content vectors using those weights as the recipe

A quick note on the terms. Q, K, and V are matrices of shape (N, d) where N is the sequence length and d is the head dimension. In practice, transformers split the full embedding dimension across multiple attention heads. If your model has 512-dimensional embeddings and 8 heads, each head operates on d=64. The sqrt(d) scaling exists because dot products grow in magnitude with d, and without the correction softmax would saturate into a near one-hot distribution, killing the ability to attend to multiple tokens at once.

The cat sat on the mat because it was tired.
Qwhat noun am I referring to?
Kcat0.82
Kmat0.14
Ksat0.04
Voutput = 0.82 × Vcat + 0.14 × Vmat + 0.04 × Vsat + …

Here is the problem. To compute softmax(QKᵀ), you need the full N x N score matrix to exist somewhere. With N=2048 tokens and float16:

# memory just for the score matrix
2048 × 2048 × 2 bytes = 8MB

# and you need to read/write it multiple times
Read Q:             256KB
Read K:             256KB
Read V:             256KB
Write scores S:     8MB   ← the problem
Read S for softmax: 8MB
Write softmax P:    8MB
Read P for V mult:  8MB
─────────────────────
Total HBM traffic:  ~33MB

At N=8192 that score matrix alone is 128MB. And it quadruples every time sequence length doubles. This is the O(N2) memory problem. Not a theoretical concern - it is why context length was stuck at 2048-4096 tokens for so long.


The key insight - tiling

Here is the analogy that made it click for me.

Imagine you are a student grading 1000 exam papers. You need to rank every student against every other student. The naive way: read all 1000 papers, write every score on a giant spreadsheet, then rank. The spreadsheet is huge.

The smart way: read 10 papers at a time. Keep a running note of the highest score you have seen and a running total. Update as you go. Never write the full spreadsheet. End result is identical.

FlashAttention is the smart way. Instead of materializing the full N x N matrix, you process small tiles that fit in SRAM, accumulate results incrementally, and write the final output to HBM exactly once.

fig 2 - tiling: processing the attention matrix in blocks (click to replay)
outer loop: Q block 1 of 8

The outer loop goes over blocks of Q (rows). The inner loop scans all blocks of K and V (columns). Each small tile fits in SRAM. The full N x N matrix never exists anywhere.

The loop order is not arbitrary. You need to fully finish computing every output token before moving on. So the structure is: pick a chunk of queries, scan through every single K and V block completely, produce the outputs for those queries, then move to the next query chunk. If you flipped it, outer loop over K/V and inner loop over Q, you would never finish any single output in one pass. You would need to store intermediate results for all N queries simultaneously, which defeats the entire purpose.


The hard part - online softmax

Tiling sounds simple. The catch: softmax breaks it.

Normal softmax for a row of scores needs the full denominator before computing any output:

softmax(x_i) = exp(x_i) / sum(exp(x_j) for all j)

That denominator requires seeing all scores simultaneously. But we are processing in chunks - we cannot see everything at once. This is the problem online softmax solves.

Three running statistics

Instead of computing softmax all at once, you maintain three values that update each chunk:

m - the running maximum score seen so far
l - the running sum of exp scores
O - the running weighted sum of V vectors (the actual output being built)

fig 3 - online softmax: running statistics update each chunk
m (max)
-inf
running maximum
l (sum)
0
running exp sum
correction
-
exp(m_old - m_new)

Why exp(score - m)?

Every exp score is computed as exp(score - m), not exp(score). This is numerical stability - nothing else. exp() overflows to infinity very quickly:exp(100) is already 2.7x1043. Subtracting the max before exponentiation keeps all values between 0 and 1, preventing overflow.

This does not change the result at all. Since softmax divides numerator by denominator, the subtracted max cancels out perfectly:

exp(score - m) / sum(exp(scores - m))
# = exp(score) / sum(exp(scores))   <- identical

The correction factor

Back to the grading analogy. Imagine you gave grades relative to the highest score you had seen so far, then found a new exam with an even higher score. Every previous grade was normalized against the wrong maximum. You would need to go back and rescale all of them. The correction factor does exactly that, in a single multiplication.

Concretely: if chunk 1 gives max 4 and chunk 2 gives a score of 6, the new max is 6. All previous exp scores used max=4, so they are too large by a factor of exp(4 minus 6) = exp(-2) = 0.135. Multiply everything by that. Since the new max is always >= the old max, this correction is always <= 1. Always scaling down, never up.

The full update each inner loop iteration:

correction = exp(m_old - m_new)    # always <= 1 since m_new >= m_old

m     = m_new
l     = correction * l + rowsum(P)
O     = correction * O + P @ V_block

# all three statistics get the same correction simultaneously
# this is O(1) work regardless of how many chunks came before

This feels wasteful but it is extremely cheap. It is one multiplication per step, all in SRAM. The alternative - recomputing everything from scratch each time the max changes - would be O(N) work per update. The correction factor reduces that to O(1).


The complete algorithm

Putting it all together - this is the entire FlashAttention forward pass:

# outer loop - process Q in blocks
for i in range(0, N, BLOCK_M):
    Q_block = Q[i:i+BLOCK_M]     # load from HBM -> SRAM
    O_block = zeros(BLOCK_M, d)
    l_block = zeros(BLOCK_M)
    m_block = full(BLOCK_M, -inf)

    # inner loop - scan all K/V blocks
    for j in range(0, N, BLOCK_N):
        K_block = K[j:j+BLOCK_N]     # load tile
        V_block = V[j:j+BLOCK_N]     # load tile

        S = Q_block @ K_block.T / sqrt(d)   # score tile
        m_new = max(m_block, rowmax(S))      # update max
        P = exp(S - m_new)                   # stable exp
        correction = exp(m_block - m_new)    # rescale factor

        l_block = correction * l_block + rowsum(P)
        O_block = correction * O_block + P @ V_block
        m_block = m_new

    # normalize and write back - only HBM write in the whole kernel
    O[i:i+BLOCK_M] = O_block / l_block

At no point does an N x N matrix exist. S is only ever (BLOCK_M x BLOCK_N) - fixed size regardless of N. That is why memory is O(N).


Causal masking - skipping half the work

Language models use causal (autoregressive) attention - token i can only attend to tokens before it, not future ones. In standard attention you implement this by setting future scores to -inf before softmax. In FlashAttention you can do better.

fig 4 - causal masking: upper triangle skipped entirely

For any tile where all column indices j > row indices i - meaning the entire K chunk is from the future relative to the Q chunk - you skip that tile completely. Do not load it, do not compute it. That is roughly half the inner loop iterations eliminated at long sequences. Real FLOPs saved, not just zeroed out.

For tiles on the diagonal where some j ≤ i and some j > i, you apply a per-element mask. Future positions get set to -inf before softmax.

The reason -inf specifically: exp(-inf) = 0 exactly, so those tokens contribute precisely zero to the softmax denominator and precisely zero to the weighted V sum. Setting future scores to zero would not work. softmax(0) is not zero, it is a small positive number (1 divided by the number of tokens), so masked positions would still bleed into the output. Negative infinity is the only value that survives the exp and comes out as exactly zero.


What the numbers actually look like

I ran this on a Tesla T4 (Google Colab free tier), comparing against naive PyTorch attention and PyTorch's built-in SDPA. All float16, batch=1, head_dim=64.

The short version: at sequence length 8192, this kernel allocates 1 MB of GPU memory where standard attention allocates 257 MB. Same outputs, verified against naive attention to float16 tolerance.

Memory - the headline result

fig 5 - peak GPU memory at each sequence length
Seq LenNaive (MB)Flash (MB)Ratio
5121.060.0618x
10244.120.1234x
204816.250.2565x
409664.500.50129x
8192257.001.00257x

The ratio doubles every time sequence length doubles. Naive memory quadruples (O(N2)), flash memory doubles (O(N)). At 8192 tokens, naive uses 257x more memory. This matches theory exactly.

Speed - an honest read

Seq LenNaive (ms)Flash (ms)Causal (ms)SDPA (ms)
5120.0590.2390.1670.064
10240.0800.4410.2640.123
20480.2901.4101.1710.144
40961.0623.9292.0820.363
81924.26813.6049.1591.582

The kernel is slower than SDPA. That is expected and worth explaining honestly. PyTorch's SDPA calls a heavily optimized C++ FlashAttention implementation tuned by engineers over years, with production-grade pipelining and architecture-specific tricks. This is a from-scratch Triton kernel written to understand the algorithm, not to win benchmarks.

Two things are worth noting. First, causal masking gives a consistent 1.5-2x speedup over non-causal by genuinely skipping upper triangle tiles. Second, software pipelining (num_stages=3) gave a 2x speedup at seq=2048 by overlapping the next K/V tile fetch with the current tile's matrix multiply.

The memory savings are real, identical to SDPA, and match theory exactly. The speed gap is real too - and understanding why it exists is more valuable than pretending it does not.

How the pseudocode becomes Triton

Before looking at the kernel, one thing worth understanding is how to pick BLOCK_M and BLOCK_N. These control how large each tile is. Too small and you have more inner loop iterations, more HBM loads per sequence, and more kernel launch overhead per Q block. Too large and the tile does not fit in SRAM, at which point the compiler spills to HBM and you lose the entire benefit. The right answer is hardware-specific: an A100 has more SRAM per streaming multiprocessor than a T4, so it can handle larger blocks. This is why the kernel uses @triton.autotune with a list of configs. Triton runs each config on a small input and picks the fastest one for the current hardware automatically. You never hardcode the block size.

The Triton kernel is literally the pseudocode above with three changes in notation:tl.load instead of array indexing,tl.dot instead of @, and pointer arithmetic instead of slice notation. The logic is identical.

@triton.jit
def flash_attention_kernel(
    Q_ptr, K_ptr, V_ptr, O_ptr,
    stride_qm, stride_qd, # how to step through Q in memory
    N, HEAD_DIM,
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
    causal: tl.constexpr,
):
    # which Q block does this kernel instance own?
    block_m = tl.program_id(0)
    off_m = block_m * BLOCK_M + tl.arange(0, BLOCK_M)

    # load Q tile into SRAM
    Q = tl.load(Q_ptr + off_m[:, None] * stride_qm + ...)

    # initialize running stats - all in SRAM
    O_acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
    l_acc = tl.zeros([BLOCK_M], dtype=tl.float32)
    m_acc = tl.full([BLOCK_M], float('-inf'), dtype=tl.float32)

    # inner loop over K/V blocks
    loop_end = block_m + 1 if causal else tl.cdiv(N, BLOCK_N)
    for block_n in range(0, loop_end):
        K = tl.load(K_ptr + ...)
        V = tl.load(V_ptr + ...)

        S = tl.dot(Q, tl.trans(K)) / tl.sqrt(HEAD_DIM)

        # causal mask on diagonal tile
        if causal:
            mask = off_m[:, None] >= off_n[None, :]
            S = tl.where(mask, S, float('-inf'))

        m_new = tl.maximum(m_acc, tl.max(S, axis=1))
        P = tl.exp(S - m_new[:, None])
        correction = tl.exp(m_acc - m_new)

        l_acc = correction * l_acc + tl.sum(P, axis=1)
        O_acc = correction[:, None] * O_acc + tl.dot(P, V)
        m_acc = m_new

    # normalize and write back - one HBM write
    O_acc = O_acc / l_acc[:, None]
    tl.store(O_ptr + ..., O_acc)

One detail worth understanding: tl.constexpr means the value is fixed at compile time rather than determined at runtime. When you pass causal=True, Triton compiles a completely separate kernel with the masking logic baked in. When you pass causal=False, it compiles a second kernel with that branch entirely removed. Think of it like a C preprocessor #ifdef, not a Python ifstatement. There is zero runtime overhead from the branch because it does not exist in the compiled kernel. The same applies to BLOCK_M and BLOCK_N, which is why they are also constexpr. Block sizes are fixed at compile time so the compiler can generate optimal memory access patterns for those exact dimensions.

One thing the pseudocode simplifies: everything shown is for a single attention head. Real transformers have H heads running in parallel, typically 8, 16, or 32. In the Triton kernel, this is handled by adding a head index dimension to the launch grid. Each kernel instance knows which head it owns and loads the corresponding slice of Q, K, and V. The heads never communicate with each other and run fully in parallel across streaming multiprocessors. Going from 1 head to 32 heads does not change the algorithm at all. It just means 32 instances of the same kernel running simultaneously.


What actually took time to understand

Reading the paper felt approachable. Writing the kernel revealed the gaps.

The correction factor application order matters, and getting it wrong produces the worst class of bug: outputs that look plausible but are subtly wrong. My first implementation updated m, then l, then O, but computed the correction before updating m. The outputs passed shape checks. The values were off by small amounts that grew with sequence length. I found it by diffing against naive attention at seq=128 with a fixed random seed and printing the per-element absolute error. It was uniform across the sequence, which pointed to a systematic rescaling error rather than an indexing issue. The fix was four lines: compute correction first, then update all three statistics simultaneously. Order matters.

The masking at tile boundaries is fiddly in a specific way. When N is not divisible by the block size, the last tile loads out-of-bounds memory addresses. Triton will not crash. It loads garbage values. Those garbage values participate in the tl.max call that computes m_new. If any garbage value is larger than any real score, which is common since uninitialized memory is unpredictable, it becomes the running maximum and corrupts all subsequent normalizations. The fix: apply the boundary mask to S before the max operation, setting out-of-bounds positions to -inf so they cannot win the max. Masking after the max does not help. This took longer to debug than the correction factor bug.

The TFLOPS gap between my kernel and SDPA is real but not mysterious. Production FlashAttention uses persistent kernels (no kernel launch overhead per step), register-level optimizations, and architecture-specific tuning for tensor cores. A first implementation does not have any of that. Understanding why the gap exists matters more than closing it in a weekend.


The code

Full implementation, 192 correctness tests (96 non-causal + 96 causal), benchmark suite, and plots at:

github.com/Vinesh2929/FlashAttention-kernel-triton

The correctness tests compare against naive PyTorch attention (non-causal) and F.scaled_dot_product_attention(is_causal=True) (causal). 72/72 passing across all sequence lengths, head dims, batch sizes, and dtypes.