HyperAIHyperAI

Command Palette

Search for a command to run...

MHLA: 토큰 수준 다중 헤드를 통한 선형 어텐션의 표현력 복원

Kewei Zhang Ye Huang Yufan Deng Jincheng Yu Junsong Chen Huan Ling Enze Xie Daquan Zhou

초록

Transformer 아키텍처는 여러 분야에서 지배적인 위치를 차지하고 있지만, 그 제곱형(self-attention) 복잡도는 대규모 응용 분야에서의 활용을 제한한다. 선형 주의(linear attention)는 효율적인 대안을 제공하지만, 직접 적용할 경우 성능 저하가 발생하는 경우가 많으며, 기존의 해결책들은 보통 깊이 분리형 합성곱(depthwise separable convolution)과 같은 추가 모듈을 도입함으로써 원래의 계산 효율성 목적을 무효화하는 계산 부담을 재도입한다. 본 연구에서는 이러한 기법들에서 발생하는 핵심적인 실패 원인을 규명한다: 전역적 컨텍스트 붕괴(global context collapse), 즉 모델이 표현적 다양성을 상실하는 현상이다. 이를 해결하기 위해, 토큰 차원을 분할된 헤드 내에서 주의를 계산함으로써 표현적 다양성을 유지하는 다중 헤드 선형 주의(Multi-Head Linear Attention, MHLA)를 제안한다. 우리는 MHLA가 선형 복잡도를 유지하면서 소프트맥스 주의의 표현력을 상당 부분 회복함을 증명하였으며, 다양한 분야에서의 효과성을 검증하였다. 동일한 시간 복잡도 조건 하에서 이미지 분류(ImageNet)에서 3.6%의 성능 향상, 자연어 처리(NLP)에서 6.3%의 성능 향상, 이미지 생성에서 12.6%의 성능 향상, 동영상 생성에서는 41%의 성능 향상을 달성하였다.

One-sentence Summary

The authors from Peking University and NVIDIA propose Multi-Head Linear Attention (MHLA), a linear-complexity attention mechanism that prevents global context collapse by preserving representational diversity through token-wise head division, outperforming softmax attention in image classification, NLP, image generation, and video generation while maintaining efficiency.

Key Contributions

  • Linear attention mechanisms offer a scalable alternative to quadratic self-attention in Transformers but often suffer from performance degradation due to global context collapse, where the model loses representational diversity by compressing all tokens into a single shared key-value summary, limiting query-specific context retrieval.

  • The proposed Multi-Head Linear Attention (MHLA) addresses this by partitioning tokens into non-overlapping heads along the token dimension, enabling query-dependent context selection through local key-value summaries and a query-conditioned mixing mechanism, while maintaining O(N)O(N)O(N) complexity without additional modules.

  • MHLA achieves state-of-the-art results across multiple domains: a 3.6% accuracy gain on ImageNet, 6.3% improvement on NLP, 12.6% boost in image generation, and a 41% enhancement in video generation, all under the same time complexity as baseline linear attention.

Introduction

The Transformer architecture's dominance in vision, language, and generative modeling is limited by its quadratic self-attention complexity, which hinders scalability to long sequences and high-resolution tasks. Linear attention offers a promising alternative with linear complexity, but prior approaches often degrade in performance due to a fundamental issue: global context collapse, where all tokens are compressed into a single shared key-value summary, reducing representational diversity and capping the rank of attention matrices. This leads to uniform attention distributions and poor query-specific context selection, especially as sequence length grows. Existing fixes rely on auxiliary modules like depthwise convolutions or gating mechanisms, which reintroduce computational overhead and fail to fully restore expressivity.

The authors propose Multi-Head Linear Attention (MHLA), a novel formulation that restores query-dependent diversity by partitioning tokens into non-overlapping heads along the token dimension, computing local key-value summaries, and enabling query-conditioned mixing over these summaries. This design preserves linear O(N)O(N)O(N) complexity while significantly increasing the rank of attention matrices, effectively recovering much of the expressive power of softmax attention. MHLA requires only standard GEMM operations, ensuring compatibility with streaming and stateful execution. Experiments show consistent gains across domains: 3.6% higher ImageNet accuracy, 6.3% improvement on NLP, 12.6% boost in image generation, and a 41% enhancement in video generation—without additional computational cost.

Method

The authors leverage a multi-head linear attention mechanism, termed Multi-Head Linear Attention (MHLA), to address the representational limitations of standard linear attention while preserving linear-time complexity. The framework begins by partitioning the input token sequence into MMM non-overlapping blocks, which are processed in parallel. For each block bbb, a local key-value summary Sb=jbK~jVjS_b = \sum_{j \in b} \widetilde{K}_j V_j^\topSb=jbKjVj and a normalizer zb=jbK~jz_b = \sum_{j \in b} \widetilde{K}_jzb=jbKj are computed, where Q~=ϕ(Q)\widetilde{Q} = \phi(Q)Q=ϕ(Q) and K~=ϕ(K)\widetilde{K} = \phi(K)K=ϕ(K) represent the feature-mapped queries and keys. This blockwise computation allows for efficient, parallel processing of the sequence.

To restore query-conditioned selectivity, MHLA introduces a learnable coefficient matrix McRM×M\mathcal{M}_c \in \mathbb{R}^{M \times M}McRM×M, which governs a "Multi-Head Mixing" process. Each row mim_imi of this matrix represents the learnable, nonnegative mixing coefficients for query block iii, determining how it combines the MMM local key-value summaries into a query-specific mixed summary S~i=b=1Mmi,bSb\widetilde{S}_i = \sum_{b=1}^{M} m_{i,b} S_bSi=b=1Mmi,bSb. This mechanism enables each query block to adaptively reweight the contributions of all other blocks, effectively creating a query-dependent global context. The output for a query vector q~\widetilde{q}q from block iii is then computed as o=q~S~iq~z~io = \frac{\widetilde{q}^\top \widetilde{S}_i}{\widetilde{q}^\top \widetilde{z}_i}o=qziqSi, where z~i=b=1Mmi,bzb\widetilde{z}_i = \sum_{b=1}^{M} m_{i,b} z_bzi=b=1Mmi,bzb is the corresponding mixed normalizer.

The initialization of the coefficient matrix Mc\mathcal{M}_cMc is designed to favor locality, with initial coefficients mi,j(0)1dist(i,j)/maxk(dist(i,k))m_{i,j}^{(0)} \propto 1 - \text{dist}(i,j)/\max_k(\text{dist}(i,k))mi,j(0)1dist(i,j)/maxk(dist(i,k)), which promotes stable and faster convergence. This locality-biased initialization is visualized in the figure, where the coefficients are reshaped into a 2D grid to illustrate the spatial relationship between blocks. The final output is generated by performing an inner block matrix multiplication between the query blocks and the mixed key-value summaries, resulting in a final output tensor of size M×Nb×dM \times N_b \times dM×Nb×d. This two-stage process—block-level selection via the mixing coefficients followed by intra-block reweighting via the kernel inner product—reintroduces query-conditioned selectivity and per-token weighting, which are lost in global linear attention. The overall complexity of MHLA remains linear in the sequence length NNN, as the dominant operations are blockwise summary computation and linear combinations of MMM matrices of size d×dd \times dd×d.

Experiment

  • Main experiments validate that MHLA mitigates global context collapse in linear attention by preserving query-conditioned token-level diversity while maintaining linear-time complexity.
  • On ImageNet-1K, MHLA achieves state-of-the-art accuracy in image classification across DeiT and VLT models, surpassing baselines with minimal parameter overhead.
  • In class-to-image generation, MHLA matches or exceeds self-attention performance on DiT and DiG models, achieving up to 2.1× faster inference than FlashAttention at 512 resolution while improving FID scores.
  • In video generation with 31,500-token sequences, MHLA outperforms vanilla linear attention and matches FlashAttention performance with 2.1× speedup, demonstrating robustness in ultra-long contexts.
  • In natural language processing, MHLA achieves competitive perplexity and zero-shot accuracy on commonsense reasoning and MMLU, and leads in LongBench for long-context tasks, especially in Mult-Doc QA and summarization.
  • Ablation studies confirm that MHLA’s locality-biased initialization and learnable mixing coefficients enhance performance, and that M ≤ √N ensures optimal efficiency and scalability.

The authors evaluate the impact of their initialization strategy and learnable parameters in Multi-Head Mixing by comparing different variants on DeiT-T. Results show that using locality-biased initialization alone achieves strong performance, and adding learnable coefficients further improves accuracy, with the best performance achieved when both are used.

The authors compare MHLA with self-attention in DiT-XL/2 for class-to-image generation, showing that MHLA achieves lower FID and sFID scores while maintaining competitive IS and precision. When classifier-free guidance is applied, MHLA still performs comparably to self-attention, demonstrating its effectiveness in high-resolution image generation.

The authors evaluate the impact of additional modules on MHLA for class-to-image generation, showing that combining CPE and gating reduces the FID score to 59.8, outperforming all other variants. This indicates that the full MHLA configuration with both modules achieves the best generation quality.

The authors evaluate the proposed MHLA method on image generation tasks using DiT and DiG models, comparing it against self-attention and linear attention baselines. Results show that MHLA achieves the best FID scores across all model sizes and resolutions, consistently outperforming linear attention while maintaining high throughput, and matches or exceeds self-attention performance without relying on extra modules like CPE at larger scales.

The authors evaluate MHLA in video generation by fine-tuning a pretrained Wan2.1-1.3B model, replacing its FlashAttention with MHLA. Results show that MHLA achieves substantially stronger performance than vanilla linear attention while maintaining the same latency, recovering performance comparable to the original model and delivering a 2.1× inference speedup.


AI로 AI 구축

아이디어에서 출시까지 — 무료 AI 코코딩, 즉시 사용 가능한 환경, 최적의 GPU 가격으로 AI 개발을 가속화하세요.

AI 협업 코딩
바로 사용 가능한 GPU
최적의 가격

HyperAI Newsletters

최신 정보 구독하기
한국 시간 매주 월요일 오전 9시 에 이번 주의 최신 업데이트를 메일로 발송합니다
이메일 서비스 제공: MailChimp