Understanding Parallel Samplers in Masked Diffusion via Random Walks on Graphs

2026-06-22Machine Learning

Machine LearningArtificial IntelligenceComputation and Language
AI summary

The authors use random walks on graphs as a controlled testing ground to study how different ways of filling in missing parts (parallel sampling) work in masked diffusion models (MDMs). They show that checking if the output is a valid walk and comparing to the original random walk helps measure performance. They prove that some popular sampling strategies don't always perform better, depending on the graph's structure. They also create a new method that is faster and exact in ideal conditions. Their experiments show this approach can improve sampling speed and quality, even in language tasks.

random walksgraphsmasked diffusion modelsparallel samplingMarkov kernelentropybisection samplerlanguage generationOpenWebText
Authors
Vansh Bansal, Cho Cholyeon, Syamantak Kumar, Sujay Sanghavi, Purnamrita Sarkar
Abstract
In this paper, we propose using random walks on graphs as a verifiable sandbox to study different parallel sampling strategies in masked diffusion models (MDMs). We train an MDM on random walk samples from a fixed graph. The graph or the transition kernel is never shown to the model explicitly and plays the role of latent structure in the sequences, albeit one that is controllable and can be used for quantitative evaluation. Thus, this framework enjoys a Sudoku-like validity check: verifying that an output is a valid walk and estimating the Markov kernel from the walks to measure distribution fidelity. Using simple graphs, we theoretically prove that parallel unmasking via widely used scores like lowest entropy is not uniformly better than a random parallel sampler; the performance critically depends on the structure of the underlying graph. We develop a new bisection sampler for random walks, which takes logarithmic steps in the sequence length and is provably exact under perfect training. Experiments on various graph walk tasks show that different parallel samplers are better for different graphs even in practice. Our initial experiments on a pretrained OpenWebText MDM show that the bisection-style samplers improve speed-quality tradeoffs even for language generation. Together, these results position graph random walks as a mechanistic benchmark for diagnosing and designing parallel samplers for masked diffusion models.