Diffusion Fine-tuning with Rewarded Moment Matching Distillation

2026-06-29Machine Learning

Machine Learning
AI summary

The authors introduce Rewarded Moment Matching Distillation (RMMD), a new method that combines two common steps in improving diffusion models—distillation and reinforcement learning fine-tuning—into one process. RMMD keeps the high quality of models by changing how training samples are generated and using the distillation loss to control model behavior. They tested RMMD on ImageNet and showed it performs better than other similar methods. They also applied it to a weather forecasting model, where RMMD sped up the model significantly while making it more accurate and well-calibrated. This shows RMMD can handle very complex scientific data.

Diffusion modelsModel distillationReinforcement learning fine-tuningMoment matchingKL regularizationFID scorePareto frontImageNetContinuous Ranked Probability Score (CRPS)Weather forecasting models
Authors
Alexis Jacq, Guillaume Couairon, Valentin De Bortoli, Quentin Berthet, Arnaud Doucet, Romuald Elie
Abstract
Distillation and Reinforcement Learning (RL) fine-tuning are the primary pillars of diffusion post-training. While traditionally studied in isolation, the interaction between these phases remains poorly understood, and in particular how fine-tuning impacts the generative quality of distilled models. We introduce Rewarded Moment Matching Distillation (RMMD), a novel framework that simultaneously distills diffusion models and maximizes a reward function. RMMD preserves the high-fidelity ``naturalness'' characteristic of advanced distillation (such as 8-step Moment Matching) by adapting the sampling loop for on-policy training and repurposing the distillation loss as a proxy for integral KL regularization. By evaluating the FID-Reward Pareto fronts on ImageNet, we demonstrate that RMMD achieves superior trade-offs compared to single-step baselines (DI++) and multi-step competitors (DRaFT, HyperNoise). Finally, we apply RMMD to GenCast, a state-of-the-art weather forecasting model, to distill it while optimizing the Continuous Ranked Probability Score (CRPS) metric. The resulting distilled model achieves a 7.5x speedup while outperforming the teacher model on 93% of target weather variables, and being better calibrated. This proves that RMMD scales to complex, high-dimensional scientific domains.