HyperAIHyperAI

Command Palette

Search for a command to run...

LeWorldModel: Stabile End-to-End-Joint-Embedding-Prädiktive Architektur auf Pixelebene

Lucas Maes Quentin Le Lidec Damien Scieur Yann LeCun Randall Balestriero

Zusammenfassung

Joint Embedding Predictive Architectures (JEPAs) bieten ein vielversprechendes Rahmenwerk zum Erlernen von Weltmodellen in kompakten latenten Räumen. Dennoch bleiben bestehende Methoden anfällig und sind auf komplexe Mehrkomponenten-Verlustfunktionen, exponentielle gleitende Durchschnitte, vortrainierte Encoder oder zusätzliche Überwachung angewiesen, um einen Zusammenbruch der Repräsentationen zu vermeiden. In dieser Arbeit stellen wir LeWorldModel (LeWM) vor, das erste JEPA, das stabil und end-to-end ausschließlich aus Rohpixeln trainiert wird und dabei lediglich zwei Verlustterme verwendet: einen Verlust für die Vorhersage der nächsten Einbettung sowie einen Regularisierer, der latente Einbettungen mit einer gaußverteilten Struktur erzwingt. Im Vergleich zur einzigen bestehenden end-to-end Alternative reduziert dies die Anzahl der justierbaren Verlust-Hyperparameter von sechs auf einen. Mit 15 Millionen Parametern, die innerhalb weniger Stunden auf einer einzigen GPU trainierbar sind, plant LeWM eine bis zu 48-mal höhere Geschwindigkeit als weltmodellbasierte Foundation-Modelle, bleibt dabei jedoch in diversen 2D- und 3D-Kontrollaufgaben konkurrenzfähig. Über die Kontrollaufgaben hinaus zeigen wir, dass der latente Raum von LeWM durch das Abfragen physikalischer Größen eine sinnvolle physikalische Struktur kodiert. Eine Überraschungs-Evaluation bestätigt, dass das Modell physikalisch unplausible Ereignisse zuverlässig erkennt.

One-sentence Summary

Researchers from Mila, NYU, and Samsung SAIL propose LeWorldModel, a Joint-Embedding Predictive Architecture that trains stably end-to-end from raw pixels using only two loss terms, eliminating the need for complex losses and pretrained encoders while achieving up to 48x faster planning on a single GPU and reliably detecting physically implausible events across diverse control tasks.

Key Contributions

  • The paper introduces LeWorldModel (LeWM), a Joint Embedding Predictive Architecture that trains stably end-to-end from raw pixels using only two loss terms. This approach reduces tunable loss hyperparameters from six to one compared to the only existing end-to-end alternative.
  • The method plans up to 48 times faster than foundation-model-based world models while remaining competitive across diverse 2D and 3D control tasks. This efficiency is achieved with 15M parameters trainable on a single GPU in a few hours without relying on exponential moving averages or pretrained encoders.
  • The latent space encodes meaningful physical structure, as demonstrated by probing of physical quantities within the model. Surprise evaluation confirms that the system reliably detects physically implausible events beyond standard control benchmarks.

Introduction

World Models enable agents to plan using internal simulations learned directly from raw sensory inputs like camera pixels. Although Joint Embedding Predictive Architectures (JEPAs) provide a compact framework for this task, prior methods often suffer from representation collapse and require complex stabilization heuristics or pretrained encoders. The authors introduce LeWorldModel, the first stable JEPA trained end-to-end from pixels using only a prediction loss and a Gaussian regularizer. This approach eliminates heuristic training tricks, reduces hyperparameter tuning, and enables efficient planning on a single GPU while capturing meaningful physical structures.

Dataset

Dataset overview

The authors evaluate their world models across four continuous control environments using the following datasets:

  • TwoRoom: A 2D navigation task from Sobal et al. where an agent moves between two rooms. The dataset includes 10,000 episodes averaging 92 steps, generated by a noisy heuristic policy guiding the agent through a door to the target.

  • PushT: A 2D manipulation task from Zhou et al. requiring an agent to push a T-shaped block to a target configuration. This subset contains 20,000 expert episodes with an average length of 196 steps.

  • OGBench-Cube: A 3D robotic manipulation task from Park et al. limited to the single-cube variant for pick-and-place operations. The authors collect 10,000 episodes of 200 steps using the benchmark library heuristic.

  • Reacher: A continuous control environment from the DeepMind Control Suite involving a two-joint arm reaching a target in a 2D plane. The dataset consists of 10,000 episodes of 200 steps collected via a Soft Actor-Critic policy.

  • Training and Processing: Each world model is trained for 10 epochs on these datasets. For predictor rollouts, three context frames are encoded into latent representations to autoregressively generate future states conditioned on actions, with predictions decoded by a separate decoder not used during training.

Method

LeWorldModel (LeWM) operates as a Joint-Embedding Predictive Architecture designed to learn task-agnostic world models from offline, reward-free data. The framework consists of two primary components: an encoder that maps raw pixel observations into a compact latent space, and a predictor that models environment dynamics by forecasting future latent embeddings conditioned on actions.

The training procedure involves processing sequences of observations and actions to update the model parameters end-to-end. The encoder, implemented as a Vision Transformer (ViT), processes an input frame oto_tot to produce a latent embedding ztz_tzt. This embedding is derived from the [CLS] token of the final layer, passed through a projection MLP with Batch Normalization to facilitate optimization. The predictor, a transformer architecture utilizing Adaptive Layer Normalization (AdaLN) for action conditioning, takes the current latent state ztz_tzt and action ata_tat to predict the next state embedding z^t+1\hat{z}_{t+1}z^t+1.

To ensure the learned representations are informative and stable, the training objective combines a prediction loss with a regularization term. The prediction loss Lpred\mathcal{L}_{\text{pred}}Lpred minimizes the mean squared error between the predicted embedding and the ground truth embedding of the next time step: Lpredz^t+1zt+122\mathcal{L}_{\text{pred}} \triangleq \| \hat{z}_{t+1} - z_{t+1} \|_2^2Lpredz^t+1zt+122 However, relying solely on prediction can lead to representation collapse. To prevent this, the authors introduce the Sketched-Isotropic-Gaussian Regularizer (SIGReg). As shown in the figure below: Training framework showing Encoder, Predictor, and SIGReg module details This module encourages the latent embeddings to match an isotropic Gaussian target distribution. Directly testing for normality in high dimensions is difficult, so SIGReg projects the embeddings onto random unit-norm directions and optimizes the univariate Epps-Pulley test statistic along these projections. By the Cramér–Wold theorem, matching these one-dimensional marginals ensures the full joint distribution matches the target. The total loss is defined as: LLeWMLpred+λSIGReg(Z)\mathcal{L}_{\text{LeWM}} \triangleq \mathcal{L}_{\text{pred}} + \lambda \, \text{SIGReg}(\mathbf{Z})LLeWMLpred+λSIGReg(Z) where λ\lambdaλ is the regularization weight and Z\mathbf{Z}Z represents the tensor of latent embeddings.

At inference time, LeWM leverages the learned dynamics for decision-making through latent planning using Model Predictive Control (MPC). The planning process is illustrated in the figure below: Latent planning pipeline with MPC and goal conditioning Given an initial observation, the system encodes it into a latent state and iteratively rolls out predicted latent states over a horizon HHH using the predictor. A cost function is computed based on the distance between the final predicted latent state z^H\hat{z}_Hz^H and the latent embedding of a goal observation zgz_gzg. An optimization solver, specifically the Cross-Entropy Method (CEM), is used to find the sequence of actions that minimizes this terminal cost. To mitigate error accumulation over long horizons, a receding-horizon strategy is employed where only the first few planned actions are executed before replanning from the updated observation.

Experiment

LeWM is evaluated against baselines such as PLDM and DINO-WM across diverse navigation and manipulation tasks in both two and three-dimensional environments. Results indicate that the method achieves significant planning speedups and more stable training convergence compared to complex multi-term objectives, while effectively capturing underlying physical quantities through latent representations. Although performance dips in low-complexity settings due to regularization mismatches, the model demonstrates robust physical understanding by detecting violations of expected dynamics without explicit temporal regularization.

The authors analyze the impact of replacing the default Vision Transformer encoder with a ResNet-18 backbone. Results indicate that while the ViT architecture yields slightly better performance, the ResNet-18 variant remains competitive, suggesting the method is robust to the choice of vision encoder. ViT encoder achieves higher success rates than ResNet-18 ResNet-18 backbone maintains competitive planning performance Method performance is largely agnostic to encoder architecture

Encoder architecture comparison on Push-T planning task

The authors evaluate planning performance on the Push-T task using Success Rate as the metric. LeWM achieves the top performance, surpassing both DINO-WM and PLDM. The results highlight the proposed method's ability to capture task-relevant quantities effectively. LeWM achieves the highest success rate among all tested models PLDM demonstrates lower performance and greater variance compared to others The proposed method outperforms DINO-WM using only pixel observations

Planning success rate comparison on Push-T

The authors analyze the impact of predictor capacity on planning performance within the Push-T environment. The results demonstrate that a small predictor configuration outperforms both smaller and larger model variants. This indicates that the small scale offers the best trade-off between capacity and optimization stability. The small predictor size achieves the highest success rate Smaller model variants result in lower performance scores Larger model configurations do not provide additional gains

Effect of predictor size on Push-T planning

The the the table compares the ability of different models to encode agent position information using linear and non-linear probes. LeWM and PLDM demonstrate significantly better linear probing performance compared to DINO-WM, while all models achieve near-perfect results with non-linear probes. LeWM and PLDM achieve similar performance on linear probes Both methods significantly outperform DINO-WM on linear metrics Non-linear probes yield near-perfect results for all models

Agent position probing results across models

The authors evaluate the physical understanding of LeWM by probing latent representations for quantities like position and velocity. Results indicate that LeWM generally outperforms the PLDM baseline and remains competitive with the pretrained DINO-WM model, particularly for positional properties. LeWM achieves lower prediction error than PLDM across most physical properties Non-linear MLP probes consistently outperform linear probes in recovering physical quantities Positional properties like block position are recovered with significantly higher accuracy than rotational properties

Physical latent probing results on Push-T environment

Experiments evaluate the proposed method's robustness and planning capabilities on the Push-T task, demonstrating that performance remains competitive across different vision encoder architectures. LeWM achieves superior planning success rates compared to baselines, while ablation studies reveal that a small predictor configuration provides the optimal balance between capacity and optimization stability. Furthermore, probing analyses confirm that the model effectively encodes agent position and physical quantities, outperforming competitors on linear probes and recovering positional properties with high accuracy.


KI mit KI entwickeln

Von der Idee bis zum Launch – beschleunigen Sie Ihre KI-Entwicklung mit kostenlosem KI-Co-Coding, sofort einsatzbereiter Umgebung und bestem GPU-Preis.

KI-gestütztes kollaboratives Programmieren
Sofort einsatzbereite GPUs
Die besten Preise

HyperAI Newsletters

Abonnieren Sie unsere neuesten Updates
Wir werden die neuesten Updates der Woche in Ihren Posteingang liefern um neun Uhr jeden Montagmorgen
Unterstützt von MailChimp