AI 和 LLM 的进步通常归因于三个方面的持续改进:模型、数据、计算。三者互相关联。要跑起那些参数量庞大的模型,就需要足够的计算资源来支撑。Llama 3 最大的模型超过 4000 亿参数在 16000 块 GPU 上训练了数周乃至数月,优化计算意味着在更低的成本下训练更大的模型。
本文将介绍 GPU 的核心特性,并据此讨论如何设计更快的算法。
GPU 与 CPU 的区别CPU 的优化目标是单任务延迟,尽可能快地完成一个任务然后转向下一个,这对通用计算是非常合理的。但是GPU 则不同,它优化的是吞吐量追求的是同时完成多个并行任务。打个比方:CPU 像一个能力极强的工人,GPU 像一群普通工人同时干活。在 LLM 训练这种大规模并行处理场景下GPU 的架构天然占优。
继续用工厂来打比方。GPU 可以看作一个庞大的工厂城镇。城镇中有多个"工厂集群"(技术上叫流式多处理器,SM),每个集群包含多个工厂(流式处理器,SP)和一个小仓库(共享内存)。整个城镇里还有一个全局仓库(DRAM),离各集群更远但容量大得多。

类比虽然简化但说明了 GPU 中一条核心:集群内的小仓库访问速度远快于全局仓库,代价是容量小得多。
全局仓库的运输通道到底有多慢?过去 20 年间,硬件浮点运算能力(对应工厂车间的加工速度)提升了 60000 倍,DRAM 带宽只提升了 100 倍,互连带宽更是只有 30 倍。

过去的瓶颈在计算,但是现在的瓶颈在内存带宽。既然数据搬运才是真正的瓶颈,减少搬运次数和搬运量就是让 GPU 跑得更快的关键。以下五个技巧,都围绕这一思路展开,来自 CS336 课程。
技巧 1:低精度计算矩阵乘法中,数字精度是可以选择的。精度越高,存储一个数字所需的字节越多:9.327595 比 9.33 占的空间大。用低精度数字意味着搬运的"货物"更少,在拥堵的"道路"上花费的时间也更短。
这意味着用 fp16 代替 fp32,但并非训练的所有阶段都需要低精度,只需在数据搬运阶段降到 fp16 即可。具体做法是:输入以 fp16 格式传入,矩阵乘法在 32 位精度下完成(计算并非瓶颈,而且高精度可以防止舍入误差的逐步累积),输出再降回 fp16 用于传输。

回到工厂类比:道路拥堵(下图红线),所以进出工厂的箱子越小越好。工厂内部空间充裕,可以在大空间中完成加工,加工完成后再打包成小箱子运出。
技巧 2:算子融合假设工厂有三步操作:正方形变圆形,圆形变三角形,三角形变星形。如果每完成一步就把半成品送回仓库再取回来做下一步,那来回搬运的次数非常多。

算子融合的做法是把多步操作在工厂内一次性完成,省去中间产品的反复搬运。

实现方式有两种:手写低级代码控制融合细节,或者直接用 torch.compile 自动完成优化。
技巧 3:重计算这个场景稍微复杂一些,假设工厂从仓库取了一个正方形,依次加工为圆形、三角形、星形。星形被送回仓库供后续使用。但到了最终步骤,四种形状全部要用——正方形、圆形、三角形、星形。工厂内部存不下东西,所有存储必须依赖仓库。
安排生产线有两条路:
选项 1:加工过程中把圆形和三角形也送回仓库保管。需要的时候直接取回。
选项 2:不保存中间形状,丢掉就丢掉。需要的时候从正方形重新加工一轮。
选项 1 省了重新加工的电力,但仓库搬运量增大。选项 2 搬运量小,但要额外消耗算力。这是一个内存与计算之间的权衡。
既然瓶颈在道路拥堵而非车间产能,重计算(选项 2)是更合理的选择:重新加工成本低,但从仓库搬运的成本可能高出几个数量级。用算力换内存带宽,划算。
技巧 4:内存合并访问仓库有个特点:货物按板条箱整箱发出。工厂请求任何一件物品,仓库都会把整个板条箱送过来。优化的要点在于:把需要的物品尽量集中在同一个箱子里。
假设每箱 4 件,工厂需要 8 件。如果这 8 件集中在 2 个箱子里,取 2 箱就够了。如果散落在 8 个箱子中,就得搬 8 箱——搬运成本翻了四倍。
技术上,DRAM 以"突发模式"读取,每次读取返回一段连续字节。即使处理器只需要其中一个地址的数据,整个突发段也会被送过来。当所有线程的访问地址落在同一个突发段内时,只需一次 DRAM 请求,这种情况称为完全合并访问。
一个直接的推论:把维度(比如词汇表大小)对齐到 64 的倍数会带来可观的速度提升。

原因很简单:分块操作(见下一个技巧)需要沿突发段的边界读取数据,如果分块边界与突发段不对齐,读取次数会急剧增加。
技巧 5:分块分块的核心思想:把大矩阵切成小块,加载到共享内存(集群内的小仓库)中,避免反复访问全局内存。
以两个 4x4 矩阵 A 和 B 的乘法为例,结果是 4x4 矩阵 C。计算 C 的某几个元素时,需要在 A 和 B 矩阵上多次跨行/跨列读取,每次读取都要访问全局内存。

分块的做法是将 A 和 B 各切成四块。小块可以整块加载到共享内存中。先加载红色块,计算部分和(图中橙色部分):

接着加载下一组块,继续累加部分和。总计算量不变,但每一步都在共享内存中完成而非反复访问全局内存,节省的时间相当可观。
FlashAttention有了上面五个技巧做铺垫,可以来看 FlashAttention 了。
先简要回顾注意力机制。权重矩阵将隐藏向量投影为 Q、K、V,然后对每个词的 q 和 k 向量求点积(等价于 Q × K.T 的矩阵乘法),得到原始注意力分数——即每个查询词对各个键词的关注程度。对原始分数做 softmax 归一化,使其加和为 1。
数值稳定性方面,取指数之前先减去最大值。e¹² 已经是 162,755,超出 fp16 的上限 65,504,直接计算会溢出。减去最大值不改变 softmax 结果,但规避了溢出(详见附录)。

归一化后的 softmax 分数与每个词的"值"向量相乘、求和,得到最终的注意力输出。

回到 FlashAttention。Q 和 K 相乘产生一个 N × N 矩阵(N 为序列长度)。当上下文窗口很大时,这个矩阵无法整个放入共享内存。
解决方案是沿 N 维度分块。比如上下文窗口 1028,按 64 切块,每块可以载入共享内存。这样仍有完整的点积结果(无需计算部分和),只是逐块填充结果矩阵。

分块本身是标准操作,棘手的部分在于 softmax 和后续的值向量加权求和。计算 softmax 通常需要整行数据来做归一化,而访问整行意味着要回全局内存取数据。FlashAttention 的突破在于"在线 softmax"——softmax 计算和值向量加权求和可以在块内一次性完成,无需看到全行数据。关键条件是最终操作是加权求和,这给了逐块修正的数学余地。
下面用一个例子来说明。假设 QK 矩阵乘法产生了六个原始分数,表示某个查询词对六个其他词的关注度。常规做法是一次性对六个分数做 softmax 再与六个值向量加权求和,得到 A:

但遍历整个长度 N 的序列在块内放不下。于是按"在线"方式进行:将六个分数分为三个块(每块 2 个元素),逐块处理。第一个块中只有两个原始分数,先基于这两个值做计算:

这一步不做归一化。虽然可以用当前的和(1+0.0082)归一化,但后续块会改变总和,到头来还得修正。所以更好的做法是记录归一化分母的累积值最后一步统一归一化。
进入第二个块。目标是得到与一次性看到所有四个值相同的结果。四个值的全局最大值是 12,第一个块需要把自己的最大值传递过来。累积的加权和与归一化分母也要一并传递。

到目前为止,如果只有四个值,取 A_(1+2) 除以归一化总和 1.3098 就能得到最终结果。
最后一个边界情况是:新块出现了更大的最大值。第三个块的最大值从 12 变成了 13,但之前的 A_(1+2) 是按 max=12 算的。要让结果与一次性看到全部六个值一致,就需要修正之前的计算——将所有旧指数乘以 e^(-1)(即 e^(12-13)),补偿最大值的变化。
不需要逐个回去修正每个指数值,只需将 A_(1+2) 和归一化分母整体乘以 e^(-1) 即可:

最后用累积分子 A_(1+2+3) 除以更新后的分母 1.4955,得到结果。整个过程从未回访之前的块:只要跟踪最大值和归一化分母,就能逐块完成 softmax。这些操作都在共享内存中进行,不必频繁访问全局内存。
效果如何?FlashAttention 原始论文显示,在 GPT-2 上注意力计算的耗时减少了数倍。

处理大规模模型时,内存放置策略,比如尽量在共享内存中完成计算对整体性能的影响远超我们的想象。
并行计算简介以上讨论都局限在单 GPU 上,小模型没问题,但现代大型 LLM 根本装不进一块 GPU。Llama 3 用了 16K 块 GPU,核心问题变成了:如何将训练计算分配到多台机器上,再将结果汇总起来。
在展开不同的并行策略之前,先回顾训练流程。以一个 2 层神经网络为例,batch size 16,使用 Adam 优化器(为每个参数维护一阶和二阶矩估计)。

拆分计算的第一种方式是拆分数据。假设有效 batch size 为 16,但每块 GPU 内存只够放 4 个样本。单 GPU 下需要跑 4 轮前向传播来累积梯度,再做一次反向传播,即梯度累积。

数据并行的做法:把 16 个样本分给 4 块 GPU,每块拿 4 个样本,各自并行执行前向传播。问题在于如何聚合梯度。
一种方式是汇集所有激活值来计算平均 loss 再求梯度,但更聪明的做法是在每块 GPU 上各自计算 4 个样本的梯度再求和——搬运的数据量更小,数学上完全等价。梯度求和后传回各机器,分别更新本地模型。

这个操作的技术术语是 all-reduce:每台机器贡献各自的梯度,合并后每台机器都拿到结果。虽然图示中画了一个"聚合器"(灰色方框),实际的 all-reduce 实现通常是环形传递——GPU 之间互相传梯度,最终全部拿到平均值。
4 块 GPU 并行,有效 batch size 仍然是 16,速度却快了很多。但有一个效率问题:每块 GPU 都在做完整模型的更新,要维护所有参数的 Adam 状态(一阶矩和二阶矩)。
内存充裕时这不成问题。但实际上每块 GPU 里复制了完整的模型参数、梯度、主权重以及 Adam 优化器状态。Adam 的状态量是模型参数量的两倍,内存占用很大。
对于大模型,内存成为硬瓶颈。ZeRO(Zero Redundancy Optimizer)针对的就是这个问题:一组内存优化技术,在保持数据并行的前提下大幅减少每块 GPU 的内存占用。
ZeRO Stage 1
核心思想是让每块 GPU 只负责更新一部分参数。比如将每层参数分成四份,GPU 1 负责 Part 1,GPU 2 负责 Part 2,以此类推。
走一遍流程:数据仍然拆分到四块 GPU 上,每块 GPU 基于自己看到的 4 个样本计算完整的梯度——到这里还是标准的数据并行。但梯度汇总后不再发回给所有人,而是按参数分片发送:每块 GPU 只收到自己负责那部分参数的梯度。术语上叫 reduce-scatter——每人只拿到合并结果的一个切片。
各 GPU 只更新自己负责的那部分参数,也只需要保留该部分的优化器状态。更新完成后,各 GPU 把自己的参数切片分享出去,拼接成完整模型。术语上叫 all-gather——每人贡献一个切片,每人拿到完整拼接结果。

ZeRO Stage 1
整个过程可以概括为两阶段:第一阶段按数据维度拆分,各 GPU 算全参数梯度再汇总;第二阶段按参数维度拆分,各 GPU 只更新自己负责的参数切片,最后拼接出完整模型。
效果是每块 GPU 只保留一小部分优化器状态,内存节省很可观。计算量方面,reduce-scatter 加 all-gather 的总通信量与朴素数据并行中的 all-reduce 等价,没有额外开销。
ZeRO Stage 2
ZeRO Stage 2 更进一步——不仅优化器状态分片,梯度本身也要分片。
关键在于反向传播是逐层进行的。每一层的梯度算完后,立刻将不属于自己管辖的部分发送给对应 GPU 并丢弃。不需要在任何时刻存储全部层的完整梯度。
在 ZeRO Stage 1 的流程中,要改变的是这一部分:

改为逐层处理梯度,红框中的部分变成如下流程:

第 2 层的梯度算完,把不负责的部分发出去、丢弃,然后处理第 1 层,重复同样的步骤。层数多的 LLM 从中获益明显——不需要同时存储所有层的梯度。代价是逐层通信带来少量额外开销。
ZeRO Stage 3,也称为完全分片数据并行(FSDP)
ZeRO Stage 3 把分片推到了极致——连模型权重都只存各自负责的那部分。这意味着前向传播也会受到影响。
流程同样是逐层进行的。到第 1 层时,执行 all-gather,各 GPU 各出自己的权重切片,拼出完整的第 1 层。每块 GPU 用完整的第 1 层权重和各自的数据计算激活值,算完后立刻丢弃不属于自己的权重切片。第 2 层同理。
反向传播与 ZeRO Stage 2 类似,但多了一步:每层计算梯度前要先 all-gather 把完整权重拼出来(因为本地没有完整权重),算完后再丢弃非本地切片。
本质上是按需逐层从各 GPU 拼出模型,任何时候都没有一块 GPU 持有全部权重。通信开销增加了,但内存节省巨大。对于给定的 GPU 配置,ZeRO Stage 3 能训练的模型规模远超前两个阶段。
CS336 课程给出的数据:8 块 A100 80GB GPU 上,不同策略可训练的最大模型尺寸差异很大。

同样的硬件配置下,ZeRO Stage 3 能训练的模型大了很多。
不过数据并行有一个约束条件:batch size。batch size 不能小于 GPU 数量:没法给一台机器半个样本。而 batch size 越大收益越低:大 batch 降低数据噪声方差,但超过一定阈值后边际收益接近于零。batch size 的"自然上限"直接限制了数据并行的扩展规模。
模型并行除了按数据维度拆分,还可以按模型维度切分,即模型并行。这里介绍两种形式:流水线并行和张量并行。
模型并行:流水线并行
流水线并行沿深度方向切分模型,一层分配给一块 GPU。问题在于前向和反向传播都是逐层串行的——每层需要前一层的输出才能开始计算,GPU 在等待输入时空闲,形成"气泡"。

缩小气泡的方法是引入 mini-batch 级别的流水线:第二块 GPU 处理某个 mini-batch 的第二层时,第一块 GPU 可以开始处理下一个 mini-batch 的第一层。

流水线并行的优势在于内存节省,每个设备只存一层的参数以及通信模式简单,只需将激活值从一层传到下一层。这种简单的通信特性使它适合部署在跨集群等带宽较低的网络链路上。
张量并行张量并行沿宽度方向切分,把单层内的矩阵乘法分配到多块 GPU 上并行执行,各自得到部分结果后再跨 GPU 求和。概念上类似于分块运算,区别在于分块是串行处理各块,张量并行是并发处理。
通信量很大——每层都要同步激活值。节点内部 NVLink 带宽在 600-900 GB/s,跨节点互连慢 10-20 倍。实践经验表明:张量并行扩展到 8 块 GPU 以上时,收益会急剧衰减。所以通常将张量并行限制在单个节点(最多 8 块 GPU)内。
张量并行有一个独特优势:不依赖 batch size。batch size 是数据并行和流水线并行共享的约束资源,张量并行与之正交,可以叠加使用而不消耗这项资源。
组合不同形式的并行几种并行策略分别沿不同维度拆分计算:数据维度、模型深度维度、模型宽度维度。实际训练中通常是多种策略的组合。
经验法则很简单:先解决内存问题,确保模型能装进 GPU。装不下就用流水线并行、张量并行、ZeRO Stage 3 等节省内存的技术。模型能装下之后,再用数据并行等手段堆算力,加快每个 batch 的处理速度。
附录:Softmax 解释softmax 将一组原始分数变换为加和为 1 的概率分布:对每个分数取指数,然后除以所有指数之和。

以三个分数(12, 7.2, 9.1)为例:

问题在于指数值增长极快。e¹² 已经是 162,755,超过 fp16 的最大值 65,504。理论值虽然正确,但计算过程中会溢出。解决办法是将分子和分母同时除以 e^(max),等价于从所有原始分数中减去最大值:

数学上结果完全一致,但避免了溢出。可能出现下溢(值太接近零),不过下溢时 0 已经是足够好的近似。这一数学技巧被几乎所有 LLM 的 softmax 实现采用。
总结这篇文章从 GPU 架构讲到并行策略,涉及的是把模型从玩具规模拉到生产规模所必须面对的工程问题。在专业团队中,训练一个无法放入单块 GPU 的 LLM 是常态,优化训练成本也是日常工作的一部分。理解底层硬件和并行机制,是做好这些工作的前提。
https://avoid.overfit.cn/post/8b2888b82d7c40c3b60e7e8847dafc9f
by Joseph