HyperAIHyperAI

Command Palette

Search for a command to run...

MHLA: Restoring Expressivity of Linear Attention via Token-Level Multi-Head

Kewei Zhang Ye Huang Yufan Deng Jincheng Yu Junsong Chen Huan Ling Enze Xie Daquan Zhou

Abstract

While the Transformer architecture dominates many fields, its quadratic self-attention complexity hinders its use in large-scale applications. Linear attention offers an efficient alternative, but its direct application often degrades performance, with existing fixes typically re-introducing computational overhead through extra modules (e.g., depthwise separable convolution) that defeat the original purpose. In this work, we identify a key failure mode in these methods: global context collapse, where the model loses representational diversity. To address this, we propose Multi-Head Linear Attention (MHLA), which preserves this diversity by computing attention within divided heads along the token dimension. We prove that MHLA maintains linear complexity while recovering much of the expressive power of softmax attention, and verify its effectiveness across multiple domains, achieving a 3.6% improvement on ImageNet classification, a 6.3% gain on NLP, a 12.6% improvement on image generation, and a 41% enhancement on video generation under the same time complexity.

One-sentence Summary

The authors from Peking University and NVIDIA propose Multi-Head Linear Attention (MHLA), a linear-complexity attention mechanism that prevents global context collapse by preserving representational diversity through token-wise head division, outperforming softmax attention in image classification, NLP, image generation, and video generation while maintaining efficiency.

Key Contributions

  • Linear attention mechanisms offer a scalable alternative to quadratic self-attention in Transformers but often suffer from performance degradation due to global context collapse, where the model loses representational diversity by compressing all tokens into a single shared key-value summary, limiting query-specific context retrieval.

  • The proposed Multi-Head Linear Attention (MHLA) addresses this by partitioning tokens into non-overlapping heads along the token dimension, enabling query-dependent context selection through local key-value summaries and a query-conditioned mixing mechanism, while maintaining O(N)O(N)O(N) complexity without additional modules.

  • MHLA achieves state-of-the-art results across multiple domains: a 3.6% accuracy gain on ImageNet, 6.3% improvement on NLP, 12.6% boost in image generation, and a 41% enhancement in video generation, all under the same time complexity as baseline linear attention.

Introduction

The Transformer architecture's dominance in vision, language, and generative modeling is limited by its quadratic self-attention complexity, which hinders scalability to long sequences and high-resolution tasks. Linear attention offers a promising alternative with linear complexity, but prior approaches often degrade in performance due to a fundamental issue: global context collapse, where all tokens are compressed into a single shared key-value summary, reducing representational diversity and capping the rank of attention matrices. This leads to uniform attention distributions and poor query-specific context selection, especially as sequence length grows. Existing fixes rely on auxiliary modules like depthwise convolutions or gating mechanisms, which reintroduce computational overhead and fail to fully restore expressivity.

The authors propose Multi-Head Linear Attention (MHLA), a novel formulation that restores query-dependent diversity by partitioning tokens into non-overlapping heads along the token dimension, computing local key-value summaries, and enabling query-conditioned mixing over these summaries. This design preserves linear O(N)O(N)O(N) complexity while significantly increasing the rank of attention matrices, effectively recovering much of the expressive power of softmax attention. MHLA requires only standard GEMM operations, ensuring compatibility with streaming and stateful execution. Experiments show consistent gains across domains: 3.6% higher ImageNet accuracy, 6.3% improvement on NLP, 12.6% boost in image generation, and a 41% enhancement in video generation—without additional computational cost.

Method

The authors leverage a multi-head linear attention mechanism, termed Multi-Head Linear Attention (MHLA), to address the representational limitations of standard linear attention while preserving linear-time complexity. The framework begins by partitioning the input token sequence into MMM non-overlapping blocks, which are processed in parallel. For each block bbb, a local key-value summary Sb=jbK~jVjS_b = \sum_{j \in b} \widetilde{K}_j V_j^\topSb=jbKjVj and a normalizer zb=jbK~jz_b = \sum_{j \in b} \widetilde{K}_jzb=jbKj are computed, where Q~=ϕ(Q)\widetilde{Q} = \phi(Q)Q=ϕ(Q) and K~=ϕ(K)\widetilde{K} = \phi(K)K=ϕ(K) represent the feature-mapped queries and keys. This blockwise computation allows for efficient, parallel processing of the sequence.

To restore query-conditioned selectivity, MHLA introduces a learnable coefficient matrix McRM×M\mathcal{M}_c \in \mathbb{R}^{M \times M}McRM×M, which governs a "Multi-Head Mixing" process. Each row mim_imi of this matrix represents the learnable, nonnegative mixing coefficients for query block iii, determining how it combines the MMM local key-value summaries into a query-specific mixed summary S~i=b=1Mmi,bSb\widetilde{S}_i = \sum_{b=1}^{M} m_{i,b} S_bSi=b=1Mmi,bSb. This mechanism enables each query block to adaptively reweight the contributions of all other blocks, effectively creating a query-dependent global context. The output for a query vector q~\widetilde{q}q from block iii is then computed as o=q~S~iq~z~io = \frac{\widetilde{q}^\top \widetilde{S}_i}{\widetilde{q}^\top \widetilde{z}_i}o=qziqSi, where z~i=b=1Mmi,bzb\widetilde{z}_i = \sum_{b=1}^{M} m_{i,b} z_bzi=b=1Mmi,bzb is the corresponding mixed normalizer.

The initialization of the coefficient matrix Mc\mathcal{M}_cMc is designed to favor locality, with initial coefficients mi,j(0)1dist(i,j)/maxk(dist(i,k))m_{i,j}^{(0)} \propto 1 - \text{dist}(i,j)/\max_k(\text{dist}(i,k))mi,j(0)1dist(i,j)/maxk(dist(i,k)), which promotes stable and faster convergence. This locality-biased initialization is visualized in the figure, where the coefficients are reshaped into a 2D grid to illustrate the spatial relationship between blocks. The final output is generated by performing an inner block matrix multiplication between the query blocks and the mixed key-value summaries, resulting in a final output tensor of size M×Nb×dM \times N_b \times dM×Nb×d. This two-stage process—block-level selection via the mixing coefficients followed by intra-block reweighting via the kernel inner product—reintroduces query-conditioned selectivity and per-token weighting, which are lost in global linear attention. The overall complexity of MHLA remains linear in the sequence length NNN, as the dominant operations are blockwise summary computation and linear combinations of MMM matrices of size d×dd \times dd×d.

Experiment

  • Main experiments validate that MHLA mitigates global context collapse in linear attention by preserving query-conditioned token-level diversity while maintaining linear-time complexity.
  • On ImageNet-1K, MHLA achieves state-of-the-art accuracy in image classification across DeiT and VLT models, surpassing baselines with minimal parameter overhead.
  • In class-to-image generation, MHLA matches or exceeds self-attention performance on DiT and DiG models, achieving up to 2.1× faster inference than FlashAttention at 512 resolution while improving FID scores.
  • In video generation with 31,500-token sequences, MHLA outperforms vanilla linear attention and matches FlashAttention performance with 2.1× speedup, demonstrating robustness in ultra-long contexts.
  • In natural language processing, MHLA achieves competitive perplexity and zero-shot accuracy on commonsense reasoning and MMLU, and leads in LongBench for long-context tasks, especially in Mult-Doc QA and summarization.
  • Ablation studies confirm that MHLA’s locality-biased initialization and learnable mixing coefficients enhance performance, and that M ≤ √N ensures optimal efficiency and scalability.

The authors evaluate the impact of their initialization strategy and learnable parameters in Multi-Head Mixing by comparing different variants on DeiT-T. Results show that using locality-biased initialization alone achieves strong performance, and adding learnable coefficients further improves accuracy, with the best performance achieved when both are used.

The authors compare MHLA with self-attention in DiT-XL/2 for class-to-image generation, showing that MHLA achieves lower FID and sFID scores while maintaining competitive IS and precision. When classifier-free guidance is applied, MHLA still performs comparably to self-attention, demonstrating its effectiveness in high-resolution image generation.

The authors evaluate the impact of additional modules on MHLA for class-to-image generation, showing that combining CPE and gating reduces the FID score to 59.8, outperforming all other variants. This indicates that the full MHLA configuration with both modules achieves the best generation quality.

The authors evaluate the proposed MHLA method on image generation tasks using DiT and DiG models, comparing it against self-attention and linear attention baselines. Results show that MHLA achieves the best FID scores across all model sizes and resolutions, consistently outperforming linear attention while maintaining high throughput, and matches or exceeds self-attention performance without relying on extra modules like CPE at larger scales.

The authors evaluate MHLA in video generation by fine-tuning a pretrained Wan2.1-1.3B model, replacing its FlashAttention with MHLA. Results show that MHLA achieves substantially stronger performance than vanilla linear attention while maintaining the same latency, recovering performance comparable to the original model and delivering a 2.1× inference speedup.


Build AI with AI

From idea to launch — accelerate your AI development with free AI co-coding, out-of-the-box environment and best price of GPUs.

AI Co-coding
Ready-to-use GPUs
Best Pricing

HyperAI Newsletters

Subscribe to our latest updates
We will deliver the latest updates of the week to your inbox at nine o'clock every Monday morning
Powered by MailChimp