Scheduling Thoughts: Learning the Order of Thought in Diffusion Language Models

2026-06-22Machine Learning

Machine LearningArtificial Intelligence
AI summary

The authors study how the order of unmasking tokens in masked diffusion language models affects the quality of generated text. They develop a mathematical way to measure and improve the choice of this order, framing it as a policy optimization problem. Their method, called Self-Aware Scheduling (SAS), learns a better order to decode tokens, leading to improved results on tasks like Sudoku and math problem solving compared to previous heuristic methods. SAS works with different decoding styles and can further improve with fine-tuning. Overall, their approach makes the generation process smarter about choosing token order to boost accuracy.

masked diffusion language modelsunmasking orderKullback-Leibler divergencesequential decodingpolicy optimizationSelf-Aware Scheduling (SAS)Group Relative Policy Optimizationheuristic schedulingdenoisermathematical reasoning
Authors
Jiawei Xu, Minghui Liu, Aakriti Agrawal, Yifan Chen, Furong Huang
Abstract
Masked diffusion language models decode by iteratively unmasking tokens, where the unmasking order defines an "order of thought" that strongly influences generation quality yet is typically chosen heuristically. We derive a tractable upper bound on the sequential decoding mismatch, measured by the Kullback-Leibler divergence and expressed in terms of the model's pathwise log-likelihood, with tightness under sufficient model expressivity. This bound induces a dense self-aware reward over ordered trajectories, casting order selection as a principled policy optimization problem with a frozen denoiser. We instantiate this idea as Self-Aware Scheduling (SAS), which learns a lightweight order policy using Group Relative Policy Optimization and applies seamlessly to both any-order and semi-autoregressive decoding. On Sudoku with 1B MDM, SAS improves puzzle accuracy from 82.0% (best heuristic schedule) to 91.8%, and reaches 97.5% with second-stage fine-tuning along learned trajectories. On mathematical reasoning with LLaDA-8B, SAS improves pass@1 on GSM8K from 64% to 76% and on MBPP from 39.5% to 41%, consistently matching or exceeding heuristic schedules across generation lengths and block sizes. Project page: https://jimmyxu123.github.io/SAS