JAX/MaxText NVFP4 Blackwell
NVIDIA has released the NVFP4 training recipe, a high-throughput four-bit mixed-precision pretraining framework for JAX and the MaxText library, optimized for NVIDIA Blackwell accelerators. The release directly addresses the computational bottlenecks inherent in training large language models by delivering up to a 1.73x throughput increase over traditional FP8 baselines without compromising model accuracy. The NVFP4 architecture utilizes a subbyte precision format backed by two-level microscaling to minimize quantization error at ultra-low bit depths. NVIDIA engineers designed the pretraining methodology to selectively apply NVFP4 quantization exclusively to multi-layer perceptron feed-forward layers. This targeted approach preserves higher precision within attention mechanisms, preventing the softmax function from amplifying quantization noise. To maintain convergence stability, the recipe integrates two-dimensional block quantization for weight consistency, a Random Hadamard Transform to flatten input outliers, and stochastic rounding to preserve gradient accuracy during optimizer updates. All compute operations consume NVFP4 inputs and reconstruct outputs in higher precision formats, ensuring seamless integration with existing training pipelines. Developers can deploy the framework immediately through the MaxText JAX-Toolbox repository. By enabling the dedicated quantization flag, users can initiate pretraining jobs on NVIDIA GB200 and GB300 Grace Blackwell Superchips. Rigorous benchmarking of Llama 3 8B and Llama 3.1 405B architectures demonstrates consistent performance gains across all configurations. On the GB200 platform, per-GPU throughput increased by 1.35x for the 8B model and 1.44x for the 405B variant. The GB300 architecture delivered even more pronounced improvements, achieving 1.31x and 1.73x speedups respectively. These gains translate directly into substantial reductions in wall-clock training time, allowing AI research facilities to accelerate model development cycles or scale larger architectures within fixed compute budgets. Crucially, performance enhancements do not come at the expense of convergence. Extended training runs across ten thousand steps confirm that NVFP4 pretraining tracks the exact loss trajectory of standard FP8 implementations, maintaining a negligible mean difference of approximately 0.026 nats within the converged regime. The framework represents a significant advancement in efficient large-scale machine learning infrastructure, providing developers with a production-ready pathway to maximize Blackwell hardware utilization for next-generation language model training.
