Meta の内部で使用される FX ツールの優れた入門書: グラフ変換を使用した PyTorch モデルの最適化

特色图像

PyTorch のグラフ モードは、パフォーマンスの点で優れています。この記事では、PyTorch プログラム グラフをキャプチャして最適化できる強力なツールである Torch.FX を紹介します。

1. はじめに

PyTorch は、eager モードとグラフ モードの 2 つの実行モードをサポートします。

熱心なモードでは、モデル内の演算子が読み取られるとすぐに実行されます。使いやすく、機械学習の実践者にとってより使いやすいため、デフォルトの実行モードとして設定されています。

グラフ モードでは、演算子をグラフに合成してから全体としてコンパイルして実行するため、パフォーマンスが高く、実際の運用でよく使用されます。

具体的には、グラフ モードは演算子の融合をサポートしており、2 つの演算子を結合することで、メモリ読み取りとカーネル起動の合計オーバーヘッドを削減または局所化できます。

Fusion は水平方向に行うことができます。複数のオペランドに適用される単一の操作 (BatchNorm など) を受け取り、オペランドを配列にマージします。

融合は垂直にすることもできます。あるカーネルを別のカーネルとマージするには、最初のカーネルの出力を使用する必要があります (ReLU の後に畳み込みを行うなど)。

Torch.FX (FX と省略) は、PyTorch パッケージの一部としてグラフ モードの実行をサポートする、公開されているツールキットです。できる:

1. PyTorchプログラムからグラフを取得

2. 開発者が取得したグラフに変換を記述できるようにする

Meta はこれまで、実稼働モデルのトレーニング スループットを最適化するために内部で FX を使用してきました。この記事では、Meta によって開発された FX ベースの最適化を紹介し、グラフ変換を使用して PyTorch デプロイメント モデルのパフォーマンスを最適化する方法を示します。

2. 背景

埋め込みテーブルはレコメンデーション システムで広く使用されています。このセクションでは、FX と埋め込みテーブルに関する背景知識を提供します。

2.1.FX

図 1 は、FX を使用して PyTorch プログラムを変換する方法を示す簡単な例です。これには 3 つのステップが含まれます。

  • プログラムからグラフを取得
  • グラフを変更します (この例では、RELU の代わりに GELU を使用します)
  • 変更したグラフから新しいプログラムを生成
図 1: PyTorch モジュールで RELU の代わりに GELU を使用する FX

FX API は、PyTorch プログラム グラフを検査および変換するための多くの追加機能を提供します。

2.2. 埋め込みテーブル

図 2: バッチ サイズ = 1 のスパース特徴埋め込みテーブルの概略図

推薦制度では、スパースな特徴 (ユーザー ID、ストーリー ID など) は、埋め込みテーブルによって表されます。

埋め込みテーブル E は HxD 行列です。ここで、H はハッシュ サイズ、D は埋め込みベクトルの次元です。 E の各行は浮動小数点数のベクトルです。

特徴ハッシュの機能は、スパース特徴を E のインデックス リスト ([S1, S2,…,Sk] など) にマッピングすることです (0≤Si)。

GPU を最大限に活用するために、通常、スパース フィーチャはバッチで処理されます。バッチ内の各エンティティには独自のインデックス リストがあります。バッチに B 個のエンティティがある場合、それは B 個のインデックス リストを持つ表現として単純に理解できます。

より厳密な表現は、B のインデックス リストを 1 つのインデックス リストにマージし、インデックスの長さのリストを追加することです (バッチ内の各エンティティには長さがあります)。

たとえば、バッチに 3 つのエンティティが含まれる場合、そのインデックス リストは次のようになります。

  • エンティティ 1: インデックス = [10, 20]
  • エンティティ 2: インデックス = [5, 9, 77, 81]
  • エンティティ 3: インデックス = [15, 20, 45]

この場合、完全なバッチ サイズのインデックスと長さは次のようになります。

  • インデックス = [10、20、5、9、77、81、15、20、45]
  • 長さ = [2, 4, 3]

バッチ埋め込みテーブル クエリ全体の出力は BxD 行列です。

3. 3種類のFX変換

PyTorch は、埋め込みテーブルへのアクセスを高速化するために 3 つの FX 変換を更新しました。このセクションでは、それらを 1 つずつ紹介します。

次の 3.1 は、複数の小さな入力テンソルを 1 つの大きなテンソルに結合する変換についてです。3.2 は、複数の並列計算チェーンを 1 つの計算チェーンに結合する変換についてです。3.3 は、重複する通信と計算の変換についてです。

3.1 入力スパース特徴の結合

バッチ内の各入力スパース フィーチャは、インデックス リストと B 長さリストの 2 つのリストとして表すことができます。ここで、B はバッチ サイズを表します。

PyTorch では、両方のリストがテンソルとして存在できます。PyTorch モデルが GPU で実行される場合、通常、埋め込みテーブルは GPU メモリ (GPU に近く、CPU メモリよりも高い読み取りおよび書き込み帯域幅を持っています) に保存されます。

入力スパース特徴を使用する必要がある場合、最初に両方のテンソルを CPU から GPU にコピーする必要があります。ただし、ホストからデバイスへのメモリ コピーごとにカーネルを起動する必要があり、実際のデータ転送よりも時間がかかります。

モデルが複数の入力スパース フィーチャを使用する場合、このコピーがパフォーマンスのボトルネックになる可能性があります (たとえば、1000 の入力スパース フィーチャでは、ホストからデバイスに 2000 のテンソルをコピーする必要があります)。

ホストからデバイスへの memcpy の数を減らす最適化方法は、複数の入力スパース特徴をデバイスに送信する前に結合することです。

たとえば、次の 3 つの入力特徴があるとします。

  • 機能_A: インデックス = [106, 211, 7]、長さ = [2, 1]
  • 機能_B: インデックス = [52, 498, 616, 870, 1013]、長さ = [3, 2]
  • 機能_C: インデックス = [2011, 19, 351, 790]、長さ = [1, 3]

組み合わせた形式は次のとおりです。

機能_A_B_C: インデックス = [106, 211, 7, 52, 498, 616, 870, 1013, 2011, 19, 351, 790]、長さ = [2, 1, 3, 2, 1, 3]

したがって、3×2=6 個のテンソルをホストからデバイスにコピーする必要はなく、2 個のテンソルだけをコピーする必要があります。

図 3(b) は、この最適化の実装を示しています。これは 2 つのコンポーネントで構成されます。

  • CPU側:入力パイプラインは、すべてのスパース特徴インデックスを 1 つのテンソルに結合し、すべての長さを別のテンソルに結合するように変更されます。これら 2 つのテンソルは GPU にコピーされます。
  • GPU側:FX を使用すると、Permute_and_Split オペレーターがモデル グラフに挿入され、マージされたテンソルから個々の特徴インデックスと長さテンソルが復元され、下流の対応するノードに送信されます。
最適化前: 両方のテンソルを CPU から GPU にコピーする必要があります
最適化後: 入力スパース特徴を結合する

3.2 埋め込みテーブルへのアクセスから始まるコンピューティングチェーンの水平統合

実稼働モデルでは、GPU ごとに 10 個の埋め込みテーブルがあるのが一般的です。パフォーマンス上の理由から、これらのテーブルに対するクエリはグループ化され、その出力が 1 つの大きなテンソルに連結されます。(図 4(a) の赤い部分を参照)。

単一の特徴出力に対して計算を実行するには、Split オペレーターを使用して、大きなテンソルを N 個の小さなテンソルに分割します。(N は特徴の数です) そして、必要な計算を各テンソルに適用します。

図 4(a) に示すように、各特徴出力 O に適用される計算は Tanh(LayerNorm(O)) です。すべての計算結果は大きなテンソルに連結され、下流の演算子 (図 4(a) の Op1) に渡されます。

ここでの主なランタイム コストは、GPU カーネルの起動コストです。たとえば、図 4(a) の GPU カーネルの起動数は 2*N+3 です (図中の各楕円は GPU カーネルを表します)。 GPU での LayerNorm と Tanh の実行時間はカーネルの起動時間に比べて非常に短いため、これはパフォーマンスに影響します。

さらに、Split オペレーターは、埋め込みベクトル出力テンソルの追加コピーを作成し、追加の GPU メモリを消費する可能性があります。

FX を使用して水平フュージョンと呼ばれる最適化を実装すると、GPU カーネルの起動回数を大幅に削減できます。(この例では、最適化された GPU カーネルの開始数は 5 です。図 4(b) を参照)。

明示的な Split の代わりに Add_middle_dim オペレーターを使用して、形状 (B、NxD) の 2D 埋め込みテンソルを形状 (B、N、D) の 3D テンソルに再形成します。次に、単一の LayerNorm を最後の次元に適用します。 Tanh を LayerNorm の結果に適用します。最後に、Remove_middle_dim オペレーターを使用して Tanh の結果を 2D テンソルに復元します。

Add_middle_dim と Remove_middle_dim はテンソルを再形成するだけなので、追加のコピーが作成されないため、GPU メモリの消費も削減できます。

最適化前: すべての出力が 1 つの大きなテンソルに連結されます
水平統合最適化後

3.3 計算と通信の重複

実稼働用の推奨モデルのトレーニングは、通常、分散 GPU システム上で完了します。各 GPU のデバイス メモリ容量は、モデル内のすべての埋め込みテーブルを収容するには十分ではないため、複数の GPU に分散する必要があります。

トレーニング ステップ中に、GPU は他の GPU 上の埋め込みテーブルから特徴値を読み書きする必要があります。これは全対全通信と呼ばれ、パフォーマンスに大きく寄与する可能性があります。

FXによる変換を実装することで、全対全通信に計算を重ねることができます。図 5(a) は、埋め込みベクトル テーブル アクセス (EmbeddingAllToAll) およびその他の演算子を備えたモデル グラフ インスタンスを示しています。図 5(b) に示すように、これらは最適化なしで GPU ストリーム上で順次実行されます。

FX を使用して、EmbeddingAllToAll を EmbeddingAllToAll_Request と EmbeddingAllToAll_Wait に分割し、それらの間に独立した演算子を配置します。

図 5: 計算と通信の重複

3.4 概要

表 1: このセクションで説明する最適化と、対応する解決されたパフォーマンスのボトルネック

これらの変換からどのモデルが恩恵を受けるかを発見するために、開発者は、メタ データ センターで実行されているモデルに関して MAIProf によって収集されたパフォーマンス データを分析しました。これらの変換により、一連の実稼働モデルでは、eager モードと比較して 2 ~ 3 倍の高速化が達成されることがわかりました。

4. 結論

パフォーマンスの観点から見ると、PyTorch のグラフ モードは、運用環境で使用される Eager モードよりも人気があります。 FX は、PyTorch プログラム グラフをキャプチャして最適化するための強力なツールです。この記事では、Meta 内の実稼働推奨モデルを最適化するための 3 つの FX 変換を示します。

最後に、より多くの PyTorch 開発者がグラフ変換を使用してモデルのパフォーマンスを向上できることを願っています。

——終わり——