Wrappers

For ops whose operand semantics break the chain default, PTX.jl ships hand-written wrapper methods on the matching Operation{...} singleton. The user-facing call site is identical to the chain default — the wrapper just provides a typed method that takes priority over the @generated chain dispatch when its argument types match.

There are four patterns that force a wrapper:

  1. Mixed address-space pointer constraints. Shared-AS pointers need the r (32-bit) constraint; the chain emits l (64-bit) for any LLVMPtr. cp.async, ldmatrix, stmatrix, mbarrier use this override to satisfy ptxas.
  2. Multi-output return. The chain returns one value; ldmatrix x2/x4, mma fragments, and shfl-with-pred-output all return tuples.
  3. Special operand layout. Braced register-vector operands ({$N, $N+1, ...}), tensor-coord forms ([ptr, {c0, c1, c2}]), tied accumulators for wgmma — none have a chain-default rendering.
  4. Compile-time-constant operand types. TMA's <N>d rank, mma shape/dtype, fragment counts — these need to be pinned at the registration boundary so each variant gets its own typed method.

Each wrapper file follows the same shape: a small declarative table maps (shape, dtype, …) to register counts and constraints, a _register function builds the asm string + LLVM call, and a for loop emits a typed method per valid combination.

Adding a new dtype/shape combination is one entry in the table plus one line in the loop.

Family overview

FamilyFileSurface
cp.async (scalar)wrappers/cp_async.jlcp.async.{ca,cg}.shared.global [smem], [global], Val(N) — needs shared r constraint and N-baked size
cp.async.bulk.tensor (TMA)wrappers/tma.jlcp.async.bulk.tensor.{1..5}d.{shared::cluster|shared::cta}.global.tile.mbarrier::complete_tx::bytes (load) and .global.shared::cta.tile.bulk_group (store) — coord vector becomes a positional Int32 argument list
cvt sub-byte FP packingwrappers/cvt.jlcvt.rn.satfinite.e2m1x2.{f32,f16x2,bf16x2} pack and cvt.rn.{f16x2,bf16x2}.e2m1x2 unpack — .b8 carrier through a mov.b16 brace-pair shim because NVPTX has no i8 constraint
ldmatrixwrappers/ldmatrix.jlldmatrix.sync.aligned.{m8n8,m16n16}.{x1,x2,x4}[.trans].{shared,shared::cta}.{b16,b8} returning UInt32 (x1) or NTuple{N, UInt32} (x2/x4)
mbarrierwrappers/mbarrier.jlinit / inval / arrive[.noComplete|.expecttx] / expecttx / testwait[.parity] / trywait[.parity] — three return shapes (Nothing / UInt64 state / Bool pred); shared-AS r constraint
mma.sync.alignedwrappers/mma.jlmma.sync.aligned.<shape>.<layA>.<layB>.<d>.<a>.<b>.<c> for bf16/f16/tf32/FP8 (Ada) and kind::f8f6f4 5×5 sub-byte FP A/B (Blackwell sm_100a+); takes/returns NTuple{N, UInt32} for A/B and NTuple{M, Float32|UInt32} for C/D
mma.sync.aligned.kind::mxf* (block-scaled)wrappers/mma_scaled.jlThree Blackwell-introduced kinds: mxf4, mxf4nvf4, mxf8f6f4. Operand layout (scale_data::UInt32, byte_id::UInt16, thread_id::UInt16) per side per PTX 9.2 §9.7.14.3
setp.dualwrappers/setp.jlsetp.<cmp>.<dtype> with %p|%q dual-pred output — 6 cmps × 12 dtypes = 72 generated methods returning Tuple{Bool, Bool}
shfl.syncwrappers/shfl.jlup / down / bfly / idx × b32 × {data-only, data+pred}
stmatrixwrappers/stmatrix.jlmirror of ldmatrixm8n8.b16 (sm_70+) and m16n8.b8 (Hopper)
vec_ldstwrappers/vec_ldst.jlld.global.v{2,4}.{f32,b32,b16} / st.global.v{2,4}.{f32,b32,b16} — braced register-vector I/O for HBM-saturating bandwidth
wgmma.mma_async (Hopper sm_90a)wrappers/wgmma.jlwgmma.mma_async.sync.aligned.m64nNk{8,16,32}.<d>.<a>.<b> — accumulator passed by value (tied operands), N stepped by 8 from 8 to 256, 12 dtype tuples × 32 N-values = 384 methods
tcgen05 (Blackwell sm100a/sm110a)wrappers/tcgen05.jlshift / dealloc / cp / ld / st whose taddr operand is a 32-bit TMEM address (returned by tcgen05.alloc), NOT a memory pointer — chain default brackets LLVMPtr but treats UInt32 as a plain scalar

Sync ops (wgmma.fence, wgmma.commit_group, wgmma.wait_group, tcgen05.alloc, tcgen05.commit, …) flow through the chain default — their opcode prefix is in NONPURE_OPCODES, so they get ~{memory} and side_effects = true automatically.

Host-side descriptor builders

Two opcodes consume packed 64-bit shared-memory descriptors plus (Blackwell only) a 32-bit instruction descriptor. PTX.jl ships pure bit-packing helpers for both:

HelperUsed by
wgmma_descriptorHopper wgmma.mma_async SMEM operand encoding (14-bit field windows + swizzle + base offset)
tcgen05_descriptorBlackwell tcgen05.mma SMEM operand encoding (3-bit layout vs wgmma's 2-bit; adds version and lbo_mode)
tcgen05_instr_desc_f16bf16_f32Blackwell tcgen05.mma 32-bit instruction descriptor for the dense F16/BF16/TF32 → F32 path. Mirrors CUTLASS/CuTe's UMMA::make_instr_desc.
smem_addr_u32Convert a Core.LLVMPtr{T, AS.Shared} to its 32-bit in-CTA SMEM offset (used as the smem_addr_u32 argument to the descriptor builders).

These are not exported but are part of the documented API. Access them as PTX.wgmma_descriptor, PTX.tcgen05_descriptor, etc.

GMMA layout helpers

For wgmma.mma_async SMEM operands, the descriptor's leading_byte_offset / stride_byte_offset / swizzle triple is fully determined by the tile geometry (dtype, M-or-N, K, major axis). The four canonical GMMA layout families (INTERLEAVE / B32 / B64 / B128) cover all wgmma-compatible SMEM tile widths.

PTX.pick_gmma_layoutFunction
pick_gmma_layout(; elem_bytes, m_or_n, k, major) -> GmmaLayout

Canonical GMMA layout for a (dtype, M-or-N, K, major) SMEM tile. major is :K or :MN. See module-top comment for the full mapping.

source
PTX.layout_for_aFunction
layout_for_a(; dtype, m, k) -> GmmaLayout

K-major GMMA layout — natural for row-major A (MxK with K-fast), and also the right pick for operand B when the kernel lays B as row-major KxN in SMEM with K-fast (the common case across the Hopper kernels in this repo). Use this for either operand whose K-dimension is fastest-varying in SMEM.

source
PTX.layout_for_mn_majorFunction
layout_for_mn_major(; dtype, k, n) -> GmmaLayout

MN-major GMMA layout — operand B laid as col-major KxN in SMEM with N-fast (i.e. the N-or-M dimension is fastest-varying). Required only when the SMEM tile genuinely needs MN-fastness and a transposed wgmma trans_b=1 flag is used. Most Hopper kernels in this repo use layout_for_a for both operands and let wgmma run with trans_b=0 against a K-fast B tile; reach for layout_for_mn_major only when the data layout forces the issue.

source

Host-side TMA descriptor encoder

Hopper TMA (cp.async.bulk.tensor.*) consumes a 128-byte CUtensorMap blob built host-side by the CUDA driver's cuTensorMapEncodeTiled. PTX.jl wraps the driver call so descriptors take Julia types / symbols instead of raw CUtensorMapDataType enums, with a thin convenience helper for the common 2D row-major case.

These methods live in ext/PTXCUDACoreExt.jl and load automatically when CUDACore is in the environment.

PTX.tensor_map_encode_tiledFunction
tensor_map_encode_tiled(dtype, global_addr, global_dim, global_strides,
                        box_dim; kwargs...) -> CuTensorMap

Build a CUtensorMap for a tiled TMA descriptor. Calls cuTensorMapEncodeTiled.

Arguments (innermost-first convention, matching the driver):

  • dtype — Julia type (Float32, UInt16, …) or symbol (:bf16, :tf32, …)
  • global_addrPtr{T}, CuPtr, or UInt. Caller-owned; the driver stores it in the descriptor (use cuTensorMapReplaceAddress to swap later).
  • global_dim::NTuple{N, <:Integer} — tensor shape, innermost first.
  • global_strides::NTuple{N-1, <:Integer} — outer-dim strides in bytes. Length must be N-1 (the innermost stride is elem_bytes * 1, implicit).
  • box_dim::NTuple{N, <:Integer} — per-launch tile shape, innermost first.

Keyword arguments:

  • elem_strides::NTuple{N, <:Integer} = (1, 1, …) — sub-tile stride.
  • interleave::Symbol = :NONE:NONE / :B16 / :B32.
  • swizzle::Symbol = :NONE:NONE / :B32 / :B64 / :B128 (plus Blackwell atom variants).
  • l2_promotion::Symbol = :NONE:NONE / :B64 / :B128 / :B256.
  • oob_fill::Symbol = :NONE:NONE / :NAN_REQUEST_ZERO_FMA.

Requires using CUDACore. Without it the package extension is inert and this function raises a clear error.

source
PTX.tensor_map_tile_2dFunction
tensor_map_tile_2d(dtype, global_addr, rows, cols, box_rows, box_cols;
                   swizzle=:B128, oob_fill=:NONE) -> CuTensorMap

Convenience for a 2D row-major (rows, cols) tensor + (box_rows, box_cols) tile. Innermost dim is cols. Stride is cols * elem_bytes.

Mirrors pyptx synthesize_tma_descriptor (2D path). Caller is responsible for keeping box_cols * elem_bytes consistent with the swizzle (e.g. 128B for :B128).

source

When to extend

Most chain-default coverage is sufficient. Reach for a wrapper when:

  • ptxas rejects the chain output ("Arguments mismatch", wrong constraint letter on a shared-AS pointer, missing brackets on a memory operand);
  • the op returns multiple values;
  • the operand layout has a brace group, tensor-coord vector, or other shape with no $N rendering;
  • the dispatch needs to key on a fragment shape or count that the chain can't see.

The pattern: copy the closest existing wrapper file, adjust the table and asm template, add a for loop that emits one typed method per combination. ~80 LOC for a new family in most cases.