Flex Attention for Inference
Lightning Talk: Flex Attention for Inference - Boyuan Feng & Driss Guessous, Meta FlexAttention is a novel compiler-driven programming model that allows implementing the majority of attention variants in a few lines of idiomatic PyTorch code. Since its release in PyTorch 2.5.0, many ML researchers have utilized it to customize their attention kernels without writing kernel code. In this talk, we present recent advances in FlexAttention for Inference, including FlexDecoding, a decoding backend optimized for inference; Grouped Query Attention (GQA) feature; a torch.compile native Paged Attention; along with feature updates including nested jagged tensor support, performance tuning guides, and trainable biases support. More details on our MLSys'25 paper (https://arxiv.org/pdf/2412.05496) and PyTorch Blog (https://pytorch.org/blog/flexattention-for-inference/)!