Blackwell NVFP4-JAX-Training
NVIDIA hat mit dem NVFP4-Trainingsschema ein neues Verfahren für die präzisionsreduzierte Vortraining von Large Language Models vorgestellt. Entwickelt in Zusammenarbeit mit dem JAX-Ökosystem und der MaxText-Bibliothek, zielt die Methode darauf ab, den Durchsatz beim Pretraining auf NVIDIA-Blackwell-Architekturen signifikant zu steigern. Da jeder Prozentpunkt geringerer Schrittzeit bei Milliardenparametermodellen Tage an Rechenzeit und erhebliche Kosten einspart, positioniert sich NVFP4 als entscheidender Hebel für die Skalierung von KI-Fabriken. Das NVFP4-Format nutzt eine Unterbyte-Präzision von vier Bit, um die GEMM-Berechnungen innerhalb der Feed-Forward-Netzwerke der Transformer zu beschleunigen. Zur Wahrung der Konvergenz bleibt die Attention-Projektion sowie die Ausgabeprojektion in höherer Präzision, da hier Quantisierungsrauschen durch die Softmax-Funktion exponentiell verstärkt wird. Stattdessen kommen drei präzise abgestimmte Techniken zum Einsatz: eine 2D-Block-Quantisierung der Gewichte, eine Random-Hadamard-Transformation zur Glättung von Ausreißern vor der Gradientenakkumulation sowie stochastisches Runden zur Vermeidung von Verzerrungen bei kleinen Updates. Native Hardwareunterstützung im GB300-Grace-Blackwell-Ultra-Superchip ermöglicht dabei eine theoretische GEMM-Durchsatzsteigerung um den Faktor sieben gegenüber FP8. Benchmark-Messungen auf GB200- und GB300-Systemen mit den Modellen Llama 3 8B und Llama 3.1 405B bestätigen die theoretischen Vorteile. Im Vergleich zu einer FP8-Baseline mit identischer Hyperparameter- und Parallelisierungskonfiguration liegen die gemessenen TFLOP/s pro GPU bei NVFP4 zwischen 500 und 700 Punkten höher. Die daraus resultierende Beschleunigung variiert je nach Modellgröße und Hardware zwischen 1,31- und 1,73-fach. Besonders das 405-Milliarden-Parameter-Modell profitiert auf dem GB300 mit einer Steigerung um den Faktor 1,73, da der reine Präzisionsgewinn hier direkt in kürzere Wall-Clock-Zeiten übersetzt wird, ohne dass Collective-Overhead von FSDP den Effekt zunichtemacht. Die Genauigkeit bleibt trotz der reduzierten Bitbreite vollständig erhalten. Trainingskurven zeigen über 10.000 Schritte hinweg nahezu identische Loss-Verläufe von etwa 12,2 auf 3,9 Nats. Die konvergierten Regime weisen nur eine minimale Abweichung von 0,026 Nats auf, was innerhalb der natürlichen Trainingsvarianz liegt. Für die praktische Nutzung steht das NVFP4-Recipe im JAX-Toolbox-Repository von MaxText bereit. Die Aktivierung erfolgt ausschließlich über ein Quantisierungsflag im Launch-Skript, empfohlen wird der offizielle NVIDIA-Container mit vorinstallierten JAX-, Transformer-Engine- und CUDA-Abhängigkeiten. Durch die Kombination aus hardwarenaher Optimierung und bewährten Konvergenzstrategien ermöglicht NVFP4 es Forschungsteams und KI-Fabriken, größere Modelle in kürzeren Zeitfenstern zu trainieren oder innerhalb bestehender Budgets höhere Modelleffizienz zu erzielen.
