ZeRO vs FSDP:大模型分布式训练的显存算账与通信拆解

从 16N GB 显存公式出发,逐级拆解 ZeRO-1/2/3 的切分逻辑,再对比 PyTorch 原生 FSDP 在工程实现上的核心差异。

1. 引言:传统 DP 的显存瓶颈与切分逻辑

在训练大模型时,单卡 OOM(Out of Memory)是常态。为了弄清楚如何切分模型,必须先明确训练时的显存究竟被什么占据。

假设训练一个参数量为 NN B 的模型(例如 N=70N = 70,即 70B),采用混合精度训练(Mixed Precision,FP16/BF16),模型状态(Model States)的显存占用如下:

总计:仅维持模型状态的基础显存就需要 16N16N GB。对于 70B 模型,这高达 1.12 TB,远超单张 80G A100/H100 的容量。

传统的 Data Parallelism (DP) 要求每张卡保留完整的 16N16N GB 状态,产生极大冗余。ZeRO 和 FSDP 的核心思路就是打破 DP 的限制,将这 16N16N GB 的显存开销切片并分摊到集群的多个 GPU 上。


2. ZeRO 1/2/3:从优化器到权重的逐级切分

DeepSpeed 提出的 ZeRO(Zero Redundancy Optimizer)通过三个阶段逐步切分模型状态,核心在于显存与通信带宽的 Trade-off。

ZeRO-1:切分优化器状态(Optimizer States Partitioning)

机制:权重和梯度依然每卡保留全量(各 2N2N GB),但将最占显存的 Adam 状态(12N12N GB)切成若干份。卡 ii 只负责更新属于自己的那一块权重。

显存占用(GB):

2N+2N+12NGPUs2N + 2N + \frac{12N}{\text{GPUs}}

通信逻辑:反向传播(Backward)计算完全量梯度后,卡 ii 取出自己负责的梯度进行参数更新。更新完成后,触发一次 AllGather 操作,将所有卡更新后的局部权重广播,拼接成最新的完整权重。


ZeRO-2:切分梯度(Gradient Partitioning)

机制:既然卡 ii 只负责更新特定部分的权重,它也就只需要保留该部分的梯度,无需存储全量梯度。

显存占用(GB):

2N+2NGPUs+12NGPUs2N + \frac{2N}{\text{GPUs}} + \frac{12N}{\text{GPUs}}

通信逻辑:Backward 过程中,每计算出一层的梯度,立即执行一次 ReduceScatter。将属于卡 ii 的梯度归约并发送给卡 ii,同时释放其他卡上的这部分梯度显存。


ZeRO-3:切分权重(Parameter Partitioning)

机制:将 FP16 的权重也进行切分。每张卡平时只保留 1GPUs\frac{1}{\text{GPUs}} 的模型状态。

显存占用(GB):

16NGPUs\frac{16N}{\text{GPUs}}

以 70B 模型、64 张卡为例,每卡仅需约 17.5 GB,单张 A100 即可承载。

通信逻辑


3. FSDP:PyTorch 原生方案与系统调度差异

FSDP(Fully Sharded Data Parallel)是 PyTorch 原生的切分方案。宏观上,FSDP 的切分逻辑等价于 ZeRO-3(切分参数、梯度和优化器),但其在工程实现和系统调度上存在显著差异。

差异一:拦截层级(Tensor vs Module)

ZeRO-3 在底层的 Tensor 级别进行拦截和替换,动态分配显存。

FSDP 建立在 PyTorch 的 nn.Module 之上。它通过 AutoWrapPolicy 包装原生 Module,将一个 Module 内的参数打平为一维的 FlatParameter,并以此为单位进行 Sharding 和通信。

差异二:计算与通信的 Overlap

FSDP 能够深度利用 CUDA Streams 实现通信与计算的重叠(Overlap)。当 FSDP 计算第 LL 层前向传播时,可异步发起第 L+1L+1 层的 AllGather 通信。如果网络带宽充裕,通信耗时会被计算耗时完全掩盖,从而实现极高的 MFU(Model Flops Utilization)。

差异三:显存分配的确定性

ZeRO-3 运行时动态拉取和释放参数,在某些复杂网络拓扑下易引发显存碎片化。

FSDP 由于提前按 Module 定义了 Wrap 规则(如按 Transformer Block 包装),显存的分配和释放是静态且可预测的,降低了突发 OOM 的风险。


4. 总结:工业界的排障与选型

在千卡集群的实际训练中:

无论选择哪种架构,集群训练的真正难点在于应对 RDMA 网络抖动导致的 AllGather/ReduceScatter 通信死锁(Deadlock)。如何通过 Nsight Systems 对 GPU 的 Timeline 进行精准的 Profile,才是检验底层分布式训练能力的最终标准。

← 返回主页