榨干每块显存:LLM 底层显存优化
作者:toy ---
作者:toy
GPU 显存是 LLM 推理与训练的硬约束,不是软性资源。当一个 70B 参数模型以 BF16 格式加载时,光是参数本身就需要约 140GB,单张 H100 的 80GB 显存根本装不下。工程师的任务不是抱怨硬件贵,而是理解显存被什么占用、每种占用能压缩多少、压缩的代价是什么。
这篇文章从显存的物理结构出发,拆解四大占用来源,逐一分析 KV Cache 量化、CPU Offload、FlashAttention、算子融合、显存碎片管理和 OOM 治理这六条优化路径,每条路径都附带可落地的配置方法和量化数字。不是教你"用哪个参数",而是讲清楚每个设计决策背后的物理约束和工程权衡。只有理解了原理,才能在自己的系统里做出正确的组合。
一、显存为什么这么贵
GPU 显存与算力的不对称增长
过去五年,GPU 的计算峰值(FP16 TFLOPs/s)增长幅度远超显存带宽和显存容量。A100 的 FP16 峰值算力是 312 TFLOPs/s,H100 SXM5 达到 989 TFLOPs/s,提升了约 3.2×;而显存带宽从 A100 的 2 TB/s 到 H100 的 3.35 TB/s,仅增长 1.7×;显存容量从 80GB 到 80GB(SXM5 版本),完全没有增长。算力越来越强,但显存装不进更多数据,这个剪刀差是 LLM 工程优化的根本张力。
这不是偶然的工程疏忽,而是物理限制。GPU 显存(HBM,High Bandwidth Memory)是一种叠层封装的 DRAM,制造复杂度和成本远超普通 DRAM,容量密度提升缓慢。而晶体管数量每两年翻倍的摩尔定律在算力层面(通过并行化)仍然有效,但在存储密度层面已接近极限。这导致 H100 的"算力/显存比"远高于 A100,高算力 GPU 上越来越容易出现"计算等数据"的情况,而不是"数据等计算"。
矩阵乘法是 LLM 中最密集的计算,在 Hopper 架构上 GEMM 可以将算力利用率推到 75% 以上。但模型越大,越多时间花在把权重从 HBM(高带宽内存)搬到片上 SRAM,而不是真正做计算。这就是"内存墙"(Memory Wall):计算单元空等数据,显存带宽成为瓶颈。
判断一个操作是"计算密集"还是"内存密集",用算术强度(Arithmetic Intensity)衡量:FLOPs 除以内存访问字节数。H100 的算术强度峰值约 139 FLOPs/byte(989 TFLOPs/s ÷ 3.35 TB/s)。GEMM 在大矩阵时算术强度高(数据复用充分),是计算密集;单个 token 的 prefill 阶段 attention 计算、KV Cache 读取则是内存密集。理解这个区别,决定了你该优化哪一端。
LLM 显存占用的四大来源
一个运行中的 LLM 系统的显存可以被拆成四块:
模型参数是最直观的部分。每个参数在 BF16 下占 2 字节,FP32 占 4 字节。7B 参数模型 BF16 精度需要约 14GB,70B 需要约 140GB。这部分在推理阶段是静态的,加载完就不再变化。
优化器状态只出现在训练阶段。AdamW 为每个参数维护一阶矩和二阶矩,各占一份参数大小。FP32 精度下,优化器状态是参数大小的 8 倍(参数 4B + 一阶矩 4B + 二阶矩 4B,共 12B/参数,mixed precision 训练下参数本身 FP16,优化器状态 FP32)。7B 模型完整训练态需要约 7B × 16 字节 ≈ 112GB,这是从事训练的工程师最头疼的数字。
激活值(Activations)是 forward pass 的中间计算结果,backward pass 时需要它来计算梯度。激活值的大小与 batch size 和序列长度成正比,但与模型参数无关。一个 7B 模型、batch size=8、序列长度=2048 的配置,激活值可以高达数十 GB。这是 Gradient Checkpointing 要解决的目标。
KV Cache 是推理阶段特有的显存消耗,也是最难控制的部分。它随序列长度线性增长,在长上下文场景下可以超过模型参数本身的占用。第二章会详细展开。
四块里,模型参数相对固定(量化可以压缩,但要在精度和大小间取舍);优化器状态是训练专用(推理不存在);激活值通过 Gradient Checkpointing 可以大幅压缩;KV Cache 是推理场景最值得深挖的目标,因为它是动态增长的,最不容易控制,也最容易被忽视。
7B 模型全精度推理的显存计算
以 LLaMA-3-7B 为例进行精确计算:
- 参数量:7.24B
- 精度:BF16(2 字节/参数)
- 参数显存:7.24B × 2 字节 ≈ 14.5 GB
KV Cache 部分(假设 batch_size=1,seq_len=4096): - 层数:32 层 - 注意力头数:32 头(KV 头数也是 32,非 GQA 版本) - head_dim:128 - KV Cache = 2(K+V)× 32(层)× 32(头)× 128(head_dim)× 4096(seq_len)× 2(BF16) = 1,073,741,824 字节 ≈ 1 GB
单 batch、4K 序列长度的 KV Cache 约 1GB,占参数显存的 7%。但如果 batch_size=32、seq_len=32768,KV Cache 就变成约 256GB,超出一切单卡显存。这就是 KV Cache 优化为什么重要。
另一个维度是 GQA(Grouped Query Attention)的影响。LLaMA-3 采用了 GQA,n_kv_heads=8,而 n_heads=32,KV Cache 大小变成标准多头注意力的 8/32 = 25%。同样的计算图,GQA 把 KV Cache 压缩到四分之一。这是架构层面的显存优化,不需要量化就能大幅降低 KV Cache 占用。未来更激进的做法是 MQA(Multi-Query Attention),n_kv_heads=1,极端情况下所有 heads 共享一套 KV,KV Cache 接近于零,代价是模型表达能力下降,适用于不依赖多头多样性的场景。
二、KV Cache:最占显存的推理中间产物
KV Cache 的结构
Transformer 的注意力机制每层都需要计算 Q(Query)、K(Key)、V(Value)三个矩阵。在自回归生成时,每个新 token 生成时,前面所有历史 token 的 K 和 V 都需要重新计算,或者可以把它们缓存起来,下次直接读取。KV Cache 就是这个缓存。
为什么只缓存 K 和 V,不缓存 Q?因为 Q(Query)是由当前生成的新 token 计算出来的,每个位置的 Q 不同,无法复用。K 和 V 代表"历史上下文",在 Decoder-only 架构(单向注意力)下,历史 token 的 K/V 一旦计算完就固定不变,新 token 的出现不影响旧 token 的 K/V。这是 KV Cache 成立的数学前提:因果性(causality)保证了历史 KV 的不变性。
KV Cache 是无损的。历史 token 的 K、V 是确定性计算结果,缓存和现场重算完全等价(不考虑数值误差时)。换言之,KV Cache 是用显存换算力的交易,且在标准推理流程中不存在精度损失。只有引入量化时才会有精度权衡。
每层 Transformer 都维护独立的 K/V 缓存。一个 32 层的模型,KV Cache 就是 32 对矩阵。每个矩阵的形状是 [batch_size, seq_len, n_heads, head_dim](或 GQA 下的 n_kv_heads),随着 prefill 阶段完成而初始化,随着生成阶段持续增长。prefill 阶段是"一次性"的:输入 prompt 的所有 token 并行处理,生成完整 KV Cache。decode 阶段是"逐 token"的:每次生成一个新 token,KV Cache 增长一行。这两个阶段的计算特征截然不同,对显存的影响方式也不同。
KV Cache 大小公式
KV_Cache_Bytes = 2 × batch_size × seq_len × n_layers × n_kv_heads × head_dim × dtype_bytes
其中:
- 2 代表 K 和 V 各一份
- dtype_bytes 在 BF16 下为 2,FP8 下为 1,INT4 下为 0.5
用 LLaMA-3-8B 的实际参数验证(GQA 版本,n_kv_heads=8):
KV = 2 × 1 × 8192 × 32 × 8 × 128 × 2
= 2 × 1 × 8192 × 32768
= 2 × 268,435,456 字节
≈ 0.5 GB(batch=1, 8K context, BF16)
当 batch_size=32 时,这个数字线性扩展到 16GB。在高并发服务场景下,100 个并发请求、8K 上下文就需要约 50GB,接近一张 H100 的全部显存。
KV Cache 随序列长度的增长
KV Cache 的增长是线性的(O(N) with N=seq_len),但"线性"在长上下文场景下依然是毁灭性的。
| seq_len | LLaMA-3-8B KV Cache(batch=1, BF16) |
|---|---|
| 4K | ~0.25 GB |
| 8K | ~0.5 GB |
| 32K | ~2 GB |
| 128K | ~8 GB |
| 1M | ~64 GB |
1M context 下光 KV Cache 就占满整张 H100。这是 KVQuant(NeurIPS 2024)论文演示单张 A100-80GB 服务 LLaMA-7B 1M context 时必须引入 3-bit 量化的原因。
Prefix Caching(前缀缓存)是另一个值得单独讲的优化点。在实际服务场景中,大量请求共享相同的系统提示(System Prompt)或 few-shot 示例。如果这部分 KV Cache 能在多个请求间共享,就可以避免重复的 prefill 计算和 KV 存储。vLLM 和 SGLang 都支持基于 radix tree 的 KV Cache 共享:对相同前缀的请求,共享同一份 KV Cache 副本,只为不同的后缀部分维护独立存储。在 System Prompt 占总长度 80% 以上的场景(如 RAG、long system prompt 场景),这可以减少 60-80% 的 prefill 计算和 KV Cache 存储。
PagedAttention 对 KV Cache 的改造
标准 KV Cache 的另一个问题不是大小,而是碎片化。传统实现为每个请求预分配一段连续显存,按最大可能序列长度分配(否则序列增长时可能无法扩展)。这导致两种浪费:
- 内部碎片:预分配 2K 但请求只生成了 500 个 token,1500 个 slot 空置
- 外部碎片:不同长度的请求释放后留下大小不一的碎片,无法组合成大块
PagedAttention(vLLM,SOSP 2023)借鉴操作系统的虚拟内存分页思想,把 KV Cache 拆成固定大小的 block(page),block 在物理显存上可以不连续,通过 block table 做地址翻译。请求增长时按需分配新 block,释放时归还整块 block。实测相比 FasterTransformer/Orca 实现 2-4× 的吞吐量提升,KV Cache 浪费接近零。
这与虚拟内存的类比非常精准:OS 用页表让进程看到连续地址空间,PagedAttention 让 attention 算子看到连续 KV 序列,物理上却允许分散存储。
PagedAttention 的 block 设计也带来了 KV Cache 共享的可能性:两个请求如果有相同的前缀,可以让它们的 block table 指向同一批物理 block(引用计数管理),实现 Copy-on-Write 式的共享。这就是 Prefix Caching 的底层机制。一个 block 同时被多个请求引用时,标记为只读;某个请求需要修改时,先复制一份再写。这把 vLLM 的 KV Cache 设计从"每请求独立缓冲区"升级为"细粒度共享内存池",在 System Prompt 大量重复的场景下效果明显。
三、KV 量化:用精度换显存
INT8 KV Cache 量化
KV Cache 量化的核心逻辑是:把 BF16(2 字节)的 K/V 压缩成 INT8(1 字节)或 FP8(1 字节),显存减半,代价是接受一定的精度损失。量化粒度直接决定精度损失的大小。
per-tensor 量化是整个 KV Cache 共享一个 scale 值。实现最简单,vLLM 默认 FP8 量化就是 per-tensor(scale=1.0,无需校准数据)。精度损失最大,适合精度不敏感的场景或快速实验。
per-token 量化是每个 token 的 K/V 向量各有独立 scale。捕获了 token 间的动态范围差异,不同 token 在不同 attention head 上的数值分布差异很大,per-token 可以针对每个 token 选择最优的量化范围。精度明显优于 per-tensor,额外元数据开销极小(每个 token 一个 float32,与 token embedding 大小相比可忽略)。
per-head 量化是每个 attention head 维护独立 scale。精度最优,需要离线校准(跑一组代表性输入收集各 head 的数值分布),vLLM 通过 --quantization-param-path 支持传入预计算好的静态 per-head scales。适合精度敏感的生产部署。
FP8 KV Cache:H100 的原生支持
H100 的 Hopper 架构原生支持 FP8(e4m3 和 e5m2 两种格式)运算,这使 FP8 KV Cache 不只是"存储压缩",还可以直接参与 FP8 矩阵乘法,避免量化-反量化的额外开销。
根据 vLLM 官方博客(2026-04-22)的实测数据:在 H100 上推理 Llama-3.1-8B,将 KV Cache 从 BF16 换成 FP8 之后:
- 每 token KV Cache 存储减半(BF16 → FP8)
- inter-token latency(ITL)的斜率降至 BF16 的 54%,即在相同序列长度下,每生成一个额外 token 的延迟是原来的一半多一点
- 在并发 8、约 2 万个输入 token 的条件下,获得 14.9% 更高的输出吞吐量
精度方面,Qwen3-30B-A3B-Thinking 在 AIME25/GPQA/MATH500 上的准确率变化仅 1-2 个百分点;Llama-3.3-70B 在 128K context 下的 AUC 恢复到 BF16 基线的 97-98%。
需要注意一个 break-even 点:FP8 KV Cache 在 context 长度超过 7K token 时才比 BF16 快。在短 context(< 7K)场景下,量化和反量化的额外计算开销反而会增加延迟。对于短文本问答服务,FP8 KV 可能得不偿失。
这个 break-even 点背后的逻辑值得深究:短 context 下,KV Cache 本身就很小,显存压力不大,FP8 压缩带来的 HBM 访问减少量不足以覆盖反量化的计算开销。长 context 下,KV Cache 巨大,每次 decode 步骤都要遍历全部 KV,HBM 读取量是决定性因素,FP8 减半了 HBM 读取量,效益明显。选择 KV 量化精度时,先确认自己的实际 context 长度分布。如果 P50 context 只有 3K,开 FP8 反而会让中位延迟变差。
KV 量化的精度损失机制
量化精度损失不是随机的,有规律可循。KV Cache 中有两类"难量化"的值:
第一类是异常值(Outliers)。LLM 的注意力权重在某些 token、某些 head 上会出现极端大值,与周围值相差数十倍。这些异常值会"撑大"量化范围,导致正常值被压缩到很少的量化级别里,精度损失集中在正常范围内。INT8 有 256 个量化级,如果量化范围被异常值撑到 [-100, 100],而大部分值集中在 [-1, 1],实际上只有约 2-3 个量化级在服务大多数值,等于退化成 INT2 精度。
第二类是 RoPE 后数值分布的偏移。RoPE(旋转位置编码)对 Key 向量做旋转变换,不同位置的 Key 经过旋转后数值范围可能扩大。如果在 RoPE 之后量化 Key,量化范围需要覆盖旋转后的全部分布,精度下降。在 RoPE 之前量化可以避免这个问题,这是 KVQuant 的 Pre-RoPE Key Quantization 的设计动机。
KVQuant:把量化做到极致
KVQuant(arxiv 2401.18079,NeurIPS 2024)提出了四种互补技术,把 KV Cache 压缩到 3-bit:
Per-Channel Key Quantization:Key 矩阵在 channel(head_dim)维度上量化,不同 channel 各自一套 scale 和 zero-point。Key 的分布在 channel 维度比 token 维度更稳定,per-channel 比 per-token 更适合 Key。
Pre-RoPE Key Quantization:RoPE(旋转位置编码)会对 Key 做旋转变换,量化后的 Key 经过 RoPE 变换后数值范围可能改变,导致量化误差放大。KVQuant 在 RoPE 之前对 Key 量化,绕过这个问题。
Non-Uniform KV Cache Quantization:对不同层使用不同精度。根据 per-layer 灵敏度分析,精度敏感的层用更多 bit,不敏感的层用更少 bit,全局 bit 预算固定的情况下把精度损失降到最低。
Per-Vector Dense-and-Sparse Quantization:识别并单独存储每个向量中的异常值(outliers),剩余值用低精度量化。KV Cache 中异常值比较集中,隔离后可以用更低 bit 量化主体部分。
实验结果:在 LLaMA/Llama-2/Llama-3/Mistral 系列上,3-bit 量化的 perplexity 下降小于 0.1(Wikitext-2 和 C4 数据集)。单张 A100-80GB 可以服务 LLaMA-7B 的 1M context length,并非因为 1M KV Cache 本来能装进 80GB,而是通过 3-bit 量化把它压缩到能装下的体积。
3-bit KV Cache 是目前已知可以在精度损失(perplexity < 0.1)约束下达到的最低 bit 宽,低于此精度损失就会快速上升,无法实用。这个 3-bit 边界并非任意的,它对应 KV 分布中异常值处理完后剩余分布用均匀量化能覆盖的最低精度,低于 3-bit 就开始触碰异常值处理的底线。
量化后的反量化时机
KV Cache 量化存储,但 attention 计算时需要用原始精度(或至少高精度)。反量化的时机有两种方案:
计算前反量化:从 KV Cache 读出量化值后,立即反量化回 BF16,再做矩阵乘。好处是与标准 attention 计算兼容;坏处是 HBM 带宽省了(量化后数据小),但 SRAM 里要做额外反量化计算。
融合反量化与计算:在 FlashAttention kernel 内部,读入量化 KV、在 SRAM 里反量化、立即做 attention 运算,不把中间值写回 HBM。这是 H100 FP8 FlashAttention-3 的做法,读 FP8、片上做 FP8 矩阵乘(WGMMA 支持 FP8 输入)、累加器 FP32 输出,整个流程在 SRAM 内完成。
四、KV Offload:把 CPU 当 GPU 的扩展内存
CPU Offload 的基本原理
当显存装不下 KV Cache 时,一个自然的想法是把"不活跃"的 KV 块搬到 CPU 内存(DRAM)。现代服务器通常有 512GB 到 2TB 的 DRAM,是 GPU 显存的 10-25 倍。理论上,可以把历史 KV Cache 放在 CPU 内存里,等需要时再搬回 GPU。
调度策略决定了 offload 的效率。最朴素的策略是 LRU(Least Recently Used):把最近没被访问的 KV blocks 换出去。但 LLM 推理的访问模式与 LRU 的假设不完全匹配,一个长序列请求会在每个 decode 步骤访问所有历史 KV,不存在"最近未访问"的 block;而短请求每次请求结束后就可以丢弃所有 KV。更好的策略是基于请求状态的感知调度:当一个请求完成了大部分 decode(剩余生成 token 数量少),它的优先级低,可以被抢先换出;而刚到来的请求 prefill 刚完成,正在 decode,优先级高,KV 留在 GPU 上。
PagedAttention 的 block 机制天然支持这一点:block 是最小的管理单元,既可以在 GPU 显存里,也可以 offload 到 CPU 内存。vLLM 的 swap_space 参数就是配置 CPU 侧 KV Cache 空间的。
PCIe 带宽瓶颈
现实情况比理论要悲观得多。H100 SXM5 的 HBM 带宽约 3.35 TB/s,而 PCIe 5.0 x16 仅约 63 GB/s,差距约 53×。
这个数字意味着什么?假设生成一个 token 需要从 32 层 KV Cache 里读取 32 × 2(K+V)= 64 个矩阵块,每块 128 × head_dim 字节。如果这些 KV 在 GPU 上,3.35 TB/s 的带宽可以毫秒级完成;如果在 CPU 上,63 GB/s 的 PCIe 传输速度成为硬瓶颈。
实测数据更直观:一项研究发现,CPU offload 场景下 99% 的延迟花费在 GPU-CPU 传输上,GPU 算力利用率仅有额定 TDP 的 28%。GPU 大部分时间在空等数据从 CPU 送来,计算资源浪费,延迟却急剧上升。
Offload 的适用场景
CPU offload 有明确的适用边界:
CPU offload 适合低 QPS(每秒请求数少)、高吞吐要求、延迟不敏感的离线处理。典型场景是批量文档处理,每个请求几分钟也可接受,但显存不够装所有请求的 KV Cache。这时 offload 让你能跑更大的 effective batch size,总吞吐提升。
不适合低延迟实时服务。如果 P99 延迟要求是 500ms,PCIe 传输可能单次就要 100ms 以上,无论如何调度都无法满足。
GH200 Grace Hopper 超芯片是个例外:NVLink-C2C 接口提供约 900 GB/s 的 CPU-GPU 带宽,相比 PCIe H100 提升约 7×。在 GH200 上,CPU offload 的延迟代价从 53× 下降到约 3.7×,场景覆盖范围明显扩大。NVLink-C2C 的带宽让 KV Cache CPU Offload 从"只适合低 QPS 离线场景"扩展到"可以在中等 QPS 在线场景下使用",前提是你有 GH200,这不是所有团队都能具备的硬件条件。
一个经常被忽视的 offload 替代方案是磁盘级 KV Cache:把更冷的 KV Cache 放到 NVMe SSD(带宽约 7 GB/s)。PCIe 带宽只有 63 GB/s 还被嫌弃,SSD 的 7 GB/s 更是只适合极冷数据(比如 prefill 结果的跨请求缓存),而不是 decode 阶段的实时访问。但对于"同一份 System Prompt 被频繁使用"的场景,把 prefill 的 KV Cache 持久化到 SSD,下次同样请求直接加载,可以完全跳过 prefill 计算。这是 LLM 服务中的"disk-backed KV Cache"方向,目前属于前沿研究,部分框架已有实验性支持。
vLLM 的 CPU Offload 配置
# vLLM 启动时配置 CPU KV Cache 交换空间
from vllm import LLM, SamplingParams
llm = LLM(
model="meta-llama/Llama-3.1-8B-Instruct",
# swap_space: CPU 侧 KV Cache 空间,单位 GB
# 当 GPU 显存不足时,vLLM 会把低优先级请求的 KV block 换出到此空间
swap_space=16, # 16GB CPU DRAM 作为 swap
gpu_memory_utilization=0.90, # GPU 显存利用率上限
max_model_len=32768, # 最大序列长度
)
swap_space 太小:频繁 OOM,无法服务长序列请求。swap_space 太大:占用大量 CPU DRAM,影响同机其他服务。实际部署建议通过压测确定 P99 latency 可接受的上限,再反推 swap_space 配置。
五、FlashAttention:从算法层消除显存瓶颈
标准注意力的显存复杂度问题
标准自注意力计算 Attention(Q, K, V) = softmax(QKᵀ/√d)V,中间必须生成 N×N 的注意力矩阵(N 为序列长度)。这个矩阵在 BF16 下占 N² × 2 字节。当 N=8192 时,这是 8192² × 2 ≈ 134MB;当 N=65536 时,这是 65536² × 2 ≈ 8.6GB,单单一个注意力矩阵就把显存撑爆。
更深层的问题是 HBM 访问模式。标准实现需要把完整的 N×N 矩阵写到 HBM,softmax 运算完再读回。这涉及大量 HBM 读写,而 HBM 带宽(3.35 TB/s on H100)是整个 GPU 计算最贵的资源之一。FlashAttention 原始论文证明,标准注意力的 HBM 访问复杂度是 Θ(Nd + N²),大部分时间花在那个 N² 项。
FlashAttention 的核心洞察
FlashAttention(arxiv 2205.14135,2022)的关键洞察是:不需要把整个 N×N 矩阵物化出来。注意力计算的输出 O = softmax(QKᵀ/√d) × V 可以通过分块(Tiling)逐步累积计算,每次只把一小块 Q 和对应的所有 K、V 装进 SRAM,在 SRAM 内完成这一块的 attention 运算,然后以"在线 softmax"的方式把结果合并到最终输出。
GPU 的片上 SRAM(Shared Memory)速度极快,但容量极小(H100 SXM5 的 L2 cache 约 50MB,SRAM 更小)。HBM 容量大(80GB),但相对慢(3.35 TB/s,听起来很快,但 SRAM 带宽在 TB/s 量级内,L1/L2 延迟远低于 HBM)。FlashAttention 的思路是:把计算搬进 SRAM,让 HBM 只做"搬运原始数据"和"写出最终结果"的角色,消除中间矩阵的 HBM 往返。
Tiling 策略与在线 Softmax
分块计算注意力的关键挑战是 softmax:softmax 需要全局的归一化因子(分母),而分块计算时无法提前知道全局最大值和全局和。
在线 Softmax(Online Softmax)解决了这个问题,原理如下:
对于 softmax(x),分母是 sum(exp(x_i)),需要先遍历所有 x_i 得到全局 sum,再计算每个 exp(x_i)/sum。两遍遍历意味着必须存下所有 x_i,分块计算要绕开的就是这个问题。
在线 Softmax 的技巧是维护两个 running 统计量:
- m(running max):到目前为止见过的最大值
- l(running sum):以 m 为基准的 exp 之和,即 sum(exp(x_i - m))
每次处理一个新块(含 k 个 key)时:
1. 计算新块的局部最大值 m_new = max(m_old, m_block)
2. 用 exp(m_old - m_new) 校正旧的 l(因为基准从 m_old 换到了 m_new)
3. 把新块的 sum(exp(x_i - m_new)) 加进 l
4. 同时用相同的校正因子更新累积的输出 O
数学上可以证明,当所有块处理完毕,m 等于全局最大值,l 等于全局归一化因子,O 等于标准 softmax 的输出。整个过程不需要把任何中间矩阵写回 HBM,是真正的流式计算。
FlashAttention 的 HBM 访问复杂度降至 O(N²d²M⁻¹),其中 M 是 SRAM 大小。在典型参数下(M 为几十 MB,d=128),HBM 读写次数最多减少 9×,FLOPs 与标准注意力完全相同(O(N²d))。没有牺牲计算量,节省的是 IO。IO 优化和计算优化是正交的,你可以在不改变 FLOPs 的前提下减少 9× 的 HBM 访问。
# PyTorch 2.0+ 的 scaled_dot_product_attention 自动调度到 FlashAttention
import torch
import torch.nn.functional as F
def flash_attention_example():
batch, heads, seq_len, head_dim = 2, 8, 4096, 128
device = 'cuda'
dtype = torch.bfloat16
q = torch.randn(batch, heads, seq_len, head_dim, device=device, dtype=dtype)
k = torch.randn(batch, heads, seq_len, head_dim, device=device, dtype=dtype)
v = torch.randn(batch, heads, seq_len, head_dim, device=device, dtype=dtype)
# 明确启用 FlashAttention,禁用标准数学路径
# 前提:CUDA 设备,dtype=float16 或 bfloat16,head_dim ∈ {64, 128}
with torch.backends.cuda.sdp_kernel(
enable_flash=True,
enable_math=False,
enable_mem_efficient=False
):
# causal=True 对应因果掩码(只看左侧 token),节省约一半 FLOPs
out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
assert out.shape == (batch, heads, seq_len, head_dim)
return out
PyTorch 2.0 以后的 scaled_dot_product_attention 会在满足条件时自动路由到 FlashAttention。条件包括:输入在 CUDA 设备、dtype 为 float16 或 bfloat16、head_dim 为 {64, 128} 之一。
Tiling 的分块大小如何选择
FlashAttention 把 Q、K、V 切成若干 block,每个 block 必须能完整装进 SRAM。分块大小(BLOCK_M × BLOCK_N,分别对应 Q 和 K/V 的序列维度)需要在三个约束之间取平衡。
约束一是 SRAM 容量。H100 每个 SM 的 shared memory 约 228 KB(可配置上限)。一个 block 在 SRAM 里需要同时存放 Q_block(BLOCK_M × head_dim)、K_block(BLOCK_N × head_dim)、V_block(BLOCK_N × head_dim)和输出累积 O_block(BLOCK_M × head_dim),共 (BLOCK_M + 3 × BLOCK_N) × head_dim × dtype_bytes 字节。当 head_dim=128、dtype=BF16 时,BLOCK_M=BLOCK_N=64 需要约 64KB,可以容纳;BLOCK_M=BLOCK_N=128 需要约 128KB,仍在限制内,但不同 GPU 架构的 shared memory 大小各异,必须校验。
约束二是 GPU Occupancy。每个 CUDA thread block 占用的 shared memory 越大,SM 上能同时调度的 thread block 就越少,GPU 利用率下降。FA-2 通过在序列长度维度并行化来补偿这一损失,但 BLOCK_M 和 BLOCK_N 设得太大仍然会导致线程利用率不足。
约束三是 Tensor Core 效率。BLOCK 越大,矩阵乘法越接近"大矩阵 GEMM",Tensor Core 利用率越高;BLOCK 太小,GEMM 退化成小矩阵乘,算力利用率低。head_dim=128 是个"天然分块大小",很多实现直接以一个 head 作为 K/V 的分块单位,只在序列长度维度做 tiling。
FA-2 的 Triton 实现默认 BLOCK_M=128、BLOCK_N=64(A100 上的搜索结果),FA-3 在 H100 上用更大的 block(BLOCK_M=192、BLOCK_N=128)配合 WGMMA 和 TMA 指令。这些数字不是手工调参的结果,而是通过 Triton 的 @triton.autotune 装饰器在实际硬件上搜索得到的。不同 GPU、不同 head_dim、不同 dtype 都对应不同的最优分块配置,跨硬件移植时不能直接照搬。
FlashAttention-2 和 FlashAttention-3 的演进
FlashAttention-2(arxiv 2307.08691,2023)在 FA-1 的基础上做了三项并行优化。
第一,新增序列长度维度的并行。FA-1 只在 batch_size × num_heads 维度分配 CUDA 线程块,当 batch_size 很小(服务场景常见)时 GPU 利用率低;FA-2 在序列长度维度也并行化,即使 batch=1 也能打满 GPU。
第二,Warp 分区从 sliced-K 改为 Q 分片。FA-1 的 sliced-K 需要 warp 间通过 shared memory 同步中间结果;FA-2 改为每个 warp 处理不同的 Q 块、共享 K/V,消除 warp 间同步开销。
第三,Causal mask 优化:因果注意力的下三角矩阵中,约一半计算是不需要的(上三角全为 -∞)。FA-2 对此做了专门的 bound-checking 和 masking 优化,把 FLOPs 降到理论下界的约一半。
实测性能:A100 80GB SXM4 上达到 230 TFLOPs/s(FP16/BF16),相比 FA-1 约 2× 提升,相比 PyTorch 标准注意力最高 9× 提升。H100 SXM5 上无额外优化达到 335 TFLOPs/s。
FlashAttention-3(arxiv 2407.08608,2024)专为 Hopper 架构设计,利用 H100 的三项独特硬件能力:
- WGMMA(Warpgroup Matrix Multiply-Accumulate):H100 新增的异步矩阵运算指令,可以同时发射多个 GEMM 操作
- TMA(Tensor Memory Accelerator):专用数据搬运单元,可以异步批量搬运 tensor,与计算并行进行
- Warp specialization:把 warp 分成专门做计算的和专门做数据搬运的,两类 warp 异步重叠,隐藏内存延迟
结果:H100 FP16 达到 740 TFLOPs/s(75% 硬件利用率),相比 FA-2 在 H100 上(35% 利用率,335 TFLOPs/s)提升 2.2×;FP8 接近 1.2 PFLOPs/s。FA-3 还支持 block quantization 的 FP8 FlashAttention,数值误差比朴素 FP8 Attention 低 2.6×。
| 版本 | 年份 | A100 (FP16) | H100 (FP16) | 适用架构 |
|---|---|---|---|---|
| FA-1 | 2022 | ~125 TFLOPs/s | - | Ampere+ |
| FA-2 | 2023 | 230 TFLOPs/s | 335 TFLOPs/s | Ampere+ |
| FA-3 | 2024 | - | 740 TFLOPs/s | Hopper only |
六、算子融合:把多个 Kernel 合并为一个
Kernel 启动的隐性开销
GPU 上的每个操作(矩阵乘、激活函数、LayerNorm 等)都是一个 CUDA Kernel。每次 Kernel 启动都有固定开销:CUDA runtime 需要调度线程块、初始化寄存器、建立执行上下文。这个 latency 通常在 5-50 微秒量级,听起来很小,但对于 pointwise 操作(如 ReLU、LayerNorm 的逐元素加减乘除),Kernel 本身的运算时间可能只有几微秒,启动开销反而是计算的主体。
更严重的问题是 HBM 往返。每个 Kernel 通常从 HBM 读入输入、写出输出。如果 LayerNorm → Linear → SwiGLU → Linear 这四个操作各自是独立 Kernel,中间产物要在 HBM 里写了又读三次。把它们合并成一个 Kernel,只需从 HBM 读一次输入、写一次最终输出,中间值全在 SRAM 内流转。
典型融合模式
FlashAttention 本身就是融合算子:把 Q·Kᵀ GEMM + 逐行在线 softmax + P·V GEMM 合并为单一 CUDA Kernel,彻底消除 N×N 中间矩阵的 HBM 往返。这是迄今为止 LLM 推理中影响最大的单一算子融合。
LayerNorm / RMSNorm 融合把 LayerNorm 的均值计算、方差计算、归一化、缩放四步合并为一个 Kernel,中间统计量不写 HBM。Megatron Core 将 LayerNorm 融进单一 CUDA Kernel;Mirage 系统可自动发现 RMSNorm + MatMul 的融合实现。
SwiGLU 激活融合:SwiGLU(x) = x × sigmoid(x) × gate,是 LLaMA 等模型广泛使用的激活函数。Gate 和激活的计算可以融合成一个 Kernel,避免把 gate 向量写回 HBM 再读取。
Bias + Activation 融合把 Linear 层的偏置加法和后续激活函数合并,减少一次 HBM 读写。
# torch.compile 可以自动触发部分算子融合
import torch
model = MyTransformerModel().cuda().bfloat16()
# torch.compile 会分析计算图,对 pointwise 操作自动融合
# fullgraph=True 要求整个函数被编译(不允许 graph breaks)
# mode='max-autotune' 让编译器搜索最优 kernel 配置(编译时间更长,运行时更快)
compiled_model = torch.compile(model, fullgraph=True, mode='max-autotune')
torch.compile 通过 TorchInductor 后端自动识别融合机会,但对动态形状(dynamic shapes)和自定义 CUDA extension 支持有限。理解 torch.compile 的工作机制有助于更好地利用它:TorchDynamo(前端)把 Python 代码转换成 FX Graph,TorchInductor(后端)把 FX Graph 编译成优化后的 kernel,Triton 是 TorchInductor 的主要代码生成目标。当 TorchInductor 发现连续的 pointwise 操作时,会把它们融合成一个 Triton kernel,消除中间 tensor 的 HBM 往返。
对于性能关键路径,手写 Triton kernel 可以提供更精细的控制:
import triton
import triton.language as tl
@triton.jit
def fused_rmsnorm_kernel(
X_ptr, W_ptr, Out_ptr,
N: tl.constexpr,
eps: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""
融合 RMSNorm Kernel:一次 HBM 读取完成均方根归一化
与分步实现相比,节省 2 次 HBM 读写(中间 norm_x 向量)
"""
row_idx = tl.program_id(0)
offsets = tl.arange(0, BLOCK_SIZE)
mask = offsets < N
# 一次读入整行
x = tl.load(X_ptr + row_idx * N + offsets, mask=mask, other=0.0)
w = tl.load(W_ptr + offsets, mask=mask, other=0.0)
# SRAM 内计算均方根
x_sq = x * x
mean_sq = tl.sum(x_sq, axis=0) / N
rrms = 1.0 / tl.sqrt(mean_sq + eps)
# 原地缩放并写出
out = x * rrms * w
tl.store(Out_ptr + row_idx * N + offsets, out, mask=mask)
Triton kernel 在 SRAM 内完成整个 RMSNorm 计算,与 PyTorch 的分步实现相比节省了中间向量的 HBM 往返。
torch.compile 的融合能力与限制
torch.compile 对 pointwise 操作的融合效果最好(element-wise 操作序列几乎总能被合并)。对 reduction 操作(sum、max 等)次之。对 matmul 的融合有限(GEMM 通常单独执行,用高度优化的 cuBLAS/CUTLASS 库)。
主要限制:
- 动态控制流:if/for 依赖 tensor 值时会产生 graph break,融合在断点处终止
- 自定义 CUDA 扩展:无法穿透进 C++ 扩展内部做融合
- 动态形状:序列长度在每次推理时变化,torch.compile 需要用
dynamic=True模式,融合效果有所下降 - 编译时间:第一次运行时间长(10-120 秒),不适合频繁重启的场景
七、显存碎片与显存池
PyTorch Caching Allocator 的工作原理
PyTorch 不直接调用 cudaMalloc 分配每个 tensor,那样太慢(cudaMalloc 需要与 CUDA driver 同步,延迟在几百微秒量级)。PyTorch 维护一个 Caching Allocator,从 CUDA 申请大块内存后自行管理分配与释放。
Caching Allocator 的核心数据结构是两个 pool:
- Small Pool(< 1MB 请求):维护若干 2MB 的固定大小 block,从 CUDA 按 2MB 申请,内部切分给小请求。分裂阈值:请求 size < 1MB 且剩余空间 > 512B 时才切分,避免产生 < 512B 的极小碎片
- Large Pool(≥ 1MB 请求):每次从 CUDA 申请最小 20MB 的 block,按需切分。分裂阈值:请求 size ≥ 1MB 且剩余 > 1MB 时才切分
分配策略是 best-fit:遍历 free list,找到大小最接近请求的 block 复用。如果没有合适的 free block,才向 CUDA 申请新的。best-fit 的好处是减少内部碎片;代价是 free list 遍历时间,但通常 free list 不大,不是性能瓶颈。
释放时自动合并相邻的 free block(coalescing),把小碎片拼回大块,下次可以服务更大的请求。合并是双向的:检查左邻和右邻是否也是 free,若是则合并成更大 block,递归进行直到邻居都是 active。正是这个合并机制,使得持有"活跃邻居"的 inactive block 成为真正意义上的碎片,它无法被合并,也无法参与分配,只能等邻居释放。
显存碎片的产生
即使有 coalescing,碎片仍然会出现。根本原因是:被拆分的 block,如果两侧邻居中有一个仍然被活跃 tensor 占用,这个 block 就无法被合并,也无法归还给 CUDA。PyTorch 用 inactive_split_bytes 指标追踪这类碎片量。
这类碎片有一个形象的比喻:想象显存是一条停车场,每辆车(tensor)占一个车位。一辆大卡车(大 tensor)被分成了两段停在车位 5-8 和 10-14。后来车位 9 被另一辆车占据,车位 5-8 的段落想与 10-14 合并已经不可能了,中间有别人挡着。PyTorch 的 coalescing 只能合并相邻的空车位,不能跨过已占用车位合并。
典型触发场景:推理服务在处理一批短请求后,遗留了大量小 block 的碎片;接下来来了一个长请求,需要一块连续的大 block,但 free list 里全是碎片,无法合并成足够大的块,触发 OOM,尽管此时 nvidia-smi 显示的 free memory 可能还有几 GB。这种情况下 torch.cuda.empty_cache() 完全无效,因为碎片 block 的邻居还活着。
另一个常见的碎片来源是 CUDA Graph 与动态显存分配的冲突。CUDA Graph 录制时会"固化"内存地址,后续执行必须用相同地址。如果 CUDA Graph 使用的 block 与动态分配的 block 交错,就会产生无法移动的内存孤岛。在使用 CUDA Graph 加速推理的框架里(如 TRT-LLM 的 paged KV cache 模式),需要特别注意这一点,expandable_segments 与 CUDA Graph 同时使用时需要做兼容性验证。
import torch
def print_memory_stats(device=0):
"""监控 GPU 显存碎片状态"""
stats = torch.cuda.memory_stats(device)
reserved = stats['reserved_bytes.all.current'] / 1e9 # 已向 CUDA 申请的总量
active = stats['active_bytes.all.current'] / 1e9 # 实际被张量占用的量
inactive_split = stats['inactive_split_bytes.all.current'] / 1e9 # 碎片量
fragmentation_ratio = inactive_split / reserved if reserved > 0 else 0
print(f'Reserved: {reserved:.2f} GB')
print(f'Active: {active:.2f} GB')
print(f'Inactive split: {inactive_split:.2f} GB # 碎片')
print(f'Fragmentation: {fragmentation_ratio:.1%}') # 超过 20% 应介入
inactive_split_bytes 是碎片的直接度量。如果 fragmentation ratio 超过 20%,应该介入处理。
碎片整理与显存池
torch.cuda.empty_cache() 的作用是把 Caching Allocator 的 free list 里的所有 block 归还给 CUDA。它能降低 reserved_bytes,但无法消除 inactive_split_bytes,被邻居占用的 block 归还不了,碎片仍在。empty_cache() 对碎片问题的作用有限,只在释放大块空闲内存时有效。
真正的防碎片方案是通过 PYTORCH_CUDA_ALLOC_CONF 配置 Caching Allocator 行为:
import os
# 在 import torch 之前设置(否则不生效)
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = (
'expandable_segments:True,' # 允许 segment 扩展而非申请新 segment
'garbage_collection_threshold:0.9,' # 90% 显存占用时主动回收
'max_split_size_mb:512' # 禁止拆分超过 512MB 的 block
)
expandable_segments:True 是最有效的防碎片参数:分配器扩展已有 segment 而非申请新 segment,内部碎片大幅降低,推荐在现代 GPU 上默认开启。但与 CUDA Graph 结合时有兼容性问题(CUDA Graph 要求固定的内存地址,可扩展 segment 地址可能变化),在使用 CUDA Graph 的推理框架里(如 TRT-LLM)需要确认兼容性。
max_split_size_mb:512 阻止大 block 被拆分成小 block,减少碎片产生,代价是小请求可能占用大 block,内存利用率下降。对 batch size 变化大的场景(如在线推理服务)有明显效果。
garbage_collection_threshold:0.9 在显存占用超过 90% 时主动触发回收,避免接近 OOM 时的突发性 sync-and-reclaim 操作(那会导致服务抖动)。对延迟敏感的服务有价值。
八、OOM 治理:不崩溃的工程实践
OOM 的常见触发场景
LLM 服务 OOM 不总是因为模型太大,很多时候是系统行为与显存约束之间的意外交互:
Batch size 过大:推理服务在低延迟模式下通常用动态 batching,当请求积压时 batch size 瞬间增大,KV Cache 线性增长,触发 OOM。应对方法是在显存水位预警时限制 batch size,而不是等到 OOM。
序列长度突变:服务大部分时间处理 2K 以内的请求,偶尔来一个 100K 的长文档请求,KV Cache 需求骤增 50 倍。防御措施是显式限制 max_model_len,超长请求直接拒绝或降级处理,不让它挤占其他请求的 KV Cache 空间。
多并发加显存碎片:多个并发请求的 KV Cache 交替分配和释放,产生碎片。下一个大请求到来时找不到足够大的连续块。这不是"内存不够",而是"内存碎片化"。用 expandable_segments:True 可以有效缓解。
激活值峰值:训练时每个 mini-batch 的 forward pass 产生大量激活,某些模型结构在特定输入形状下激活值特别大(如长序列 + 宽 FFN)。梯度累积(gradient accumulation)会延长激活值的生存周期,放大峰值。
模型加载与推理之间的峰值:一个容易被忽视的场景是模型加载的瞬间。加载 70B 模型时,磁盘读取 → CPU DRAM → GPU HBM 的传输过程中,可能短暂同时在 CPU 和 GPU 上各有一份模型权重,峰值需求是稳态的两倍。torch.load() 加 map_location='cpu' 先加载到 CPU、再按层逐步移到 GPU,可以避免 GPU 侧的双份峰值。实际上 vLLM 的模型加载器就是这么做的。
Gradient Checkpointing:以算力换显存
训练时,标准做法是保存所有 forward pass 的激活值以备 backward 使用。Gradient Checkpointing(也叫 Activation Checkpointing)改变这一策略:在 forward pass 时只保存"检查点"(通常是每层的输入),在 backward pass 时重新计算中间激活。
代价:额外计算约增加 33%(因为每层 forward 计算了两次)。收益:激活值显存从 O(n)(n 为层数)降至 O(√n),整体可节省 50-70% 的激活显存。
PyTorch 2.5 引入了 Selective Activation Checkpointing(SAC),允许通过 policy_fn 精细控制哪些算子需要重算、哪些要保存:
import torch
from torch.utils.checkpoint import checkpoint, create_selective_checkpoint_contexts
from functools import partial
aten = torch.ops.aten
# 策略:保留昂贵算子(matmul、FlashAttention),重算廉价算子(ReLU、Dropout 等)
compute_intensive_ops = [
aten.mm.default,
aten.bmm,
aten.addmm,
aten._scaled_dot_product_flash_attention, # FlashAttention 本身也贵,保留
]
def sac_policy_fn(ctx, op, *args, **kwargs):
"""SAC policy:只重算非计算密集型算子"""
if op in compute_intensive_ops:
return torch.utils.checkpoint.CheckpointPolicy.MUST_SAVE
else:
return torch.utils.checkpoint.CheckpointPolicy.PREFER_RECOMPUTE
def forward_with_sac(module, *inputs):
"""使用 SAC 包裹 transformer block
相比全量 checkpointing,减少约 30-40% 的重算开销,显存节省相近"""
return checkpoint(
module,
*inputs,
use_reentrant=False, # 新 API,推荐使用(旧 API 有内存泄漏风险)
context_fn=partial(create_selective_checkpoint_contexts, sac_policy_fn),
)
SAC 的直觉是:matmul 很贵(需要调用 cuBLAS,启动大量线程),应该保存其结果;而 pointwise 操作(ReLU、Dropout、element-wise 加法)极廉价,宁可重算也要节省存储这些中间值的显存。
PyTorch 2.4 还引入了 Memory Budget API:
# 设置显存预算(0=全量 checkpointing,1=不 checkpoint)
# 编译器自动寻找 Pareto 最优的 checkpoint 策略
torch._dynamo.config.activation_memory_budget = 0.5
model = MyTransformerModel().cuda()
compiled_model = torch.compile(model)
# 运行时自动应用最优 SAC 策略
out = compiled_model(inputs)
这把 SAC 的策略搜索自动化:用 activation_memory_budget 指定愿意花多少显存,编译器找到满足预算约束下计算开销最小的 checkpoint 方案。
Activation Offload 与 Gradient Checkpointing 的权衡
Activation Offload 是另一种降低激活显存的方案:forward 时把激活值搬到 CPU 内存,backward 时通过 PCIe 取回。理论上显存节省更彻底(只受 CPU DRAM 大小限制)。
现实中 Activation Offload 受限于 PCIe 带宽。H100 的计算速度极快,完成一层 forward 只需几毫秒;而通过 PCIe 5.0(63 GB/s)搬运一个 large activation tensor 的时间通常超过 GPU 重算时间。在高性能 GPU 上,Activation Offload 的延迟瓶颈是传输,Gradient Checkpointing 的代价是计算。GPU 算力便宜,PCIe 带宽贵,所以优先选 checkpointing。
实践建议: - 高性能 GPU(A100/H100):优先全量 Gradient Checkpointing,再考虑 SAC 优化计算开销 - 低性能 GPU(T4/V100):GPU 算力慢,重算代价高,Activation Offload 值得考虑 - 显存极度紧张(checkpointing 后仍然 OOM):两者叠加使用,checkpointing 先压一波,offload 作保底
Gradient Checkpointing 的实现细节
全量 Gradient Checkpointing 的实现比表面上复杂:PyTorch 的旧 API(use_reentrant=True,默认)使用 autograd 的 saved_tensors 机制,把检查点处的 tensor 标记为可重算,backward 时调用重算函数。旧 API 有一个已知问题:重算过程中产生的中间 tensor 与主 backward pass 的 tensor 可能发生内存重叠,在某些边缘情况下导致结果错误或内存泄漏。
新 API(use_reentrant=False)使用 functional 风格,不依赖 autograd 内部的重入机制,更安全且更兼容 torch.compile。2024 年以后的代码建议统一切换到 use_reentrant=False:
from torch.utils.checkpoint import checkpoint
# 旧 API(有潜在问题,不推荐新代码使用)
# out = checkpoint(fn, *args)
# 新 API(推荐)
out = checkpoint(fn, *args, use_reentrant=False)
全量 checkpointing 的 33% 额外计算开销是平均值,实际开销取决于模型结构。对于有大量 attention 层的 transformer,attention 计算是最贵的算子,全量 checkpointing 会重算所有 attention,开销接近 50%。SAC 在这里的价值就体现出来了:保留 FlashAttention 结果(避免最贵的重算),只重算廉价的 pointwise 算子,整体开销可以降到 10-15% 的额外计算,显存节省保持在 60-70%。
显存监控与 OOM 预警
在生产环境中,不能等 OOM 崩溃后再排查,要在 OOM 发生前预警。
nvidia-smi 提供粗粒度显存监控:
# 每秒刷新,监控显存使用率
nvidia-smi dmon -s m -d 1
# 查看各进程显存占用(按显存排序)
nvidia-smi --query-compute-apps=pid,used_memory --format=csv,noheader | sort -t, -k2 -rn
torch.cuda.memory_stats() 提供细粒度 PyTorch 层面指标:
import torch
def check_memory_health(device=0, warn_threshold=0.85):
"""
显存健康检查:
- fragmentation_ratio > 20%:碎片化严重,考虑 expandable_segments
- utilization > warn_threshold:接近 OOM,触发预警
"""
stats = torch.cuda.memory_stats(device)
props = torch.cuda.get_device_properties(device)
total_memory = props.total_memory
reserved = stats['reserved_bytes.all.current']
active = stats['active_bytes.all.current']
inactive_split = stats['inactive_split_bytes.all.current']
utilization = reserved / total_memory
fragmentation = inactive_split / reserved if reserved > 0 else 0
if utilization > warn_threshold:
print(f'[WARN] 显存利用率 {utilization:.1%},接近 OOM 阈值')
if fragmentation > 0.20:
print(f'[WARN] 碎片率 {fragmentation:.1%},建议开启 expandable_segments')
return {
'utilization': utilization,
'fragmentation': fragmentation,
'reserved_gb': reserved / 1e9,
'active_gb': active / 1e9,
}
在推理服务里,可以在每个请求处理完成后调用上面的 check_memory_health(),当 utilization 超过 85% 时向监控系统发送告警,当超过 90% 时拒绝新的长序列请求。这比等到 OOM 后重启服务代价小得多。
训练 vs 推理的显存优化差异
显存组成对比
训练和推理的显存结构有本质区别,优化重点也截然不同。
推理阶段的显存构成相对简单:模型参数(静态,加载后不变)+ KV Cache(动态,随 batch size 和序列长度线性增长)+ 少量运行时 buffer。激活值只需保留当前正在计算的一层,不需要为 backward 留存,开销在 MB 量级。优化重点集中在 KV Cache 管理(PagedAttention、量化、offload)和算子融合(减少 HBM 往返)。
训练阶段的显存构成复杂得多:模型参数 + 梯度 + 优化器状态 + 全部 forward 激活值(backward 时需要)。对于 7B 模型 AdamW 混合精度训练,光是参数 + 梯度 + 优化器就需要约 56GB,激活值峰值再叠加上去,单卡 H100 80GB 也捉襟见肘。
| 组成部分 | 推理 | 训练(AdamW mixed precision) |
|---|---|---|
| 模型参数 | BF16(2B/参数) | BF16(2B/参数)+ FP32 主权重(4B) |
| 梯度 | 无 | BF16 或 FP32(2-4B/参数) |
| 优化器状态 | 无 | FP32 一阶矩 + 二阶矩(8B/参数) |
| 激活值 | 当前层(MB 级) | 全部层(GB 级,视 batch) |
| KV Cache | 核心开销(GB 级) | 无(训练不缓存 KV) |
梯度检查点的使用
Gradient Checkpointing 是训练阶段最重要的显存优化手段之一,核心思路是用计算换显存:forward pass 时只保留关键检查点(通常是每个 Transformer block 的输入 tensor),丢弃中间激活;backward pass 时从最近的检查点重新执行 forward,重新生成需要的激活。代价是每个被检查点覆盖的 block 会额外执行一次 forward,理论计算开销增加约 33%。
from torch.utils.checkpoint import checkpoint_sequential
import torch.nn as nn
class TransformerWithCheckpoint(nn.Module):
def __init__(self, layers):
super().__init__()
self.layers = nn.ModuleList(layers)
def forward(self, x):
# segments=len(self.layers) 表示每层各作一个检查点
# use_reentrant=False 是 PyTorch 推荐的新 API,避免旧 API 的内存泄漏风险
return checkpoint_sequential(
self.layers, segments=len(self.layers),
input=x, use_reentrant=False
)
全量检查点能把激活显存从 O(层数 × batch × seq_len) 降至 O(batch × seq_len),激活显存几乎与层数无关,代价是 33% 的额外计算。在 H100 上计算很快,这个代价通常可以接受。如果想进一步降低计算代价,使用前文介绍的 SAC 策略,保留 matmul 结果、只重算 pointwise 操作,可以在接近全量 checkpoint 的显存节省的同时,把计算开销控制在 10-15%。
训练时的混合精度(BF16/FP16)配置
混合精度训练是训练阶段降低显存的标配方案。核心规则:forward 和 backward 用 BF16,梯度更新时将梯度转成 FP32 与 FP32 主权重合并,写回参数时再转成 BF16。这样参数的"工作副本"是 BF16(2 字节),optimizer 状态是 FP32(4 字节),数值稳定性与纯 FP32 接近。
BF16 比 FP16 更适合训练:BF16 指数位更宽(8 位 vs 5 位),动态范围与 FP32 相同,不容易溢出;FP16 精度更高(尾数 10 位 vs 7 位),但更容易出现梯度溢出,需要配合 GradScaler 做 loss scaling。Hopper/Ampere 架构对 BF16 矩阵运算有原生 Tensor Core 加速,当代训练几乎都选 BF16。
# BF16 训练:不需要 GradScaler(BF16 动态范围足够,无需 loss scaling)
import torch
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
output = model(inputs)
loss = criterion(output, targets)
loss.backward()
optimizer.step()
optimizer.zero_grad()
# FP16 训练:需要 GradScaler 防止梯度下溢
from torch.cuda.amp import GradScaler
scaler = GradScaler()
with torch.autocast(device_type='cuda', dtype=torch.float16):
output = model(inputs)
loss = criterion(output, targets)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
相比纯 FP32 训练,混合精度把参数和激活的存储减半(FP32→BF16),优化器状态保持 FP32 不变,综合下来整体显存减少约 30-40%。显存极度紧张时,可以叠加 optimizer state offload(ZeRO-2/3 风格,把 FP32 优化器状态搬到 CPU 内存),进一步释放 GPU 显存用于更大的 batch size。
一个 70B 模型的显存规划案例
参数量 × dtype_bytes = 显存基线
以 LLaMA-3-70B(实际参数量约 70.6B)为基准,做精确的显存规划。
推理场景(BF16 全精度推理):
参数显存 = 70.6B × 2 字节 = 141.2 GB
单张 H100(80GB)装不下,需要至少 2 张(tensor parallel 或 pipeline parallel 切分)。如果用 INT4 量化(GPTQ/AWQ):
参数显存 = 70.6B × 0.5 字节 = 35.3 GB
单张 H100 装得下参数,还剩约 44GB 用于 KV Cache 和 runtime buffer。
KV Cache 估算(LLaMA-3-70B 架构参数:80 层、GQA n_kv_heads=8、head_dim=128):
KV_Cache = 2 × batch × seq_len × n_layers × n_kv_heads × head_dim × dtype_bytes
= 2 × 1 × 4096 × 80 × 8 × 128 × 2 # BF16, batch=1, seq=4K
≈ 1.34 GB(单请求 4K 上下文)
batch=32 时 KV Cache 约 43 GB,叠加 INT4 参数的 35GB,正好接近 80GB 上限,已经没有余量应对碎片和激活峰值。实际部署需要用 FP8 KV Cache 把 KV 再减半,或者限制 batch size 上限。
训练场景(mixed precision,AdamW):
BF16 参数 = 70.6B × 2 ≈ 141 GB
BF16 梯度 = 70.6B × 2 ≈ 141 GB
FP32 优化器状态 = 70.6B × 8 ≈ 565 GB
激活值(估算) ≈ 数十至数百 GB(依 batch × seq_len)
--------------------------------------------------
合计(不含激活)≈ 847 GB
即使不算激活,70B 全参数训练需要约 850GB 显存,约 11 张 H100。70B 级别的训练必须用 ZeRO-3 把参数、梯度、优化器状态分片到多个 GPU,每卡只持有全局参数的 1/N。
选择方案的决策流程
① 确定任务类型
├─ 推理 → ② 确定精度容忍度
└─ 训练 → ⑤ 确定训练规模
② 精度容忍度
├─ 高精度要求(金融/医疗)→ BF16 全精度 → ③ 确定卡数
└─ 可接受 1-2% 精度损失 → INT4/INT8 量化 → ③ 确定卡数
③ 确定最少卡数
所需显存 ÷ (单卡显存 × 70%) = 最少卡数(保留 30% 给 KV Cache 和碎片)
├─ 单卡能装 → ④ KV Cache 优化
└─ 多卡 → Tensor Parallel(同节点,NVLink 低延迟)
或 Pipeline Parallel(跨节点,带宽受限时)
④ KV Cache 优化
├─ context P50 > 7K → FP8 KV Cache(H100)/ INT8 KV Cache(A100)
├─ 高并发共享系统提示 → Prefix Caching(radix tree 共享 KV block)
└─ 低 QPS 离线处理 → CPU Offload(swap_space 配置)
⑤ 训练规模
├─ 全参数微调 → ZeRO-3 + Gradient Checkpointing + BF16 混合精度
├─ 显存仍不足 → 叠加 Activation Offload + 减小 micro-batch + 增大 grad accumulation
└─ 参数高效方法 → LoRA/QLoRA(冻结主干,只训 adapter,显存降至 1/10 量级)
这个决策流程不是银弹,每个分支都需要压测验证。显存规划的最终依据是在真实负载下采集的 torch.cuda.memory_stats(),而不是理论计算。理论给你下界,实测给你安全水位。
九、把优化组合起来
每种优化技术都有自己的适用范围和代价。实际部署时,需要根据硬件配置、服务场景和精度要求组合使用。
推理服务的典型组合
H100 生产推理:
- FlashAttention-3(或 vLLM/TRT-LLM 内置,无需自行集成)
- FP8 KV Cache(context 中位数 > 7K 时启用,否则先做压测验证)
- PagedAttention + Prefix Caching(共享系统提示的场景效果突出)
- expandable_segments:True,禁止 max_split_size_mb 过小
- 显存水位预警(85% 告警,90% 限流),动态限制新请求的 max_tokens
A100 生产推理:
- FlashAttention-2(PyTorch SDPA 自动路由,确认 head_dim 满足条件)
- INT8 KV Cache(per-token 粒度,无需校准)
- PagedAttention,block_size=16 或 32(太小 block table 开销大,太大碎片大)
- max_split_size_mb:512 防碎片
T4 / 低端 GPU 推理: - FlashAttention-2(T4 是 Turing 架构,FA-2 兼容) - INT4 KV Cache(激进量化,适合精度不敏感场景) - KV Cache CPU Offload(带宽约 16 GB/s PCIe 3.0,仅在低 QPS 离线场景使用) - 限制最大序列长度(T4 只有 16GB,不限长度必定 OOM)
训练场景的典型组合
单卡训练(A100/H100):
- Gradient Checkpointing(SAC 模式,保留 matmul 和 FlashAttention 结果)
- FlashAttention-2
- 混合精度(BF16 参数 + FP32 优化器状态,adamw_bf16 可进一步压缩优化器状态)
- torch.compile 开启算子融合,配合 fullgraph=True 避免 graph break
显存极度紧张时(显存 < 40GB 训练 13B+): - Gradient Checkpointing 全量(先保证能跑起来) - 再用 SAC 把额外计算开销降下来 - 减小 micro-batch size + 增大 gradient accumulation steps 到等效 batch size - CPU Offload 优化器状态(ZeRO-2:分布优化器状态和梯度到 CPU;ZeRO-3:连参数也分布) - 最后手段:Activation Offload,预期会有较大延迟代价
不要过度优化的警告
每种优化都有隐藏代价。FP8 KV Cache 在 7K 以下 context 反而变慢;Activation Offload 在 H100 上因 PCIe 瓶颈反而比重算慢;max_split_size_mb 设置过小会导致大 block 无法被拆分给多个小请求,显存利用率下降;expandable_segments:True 与 CUDA Graph 存在兼容性风险。
优化的顺序应该是:先测量,找到真正的瓶颈,再针对瓶颈选择对应的优化手段。用 FlashAttention 解决显存量问题是正确的;用 FlashAttention 解决显存碎片问题则是错配,它们解决的根本不是同一个问题。
一个常见的误区是把所有优化一起打开,然后在出问题时无法判断是哪个优化引入的。推荐的工程做法是逐项引入、逐项验证:建立 baseline,引入一个优化,压测验证吞吐和延迟,确认没有退化,再引入下一个。这比一次性全上要慢,但遇到问题时回滚的代价小得多。
量化各路优化的本质
最后整体看一眼:这些优化节省的究竟是什么?
FlashAttention 节省的是 HBM 带宽:减少中间矩阵的 HBM 读写,让相同的算力能处理更多的 token。它不减少显存容量需求(中间矩阵不需要长期存储),减少的是瞬时 IO。
KV Cache 量化节省的是显存容量:把 KV Cache 压缩到原来的 50%(FP8)或 25%(INT4),让相同显存能装下更多并发请求或更长的 context。
Gradient Checkpointing 节省的是激活值显存:用重算换存储,让训练可以在更小的显存上跑更大的 batch。
算子融合节省的是启动开销和 HBM 往返:把多个 kernel 合成一个,中间值在 SRAM 内流转,不写 HBM。
显存碎片治理节省的是显存可用率:reserved memory 不变,但 usable block 更多。
这五条路径针对的是五种不同的物理瓶颈,互不替代。理解了"节省的是什么",就知道在自己的场景里该优先选哪条路。
基准测量与工具链
在引入任何优化之前,建立清晰的基准测量体系:
import torch
import time
class MemoryAndLatencyBenchmark:
"""
推理服务基准测量工具
在部署任何显存优化之前,先用这个工具建立基线
"""
def __init__(self, device=0):
self.device = device
self.baseline = {}
def capture_baseline(self, tag='before_optimization'):
"""捕获当前状态作为基线"""
torch.cuda.synchronize(self.device)
stats = torch.cuda.memory_stats(self.device)
props = torch.cuda.get_device_properties(self.device)
self.baseline[tag] = {
'reserved_gb': stats['reserved_bytes.all.current'] / 1e9,
'active_gb': stats['active_bytes.all.current'] / 1e9,
'inactive_split_gb': stats['inactive_split_bytes.all.current'] / 1e9,
'total_gb': props.total_memory / 1e9,
}
return self.baseline[tag]
def compare(self, before_tag, after_tag):
"""对比优化前后的显存变化"""
before = self.baseline[before_tag]
after = self.baseline[after_tag]
delta_reserved = after['reserved_gb'] - before['reserved_gb']
delta_active = after['active_gb'] - before['active_gb']
print(f'Reserved 变化: {delta_reserved:+.2f} GB')
print(f'Active 变化: {delta_active:+.2f} GB')
print(f'碎片率 before: {before["inactive_split_gb"]/before["reserved_gb"]:.1%}')
print(f'碎片率 after: {after["inactive_split_gb"]/after["reserved_gb"]:.1%}')
def latency_profile(self, fn, n_warmup=10, n_measure=100):
"""
测量函数延迟,去除 warmup
使用 CUDA event 计时,比 Python time.time() 精确
"""
# Warmup(让 GPU 预热,JIT 编译完成)
for _ in range(n_warmup):
with torch.no_grad():
fn()
torch.cuda.synchronize(self.device)
# 正式测量
start_events = [torch.cuda.Event(enable_timing=True) for _ in range(n_measure)]
end_events = [torch.cuda.Event(enable_timing=True) for _ in range(n_measure)]
for i in range(n_measure):
start_events[i].record()
with torch.no_grad():
fn()
end_events[i].record()
torch.cuda.synchronize(self.device)
latencies = [s.elapsed_time(e) for s, e in zip(start_events, end_events)]
p50 = sorted(latencies)[n_measure // 2]
p99 = sorted(latencies)[int(n_measure * 0.99)]
print(f'P50 latency: {p50:.2f} ms')
print(f'P99 latency: {p99:.2f} ms')
return {'p50': p50, 'p99': p99, 'all': latencies}
使用 CUDA event 计时比 Python 的 time.time() 精确得多,因为 GPU 操作是异步的,time.time() 可能在 GPU 完成之前就返回,导致测量结果偏低。torch.cuda.synchronize() 强制等待 GPU 完成,但如果插在每次测量前后,本身也会引入延迟。CUDA event 是最正确的做法:在 GPU 流上插入 event 标记,事后查询两个 event 之间的真实 GPU 时间。
从今天起可以执行的第一步:在你的推理服务里加上 check_memory_health() 的监控端点,每隔 5 分钟记录一次 reserved_bytes、active_bytes、inactive_split_bytes。一周数据之后,你会清楚地看到碎片率是不是真正的问题、显存峰值在什么场景下出现、KV Cache 和激活值各自占多大比例。没有度量,就没有优化,只有猜测。
作者:toy