Blackwell : JAX accélère l'IA
NVIDIA présente NVFP4, une méthode de pré-entraînement de grands modèles de langage utilisant une précision mixte 4 bits, optimisée pour JAX et le framework MaxText sur l'architecture Blackwell. Cette approche cible un défi majeur de l'IA moderne : réduire drastiquement le temps d'entraînement et les coûts de calcul sans compromettre la précision des modèles. Le format NVFP4 repose sur un encodage par micro-échelle à deux niveaux, permettant de préserver les signaux importants avec moins d'erreurs de quantification qu'autres formats basse précision. Pour garantir la stabilité de l'apprentissage, NVIDIA applique cette quantification uniquement aux couches d'alimentation avant, qui consomment la majeure partie des ressources. Les couches d'attention restent en précision supérieure, car leur mécanisme amplifierait les bruits liés à la quantification. L'intégration au sein de TransformerEngine combine également une quantification par blocs 2D, une transformation de Hadamard aléatoire pour lisser les valeurs aberrantes, et un arrondi stochastique pour maintenir les gradients non biaisés. Les performances ont été évaluées sur les superpuces GB200 et GB300 Grace Blackwell, en pré-entraînant les modèles Llama 3 8B et Llama 3.1 405B. Comparée au standard FP8 actuel, la recette NVFP4 génère un gain de débit compris entre 1,31 et 1,73 fois. Les modèles les plus volumineux tirent le meilleur parti de cette optimisation, avec un gain de 1,73 fois pour le Llama 3.1 405B sur GB300. Ces améliorations proviennent exclusivement du passage à une précision plus basse, les hyperparamètres et la configuration de parallélisme restant identiques. Sur le plan de la précision, les courbes de perte d'entraînement se superposent parfaitement. Après dix mille pas d'entraînement, l'écart moyen de précision entre NVFP4 et FP8 se limite à 0,026 nat, une différence totalement absorbée par le bruit naturel des étapes. Cette démonstration prouve que la réduction de la précision de calcul n'altère pas la convergence du modèle. L'adoption de NVFP4 s'effectue directement via MaxText en activant un simple paramètre de quantification. Un conteneur public fourni par NVIDIA intègre déjà l'environnement JAX, TransformerEngine et les bibliothèques CUDA nécessaires. Les outils de traçage permettent de monitorer le temps par étape et le débit par GPU en temps réel. Cette innovation permet aux centres de calcul de former des modèles plus vastes dans des délais réduits, ou d'optimiser leurs budgets informatiques existants. En sécurisant un gain de performance significatif tout en conservant la fiabilité des résultats, NVFP4 positionne la précision 4 bits comme une norme industrielle viable pour le pré-entraînement à l'échelle du pétatoken.
