HyperAI
Back to Headlines

nanoVLM : Optimiser la Génération de Texte avec KV Caching DIY

il y a un mois

Résumé de l'implémentation de KV Caching dans nanoVLM Introduction Les modèles de langage auto-régréssifs génèrent du texte de manière séquentielle, en prédiction token par token. Ce processus, bien qu'efficace pour la création de textes, devient redondant et coûteux en calcul lors de l'inference, car il recompute systématiquement les matrices de clés (K) et de valeurs (V) pour l'ensemble de la séquence en cours, même si les tokens antérieurs n'ont pas changé. Dans cet article, nous explorons la technique de KV Caching, qui permet d'optimiser ce processus en conservant les valeurs de K et V déjà calculées. Nous avons mis en œuvre cette technique dans notre dépôt nanoVLM, un code base simplifié pour l'entraînement de modèles Vision Language Model (VLM) avec PyTorch pur. Cette mise en œuvre a engendré une accélération de 38% lors de la génération de texte. Revisiter l'architecture des Transformers Le mécanisme d'attention Dans les modèles Transformer, chaque couche est composée de blocs d'attention auto-référencée. Pour une suite de ( T ) tokens d'entrée représentés par ( X \in \mathbb{R}^{T \times D} ), l'attention auto-référencée est calculée ainsi : [ \text{Attention}(X; Q, K, V) = \text{softmax}\left( \frac{QK^\top \cdot M}{\sqrt{d_k}} \right)V ] Voici un exemple minimal en PyTorch : ```python import torch import torch.nn.functional as F input_seq_length = 5 dim_model = 10 input_ids_emb = torch.randn(input_seq_length, dim_model) W_q = torch.randn(dim_model, dim_model) W_k = torch.randn(dim_model, dim_model) W_v = torch.randn(dim_model, dim_model) Q = input_ids_emb @ W_q K = input_ids_emb @ W_k V = input_ids_emb @ W_v attention_scores = Q @ K.T causal_mask = torch.tril(torch.ones(input_seq_length, input_seq_length)) masked_scores = attention_scores.masked_fill(causal_mask == 0, float('-inf')) attention_weights = F.softmax(masked_scores, dim=-1) output = attention_weights @ V ``` Où se manifeste la redondance Lors de la génération auto-régressive, le modèle genère un token à la fois et recompute les matrices ( Q ), ( K ) et ( V ) pour l'ensemble de la séquence à chaque étape, même si les tokens précédents restent inchangés. Par exemple, si nous ajoutons un nouveau token à une séquence de 5 tokens : ```python new_token_emb = torch.randn(1, dim_model) extended_input = torch.cat([input_ids_emb, new_token_emb], dim=0) Q_ext = extended_input @ W_q K_ext = extended_input @ W_k V_ext = extended_input @ W_v La vérification montre que les matrices \( K \) et \( V \) des 5 premiers tokens sont identiques aux valeurs précédemment calculées :python torch.testing.assert_close(K, K_ext[:input_seq_length]) # test pass torch.testing.assert_close(V, V_ext[:input_seq_length]) # test pass ``` Cette redondance devient de plus en plus coûteuse mesure que les séquences s'allongent. Comment le KV Caching résout ce problème Pour éliminer cette inefficacité, le KV Caching permet de conserver les matrices ( K ) et ( V ) dans un cache. Ce cache, d'une forme générale ((\text{batch_size}, \text{num_heads}, \text{seq_len_cached}, \text{head_dim})), est mis à jour incrémentalement au fur et à mesure que de nouveaux tokens sont générés. Cette technique transforme la génération complète de la séquence en une mise à jour légère et incrémentale. KV Caching dans nanoVLM : De la théorie à la pratique 1. Mise à jour du KV Cache dans le bloc d'attention Dans la classe LanguageModelGroupedAttention, la fonction forward a été modifiée pour accepter et mettre à jour un cache de clés et de valeurs. Au lieu de recomputer ( Q ), ( K ) et ( V ) à chaque étape, le modèle ne compute que ( K_{\text{new}} ) et ( V_{\text{new}} ) pour le token en cours, avant de les ajouter au cache existant : ```python def forward(self, x, cos, sin, attention_mask=None, block_kv_cache=None): is_prefill = block_kv_cache is None B, T_curr, C = x.size() # Project inputs to Q, K, V q_curr, k_curr, v_curr = project_current_tokens(x) q, k_rotated = apply_rotary_pos_embd(q_curr, k_curr, cos, sin) if not is_prefill and block_kv_cache['key'] is not None: # Append new keys and values to the cache k = torch.cat([block_kv_cache['key'], k_rotated], dim=2) v = torch.cat([block_kv_cache['value'], v_curr], dim=2) else: # First pass (prefill) — no cache k, v = k_rotated, v_curr block_kv_cache = {'key': k, 'value': v} return attention_output, block_kv_cache ``` 2. Suivi du cache entre les couches Dans la classe LanguageModel, nous introduisons un suivi du cache au niveau des couches. L'argument start_pos aide le modèle à calculer les encodages positionnels rotatoires pour les nouveaux tokens : ```python def forward(self, x, kv_cache=None, start_pos=0): T_curr = x.size(1) position_ids = torch.arange(start_pos, start_pos + T_curr, device=x.device) cos, sin = self.rotary_embd(position_ids) for i, block in enumerate(self.blocks): # Pass per-layer KV cache x, kv_cache[i] = block(x, cos, sin, attention_mask, kv_cache[i]) return x, kv_cache ``` 3. Phase de pré remplissage et décodage séquentiel La méthode generate() de la classe VisionLanguageModel a été divisée en deux étapes : Phase de pré-remplissage (prefill) : Le modèle traite l'input initial pour construire le cache. Phase de décodage séquentiel : Le modèle génère un token à la fois en utilisant les clés et valeurs en cache. Résumé des modifications | Module | Comportement original | Nouveau comportement | |---------------------------------------------|----------------------------------------------|----------------------------------------------| | LanguageModelGroupedAttention.forward | Recompute ( Q ), ( K ), ( V ) à chaque étape | Utilise et met à jour le KV cache | | LanguageModel.forward | Aucune mémoire de l'état précédent | Suit le cache par couche, gère start_pos | | VisionLanguageModel.generate | Boucle de génération en une phase | Divisé en phases de pré-remplissage et décodage| Pourquoi le KV Caching est important | Bénéfice | Explication | |--------------------------------------------|--------------------------------------------------------------------------| | Croissance incrémentale | Le cache s'agrandit d'une ligne par token nouvellement généré. | | Décodage conscient de la position | start_pos assure la correction des calculs d'encodage positionnel. | | Efficacité | Réduit l'inférence par token à ( O(\text{seq len}) ) au lieu de quadratique. | Le KV Caching élimine les calculs inutiles pendant la génération auto-régressive, ce qui permet une inférence plus rapide et plus efficace, particulièrement utile pour les longues séquences et les applications en temps réel. Cette technique, bien qu'elle implique un compromis entre rapidité et utilisation de la mémoire, peut rendre les LLMs (Large Language Models) plus accessibles sur des appareils grand public. Évaluation de la technique par des professionnels de l'industrie Les experts de l'industrie saluent l'approche de nanoVLM pour l'implémentation de KV Caching. Selon eux, le fait d'avoir un code base compact et bien documenté facilite la compréhension et l'amélioration continue de la technique. NanoGPT, la société derrière nanoVLM, est connue pour proposer des solutions minimalistes et efficaces pour l'apprentissage automatique, ce qui en fait un excellent point de départ pour les développeurs souhaitant explorer des optimisations complexes comme le KV Caching sans être submergés par la taille du code. Cette implémentation contribue significativement à démocratiser l'accès aux modèles de langage avancés, rendant leur utilisation plus accessible et performante.

Related Links