HyperAI
Back to Headlines

KV-Cache steigert NanoVLM-Generierung um 38%

vor 25 Tagen

Wir haben KV-Caching von Grund auf neu in unserem nanoVLM-Repository implementiert (ein kleines Codebasis, um ein eigenes Vision-Language-Modell mit reinem PyTorch zu trainieren). Dies hat uns eine Geschwindigkeitsverbesserung von 38% bei der Textgenerierung gebracht. In diesem Blogbeitrag behandeln wir KV-Caching und alle unsere Erfahrungen bei dessen Implementierung. Die gelernten Lektionen sind allgemein und können auf die Generierung aller autoregressiven Sprachmodelle angewendet werden. Die Implementierung von Grund auf neu in einem kleinen Codebasis ist eine großartige Lerngelegenheit – folgen Sie uns auf dieser Reise! Einführung Autoregressive Sprachmodelle erzeugen Text, indem sie Token für Token generieren. Während der Inferenz verarbeitet das Modell eine gegebene Eingabeabfolge, prognostiziert den nächsten Token, hängt ihn an die Abfolge an und wiederholt diesen Prozess, bis ein gewisses Stoppkriterium erreicht ist. Diese schrittweise Generierung ist inhärent sequentiell, und die Wiederholung führt zu rechnerischer Redundanz. In diesem Beitrag untersuchen wir KV-Caching, eine Optimierungstechnik, die diese Ineffizienz reduziert. Rückblick auf die Transformer-Architektur Bevor wir uns dem Caching zuwenden, werfen wir einen Blick darauf, wie die Aufmerksamkeit in Transformer-Modellen funktioniert. Ein Transformer-Sprachmodell besteht aus stapelbaren Schichten, jede aus: Um zu verstehen, wo KV-Caching hilft, konzentrieren wir uns auf den Selbst-Aufmerksamkeitsmechanismus innerhalb eines einzelnen Aufmerksamkeitskopfes. Hier ist eine einfache PyTorch-Implementierung, um die wesentlichen Berechnungen zu visualisieren: ```python import torch input_seq_length = 5 dim_model = 10 input_ids_emb = torch.randn(input_seq_length, dim_model) W_q = torch.randn(dim_model, dim_model) W_k = torch.randn(dim_model, dim_model) W_v = torch.randn(dim_model, dim_model) Q = input_ids_emb @ W_q K = input_ids_emb @ W_k V = input_ids_emb @ W_v ``` Selbst-Aufmerksamkeitsberechnung Für eine Sequenz von ( T ) Eingabe-Embeddings, dargestellt als ( X \in \mathbb{R}^{T \times D} ), wird die Selbst-Aufmerksamkeit wie folgt berechnet: Die endgültige Ausgabe lautet: [ \text{Attention}(X; Q, K, V) = \text{softmax}\left( \frac{QK^\top \cdot M}{\sqrt{d_k}} \right)V ] Hier ist eine minimale PyTorch-Implementierung unter Verwendung einer kausalen Maske: ```python import torch.nn.functional as F attention_scores = Q @ K.T Untere Dreiecksmaske, um den Zugriff auf zukünftige Token zu verhindern causal_mask = torch.tril(torch.ones(input_seq_length, input_seq_length)) masked_scores = attention_scores.masked_fill(causal_mask == 0, float('-inf')) attention_weights = F.softmax(masked_scores, dim=-1) output = attention_weights @ V ``` Wo Redundanzen auftreten Bei autoregressiver Generierung generiert das Modell einen Token nach dem anderen. Bei jedem Schritt werden ( Q ), ( K ) und ( V ) für die gesamte Sequenz neu berechnet, obwohl die früheren Token nicht verändert wurden. ```python new_token_emb = torch.randn(1, dim_model) extended_input = torch.cat([input_ids_emb, new_token_emb], dim=0) Q_ext = extended_input @ W_q K_ext = extended_input @ W_k V_ext = extended_input @ W_v ``` Um die Redundanz zu bestätigen: python torch.testing.assert_close(K, K_ext[:input_seq_length]) # Test erfolgreich torch.testing.assert_close(V, V_ext[:input_seq_length]) # Test erfolgreich Diese Überprüfungen zeigen, dass für alle Token außer dem neuesten ( K ) und ( V ) identisch mit den vorher berechneten Werten sind. Original (5×5): Erweitert (6×6): ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ □ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ □ ■ ■ ■ ■ ■ → ■ ■ ■ ■ ■ □ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ □ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ □ □ □ □ □ □ □ Die meisten Aufmerksamkeitsberechnungen werden unnötigerweise wiederholt, was teurer wird, je länger die Sequenzen werden. Wie KV-Caching das behebt Um diese Ineffizienz zu beseitigen, verwenden wir KV-Caching: Dies ändert die Generierung von einer vollständigen Sequenzneuberechnung zu einer leichten, inkrementellen Aktualisierung. ✅ In der Praxis ist dieser Cache ein pro-Schicht-Wörterbuch mit den Schlüsseln "key" und "value", jeweils von der Form ((\text{batch_size}, \text{num_heads}, \text{seq_len_cached}, \text{head_dim})). Dies bildet die Grundlage dafür, wie moderne Große Sprachmodelle (LLMs) effizient lange Ausgaben erzeugen können. KV-Caching in nanoVLM: Von Theorie zur Praxis Nun, da wir die Theorie hinter KV-Caching verstehen, sehen wir uns an, wie es in der Praxis in unserem nanoVLM-Repository implementiert wird. Dies ist eine ideale Testumgebung, da es sich um einen sehr prägnanten und selbstständigen Codebasis handelt. KV-Caching wird in drei zentralen Komponenten unseres Modells aktiviert: Der Aufmerksamkeitsblock, der KV-Cache verwendet und aktualisiert Das Sprachmodell, das KV-Cache pro Schicht verfolgt Die Generierungsschleife, die zwischen Vorfüllphase (Prefill) und sequenzieller Dekodierphase getrennt ist 1. Aktualisierung des KV-Caches im Aufmerksamkeitsblock In der Klasse LanguageModelGroupedAttention modifizieren wir die forward-Funktion, um einen Cache von Keys und Values (block_kv_cache) zu akzeptieren und zu aktualisieren. Früher berechnete das Modell ( K ) und ( V ) bei jedem Generierungsschritt neu. Jetzt berechnen wir nur ( K_{\text{new}} ) und ( V_{\text{new}} ) für den aktuellen Token und hängen sie an die zwischengespeicherten Werte an. ```python def forward(self, x, cos, sin, attention_mask=None, block_kv_cache=None): is_prefill = block_kv_cache is None B, T_curr, C = x.size() # Projektieren der Eingaben auf Q, K, V q_curr, k_curr, v_curr = project_current_tokens(x) q, k_rotated = apply_rotary_pos_embd(q_curr, k_curr, cos, sin) if not is_prefill and block_kv_cache['key'] is not None: # Neue Keys und Values an den Cache anhängen k = torch.cat([block_kv_cache['key'], k_rotated], dim=2) v = torch.cat([block_kv_cache['value'], v_curr], dim=2) else: # Erster Durchgang (Vorfüllphase) – kein Cache k, v = k_rotated, v_curr block_kv_cache = {'key': k, 'value': v} return attention_output, block_kv_cache ``` 2. Cache-Verfolgung über Schichten hinweg In der Klasse LanguageModel führen wir eine schichtweise Cache-Verfolgung ein. Das Argument start_pos hilft dem Modell, korrekte rotatorische Positionierungscodierungen für neu erzeugte Token zu berechnen. ```python def forward(self, x, kv_cache=None, start_pos=0): T_curr = x.size(1) position_ids = torch.arange(start_pos, start_pos + T_curr, device=x.device) cos, sin = self.rotary_embd(position_ids) for i, block in enumerate(self.blocks): # Pro-Schicht KV-Cache übergeben x, kv_cache[i] = block(x, cos, sin, attention_mask, kv_cache[i]) return x, kv_cache ``` 3. Vorfüllphase vs. Dekodierphase in der Generierungsschleife Die größte architektonische Änderung liegt in der generate()-Methode des VisionLanguageModel. Wir teilen die Generierung in zwei Phasen: VORFÜLLPHASE (Cache-Konstruktion): [Eingabe: "Was ist"] → [Transformer] → [Cache: ( K ) und ( V ) für alle Schichten] DEKODIERPHASE (Token-für-Token): [Token: "das"] → [( Q(\text{das}) + \text{gespeicherte } K/V )] → [nächster Token: "?"] → ... Hier ist der entsprechende Code: ```python VORFÜLL: Die Eingabe-Eingabeabfolge verarbeiten und den Cache füllen prompt_output, kv_cache_list = self.forward( inputs, kv_cache=None, start_pos=0 ) DEKODIER: Token für Token generieren, indem der gespeicherte K/V genutzt wird for i in range(max_new_tokens): next_token = sample_from(prompt_output) decode_output, kv_cache_list = self.forward( next_token, kv_cache=kv_cache_list, start_pos=current_position # mit jedem Schritt aktualisiert ) prompt_output = decode_output ``` Durch die Trennung dieser Phasen vermeiden wir redundante Berechnungen und beschleunigen die Inferenz dramatisch, insbesondere bei langen Eingaben. Zusammenfassung der Änderungen | Modul | Ursprüngliches Verhalten | Neues Verhalten | |-------|--------------------------|-----------------| | LanguageModelGroupedAttention.forward | ( Q ), ( K ), ( V ) bei jedem Schritt neu berechnen | KV-Cache verwenden und aktualisieren | | LanguageModel.forward | Keine Erinnerung an den vorherigen Zustand | Pro-Schicht KV-Cache verfolgen, start_pos verwalten | | VisionLanguageModel.generate | Einphasige Generierungsschleife | Geteilt in Vorfüll- und Dekodierphasen | Zusammenfassung: Warum KV-Caching wichtig ist | Nutzen | Erklärung | |--------|-----------| | Inkrementelles Wachstum | Der Cache wächst bei jedem neuen Token um eine Zeile. | | Positionsbezogenes Decoding | start_pos stellt sicher, dass die Berechnungen der Positionierungscodierungen korrekt sind. | | Effizienz | Reduziert die pro-Token-Inferenz von quadratisch auf linear in der Sequenzlänge. | KV-Caching eliminiert unnötige Berechnungen während der autoregressiven Generierung, was eine schnellere und effizientere Inferenz ermöglicht, besonders bei langen Sequenzen und in Echtzeit-Anwendungen. Es handelt sich dabei um einen Kompromiss zwischen Geschwindigkeit und Speicher, wobei die Nachteile komplexeres Code und Einschränkungen für fortgeschrittene Inferenzverfahren wie Beam-Search etc. sein können. KV-Caching ist eine weit verbreitete Methode zur Beschleunigung der Inferenz bei großen Sprachmodellen, was es möglich macht, sie auf verbraucherseitiger Hardware zu betreiben. Jetzt wissen Sie auch, wie es funktioniert! Bewertung durch Branchenexperten und Unternehmensprofile KV-Caching ist eine Technik, die in der Branche zunehmend an Bedeutung gewinnt, da sie eine wesentliche Rolle bei der Skalierung von Sprachmodellen spielt. Unternehmen wie Google und Meta setzen sie in ihren fortschrittlichsten Modellen ein, um die Performanz bei der Textgenerierung zu steigern. Obwohl die Implementierung zusätzliche Komplexität bringt, lohnt sich der Aufwand durch die signifikanten Leistungsverbesserungen.NanoVLM, ein kleines und übersichtliches Projekt, zeigt, dass KV-Caching auch in kleineren, lernenden Umgebungen praktisch und lehrreich eingesetzt werden kann.

Related Links