Wednesday, December 24, 2025

How Attention Runs on Real Hardware

 From GPU streaming multiprocessors to TPU systolic arrays

1. How Scaled Dot-Product Attention Maps onto a GPU SM (Step by Step)

A GPU Streaming Multiprocessor (SM) is optimized for:

  • Massive parallelism

  • Matrix multiply (via Tensor Cores)

  • Hiding memory latency with threads

Now map the attention equation onto that reality.

Step 1: Load tiles of Q and K

  • Q and K are chunked into tiles

  • Each tile is loaded from global memory into shared memory

  • Threads cooperate to load contiguous blocks (coalesced access)

Hardware view:

This is a classic tiling pattern — exactly what GPUs are built for.


Step 2: Compute QKTQK^T

  • Each warp performs partial dot products

  • Tensor Cores accelerate this as a GEMM

  • Accumulators stay in registers

This is compute-dense and efficient:

  • High arithmetic intensity

  • Minimal memory traffic relative to FLOPs


Step 3: Apply scaling and softmax (fused)

  • Scaling by 1/dk1/\sqrt{d_k} is applied immediately

  • Softmax is computed row-wise

  • Max subtraction + exponentials + reduction are fused

Key hardware detail:

This fusion minimizes round-trips to memory and keeps values in registers.


Step 4: Multiply by V

  • The attention weights never need to be fully written out

  • They are immediately used to weight V

  • Another GEMM-like operation

This is where the structure really shines:

  • Attention becomes two back-to-back GEMMs with a fused nonlinearity


Step 5: Backward pass (gradient w.r.t. V)

LV=ATG\frac{\partial L}{\partial V} = A^T G
  • Pure matrix multiply

  • No softmax

  • No exponentials

  • Easy to tile and accumulate

Hardware takeaway:

Gradients w.r.t. V are “perfect GPU food.”


2. Why TPUs Love Attention Even More

TPUs are designed around systolic arrays:

  • Fixed-function matrix multiplication pipelines

  • Data flows rhythmically across the array

  • Minimal control flow

Scaled dot-product attention fits because:

a) Static shapes

  • Fixed-size matrix multiplications

  • Predictable memory access patterns

b) Linear algebra dominance

  • QKTQK^T, AVAV, and ATGA^T G all map cleanly

  • Softmax is a small, local deviation

c) Low-precision stability

  • Scaling keeps values in BF16-safe ranges

  • Prevents softmax collapse

TPU perspective:

Attention looks like a carefully staged sequence of systolic matrix multiplies with a small nonlinear interruption.

That’s almost exactly what TPUs were designed for.


3. Why Transformers Beat RNNs on Hardware (Not Just Modeling)

This is where the story really crystallizes.

RNNs on hardware

  • Sequential dependency

  • Time step tt depends on t1t-1

  • Poor parallelism

  • Underutilized compute units

Even with cuDNN optimizations:

  • SMs sit idle

  • Memory latency can’t be hidden effectively


Transformers on hardware

  • Entire sequence processed at once

  • No recurrence

  • Massive parallelism across tokens

  • Dominated by GEMMs

Hardware consequence:

GPUs and TPUs can finally operate at full occupancy.

This is not a small advantage — it’s often orders of magnitude.

That’s why:

  • Transformers scale with hardware

  • RNNs plateau early

  • “Attention is all you need” turned out to be literally true for accelerators


4. Attention Kernel Pseudocode (Hardware-Oriented)

Here’s a hardware-centric sketch — not Python, not PyTorch, but how a kernel “thinks”:

for each block of queries Q_tile: load Q_tile into shared memory initialize output_tile = 0 initialize row_max, row_sum for each block of keys/values K_tile, V_tile: load K_tile, V_tile into shared memory scores = Q_tile × K_tileáµ€ scores *= scale update row_max exp_scores = exp(scores - row_max) update row_sum output_tile += exp_scores × V_tile output_tile /= row_sum write output_tile to global memory

Backward w.r.t. V:

for each block: load attention_tile (or recompute) load grad_output_tile grad_V += attention_tileáµ€ × grad_output_tile

Notice:

  • Everything is tiled

  • Everything streams

  • Memory writes are minimized

  • Gradients w.r.t. V are clean and linear

This pseudocode is the hardware story.


The Unifying Insight 

Scaled dot-product attention is not just a successful abstraction — it is an unusually good match to the physical realities of modern accelerators. Its dominance of matrix multiplication, its numerical scaling, and even the simplicity of its gradients align with how GPUs and TPUs actually execute code. Transformers did not merely benefit from better hardware; they emerged as the first architecture to fully exploit it.

No comments: