#计算机体系结构

汉松
1周前
从零实现 vLLM 的第三篇文章,我们来了解如何加速 Attention 计算,学习 FlashAttention 的原理。 要理解 FlashAttention 的巧妙,我们必须先理解传统注意力机制的“笨拙”之处。 匹配度计算(QK):你(Query)拿着一个“科幻小说”的主题清单,去比对图书馆里成千上万本书的标签(Key),得出一个巨大的“匹配度”分数表。 权重分配(Softmax):你将这张分数表转化为百分比,告诉你应该投入多少“注意力”到每一本书上。 内容加权(AV):最后,你根据这些百分比,将所有书的内容摘要(Value)融合,得到一份为你量身定制的、关于“科幻小说”的综合信息。 这个流程在理论上无懈可击,但在实际的硬件执行中,却隐藏着一个致命的性能瓶颈。 想象一下GPU的内存结构:它有一小块速度飞快的“片上内存”(SRAM),就像你手边的工作台;也有一大块容量巨大但速度较慢的“全局内存”(DRAM/HBM),如同一个需要长途跋涉才能到达的中央仓库。 传统的注意力计算,就像一个效率极低的工匠。他在工作台(SRAM)上完成第一步,计算出那张巨大的“匹配度”分数表后,并不直接进行下一步。相反,他必须先把这张巨大的、还只是“半成品”的表,辛辛苦苦地运送到遥远的中央仓库(DRAM)存放。接着,为了进行第二步Softmax计算,他又得从仓库把这张表取回来。计算完成后,得到的“注意力权重”表,又是一个半成品,他再次将其送回遥远的仓库。最后,为了完成第三步,他需要同时取回“权重”表和所有书籍的“内容”,才能在工作台上完成最终的融合。 这个过程中,真正的计算(点积、Softmax)或许耗时并不长,但来回搬运这些巨大中间产物(匹配度矩阵和注意力矩阵)的时间,却成了无法忍受的开销。 这就是I/O瓶颈——当序列长度N增加时,这些中间矩阵的大小会以 N 平方 的速度急剧膨胀,频繁的读写操作会让GPU的大部分时间都浪费在等待数据上,而非真正的计算。 FlashAttention的革命:合并工序,一步到位 FlashAttention的作者们洞察到了问题的本质:我们需要的只是最终的结果O,中间过程的矩阵其实根本不必“留档”。 于是,他们进行了一场工作流程的革命。他们没有发明新的工具或公式,而是彻底改造了生产线,将三个独立的工序融合成一个在高速工作台(SRAM)上一气呵成的“超级工序”。 这场革命的核心武器有两个:分块(Tiling) 与 在线Softmax(Online Softmax)。 1. 分块处理: FlashAttention不再试图一次性处理所有书籍(整个K和V矩阵)。而是像一位聪明的工匠,把任务分解。他每次只从仓库中取一小批书的标签(K块)和内容(V块)到他的工作台上。 2. 在线Softmax的魔法: 这是整个流程中最精妙的部分。传统的Softmax需要“总览全局”才能计算,这也是为什么它难以被分块的原因。但FlashAttention通过一种巧妙的递推算法,实现了“在线”更新。 想象一下,工匠在处理完第一批书后,会得到一个临时的、局部的结果,并记录下两个关键的“全局统计数据”:到目前为止见过的最高匹配分(m)和当前结果的归一化因子(d')。当第二批书的数据被取到工作台上时,他不需要回头看第一批书的细节。他只需利用新一批书的数据和之前存下的那两个“全局统计数据”,就能计算出一个更新后的、融合了前两批书信息的新结果,并再次更新这两个统计数据。 这个过程不断重复,每一批新的K/V数据块都被加载到高速SRAM中,与Q进行计算,然后用来迭代更新最终的输出O以及那几个关键的统计量。自始至终,那张庞大的、完整的注意力分数矩阵从未在任何地方被完整地构建出来。 它就像一个在计算过程中短暂存在的“幽灵”,用完即逝,从而彻底消除了对慢速全局内存的读写瓶颈。 FlashAttention的成功,给我们带来了远超于算法本身的启示。它证明了,在AI的“摩天大楼”越建越高的今天,地基:计算机体系结构,它的重要性从未改变。 它的巧妙之处,不在于发明了更复杂的数学公式来拟合数据,而在于深刻理解了硬件的工作原理,并用最经典、最基础的计算优化思想:减少内存访问,去解决一个看似前沿的AI问题。 一个为 Attention 计算带来极大加速的技术,其内核只不过是一场计算机体系结构的大师级实践课。它证明了,最深刻的优化往往不是发明新事物,而是熟练掌握最基本原则,并应用于新的硬件之上。 山川地貌会变,但万有引力定律亘古不变。 对技术细节感兴趣的可以阅读公众号的原文,算法公式比较抽象,我画了很多图来辅助理解,链接在回复中。