Invariant Gradient Alignment for Robust Reasoning Distillation

2026-06-03Machine Learning

Machine LearningArtificial Intelligence
AI summary

The authors address a problem where large language models fail to handle new types of questions that look different but have the same logical structure as what they were trained on. They propose a new training method called Invariant Gradient Alignment (IGA) that helps models learn patterns that stay consistent across different topics like math or medicine. IGA works by focusing on shared logical parts of problems and carefully adjusting learning updates to ignore confusing differences between domains. Their approach improves performance on new, different data and keeps the model efficient. Experiments show it outperforms previous methods in accuracy and logical consistency.

Large Language ModelsShortcut LearningOut-of-Distribution GeneralizationChain-of-Thought ReasoningLogical Isomer SetsGradient AlignmentContinuous Gradient Conflict MaskLow-Rank Adaptation (LoRA)Stochastic Gradient DescentRepresentational Invariance
Authors
Zehua Cheng, Wei Dai, Jiahao Sun
Abstract
Large language models (LLMs) suffer from shortcut learning: they systematically fail on out-of-distribution (OOD) inputs whose semantic surface differs from training data, even when the logical structure is identical. This undermines knowledge distillation pipelines that transfer chain-of-thought reasoning to smaller students. We introduce Invariant Gradient Alignment (IGA), a training framework that aligns gradient updates across semantically diverse but logically isomorphic examples via three innovations: (i) Logical Isomer Sets, groups of problems sharing identical logical structure across distinct semantic domains (mathematics, medicine, law, science); (ii) a differentiable \emph{Continuous Gradient Conflict Mask}, that suppresses parameter dimensions with high cross-domain gradient variance while preserving invariant directions; and (iii) a truncated SVD projection of the masked gradient back onto the LoRA low-rank manifold, maintaining parameter efficiency throughout. Theoretically, IGA yields tighter OOD generalization bounds than ERM, scaling with the number of isomer domains, and converges at the standard SGD rate under mild regularity. Empirically, IGA outperforms eight baselines across four benchmarks with accuracy gains up to 14.3 pp over ERM-SFT and a Logical Consistency Score of 0.031 versus 0.142 -- a fourfold improvement in representational invariance.