使用 KubeRay 和 GCSFuse 的分布式检查点#

本示例演示了如何使用 KubeRay 协调分布式检查点,其中使用 GCSFuse CSI 驱动程序和 Google Cloud Storage 作为远程存储系统。为了说明这些概念,本指南使用了使用 Ray Train 对 PyTorch 图像分类器进行微调示例。

为什么使用 GCSFuse 进行分布式检查点?#

在大规模、高性能机器学习中,分布式检查点对于容错至关重要,它确保如果在训练期间节点发生故障,Ray 可以从最新的保存检查点恢复进程,而不是从头开始。虽然可以直接引用远程存储路径(例如 gs://my-checkpoint-bucket),但使用 Google Cloud Storage FUSE (GCSFuse) 对于分布式应用具有明显的优势。GCSFuse 允许您像挂载本地文件系统一样挂载 Cloud Storage 存储桶,这使得依赖这些语义的分布式应用的检查点管理更加直观。此外,GCSFuse 专为高性能工作负载设计,为大型模型的分布式检查点提供所需的性能和可扩展性。

分布式检查点GCSFuse 结合使用,可以实现更大规模的模型训练,并提高可用性和效率。

在 GKE 上创建 Kubernetes 集群#

创建一个启用了 GCSFuse CSI 驱动程序Workload Identity 的 GKE 集群,以及一个包含 4 个 L4 GPU 的 GPU 节点池

export PROJECT_ID=<your project id>
gcloud container clusters create kuberay-with-gcsfuse \
    --addons GcsFuseCsiDriver \
    --cluster-version=1.29.4 \
    --location=us-east4-c \
    --machine-type=g2-standard-8 \
    --release-channel=rapid \
    --num-nodes=4 \
    --accelerator type=nvidia-l4,count=1,gpu-driver-version=latest \
    --workload-pool=${PROJECT_ID}.svc.id.goog

验证您的集群(含 4 个 GPU)是否成功创建

$ kubectl get nodes "-o=custom-columns=NAME:.metadata.name,GPU:.status.allocatable.nvidia\.com/gpu"
NAME                                                  GPU
gke-kuberay-with-gcsfuse-default-pool-xxxx-0000       1
gke-kuberay-with-gcsfuse-default-pool-xxxx-1111       1
gke-kuberay-with-gcsfuse-default-pool-xxxx-2222       1
gke-kuberay-with-gcsfuse-default-pool-xxxx-3333       1

安装 KubeRay operator#

按照部署 KubeRay operator 从 Helm 仓库安装最新的稳定版 KubeRay operator。如果您正确设置了 GPU 节点池的污点 (taint),KubeRay operator Pod 必须位于 CPU 节点上。

配置 GCS Bucket#

创建一个 Ray 用作远程文件系统的 GCS Bucket。

BUCKET=<your GCS bucket>
gcloud storage buckets create gs://$BUCKET --uniform-bucket-level-access

创建一个 Kubernetes ServiceAccount,授予 RayCluster 挂载 GCS Bucket 的权限

kubectl create serviceaccount pytorch-distributed-training

roles/storage.objectUser 角色绑定到 Kubernetes service account 和 bucket IAM 策略。请参阅 识别项目以查找您的项目 ID 和项目编号

PROJECT_ID=<your project ID>
PROJECT_NUMBER=<your project number>
gcloud storage buckets add-iam-policy-binding gs://${BUCKET} --member "principal://iam.googleapis.com/projects/${PROJECT_NUMBER}/locations/global/workloadIdentityPools/${PROJECT_ID}.svc.id.goog/subject/ns/default/sa/pytorch-distributed-training"  --role "roles/storage.objectUser"

有关更多详情,请参阅 使用 Cloud Storage FUSE CSI 驱动程序访问 Cloud Storage 存储桶

部署 RayJob#

下载执行使用 Ray Train 对 PyTorch 图像分类器进行微调中所有步骤的 RayJob。其源代码也位于 KubeRay 仓库中。

curl -LO https://raw.githubusercontent.com/ray-project/kuberay/master/ray-operator/config/samples/pytorch-resnet-image-classifier/ray-job.pytorch-image-classifier.yaml

修改 RayJob,将所有 GCS_BUCKET 占位符替换为您之前创建的 Google Cloud Storage Bucket。或者,您也可以使用 sed

sed -i "s/GCS_BUCKET/$BUCKET/g" ray-job.pytorch-image-classifier.yaml

部署 RayJob

kubectl create -f ray-job.pytorch-image-classifier.yaml

部署的 RayJob 包含以下配置,以启用分布式检查点到共享文件系统

  • 4 个 Ray worker,每个带有一个 GPU。

  • 所有 Ray 节点使用我们之前创建的 pytorch-distributed-training ServiceAccount。

  • 包含由 gcsfuse.csi.storage.gke.io CSI 驱动程序管理的卷。

  • 挂载一个共享存储路径 /mnt/cluster_storage,其后端是您之前创建的 GCS Bucket。

您可以使用注解 (annotation) 配置 Pod,这允许对 GCSFuse sidecar 容器进行更精细的控制。有关更多详情,请参阅指定 Pod 注解

annotations:
  gke-gcsfuse/volumes: "true"
  gke-gcsfuse/cpu-limit: "0"
  gke-gcsfuse/memory-limit: 5Gi
  gke-gcsfuse/ephemeral-storage-limit: 10Gi

定义 GCSFuse 容器卷时,您也可以指定挂载选项

csi:
  driver: gcsfuse.csi.storage.gke.io
  volumeAttributes:
    bucketName: GCS_BUCKET
    mountOptions: "implicit-dirs,uid=1000,gid=100"

请参阅挂载选项了解更多关于挂载选项的信息。

Ray 作业的日志应表明使用了共享远程文件系统 /mnt/cluster_storage 和检查点目录。例如:

Training finished iteration 10 at 2024-04-29 10:22:08. Total running time: 1min 30s
╭─────────────────────────────────────────╮
│ Training result                         │
├─────────────────────────────────────────┤
│ checkpoint_dir_name   checkpoint_000009 │
│ time_this_iter_s                6.47154 │
│ time_total_s                    74.5547 │
│ training_iteration                   10 │
│ acc                             0.24183 │
│ loss                            0.06882 │
╰─────────────────────────────────────────╯
Training saved a checkpoint for iteration 10 at: (local)/mnt/cluster_storage/finetune-resnet/TorchTrainer_cbb82_00000_0_2024-04-29_10-20-37/checkpoint_000009

检查检查点数据#

RayJob 完成后,您可以使用 gsutil 等工具检查您的 bucket 内容。

gsutil ls gs://my-ray-bucket/**
gs://my-ray-bucket/finetune-resnet/
gs://my-ray-bucket/finetune-resnet/.validate_storage_marker
gs://my-ray-bucket/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/
gs://my-ray-bucket/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/checkpoint_000007/
gs://my-ray-bucket/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/checkpoint_000007/checkpoint.pt
gs://my-ray-bucket/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/checkpoint_000008/
gs://my-ray-bucket/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/checkpoint_000008/checkpoint.pt
gs://my-ray-bucket/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/checkpoint_000009/
gs://my-ray-bucket/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/checkpoint_000009/checkpoint.pt
gs://my-ray-bucket/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/error.pkl
gs://my-ray-bucket/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/error.txt
gs://my-ray-bucket/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/events.out.tfevents.1714436502.orch-image-classifier-nc2sq-raycluster-tdrfx-head-xzcl8
gs://my-ray-bucket/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/events.out.tfevents.1714436809.orch-image-classifier-zz4sj-raycluster-vn7kz-head-lwx8k
gs://my-ray-bucket/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/params.json
gs://my-ray-bucket/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/params.pkl
gs://my-ray-bucket/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/progress.csv
gs://my-ray-bucket/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/result.json
gs://my-ray-bucket/finetune-resnet/basic-variant-state-2024-04-29_17-21-29.json
gs://my-ray-bucket/finetune-resnet/basic-variant-state-2024-04-29_17-26-35.json
gs://my-ray-bucket/finetune-resnet/experiment_state-2024-04-29_17-21-29.json
gs://my-ray-bucket/finetune-resnet/experiment_state-2024-04-29_17-26-35.json
gs://my-ray-bucket/finetune-resnet/trainer.pkl
gs://my-ray-bucket/finetune-resnet/tuner.pkl

从检查点恢复#

如果作业失败,您可以使用最新的检查点恢复模型训练。本示例配置 TorchTrainer 自动从最新的检查点恢复

experiment_path = os.path.expanduser("/mnt/cluster_storage/finetune-resnet")
if TorchTrainer.can_restore(experiment_path):
    trainer = TorchTrainer.restore(experiment_path,
        train_loop_per_worker=train_loop_per_worker,
        train_loop_config=train_loop_config,
        scaling_config=scaling_config,
        run_config=run_config,
    )
else:
    trainer = TorchTrainer(
        train_loop_per_worker=train_loop_per_worker,
        train_loop_config=train_loop_config,
        scaling_config=scaling_config,
        run_config=run_config,
    )

通过重新部署相同的 RayJob,您可以验证自动检查点恢复功能。

kubectl create -f ray-job.pytorch-image-classifier.yaml

如果之前的作业成功完成,训练作业应该从 checkpoint_000009 目录恢复检查点状态,然后立即完成训练,迭代次数为 0

2024-04-29 15:51:32,528 INFO experiment_state.py:366 -- Trying to find and download experiment checkpoint at /mnt/cluster_storage/finetune-resnet
2024-04-29 15:51:32,651 INFO experiment_state.py:396 -- A remote experiment checkpoint was found and will be used to restore the previous experiment state.
2024-04-29 15:51:32,652 INFO tune_controller.py:404 -- Using the newest experiment state file found within the experiment directory: experiment_state-2024-04-29_15-43-40.json

View detailed results here: /mnt/cluster_storage/finetune-resnet
To visualize your results with TensorBoard, run: `tensorboard --logdir /home/ray/ray_results/finetune-resnet`

Result(
  metrics={'loss': 0.070047477101968, 'acc': 0.23529411764705882},
  path='/mnt/cluster_storage/finetune-resnet/TorchTrainer_ecc04_00000_0_2024-04-29_15-43-40',
  filesystem='local',
  checkpoint=Checkpoint(filesystem=local, path=/mnt/cluster_storage/finetune-resnet/TorchTrainer_ecc04_00000_0_2024-04-29_15-43-40/checkpoint_000009)
)

如果之前的作业在较早的检查点失败,作业应该从上次保存的检查点恢复,并运行直到 max_epochs=10。例如,如果上次运行在第 7 个 epoch 失败,训练将自动使用 checkpoint_000006 恢复,并再运行 3 次迭代直到第 10 个 epoch

(TorchTrainer pid=611, ip=10.108.2.65) Restored on 10.108.2.65 from checkpoint: Checkpoint(filesystem=local, path=/mnt/cluster_storage/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/checkpoint_000006)
(RayTrainWorker pid=671, ip=10.108.2.65) Setting up process group for: env:// [rank=0, world_size=4]
(TorchTrainer pid=611, ip=10.108.2.65) Started distributed worker processes:
(TorchTrainer pid=611, ip=10.108.2.65) - (ip=10.108.2.65, pid=671) world_rank=0, local_rank=0, node_rank=0
(TorchTrainer pid=611, ip=10.108.2.65) - (ip=10.108.1.83, pid=589) world_rank=1, local_rank=0, node_rank=1
(TorchTrainer pid=611, ip=10.108.2.65) - (ip=10.108.0.72, pid=590) world_rank=2, local_rank=0, node_rank=2
(TorchTrainer pid=611, ip=10.108.2.65) - (ip=10.108.3.76, pid=590) world_rank=3, local_rank=0, node_rank=3
(RayTrainWorker pid=589, ip=10.108.1.83) Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /home/ray/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
(RayTrainWorker pid=671, ip=10.108.2.65)
  0%|          | 0.00/97.8M [00:00<?, ?B/s]
(RayTrainWorker pid=671, ip=10.108.2.65)
 22%|██▏       | 21.8M/97.8M [00:00<00:00, 229MB/s]
(RayTrainWorker pid=671, ip=10.108.2.65)
 92%|█████████▏| 89.7M/97.8M [00:00<00:00, 327MB/s]
(RayTrainWorker pid=671, ip=10.108.2.65)
100%|██████████| 97.8M/97.8M [00:00<00:00, 316MB/s]
(RayTrainWorker pid=671, ip=10.108.2.65) Moving model to device: cuda:0
(RayTrainWorker pid=671, ip=10.108.2.65) Wrapping provided model in DistributedDataParallel.
(RayTrainWorker pid=671, ip=10.108.2.65) Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /home/ray/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth [repeated 3x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.rayai.org.cn/en/master/ray-observability/ray-logging.html#log-deduplication for more options.)
(RayTrainWorker pid=590, ip=10.108.3.76)
  0%|          | 0.00/97.8M [00:00<?, ?B/s] [repeated 3x across cluster]
(RayTrainWorker pid=590, ip=10.108.0.72)
 85%|████████▍ | 82.8M/97.8M [00:00<00:00, 256MB/s]
100%|██████████| 97.8M/97.8M [00:00<00:00, 231MB/s] [repeated 11x across cluster]
(RayTrainWorker pid=590, ip=10.108.3.76)
100%|██████████| 97.8M/97.8M [00:00<00:00, 238MB/s]
(RayTrainWorker pid=671, ip=10.108.2.65) Epoch 7-train Loss: 0.0903 Acc: 0.2418
(RayTrainWorker pid=671, ip=10.108.2.65) Epoch 7-val Loss: 0.0881 Acc: 0.2353
(RayTrainWorker pid=590, ip=10.108.0.72) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/mnt/cluster_storage/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/checkpoint_000007)
(RayTrainWorker pid=590, ip=10.108.0.72) Moving model to device: cuda:0 [repeated 3x across cluster]
(RayTrainWorker pid=590, ip=10.108.0.72) Wrapping provided model in DistributedDataParallel. [repeated 3x across cluster]

Training finished iteration 8 at 2024-04-29 17:27:29. Total running time: 54s
╭─────────────────────────────────────────╮
│ Training result                         │
├─────────────────────────────────────────┤
│ checkpoint_dir_name   checkpoint_000007 │
│ time_this_iter_s               40.46113 │
│ time_total_s                   95.00043 │
│ training_iteration                    8 │
│ acc                             0.23529 │
│ loss                            0.08811 │
╰─────────────────────────────────────────╯
Training saved a checkpoint for iteration 8 at: (local)/mnt/cluster_storage/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/checkpoint_000007
(RayTrainWorker pid=671, ip=10.108.2.65) Epoch 8-train Loss: 0.0893 Acc: 0.2459
(RayTrainWorker pid=671, ip=10.108.2.65) Epoch 8-val Loss: 0.0859 Acc: 0.2353
(RayTrainWorker pid=589, ip=10.108.1.83) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/mnt/cluster_storage/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/checkpoint_000008) [repeated 4x across cluster]

Training finished iteration 9 at 2024-04-29 17:27:36. Total running time: 1min 1s
╭─────────────────────────────────────────╮
│ Training result                         │
├─────────────────────────────────────────┤
│ checkpoint_dir_name   checkpoint_000008 │
│ time_this_iter_s                5.99923 │
│ time_total_s                  100.99965 │
│ training_iteration                    9 │
│ acc                             0.23529 │
│ loss                            0.08592 │
╰─────────────────────────────────────────╯
Training saved a checkpoint for iteration 9 at: (local)/mnt/cluster_storage/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/checkpoint_000008
2024-04-29 17:27:37,170 WARNING util.py:202 -- The `process_trial_save` operation took 0.540 s, which may be a performance bottleneck.
(RayTrainWorker pid=671, ip=10.108.2.65) Epoch 9-train Loss: 0.0866 Acc: 0.2377
(RayTrainWorker pid=671, ip=10.108.2.65) Epoch 9-val Loss: 0.0833 Acc: 0.2353
(RayTrainWorker pid=589, ip=10.108.1.83) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/mnt/cluster_storage/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/checkpoint_000009) [repeated 4x across cluster]

Training finished iteration 10 at 2024-04-29 17:27:43. Total running time: 1min 8s
╭─────────────────────────────────────────╮
│ Training result                         │
├─────────────────────────────────────────┤
│ checkpoint_dir_name   checkpoint_000009 │
│ time_this_iter_s                6.71457 │
│ time_total_s                  107.71422 │
│ training_iteration                   10 │
│ acc                             0.23529 │
│ loss                            0.08333 │
╰─────────────────────────────────────────╯
Training saved a checkpoint for iteration 10 at: (local)/mnt/cluster_storage/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/checkpoint_000009

Training completed after 10 iterations at 2024-04-29 17:27:45. Total running time: 1min 9s
2024-04-29 17:27:46,236 WARNING experiment_state.py:323 -- Experiment checkpoint syncing has been triggered multiple times in the last 30.0 seconds. A sync will be triggered whenever a trial has checkpointed more than `num_to_keep` times since last sync or if 300 seconds have passed since last sync. If you have set `num_to_keep` in your `CheckpointConfig`, consider increasing the checkpoint frequency or keeping more checkpoints. You can suppress this warning by changing the `TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S` environment variable.

Result(
  metrics={'loss': 0.08333033206416111, 'acc': 0.23529411764705882},
  path='/mnt/cluster_storage/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29',
  filesystem='local',
  checkpoint=Checkpoint(filesystem=local, path=/mnt/cluster_storage/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/checkpoint_000009)
)
(RayTrainWorker pid=590, ip=10.108.3.76) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/mnt/cluster_storage/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/checkpoint_000009) [repeated 3x across cluster]