GRAIN: Group Aggregation via Min-Norm Objective
2026-06-22 • Machine Learning
Machine Learning
AI summaryⓘ
The authors address instability problems that happen when training very large deep learning models on small datasets, which can lead to different final results each time. They propose GRAIN, a new training method that changes how gradient updates from different groups of data are combined, using a min-norm approach instead of averaging. This method helps avoid conflicts between gradient updates and makes training more stable without extra computational cost. Tests show that GRAIN improves performance consistency across various large-scale tasks compared to standard training approaches.
overparameterized modelsgradient aggregationmini-batch optimizationstochastic gradient descent (SGD)loss landscapeuniform stabilitygradient conflictpretrained modelstraining variance
Authors
Nghia Bui, Jiarui Yao, Lijing Wang
Abstract
Learning instability is a long-standing problem across machine learning, but it is especially acute in the overparameterized regime that defines modern deep learning: large models fine-tuned or trained on limited data traverse flat loss landscapes with many nearly-equivalent minima, and stochastic factors (initialization, data order, dropout, hardware non-determinism) can route optimization to very different solutions. The rise of large pretrained models (LPMs) makes the problem more urgent: training cost is high, downstream data is often small, and repeated runs for variance reduction are prohibitive. We introduce \textbf{GRAIN} (\textbf{G}roup \textbf{A}ggregation via m\textbf{IN}-norm objective), a lightweight training algorithm that replaces the mean aggregation used in mini-batch optimization (both across mini-batches and within a mini-batch) with a min-norm convex combination of group-wise gradients. \mName guarantees a non-negative inner product between the aggregated update and every group gradient, resolving intra- and inner-batch gradient conflict, and retains an $\mathcal{O}(1/T)$ convergence rate comparable to SGD. Under mild smoothness and absolute-continuity assumptions, the min-norm solution differs almost surely from the arithmetic mean, which yields a uniform-stability bound for \mName strictly tighter than the standard bound for SGD. Empirically across generation, classification, and regression at LPM scale, \mName delivers consistent improvements in mean performance and reductions in run-to-run variance over a broad suite of tasks, with no extra training-time or storage cost beyond a single backward pass.