技术指南

激活重新计算权衡

激活重新计算(梯度或激活检查点)通过丢弃前向传递中的中间激活并在后向传递中重新计算它们,在训练期间节省 GPU 内存。

概述

激活重新计算(梯度或激活检查点)通过丢弃前向传递中的中间激活并在后向传递中重新计算它们,在训练期间节省 GPU 内存。它以额外的计算能力换取在相同硬件上训练更大模型或更长序列的能力。

激活重新计算权衡是一个技术构建块,会大规模影响模型质量、基础设施成本、延迟和可靠性。

深入探讨

反向传播需要前向传递激活来计算梯度,因此默认情况下存储每一层的输出——巨大的内存成本随着模型大小、批量大小和序列长度的增加而增长。激活重新计算仅保留一些“检查点”张量(通常只是层边界)并丢弃其余部分。在向后传递期间,它重新运行检查点之间的前向计算,以根据需要重新生成丢弃的激活。经典的结果是,在每 sqrt(N) 层放置检查点时,内存会下降到大约 O(sqrt(N)),同时增加大约一次额外的前向传递(计算量增加约 33%)。选择性变体仅重新计算廉价但占用大量内存的操作(例如注意力或丢失),同时缓存昂贵的操作,以少得多的重新计算开销获得大部分内存节省。

技术洞察

基本的权衡是内存与 FLOPs。完全重新计算大约每步增加一次额外的前向传递(慢约 30-40%),但可以将激活内存减少一个数量级。明智之举是选择性检查点:识别内存大但计算成本低的操作(softmax、layernorm、GELU、注意力分数)并仅重新计算这些操作,同时缓存昂贵的 GEMM 的结果 - 最大限度地减少计算浪费。

掌握激活重新计算权衡

激活重新计算(梯度或激活检查点)通过丢弃前向传递中的中间激活并在后向传递中重新计算它们,在训练期间节省 GPU 内存。它以额外的计算能力换取在相同硬件上训练更大模型或更长序列的能力。激活重新计算权衡是一个技术构建块,会大规模影响模型质量、基础设施成本、延迟和可靠性。为了建立深入的理解,请将激活重新计算权衡视为一种操作模型,而不是单个功能:定义所需的结果,澄清假设,并将系统可以可靠地执行的操作与仍需要专家判断的操作分开。

在实践中,使用激活重新计算权衡的强大团队可以根据可靠性和成本来优化架构、数据和基础设施选择。他们记录明确的成功标准,根据实际数据和工作流程进行测试,并根据观察到的失败模式而不是一次性基准测试胜利进行迭代。这就是理论理解转变为跨产品、政策和运营的持久能力的地方。

多年来,架构决策决定着性能和运营成本。与此同时,优化一个基准测试可以隐藏更广泛的系统弱点。最具弹性的方法是将实验速度与治理规则结合起来:运行试点、捕获证据、发布决策日志,并随着模型行为、用户期望和监管要求的发展不断更新保障措施。

战略影响

多年来,架构决策决定着性能和运营成本。

多年来,架构决策决定着性能和运营成本。在高质量部署中,这会转化为可衡量的操作规则、所有权边界和定期审查仪式,以便团队可以增强信心,而不是扩大模糊性。

技术教育帮助团队选择正确的堆栈,而不仅仅是最新的堆栈。

技术教育帮助团队选择正确的堆栈,而不仅仅是最新的堆栈。在高质量部署中,这会转化为可衡量的操作规则、所有权边界和定期审查仪式,以便团队可以增强信心,而不是扩大模糊性。

更好的工程选择可以减少生产中的可靠性事故。

更好的工程选择可以减少生产中的可靠性事故。在高质量部署中,这会转化为可衡量的操作规则、所有权边界和定期审查仪式,以便团队可以增强信心,而不是扩大模糊性。

激活重新计算权衡的未来

重新计算变得越来越自动化和选择性。现在,框架会分析每个操作的内存和 FLOP 成本,以选择最佳检查点,并将重新计算与激活卸载到 CPU/NVMe 以及并行策略相结合。随着上下文长度和模型大小不断增长,预计编译器驱动的策略(在 PyTorch、JAX/XLA 中)会自动选择每个操作的重新计算决策,再加上重新计算与通信的更紧密重叠,因此额外的 FLOP 被部分隐藏。

现实世界的实施

通过检查每个层块来训练一个大型变压器,否则该变压器无法适应

使用 PyTorch 的 torch.utils.checkpoint 包装变压器块并切割激活内存

Megatron-LM 中注意力/softmax 的选择性重新计算,以节省内存且速度减慢最小

通过重新计算激活而不是存储它们,在固定的 GPU 预算上实现更长的序列长度

实施模式

实践中激活重新计算的权衡

通过检查每个层块来训练一个无法适应的大型变压器。

通过对每个层块设置检查点来训练原本不适合的大型变压器 团队在预先定义质量阈值、为边缘情况保留人工升级路径并随着时间的推移跟踪生产力增益和错误成本时通常会获得更好的结果。

实践中激活重新计算的权衡

使用 PyTorch 的 torch.utils.checkpoint 包装变压器块并剪切激活内存。

使用 PyTorch 的 torch.utils.checkpoint 来包装变压器块并减少激活内存 当团队预先定义质量阈值、为边缘情况保留人工升级路径并随着时间的推移跟踪生产力增益和错误成本时,通常会获得更好的结果。

实践中激活重新计算的权衡

Megatron-LM 中选择性重新计算注意力/softmax,以节省内存并最小化速度减慢。

在 Megatron-LM 中选择性地重新计算注意力/softmax,以最小化速度节省内存 当团队预先定义质量阈值、为边缘情况保留人工升级路径并随着时间的推移跟踪生产力增益和错误成本时,通常会获得更好的结果。

实践中激活重新计算的权衡

通过重新计算激活而不是存储它们,在固定的 GPU 预算上实现更长的序列长度。

通过重新计算激活而不是存储它们,在固定的 GPU 预算上实现更长的序列长度 当团队预先定义质量阈值、为边缘情况保留人工升级路径并随着时间的推移跟踪生产力增益和错误成本时,通常会获得更好的结果。

风险与防护栏

!

优化一项基准测试可以隐藏更广泛的系统弱点。

!

基础设施和维护成本常常被低估。

!

随着系统变得更加复杂,安全性和可观察性差距可能会扩大。

实施路线图

1

在实施之前定义延迟、质量和成本目标。

在实施之前定义延迟、质量和成本目标。将每个步骤视为证据门:如果不满足标准,则暂停推出,缩小差距,然后再扩大使用。

2

在实际负载和数据条件下进行基准测试。

在实际负载和数据条件下进行基准测试。将每个步骤视为证据门:如果不满足标准,则暂停推出,缩小差距,然后再扩大使用。

3

仪器监控错误、漂移和用户影响。

仪器监控错误、漂移和用户影响。将每个步骤视为证据门:如果不满足标准,则暂停推出,缩小差距,然后再扩大使用。

4

在扩展之前准备回滚和事件响应路径。

在扩展之前准备回滚和事件响应路径。将每个步骤视为证据门:如果不满足标准,则暂停推出,缩小差距,然后再扩大使用。

不断探索