Description
Scaled dot product attention provides significant acceleration of the transformer layer through fusion of the multihead attention layer. There are several different algorithms to achieve this but tiled attention through scaled dot product attention via Flash Attention is a very popular approach. In PyTorch on the ROCm platform this is currently achieved through ahead of time compiled (AOT) Triton kernels in a linkable archive. AMD’s work to enable and package these kernels is done through AOTriton, which aims to use Triton’s compiler and GPU kernels for faster development. AOTriton maintains an optimized set of tiling sizes and other parameters to provide optimized, pre-compiled Triton kernels. The differences between JIT and AOT are few but are very important. Despite this, prototyping kernels in Triton is much faster than template-based C++ libraries. In this presentation we will go into detail on the interaction layer between PyTorch and AOTriton, the structure of AOTriton and how to add new triton kernels to AOTriton.