SAE-FD: Sparse Autoencoder Feature Distillation for Continual Learning of Large Language Models

2026-05-25Machine Learning

Machine Learning
AI summary

The authors study how large language models can learn new tasks over time without forgetting what they already know, a problem called catastrophic forgetting. They focus on regularization methods that try to protect old knowledge but struggle because model features overlap in complex ways. To fix this, they propose a new method called Sparse Autoencoder Feature Distillation (SAFD), which uses a special encoding to separate features more clearly. Their experiments show this method helps models remember old tasks better while still learning new ones effectively.

Continual LearningCatastrophic ForgettingRegularizationSparse AutoencoderFeature SpaceRepresentation LearningBackward TransferOvercomplete BasisLarge Language ModelsDense vs Sparse Representations
Authors
Mingxu Zhang, Yuhan Li, Lujundong Li, Dazhong Shen, Hui Xiong, Ying Sun
Abstract
Continual learning enables large language models to adapt to evolving tasks without retraining from scratch, yet catastrophic forgetting remains a central obstacle. Among continual learning methods, regularization-based approaches are widely used to constrain model updates and reduce forgetting, operating in weight space, gradient space, or output space. However, these dense representation spaces suffer from feature superposition, where multiple concepts are encoded in overlapping dimensions, making it difficult to selectively protect previously learned knowledge without impeding new-task learning. To address this issue, we propose \method (Sparse Autoencoder Feature Distillation), which anchors model representations in the sparse feature space of a pre-trained Sparse Autoencoder, where dense activations are decomposed into a sparse overcomplete basis that reduces representational entanglement, enabling more targeted regularization with less interference to new-task learning. Experiments on two continual learning benchmarks across three model architectures show that \method consistently outperforms existing regularization-based methods, achieving up to 52.70% average accuracy with only -0.46 backward transfer.