HSAP: A Hierachical Sequence-aware Parallelism for Hybrid-Context Generative Models
2026-06-29 • Machine Learning
Machine LearningDistributed, Parallel, and Cluster Computing
AI summaryⓘ
The authors address problems in training large language models when sequences with mixed contexts are packed together, which confuses the model's attention mechanism in current parallel processing methods. They introduce a new Sequence-Aware Parallelism algorithm that improves communication and computation across multiple devices using just-in-time compilation. This algorithm is incorporated into a Hierarchical Sequence-Aware Parallelism framework that manages memory and communication more efficiently. Their experiments show this approach performs better than existing sequence parallelism methods.
sequence parallelismcausal attentionpacked sequenceslarge language modelsJIT compilationNCCLtensor communicationmemory managementhierarchical frameworkfine-tuning
Authors
Songxin Zhang, Zejian Xie, Zhuoyang Song, Cong lin, Junyu Lu, Jiaxing Zhang, Bingyi Jing
Abstract
In this paper, we aim to combine the advantages of existing sequence parallelism paradigms and overcomes their drawbacks, the most serious of which is the incapability to correctly compute causal attention on the hybrid-context packed sequences, in a stronger sequence parallelism framework. The practical technique of packing sequences for efficiently pretraining and fine-tuning large language models causes cross-contamination problem in attention computation, which can be effectively solved when no parallelism in the sequence length dimension is taken. However, in sequence parallelism, existing approaches either ignore the scenario of hybrid-context sequences or conversely sacrifice and limit parallelism degree for supporting the scenario. To this end, we innovatively propose an efficient Sequence-Aware Parallelism algorithm to conquer the obstacles of intensive tensor transmission and partial attention computation across multiple device groups. Our algorithm utilizes JIT (Just-In-Time) compilation to optimize the communication strategy of all device groups in NCCL level. Further, we integrate existing sequence parallelism paradigms into a Hierachical Sequence-Aware Parallelism framework which benefits from our sequence-aware algorithm. We additionally elaborate on the memory and communication overhead management of the hierachical framework to optimize its performance. Through multiple experiments, we demonstrate that our proposed approach outperform other state-of-the-arts sequence parallelism approches in multiple metrics.