Toward Compiler World Models: Learning Latent Dynamics for Efficient Tensor Program Search
2026-06-08 • Machine Learning
Machine LearningProgramming Languages
AI summaryⓘ
The authors address the challenge of optimizing tensor programs, which is important for speeding up machine learning but has a huge number of possible options to try. They propose a new evaluator inspired by world models that predicts how sequences of scheduling actions affect program performance without repeatedly generating full code versions. Their method, integrated into the TVM AutoScheduler, results in faster program execution and requires fewer trials compared to previous approaches. This leads to significant speedups on both GPUs and CPUs while reducing the measurement effort needed.
tensor program optimizationauto-schedulercost modelsschedule trajectorylatent dynamicsTVM AutoSchedulerAST mutationGPUCPUprogram scheduling
Authors
Haolin Pan, Lianghong Huang, Xvlin Zhou, Mingjie Xing, Yanjun Wu
Abstract
Tensor program optimization is essential for modern machine learning systems, but its search space is enormous. Existing auto-schedulers reduce measurement cost with learned cost models, yet they usually evaluate each candidate as a static code snapshot, ignoring the schedule trajectory that produced it. This makes them insensitive to action dependencies and vulnerable to superficial code variations. We propose a \emph{world-model-inspired} evaluator that models schedule evaluation as action-conditioned latent dynamics over program states. Starting from the initial program, it rolls out scheduling actions in a continuous latent space with a lightweight transition model, avoiding expensive AST mutation and repeated code encoding. The final dynamic representation is combined with action and hardware features to rank candidates. Implemented in TVM AutoScheduler, our method improves representative-subgraph latency over Ansor by 1.37$\times$ on GPU and 1.54$\times$ on CPU under the same 64-trial budget. It also matches Ansor-10K within 2.2% geometric mean using 10$\times$ fewer measurements, and accelerates full-model inference over PyTorch/PyTorch-opt(cuDNN) by 4.61$\times$/3.67$\times$ geometric mean.