HyperAIHyperAI

Command Palette

Search for a command to run...

MemoryLLM : Mémoire à alimentation directe interprétable plug-and-play pour les transformateurs

Ajay Jaiswal Lauren Hannah Han-Byul Kim Duc Hoang Arnav Kundu Mehrdad Farajtabar Minsik Cho

Résumé

Comprendre le fonctionnement des composants des transformateurs dans les grands modèles linguistiques (LLM) est essentiel, car il constitue le fondement des récents progrès technologiques en intelligence artificielle. Dans ce travail, nous revisitons les défis liés à l’interprétabilité des modules à propagation avant (FFN, feed-forward networks) et proposons MemoryLLM, une architecture visant à déconnecter les FFN de l’attention mutuelle, permettant ainsi d’étudier ces derniers comme une mémoire neuronale contextuelle indépendante, agissant au niveau des tokens. Plus précisément, nous analysons la manière dont les tokens d’entrée accèdent aux emplacements mémoire au sein des paramètres des FFN, ainsi que l’importance de cette mémoire FFN dans diverses tâches downstream. MemoryLLM permet d’obtenir des FFN indépendants du contexte en les entraînant de manière isolée par rapport à l’attention mutuelle, directement à partir des embeddings des tokens. Cette approche permet de pré-calculer les FFN sous forme de recherches par token (ToL, token-wise lookups), facilitant ainsi le transfert à la demande entre la VRAM et le stockage, tout en améliorant l’efficacité de l’inférence. Nous introduisons également Flex-MemoryLLM, une architecture intermédiaire entre le design classique des transformateurs et MemoryLLM, qui permet de combler l’écart de performance résultant de l’entraînement des FFN à l’aide d’embeddings contextuels indépendants des tokens.

One-sentence Summary

Apple researchers propose MemoryLLM, a transformer architecture that decouples feed-forward networks (FFNs) from self-attention to treat them as context-free, token-indexed neural memory, enabling interpretable analysis and pre-computed lookups that reduce VRAM usage while maintaining performance through Flex-MemoryLLM variants.

Key Contributions

  • MemoryLLM decouples feed-forward networks (FFNs) from self-attention by training them in isolation on context-free token embeddings, enabling interpretable token-wise neural retrieval memory without relying on residual stream interactions.
  • The architecture supports pre-computed token lookups (ToLs) for FFNs, allowing plug-n-play memory offloading to storage and improving inference efficiency, validated across 250M, 750M, and 1B parameter scales.
  • Flex-MemoryLLM bridges performance gaps by splitting FFN parameters into context-aware and context-free components, maintaining competitiveness with conventional transformers while preserving interpretability and efficiency benefits.

Introduction

The authors leverage the observation that feed-forward networks (FFNs) in transformers—though parameter-heavy—are poorly understood due to their tight coupling with self-attention modules, which obscure their function as interpretable memory systems. Prior work attempted to reverse-engineer FFNs as key-value memories but relied on post-hoc analysis of pretrained models, requiring calibration datasets and offering only indirect, non-discrete query mappings. MemoryLLM addresses this by decoupling FFNs entirely from self-attention during training, treating them as context-free, token-indexed neural retrieval memories that can be precomputed and stored. This enables both interpretable token-level memory access and efficient inference via plug-n-play memory offloading. To mitigate performance loss from full decoupling, they also introduce Flex-MemoryLLM, which blends context-free and context-aware FFNs to bridge the gap with conventional transformers.

Method

The authors leverage a novel transformer architecture called MemoryLLM to decouple feed-forward networks (FFNs) from the residual stream and self-attention modules, enabling a deterministic and interpretable analysis of FFN functionality. In conventional LLMs, FFNs operate on a dynamically evolving residual stream that combines contextual information from prior layers, making their internal mechanisms opaque. MemoryLLM addresses this by training all FFN modules across transformer layers directly on the initial token embeddings X0X_0X0, which are static and derived solely from token IDs via the tokenizer. This design isolates FFNs from contextual dependencies, allowing them to function as context-free token-indexed memory reservoirs.

Refer to the framework diagram, which contrasts the conventional dense LLM with MemoryLLM. In the conventional architecture, each transformer layer LLL computes self-attention on the residual stream XLX_LXL, adds the result to XLX_LXL, and then applies the FFN to this sum. In MemoryLLM, the self-attention module operates as usual, but the FFN at layer LLL receives the original input embedding X0X_0X0 instead of the residual stream. The output of layer LLL is then computed as:

XL+1=XL+Attn(XL)+FFN(X0)X_{L+1} = X_L + \mathrm{Attn}(X_L) + \mathrm{FFN}(X_0)XL+1=XL+Attn(XL)+FFN(X0)

This parallelization of context-aware self-attention and context-free FFN computation preserves the residual flow while enabling FFNs to be interpreted as neural key-value memories over a finite, human-interpretable query space—the vocabulary.

To formalize this interpretation, the authors introduce the TKV (token-key-value) framework. Within each FFN, the up-projection matrix WUpW_{Up}WUp is treated as a set of KKK key vectors, and the down-projection matrix WDownW_{Down}WDown as corresponding value vectors. The gate-projection matrix WGateW_{Gate}WGate acts as a learned reweighting function that modulates the contribution of each key. For a query vector q=xiq = x_iq=xi corresponding to token tit_iti, the memory retrieval process involves two steps. First, the memory cell coefficients ckic_{k_i}cki are computed via dot product between qqq and the columns of WUpW_{Up}^\topWUp, then reweighted element-wise by the gate vector gkig_{k_i}gki:

c~ki=(q1×dWUp[:,ki])×gki\tilde{c}_{k_i} = (q^{1 \times d} \cdot W_{Up_{[:,k_i]}}^\top) \times g_{k_i}c~ki=(q1×dWUp[:,ki])×gki

Second, the retrieved output is a weighted sum of the value vectors vkiv_{k_i}vki:

FFN(X0q)=σq1×d=iKc~kivki\mathrm{FFN}(X_0^q) = \sigma_q^{1 \times d} = \sum_i^K \tilde{c}_{k_i} \cdot v_{k_i}FFN(X0q)=σq1×d=iKc~kivki

This framework eliminates the need for laborious reverse-engineering of input prefixes and provides a direct mapping from token IDs to memory cells.

The static nature of FFN inputs in MemoryLLM enables a significant efficiency gain: FFN outputs for all vocabulary tokens can be pre-computed offline and stored as token-wise lookup tables (ToLs). Each ToL for token tit_iti is a concatenation of FFN outputs across all NNN layers:

ToLxti1×(N×d)=Concatk=0N1{FFNLk(xti),dim=1}\mathrm{ToL}_{x_{t_i}}^{1 \times (N \times d)} = \mathbf{Concat}_{k=0}^{N-1} \left\{ \mathrm{FFN}_{L_k}(x_{t_i}) \, , \, \dim = 1 \right\}ToLxti1×(N×d)=Concatk=0N1{FFNLk(xti),dim=1}

These ToLs can be offloaded to storage and asynchronously prefetched during inference, reducing both computational load and VRAM usage. The authors further propose an on-demand plug-n-play policy that caches ToLs for frequent tokens (following Zipf’s law) and loads less frequent ones as needed.

To bridge the performance gap between MemoryLLM and conventional dense LLMs, the authors introduce Flex-MemoryLLM. This hybrid architecture splits the FFN parameters in each layer into two components: FFN Compute (FFN-C), a dense module operating on the residual stream to enhance computational capacity, and FFN Memory (FFN-M), a context-free memory module trained on X0X_0X0 like in MemoryLLM. The total parameter count remains identical to the base model, but a portion of the FFN parameters (e.g., 5h25h^25h2 for β=3\beta=3β=3) can be offloaded as static ToLs during inference. The output of layer LLL in Flex-MemoryLLM is:

XL+1=XL+Attn(XL)+FFN-C(XL)+FFN-M(X0)X_{L+1} = X_L + \mathrm{Attn}(X_L) + \mathrm{FFN\text{-}C}(X_L) + \mathrm{FFN\text{-}M}(X_0)XL+1=XL+Attn(XL)+FFN-C(XL)+FFN-M(X0)

This design allows for a smooth trade-off between performance and efficiency, enabling models with significantly fewer active parameters to match or even exceed the performance of dense counterparts.

The authors also explore storage optimization for ToLs, estimating the total storage size as:

Storage Size=vocab size×num layers×hidden dim×bits per param\text{Storage Size} = \text{vocab size} \times \text{num layers} \times \text{hidden dim} \times \text{bits per param}Storage Size=vocab size×num layers×hidden dim×bits per param

For a 1B parameter MemoryLLM with 24 layers and 2048 hidden dimension, this amounts to approximately 12.6 GB in F16 precision. They suggest quantization, low-rank compression, and layer-wise compression as strategies to reduce this footprint.

Experiment

  • Semantically similar tokens activate similar keys in FFN memory, forming interpretable clusters (e.g., punctuation, names, locations), enabling targeted memory editing or toxicity control.
  • Clustering strength remains high across all layers, with terminal layers showing more outlier keys, suggesting focused token-level information convergence.
  • Reducing FFN contribution via interpolation harms recall-heavy tasks more than reasoning tasks, confirming FFNs act as token-indexed retrieval memory.
  • MemoryLLM underperforms base LLMs when comparing total parameters but outperforms them when comparing active parameters, validating its efficiency via precomputed ToLs.
  • MemoryLLM and Flex-MemoryLLM outperform pruned base models at matching active parameter counts, offering a viable alternative to pruning techniques.
  • ToLs exhibit strong low-rank properties, especially in terminal layers, enabling ~2x storage reduction via SVD with minimal performance loss.
  • Dropping ToLs from middle layers causes minimal performance degradation, indicating high redundancy and offering a practical compression strategy.

The authors use a consistent 24-layer architecture across models with varying total parameter counts, adjusting intermediate dimensions to control active parameter counts while keeping hidden dimensions and attention heads fixed. MemoryLLM and Flex-MemoryLLM variants achieve lower active parameter counts than their base counterparts by design, enabling more efficient deployment without altering layer count or core structure. These configurations support comparisons of performance versus active parameter efficiency, particularly when evaluating memory-based architectures against conventional pruning methods.

The authors evaluate MemoryLLM-1B under varying precision levels for token-wise lookup tables, observing that reducing precision from 16-bit to 4-bit incurs only marginal performance changes across multiple downstream tasks. Results show that even at 4-bit precision, the model maintains competitive scores, suggesting strong resilience to low-precision quantization. This supports the feasibility of deploying MemoryLLM with significantly reduced storage requirements without substantial degradation in task performance.

The authors use a controlled interpolation parameter to reduce FFN contribution and observe that tasks relying on recall or retrieval suffer greater performance degradation compared to reasoning tasks. Results show that as FFN influence decreases, models retain stronger logical inference capabilities while losing accuracy on fact-based tasks. This suggests FFNs in MemoryLLM serve as token-indexed memory stores more critical for knowledge retrieval than for reasoning.

The authors evaluate MemoryLLM and its variants against a base model, finding that MemoryLLM achieves the lowest inference memory usage while also delivering the fastest decoding speed per token. Flex-MemoryLLM variants show trade-offs between memory and speed, with higher hidden dimension scaling increasing memory but not consistently improving speed. Results indicate MemoryLLM’s architecture enables more efficient inference without compromising throughput.

The authors use a parameter allocation strategy to split feed-forward network components into context-dependent and context-free memory modules, showing that shifting more parameters to the memory module reduces the context-dependent component while maintaining total parameter count. Results show that architectures like Flex-MemoryLLM can flexibly balance these components, enabling trade-offs between computational efficiency and memory capacity without altering overall model size.


Créer de l'IA avec l'IA

De l'idée au lancement — accélérez votre développement IA avec le co-codage IA gratuit, un environnement prêt à l'emploi et le meilleur prix pour les GPU.

Codage assisté par IA
GPU prêts à l’emploi
Tarifs les plus avantageux

HyperAI Newsletters

Abonnez-vous à nos dernières mises à jour
Nous vous enverrons les dernières mises à jour de la semaine dans votre boîte de réception à neuf heures chaque lundi matin
Propulsé par MailChimp