From GPU streaming multiprocessors to TPU systolic arrays
1. How Scaled Dot-Product Attention Maps onto a GPU SM (Step by Step)
A GPU Streaming Multiprocessor (SM) is optimized for:
-
Massive parallelism
-
Matrix multiply (via Tensor Cores)
-
Hiding memory latency with threads
Now map the attention equation onto that reality.
Step 1: Load tiles of Q and K
-
Q and K are chunked into tiles
-
Each tile is loaded from global memory into shared memory
-
Threads cooperate to load contiguous blocks (coalesced access)
Hardware view:
This is a classic tiling pattern — exactly what GPUs are built for.
Step 2: Compute
-
Each warp performs partial dot products
-
Tensor Cores accelerate this as a GEMM
-
Accumulators stay in registers
This is compute-dense and efficient:
-
High arithmetic intensity
-
Minimal memory traffic relative to FLOPs
Step 3: Apply scaling and softmax (fused)
-
Scaling by is applied immediately
-
Softmax is computed row-wise
-
Max subtraction + exponentials + reduction are fused
Key hardware detail:
This fusion minimizes round-trips to memory and keeps values in registers.
Step 4: Multiply by V
-
The attention weights never need to be fully written out
-
They are immediately used to weight V
-
Another GEMM-like operation
This is where the structure really shines:
-
Attention becomes two back-to-back GEMMs with a fused nonlinearity
Step 5: Backward pass (gradient w.r.t. V)
-
Pure matrix multiply
-
No softmax
-
No exponentials
-
Easy to tile and accumulate
Hardware takeaway:
Gradients w.r.t. V are “perfect GPU food.”
2. Why TPUs Love Attention Even More
TPUs are designed around systolic arrays:
-
Fixed-function matrix multiplication pipelines
-
Data flows rhythmically across the array
-
Minimal control flow
Scaled dot-product attention fits because:
a) Static shapes
-
Fixed-size matrix multiplications
-
Predictable memory access patterns
b) Linear algebra dominance
-
, , and all map cleanly
-
Softmax is a small, local deviation
c) Low-precision stability
-
Scaling keeps values in BF16-safe ranges
-
Prevents softmax collapse
TPU perspective:
Attention looks like a carefully staged sequence of systolic matrix multiplies with a small nonlinear interruption.
That’s almost exactly what TPUs were designed for.
3. Why Transformers Beat RNNs on Hardware (Not Just Modeling)
This is where the story really crystallizes.
RNNs on hardware
-
Sequential dependency
-
Time step depends on
-
Poor parallelism
-
Underutilized compute units
Even with cuDNN optimizations:
-
SMs sit idle
-
Memory latency can’t be hidden effectively
Transformers on hardware
-
Entire sequence processed at once
-
No recurrence
-
Massive parallelism across tokens
-
Dominated by GEMMs
Hardware consequence:
GPUs and TPUs can finally operate at full occupancy.
This is not a small advantage — it’s often orders of magnitude.
That’s why:
-
Transformers scale with hardware
-
RNNs plateau early
-
“Attention is all you need” turned out to be literally true for accelerators
4. Attention Kernel Pseudocode (Hardware-Oriented)
Here’s a hardware-centric sketch — not Python, not PyTorch, but how a kernel “thinks”:
Backward w.r.t. V:
Notice:
-
Everything is tiled
-
Everything streams
-
Memory writes are minimized
-
Gradients w.r.t. V are clean and linear
This pseudocode is the hardware story.
The Unifying Insight
Scaled dot-product attention is not just a successful abstraction — it is an unusually good match to the physical realities of modern accelerators. Its dominance of matrix multiplication, its numerical scaling, and even the simplicity of its gradients align with how GPUs and TPUs actually execute code. Transformers did not merely benefit from better hardware; they emerged as the first architecture to fully exploit it.
No comments:
Post a Comment