In this PyTorch blog on which we collaborated, we explain the FlexAttention extension to FlashAttention-4 (or from another point of view, the incorporation of FA-4 as an attention backend for the PyTorch FlexAttention API).

FlexAttention + FlashAttention-4: Fast and Flexible – PyTorch
On Hopper and Blackwell GPUs, FlexAttention now has a FlashAttention-4 backend.
We added support in PyTorch to automatically generate CuTeDSL score/mask modification functions, and to JIT-instantiate FlashAttention-4 for custom attention variants.
This leads to performance gains of 1.2× to 3.2× over the existing Triton implementation on compute-bound workloads.
We added support in PyTorch to automatically generate CuTeDSL score/mask modification functions, and to JIT-instantiate FlashAttention-4 for custom attention variants.
This leads to performance gains of 1.2× to 3.2× over the existing Triton implementation on compute-bound workloads.

Leave a Reply