FORGE: Fused On-Register Gradient Elimination for Memory-Efficient LLM Training
2026-06-22 • Machine Learning
Machine Learning
AI summaryⓘ
The authors explain that traditional training of neural networks saves all weight gradients in memory before updating, which limits how big models can be. They introduce FORGE, a method that combines the gradient calculation and optimizer step immediately, so gradients don't need to be stored as separate large tensors. This saves a lot of memory and keeps the update accurate, even with lower precision formats. Their experiments show that FORGE allows training bigger models with less memory and faster speed on typical GPUs.
reverse-mode differentiationweight gradientsoptimizer stepmemory usagegradient foldingregistersbf168-bit optimizationtensor parallelismMegatron-LM
Authors
Dikshant Kukreja, Kritarth Prasad, Avinash Anand, Zhengkui Wang, Erik Cambria, Timothy Liu, Aik Beng Ng, Simon See, Bapi Chatterjee
Abstract
Reverse-mode differentiation computes every weight gradient, writes it to memory, and only then lets the optimizer read it back. This two-phase schedule sets the memory ceiling of modern training: at the seam between the phases, every layer's gradient is live at once. We argue that this materialized gradient is an artifact of how differentiation is staged, not a quantity that learning requires -- and we eliminate it. FORGE folds the optimizer step into the backward pass and applies it one tile at a time, entirely in registers, so each gradient tile is consumed the instant it is produced and never becomes a tensor. The fusion changes only when the update happens, not what it computes: in full precision the fused step is provably exact -- the identical optimizer update, for every element-wise rule -- and that exactness survives tensor- and sequence-parallel sharding; in the bf16 and 8-bit regimes used in practice it is faithful rather than bit-identical, its deviation bounded and, for the weight store, rendered unbiased by stochastic rounding. Because each gradient tile is born and consumed in the same registers, it is never converted down to bf16 to be stored and read back; FORGE thus preserves the full-precision fidelity that both bf16 and 8-bit optimizers lose to that conversion. Nor is the method tied to one architecture or one optimizer: linear layers are ubiquitous, and FORGE reclaims the gradient memory of any of them under any element-wise rule. Empirically FORGE more than halves the memory of an optimizer step and, at the small batch sizes typical of fine-tuning and continued pretraining, runs about 1.5x faster; integrated into tensor-parallel Megatron-LM it fits 8B training at four times the micro-batch a standard optimizer allows on the same GPUs.