UWM-JEPA: Predictive World Models That Imagine in Belief Space

2026-05-25Machine Learning

Machine LearningArtificial IntelligenceRobotics
AI summary

The authors propose a new type of world model called UWM-JEPA that uses a special mathematical structure called a density matrix and a unitary predictor to better imagine hidden futures in partially observed environments. This approach keeps uncertainty consistent during imagination steps, unlike standard vector-based models that lose important information. On a test requiring prediction over several hidden steps, their model performed much better than a similar baseline. They also found that training on counterfactual outcomes, rather than just following given targets, is key to making the model sensitive to different possible actions. Overall, the work shows that how the model represents uncertainty and makes predictions is crucial for imagining under uncertainty, beyond just encoding the current situation well.

World ModelsPartially Observed EnvironmentsJoint Embedding Predictive Architectures (JEPA)Density MatrixUnitary PredictorCounterfactual ActionsHidden StateLatent SpaceRolloutEncoder
Authors
Santosh Kumar Radha, Oktay Goktas
Abstract
World models for partially observed environments must imagine multiple compatible hidden futures and steer between them under counterfactual actions. Joint Embedding Predictive Architectures (JEPAs) do this in latent space, but a vector-valued latent has no internal structure for carrying the belief over hidden continuations through blind rollout. We introduce the Unitary World Model JEPA (UWM-JEPA), a JEPA world model with a density-matrix latent on a joint system-environment space and a learned unitary predictor. The construction preserves the joint-state spectrum exactly during rollout, so the predictor itself cannot dissipate the represented uncertainty. On a hidden-velocity indicator task requiring five-step forward simulation under a given action sequence with the target observation masked, UWM-JEPA reaches 0.77 accuracy and degrades monotonically as actions are perturbed; a parameter-matched LSTM-JEPA trained under the same counterfactual-target objective and action head collapses to majority-class accuracy (0.53) under every action condition. Under blind rollout, UWM-JEPA loses fewer than ten points of probe R^2 at short horizons while vector-latent baselines lose forty-one and sixty-eight; both nevertheless tie on a held-out context probe, locating the separation in the predictor rather than the encoder. Action sensitivity itself requires training against counterfactual rather than teacher-forced targets, a finding that applies beyond the unitary parameterisation. For JEPA world models to imagine under partial observability, latent geometry and predictor dynamics matter, not frozen context-encoding capacity alone.