2805 Bowers Ave, Santa Clara, CA 95051 | 408-730-2275
research@colfax-intl.com

CUTLASS Tutorial: Mastering the NVIDIA® Tensor Memory Accelerator (TMA)

TMA (Tensor Memory Accelerator) is a new feature introduced in the NVIDIA Hopper™ architecture for doing asynchronous memory copy between a GPU’s global memory (GMEM) and the shared memory (SMEM) of its threadblocks (i.e., CTAs). Compared to prior approaches, TMA offers a number of advantages, such as (1) improving GPU utilization through facilitating warp-specialized kernel schedules via asynchrony, and (2) handling the computation of auxiliary copy data such as addresses and strides in a single-threaded manner via the TMA copy descriptor, which is both more register-efficient and necessarily handles predication (e.g., out-of-bounds checks). These advantages are well articulated in NVIDIA’s technical blog and Hopper tuning guide, which we highly recommend to readers for understanding the rationales behind the design of TMA.

In contrast to those sources, this blog post is focused on achieving an operational understanding of how to write kernels that use TMA. Throughout, we rely on the CuTe library, in which TMA is exposed through APIs wrapping lower-level GPU instructions. These instructions include PTX instructions cp.async.bulk.tensor and cp.reduce.async.bulk.tensor, as well as the cuTensorMap operand, which we will also discuss in this post.

We organize this blog post into three main sections: the first about TMA load, the second about TMA store, and lastly the third covering more advanced operations such as TMA store reduce and TMA load multicast. In essence, TMA load copies (“loads”) data from the GPU’s GMEM into one of its CTA’s SMEM, while TMA store copies (“stores”) data from a CTA’s SMEM to the GPU’s GMEM. Since TMA load, TMA store, and the more advanced variants share many concepts, we will introduce the bulk of the necessary concepts in the TMA load section and only focus on the remaining differences in the subsequent sections.

Also, given that TMA is an asynchronous operation (executed in the async proxy), we will need to use certain memory consistency enforcement tools, such as async memory barrier (i.e., mbarrier) and async memory fence (i.e., fence.proxy.async), to ensure correct behavior of the kernel. Synchronization is a vast topic of discussion by itself, so we will only cover these concepts to the degree needed for their practical use.

Finally, for readers looking for a resource that covers many of the same points with no reference to CUTLASS or CuTe concepts, we recommend the treatment of TMA in the CUDA® programming guide.

TMA Load

TMA load copies data from GMEM into SMEM. In this section, we demonstrate how to write a kernel that uses TMA load for this goal. A kernel that uses TMA load is quite different from a kernel that uses other memory copy methods, so we will first show how to write such a kernel for a simple example task. Then, we will explain the involved concepts.

Example task

To demonstrate the usage of TMA load, we consider a simple task of tiling a 2D row-major matrix. We are given a matrix A of shape [m,n] and two positive integers CTA_M and CTA_N. Note that CTA_M and CTA_N are known at compilation time, while m and n are given to us at runtime via the matrix A. For simplicity, let’s also assume that m % CTA_M == n % CTA_N == 0, though we will see later that this requirement can be relaxed.

We launch a grid of CTAs with size {m/CTA_M, n/CTA_N, 1}, where the SMEM of the (i,j)-th CTA holds the (i,j)-th tile with shape [CTA_M, CTA_N] from A. We can depict this assignment in numpy pseudocode as:

A = np.random.uniform(M, N)
for i in range(M):
  for j in range(N):
    cta_i_j = A.reshape(M // CTA_M, CTA_M, N // CTA_N, N)[i, :, j, :]

The two-step process. To perform this task, we use TMA load. In CuTe, a TMA load operation is implemented in two steps. The first step is the construction of the TMA copy descriptor in the host code, while the second step is the execution of the actual TMA load using this descriptor inside the kernel code. Note that this two-step process is different from what we normally do with CuTe’s TiledCopy — where all the copy steps are written in the kernel code — as shown in this tutorial.

Host code

On the host, we create three objects: the GMEM tensor which we copy from, the layout of the SMEM tensor on each of the CTAs that we copy into, and a tma_load object that takes these two as arguments. Note that since we create the SMEM layout on the host, all CTAs will share the same SMEM layout for the purposes of the TMA load. Once we have these objects, they can be passed to the kernel on device, inside of which the TMA load operation is invoked.

The entire code block on the host is:

template <typename T, int CTA_M, int CTA_N>
void host_fn(T* data, int M, int N) {
  using namespace cute;

  // create the GMEM tensor
  auto gmem_layout = make_layout(make_shape(M, N), LayoutRight{});
  auto gmem_tensor = make_tensor(make_gmem_ptr(T), gmem_layout);

  // create the SMEM layout
  auto smem_layout = make_layout(make_shape(CTA_M, CTA_N), LayoutRight{});

  // create the TMA object
  auto tma_load = make_tma_copy(SM90_TMA_LOAD{}, gmem_tensor, smem_layout);

  // invoke the kernel
  tma_load_kernel<CTA_M, CTA_N>
                 <<<1, dim3{M / CTA_M, N / CTA_N, 1}>>>
                 (tma_load, gmem_tensor, smem_layout);
}

The lines that create gmem_layout, gmem_tensor, and smem_tensor simply use basic CuTE concepts, so we refer readers to these CuTe tutorials for a memory refresh. Here, we focus on explaining the tma_load object. This object is an instance of cute::TiledCopy, which holds the information and implements the methods to perform a CTA-wide copy operation. In the code snippet, the tma_load object is created via this explicit default of the cute::make_tma_copy function. This function’s full implementation has some nuances, which we will dive into when we discuss MULTICAST later in this blog post, but the explicit default suffices for most use cases, such as our example task. We recommend using the explicit default to avoid unnecessary complications (and bugs).

Let’s look into the signature that we used for make_tma_copy:

  • Its last two arguments are gmem_tensor and smem_layout. Under the hood, make_tma_copy uses this information to create a TmaDescriptor, which is just an alias for CUtensorMap. This descriptor object is used inside the TMA kernel.
  • Its first argument is an instance of SM90_TMA_LOAD. This object dispatches the copy operation to the desired cp.async.bulk.tensor PTX call, which we will go into deeper in the third section below.

Kernel code

The relevant kernel code snippet looks like this. These lines pack many important TMA concepts, which we will explain below.

template <typename T, int CTA_M, int CTA_N, class TmaLoad, class GmemTensor>
void tma_load_kernel(__grid_constant__ const TmaLoad tma_load, GmemTensor gmem_tensor) {
  using namespace cute;
  constexpr int tma_transaction_bytes = CTA_M * CTA_N * sizeof(T);

  __shared__ T smem_data[CTA_M * CTA_N];
  __shared__ uint64_t tma_load_mbar;

  auto smem_layout = make_layout(make_shape(CTA_M, CTA_N), LayoutRight{});
  auto smem_tensor = make_tensor(make_smem_ptr(T), smem_layout);

  if (threadIdx.x == 0) {
    auto gmem_tensor_coord = tma_load.get_tma_tensor(shape(gmem_tensor));

    auto gmem_tensor_coord_cta = local_tile(
        gmem_tensor_coord,
        Tile<Int<CTA_M>, Int<CTA_N>>{},
        make_coord(blockIdx.x, blockIdx.y));

    initialize_barrier(tma_load_mbar, /* arrival count */ 1);

    set_barrier_transaction_bytes(tma_load_mbar, tma_transaction_bytes);

    auto tma_load_per_cta = tma_load.get_slice(0);
    copy(tma_load.with(tma_load_mbar),
         tma_load_per_cta.partition_S(gmem_tensor_coord_per_cta),
         tma_load_per_cta.partition_D(smem_tensor));
  }
  __syncthreads();
  wait_barrier(tma_load_mbar, /* phase */ 0);

  // after this line, the TMA load is finished
}

First, at line 7, the tma_load argument for the kernel must be annotated with __grid_constant__ const. If we have two tensors that we want to copy from GMEM into SMEM, each of them must have its own TiledCopy instance, and each instance must be __grid_constant__ const. This is a requirement for passing a cuTensorMap from host to device as documented here, for instance.

The next important point is that for a TMA copy, only one thread will be responsible for issuing the TMA operation. In the code snippet, all the TMA-related variables and instructions are contained in the if block starting at line 12, which is only executed by thread 0. On the other hand, line 30 contains an instruction for all threads in the CTA to wait for the TMA operations to finish.

Coordinates and Arithmetic tuples

For now, let’s look into the TMA load logic. This starts at line 13, where we create a gmem_tensor_coord object that holds the coordinates of the GMEM tensor to be copied. If we try the following:

if (cute::thread(0)) { cute::print(gmem_tensor_coord); }

then we see the output like so (for M=N=1024):

ArithTuple(_0,_0) o (1024,1024):(_1@1,_1@0)

Lines 15-18 are self-explanatory for readers familiar with the way tiled copy works in CuTe, where a GMEM tensor is tiled into smaller partitions, and each CTA slices into the tiled tensor according to the block coordinate to obtain its view of GMEM. Note however that the partitioning applies to the aforementioned ArithTuple representing the coordinates of gmem_tensor, instead of to gmem_tensor itself. In particular, the ArithTuple is partitioned into tiles of shape [CTA_M,CTA_N], and then each CTA takes its tile.

If we print gmem_tensor_coord_cta using print_tensor as follows:

if (cute::block(7)) { cute::print_tensor(gmem_tensor_coord_cta); }

then for CTA_M == CTA_N == 16, we see:

ArithTuple(0,112) o (_16,_16):(_1@1,_1@0):
  (0,112)  (1,112)  (2,112)  (3,112)  (4,112)  (5,112)  (6,112)  (7,112)  (8,112)  (9,112)  (10,112)  (11,112)  (12,112)  (13,112)  (14,112)  (15,112)
  (0,113)  (1,113)  (2,113)  (3,113)  (4,113)  (5,113)  (6,113)  (7,113)  (8,113)  (9,113)  (10,113)  (11,113)  (12,113)  (13,113)  (14,113)  (15,113)
  // more lines
  (0,127)  (1,127)  (2,127)  (3,127)  (4,127)  (5,127)  (6,127)  (7,127)  (8,127)  (9,127)  (10,127)  (11,127)  (12,127)  (13,127)  (14,127)  (15,127)

These numbers are the coordinates in gmem_tensor whose values will be copied into the smem_tensor of CTA 7. We encourage readers to try running this code snippet while replacing cute::block(7) with other indices to understand which CTAs copy from which coordinates in gmem_tensor.

Next, the copy operation itself, issued in lines 25-27, has the usual signature of a TiledCopy operation, where the source tensor is replaced by the partitioned coordinates.

Memory barrier

We have left out lines 20, 22, and 30, all of which involve the uint64_t variable tma_load_mbar which lives in SMEM. This is the asynchronous transaction barrier that we use to synchronize the TMA load with the rest of the kernel that consumes the resulting data loaded into SMEM. A high-level description of this type of barrier is given in the NVIDIA technical blog on Hopper architecture. In terms of our kernel, the important points are as follows:

  1. We initialize the mbarrier object in shared memory on line 20. The CuTe method initialize_barrier wraps the PTX instruction mbarrier.init.shared.b64, which takes in an additional arrival count parameter. In our context, since a single thread will initiate the TMA load, we should set the arrival count to 1. Moreover, the starting phase of the mbarrier will always be set to 0.
  2. We both perform an arrive-on operation and set the expected transaction count for the mbarrier object on line 22 with the CuTe method set_barrier_transaction_bytes, which wraps the PTX instruction mbarrier.arrive.expect_tx.shared::cta.b64. The transaction count is set to equal the number of bytes transferred by the TMA load, which we compute on line 4.
  3. On lines 25-27, the copy instruction, which dispatches to the desired flavor of cp.async.bulk.tensor, always has for its completion mechanism mbarrier::complete_tx::bytes with the provided mbarrier object.
  4. On line 30, we execute the wait operation on the mbarrier object. Note that all threads wait on the mbarrier, in contrast to only thread 0 arriving at the mbarrier, and the invocation of __syncthreads() is necessary prior to the wait_barrier to resolve the thread divergence. Here, wait_barrier wraps the PTX instruction mbarrier.try_wait.parity.shared::cta.b64. The try_wait qualifier (as opposed to test_wait) indicates that the wait is a blocking instruction. The parity qualifier, whose use entails providing a phase bit, indicates that the thread sleeps until that phase bit of the mbarrier flips. Because this is the first use of the mbarrier post-initialization to track completion, we supply 0 as the phase. If we were doing another TMA load, we would then have to flip the phase to reuse the mbarrier. In general, the CUTLASS Pipeline APIs offer a higher-level way to handle the lifecycle of the mbarrier objects when doing a series of TMA loads, as one might do in a software pipelining scheme.
  5. After wait_barrier, the memory consistency model furnishes us the following guarantee: the write to SMEM done by the TMA load is made visible to all threads that invoked the mbarrier wait (so in our example kernel, all threads in the CTA).
REMAINDER TILES WITH TMA and STRIDE REQUIREMENTS

In our above example, we supposed that m%CTA_M==0 and n%CTA_N==0. However, for the purposes of doing a TMA load, we can dispense with this assumption entirely. Instead of needing to handle the out-of-bounds logic ourselves when loading in remainder tiles from GMEM to SMEM, the TMA copy unit will necessarily predicate the memory copy to not read out-of-bounds. This is consistent with the use of special “implicit” CuTe tensors with ArithTuple as described above in the TMA load — if we used ordinary CuTe tensors instead, then they could be sliced to produce new CuTe tensors with possibly out-of-bounds pointers to GMEM, invariably leading to bugs.

However, there is one important requirement on the strides of the GMEM tensor itself to bear in mind for TMA, which is the 16-byte boundary requirement. As one might expect, TMA doesn’t support copying arbitrarily strided regions of GMEM. Rather, we need to assume that the tile being copied has (i) a contiguous direction (stride 1), and (ii) other strides as multiples of 16 bytes. This is asserted in the CUTLASS codebase.

For example, for our row-major GMEM tensor of floats, with shape (m, n) and stride (n, 1), this imposes the requirement that n%4==0. If this isn’t satisfied, then one can pad the input tensors to be of the right extent before invoking the kernel.

TMA Store

Equipped with the basics of TMA load, studying TMA store is a lot easier thanks to the many similarities between the two operations. Similar to TMA load, implementing TMA store is a two-step process: defining the TMA copy descriptor on the host, and then issuing the TMA store operation inside the kernel.

Example task and code

For illustration purposes, let’s consider the reverse example of TMA load, where we copy from the SMEM in multiple CTAs to corresponding tiles in a partitioned GMEM tensor. A difference here is that we will fill the SMEM tiles in the CTAs with a simple pattern of numbers before copying them to GMEM (otherwise, we would be copying undefined values). A functional code snippet is as follows:

template <typename T, int CTA_M=32, int CTA_N=32>
void host_fn(T* data, int M, int N) {
  using namespace cute;

  // create the GMEM tensor
  auto gmem_layout = make_layout(make_shape(M, N), LayoutRight{});
  auto gmem_tensor = make_tensor(make_gmem_ptr(T), gmem_layout);

  // create the SMEM layout
  auto smem_layout = make_layout(make_shape(CTA_M, CTA_N), LayoutRight{});

  // create the TMA object
  auto tma_store = make_tma_copy(SM90_TMA_STORE{}, gmem_tensor, smem_layout);

  // invoke the kernel
  tma_store_kernel<CTA_M, CTA_N>
                  <<<CTA_M, dim3{M / CTA_M, N / CTA_N, 1}>>>
                  (tma_store, gmem_tensor, smem_layout);
}

template <typename T, int CTA_M, int CTA_N, class TmaStore, class GmemTensor>
void tma_store_kernel(__grid_constant__ const TmaStore tma_store, GmemTensor gmem_tensor) {
  using namespace cute;
  __shared__ T smem_data[CTA_M * CTA_N];

  auto smem_layout = make_layout(make_shape(CTA_M, CTA_N), LayoutRight{});
  auto smem_tensor = make_tensor(make_smem_ptr(T), smem_layout);

  // fill the rows of smem_data
  for (int j = 0; j < CTA_N; ++j) {
    smem_data(threadIdx.x, j) = threadIdx.x;
  }
 
  __syncthreads();
  tma_store_fence();

  if (threadIdx.x == 0) {
    auto gmem_tensor_coord = tma_store.get_tma_tensor(shape(gmem_tensor));

    auto gmem_tensor_coord_cta = local_tile(
      gmem_tensor_coord,
      Tile<Int<CTA_M>, Int<CTA_N>>{},
      make_coord(blockIdx.x, blockIdx.y));

    auto tma_store_per_cta = tma_store.get_slice(0);
    copy(tma_store,
         tma_store_per_cta.partition_S(smem_tensor),
         tma_store_per_cta.partition_D(gmem_tensor_coord_per_cta));
    // tma_store_arrive();
  }
  // tma_store_wait<0>();
}

The host code looks almost identical to that of TMA load, except for the call to tma_store_kernel. Note that we have arranged for each CTA to have CTA_M threads. Our example then has each CTA hold a[CTA_M,CTA_N]tile in SMEM such that in lines 29-32, thread i fills row i with the value i.

In the kernel code, the if block in lines 39-49 is similar to the if block in the tma_load_kernel. In particular, only thread 0 issues the TMA store operation. All of the tensor tiling logic is conceptually the same. However, the copying direction is reversed: for TMA store, the tma_store_per_cta.partition_S method is applied to smem_tensor, while the tma_store_per_cta.partition_D method is applied to the coordinates of the GMEM tensor. Note that the coordinates are also represented as an ArithTuple, similar to TMA load.

Memory fence

The most important difference between the code for TMA load and store is that we no longer see any mbarrier object being used with TMA store. This is because TMA store uses another mechanism to enforce memory consistency: a memory fence.

The intention of a memory fence is to establish a guaranteed ordering between memory accesses requested by the executing thread before and after the fence. In our example, we need to ensure that all the writes to SMEM done in lines 29-32 are visible to the TMA store executed by thread 0. To this end, on line 35 we have the CuTe method tma_store_fence() that wraps the PTX instruction fence.proxy.async.shared::cta.

This instruction contains two important qualifiers that describe the effect of the fence: the scope and the proxykind. The scope indicates the set of threads that participate in the ordering enforced by the fence. In our case, the qualifier cta defines the scope as given by all threads in the CTA (which is the smallest possible scope for the purposes of the memory consistency model). The proxykind indicates the type of proxy that will participate in the ordering enforced by the fence, in addition to the generic proxy. In our case, we choose the proxykind to be async.shared since the TMA store is executed in the async proxy (with respect to each CTA). If we replaced the async fence by a different memory fence primitive such as __threadfence_block() that doesn’t involve the async proxy, we would destroy the guarantee needed for correct behavior of the kernel, leading to race conditions in practice.

TMA STORE ARRIVE AND WAIT

In lines 49 and 51, we have tma_store_arrive(), which commits the TMA store operation (technically, as a cp.async.bulk-group), and tma_store_wait<Count>(), which waits until at most Count many of the committed TMA store operations are pending (e.g., if all should be completed, then set Count to be 0). These operations are useful when one has other in-kernel work waiting on the completion of the TMA store — for example, this would be needed to reuse the freed SMEM made available after writing out. However, because our kernel simply exits after the TMA store is done, we don’t need the TMA store arrive and wait pattern here, so we comment out those lines.

A Deeper Look at TMA Operations

TMA LOADTMA STORE
DirectionGMEM -> SMEMSMEM -> GMEM
Sync methodMemory barrierProxy fence
When to syncAfter the operationBefore the operation
Summary of TMA operations.

Thus far, we have learned how to invoke the TMA load and TMA store operations. The above table compares and contrasts these operations. To invoke either operation, we need to create an object akin to TiledCopy via the cute::make_tma_copy method on the host code, and then pass this object into a kernel function, where we use them in cute::copy to actually invoke the operation. In this section, we take a deeper dive into what really happens when we call these TiledCopy objects in the kernel function. From this deep dive, we discuss two extensions: TMA store reduce and TMA load multicast.

PTX Instructions of TMA Load and Store

PTX (Parallel Thread Execution) is a low-level intermediate language for NVIDIA GPUs. For our discussion, the relevant part of PTX comprises a set of instructions that can be inserted into CUDA code via blocks wrapped by the asm volatile key words. In particular, when we call cute::copy(tma_load, ...) or cute::copy(tma_store, ...) as described in previous sections, certain PTX instructions are called to perform these operations. By studying the PTX, we can better understand TMA load and TMA store.

Let us start with TMA load. Recall that when we create the tma_load object in the host code, we must provide the GMEM tensor (which contains the source data to copy from) and the SMEM layout (which describes how the data will look like inside each CTA). Using this tensor and layout, CuTe determines the underlying PTX instruction to be executed when cute::copy(tma_load, ...) is invoked in the kernel. The PTX instruction is chosen depending on the rank of the GMEM tensor (note that rank here means the number of dimensions of the tensor, as opposed to matrix rank/nullity in linear algebra). In our example, the GMEM tensor has rank two, so the following PTX instruction will be executed:

    asm volatile (
      "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes"
      " [%0], [%1, {%3, %4}], [%2];"
      :
      : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
        "r"(crd0), "r"(crd1)
      : "memory");

Looking at this PTX instruction, we see many familiar concepts. For instance, gmem_int_desc refers to the coordinates kept in the TMA descriptor, while mbarrier::complete_tx::bytes and smem_int_mbar refer to the memory barrier. Note also that tensor.2d refers to the fact that we are copying a rank-2 tensor, i.e., a 2D matrix.

It turns out that not only TMA load but all TMA operations are wrappers around certain cp.async.bulk instructions. The NVIDIA PTX documentation dedicates an entire section to discuss cp.async.bulk instructions, specifically their syntaxes and operands. We encourage readers to read that section and the references therein for a more thorough study of TMA operations, which cover a much larger scope than this blog post is intended to. Here, we will discuss two extensions of TMA that are exposed via these cp.async.bulk instructions.

TMA Store Reduce

Recall that TMA store copies data from the SMEM of multiple CTAs into the corresponding tiles in a GMEM tensor. We can interpret TMA store as an assignment operation illustrated by the following Python pseudocode:

for cta_idx in range(number_of_ctas):
  gmem_dst[cta_idx] = smem_src[cta_idx]

What if we want to do the following instead?

for cta_idx in range(number_of_ctas):
  gmem_dst[cta_idx] += smem_src[cta_idx]
  # or this:
  gmem_dst[cta_idx] = max(gmem_dst[cta_idx], smem_src[cta_idx])
  # or this:
  gmem_dst[cta_idx] = min(gmem_dst[cta_idx], smem_src[cta_idx])

All of these operations — namely reduce sum, reduce max, and reduce min — are fairly common in tensor programs. In particular, reduce sum is an inevitable subroutine in Split-K GEMM, while reduce max and reduce min are often used in attention. As simple as these operations look, implementing them in CUDA kernels is not very straightforward. We invite readers to briefly think through how many rounds of data movements between GMEM and SMEM must be carried out to achieve these goals before reading the next paragraph.

The vanilla implementation of a reduce operation that “accumulates” values from a CTA’s SMEM into a tile in a GMEM tensor consists of one GMEM read, one processing block, and one GMEM write. First, the original value from the GMEM is loaded into the CTA’s SMEM or register, then the reduce operation happens, and finally the result is written back out. This process is slow.

Making a slight modification to the constructor of the TMA store TiledCopy object allows us to condense this three-step procedure to just one PTX instruction, namely cp.reduce.async.bulk instead of cp.async.bulk. Precisely, we can make the following one line change on the host code:

// original: create a TMA store object
auto tma_store = make_tma_copy(SM90_TMA_STORE{}, gmem_tensor, smem_layout);

// to create a TMA reduce sum object
auto tma_reduce_sum = make_tma_copy(SM90_TMA_REDUCE_ADD{}, gmem_tensor, smem_layout);

and then use tma_reduce_sum instead, which now calls cp.reduce.async.bulk instead of cp.async.bulk under the hood.

As an aside, the PTX instruction cp.reduce.async.bulk has been available since the release of CUDA 12.0, but was not exposed through CUTLASS and CuTe until the CUTLASS 3.5 release. We hope other reduction operations will be exposed in future releases, but if they are not, it’s fairly simple to adapt the CuTe code for TMA reduce add to perform the max and min reductions, as well as other bitwise reductions that cp.reduce.async.bulk offers: and, or, xor, inc, and dec.

TMA Load Multicast

In the previous section, we have seen that studying PTX instructions allows us to discover TMA reduce operations, which can be used instead of TMA store for certain applications. In this section, we will study the multicast extension of TMA load.

To aid our understanding, we first take a look at the full syntax of cp.async.bulk.tensor:

// global -> shared::cluster:
cp.async.bulk.tensor.dim.dst.src{.load_mode}.completion_mechanism
{.multicast}{.level::cache_hint}
  [dstMem],
  [tensorMap, tensorCoords],
  [mbar]
  {, im2colOffsets}
  {, ctaMask}
  {, cache-policy}

.dst =                  { .shared::cluster }
.src =                  { .global }
.dim =                  { .1d, .2d, .3d, .4d, .5d }
.completion_mechanism = { .mbarrier::complete_tx::bytes }
.load_mode =            { .tile, .im2col }
.level::cache_hint =    { .L2::cache_hint }
.multicast =            { .multicast::cluster  }

Again, without the need to completely understand the syntax of PTX instructions, we see many familiar concepts such as .dim, .global for src, and .mbarrier for completion_mechanism. This section focuses on the multicast operand.

Multicast refers to a situation where we have a tile in a GMEM tensor that we want to copy to multiple SMEM locations in multiple CTAs. This is typically the case in GEMM kernels (i.e., matrix multiplication), where an input matrix column tile is needed for multiple row tiles or vice versa. In such cases, while TMA load is still perfectly functional — we simply provide the same TMA descriptor to the multiple CTAs that need it — the .multicast operand allows us to guarantee L2-cache hits.

Let’s consider an extension of the above TMA load example to one with multicast. To begin with, we need to define the cluster dimensions of our kernel to be non-trivial, since a requirement for a subset of CTAs to collectively participate in a TMA load multicast operation is that they belong to the same (threadblock) cluster. In order to keep things simple, we will just change the grid dimensions as so:

// old grid dimensions and implicit trivial cluster dimensions
dim3 grid_dims = dim3{M / CTA_M, N / CTA_N, 1};
dim3 cluster_dums = dim3{1, 1, 1};

// new grid dimensions and cluster dimensions
dim3 grid_dims = dim3{M / CTA_M, N / CTA_N, 2};
dim3 cluster_dums = dim3{1, 1, 2};

Note that when using clusters, the cluster dimensions must evenly divide into the grid dimensions, or the kernel will not launch. In our new kernel, we will then arrange for the same tile of GMEM to be loaded into each CTA’s SMEM for every pair of CTAs in the same cluster, which occurs if and only if the two CTAs have the same blockIdx.x and blockIdx.y.

First, in the host code we make the following change to the definition of the TMA load TiledCopy object:

// original: create a TMA load object
auto tma_load = make_tma_copy(SM90_TMA_LOAD{}, gmem_tensor, smem_layout);

// new: create a TMA load multicast object for the given cluster size
auto tma_load = make_tma_copy(SM90_TMA_LOAD_MULTICAST{},
      gmem_tensor, smem_layout, cute::_2{});

We write _2{} for the last parameter (the cluster size) to pass it as a compile-time constant, using the CuTe integer types provided for this purpose. In practice and more idiomatically, we would have defined the ClusterShape type prior (in our case, to be Shape<_1,_1,_2>) and then write size<2>ClusterShape{} for that parameter.

We then change the kernel code as follows:

template <typename T, int CTA_M, int CTA_N, class ClusterShape,
          class TmaLoad, class GmemTensor>
void tma_load_kernel(__grid_constant__ const TmaLoad tma_load,
                     GmemTensor gmem_tensor) {
  using namespace cute;
  uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
  constexpr uint32_t cluster_size = size<2>(ClusterShape{}));
  constexpr uint16_t tma_mcast_mask = (uint16_t(1) << cluster_size) - 1;
  constexpr int tma_transaction_bytes = CTA_M * CTA_N * sizeof(T);

  __shared__ T smem_data[CTA_M * CTA_N];
  __shared__ uint64_t tma_load_mbar;

  auto smem_layout = make_layout(make_shape(CTA_M, CTA_N), LayoutRight{});
  auto smem_tensor = make_tensor(make_smem_ptr(T), smem_layout);
  auto gmem_tensor_coord = tma_load.get_tma_tensor(shape(gmem_tensor));
  auto gmem_tensor_coord_cta = local_tile(
        gmem_tensor_coord,
        Tile<Int<CTA_M>, Int<CTA_N>>{},
        make_coord(blockIdx.x, blockIdx.y));

  if (threadIdx.x == 0) {
    initialize_barrier(tma_load_mbar, /* arrival count */ 1);
  }
  __syncthreads();
  cute::cluster_sync();
  cutlass::arch::fence_barrier_init();

  if (threadIdx.x == 0) {
    set_barrier_transaction_bytes(tma_load_mbar, tma_transaction_bytes);
    auto tma_load_per_cta = tma_load.get_slice(block_rank_in_cluster);
    copy(tma_load.with(tma_load_mbar, tma_mcast_mask),
         tma_load_per_cta.partition_S(gmem_tensor_coord_per_cta),
         tma_load_per_cta.partition_D(smem_tensor));
  }
  __syncthreads();
  wait_barrier(tma_load_mbar, /* phase */ 0);

  // after this line, the TMA load is finished

  cute::cluster_sync();
}

We have highlighted the relevant changes. First, we now need to track the internal index of the CTA within its cluster, which we fetch via the CuTe method block_rank_in_cluster(). This returns the value of the special register %cluster_ctarank, which will take on values 0 and 1 in our example. For brevity, let us refer to this as the ctaid. We then have the following three modifications to the code to unpack:

  1. Additional cluster synchronization primitives.
  2. Use of the uint16 bitmask in the multicast operation.
  3. Use of the ctaid in determining the slice of the TiledCopy object used to partition the GMEM and SMEM tensors.

For (1), we use the CuTe method cluster_sync(), which does both a cluster barrier arrive and wait operation in sequence. We insert this in two places: in lines 7-8 we use cluster_sync() together with a fence to ensure cluster-wide visibility of the mbarrier initialization, and on line 41 we use another cluster_sync() to ensure that one of the two CTAs in the cluster doesn’t exit prematurely while the other is still waiting for the multicast load to complete. In general, there would be compute done on the data loaded into SMEM, and the last cluster_sync() would appear at the very end of the kernel code.

For (2), we pass a uint16 bitmask to the copy operation to specify which CTAs will participate in the TMA multicast load. The bits set to 1 in the mask indicate which CTAs are active, with a maximum of 16 CTAs in a cluster (maximum nonportable size) and the position of the bit corresponding to the ctaid. Thus, in our example, by setting tma_mcast_mask to 0b11 we specify that both CTAs in the cluster will participate.

Finally, for (3), the ctaid is used to specify the offset used when slicing into GMEM for the TMA multicast load operation launched from the given CTA. To explain this point clearly, consider the following example of loading in a 16 x 16 tile of integers, initialized to be 0-255 in ascending row-major order, from GMEM to the SMEM of two CTAs in a cluster. Suppose we mistakenly gave 0 as the parameter to tma_load.get_slice for both CTAs. Then we obtain the following in both CTAs’ SMEM after completion of the load:

    0    1    2    3    4    5    6    7    8    9   10   11   12   13   14   15
   16   17   18   19   20   21   22   23   24   25   26   27   28   29   30   31
   32   33   34   35   36   37   38   39   40   41   42   43   44   45   46   47
   48   49   50   51   52   53   54   55   56   57   58   59   60   61   62   63
   64   65   66   67   68   69   70   71   72   73   74   75   76   77   78   79
   80   81   82   83   84   85   86   87   88   89   90   91   92   93   94   95
   96   97   98   99  100  101  102  103  104  105  106  107  108  109  110  111
  112  113  114  115  116  117  118  119  120  121  122  123  124  125  126  127
    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0

In contrast, if we have 1 be the given parameter for both CTAs, then we get this in both CTAs’ SMEM:

    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0
  128  129  130  131  132  133  134  135  136  137  138  139  140  141  142  143
  144  145  146  147  148  149  150  151  152  153  154  155  156  157  158  159
  160  161  162  163  164  165  166  167  168  169  170  171  172  173  174  175
  176  177  178  179  180  181  182  183  184  185  186  187  188  189  190  191
  192  193  194  195  196  197  198  199  200  201  202  203  204  205  206  207
  208  209  210  211  212  213  214  215  216  217  218  219  220  221  222  223
  224  225  226  227  228  229  230  231  232  233  234  235  236  237  238  239
  240  241  242  243  244  245  246  247  248  249  250  251  252  253  254  255

Finally, giving either 0 from ctaid 1 and 1 from ctaid 0, or 0 from ctaid 0 and 1 from ctaid 1, would correctly load in the entire tile to both CTAs’ SMEM. These printouts illustrate that issuing the multicast operation from one CTA in the cluster loads half of the GMEM into each of the two CTAs’ SMEM, with the slice of the TiledCopy determining the respective half. This is consistent with the description of multicast for cp.async.bulk.tensor in the PTX documentation:

The source data is multicast to the same CTA-relative offset as dstMem in the shared memory of each destination CTA.

In terms of the TiledCopy object, which generically has a layout TiledLayout_TV mapping thread-value tuples to logical coordinates of the tile, CuTe treats the ctaid as the thread index for the purposes of slicing. For example, printing out the TiledCopy in our 16 x 16 example yields the following:

TiledCopy
  Tiler_MN:       (_16,_16)
  TiledLayout_TV: (_2,((_16,_16))):(_8,((_16,_1)))
Copy_Atom
  ThrID:        _1:_0
  ValLayoutSrc: (_1,_256):(_0,_1)
  ValLayoutDst: (_1,_256):(_0,_1)
  ValLayoutRef: (_1,_256):(_0,_1)
  ValueType:    32b

which has two “threads” corresponding to the two CTAs in the cluster, with the offset position given by the logical coordinate (8,0) in the (16,16) tile for ctaid 1.

Conclusion

In this blog post, we walked through a few simplified examples of using TMA load, store, store reduce, and load multicast to perform memory copy between GMEM and SMEM in a CUDA kernel, using the methods provided by the CUTLASS library.

We started by providing an overview of TMA and went into how a user can invoke these operations in a GPU kernel. Then, we dived deeper into the low-level PTX instructions in order to elicit a greater understanding of TMA. We hope this blog post is helpful for readers who want to understand TMA, to refresh their knowledge on the topic, or to debug their existing projects which use TMA.

We left out a few important topics such as supported swizzling modes for TMA and the ability for TMA to copy GMEM to SMEM in an interleaved format, permuting strides outside the contiguous dimension. These are important when using TMA in conjunction with the Warpgroup Matrix-Multiply-Accumulate (WGMMA) instructions, also new to the Hopper architecture, in order to load tensor data in a memory format compatible with WGMMA. We will explain these points when we discuss Hopper-based GEMM in a future post.

Lastly, fully-worked out examples of the kernels discussed in this blog post can be found on our Colfax Research GitHub repository.

Posted

in

,