Many variants of attention (Vaswani et al., 2017) have become popular in recent years, for reasons related to performance and model quality. These include:
- Causal attention for autoregressive language modeling, where a token only attends to those prior;
- Sliding window attention for long-context language modeling, where a token only attends to those prior within a predefined window, reducing computational complexity of attention from \mathcal{O}(n^2 d) to \mathcal{O}(nwd) where w is the window size;
- ALiBi (Press et al., 2021), which uses a positional bias linear in distance to encode relative position without explicit embeddings, improving extrapolation to longer sequences;
- T5 bias and PrefixLM (Raffel et al., 2020), which introduce learned additive biases or prefix tokens that condition attention structure on task semantics rather than strict sequence order and allow partial bidirectional (non-causal) attention, respectively; and
- Attention sink (Xiao et al., 2023), which significantly boosts the quality of sliding window attention by adding a fixed set of cached KV tokens to which all tokens attend, preserving global context while maintaining linear computational complexity.
The PyTorch team at Meta recognized that most of these variants (including all of the above) can be unified under one elegant framework, dubbed FlexAttention (Guessous et al., 2024). This simple API allows users to define and work with a large collection of attention variants, including novel combinations of existing ones, with relatively little development overhead and decent performance.
FlexAttention adds two options for customization: a score_mod callable that modifies pre-softmax attention scores and a mask_mod callable that masks out pre-softmax attention scores. Altogether, FlexAttention takes the form
\text{FlexAttention}(Q, K, V) = \text{Softmax}\left({\color{orange}\text{mask\_mod}}\left({\color{red}\text{score\_mod}}\left(QK^T\right)\right)\right) VNote that mask_mod is a special case of score_mod where scores are set to -inf; we keep the two separate for efficiency reasons, as will be explained when discussing block sparsity.
The original FlexAttention implementation is in Triton. While this implementation comes within 90% of FlashAttention 2 performance on Ampere GPUs, the performance on Hopper is significantly worse in comparison to FlashAttention 3.
In this blog post, we discuss our recent implementation of FlexAttention integrated into FlashAttention 3 CuTe DSL, done in collaboration with Driss Guessous (Meta) and Tri Dao (Princeton; Together AI), achieving 95% of the performance of FlashAttention 3 in the forward pass. This is a roughly 50% speedup over the Triton version in most cases. The implementation in FlashAttention 4 on Blackwell, though still in progress, exhibits similar—and in many cases, significantly greater—performance gains.
We focus on explaining the API so that developers can quickly integrate FlexAttention into their workflows.
Score Modification
The score_mod callable modifies pre-softmax attention scores based on position and optional auxiliary tensors. The generic signature is:
generic_score_mod(
score: float,
batch_idx: int,
head_idx: int,
q_idx: int,
kv_idx: int,
aux_tensors: Optional[list[tensor]],
) -> float
Examples
Example 1: T5 (Relative Positional) Bias
def rel_bias_score_mod(score, batch_idx, head_idx, q_idx, kv_idx, aux_tensors):
bias_tensor = aux_tensors[0]
rel_pos = math.abs(q_idx - kv_idx)
return score + bias_tensor[batch_idx, head_idx, rel_pos]
Example 2: ALiBi
def alibi_score_mod(score, batch_idx, head_idx, q_idx, kv_idx, aux_tensors):
slope = math.exp2(-(head_idx + 1))
dist = math.abs(q_idx - kv_idx)
return score - slope * dist
CuTe DSL Implementation
In the CuTe DSL implementation, we require that score_mod be defined in terms of the TensorSSA abstraction (see the CUTLASS TensorSSA notebook). For example, T5 bias could take the following form:
@cute.jit
def rel_bias_score_mod_cute(
tSrS_ssa: cute.TensorSSA,
batch_idx: cute.TensorSSA,
head_idx: cute.TensorSSA,
q_idx: cute.TensorSSA,
kv_idx: cute.TensorSSA,
aux_tensors: Optional[list]
) -> cute.TensorSSA:
bias_tensor = aux_tensors[0]
rel_pos = cute.TensorSSA(
mlir_math.absi(q_idx - kv_idx),
q_idx.shape,
q_idx.dtype
)
bias = bias_tensor[batch_idx[0], head_idx[0], rel_pos[0]].to(cutlass.Float32)
return tSrS_ssa + bias
Application of score_mod is expensive, as it requires looping over all entries in the scores matrix; TensorSSA accordingly allows for easy vectorized and broadcasted instructions. In the score mod application in the attention mainloop, we compute modified scores in groups of vec_size, a tunable hyperparameter. We note that without further assumptions, vectorization of score_mod application is not feasible when using aux_tensors.
Usage
Once a user has defined a score_mod function, they can easily pass it into the FlashAttention interface.
Direct CuTe DSL interface:
from flash_attn.cute.interface import _flash_attn_fwd
out, _ = _flash_attn_fwd(
q, k, v, # torch.Tensor
score_mod=rel_bias_score_mod_cute,
aux_tensors=aux_tensors, # Optional[list[torch.Tensor]]
)
Torch tensors are converted to cute.Tensors within the _flash_attn_fwd method. Many optional arguments have been omitted here for brevity.
PyTorch integrated interface:
The CuTe DSL implementation of FlexAttention is also integrated into PyTorch when built from source; it will be incorporated into the stable build in the near future. Instead of defining a TensorSSA-compatible score_mod function, one can define score_mod within PyTorch and rely on TorchInductor to properly generate the CuTe DSL code:
from torch.nn.attention.flex_attention import flex_attention
compiled_fn = torch.compile(flex_attention)
out = compiled_fn(
q, k, v,
score_mod=rel_bias_score_mod,
kernel_options={"force_flash": True}, # Use CuTe DSL backend
)
Mask Modification
Defining mask_mod callables is almost identical to the score_mod case, with some simplifications. The mask application logic is contained in the FlashAttention forward kernel, so our mask_mod callable need only return a Boolean indicating whether or not a certain score needs to be masked (set to -inf):
generic_mask_mod(
batch_idx: cute.TensorSSA,
head_idx: cute.TensorSSA,
q_idx: cute.TensorSSA,
kv_idx: cute.TensorSSA,
aux_tensors: Optional[list],
) -> cute.TensorSSA # dtype == cutlass.Boolean
Note that unlike score_mod, we don’t pass in the score itself—we only need the positional information to determine whether a particular attention element should be masked.
Examples
Example 1: Causal Mask with Offset
To create a causal mask with the proper offset (seqlen_k - seqlen_q, or others as needed), do
import flash_attn.cute.utils as utils
def create_causal_mask_with_offset(offset: int):
@cute.jit
def _causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx, aux_tensors):
offset_ssa = utils.scalar_to_ssa(val=offset, dtype=cutlass.Int32)
return kv_idx <= q_idx + offset_ssa
return _causal_mask_mod
(Aside: this mask will require recompilation every time seqlen_k - seqlen_q changes; to avoid this, one could pass in offset as an additional aux_tensor.)
Example 2: Document Masking
When sequences from multiple documents have been concatenated, tokens should only attend within their document. To prevent information leakage across document boundaries, we do the following:
@cute.jit
def document_mask_mod(batch_idx, head_idx, q_idx, kv_idx, aux_tensors):
doc_ids = aux_tensors[0]
doc_id_q = doc_ids[batch_idx[0], head_idx[0], q_idx[0]]
doc_id_kv = doc_ids[batch_idx[0], head_idx[0], kv_idx[0]]
q_doc = utils.scalar_to_ssa(doc_id_q, cutlass.Int32)
kv_doc = utils.scalar_to_ssa(doc_id_kv, cutlass.Int32)
return q_doc == kv_doc
Here, doc_ids is an Int32 tensor of shape (B, H, seqlen) representing to which document a given token belongs, with the assumption that documents are contiguous. For simplicity, we may also sometimes assume it is non-negative and non-decreasing, though this is not strictly required.
Usage
out, _ = _flash_attn_fwd(
q, k, v,
mask_mod=document_mask_mod,
aux_tensors=[doc_ids],
)
Block Sparsity
The conceptual simplicity of FlexAttention belies the need for thoughtful optimizations. When large portions of the scores matrix are to be masked, we would like to intelligently avoid these regions where possible, skipping unnecessary data movement and computation. To do so, FlexAttention implements block sparsity with mask mods.
Take for example the case of causal masking. Consider the problem with batch size 1, one head, seqlen_q = 768, seqlen_kv = 896, and work tile size 128×128. There are 42 total blocks to handle:
- 6 blocks along the main diagonal (note that in causal masking, the diagonal is offset to meet the bottom-right corner, rather than the top-left) are split in half by the causal mask; these need
mask_modapplication - 21 blocks below the diagonal have no masking at all; these do not need
mask_modapplication (though they do needscore_mod), so we should skip applyingmask_modon these blocks - The remaining 15 blocks are to be skipped entirely; it would be wasteful even to load them

Block Sparsity Tensors
Each work tile in the FlashAttention kernel corresponds to one (batch, head, q_block) coordinate. To compute only the tiles needed, we need to know the coordinates of each partially-masked tile and the coordinates of each fully-computed tile. We encapsulate these in two tensors:
mask_block_idx: [B, H, num_q_blocks, num_kv_blocks]representing blocks that require application ofmask_modandfull_block_idx: [B, H, num_q_blocks, num_kv_blocks]representing fully-computed blocks.
Here, num_q_blocks = ceil_div(seqlen_q, tile_m) is the number of work tiles in the q dimension, and num_kv_blocks = ceil_div(seqlen_kv / tile_n) is the number of work tiles in the kv dimension.
To index properly into these tensors, we also keep track of two “count” tensors:
mask_block_cnt: [B, H, num_q_blocks]representing the total number of partially-maskedkv_blocksandfull_block_cnt: [B, H, num_q_blocks]representing the total number of fully-computedkv_blocks.
We assume that for any (b, h, q_block):
- Setting
mask_cnt = mask_block_cnt[b, h, q_block], the tensormask_block_idx[b, h, q_block, :mask_cnt]is strictly increasing, and the remaindermask_block_idx[b, h, q_block, mask_cnt:]is identically 0 - Setting
full_cnt = full_block_cnt[b, h, q_block], the tensorfull_block_idx[b, h, q_block, :full_cnt]is strictly increasing, disjoint frommask_block_idx[b, h, q_block, :mask_cnt], and the remainderfull_block_idx[b, h, q_block, full_cnt:]is identically 0.
The disjointness condition in 2 guarantees that no block is processed twice.
For cleanliness, these tensors are wrapped in a BlockSparseTensors class:
class BlockSparseTensors(NamedTuple):
mask_block_cnt: cute.Tensor
mask_block_idx: cute.Tensor
full_block_cnt: Optional[cute.Tensor]
full_block_idx: Optional[cute.Tensor]
Note that full_block_cnt and full_block_idx can be optional; mask_mod will be applied to all blocks in that case.
Example: Causal Masking Block Sparsity
For causal masking with the parameters above, the block sparsity tensors are:
mask_block_cnt = [[[1, 1, 1, 1, 1, 1]]]
mask_block_idx = [[[[1, 0, 0, 0, 0, 0, 0],
[2, 0, 0, 0, 0, 0, 0],
[3, 0, 0, 0, 0, 0, 0],
[4, 0, 0, 0, 0, 0, 0],
[5, 0, 0, 0, 0, 0, 0],
[6, 0, 0, 0, 0, 0, 0]]]]
full_block_cnt = [[[1, 2, 3, 4, 5, 6]]]
full_block_idx = [[[[0, 0, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0, 0],
[0, 1, 2, 0, 0, 0, 0],
[0, 1, 2, 3, 0, 0, 0],
[0, 1, 2, 3, 4, 0, 0],
[0, 1, 2, 3, 4, 5, 0]]]]
Computing Block Sparsity
Computing BlockSparseTensors for a given mask_mod, sequence length, and tile size can be computationally expensive; this is unavoidable. It is, however, generally amortized across all layers of a model, so it is not too problematic in practice.
PyTorch possesses a similar, more robust class BlockMask that can be converted into BlockSparseTensors:
from torch.nn.attention.flex_attention import create_block_mask
block_mask_torch = create_block_mask(
mask_mod_fn, # PyTorch mask function
B, H, seqlen_q, seqlen_kv,
device="cuda",
BLOCK_SIZE=(tile_m, tile_n),
)
# Convert to CuTe DSL format
_, _, mask_cnt, mask_idx, full_cnt, full_idx, *_ = block_mask_torch.as_tuple()
block_sparse_tensors = BlockSparseTensorsTorch(
mask_block_cnt=mask_cnt,
mask_block_idx=mask_idx,
full_block_cnt=full_cnt,
full_block_idx=full_idx,
)
Warning: The tile size used to compute block sparsity must be the same as the tile size used in the kernel.
Complete API Call
Altogether, FlexAttention can be used in FlashAttention CuTe DSL by calling
_flash_attn_fwd(
q, k, v, # torch.Tensor
score_mod=score_mod, # Callable
mask_mod=mask_mod, # Callable
block_sparse_tensors_torch=block_sparse_tensors, # BlockSparseTensorsTorch
aux_tensors=aux_tensors, # Optional[list[torch.Tensor]]
)
Within _flash_attn_fwd, block_sparse_tensors_torch is converted into a BlockSparseTensors object via
sparse_tensors = flash_attn.cute.block_sparsity.to_cute_block_sparse_tensors(
block_sparse_tensors_torch
)
Examples
Example 1: Document Masking with Relative Positional Bias
This example demonstrates a combination of a score_mod and a mask_mod that both use aux_tensors.
We assume given a doc_ids tensor as well as a rel_bias tensor with shapes [B, H, max_seqlen], where max_seqlen = max(seqlen_kv, seqlen_q). For example, we may have B = 1, H = 1, max_seqlen = 640, and a doc_ids tensor where
# 3 documents at positions [0:230], [230:410], [410:640]
doc_ids = torch.zeros((1, 1, 640), dtype=torch.int32)
doc_ids[0, 0, :230] = 0
doc_ids[0, 0, 230:410] = 1
doc_ids[0, 0, 410:] = 2
The full implementation combining score_mod and mask_mod is
@cute.jit
def doc_rel_bias_score_mod(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors):
rel_bias = aux_tensors[0]
distance = cute.TensorSSA(
mlir_math.absi(q_idx - kv_idx),
q_idx.shape, q_idx.dtype
)
bias = rel_bias[b_idx[0], h_idx[0], distance[0]].to(cutlass.Float32)
return tSrS_ssa + bias
@cute.jit
def document_mask_mod(b_idx, h_idx, q_idx, kv_idx, aux_tensors):
doc_ids = aux_tensors[1] # Second aux tensor
q_doc = doc_ids[b_idx[0], h_idx[0], q_idx[0]]
kv_doc = doc_ids[b_idx[0], h_idx[0], kv_idx[0]]
q_doc_ssa = utils.scalar_to_ssa(q_doc, cutlass.Int32)
kv_doc_ssa = utils.scalar_to_ssa(kv_doc, cutlass.Int32)
return q_doc_ssa == kv_doc_ssa
rel_bias = torch.randn((1, 1, 640), dtype=torch.float32)
aux_tensors = [rel_bias, doc_ids]
# Compute block sparsity
block_sparse_tensors = compute_block_sparsity(...)
out, _ = _flash_attn_fwd(
q, k, v,
score_mod=doc_rel_bias_score_mod,
mask_mod=document_mask_mod,
block_sparse_tensors_torch=block_sparse_tensors,
aux_tensors=aux_tensors,
)
The block sparsity tensors for this mask show the structure of the three document blocks clearly, with tokens only attending within their respective documents.

Example 2: PrefixLM with Per-Head Bias
PrefixLM (Raffel et al., 2020) combines causal and non-causal attention by having all tokens attend to a fixed-length prefix in addition to ordinary causal masking. This is useful for tasks where the input should be processed bidirectionally (like in an encoder) while the output remains autoregressive.
The mask_mod function is as follows:
def create_prefix_lm_mask(prefix: int, offset: int):
@cute.jit
def _prefix_lm_mask_mod(b_idx, h_idx, q_idx, kv_idx, aux_tensors):
prefix_ssa = utils.scalar_to_ssa(prefix, cutlass.Int32)
offset_ssa = utils.scalar_to_ssa(offset, cutlass.Int32)
# Allow bidirectional attention in prefix OR causal after
in_prefix = kv_idx < prefix_ssa
causal = kv_idx <= q_idx + offset_ssa
return in_prefix | causal
return _prefix_lm_mask_mod
The score_mod function is a simple one, taking a per-head bias tensor head_bias:
@cute.jit
def head_bias_score_mod(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors):
head_bias = aux_tensors[0]
bias_val = head_bias[h_idx[0]].to(cutlass.Float32)
return tSrS_ssa + bias_val
With batch size 1, 1 head, tile size 128×128, sequence length 768, and prefix 204, the block sparsity structure shows the characteristic pattern of PrefixLM: all tokens can attend bidirectionally to the first block (the prefix), while subsequent tokens follow causal masking.
head_biases = torch.randn(num_heads, dtype=torch.float32)
mask_mod = create_prefix_lm_mask(prefix=204, offset=0)
out, _ = _flash_attn_fwd(
q, k, v,
score_mod=head_bias_score_mod,
mask_mod=mask_mod,
block_sparse_tensors_torch=block_sparse_tensors,
aux_tensors=[head_biases],
)

Quick Reference
| Feature | Type | Example |
| ALiBi | score_mod | -slope * distance |
| Causal | mask_mod | kv_idx <= q_idx |
| Sliding window | mask_mod | abs(q_idx - kv_idx) <= w |
| T5 bias | score_mod | score + bias[rel_pos] |
| Document mask | mask_mod | doc[q] == doc[kv] |
| PrefixLM | mask_mod | kv < prefix | kv <= q |
Getting Started
Here is a minimal working example to get started with FlexAttention:
# 1. Define mods
import flash_attn.cute.utils as utils
@cute.jit
def my_score_mod(score, b_idx, h_idx, q_idx, kv_idx, aux_tensors):
scale = utils.scalar_to_ssa(1.1, cutlass.Float32)
return score * scale # Example: scale scores
@cute.jit
def my_mask_mod(b_idx, h_idx, q_idx, kv_idx, aux_tensors):
return kv_idx <= q_idx # Example: causal w/o offset
# 2. Compute block sparsity
from torch.nn.attention.flex_attention import create_block_mask
block_mask = create_block_mask(
my_mask_mod,
B, H, seqlen_q, seqlen_kv,
device="cuda",
BLOCK_SIZE=(128, 128)
)
# 3. Run attention
from flash_attn.cute.interface import _flash_attn_fwd
out, lse = _flash_attn_fwd(
q, k, v,
score_mod=my_score_mod,
mask_mod=my_mask_mod,
block_sparse_tensors_torch=block_mask
)
The key steps are: (1) define your attention modifications as callables, (2) compute block sparsity once for your mask pattern and tile size, and (3) call the forward function with your modifications. The block sparsity computation can be cached and reused across layers and iterations.
For more details on the native CuTe DSL API (without PyTorch), see the Appendix.
References
- Vaswani et al., “Attention Is All You Need”, 2017. Attention is all you need. In Proceedings of the 31st International Conference on Neural Information Processing Systems (NIPS’17). Curran Associates Inc., Red Hook, NY, USA, 6000–6010.
- Press et al., “Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation”, 2021. https://arxiv.org/abs/2108.12409
- Raffel, Colin et al. “Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer.” J. Mach. Learn. Res. 21 (2019): 140:1-140:67.
- Xiao et al., “Efficient Streaming Language Models with Attention Sinks”, 2023. https://arxiv.org/abs/2309.17453
- Guessous et al., “FlexAttention: The Flexibility of PyTorch with the Performance of FlashAttention”, 2024. https://pytorch.org/blog/flexattention/
Appendix: CuTe DSL-native API
In this appendix, we present the API to use FlexAttention without referencing torch tensors. We provide a CuTe DSL-native block sparsity computation kernel in flash_attn.cute.compute_block_sparsity with interface compute_blocksparse_tensors, which has the signature
compute_block_sparsity(
tile_m: int,
tile_n: int,
batch_size: int,
num_heads: int,
seqlen_q: int,
seqlen_k: int,
mask_mod: Callable,
aux_tensors: Optional[list[cute.Tensor]],
device: str = "cuda",
compute_full_blocks: bool = True,
use_fast_sampling: bool = False,
) -> Tuple[BlockSparseTensors, Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]
With this kernel in hand, we can present a complete example workflow on Hopper architecture using the native API:
from flash_attn.cute.compute_block_sparsity import compute_blocksparse_tensors
from flash_attn.cute.flash_fwd import FlashAttentionForwardSm90
tile_m, tile_n = 128, 128
batch_size, num_heads, seqlen_q, seqlen_k = 2, 8, 8192, 8192
mask_mod = user_defined_mask_mod
score_mod = user_defined_score_mod
aux_tensors = user_provided_aux_tensors
device = "cuda"
# Compute block sparsity
blocksparse_tensors, blocksparse_torch_tensors = compute_blocksparse_tensors(
tile_m,
tile_n,
batch_size,
num_heads,
seqlen_q,
seqlen_k,
mask_mod,
aux_tensors,
device,
)
# Instantiate kernel
fa_fwd = FlashAttentionForwardSm90(
dtype,
head_dim,
head_dim_v,
qhead_per_kvhead,
is_causal=False,
is_local=False,
pack_gqa=False,
tile_m=tile_m,
tile_n=tile_n,
num_stages=2,
num_threads=384,
Q_in_regs=False,
intra_wg_overlap=True, # tunable hyperparameter for optimizations
mma_pv_is_rs=True, # tunable hyperparameter for optimizations
mask_mod=mask_mod,
score_mod=score_mod,
has_aux_tensors=aux_tensors is not None, # known at compile time
)
# Assume relevant tensors are easily accessible
q_tensor, k_tensor, v_tensor, o_tensor, lse_tensor = get_tensors(...)
# Compile kernel; in a real use case, compiled kernels will be cached
fa_fwd_compiled = cute.compile(
fa_fwd,
q_tensor,
k_tensor,
v_tensor,
o_tensor,
lse_tensor,
softmax_scale,
current_stream,
cu_seqlens_q_tensor,
cu_seqlens_k_tensor,
seqused_q_tensor,
seqused_k_tensor,
page_table_tensor,
None, # window size left
None, # window size right
learnable_sink_tensor,
blocksparse_tensors,
aux_tensors,
)
# Run kernel with new arguments if needed
fa_fwd_compiled(
q_tensor_new,
k_tensor_new,
v_tensor_new,
o_tensor_new,
lse_tensor_new,
softmax_scale_new,
current_stream_new,
cu_seqlens_q_tensor_new,
cu_seqlens_k_tensor_new,
seqused_q_tensor_new,
seqused_k_tensor_new,
page_table_tensor_new,
None, # window size left
None, # window size right
learnable_sink_tensor_new,
blocksparse_tensors_new,
aux_tensors_new,
)
We note that an example for Blackwell architecture would be entirely analogous, replacing FlashAttentionForwardSm90 with FlashAttentionForwardSm100 and modifying kernel parameters accordingly.

Leave a Reply