HyperAIHyperAI
Back to Headlines

nanoVLM Implements KV Caching for 38% Faster Autoregressive Text Generation

3 months ago

Implementation of KV Caching in nanoVLM: A 38% Speedup in Text Generation Introduction to Autoregressive Language Models Autoregressive language models generate text by predicting one token at a time and appending it to the existing sequence, repeating this process until a stopping criterion is met. This sequential generation, while effective, can be computationally expensive, particularly because it involves redundant calculations at each step. Revisiting the Transformer Architecture Transformers consist of stacked layers, each containing a self-attention mechanism. The self-attention module computes query (Q), key (K), and value (V) matrices from the input embeddings. For a sequence of length ( T ) and embeddings of dimension ( D ), the attention mechanism is defined as: [ \text{Attention}(X; Q, K, V) = \text{softmax}\left( \frac{QK^\top \cdot M}{\sqrt{d_k}} \right)V ] Here, ( M ) is a causal mask that prevents the model from attending to future tokens. In PyTorch, this can be visualized and computed as follows: ```python import torch import torch.nn.functional as F input_seq_length = 5 dim_model = 10 input_ids_emb = torch.randn(input_seq_length, dim_model) W_q = torch.randn(dim_model, dim_model) W_k = torch.randn(dim_model, dim_model) W_v = torch.randn(dim_model, dim_model) Q = input_ids_emb @ W_q K = input_ids_emb @ W_k V = input_ids_emb @ W_v attention_scores = Q @ K.T causal_mask = torch.tril(torch.ones(input_seq_length, input_seq_length)) masked_scores = attention_scores.masked_fill(causal_mask == 0, float('-inf')) attention_weights = F.softmax(masked_scores, dim=-1) output = attention_weights @ V ``` Where Redundancy Occurs During autoregressive generation, the model generates one token at a time and recomputes the ( Q ), ( K ), and ( V ) matrices for the entire sequence at each step, even though the earlier tokens remain unchanged. For example, consider a sequence initially of length 5: ``` Original (5×5): ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ Extended (6×6): ■ ■ ■ ■ ■ □ ■ ■ ■ ■ ■ □ ■ ■ ■ ■ ■ □ ■ ■ ■ ■ ■ □ ■ ■ ■ ■ ■ □ □ □ □ □ □ □ ``` The redundancy is evident: most of the ( K ) and ( V ) matrices are identical to their previous values, and only the new token requires computation. How KV Caching Mitigates Redundancy KV Caching alleviates this inefficiency by storing the ( K ) and ( V ) matrices from previous steps and updating them incrementally. Here’s how it works in practice: Initialization: During the initial pass (prefill phase), the model processes the input prompt and stores the ( K ) and ( V ) matrices in a cache. Incremental Updates: For each subsequent token, the model only computes new ( K ) and ( V ) values for the current token and appends them to the cache. This avoids the need to recompute the entire ( K ) and ( V ) matrices. Implementation in nanoVLM The nanoVLM repository, built using pure PyTorch, provides a clean and concise environment for implementing KV Caching. The implementation involves modifications to three primary components: Attention Block: Original Behaviour: Recomputes ( Q ), ( K ), and ( V ) on every step. New Behaviour: Uses and updates a per-layer KV cache. ```python def forward(self, x, cos, sin, attention_mask=None, block_kv_cache=None): is_prefill = block_kv_cache is None B, T_curr, C = x.size() q_curr, k_curr, v_curr = project_current_tokens(x) q, k_rotated = apply_rotary_pos_embd(q_curr, k_curr, cos, sin) if not is_prefill and block_kv_cache['key'] is not None: k = torch.cat([block_kv_cache['key'], k_rotated], dim=2) v = torch.cat([block_kv_cache['value'], v_curr], dim=2) else: k, v = k_rotated, v_curr block_kv_cache = {'key': k, 'value': v} return attention_output, block_kv_cache ``` Language Model: Original Behaviour: No memory of previous state. New Behaviour: Tracks per-layer KV cache and handles the start_pos argument to ensure correct position encoding. ```python def forward(self, x, kv_cache=None, start_pos=0): T_curr = x.size(1) position_ids = torch.arange(start_pos, start_pos + T_curr, device=x.device) cos, sin = self.rotary_embd(position_ids) for i, block in enumerate(self.blocks): x, kv_cache[i] = block(x, cos, sin, attention_mask, kv_cache[i]) return x, kv_cache ``` Generation Loop: Original Behaviour: One-phase generation loop. New Behaviour: Split into prefill and decode phases. ```python # PREFILL: Process the input prompt, fill the cache prompt_output, kv_cache_list = self.forward(inputs, kv_cache=None, start_pos=0) # DECODE: Generate one token at a time using cached K/V for i in range(max_new_tokens): next_token = sample_from(prompt_output) decode_output, kv_cache_list = self.forward(next_token, kv_cache=kv_cache_list, start_pos=current_position) prompt_output = decode_output ``` Summary of Changes LanguageModelGroupedAttention.forward: Transitioned from recomputing ( Q ), ( K ), and ( V ) on every step to using and updating KV cache. LanguageModel.forward: Introduced per-layer KV cache tracking and handled start_pos for accurate position encodings. VisionLanguageModel.generate: Split the generation loop into two phases—prefill and decode—to reduce redundant computations. Key Benefits of KV Caching Incremental Growth: The cache grows by one row per new token, maintaining efficiency. Position-Aware Decoding: The start_pos argument ensures that position encoding calculations remain correct. Efficiency: Reduces per-token inference complexity to ( O(\text{seq len}) ) instead of quadratic, leading to significant speedups. Evaluation and Impact Industry experts have praised the effectiveness of KV Caching in enhancing the performance of large language models (LLMs). By significantly reducing the computational overhead, KV Caching makes it feasible to run these models on consumer-grade hardware, opening up new possibilities for real-time applications. However, it does introduce some complexity in the code and can constrain more sophisticated inference methods like beam search. The nanoVLM project stands out for its educational value, offering a clear and concise implementation that demystifies the intricacies of KV Caching for aspiring developers and researchers. Company Profile: nanoVLM is a small and open-source project aimed at making state-of-the-art Vision-Language Models accessible to a broader audience by providing a minimalistic and understandable codebase. It is built entirely with PyTorch, emphasizing simplicity and modularity to facilitate learning and experimentation.

Related Links

nanoVLM Implements KV Caching for 38% Faster Autoregressive Text Generation | Headlines | HyperAI