HyperAIHyperAI

Command Palette

Search for a command to run...

MSA: 1 億トークン規模への効率的なエンドツーエンドメモリモデルのスケーリングを実現するメモリスパースアテンション

概要

長期的記憶は、人間の知性の基盤をなす要素である。AI に生涯規模の情報を処理させることは、この分野における長年の課題であり続けてきた。フルアテンション・アーキテクチャの制約により、大規模言語モデル(LLM)の有効なコンテキスト長は通常 100 万トークンに制限されている。既存のアプローチ、すなわちハイブリッド線形アテンション、固定サイズのメモリ状態(例:RNN)、RAG やエージェントシステムなどの外部記憶手法は、この限界の拡張を試みている。しかし、それらはコンテキスト長の増加に伴い精度が著しく低下し、レイテンシが急増する、メモリ内容を動的に変更できない、あるいはエンドツーエンドの最適化が欠如しているなどの問題に直面している。これらのボトルネックは、大規模コーパスの要約、デジタルツイン、長履歴を有するエージェント推論といった複雑なシナリオの実現を阻害し、メモリ容量の制約や推論速度の低下を招いている。本研究では、エンドツーエンドで学習可能であり、効率的かつ大規模にスケーラブルなメモリモデルフレームワーク「Memory Sparse Attention(MSA)」を提案する。スケーラブルなスパースアテンションやドキュメント単位での RoPE といった中核的な革新を通じて、MSA はトレーニングおよび推論の両方において線形複雑性を達成しながら卓越した安定性を維持し、16K トークンから 100M トークンへのスケーリングにおいて 9% 未満の精度低下しか示さない。さらに、KV キャッシュの圧縮と Memory Parallel の組み合わせにより、2 枚の A800 GPU 上で 100M トークンの推論を実現している。また、散在するメモリセグメント間での複雑なマルチホップ推論を促進するため、「Memory Interleaving」も提案する。MSA は、長コンテキストベンチマークにおいて、最先端の LLM、最先端の RAG システム、および主要なメモリエージェントを大幅に上回る性能を示す。これらの結果は、メモリ容量と推論能力を分離することで、MSA が汎用モデルに本質的かつ生涯規模のメモリを付与するためのスケーラブルな基盤を提供することを示している。

One-sentence Summary

Shanda Group and Peking University researchers propose Memory Sparse Attention (MSA), an end-to-end framework using scalable sparse attention and document-wise RoPE to achieve linear complexity. This approach enables stable 100M token inference on limited GPUs, outperforming RAG systems in long-corpus summarization and multi-hop reasoning tasks.

Key Contributions

  • The paper introduces Memory Sparse Attention (MSA), an end-to-end trainable framework that combines a scalable sparse attention architecture with document-wise RoPE to achieve linear complexity in training and inference while maintaining precision stability with less than 9% degradation when scaling from 16K to 100M tokens.
  • A Memory Inter-leaving mechanism is proposed to facilitate complex multi-hop reasoning across scattered memory segments, addressing the challenge of retrieving and utilizing information from non-contiguous parts of a massive context.
  • Experiments demonstrate that MSA significantly outperforms frontier language models, state-of-the-art RAG systems, and leading memory agents on long-context benchmarks, while KV cache compression and Memory Parallel enable 100M token inference on 2x A800 GPUs.

Introduction

Long-term memory is essential for AI applications like Digital Twins and complex agent reasoning, yet current Large Language Models are typically capped at 1M tokens due to the computational constraints of full-attention architectures. Prior approaches such as RAG systems suffer from precision loss and lack end-to-end trainability, while latent state methods like linear attention or KV cache compression either degrade rapidly at scale or incur prohibitive costs. To overcome these barriers, the authors introduce Memory Sparse Attention (MSA), an end-to-end trainable framework that combines sparse attention with document-wise RoPE to achieve linear complexity and maintain high precision across 100M tokens. This approach enables efficient inference on standard hardware while supporting advanced mechanisms like Memory Inter-leaving for robust multi-hop reasoning across massive context windows.

Dataset

  • The authors constructed a diverse pre-training corpus containing 158.95 billion tokens across 17.9 million queries to balance robust retrieval capabilities with broad general knowledge.
  • The dataset spans multiple domains ranging from scientific literature to general community Q&A.
  • To ensure a balanced distribution, any dataset outside the KALM suite exceeding 0.5 million queries is downsampled to a maximum of 0.5 million queries.
  • The KALM instruction data is retained in its entirety without downsampling.
  • The full corpus serves as the foundation for pre-training the model to achieve both specialized retrieval performance and general language understanding.

Method

The authors introduce MSA (Memory Sparse Attention), a unified, end-to-end trainable latent memory framework designed for massive memory Question-Answering. The core principle of MSA is to seamlessly integrate the processes of memory sparse retrieval and answer generation into a single, jointly-optimized architecture, moving beyond the limitations of conventional decoupled "retrieve-then-read" pipelines while preserving the ability to handle long-context memory.

Sparse Attention Mechanism

To efficiently process massive memory at the latent state level, MSA replaces the standard dense self-attention with a document-based retrieval sparse attention mechanism. Refer to the framework diagram below for the internal structure of the MSA layer.

Formally, let the memory bank consist of a set of documents D={d1,d2,,dN}\mathcal{D} = \{d_1, d_2, \ldots, d_N\}D={d1,d2,,dN}. For each document did_idi, the model generates standard Key Ki,hK_{i,h}Ki,h and Value Vi,hV_{i,h}Vi,h matrices via the backbone model's projection weights. In parallel, a Router K Projector generates a specialized routing key matrix Ki,hRK_{i,h}^RKi,hR:

Ki,h=HiWKh,Vi,h=HiWVh,Ki,hR=HiWKRh.K _ { i , h } = H _ { i } W _ { K } ^ { h } , \quad V _ { i , h } = H _ { i } W _ { V } ^ { h } , \quad K _ { i , h } ^ { R } = H _ { i } W _ { K ^ { R } } ^ { h } .Ki,h=HiWKh,Vi,h=HiWVh,Ki,hR=HiWKRh.

To significantly reduce the memory footprint and retrieval complexity, the authors segment each document into multiple fixed-length chunks and perform chunk-wise mean pooling, denoted as ϕ()\phi(\cdot)ϕ(), to compress these states into latent representations. This yields the compressed matrices Kˉi,h\bar{K}_{i,h}Kˉi,h, Vˉi,h\bar{V}_{i,h}Vˉi,h, and Kˉi,hR\bar{K}_{i,h}^{R}Kˉi,hR.

During inference, given a user query with hidden state HqH_qHq, the model computes standard states Qq,h,Kq,h,Vq,hQ_{q,h}, K_{q,h}, V_{q,h}Qq,h,Kq,h,Vq,h and a specific routing query Qq,hRQ_{q,h}^RQq,hR via a Router Q Projector. The relevance score SijS_{ij}Sij for the jjj-th chunk of the iii-th document is computed as the cosine similarity between the query's routing vector and the memory's compressed routing keys, aggregated across attention heads. To identify the most relevant memory segments, a maximum pooling is applied over the query-token-level relevance scores:

Sij=maxtoken  t,  head  h(mean(cos((Qq,hR)t,Kˉij,hR))).S _ { i j } = \operatorname* { m a x } _ { \mathrm { t o k e n } \; t , \; \mathrm { h e a d } \; h } ( \mathrm { m e a n } ( \cos ( ( Q _ { q , h } ^ { R } ) _ { t } , \bar { K } _ { i j , h } ^ { R } ) ) ) .Sij=tokent,headhmax(mean(cos((Qq,hR)t,Kˉij,hR))).

Based on these scores, the system selects the indices of the Top-k documents. Finally, the generation is performed by concatenating the compressed Key and Value matrices of the selected documents before the query's local cache. The model then performs autoregressive generation where the query QqQ_qQq from active tokens attends to this aggregated, sparsity-aware context:

Kctx=[{Kˉi}iI;Kq],Vctx=[{Vˉi}iI;Vq],Output=Attention(Qa,Kctx,Vctx).\begin{array} { r l } & { K _ { \mathrm { c t x } } = [ \{ \bar { K } _ { i } \} _ { i \in \mathcal { I } } ; K _ { q } ] , \quad V _ { \mathrm { c t x } } = [ \{ \bar { V } _ { i } \} _ { i \in \mathcal { I } } ; V _ { q } ] , } \\ & { \mathrm { O u t p u t } = \mathrm { A t t e n t i o n } ( Q _ { a } , K _ { \mathrm { c t x } } , V _ { \mathrm { c t x } } ) . } \end{array}Kctx=[{Kˉi}iI;Kq],Vctx=[{Vˉi}iI;Vq],Output=Attention(Qa,Kctx,Vctx).

The authors implement the MSA routing strategy selectively, applying it exclusively to the latter half of the model's layers. Empirical analysis reveals that hidden states in the initial layers fail to capture the high-level semantic abstractions necessary for effective retrieval, rendering the routing mechanism inefficient at these depths.

Positional Encoding

To ensure robust generalization across varying memory scales, MSA employs independent RoPE for each document. Standard global positional encodings would assign monotonically increasing position IDs across the concatenated sequence, causing position indices to shift drastically as the number of documents grows. By assigning independent position IDs (starting from 0) to each document, MSA decouples the positional semantics from the total number of documents in memory. Consequently, the model can effectively extrapolate, maintaining high retrieval and reasoning accuracy on massive memory contexts even after being trained only on smaller subsets.

Complementing this parallel strategy, the authors employ Global RoPE for the active context, which includes the user query and the subsequent autoregressive generation. The position IDs for these tokens are offset by the number of retrieved documents. Specifically, the position indices for the query initiate from kkk (corresponding to the Top-kkk retrieved compressed KVs). This strategic offset ensures that the model perceives the active context as a logical continuation of the retrieved background information.

Training Process

To endow the model with robust retrieval capabilities, the authors perform continuous pre-training on a deduplicated corpus. The overarching objective of this stage is to train the model to perform Generative Retrieval, where the model autoregressively generates the unique document IDs of relevant documents.

To explicitly guide the internal sparse attention mechanism beyond the supervision provided by the standard generation loss LLLM\mathcal{L}_{\text{LLM}}LLLM, an auxiliary loss, Laux\mathcal{L}_{\text{aux}}Laux, is introduced to supervise the Layer-wise Routing process. Within each MSA layer, the Router Projector is responsible for selecting the Top-kkk most relevant documents. The auxiliary loss is defined as a supervised contrastive objective:

Laux=1Pi=1Plogexp(si+/τ)exp(si+/τ)+i=1Nexp(si,i/τ),\mathcal { L } _ { \mathrm { a u x } } = - \frac { 1 } { | \mathcal { P } | } \sum _ { i = 1 } ^ { | \mathcal { P } | } \log \frac { \exp \big ( s _ { i } ^ { + } / \tau \big ) } { \exp \big ( s _ { i } ^ { + } / \tau \big ) + \sum _ { i = 1 } ^ { | \mathcal { N } | } \exp \big ( s _ { i , i } ^ { - } / \tau \big ) } ,Laux=P1i=1Plogexp(si+/τ)+i=1Nexp(si,i/τ)exp(si+/τ),

where τ\tauτ is the temperature parameter. This objective explicitly enforces separation between relevant and irrelevant document chunks in the latent routing space.

To ensure stability, a two-phase optimization schedule is adopted. In the initial warm-up phase, the focus is on aligning the internal Router Projectors with a loss of L=0.1LLLM+Laux\mathcal{L} = 0.1\mathcal{L}_{\text{LLM}} + \mathcal{L}_{\text{aux}}L=0.1LLLM+Laux. Upon completion of the warm-up, the system transitions to the main pre-training phase, where the loss weights are adjusted to L=LLLM+0.1Laux\mathcal{L} = \mathcal{L}_{\text{LLM}} + 0.1\mathcal{L}_{\text{aux}}L=LLLM+0.1Laux. Following continuous pre-training, a two-stage curriculum learning strategy is implemented for SFT on Question Answering tasks, extending the memory context length from 8k to 64k tokens to enhance data quality and length extrapolation.

Inference Pipeline

The inference pipeline is designed to handle the large-scale memory bank efficiently through three distinct stages, as shown in the figure below.

Stage 1: Global Memory Encoding (Offline). This stage is a one-time, offline pre-computation over the entire document corpus. For every document, the model performs a forward pass to generate the standard KKK and VVV matrices and the specialized Router KKK Projector generates the routing key matrix KRK^{R}KR. All three matrices are partitioned into chunks and compressed via mean pooling. The resulting compact representations are cached in the memory bank.

Stage 2: Routing and Context Assembly (Online). This stage is initiated upon receiving a user question. The model computes the question's hidden states and projects them via the Router QQQ Projector to obtain the routing query QqRQ_{q}^{R}QqR. This query is matched against the cached global routing keys KˉR\bar{K}^{R}KˉR to calculate relevance scores and identify the Top-kkk documents. Crucially, only the compact Key and Value matrices of these selected documents are loaded and concatenated with the question's local KqK_{q}Kq and VqV_{q}Vq to form the final sparse context.

Stage 3: Sparse Generation (Online). In the final stage, the model operates autoregressively on the assembled sparse context. The standard attention mechanism computes the interaction between the active token's Query QqQ_qQq and the concatenated KV pairs, generating the final answer token by token.

To address complex queries requiring multi-hop reasoning, MSA incorporates an adaptive Memory Interleave Mechanism. This essentially performs the routing and context assembly and Sparse Generation in an iterative manner. The inference process alternates between Generative Retrieval and Context Expansion, where retrieved documents are treated as part of the query for the next iteration. This cycle repeats adaptively until the model determines that the accumulated documents are sufficient, at which point it transitions to generating the final answer.

Efficiency and Scaling

MSA achieves linear complexity with respect to memory size LLL in both training and inference regimes. To enable extreme-length memory inference on a standard single node, the authors implement a specialized inference engine called Memory Parallel. This engine supports inference over a massive memory context of up to 100 million tokens with limited GPU resources.

A Tiered Memory Storage Strategy is employed to address capacity constraints. Routing Keys (KˉR\bar{K}^{R}KˉR) are distributed across the VRAM of multiple GPUs to ensure low-latency retrieval, while the bulk of the memory bank, the Content KVs (Kˉ\bar{K}Kˉ, Vˉ\bar{V}Vˉ), is stored in the host DRAM (CPU memory). Upon identifying the Top-kkk relevant chunks via GPU scoring, only the corresponding Content KVs are asynchronously fetched from the host to the GPU. Additionally, a Memory-Parallel Retrieval strategy is used where the query hidden states are broadcast to all GPUs, and each GPU independently calculates similarity scores against its local shard of Routing Keys before a global reduction identifies the Top-kkk indices.

Experiment

  • MSA was evaluated on nine diverse QA benchmarks and the RULER Needle In A Haystack task to validate its efficacy against same-backbone RAG systems, best-of-breed RAG configurations, and long-context memory architectures.
  • The model demonstrates consistent superiority over standard RAG baselines and achieves competitive or top performance against systems using significantly larger generators, proving that its architectural design effectively isolates and enhances retrieval and reasoning capabilities.
  • Experiments on context scaling from 32K to 1M tokens confirm that MSA maintains exceptional stability and high retrieval accuracy, avoiding the catastrophic degradation observed in unmodified backbones and other long-context models.
  • Ablation studies validate that the two-stage curriculum learning strategy, memory interleave mechanism for multi-hop reasoning, continual pre-training for router precision, and integration of original document text are all critical components for the system's overall performance.
  • Analysis of context degradation up to 100M tokens reveals that MSA sustains high generation quality with minimal performance loss, successfully decoupling reasoning capabilities from massive memory capacity.

AIでAIを構築

アイデアからローンチまで — 無料のAIコーディング支援、すぐに使える環境、最高のGPU価格でAI開発を加速。

AI コーディング補助
すぐに使える GPU
最適な料金体系

HyperAI Newsletters

最新情報を購読する
北京時間 毎週月曜日の午前9時 に、その週の最新情報をメールでお届けします
メール配信サービスは MailChimp によって提供されています