QK-Normed MLA: QK normalization without full key caching
2026-06-15 • Machine Learning
Machine LearningComputation and Language
AI summaryⓘ
The authors explain how to make a technique called QK normalization work with Multi-head Latent Attention (MLA), which usually isn't straightforward. They found that what looked like a problem was actually just how things were implemented, and by adjusting some math steps, they made QK normalization compatible with MLA. Their experiments with large models showed this approach improves training and accuracy without making the system much slower. This means that QK normalization can now be used with MLA models more easily and efficiently.
Query-Key NormalizationMulti-head Latent AttentionRMSNormAttention MechanismLatent StatesModel DecodingTraining LossContext LengthCache Mechanism
Authors
Yizhou Han, Yao Zhao, Jun Zhou, Longfei Li, Ruoyu Sun
Abstract
Query-key (QK) normalization stabilizes attention by controlling the scale of queries and keys before the dot product, but is not immediately compatible with Multi-head Latent Attention (MLA). MLA achieves efficient decoding by caching low-dimensional latent states instead of full keys, whereas post-projection QK RMSNorm appears to require the fully projected key for every cached token. We show this apparent incompatibility is an implementation artifact, not an architectural constraint. RMSNorm decomposes into a static affine weight and a dynamic scalar RMS statistic. The static key-side weight can be absorbed into the MLA query-side projection; the dynamic key statistic reduces to one inverse-RMS scalar per token and KV group. The resulting formulation is exactly equivalent to explicit post-projection QK RMSNorm in exact arithmetic and preserves MLA's latent decode path. In our 400M runs trained for up to 100B tokens, QK-Normed MLA achieves lower training loss and better downstream accuracy than QK clipping, while H800 decode benchmarks show less than 2% latency overhead up to 256k context. These results make QK normalization a practical stabilization option for MLA models without requiring full-key caching.