HyperAIHyperAI

Command Palette

Search for a command to run...

MSA : Mécanisme d'attention parcimonieuse en mémoire pour une mise à l'échelle efficace de modèles de mémoire de bout en bout jusqu'à 100M tokens

Résumé

La mémoire à long terme constitue un pilier fondamental de l'intelligence humaine. Permettre aux systèmes d'intelligence artificielle de traiter des informations à l'échelle d'une vie entière demeure un objectif de longue date dans le domaine. En raison des contraintes inhérentes aux architectures à attention complète, la longueur de contexte effective des grands modèles de langage (LLM) est généralement limitée à 1 million de tokens. Les approches existantes, telles que l'attention linéaire hybride, les états de mémoire de taille fixe (par exemple, les réseaux de neurones récurrents ou RNN) et les méthodes de stockage externe comme le RAG (Retrieval-Augmented Generation) ou les systèmes d'agent, tentent de repousser cette limite. Cependant, elles souffrent souvent d'une dégradation sévère de la précision et d'une augmentation rapide de la latence à mesure que la longueur du contexte s'accroît, d'une incapacité à modifier dynamiquement le contenu de la mémoire, ou encore d'un manque d'optimisation de bout en bout. Ces goulots d'étranglement entravent des scénarios complexes tels que la summarisation de grands corpus, les jumeaux numériques (Digital Twins) et le raisonnement d'agents sur de longues historiques, tout en limitant la capacité de mémoire et en ralentissant l'inférence. Nous présentons Memory Sparse Attention (MSA), un cadre de modèle de mémoire entraînable de bout en bout, efficace et massivement évolutif. Grâce à des innovations fondamentales incluant une attention parcimonieuse évolutive et un RoPE (Rotary Positional Embedding) par document, MSA atteint une complexité linéaire tant pour l'entraînement que pour l'inférence, tout en maintenant une stabilité exceptionnelle, avec une dégradation inférieure à 9 % lors du passage de 16K à 100M de tokens. De plus, la compression du cache KV, combinée à une parallélisation de la mémoire (Memory Parallel), permet une inférence sur 100M de tokens avec 2 GPU A800. Nous proposons également une intercalation de mémoire (Memory Interleaving) pour faciliter un raisonnement multi-sauts complexe à travers des segments de mémoire dispersés. MSA surpasse nettement les LLM de pointe, les systèmes RAG les plus avancés et les agents de mémoire leaders sur des benchmarks de contexte long. Ces résultats démontrent que, en découplant la capacité de mémoire du raisonnement, MSA fournit une fondation évolutive pour doter les modèles à usage général d'une mémoire intrinsèque à l'échelle d'une vie entière.

One-sentence Summary

Shanda Group and Peking University researchers propose Memory Sparse Attention (MSA), an end-to-end framework using scalable sparse attention and document-wise RoPE to achieve linear complexity. This approach enables stable 100M token inference on limited GPUs, outperforming RAG systems in long-corpus summarization and multi-hop reasoning tasks.

Key Contributions

  • The paper introduces Memory Sparse Attention (MSA), an end-to-end trainable framework that combines a scalable sparse attention architecture with document-wise RoPE to achieve linear complexity in training and inference while maintaining precision stability with less than 9% degradation when scaling from 16K to 100M tokens.
  • A Memory Inter-leaving mechanism is proposed to facilitate complex multi-hop reasoning across scattered memory segments, addressing the challenge of retrieving and utilizing information from non-contiguous parts of a massive context.
  • Experiments demonstrate that MSA significantly outperforms frontier language models, state-of-the-art RAG systems, and leading memory agents on long-context benchmarks, while KV cache compression and Memory Parallel enable 100M token inference on 2x A800 GPUs.

Introduction

Long-term memory is essential for AI applications like Digital Twins and complex agent reasoning, yet current Large Language Models are typically capped at 1M tokens due to the computational constraints of full-attention architectures. Prior approaches such as RAG systems suffer from precision loss and lack end-to-end trainability, while latent state methods like linear attention or KV cache compression either degrade rapidly at scale or incur prohibitive costs. To overcome these barriers, the authors introduce Memory Sparse Attention (MSA), an end-to-end trainable framework that combines sparse attention with document-wise RoPE to achieve linear complexity and maintain high precision across 100M tokens. This approach enables efficient inference on standard hardware while supporting advanced mechanisms like Memory Inter-leaving for robust multi-hop reasoning across massive context windows.

Dataset

  • The authors constructed a diverse pre-training corpus containing 158.95 billion tokens across 17.9 million queries to balance robust retrieval capabilities with broad general knowledge.
  • The dataset spans multiple domains ranging from scientific literature to general community Q&A.
  • To ensure a balanced distribution, any dataset outside the KALM suite exceeding 0.5 million queries is downsampled to a maximum of 0.5 million queries.
  • The KALM instruction data is retained in its entirety without downsampling.
  • The full corpus serves as the foundation for pre-training the model to achieve both specialized retrieval performance and general language understanding.

Method

The authors introduce MSA (Memory Sparse Attention), a unified, end-to-end trainable latent memory framework designed for massive memory Question-Answering. The core principle of MSA is to seamlessly integrate the processes of memory sparse retrieval and answer generation into a single, jointly-optimized architecture, moving beyond the limitations of conventional decoupled "retrieve-then-read" pipelines while preserving the ability to handle long-context memory.

Sparse Attention Mechanism

To efficiently process massive memory at the latent state level, MSA replaces the standard dense self-attention with a document-based retrieval sparse attention mechanism. Refer to the framework diagram below for the internal structure of the MSA layer.

Formally, let the memory bank consist of a set of documents D={d1,d2,,dN}\mathcal{D} = \{d_1, d_2, \ldots, d_N\}D={d1,d2,,dN}. For each document did_idi, the model generates standard Key Ki,hK_{i,h}Ki,h and Value Vi,hV_{i,h}Vi,h matrices via the backbone model's projection weights. In parallel, a Router K Projector generates a specialized routing key matrix Ki,hRK_{i,h}^RKi,hR:

Ki,h=HiWKh,Vi,h=HiWVh,Ki,hR=HiWKRh.K _ { i , h } = H _ { i } W _ { K } ^ { h } , \quad V _ { i , h } = H _ { i } W _ { V } ^ { h } , \quad K _ { i , h } ^ { R } = H _ { i } W _ { K ^ { R } } ^ { h } .Ki,h=HiWKh,Vi,h=HiWVh,Ki,hR=HiWKRh.

To significantly reduce the memory footprint and retrieval complexity, the authors segment each document into multiple fixed-length chunks and perform chunk-wise mean pooling, denoted as ϕ()\phi(\cdot)ϕ(), to compress these states into latent representations. This yields the compressed matrices Kˉi,h\bar{K}_{i,h}Kˉi,h, Vˉi,h\bar{V}_{i,h}Vˉi,h, and Kˉi,hR\bar{K}_{i,h}^{R}Kˉi,hR.

During inference, given a user query with hidden state HqH_qHq, the model computes standard states Qq,h,Kq,h,Vq,hQ_{q,h}, K_{q,h}, V_{q,h}Qq,h,Kq,h,Vq,h and a specific routing query Qq,hRQ_{q,h}^RQq,hR via a Router Q Projector. The relevance score SijS_{ij}Sij for the jjj-th chunk of the iii-th document is computed as the cosine similarity between the query's routing vector and the memory's compressed routing keys, aggregated across attention heads. To identify the most relevant memory segments, a maximum pooling is applied over the query-token-level relevance scores:

Sij=maxtoken  t,  head  h(mean(cos((Qq,hR)t,Kˉij,hR))).S _ { i j } = \operatorname* { m a x } _ { \mathrm { t o k e n } \; t , \; \mathrm { h e a d } \; h } ( \mathrm { m e a n } ( \cos ( ( Q _ { q , h } ^ { R } ) _ { t } , \bar { K } _ { i j , h } ^ { R } ) ) ) .Sij=tokent,headhmax(mean(cos((Qq,hR)t,Kˉij,hR))).

Based on these scores, the system selects the indices of the Top-k documents. Finally, the generation is performed by concatenating the compressed Key and Value matrices of the selected documents before the query's local cache. The model then performs autoregressive generation where the query QqQ_qQq from active tokens attends to this aggregated, sparsity-aware context:

Kctx=[{Kˉi}iI;Kq],Vctx=[{Vˉi}iI;Vq],Output=Attention(Qa,Kctx,Vctx).\begin{array} { r l } & { K _ { \mathrm { c t x } } = [ \{ \bar { K } _ { i } \} _ { i \in \mathcal { I } } ; K _ { q } ] , \quad V _ { \mathrm { c t x } } = [ \{ \bar { V } _ { i } \} _ { i \in \mathcal { I } } ; V _ { q } ] , } \\ & { \mathrm { O u t p u t } = \mathrm { A t t e n t i o n } ( Q _ { a } , K _ { \mathrm { c t x } } , V _ { \mathrm { c t x } } ) . } \end{array}Kctx=[{Kˉi}iI;Kq],Vctx=[{Vˉi}iI;Vq],Output=Attention(Qa,Kctx,Vctx).

The authors implement the MSA routing strategy selectively, applying it exclusively to the latter half of the model's layers. Empirical analysis reveals that hidden states in the initial layers fail to capture the high-level semantic abstractions necessary for effective retrieval, rendering the routing mechanism inefficient at these depths.

Positional Encoding

To ensure robust generalization across varying memory scales, MSA employs independent RoPE for each document. Standard global positional encodings would assign monotonically increasing position IDs across the concatenated sequence, causing position indices to shift drastically as the number of documents grows. By assigning independent position IDs (starting from 0) to each document, MSA decouples the positional semantics from the total number of documents in memory. Consequently, the model can effectively extrapolate, maintaining high retrieval and reasoning accuracy on massive memory contexts even after being trained only on smaller subsets.

Complementing this parallel strategy, the authors employ Global RoPE for the active context, which includes the user query and the subsequent autoregressive generation. The position IDs for these tokens are offset by the number of retrieved documents. Specifically, the position indices for the query initiate from kkk (corresponding to the Top-kkk retrieved compressed KVs). This strategic offset ensures that the model perceives the active context as a logical continuation of the retrieved background information.

Training Process

To endow the model with robust retrieval capabilities, the authors perform continuous pre-training on a deduplicated corpus. The overarching objective of this stage is to train the model to perform Generative Retrieval, where the model autoregressively generates the unique document IDs of relevant documents.

To explicitly guide the internal sparse attention mechanism beyond the supervision provided by the standard generation loss LLLM\mathcal{L}_{\text{LLM}}LLLM, an auxiliary loss, Laux\mathcal{L}_{\text{aux}}Laux, is introduced to supervise the Layer-wise Routing process. Within each MSA layer, the Router Projector is responsible for selecting the Top-kkk most relevant documents. The auxiliary loss is defined as a supervised contrastive objective:

Laux=1Pi=1Plogexp(si+/τ)exp(si+/τ)+i=1Nexp(si,i/τ),\mathcal { L } _ { \mathrm { a u x } } = - \frac { 1 } { | \mathcal { P } | } \sum _ { i = 1 } ^ { | \mathcal { P } | } \log \frac { \exp \big ( s _ { i } ^ { + } / \tau \big ) } { \exp \big ( s _ { i } ^ { + } / \tau \big ) + \sum _ { i = 1 } ^ { | \mathcal { N } | } \exp \big ( s _ { i , i } ^ { - } / \tau \big ) } ,Laux=P1i=1Plogexp(si+/τ)+i=1Nexp(si,i/τ)exp(si+/τ),

where τ\tauτ is the temperature parameter. This objective explicitly enforces separation between relevant and irrelevant document chunks in the latent routing space.

To ensure stability, a two-phase optimization schedule is adopted. In the initial warm-up phase, the focus is on aligning the internal Router Projectors with a loss of L=0.1LLLM+Laux\mathcal{L} = 0.1\mathcal{L}_{\text{LLM}} + \mathcal{L}_{\text{aux}}L=0.1LLLM+Laux. Upon completion of the warm-up, the system transitions to the main pre-training phase, where the loss weights are adjusted to L=LLLM+0.1Laux\mathcal{L} = \mathcal{L}_{\text{LLM}} + 0.1\mathcal{L}_{\text{aux}}L=LLLM+0.1Laux. Following continuous pre-training, a two-stage curriculum learning strategy is implemented for SFT on Question Answering tasks, extending the memory context length from 8k to 64k tokens to enhance data quality and length extrapolation.

Inference Pipeline

The inference pipeline is designed to handle the large-scale memory bank efficiently through three distinct stages, as shown in the figure below.

Stage 1: Global Memory Encoding (Offline). This stage is a one-time, offline pre-computation over the entire document corpus. For every document, the model performs a forward pass to generate the standard KKK and VVV matrices and the specialized Router KKK Projector generates the routing key matrix KRK^{R}KR. All three matrices are partitioned into chunks and compressed via mean pooling. The resulting compact representations are cached in the memory bank.

Stage 2: Routing and Context Assembly (Online). This stage is initiated upon receiving a user question. The model computes the question's hidden states and projects them via the Router QQQ Projector to obtain the routing query QqRQ_{q}^{R}QqR. This query is matched against the cached global routing keys KˉR\bar{K}^{R}KˉR to calculate relevance scores and identify the Top-kkk documents. Crucially, only the compact Key and Value matrices of these selected documents are loaded and concatenated with the question's local KqK_{q}Kq and VqV_{q}Vq to form the final sparse context.

Stage 3: Sparse Generation (Online). In the final stage, the model operates autoregressively on the assembled sparse context. The standard attention mechanism computes the interaction between the active token's Query QqQ_qQq and the concatenated KV pairs, generating the final answer token by token.

To address complex queries requiring multi-hop reasoning, MSA incorporates an adaptive Memory Interleave Mechanism. This essentially performs the routing and context assembly and Sparse Generation in an iterative manner. The inference process alternates between Generative Retrieval and Context Expansion, where retrieved documents are treated as part of the query for the next iteration. This cycle repeats adaptively until the model determines that the accumulated documents are sufficient, at which point it transitions to generating the final answer.

Efficiency and Scaling

MSA achieves linear complexity with respect to memory size LLL in both training and inference regimes. To enable extreme-length memory inference on a standard single node, the authors implement a specialized inference engine called Memory Parallel. This engine supports inference over a massive memory context of up to 100 million tokens with limited GPU resources.

A Tiered Memory Storage Strategy is employed to address capacity constraints. Routing Keys (KˉR\bar{K}^{R}KˉR) are distributed across the VRAM of multiple GPUs to ensure low-latency retrieval, while the bulk of the memory bank, the Content KVs (Kˉ\bar{K}Kˉ, Vˉ\bar{V}Vˉ), is stored in the host DRAM (CPU memory). Upon identifying the Top-kkk relevant chunks via GPU scoring, only the corresponding Content KVs are asynchronously fetched from the host to the GPU. Additionally, a Memory-Parallel Retrieval strategy is used where the query hidden states are broadcast to all GPUs, and each GPU independently calculates similarity scores against its local shard of Routing Keys before a global reduction identifies the Top-kkk indices.

Experiment

  • MSA was evaluated on nine diverse QA benchmarks and the RULER Needle In A Haystack task to validate its efficacy against same-backbone RAG systems, best-of-breed RAG configurations, and long-context memory architectures.
  • The model demonstrates consistent superiority over standard RAG baselines and achieves competitive or top performance against systems using significantly larger generators, proving that its architectural design effectively isolates and enhances retrieval and reasoning capabilities.
  • Experiments on context scaling from 32K to 1M tokens confirm that MSA maintains exceptional stability and high retrieval accuracy, avoiding the catastrophic degradation observed in unmodified backbones and other long-context models.
  • Ablation studies validate that the two-stage curriculum learning strategy, memory interleave mechanism for multi-hop reasoning, continual pre-training for router precision, and integration of original document text are all critical components for the system's overall performance.
  • Analysis of context degradation up to 100M tokens reveals that MSA sustains high generation quality with minimal performance loss, successfully decoupling reasoning capabilities from massive memory capacity.

Créer de l'IA avec l'IA

De l'idée au lancement — accélérez votre développement IA avec le co-codage IA gratuit, un environnement prêt à l'emploi et le meilleur prix pour les GPU.

Codage assisté par IA
GPU prêts à l’emploi
Tarifs les plus avantageux

HyperAI Newsletters

Abonnez-vous à nos dernières mises à jour
Nous vous enverrons les dernières mises à jour de la semaine dans votre boîte de réception à neuf heures chaque lundi matin
Propulsé par MailChimp