nanoVLMでKVキャッシュを実装:生成速度が38%向上 nanoVLMリポジトリ(純粋なPyTorchを使用して独自のビジョン言語モデルを学習するための小さなコードベース)でKVキャッシュを一から実装し、生成速度が38%向上しました。このブログポストでは、KVキャッシュの仕組みと実装経験を詳細に説明します。学んだ教訓は一般的であり、すべての自己回帰言語モデルの生成に応用できます。小さなコードベースから一から実装することで得られる学びの体験をお楽しみください!
KVキャッシュを用いたnanoVLMの生成最適化 我々は、純粋なPyTorchを使用して独自のビジョン言語モデル(VLM)を訓練するための小さなコードベースであるnanoVLMリポジトリにおいて、KVキャッシュを自前で実装しました。この最適化により、テキスト生成の速度が約38%向上しました。このブログでは、KVキャッシュの原理と実装経験について詳しく説明します。 自帰生成モデルの基本概念 自帰生成モデルは、順次的にトークンを一つずつ生成します。各ステップで、モデルは与えられた入力シーケンスを処理し、次のトークンを予測し、それをシーケンスに追加するというプロセスを繰り返します。この繰り返しは計算の冗長性を引き起こします。特に、新しいトークン以外の以前のトークンが変更されていないにもかかわらず、每次都心のQ、K、Vベクトルを再計算していたのが問題でした。 トランザフォーマーのアーキテクチャの見直し トランザフォーマー言語モデルでは、各レイヤーが以下のように構成されています: 1. 自己注意機構:入力シーケンスの中でトークン間の関係を学習します。 2. 位置エンコーディング:各トークンの位置情報を提供します。 3. フィードフォワードネットワーク:さらに高度な特徴量を抽出します。 自己注意機構では、入力埋め込み X ∈ R^{T×D} の自己注意は以下の式で計算されます: Attention(X; Q, K, V) = softmax((QK^⊤) × (1/√d_k))V ここで Q, K, V は順方向伝播中に複数回再計算されるため、計算コストが増大します。 KVキャッシュの効果 KVキャッシュは、この冗長性を解消するために使用されます。各レイヤーごとに K と V をキャッシュし、新しいトークンの生成時にのみ更新します。これにより、シーケンス全体の再計算から軽量なインクリメンタルアップデートに変換されます。具体的には、以下の手順が行われます: ** AttentionブロックにおけるKVキャッシュの更新**: 入力を Q, K, V に射影します。 キャッシュが存在する場合は、新しい K と V をキャッシュに追加します。 キャッシュがない初期段階(プリフィル)では、直接 K と V を計算します。 ```python def forward(self, x, cos, sin, attention_mask=None, block_kv_cache=None): is_prefill = block_kv_cache is None B, T_curr, C = x.size() q_curr, k_curr, v_curr = project_current_tokens(x) q, k_rotated = apply_rotary_pos_embd(q_curr, k_curr, cos, sin) if not is_prefill and block_kv_cache['key'] is not None: k = torch.cat([block_kv_cache['key'], k_rotated], dim=2) v = torch.cat([block_kv_cache['value'], v_curr], dim=2) else: k, v = k_rotated, v_curr block_kv_cache = {'key': k, 'value': v} return attention_output, block_kv_cache ``` レイヤー間にわたるキャッシュのトラッキング: 各レイヤーごとにKVキャッシュを追跡します。 start_pos 引数を使用して、新規生成トークンの正解の位置エンコーディングを計算します。 ```python def forward(self, x, kv_cache=None, start_pos=0): T_curr = x.size(1) position_ids = torch.arange(start_pos, start_pos + T_curr, device=x.device) cos, sin = self.rotary_embd(position_ids) for i, block in enumerate(self.blocks): x, kv_cache[i] = block(x, cos, sin, attention_mask, kv_cache[i]) return x, kv_cache ``` 生成ループにおけるプリフィルとデコードの分離: 初期段階(プリフィル)では、入力プロンプトを処理し、各レイヤーのΚVキャッシュを構築します。 その後、デコード段階では、新規トークンを逐一生成するためにキャッシュされた K と V を使用します。 ```python # プリフィル: 入力プロンプトを処理し、キャッシュを満たす prompt_output, kv_cache_list = self.forward(inputs, kv_cache=None, start_pos=0) # デコード: 新規トークンを逐一生成 for i in range(max_new_tokens): next_token = sample_from(prompt_output) decode_output, kv_cache_list = self.forward( next_token, kv_cache=kv_cache_list, start_pos=current_position # 各ステップで更新 ) prompt_output = decode_output ``` 実装による変更点要約 LanguageModelGroupedAttention.forward: 都度Q、K、Vの再計算から、KVキャッシュの使用と更新へ。 LanguageModel.forward: 前状態の記憶なしから、レイヤーごとのKVキャッシュのトラッキングと start_pos の管理へ。 VisionLanguageModel.generate: 一つの生成ループから、プリフィルとデコードの二段階に分割。 KVキャッシュの重要性 インクリメンタル成長: 新規トークンごとにキャッシュが一行増える。 位置認識のデコーディング: start_pos が位置エンコーディングの正しさを保証する。 効率性: トークンあたりの推論時間をO(シーケンス長)から二次的なものに削減。 KVキャッシュは、自帰生成時の不要な計算を排除することで、特に長尺のシーケンスやリアルタイムアプリケーションでの高速な推論を可能にします。ただし、速度とメモリのトレードオフがあり、コードがより複雑になる可能性がある点や、ビームサーチなどの高度な推論手法が制限される可能性があります。現在、KVキャッシュは大規模言語モデル(LLM)の高速推論を支える有力な手法として広く採用され、消費者向けハードウェアでも動作可能にする主な技術の一つとなっています。 この技術革新は、nanoVLMプロジェクトの重要な一歩となっています。KVキャッシュの実装により、計算効率が大幅に改善され、大規模言語モデルの本格的な普及に貢献しています。 nanoVLMは、PyTorchを使用してビョルン言語モデルを訓練するためのコンパクトで自己完結型のコードベースであり、研究者や開発者の教育と理解に役立つリソースとなっています。