Discovering Latent Groups for Robust Classification
2026-06-22 • Machine Learning
Machine LearningArtificial IntelligenceComputer Vision and Pattern Recognition
AI summaryⓘ
The authors present neural classification trees (NCT), a new way for machine learning models to understand and handle hidden groups in data without needing extra group labels. Instead of just giving a final prediction, their model builds a tree structure that sorts data points into 'easy' or 'hard' categories based on how well they are predicted. This process helps the model separate tricky, underrepresented groups, making the model’s behavior easier to understand. They tested NCT on several datasets and found it works well compared to current top methods while also showing clear links between the model's structure and group differences in the data.
spurious correlationsunderrepresented subgroupsneural classification treesmodel robustnesspseudo-labelslatent subgroup structureinterpretabilitymulti-class classificationbinary classificationmachine learning fairness
Authors
Ankur Garg, Ulrich Aïvodji, Samira Ebrahimi Kahou, Vincent Michalski
Abstract
Machine learning models exploit spurious correlations, achieving high average accuracy but failing disproportionately on underrepresented subgroups. Existing methods address this by adjusting network parameters, guided either by subgroup annotations or inferred pseudo-group labels. Yet at inference, these methods produce only a class prediction, with no insight into a sample's latent subgroup. We propose neural classification trees (NCT), a framework that achieves robustness by encoding subgroup structure in its tree-shaped architecture. By routing each sample to an "easy" or "hard" node of this tree -- based on prediction correctness -- and reusing these routes as pseudo-labels for the next iteration, NCT disentangles conflicting subgroups, without requiring subgroup supervision. We evaluate NCT on five benchmarks spanning binary and multi-class spurious correlations. Our experiments show that the learned tree topology provides strong interpretability by consistently isolating minority subgroups, which provides a transparent mapping between the model architecture and the data's latent group structure, while yielding competitive robustness with state-of-the-art methods.