PyTorch APIs for High Performance MoE Training and Inference
PyTorch APIs for High Performance MoE Training and Inference - Daniel Vega-Myhre; Ke Wen & Natalia Gimelshein, Meta With models like DeepSeekV3 and Llama4 rising in popularity, there has been an increasing demand for PyTorch-native APIs and tailored performance optimizations for MoE architectures. This will be a joint talk between PyTorch Core, Distributed and Performance teams, focusing on features we’ve developed to better support and accelerate both MoE training and inference: The talk will be broadly divided into 3 categories: 1. Computation (grouped GEMMs in PyTorch Core) 2. Communication (all-to-all-v dispatch/combine APIs in PyTorch Distributed) 3. Low precision training and inference optimizations (torchao API for MoE float8 training, low precision comms kernels, differentiable scaled grouped GEMM with dynamic quantization)