AI训练可能持续数天、数周甚至数月。不幸的是,这个过程经常容易受到意外中断的影响。无论您是想使用Spot实例节省高达3倍的成本,还是想防止随机的GPU故障毁掉您的一天,定期使用检查点保存进度是构建容错训练过程的关键一步

在这篇文章中,我们将探讨在云上加速AI模型检查点的技术,以及如何在SkyPilot中轻松实现这些技术。

提要

  • 使用高性能磁盘写入检查点。
  • 将检查点上传到云存储桶以安全地存储检查点。
  • 使用本地磁盘作为云存储桶的缓存,可将检查点速度提高9.6倍。

这是一个快速的SkyPilot YAML配置示例,展示了这种方法

resources:
  accelerators: A100:8
  disk_tier: best

workdir: .

file_mounts:
  /checkpoints:
    source: gs://my-checkpoints-bucket
    mode: MOUNT_CACHED

run: |
  python train.py --outputs /checkpoints  

检查点不仅仅是保存文件

检查点可能看起来很简单——只需将内存中的状态保存到磁盘——但现实要复杂得多。

具体来说,检查点会带来性能上的权衡。由于内存中的模型在导出到磁盘时必须保持不变,模型训练必须在写入检查点期间暂停。在等待磁盘I/O时保持GPU空闲会降低GPU利用率,导致成本显著增加(如下图所示,最高可达30%)。

通过最小化写入每个检查点所需的时间,从而最小化GPU的空闲时间,可以减轻检查点对性能的负面影响。写入检查点的时间取决于两个主要因素:检查点大小写入速度

在8x A100 GCP上微调Llama 7B,检查点周期为30分钟。

要写入的检查点大小与模型大小成正比。随着AI模型在LLM热潮后持续扩展,检查点的大小也随之增长,使得检查点的写入成本更高。如果没有适当的设置,您可能会发现检查点占用了高达三分之一的训练时间。在此期间,昂贵的GPU处于空闲状态,等待检查点操作完成。

技巧1:使用高性能磁盘

使用高端磁盘可显著提升检查点性能。如图所示,使用高端磁盘显著降低了检查点开销,在相同时间内可以进行更多训练计算。

在8x A100 GCP上使用高端磁盘微调Llama 7B。更快的磁盘减少了检查点花费的时间,提高了GPU利用率。

SkyPilot简化了高性能磁盘的选择

sky launch --disk-tier best train.yaml

或者,直接在SkyPilot YAML文件train.yaml中指定

resources:
  accelerators: A100:8
  disk_tier: best

使用disk_tier: best指示SkyPilot选择支持的性能最高的卷/磁盘。例如

  • 在AWS上,SkyPilot配置VM使用io2
  • 在GCP上,SkyPilot使用pd-ssd
  • 在Azure上,SkyPilot使用Premium_LRS磁盘

将检查点持久化到云存储桶

在云原生环境中,磁盘是短暂的。当VM或Pod被删除时,任何本地保存的检查点都可能随之消失。为了真正持久化检查点,必须将它们上传到与计算资源生命周期分离的远程存储。云存储桶就是一种远程存储检查点的选项。

云存储桶与常见的POSIX文件系统行为不同,因此拥有一套不同的API。由于这种差异,用于在磁盘上保存和加载检查点的PyTorch脚本无法直接在云存储桶上工作。

def save_training_state(step: int, model: torch.nn.Module, optimizer: torch.optim.Optimizer):
    torch.save({
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'step': step
    }, f"checkpoints/checkpoint_{step}.pt")

技巧2:将检查点上传到云存储桶,无需修改代码

SkyPilot提供了一个简单的选项,可以将云存储桶挂载到您创建的Pod/VM上,这样您就可以像与磁盘上的普通路径交互一样,在存储桶上进行读写操作。*

file_mounts:
  /checkpoints:
    source: gs://my-checkpoint-bucket

train.yaml中添加此块后,写入路径/checkpoints的所有文件都会直接同步到指定的云存储桶:gs://my-checkpoint-bucket

这对小型模型效果很好,例如,一个1.3B的LLM模型只需24秒即可将检查点写入/checkpoints。然而,对于更大的模型,一次保存模型检查点可能需要8分钟以上。如果每30分钟写入一次检查点,这几乎占用了三分之一的计算时间。云存储桶上传的性能限制使得它的表现像低端磁盘一样,抵消了技巧1带来的改进。

参数大小检查点大小 (GB)检查点延迟 (s)
1.3B2.324
6.7B13491

* 默认挂载模式不支持某些特定的POSIX操作,例如随机写入或追加。

技巧3:使用本地磁盘作为云存储桶的缓存

在性能敏感的应用中,可以通过添加中间缓存来屏蔽慢速的IO操作。例如,CPU依赖多级硬件缓存来屏蔽与RAM交互的IO延迟。

同样的方法也可应用于挂载云存储桶。SkyPilot最近发布了一种新的云存储桶挂载模式MOUNT_CACHED,它带来了显著的9.6倍性能提升。

在8x A100 GCP上使用MOUNT_CACHED微调Llama 7B。使检查点上传异步化可解除对GPU的阻塞,提高利用率。

这只需对SkyPilot YAML文件train.yaml进行一行修改

file_mounts:
  /checkpoints:
    source: s3://my-checkpoint-bucket
    mode: MOUNT_CACHED

这种方法结合了前两个技巧的经验,提供了两全其美的方案。检查点最初保存到本地高性能磁盘,使训练可以尽快继续。当GPU持续工作时,检查点会异步上传到云存储桶进行持久存储。

通过将检查点上传到云存储桶的操作与训练过程交错进行,我们可以优化检查点速度和资源利用率。

额外技巧:使用ML框架提供的异步检查点

将缓慢的检查点操作移出训练的关键路径是一个众所周知的概念。一些ML框架实现了异步检查点功能,例如PyTorch的分布式异步检查点dcp.async_save),它利用RAM来加快阻塞式写入的完成速度。这些方法可以作为我们上面讨论的方法的补充,引入另一层缓存。

MOUNT_CACHED + 异步检查点的示例

结论

为了加速云上的模型检查点操作,请使用高端磁盘作为缓存,并异步将检查点上传到云存储桶。SkyPilot最近推出的针对云存储桶的MOUNT_CACHED模式,只需对SkyPilot YAML进行一行修改,即可将检查点性能提升9.6倍

# Install via: pip install 'skypilot-nightly[aws,gcp,azure,kubernetes]'

resources:
  accelerators: A100:8
  disk_tier: best

workdir: .

file_mounts:
  /checkpoints:
    source: gs://my-checkpoint-bucket
    mode: MOUNT_CACHED

run: |
  python train.py --outputs /checkpoints  

后续步骤

致谢:我们感谢Doyoung KimMOUNT_CACHED的初步设计和实现方面付出的巨大努力,并感谢Hriday Sheth对原型的早期基准测试和测试。