HyperAIHyperAI

Command Palette

Search for a command to run...

MSA: الانتباه المتناثر للذاكرة (Memory Sparse Attention) لتوسيع النماذج الذاكرة من البداية إلى النهاية بكفاءة حتى 100M Tokens

الملخص

تُعد الذاكرة طويلة المدى حجر الزاوية في الذكاء البشري. ولا يزال تمكين الذكاء الاصطناعي من معالجة معلومات على مقياس عمر الفرد هدفًا طويل الأمد في هذا المجال. ونظرًا لقيود معماريات الانتباه الكامل (full-attention architectures)، فإن طول السياق الفعال لنماذج اللغات الكبيرة (LLMs) يقتصر عادةً على 1 مليون توكن (token). وتحاول النهج القائمة، مثل الانتباه الخطي الهجين (hybrid linear attention)، وحالات الذاكرة ذات الحجم الثابت (مثل الشبكات العصبية المتكررة RNNs)، وطرق التخزين الخارجي مثل الاسترجاع المعزز بالتوليد (RAG) أو أنظمة الوكلاء (agent systems)، تجاوز هذا الحد. غير أنها غالبًا ما تعاني من تدهور حاد في الدقة، وزيادة سريعة في زمن الاستجابة مع نمو طول السياق، أو عدم القدرة على تعديل محتوى الذاكرة ديناميكيًا، أو غياب التحسين من الطرف إلى الطرف (end-to-end optimization). وتعيق هذه الاختناقات سيناريوهات معقدة مثل تلخيص مجموعات نصية ضخمة (large-corpus summarization)، والتوائم الرقمية (Digital Twins)، والاستدلال طويل السجل في أنظمة الوكلاء، في حين تقيد سعة الذاكرة وتُبطئ الاستدلال (inference).في هذا العمل، نقدم نموذج الانتباه المتناثر للذاكرة (Memory Sparse Attention - MSA)، وهو إطار عمل قابل للتدريب من الطرف إلى الطرف، وفعال، وقابل للتوسع بكميات هائلة لنماذج الذاكرة. ومن خلال ابتكارات جوهرية تشمل الانتباه المتناثر القابل للتوسع (scalable sparse attention) وRoPE على مستوى الوثيقة (document-wise RoPE)، يحقق MSA تعقيدًا خطيًا في كل من التدريب والاستدلال مع الحفاظ على استقرار استثنائي، حيث يُظهر تدهورًا يقل عن 9% عند التوسع من 16 ألف توكن إلى 100 مليون توكن. وعلاوة على ذلك، يتيح ضغط ذاكرة المفاتيح والقيم (KV cache compression) المقترن بـ Memory Parallel إجراء استدلال على 100 مليون توكن باستخدام وحدتي معالجة رسومية A800. كما نقترح تقنية Memory Interleaving لتسهيل الاستدلال متعدد القفزات (multi-hop reasoning) المعقد عبر مقاطع ذاكرة متفرقة.يتفوق MSA بشكل ملحوظ على نماذج LLMs الرائدة، وأنظمة RAG الأحدث، ووكلاء الذاكرة المتصدرين في مقاييس السياق الطويل (long-context benchmarks). وتُظهر هذه النتائج أن فصل سعة الذاكرة عن عملية الاستدلال يمكّن MSA من توفير أساس قابل للتوسع لمنح النماذج متعددة الأغراض ذاكرة داخلية على مقياس عمر الفرد.

One-sentence Summary

Shanda Group and Peking University researchers propose Memory Sparse Attention (MSA), an end-to-end framework using scalable sparse attention and document-wise RoPE to achieve linear complexity. This approach enables stable 100M token inference on limited GPUs, outperforming RAG systems in long-corpus summarization and multi-hop reasoning tasks.

Key Contributions

  • The paper introduces Memory Sparse Attention (MSA), an end-to-end trainable framework that combines a scalable sparse attention architecture with document-wise RoPE to achieve linear complexity in training and inference while maintaining precision stability with less than 9% degradation when scaling from 16K to 100M tokens.
  • A Memory Inter-leaving mechanism is proposed to facilitate complex multi-hop reasoning across scattered memory segments, addressing the challenge of retrieving and utilizing information from non-contiguous parts of a massive context.
  • Experiments demonstrate that MSA significantly outperforms frontier language models, state-of-the-art RAG systems, and leading memory agents on long-context benchmarks, while KV cache compression and Memory Parallel enable 100M token inference on 2x A800 GPUs.

Introduction

Long-term memory is essential for AI applications like Digital Twins and complex agent reasoning, yet current Large Language Models are typically capped at 1M tokens due to the computational constraints of full-attention architectures. Prior approaches such as RAG systems suffer from precision loss and lack end-to-end trainability, while latent state methods like linear attention or KV cache compression either degrade rapidly at scale or incur prohibitive costs. To overcome these barriers, the authors introduce Memory Sparse Attention (MSA), an end-to-end trainable framework that combines sparse attention with document-wise RoPE to achieve linear complexity and maintain high precision across 100M tokens. This approach enables efficient inference on standard hardware while supporting advanced mechanisms like Memory Inter-leaving for robust multi-hop reasoning across massive context windows.

Dataset

  • The authors constructed a diverse pre-training corpus containing 158.95 billion tokens across 17.9 million queries to balance robust retrieval capabilities with broad general knowledge.
  • The dataset spans multiple domains ranging from scientific literature to general community Q&A.
  • To ensure a balanced distribution, any dataset outside the KALM suite exceeding 0.5 million queries is downsampled to a maximum of 0.5 million queries.
  • The KALM instruction data is retained in its entirety without downsampling.
  • The full corpus serves as the foundation for pre-training the model to achieve both specialized retrieval performance and general language understanding.

Method

The authors introduce MSA (Memory Sparse Attention), a unified, end-to-end trainable latent memory framework designed for massive memory Question-Answering. The core principle of MSA is to seamlessly integrate the processes of memory sparse retrieval and answer generation into a single, jointly-optimized architecture, moving beyond the limitations of conventional decoupled "retrieve-then-read" pipelines while preserving the ability to handle long-context memory.

Sparse Attention Mechanism

To efficiently process massive memory at the latent state level, MSA replaces the standard dense self-attention with a document-based retrieval sparse attention mechanism. Refer to the framework diagram below for the internal structure of the MSA layer.

Formally, let the memory bank consist of a set of documents D={d1,d2,,dN}\mathcal{D} = \{d_1, d_2, \ldots, d_N\}D={d1,d2,,dN}. For each document did_idi, the model generates standard Key Ki,hK_{i,h}Ki,h and Value Vi,hV_{i,h}Vi,h matrices via the backbone model's projection weights. In parallel, a Router K Projector generates a specialized routing key matrix Ki,hRK_{i,h}^RKi,hR:

Ki,h=HiWKh,Vi,h=HiWVh,Ki,hR=HiWKRh.K _ { i , h } = H _ { i } W _ { K } ^ { h } , \quad V _ { i , h } = H _ { i } W _ { V } ^ { h } , \quad K _ { i , h } ^ { R } = H _ { i } W _ { K ^ { R } } ^ { h } .Ki,h=HiWKh,Vi,h=HiWVh,Ki,hR=HiWKRh.

To significantly reduce the memory footprint and retrieval complexity, the authors segment each document into multiple fixed-length chunks and perform chunk-wise mean pooling, denoted as ϕ()\phi(\cdot)ϕ(), to compress these states into latent representations. This yields the compressed matrices Kˉi,h\bar{K}_{i,h}Kˉi,h, Vˉi,h\bar{V}_{i,h}Vˉi,h, and Kˉi,hR\bar{K}_{i,h}^{R}Kˉi,hR.

During inference, given a user query with hidden state HqH_qHq, the model computes standard states Qq,h,Kq,h,Vq,hQ_{q,h}, K_{q,h}, V_{q,h}Qq,h,Kq,h,Vq,h and a specific routing query Qq,hRQ_{q,h}^RQq,hR via a Router Q Projector. The relevance score SijS_{ij}Sij for the jjj-th chunk of the iii-th document is computed as the cosine similarity between the query's routing vector and the memory's compressed routing keys, aggregated across attention heads. To identify the most relevant memory segments, a maximum pooling is applied over the query-token-level relevance scores:

Sij=maxtoken  t,  head  h(mean(cos((Qq,hR)t,Kˉij,hR))).S _ { i j } = \operatorname* { m a x } _ { \mathrm { t o k e n } \; t , \; \mathrm { h e a d } \; h } ( \mathrm { m e a n } ( \cos ( ( Q _ { q , h } ^ { R } ) _ { t } , \bar { K } _ { i j , h } ^ { R } ) ) ) .Sij=tokent,headhmax(mean(cos((Qq,hR)t,Kˉij,hR))).

Based on these scores, the system selects the indices of the Top-k documents. Finally, the generation is performed by concatenating the compressed Key and Value matrices of the selected documents before the query's local cache. The model then performs autoregressive generation where the query QqQ_qQq from active tokens attends to this aggregated, sparsity-aware context:

Kctx=[{Kˉi}iI;Kq],Vctx=[{Vˉi}iI;Vq],Output=Attention(Qa,Kctx,Vctx).\begin{array} { r l } & { K _ { \mathrm { c t x } } = [ \{ \bar { K } _ { i } \} _ { i \in \mathcal { I } } ; K _ { q } ] , \quad V _ { \mathrm { c t x } } = [ \{ \bar { V } _ { i } \} _ { i \in \mathcal { I } } ; V _ { q } ] , } \\ & { \mathrm { O u t p u t } = \mathrm { A t t e n t i o n } ( Q _ { a } , K _ { \mathrm { c t x } } , V _ { \mathrm { c t x } } ) . } \end{array}Kctx=[{Kˉi}iI;Kq],Vctx=[{Vˉi}iI;Vq],Output=Attention(Qa,Kctx,Vctx).

The authors implement the MSA routing strategy selectively, applying it exclusively to the latter half of the model's layers. Empirical analysis reveals that hidden states in the initial layers fail to capture the high-level semantic abstractions necessary for effective retrieval, rendering the routing mechanism inefficient at these depths.

Positional Encoding

To ensure robust generalization across varying memory scales, MSA employs independent RoPE for each document. Standard global positional encodings would assign monotonically increasing position IDs across the concatenated sequence, causing position indices to shift drastically as the number of documents grows. By assigning independent position IDs (starting from 0) to each document, MSA decouples the positional semantics from the total number of documents in memory. Consequently, the model can effectively extrapolate, maintaining high retrieval and reasoning accuracy on massive memory contexts even after being trained only on smaller subsets.

Complementing this parallel strategy, the authors employ Global RoPE for the active context, which includes the user query and the subsequent autoregressive generation. The position IDs for these tokens are offset by the number of retrieved documents. Specifically, the position indices for the query initiate from kkk (corresponding to the Top-kkk retrieved compressed KVs). This strategic offset ensures that the model perceives the active context as a logical continuation of the retrieved background information.

Training Process

To endow the model with robust retrieval capabilities, the authors perform continuous pre-training on a deduplicated corpus. The overarching objective of this stage is to train the model to perform Generative Retrieval, where the model autoregressively generates the unique document IDs of relevant documents.

To explicitly guide the internal sparse attention mechanism beyond the supervision provided by the standard generation loss LLLM\mathcal{L}_{\text{LLM}}LLLM, an auxiliary loss, Laux\mathcal{L}_{\text{aux}}Laux, is introduced to supervise the Layer-wise Routing process. Within each MSA layer, the Router Projector is responsible for selecting the Top-kkk most relevant documents. The auxiliary loss is defined as a supervised contrastive objective:

Laux=1Pi=1Plogexp(si+/τ)exp(si+/τ)+i=1Nexp(si,i/τ),\mathcal { L } _ { \mathrm { a u x } } = - \frac { 1 } { | \mathcal { P } | } \sum _ { i = 1 } ^ { | \mathcal { P } | } \log \frac { \exp \big ( s _ { i } ^ { + } / \tau \big ) } { \exp \big ( s _ { i } ^ { + } / \tau \big ) + \sum _ { i = 1 } ^ { | \mathcal { N } | } \exp \big ( s _ { i , i } ^ { - } / \tau \big ) } ,Laux=P1i=1Plogexp(si+/τ)+i=1Nexp(si,i/τ)exp(si+/τ),

where τ\tauτ is the temperature parameter. This objective explicitly enforces separation between relevant and irrelevant document chunks in the latent routing space.

To ensure stability, a two-phase optimization schedule is adopted. In the initial warm-up phase, the focus is on aligning the internal Router Projectors with a loss of L=0.1LLLM+Laux\mathcal{L} = 0.1\mathcal{L}_{\text{LLM}} + \mathcal{L}_{\text{aux}}L=0.1LLLM+Laux. Upon completion of the warm-up, the system transitions to the main pre-training phase, where the loss weights are adjusted to L=LLLM+0.1Laux\mathcal{L} = \mathcal{L}_{\text{LLM}} + 0.1\mathcal{L}_{\text{aux}}L=LLLM+0.1Laux. Following continuous pre-training, a two-stage curriculum learning strategy is implemented for SFT on Question Answering tasks, extending the memory context length from 8k to 64k tokens to enhance data quality and length extrapolation.

Inference Pipeline

The inference pipeline is designed to handle the large-scale memory bank efficiently through three distinct stages, as shown in the figure below.

Stage 1: Global Memory Encoding (Offline). This stage is a one-time, offline pre-computation over the entire document corpus. For every document, the model performs a forward pass to generate the standard KKK and VVV matrices and the specialized Router KKK Projector generates the routing key matrix KRK^{R}KR. All three matrices are partitioned into chunks and compressed via mean pooling. The resulting compact representations are cached in the memory bank.

Stage 2: Routing and Context Assembly (Online). This stage is initiated upon receiving a user question. The model computes the question's hidden states and projects them via the Router QQQ Projector to obtain the routing query QqRQ_{q}^{R}QqR. This query is matched against the cached global routing keys KˉR\bar{K}^{R}KˉR to calculate relevance scores and identify the Top-kkk documents. Crucially, only the compact Key and Value matrices of these selected documents are loaded and concatenated with the question's local KqK_{q}Kq and VqV_{q}Vq to form the final sparse context.

Stage 3: Sparse Generation (Online). In the final stage, the model operates autoregressively on the assembled sparse context. The standard attention mechanism computes the interaction between the active token's Query QqQ_qQq and the concatenated KV pairs, generating the final answer token by token.

To address complex queries requiring multi-hop reasoning, MSA incorporates an adaptive Memory Interleave Mechanism. This essentially performs the routing and context assembly and Sparse Generation in an iterative manner. The inference process alternates between Generative Retrieval and Context Expansion, where retrieved documents are treated as part of the query for the next iteration. This cycle repeats adaptively until the model determines that the accumulated documents are sufficient, at which point it transitions to generating the final answer.

Efficiency and Scaling

MSA achieves linear complexity with respect to memory size LLL in both training and inference regimes. To enable extreme-length memory inference on a standard single node, the authors implement a specialized inference engine called Memory Parallel. This engine supports inference over a massive memory context of up to 100 million tokens with limited GPU resources.

A Tiered Memory Storage Strategy is employed to address capacity constraints. Routing Keys (KˉR\bar{K}^{R}KˉR) are distributed across the VRAM of multiple GPUs to ensure low-latency retrieval, while the bulk of the memory bank, the Content KVs (Kˉ\bar{K}Kˉ, Vˉ\bar{V}Vˉ), is stored in the host DRAM (CPU memory). Upon identifying the Top-kkk relevant chunks via GPU scoring, only the corresponding Content KVs are asynchronously fetched from the host to the GPU. Additionally, a Memory-Parallel Retrieval strategy is used where the query hidden states are broadcast to all GPUs, and each GPU independently calculates similarity scores against its local shard of Routing Keys before a global reduction identifies the Top-kkk indices.

Experiment

  • MSA was evaluated on nine diverse QA benchmarks and the RULER Needle In A Haystack task to validate its efficacy against same-backbone RAG systems, best-of-breed RAG configurations, and long-context memory architectures.
  • The model demonstrates consistent superiority over standard RAG baselines and achieves competitive or top performance against systems using significantly larger generators, proving that its architectural design effectively isolates and enhances retrieval and reasoning capabilities.
  • Experiments on context scaling from 32K to 1M tokens confirm that MSA maintains exceptional stability and high retrieval accuracy, avoiding the catastrophic degradation observed in unmodified backbones and other long-context models.
  • Ablation studies validate that the two-stage curriculum learning strategy, memory interleave mechanism for multi-hop reasoning, continual pre-training for router precision, and integration of original document text are all critical components for the system's overall performance.
  • Analysis of context degradation up to 100M tokens reveals that MSA sustains high generation quality with minimal performance loss, successfully decoupling reasoning capabilities from massive memory capacity.

بناء الذكاء الاصطناعي بالذكاء الاصطناعي

من الفكرة إلى الإطلاق — سرّع تطوير الذكاء الاصطناعي الخاص بك مع المساعدة البرمجية المجانية بالذكاء الاصطناعي، وبيئة جاهزة للاستخدام، وأفضل أسعار لوحدات معالجة الرسومات.

البرمجة التعاونية باستخدام الذكاء الاصطناعي
وحدات GPU جاهزة للعمل
أفضل الأسعار

HyperAI Newsletters

اشترك في آخر تحديثاتنا
سنرسل لك أحدث التحديثات الأسبوعية إلى بريدك الإلكتروني في الساعة التاسعة من صباح كل يوم اثنين
مدعوم بواسطة MailChimp