Wrappers
For ops whose operand semantics break the chain default, PTX.jl ships hand-written wrapper methods on the matching Operation{...} singleton. The user-facing call site is identical to the chain default — the wrapper just provides a typed method that takes priority over the @generated chain dispatch when its argument types match.
There are four patterns that force a wrapper:
- Mixed address-space pointer constraints. Shared-AS pointers need the
r(32-bit) constraint; the chain emitsl(64-bit) for anyLLVMPtr.cp.async,ldmatrix,stmatrix,mbarrieruse this override to satisfy ptxas. - Multi-output return. The chain returns one value; ldmatrix x2/x4, mma fragments, and shfl-with-pred-output all return tuples.
- Special operand layout. Braced register-vector operands (
{$N, $N+1, ...}), tensor-coord forms ([ptr, {c0, c1, c2}]), tied accumulators for wgmma — none have a chain-default rendering. - Compile-time-constant operand types. TMA's
<N>drank, mma shape/dtype, fragment counts — these need to be pinned at the registration boundary so each variant gets its own typed method.
Each wrapper file follows the same shape: a small declarative table maps (shape, dtype, …) to register counts and constraints, a _register function builds the asm string + LLVM call, and a for loop emits a typed method per valid combination.
Adding a new dtype/shape combination is one entry in the table plus one line in the loop.
Family overview
| Family | File | Surface |
|---|---|---|
cp.async (scalar) | wrappers/cp_async.jl | cp.async.{ca,cg}.shared.global [smem], [global], Val(N) — needs shared r constraint and N-baked size |
cp.async.bulk.tensor (TMA) | wrappers/tma.jl | cp.async.bulk.tensor.{1..5}d.{shared::cluster|shared::cta}.global.tile.mbarrier::complete_tx::bytes (load) and .global.shared::cta.tile.bulk_group (store) — coord vector becomes a positional Int32 argument list |
cvt sub-byte FP packing | wrappers/cvt.jl | cvt.rn.satfinite.e2m1x2.{f32,f16x2,bf16x2} pack and cvt.rn.{f16x2,bf16x2}.e2m1x2 unpack — .b8 carrier through a mov.b16 brace-pair shim because NVPTX has no i8 constraint |
ldmatrix | wrappers/ldmatrix.jl | ldmatrix.sync.aligned.{m8n8,m16n16}.{x1,x2,x4}[.trans].{shared,shared::cta}.{b16,b8} returning UInt32 (x1) or NTuple{N, UInt32} (x2/x4) |
mbarrier | wrappers/mbarrier.jl | init / inval / arrive[.noComplete|.expecttx] / expecttx / testwait[.parity] / trywait[.parity] — three return shapes (Nothing / UInt64 state / Bool pred); shared-AS r constraint |
mma.sync.aligned | wrappers/mma.jl | mma.sync.aligned.<shape>.<layA>.<layB>.<d>.<a>.<b>.<c> for bf16/f16/tf32/FP8 (Ada) and kind::f8f6f4 5×5 sub-byte FP A/B (Blackwell sm_100a+); takes/returns NTuple{N, UInt32} for A/B and NTuple{M, Float32|UInt32} for C/D |
mma.sync.aligned.kind::mxf* (block-scaled) | wrappers/mma_scaled.jl | Three Blackwell-introduced kinds: mxf4, mxf4nvf4, mxf8f6f4. Operand layout (scale_data::UInt32, byte_id::UInt16, thread_id::UInt16) per side per PTX 9.2 §9.7.14.3 |
setp.dual | wrappers/setp.jl | setp.<cmp>.<dtype> with %p|%q dual-pred output — 6 cmps × 12 dtypes = 72 generated methods returning Tuple{Bool, Bool} |
shfl.sync | wrappers/shfl.jl | up / down / bfly / idx × b32 × {data-only, data+pred} |
stmatrix | wrappers/stmatrix.jl | mirror of ldmatrix — m8n8.b16 (sm_70+) and m16n8.b8 (Hopper) |
vec_ldst | wrappers/vec_ldst.jl | ld.global.v{2,4}.{f32,b32,b16} / st.global.v{2,4}.{f32,b32,b16} — braced register-vector I/O for HBM-saturating bandwidth |
wgmma.mma_async (Hopper sm_90a) | wrappers/wgmma.jl | wgmma.mma_async.sync.aligned.m64nNk{8,16,32}.<d>.<a>.<b> — accumulator passed by value (tied operands), N stepped by 8 from 8 to 256, 12 dtype tuples × 32 N-values = 384 methods |
tcgen05 (Blackwell sm100a/sm110a) | wrappers/tcgen05.jl | shift / dealloc / cp / ld / st whose taddr operand is a 32-bit TMEM address (returned by tcgen05.alloc), NOT a memory pointer — chain default brackets LLVMPtr but treats UInt32 as a plain scalar |
Sync ops (wgmma.fence, wgmma.commit_group, wgmma.wait_group, tcgen05.alloc, tcgen05.commit, …) flow through the chain default — their opcode prefix is in NONPURE_OPCODES, so they get ~{memory} and side_effects = true automatically.
Host-side descriptor builders
Two opcodes consume packed 64-bit shared-memory descriptors plus (Blackwell only) a 32-bit instruction descriptor. PTX.jl ships pure bit-packing helpers for both:
| Helper | Used by |
|---|---|
wgmma_descriptor | Hopper wgmma.mma_async SMEM operand encoding (14-bit field windows + swizzle + base offset) |
tcgen05_descriptor | Blackwell tcgen05.mma SMEM operand encoding (3-bit layout vs wgmma's 2-bit; adds version and lbo_mode) |
tcgen05_instr_desc_f16bf16_f32 | Blackwell tcgen05.mma 32-bit instruction descriptor for the dense F16/BF16/TF32 → F32 path. Mirrors CUTLASS/CuTe's UMMA::make_instr_desc. |
smem_addr_u32 | Convert a Core.LLVMPtr{T, AS.Shared} to its 32-bit in-CTA SMEM offset (used as the smem_addr_u32 argument to the descriptor builders). |
These are not exported but are part of the documented API. Access them as PTX.wgmma_descriptor, PTX.tcgen05_descriptor, etc.
GMMA layout helpers
For wgmma.mma_async SMEM operands, the descriptor's leading_byte_offset / stride_byte_offset / swizzle triple is fully determined by the tile geometry (dtype, M-or-N, K, major axis). The four canonical GMMA layout families (INTERLEAVE / B32 / B64 / B128) cover all wgmma-compatible SMEM tile widths.
PTX.pick_gmma_layout — Function
pick_gmma_layout(; elem_bytes, m_or_n, k, major) -> GmmaLayoutCanonical GMMA layout for a (dtype, M-or-N, K, major) SMEM tile. major is :K or :MN. See module-top comment for the full mapping.
PTX.layout_for_a — Function
layout_for_a(; dtype, m, k) -> GmmaLayoutK-major GMMA layout — natural for row-major A (MxK with K-fast), and also the right pick for operand B when the kernel lays B as row-major KxN in SMEM with K-fast (the common case across the Hopper kernels in this repo). Use this for either operand whose K-dimension is fastest-varying in SMEM.
PTX.layout_for_mn_major — Function
layout_for_mn_major(; dtype, k, n) -> GmmaLayoutMN-major GMMA layout — operand B laid as col-major KxN in SMEM with N-fast (i.e. the N-or-M dimension is fastest-varying). Required only when the SMEM tile genuinely needs MN-fastness and a transposed wgmma trans_b=1 flag is used. Most Hopper kernels in this repo use layout_for_a for both operands and let wgmma run with trans_b=0 against a K-fast B tile; reach for layout_for_mn_major only when the data layout forces the issue.
Host-side TMA descriptor encoder
Hopper TMA (cp.async.bulk.tensor.*) consumes a 128-byte CUtensorMap blob built host-side by the CUDA driver's cuTensorMapEncodeTiled. PTX.jl wraps the driver call so descriptors take Julia types / symbols instead of raw CUtensorMapDataType enums, with a thin convenience helper for the common 2D row-major case.
These methods live in ext/PTXCUDACoreExt.jl and load automatically when CUDACore is in the environment.
PTX.tensor_map_encode_tiled — Function
tensor_map_encode_tiled(dtype, global_addr, global_dim, global_strides,
box_dim; kwargs...) -> CuTensorMapBuild a CUtensorMap for a tiled TMA descriptor. Calls cuTensorMapEncodeTiled.
Arguments (innermost-first convention, matching the driver):
dtype— Julia type (Float32,UInt16, …) or symbol (:bf16,:tf32, …)global_addr—Ptr{T},CuPtr, orUInt. Caller-owned; the driver stores it in the descriptor (usecuTensorMapReplaceAddressto swap later).global_dim::NTuple{N, <:Integer}— tensor shape, innermost first.global_strides::NTuple{N-1, <:Integer}— outer-dim strides in bytes. Length must beN-1(the innermost stride iselem_bytes * 1, implicit).box_dim::NTuple{N, <:Integer}— per-launch tile shape, innermost first.
Keyword arguments:
elem_strides::NTuple{N, <:Integer} = (1, 1, …)— sub-tile stride.interleave::Symbol = :NONE—:NONE/:B16/:B32.swizzle::Symbol = :NONE—:NONE/:B32/:B64/:B128(plus Blackwell atom variants).l2_promotion::Symbol = :NONE—:NONE/:B64/:B128/:B256.oob_fill::Symbol = :NONE—:NONE/:NAN_REQUEST_ZERO_FMA.
Requires using CUDACore. Without it the package extension is inert and this function raises a clear error.
PTX.tensor_map_tile_2d — Function
tensor_map_tile_2d(dtype, global_addr, rows, cols, box_rows, box_cols;
swizzle=:B128, oob_fill=:NONE) -> CuTensorMapConvenience for a 2D row-major (rows, cols) tensor + (box_rows, box_cols) tile. Innermost dim is cols. Stride is cols * elem_bytes.
Mirrors pyptx synthesize_tma_descriptor (2D path). Caller is responsible for keeping box_cols * elem_bytes consistent with the swizzle (e.g. 128B for :B128).
When to extend
Most chain-default coverage is sufficient. Reach for a wrapper when:
- ptxas rejects the chain output ("Arguments mismatch", wrong constraint letter on a shared-AS pointer, missing brackets on a memory operand);
- the op returns multiple values;
- the operand layout has a brace group, tensor-coord vector, or other shape with no
$Nrendering; - the dispatch needs to key on a fragment shape or count that the chain can't see.
The pattern: copy the closest existing wrapper file, adjust the table and asm template, add a for loop that emits one typed method per combination. ~80 LOC for a new family in most cases.