Mesh-TensorFlow: スーパーコンピュータ用の深層学習

バッチ分割(データ並列処理)は、その普遍的な適用可能性とシングルプログラム・マルチデータ(SPMD)プログラミングへの適合性により、分散ディープニューラルネットワーク(DNN)の学習戦略として主流となっています。しかし、バッチ分割にはメモリ制約による非常に大きなモデルの学習不能、高いレイテンシー、および小さなバッチサイズでの効率の低下などの問題があります。これらの問題は、より一般的な分散戦略(モデル並列処理)によって解決できます。残念ながら、効率的なモデル並列アルゴリズムは発見しやすくなく、説明や実装も特に大規模クラスタでは複雑になります。そこで、我々はMesh-TensorFlowを導入します。これは、一般的なクラスの分散テンソル計算を指定するための言語です。データ並列処理が「バッチ」次元に沿ってテンソルと操作を分割することと見なされる一方で、Mesh-TensorFlowではユーザーが任意のテンソル次元を任意の多次元プロセッサメッシュの次元に分割することができます。Mesh-TensorFlowグラフは、Allreduceなどの集団通信プリミティブと組み合わされた並列操作からなるSPMDプログラムにコンパイルされます。我々はMesh-TensorFlowを使用して、効率的なデータ並列およびモデル並列版のTransformer系列対系列モデルを実装しました。最大512コアまでのTPUメッシュを使用して最大50億パラメータを持つTransformerモデルを学習させることで、WMT'14英仏翻訳タスクと10億単語言語モデリングベンチマークにおいて最先端の結果を超えることができました。Mesh-TensorFlowは https://github.com/tensorflow/mesh で利用可能です。