Inside NVIDIA GPUs: Anatomy of high performance matmul kernels
In GPU programming, a kernel is:
A function that runs in parallel on many GPU threads.
More precisely:
-
A kernel is launched from the CPU
-
It runs on the GPU
-
Thousands to millions of lightweight threads execute the same code on different data
-
Each thread has a unique index (thread/block IDs)
Example intuition:
-
CPU function → runs once
-
GPU kernel → runs many times in parallel, once per element / tile / token
In deep learning:
-
Every major operation (matmul, attention, convolution, softmax) is implemented as one or more GPU kernels.
Why matmul kernels matter so much
Matrix multiplication (GEMM) is the canonical GPU kernel because it exercises every hard part of GPU programming:
-
Tiling and blocking
-
Shared memory usage
-
Register pressure
-
Memory coalescing
-
Instruction-level parallelism
-
Synchronization
-
Numerical precision tradeoffs
If you understand how high-performance matmul works, you understand the GPU.
That’s why this claim is true:
Understanding matmul kernels gives you the toolkit to design nearly any other high-performance GPU kernel.
What “other kernels” does this refer to?
Here are the major classes of GPU kernels that reuse matmul ideas.
1. Convolution kernels (CNNs)
Convolutions can be:
-
lowered to matrix multiplication (
im2col) -
or implemented as tiled stencil operations
Either way, they rely on:
-
blocking
-
shared memory reuse
-
careful memory layout
Mental model:
Convolution = structured matmul with spatial reuse.
2. Attention kernels (Transformers)
Modern attention kernels (FlashAttention, etc.) are:
-
tiled matrix multiplications
-
with softmax and normalization fused in
They rely on:
-
tiling QKᵀ
-
minimizing memory traffic
-
keeping intermediate values on-chip
Mental model:
Attention = matmul + reductions + numerically stable softmax.
3. Reduction kernels (sum, max, layer norm)
Reductions look simpler, but:
-
require careful synchronization
-
use warp-level primitives
-
must minimize memory divergence
Matmul teaches:
-
hierarchical reduction patterns
-
warp/block coordination
Mental model:
Reduction = matmul without the multiply.
4. Elementwise & fused kernels
Things like:
-
activation functions (GELU, ReLU)
-
bias add
-
dropout
-
residual connections
High performance comes from:
-
fusing many small ops
-
minimizing memory reads/writes
Matmul teaches:
-
why memory bandwidth dominates
-
why fusion matters
Mental model:
Elementwise kernels are memory-bound matmuls.
5. Embedding & lookup kernels
Used in:
-
token embeddings
-
sparse updates
-
recommender systems
Performance depends on:
-
memory coalescing
-
cache behavior
-
avoiding random access penalties
Matmul intuition helps you:
-
think in terms of contiguous tiles
-
restructure access patterns
6. Communication kernels (all-reduce, all-gather)
In multi-GPU training:
-
bandwidth and overlap dominate
-
computation must hide communication latency
Matmul teaches:
-
overlapping compute and memory
-
pipelining work across blocks
Why matmul is the “grammar” of GPU programming
Matmul kernels force you to learn:
| Concept | Why it matters |
|---|---|
| Tiling | Fits data into fast memory |
| Shared memory | Reuse data across threads |
| Registers | Avoid spilling |
| Memory coalescing | Maximize bandwidth |
| Synchronization | Correctness + performance |
| Precision formats | FP16, BF16, FP8 tradeoffs |
Once you internalize these, every other kernel is a variation.
Big picture (very Shape of Scale)
Transformers, RL, sampling, and scaling all depend on:
-
executing massive numbers of matmuls
-
keeping hardware saturated
-
minimizing memory movement
That’s why:
-
hardware is designed around matmul
-
software stacks optimize matmul first
-
learning matmul unlocks everything else
One-sentence takeaway
A kernel is a massively parallel GPU function, and matmul kernels are the canonical example that teaches every core idea needed to build high-performance GPU code.
No comments:
Post a Comment