Transpose
Naive transpose is deceptively bad.
@triton.jit
def matrix_transpose_kernel(
input, output, rows, cols, stride_ir, stride_ic, stride_or, stride_oc
):
row_pid = tl.program_id(0)
col_pid = tl.program_id(1)
src_offset = input + row_pid * stride_ir + col_pid * stride_ic
tar_offset = output + col_pid * stride_or + row_pid * stride_oc
val = tl.load(src_offset)
tl.store(tar_offset, val)
For row-major input:
- loads are contiguous across lanes -> coalesced
- stores are strided by
rows-> not coalesced
That means a warp can read with ~1 transaction but write with up to ~32 transactions. The kernel is no longer limited by FLOPs; it is limited by wasted memory transactions and poor store efficiency. (CUDA Best Practices Guide — Coalesced Access)
The tiled kernel fixes this by moving work from scalar addresses to blocked addresses:
@triton.jit
def matrix_transpose_kernel(
input,
output,
rows,
cols,
stride_ir,
stride_ic,
stride_or,
stride_oc,
ROW_BLOCK: tl.constexpr,
COL_BLOCK: tl.constexpr,
):
row_pid = tl.program_id(0)
col_pid = tl.program_id(1)
# get the tile offsets
row_offsets = row_pid * ROW_BLOCK + tl.arange(0, ROW_BLOCK)
col_offsets = col_pid * COL_BLOCK + tl.arange(0, COL_BLOCK)
input_offsets = (
input + row_offsets[:, None] * stride_ir + col_offsets[None, :] * stride_ic
)
in_mask = (row_offsets[:, None] < rows) & (col_offsets[None, :] < cols)
val = tl.load(input_offsets, mask=in_mask, other=0.0)
output_offsets = (
output + col_offsets[:, None] * stride_or + row_offsets[None, :] * stride_oc
)
out_mask = tl.trans(in_mask)
tl.store(output_offsets, tl.trans(val), mask=out_mask)
Key idea:
- load a
[ROW_BLOCK, COL_BLOCK]tile from input - transpose inside the tile
- store it so output lanes write contiguous rows
Now both sides are friendly to the memory system since both input loads and output stores are coalesced, so that compiler can lower this into vectorized global ops and a shared-memory transpose path. Every cache line is better utilized.
I feel like this is pretty similar idea as common ORM dataloader pattern where data with similar locality is batched and queried together for better overall throughput.
Reduction
data = tl.load(input + offsets, mask=mask, other=0.0)
sum = tl.sum(data)
tl.atomic_add(output, sum, sem="relaxed")
Each program loads one contiguous chunk, reduces locally with tl.sum in registers and emits one global atomic to merge its partial sum. Atomicity is required because multiple programs update the same output. Note that the kernel does one atomic per block, not per element, so it’s still efficient.
Memory semantics for atomic_add
sem controls ordering, not atomicity itself.
| Semantic | Ordering Guarantee | Use Cases |
|---|---|---|
relaxed | atomic, but no extra ordering guarantees | counters, reductions, statistics |
acquire | later memory ops cannot move before it | read a flag, then consume guarded data |
release | earlier memory ops cannot move after it | publish data, then set a flag |
acq_rel | both directions; strongest and usually slowest | locks, queues, synchronization primitives |
Stronger semantics can force additional ordering/fence behavior, reducing scheduling and memory-system freedom.
Under the hood: how atomicity is implemented
The atomicity semantic here is pretty interesting because it’s counter-intuitive to the “parallel” perception of kernel programs.
Does atomic_add require cross-warp communication?
No. Threads in different warps (or different SMs) never coordinate directly with each other. The serialization point is the L2 cache controller, not the warp scheduler. There is no barrier, no shared-memory exchange, and no warp shuffle involved. (PTX ISA — atom instruction)
So where does the buffering happen?
When multiple SMs issue atom.global.add to the same address simultaneously, the requests land in the L2 cache’s atomic queue — an internal FIFO per L2 partition. The controller:
- Receives incoming atomic requests into the buffer.
- Locks the cache line.
- Applies each read-modify-write one at a time, in (nondeterministic) arrival order.
- Releases the cache line and moves to the next queued request.
On Ampere/Hopper there is also a reduction network at the memory controller level that can coalesce some atomics before they even reach L2, further hiding latency.
Does issuing an atomic stall the thread? What about the rest of the warp?
It does stall the thread until RMW is acknowledged by the memory controller. However, the rest of the warp is not blocked because the scheduler switches to other ready warps during the stall. The latency is amortized.
Q: What does relaxed actually skip at the hardware level?
relaxed tells the compiler and hardware not to insert any memory fence around the atomic. Stronger semantics (acquire, release, acq_rel) force the hardware to flush or drain store buffers before or after the atomic, synchronizing the broader memory subsystem — that is the extra cost they pay. For a pure accumulation like this reduction, none of that flushing is necessary, so relaxed is the right choice.