Stochastic Estimation of the Layer-wise Hessian Trace for Monitoring Neural-network Training
2026-05-25 • Machine Learning
Machine Learning
AI summaryⓘ
The authors developed a new method to efficiently estimate parts of the Hessian matrix, which measures how the training loss of a neural network curves around its parameters. Their method combines a known stochastic trace estimator with a single Hessian-vector product to get unbiased per-layer curvature estimates in one backward pass, even for very large networks. They also identified a bias source when weights are shared and provided formulas for the estimator's variance to guide practical use. Applying this, they successfully detected when networks start memorizing training labels using popular architectures on CIFAR datasets with high accuracy.
Hessian matrixempirical risktrace estimatorHessian-vector productweight sharingstochastic estimationlabel memorizationResNetVGGCIFAR-10
Authors
Maxim Bolshim, Alexander Kugaevskikh
Abstract
The loss and the norm of its gradient separate the healthy and the pathological regimes of neural-network training only weakly, whilst the curvature of the empirical risk differs qualitatively between them but is inaccessible explicitly at parameter counts $P\sim 10^{6}-10^{8}$. We present a stochastic estimator of the trace of the diagonal blocks of the Hessian matrix of the empirical risk of a neural network. The procedure combines the Hutchinson stochastic trace estimator with a single Hessian-vector product over the whole parameter vector and recovers unbiased estimates of every per-layer trace in one backward pass through the computational graph. We show that correctness under weight sharing requires the layer-wise Hessian to be assembled before the second differentiation: unrolling shared weights into independent coordinates introduces a systematic bias whose sign and magnitude are governed by the cross-instance blocks of the unrolled Hessian. A closed-form expression for the variance of the estimator at a fixed Hessian is derived, together with a decomposition of the total variance under the mini-batch sampling distribution. This decomposition yields a critical probe count $K^{\star}$ that balances the two sources of randomness and supports the practical recommendation $K\in[5,10]$ in the on-line monitoring regime. The estimator is applied to the detection of the label-memorisation regime of ResNet-18, ResNet-34, and VGG-11 on CIFAR-10 and CIFAR-100, where a calibrated cumulative-sum decision rule attains an empirical detection power of $179/180$ at a false-alarm rate of $16/120$.