腾讯混元Stem稀疏注意力算法:长文推理加速新SOTA
摘要
面对一份数万字的长文档,当用户将全部内容输入大模型并按下发送键后,光标迟迟不动的
面对一份数万字的长文档,当用户将全部内容输入大模型并按下发送键后,光标迟迟不动的等待期,正是预填充(Prefill)阶段。
此阶段的性能瓶颈,根源仍在于Transformer自注意力机制的二次方计算复杂度:输入序列越长,预填充耗时呈平方级攀升。块级稀疏注意力(Sparse Attention)是目前主流的破解路径——仅聚焦“关键”token,大幅削减冗余计算。
然而,从算法到算子,现有方案存在明显缺陷。算法层面,要么对所有位置“一刀切”分配相同的稀疏预算,忽略因果架构中初始token的递归依赖特性;要么单纯依赖注意力分数筛选token,忽视Value向量的实际信息承载量。算子层面则存在隐性门槛:即便稀疏模式再精巧,若底层算子无法高效跳过被丢弃的块,加速效果必然大打折扣。
腾讯混元AI Infra团队提出的Stem方案,已被机器学习顶会ICML-26收录。该方法从“因果信息流”视角重构块级稀疏逻辑,凭借Token位置衰减(TPD)与输出感知度量(OAM)两项创新,仅消耗25%的算力即可逼近稠密注意力精度。配套的HPC算子库则将理论加速比真正转化为端到端实测性能。

一、Stem算法:从“信息流”重新定义稀疏注意力
1. 核心洞察:初始token是信息流的“树干”
Stem这一命名,取自“树干”的隐喻。在因果注意力架构中,序列初始位置的token如同树的主干,支撑着后续所有信息传递的根基。




2. 实战验证:Stem全栈加速效果
理解了“初始token为何关键”,你或许会追问:这一发现投入真实生产环境后,能交出怎样的数据?
我们并未止步于学术基准测试,而是直接将Stem嵌入腾讯混元Hy3 preview(W8A8-FP8)的vLLM推理框架,配合HPC团队优化的Stem算子,端到端测量首字延迟(TTFT)与模型精度。这意味着Stem不仅在BF16学术基线上可运行,更需在量化后的工业级模型中保持稳定表现。
至于Stem与其他稀疏算法在开源模型上的速度与精度对比,论文已给出完整结论,此处不再复述。下面直接展示Stem在Hy3 preview(W8A8-FP8)上更贴近生产环境的落地数据。
2.1 首字加速比

2.2 模型精度

3. 揭秘:Stem凭什么又快又准?
数据已然佐证,核心疑问在于:Stem如何在削减75%计算量的同时,仍维持精度不降?
答案隐藏在两个看似简单、却被以往研究长期忽略的细节里:预算在何处倾斜?筛选token时依据什么?
Stem的两大核心创新——Token位置衰减(TPD)与输出感知度量(OAM)——分别给出了回答。
3.1 Token位置衰减策略(TPD):重新分配预算,而非增加预算


关键在于:TPD并未提升总计算量,而是在相同预算下重新调配资源,将算力聚焦于信息流的关键节点。


3.2 输出感知度量(OAM):不只看“路由分数”,更看“信息贡献”
预算分配策略确定后,下一步是:在预算范围内,选择哪些token?
传统方法仅以注意力分数(Query与Key的点积)为依据,但Stem揭示了一个常被忽略的事实:分数高并不等于实际贡献大。某个token可能与Query高度相关,但若其Value向量幅值趋近于零,则对输出几乎无贡献;反之,分数中等但Value幅值巨大的token,才是真正的“信息富矿”。
注意力的本质是加权求和,token的真实贡献等于路由概率与信号幅值的乘积。Stem据此提出输出感知度量(OAM):
![]()

至此,Stem算法层面的全貌已然清晰:TPD决定“在何处倾注预算”,OAM决定“在该处选择哪些token”,一前一后,将稀疏注意力从“一刀切”升级为“依信息流动态分配稀疏预算”。
但算法选型再精准,最终能否在GPU上实现3.6倍加速,仍取决于底层算子的配合度。下一节,我们将介绍HPC算子。
二、算子优化:HPC-Stem + HPC-BSA
1. 现有算子实现的挑战
块级稀疏注意力的计算分为两个阶段:评估选块(评估每个block的重要性并记录结果)与稀疏执行(按评估结果跳过不重要的block,仅对选中block执行注意力计算)。在现有生态中,两个阶段均存在瓶颈。评估选块流程中,主流方法依赖softmax归一化,导致额外的gather操作和FP8 GEMM精度误差,需维护大型中间张量(128K下可达16 GB)。稀疏注意力计算过程中,现有开源BSA算子需动态判断跳过哪些块,引入显著的跳块开销。
2. HPC优化:我们做了什么
针对上述两个瓶颈,我们分别设计了两大核心算子:HPC-Stem将评估选块流程加速数十倍,HPC-BSA面向Hopper架构,将稀疏算法的理论加速比真正落地。
2.1 Stem算子优化
HPC-Stem的评估流程分为OAM评分与TPD选块两步。OAM评分方面,我们发现其纯加法的度量结构(不依赖softmax和gather)支持一项关键数学简化:原始实现中对采样后的Q/K做全量矩阵乘法产生的中间张量,可等价转化为先预计算Q和K各自的紧凑block级别表示,再利用一次标准GEMM直接获得全部block评分,计算量降低约64倍,中间张量完全消除。TPD选块方面,我们将预算生成与选块排序的逻辑融合为单一算子,显著提升评估速度。
2.2 BSA算子优化
HPC-BSA针对Hopper架构从零设计,采用数据搬运与计算并行的流水线架构,原生支持vLLM的Paged KV Cache与FP8量化(计算吞吐翻倍)。

在此基础之上,块级稀疏的支持自然融入。如上图所示,经过分页的KV Cache本就非连续,需逐页加载。对于稀疏场景,kernel在处理每个Q分块时,先将对应的block mask缓存至片上高速存储,并在线构建出需要计算的KV分块列表。内层循环仅遍历这些有效分块,完全避免逐次判断跳过带来的额外开销。被跳过的分块在数学上等价于注意力分数全为负无穷,不影响softmax的正确性,计算逻辑无需任何修改。

如上图所示,相较于MIT原版算子中逐步查询索引、计算偏移的跳块流程,HPC-BSA将稀疏的判断与筛选完全前置至循环之外,内层计算路径与稠密Attention几乎一致,实现近零开销的块级跳过。
3. Benchmark:算子级性能
我们对HPC-BSA算子进行了性能测试,以HPC-Dense (FP8) 与 FlashAttention V3 (FP8) 作为稠密基线,以MIT-BSA (BF16) 和 FlashPrefill-BSA (BF16) 两个主流开源实现作为稀疏对照。
结果揭示了三个关键发现。第一,HPC-BSA的延迟与计算密度呈近乎完美的线性关系:50%稀疏度下延迟约为稠密基线的一半,80%稀疏度下仅为约五分之一,跳块机制的额外开销控制在2.5%以内,算法的理论加速比几乎被完整转化为实测性能。第二,HPC-BSA相比MIT开源的BSA算子(MIT-BSA)在全稀疏度范围内稳定保持约3倍加速,该收益来自FP8计算吞吐优势与Hopper架构优化的叠加。第三,上述优势在8K到256K的全序列长度范围内保持稳定,展现出良好的长序列扩展性。
不同稀疏度(block sparse ratio)下,BSA算子延时的变化:

不同稀疏度下HPC-BSA (FP8) 的延迟与稠密基线对比。标注百分比为HPC-BSA延迟占HPC-Dense的比例。淡红虚线为理论的最快速度上限(HPC-Dense × density)。
64K输入下,不同稀疏度(block sparse ratio)下HPC-BSA算子与主流开源BSA算子加速比。HPC-BSA相比MIT-BSA (BF16) 在全稀疏度范围内稳定保持约3.1倍加速,相较于FlashPrefill (BF16) 稳定在2.1倍加速左右。

75%稀疏度下,不同序列长度下HPC-BSA相比主流开源BSA算子的加速比,相较于两个基线同样保持了较稳定的效果。

三、总结与展望
来源:互联网
本网站新闻资讯均来自公开渠道,力求准确但不保证绝对无误,内容观点仅代表作者本人,与本站无关。若涉及侵权,请联系我们处理。本站保留对声明的修改权,最终解释权归本站所有。