Taming Curvature: Architecture Warm-Up for Stable Transformer Training
2026-06-15 • Machine Learning
Machine Learning
AI summaryⓘ
The authors address the problem of unstable training in very large Transformer models, where the training process can suddenly become erratic and inefficient. They develop a fast method to estimate the biggest measure of curvature (a mathematical concept related to how the model changes) during training, even at billion-parameter scales. Using this, they discover that training problems happen when this curvature spikes, especially in deeper networks. To fix this, they propose gradually increasing the model's depth during training to keep things stable. Their experiments show this approach helps track curvature effectively and reduces instability without slowing down learning.
TransformerEdge of StabilityCurvatureHessianPower iterationHessian-vector productsOptimization stabilityBillion-parameter modelsNetwork depthTraining instabilities
Authors
Sameera Ramasinghe, Ajanthan Thalaiyasingam, Hadi Mohaghegh Dolatabadi, Chamin Hewa Koneputugodage, Gil Avraham, Violetta Shevchenko, Yan Zuo, Karol Pajak, Alexander Long
Abstract
Training billion-parameter Transformers is often brittle, with transient loss spikes and divergence that waste compute. Even though the recently developed Edge of Stability (EoS) theory provides a powerful tool to understand and control the stability of optimization methods via the (preconditioned) curvature, these curvature-controlling methods are not popular in large-scale Transformer training due to the complexity of curvature estimation. To this end, we first introduce a fast online estimator of the largest (preconditioned) Hessian eigenvalue (i.e., curvature) based on a warm-started variant for power iteration with Hessian-vector products. We show theoretically, and verify empirically, that the proposed method makes per-iteration curvature tracking feasible at billion parameter scale while being more accurate. Using this tool, we find that training instabilities coincide with surges in preconditioned curvature and that curvature grows with depth. Motivated by these observations, we propose architecture warm-up: progressively growing network depth to carefully control the preconditioned Hessian and stabilize training. Experiments on large Transformers validate that our approach enables efficient curvature tracking and reduces instabilities compared to existing state-of-the-art stabilization techniques without slowing down convergence.