保存和加载检查点#
Ray Train 提供了一种使用 Checkpoint 来快照训练进度的机制。
这对于以下情况非常有用:
存储表现最佳的模型权重:将模型保存到持久化存储,并用于下游的服务或推理。
容错:在长时间运行的训练作业中处理 worker 进程和节点的故障,并利用可抢占的机器。
分布式检查点:Ray Train 检查点可用于 并行上传多个 worker 的模型分片。
训练期间保存检查点#
``Checkpoint`` 是 Ray Train 提供的一个轻量级接口,它表示存在于本地或远程存储中的一个目录。
例如,一个检查点可能指向云存储中的一个目录:s3://my-bucket/my-checkpoint-dir。本地可用的检查点指向本地文件系统上的一个位置:/tmp/my-checkpoint-dir。
在训练循环中保存检查点的方法如下:
将模型检查点写入本地目录。
由于 ``
Checkpoint`` 只是指向一个目录,因此其内容完全由您决定。这意味着您可以使用任何您想要的序列化格式。
这使得使用训练框架提供的常用检查点实用程序非常方便,例如 ``
torch.save``、``pl.Trainer.save_checkpoint``、Accelerate 的 ``accelerator.save_model``、Transformers 的 ``save_pretrained``、``tf.keras.Model.save`` 等。
使用 ``
Checkpoint.from_directory`` 从该目录创建 ``Checkpoint``。使用 ``
ray.train.report(metrics, checkpoint=...)`` 将检查点报告给 Ray Train。与检查点一起报告的指标用于跟踪表现最佳的检查点。
这将在配置时上传检查点到持久化存储。请参阅配置持久化存储。
``Checkpoint`` 的生命周期,从本地保存到磁盘到通过 ``train.report`` 上传到持久化存储。#
如上图所示,保存检查点的最佳实践是首先将检查点转储到本地临时目录。然后,对 ``train.report`` 的调用会将检查点上传到其最终的持久化存储位置。然后,可以安全地清理本地临时目录以释放磁盘空间(例如,通过退出 ``tempfile.TemporaryDirectory`` 上下文)。
提示
在标准的 DDP 训练中,每个 worker 都拥有完整的模型副本,您应该只从单个 worker 保存和报告检查点,以防止重复上传。
这通常看起来像
import tempfile
from ray import train
def train_fn(config):
...
metrics = {...}
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
checkpoint = None
# Only the global rank 0 worker saves and reports the checkpoint
if train.get_context().get_world_rank() == 0:
... # Save checkpoint to temp_checkpoint_dir
checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)
train.report(metrics, checkpoint=checkpoint)
如果使用 DeepSpeed Zero 和 FSDP 等并行训练策略,其中每个 worker 只拥有完整训练状态的一个分片,那么您可以从每个 worker 保存和报告检查点。有关示例,请参阅从多个 worker 保存检查点(分布式检查点)。
以下是使用不同训练框架保存检查点的几个示例
import os
import tempfile
import numpy as np
import torch
import torch.nn as nn
from torch.optim import Adam
import ray.train.torch
from ray import train
from ray.train import Checkpoint, ScalingConfig
from ray.train.torch import TorchTrainer
def train_func(config):
n = 100
# create a toy dataset
# data : X - dim = (n, 4)
# target : Y - dim = (n, 1)
X = torch.Tensor(np.random.normal(0, 1, size=(n, 4)))
Y = torch.Tensor(np.random.uniform(0, 1, size=(n, 1)))
# toy neural network : 1-layer
# Wrap the model in DDP
model = ray.train.torch.prepare_model(nn.Linear(4, 1))
criterion = nn.MSELoss()
optimizer = Adam(model.parameters(), lr=3e-4)
for epoch in range(config["num_epochs"]):
y = model.forward(X)
loss = criterion(y, Y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
metrics = {"loss": loss.item()}
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
checkpoint = None
should_checkpoint = epoch % config.get("checkpoint_freq", 1) == 0
# In standard DDP training, where the model is the same across all ranks,
# only the global rank 0 worker needs to save and report the checkpoint
if train.get_context().get_world_rank() == 0 and should_checkpoint:
torch.save(
model.module.state_dict(), # NOTE: Unwrap the model.
os.path.join(temp_checkpoint_dir, "model.pt"),
)
checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)
train.report(metrics, checkpoint=checkpoint)
trainer = TorchTrainer(
train_func,
train_loop_config={"num_epochs": 5},
scaling_config=ScalingConfig(num_workers=2),
)
result = trainer.fit()
提示
您最可能希望在将 DDP 模型解包后再将其保存到检查点。 ``model.module.state_dict()`` 是不带 ``"module."`` 前缀的 state dict。
Ray Train 利用 PyTorch Lightning 的 ``Callback`` 接口来报告指标和检查点。我们提供了一个简单的回调实现,该实现会在 ``on_train_epoch_end`` 时报告。
具体来说,在每个训练 epoch 结束时,它会:
从 ``
trainer.callback_metrics`` 收集所有记录的指标通过 ``
trainer.save_checkpoint`` 保存检查点通过 ``
ray.train.report(metrics, checkpoint)`` 报告给 Ray Train
import pytorch_lightning as pl
from ray import train
from ray.train.lightning import RayTrainReportCallback
from ray.train.torch import TorchTrainer
class MyLightningModule(pl.LightningModule):
# ...
def on_validation_epoch_end(self):
...
mean_acc = calculate_accuracy()
self.log("mean_accuracy", mean_acc, sync_dist=True)
def train_func():
...
model = MyLightningModule(...)
datamodule = MyLightningDataModule(...)
trainer = pl.Trainer(
# ...
callbacks=[RayTrainReportCallback()]
)
trainer.fit(model, datamodule=datamodule)
ray_trainer = TorchTrainer(
train_func,
scaling_config=train.ScalingConfig(num_workers=2),
run_config=train.RunConfig(
checkpoint_config=train.CheckpointConfig(
num_to_keep=2,
checkpoint_score_attribute="mean_accuracy",
checkpoint_score_order="max",
),
),
)
您始终可以通过 ``result.checkpoint`` 和 ``result.best_checkpoints`` 获取已保存的检查点路径。
对于更高级的用法(例如,以不同的频率报告,报告自定义检查点文件),您可以实现自己的自定义回调。这是一个每 3 个 epoch 报告一次检查点的简单示例
import os
from tempfile import TemporaryDirectory
from pytorch_lightning.callbacks import Callback
import ray
import ray.train
from ray.train import Checkpoint
class CustomRayTrainReportCallback(Callback):
def on_train_epoch_end(self, trainer, pl_module):
should_checkpoint = trainer.current_epoch % 3 == 0
with TemporaryDirectory() as tmpdir:
# Fetch metrics from `self.log(..)` in the LightningModule
metrics = trainer.callback_metrics
metrics = {k: v.item() for k, v in metrics.items()}
# Add customized metrics
metrics["epoch"] = trainer.current_epoch
metrics["custom_metric"] = 123
checkpoint = None
global_rank = ray.train.get_context().get_world_rank() == 0
if global_rank == 0 and should_checkpoint:
# Save model checkpoint file to tmpdir
ckpt_path = os.path.join(tmpdir, "ckpt.pt")
trainer.save_checkpoint(ckpt_path, weights_only=False)
checkpoint = Checkpoint.from_directory(tmpdir)
# Report to train session
ray.train.report(metrics=metrics, checkpoint=checkpoint)
Ray Train 利用 Hugging Face Transformers Trainer 的 ``Callback`` 接口来报告指标和检查点。
选项 1:使用 Ray Train 的默认报告回调
我们提供了一个简单的回调实现 ``RayTrainReportCallback``,它会在检查点保存时进行报告。您可以通过 ``save_strategy`` 和 ``save_steps`` 来更改检查点频率。它会收集最新的记录指标,并与最新的已保存检查点一起报告。
from transformers import TrainingArguments
from ray import train
from ray.train.huggingface.transformers import RayTrainReportCallback, prepare_trainer
from ray.train.torch import TorchTrainer
def train_func(config):
...
# Configure logging, saving, evaluation strategies as usual.
args = TrainingArguments(
...,
evaluation_strategy="epoch",
save_strategy="epoch",
logging_strategy="step",
)
trainer = transformers.Trainer(args, ...)
# Add a report callback to transformers Trainer
# =============================================
trainer.add_callback(RayTrainReportCallback())
trainer = prepare_trainer(trainer)
trainer.train()
ray_trainer = TorchTrainer(
train_func,
run_config=train.RunConfig(
checkpoint_config=train.CheckpointConfig(
num_to_keep=3,
checkpoint_score_attribute="eval_loss", # The monitoring metric
checkpoint_score_order="min",
)
),
)
请注意,``RayTrainReportCallback`` 将最新的指标和检查点绑定在一起,因此用户可以正确配置 ``logging_strategy``、``save_strategy`` 和 ``evaluation_strategy``,以确保监控指标与检查点保存同步记录。
例如,评估指标(此处为 ``eval_loss``)在评估期间记录。如果用户希望根据 ``eval_loss`` 保留最佳的 3 个检查点,他们应该使保存和评估频率保持一致。以下是两种有效配置的示例
args = TrainingArguments(
...,
evaluation_strategy="epoch",
save_strategy="epoch",
)
args = TrainingArguments(
...,
evaluation_strategy="steps",
save_strategy="steps",
eval_steps=50,
save_steps=100,
)
# And more ...
选项 2:实现您自定义的报告回调
如果您觉得 Ray Train 的默认 ``RayTrainReportCallback`` 无法满足您的用例,您也可以自己实现回调!下面是一个收集最新指标并在检查点保存时报告的示例实现。
from ray import train
from transformers.trainer_callback import TrainerCallback
class MyTrainReportCallback(TrainerCallback):
def __init__(self):
super().__init__()
self.metrics = {}
def on_log(self, args, state, control, model=None, logs=None, **kwargs):
"""Log is called on evaluation step and logging step."""
self.metrics.update(logs)
def on_save(self, args, state, control, **kwargs):
"""Event called after a checkpoint save."""
checkpoint = None
if train.get_context().get_world_rank() == 0:
# Build a Ray Train Checkpoint from the latest checkpoint
checkpoint_path = transformers.trainer.get_last_checkpoint(args.output_dir)
checkpoint = Checkpoint.from_directory(checkpoint_path)
# Report to Ray Train with up-to-date metrics
ray.train.report(metrics=self.metrics, checkpoint=checkpoint)
# Clear the metrics buffer
self.metrics = {}
您可以通过实现自己的 Transformers Trainer 回调来定制报告的时机(``on_save``、``on_epoch_end``、``on_evaluate``)和内容(自定义指标和检查点文件)。
从多个 worker 保存检查点(分布式检查点)#
在模型并行训练策略中,每个 worker 只拥有完整模型的一个分片,您可以从每个 worker 并行保存和报告检查点分片。
Ray Train 中的分布式检查点。每个 worker 独立将其自己的检查点分片上传到持久化存储。#
在进行模型并行训练(例如 DeepSpeed、FSDP、Megatron-LM)时,分布式检查点是保存检查点的最佳实践。
主要有两个好处:
速度更快,导致更少的空闲时间。更快的检查点可以鼓励更频繁的检查点!
每个 worker 可以并行上传其检查点分片,从而最大化集群的网络带宽。集群不是由一个节点上传大小为 ``
M`` 的完整模型,而是将负载分散到 ``N`` 个节点上,每个节点上传大小为 ``M / N`` 的分片。分布式检查点避免了需要将整个模型收集到单个 worker 的 CPU 内存中。
这个收集操作对执行检查点的 worker 提出了很高的 CPU 内存要求,并且是常见的 OOM 错误来源。
以下是使用 PyTorch 进行分布式检查点的示例
from ray import train
from ray.train import Checkpoint
from ray.train.torch import TorchTrainer
def train_func(config):
...
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
rank = train.get_context().get_world_rank()
torch.save(
...,
os.path.join(temp_checkpoint_dir, f"model-rank={rank}.pt"),
)
checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)
train.report(metrics, checkpoint=checkpoint)
trainer = TorchTrainer(
train_func,
scaling_config=train.ScalingConfig(num_workers=2),
run_config=train.RunConfig(storage_path="s3://bucket/"),
)
# The checkpoint in cloud storage will contain: model-rank=0.pt, model-rank=1.pt
注意
具有相同名称的检查点文件会在 worker 之间发生冲突。您可以通过在检查点文件名中添加特定于 rank 的后缀来解决此问题。
请注意,文件名冲突不会导致错误,但结果将是最后上传的版本被持久化。如果文件内容在所有 worker 中都相同,则没有问题。
DeepSpeed 等框架提供的模型分片保存实用程序会自动创建特定于 rank 的文件名,因此您通常无需担心此问题。
检查点上传模式#
默认情况下,当您调用 ``report()`` 时,Ray Train 会将您的检查点从本地磁盘上的 ``checkpoint.path`` 同步推送到您 ``storage_path`` 上的 ``checkpoint_dir_name``。这等同于调用 ``report()`` 时将 ``CheckpointUploadMode`` 设置为 ``ray.train.CheckpointUploadMode.SYNC``。
def train_fn(config):
...
metrics = {...}
with tempfile.TemporaryDirectory() as tmpdir:
... # Save checkpoint to tmpdir
checkpoint = Checkpoint.from_directory(tmpdir)
train.report(
metrics,
checkpoint=checkpoint,
checkpoint_upload_mode=train.CheckpointUploadMode.SYNC,
)
异步检查点上传#
您可能希望异步上传检查点,以便下一个训练步骤可以并行开始。如果是这样,您应该使用 ``ray.train.CheckpointUploadMode.ASYNC``,它会启动一个新线程来上传检查点。这对于可能需要更长时间才能上传的大型检查点很有用,但如果您只想立即上传小型检查点,可能会增加不必要的复杂性(见下文)。
每个 ``report`` 调用都会阻塞,直到前一个 ``report`` 的检查点上传完成,然后才开始新的检查点上传线程。Ray Train 这样做是为了避免累积过多的上传线程并可能耗尽内存。
由于 ``report`` 返回时不会等待检查点上传完成,因此您必须确保本地检查点目录在上传完成之前一直存在。这意味着您不能使用 Ray Train 可能会在上传完成之前删除的临时目录,例如来自 ``tempfile.TemporaryDirectory``。``report`` 还公开了 ``delete_local_checkpoint_after_upload`` 参数,当 ``checkpoint_upload_mode`` 为 ``ray.train.CheckpointUploadMode.ASYNC`` 时,该参数默认为 ``True``。
def train_fn(config):
...
metrics = {...}
tmpdir = tempfile.mkdtemp()
... # Save checkpoint to tmpdir
checkpoint = Checkpoint.from_directory(tmpdir)
train.report(
metrics,
checkpoint=checkpoint,
checkpoint_upload_mode=train.CheckpointUploadMode.ASYNC,
)
此图说明了同步和异步检查点上传之间的区别。#
自定义检查点上传#
``report()`` 默认使用 PyArrow 的文件系统复制实用程序将检查点从磁盘上传到远程 ``storage_path``,然后再将检查点报告给 Ray Train。如果您更愿意手动上传检查点或使用第三方库(如 Torch Distributed Checkpointing),则有以下选项:
如果您想同步上传检查点,可以先将检查点上传到 ``storage_path``,然后使用 ``ray.train.CheckpointUploadMode.NO_UPLOAD`` 报告对已上传检查点的引用。
from s3torchconnector.dcp import S3StorageWriter
from torch.distributed.checkpoint.state_dict_saver import save
from torch.distributed.checkpoint.state_dict import get_state_dict
def train_fn(config):
...
for epoch in range(config["num_epochs"]):
# Directly upload checkpoint to s3 with Torch
model, optimizer = ...
storage_context = ray.train.get_context().get_storage()
checkpoint_path = (
f"s3://{storage_context.build_checkpoint_path_from_name(str(epoch))}"
)
storage_writer = S3StorageWriter(region="us-west-2", path=checkpoint_path)
model_dict, opt_dict = get_state_dict(model=model, optimizers=optimizer)
save(
{"model": model_dict, "opt": opt_dict},
storage_writer=storage_writer,
)
# Report that checkpoint to Ray Train
metrics = {...}
checkpoint = Checkpoint(checkpoint_path)
train.report(
metrics,
checkpoint=checkpoint,
checkpoint_upload_mode=train.CheckpointUploadMode.NO_UPLOAD,
)
如果您想异步上传检查点,可以将 ``checkpoint_upload_mode`` 设置为 ``ray.train.CheckpointUploadMode.ASYNC``,并将 ``checkpoint_upload_fn`` 传递给 ``ray.train.report``。此函数接受传递给 ``ray.train.report`` 的 ``Checkpoint`` 和 ``checkpoint_dir_name``,并返回已持久化的 ``Checkpoint``。
from torch.distributed.checkpoint.state_dict_saver import async_save
from s3torchconnector.dcp import S3StorageWriter
from torch.distributed.checkpoint.state_dict import get_state_dict
from ray import train
from ray.train import Checkpoint
def train_fn(config):
...
for epoch in config["num_epochs"]:
# Start async checkpoint upload to s3 with Torch
model, optimizer = ...
storage_context = train.get_context().get_storage()
checkpoint_path = (
f"s3://{storage_context.build_checkpoint_path_from_name(str(epoch))}"
)
storage_writer = S3StorageWriter(region="us-west-2", path=checkpoint_path)
model_dict, opt_dict = get_state_dict(model=model, optimizers=optimizer)
ckpt_ref = async_save(
{"model": model_dict, "opt": opt_dict},
storage_writer=storage_writer,
)
def wait_async_save(checkpoint, checkpoint_dir_name):
# This function waits for checkpoint to be finalized before returning it as is
ckpt_ref.result()
return checkpoint
# Ray Train kicks off a thread that waits for the async checkpoint upload to complete
# before reporting the checkpoint
metrics = {...}
checkpoint = Checkpoint(checkpoint_path)
train.report(
metrics=metrics,
checkpoint=checkpoint,
checkpoint_upload_mode=train.CheckpointUploadMode.ASYNC,
checkpoint_upload_function=wait_async_save,
)
警告
在您的 ``checkpoint_upload_fn`` 中,不应调用 ``ray.train.report``,这可能会导致意外行为。您还应避免集体操作,例如 ``report()`` 或 ``model.state_dict()``,这可能导致死锁。
注意
请勿将 ``checkpoint_upload_fn`` 与 ``checkpoint_upload_mode=ray.train.CheckpointUploadMode.NO_UPLOAD`` 一起传递,因为 Ray Train 将简单地忽略 ``checkpoint_upload_fn``。您可以将 ``checkpoint_upload_fn`` 与 ``checkpoint_upload_mode=ray.train.CheckpointUploadMode.SYNC`` 一起传递,但这等同于自己上传检查点,并使用 ``ray.train.CheckpointUploadMode.NO_UPLOAD`` 报告检查点。
配置检查点#
Ray Train 通过 ``CheckpointConfig`` 提供了一些用于检查点的配置选项。主要的配置是仅保留与某个指标相关的排名前 ``K`` 的检查点。较低性能的检查点将被删除以节省存储空间。默认情况下,保留所有检查点。
from ray.train import RunConfig, CheckpointConfig
# Example 1: Only keep the 2 *most recent* checkpoints and delete the others.
run_config = RunConfig(checkpoint_config=CheckpointConfig(num_to_keep=2))
# Example 2: Only keep the 2 *best* checkpoints and delete the others.
run_config = RunConfig(
checkpoint_config=CheckpointConfig(
num_to_keep=2,
# *Best* checkpoints are determined by these params:
checkpoint_score_attribute="mean_accuracy",
checkpoint_score_order="max",
),
# This will store checkpoints on S3.
storage_path="s3://remote-bucket/location",
)
注意
如果您想通过 ``CheckpointConfig`` 保存与某个指标相关的排名前 ``num_to_keep`` 的检查点,请确保该指标始终与检查点一起报告。
训练期间使用检查点#
在训练期间,您可能需要出于各种原因(例如,将表现最佳的检查点报告给实验跟踪器)访问您已报告的检查点及其关联指标。您可以通过在训练函数中调用 ``get_all_reported_checkpoints()`` 来实现这一点。此函数返回一个 ``ReportedCheckpoint`` 对象列表,这些对象代表您迄今为止已报告的所有 ``Checkpoint`` 及其关联指标,并且这些指标已根据检查点配置进行了保留。
此函数支持两种一致性模式:
CheckpointConsistencyMode.COMMITTED:阻塞直到最新 ``ray.train.report`` 的检查点已上传到持久化存储并提交。CheckpointConsistencyMode.VALIDATED:阻塞直到最新 ``ray.train.report`` 的检查点已上传到持久化存储、提交并验证(请参阅异步验证检查点)。这是默认的一致性模式,如果您的报告未启动验证,则行为与 ``CheckpointConsistencyMode.COMMITTED`` 相同。
import ray.train
from ray.train import CheckpointConsistencyMode
def train_fn():
for epoch in range(2):
metrics = {"train/loss": 0.1}
checkpoint = ...
ray.train.report(
metrics,
checkpoint=checkpoint,
validate_fn=...,
validate_config=...,
)
# Get committed checkpoints which may still have ongoing validations.
committed_checkpoints = ray.train.get_all_reported_checkpoints(
consistency_mode=CheckpointConsistencyMode.COMMITTED)
# Wait for all pending validations to finish to access reported checkpoints
# with validation metrics attached.
validated_checkpoints = ray.train.get_all_reported_checkpoints()
...
训练后使用检查点#
可以使用 ``Result.checkpoint`` 访问最新保存的检查点。
可以使用 ``Result.best_checkpoints`` 访问所有已持久化的检查点。如果设置了 ``CheckpointConfig(num_to_keep)``,则此列表将包含最佳的 ``num_to_keep`` 个检查点。
有关检查训练结果的完整指南,请参阅检查训练结果。
``Checkpoint.as_directory`` 和 ``Checkpoint.to_directory`` 是与 Train 检查点交互的两个主要 API。
from pathlib import Path
from ray.train import Checkpoint
# For demonstration, create a locally available directory with a `model.pt` file.
example_checkpoint_dir = Path("/tmp/test-checkpoint")
example_checkpoint_dir.mkdir()
example_checkpoint_dir.joinpath("model.pt").touch()
# Create the checkpoint, which is a reference to the directory.
checkpoint = Checkpoint.from_directory(example_checkpoint_dir)
# Inspect the checkpoint's contents with either `as_directory` or `to_directory`:
with checkpoint.as_directory() as checkpoint_dir:
assert Path(checkpoint_dir).joinpath("model.pt").exists()
checkpoint_dir = checkpoint.to_directory()
assert Path(checkpoint_dir).joinpath("model.pt").exists()
对于 Lightning 和 Transformers,如果您在训练函数中使用默认的 ``RayTrainReportCallback`` 来保存检查点,您可以按如下方式检索原始检查点文件:
# After training finished
checkpoint = result.checkpoint
with checkpoint.as_directory() as checkpoint_dir:
lightning_checkpoint_path = f"{checkpoint_dir}/checkpoint.ckpt"
# After training finished
checkpoint = result.checkpoint
with checkpoint.as_directory() as checkpoint_dir:
hf_checkpoint_path = f"{checkpoint_dir}/checkpoint/"
从检查点恢复训练状态#
为了实现容错,您应该修改训练循环以从 ``Checkpoint`` 恢复训练状态。
训练函数中可以访问要恢复的 ``Checkpoint``,方法是调用 ``ray.train.get_checkpoint``。
``ray.train.get_checkpoint`` 返回的检查点是在自动故障恢复期间填充的最新报告检查点。
有关恢复和容错的更多详细信息,请参阅处理故障和节点抢占。
import os
import tempfile
import numpy as np
import torch
import torch.nn as nn
from torch.optim import Adam
import ray.train.torch
from ray import train
from ray.train import Checkpoint, ScalingConfig
from ray.train.torch import TorchTrainer
def train_func(config):
n = 100
# create a toy dataset
# data : X - dim = (n, 4)
# target : Y - dim = (n, 1)
X = torch.Tensor(np.random.normal(0, 1, size=(n, 4)))
Y = torch.Tensor(np.random.uniform(0, 1, size=(n, 1)))
# toy neural network : 1-layer
model = nn.Linear(4, 1)
optimizer = Adam(model.parameters(), lr=3e-4)
criterion = nn.MSELoss()
# Wrap the model in DDP and move it to GPU.
model = ray.train.torch.prepare_model(model)
# ====== Resume training state from the checkpoint. ======
start_epoch = 0
checkpoint = train.get_checkpoint()
if checkpoint:
with checkpoint.as_directory() as checkpoint_dir:
model_state_dict = torch.load(
os.path.join(checkpoint_dir, "model.pt"),
# map_location=..., # Load onto a different device if needed.
)
model.module.load_state_dict(model_state_dict)
optimizer.load_state_dict(
torch.load(os.path.join(checkpoint_dir, "optimizer.pt"))
)
start_epoch = (
torch.load(os.path.join(checkpoint_dir, "extra_state.pt"))["epoch"] + 1
)
# ========================================================
for epoch in range(start_epoch, config["num_epochs"]):
y = model.forward(X)
loss = criterion(y, Y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
metrics = {"loss": loss.item()}
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
checkpoint = None
should_checkpoint = epoch % config.get("checkpoint_freq", 1) == 0
# In standard DDP training, where the model is the same across all ranks,
# only the global rank 0 worker needs to save and report the checkpoint
if train.get_context().get_world_rank() == 0 and should_checkpoint:
# === Make sure to save all state needed for resuming training ===
torch.save(
model.module.state_dict(), # NOTE: Unwrap the model.
os.path.join(temp_checkpoint_dir, "model.pt"),
)
torch.save(
optimizer.state_dict(),
os.path.join(temp_checkpoint_dir, "optimizer.pt"),
)
torch.save(
{"epoch": epoch},
os.path.join(temp_checkpoint_dir, "extra_state.pt"),
)
# ================================================================
checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)
train.report(metrics, checkpoint=checkpoint)
if epoch == 1:
raise RuntimeError("Intentional error to showcase restoration!")
trainer = TorchTrainer(
train_func,
train_loop_config={"num_epochs": 5},
scaling_config=ScalingConfig(num_workers=2),
run_config=train.RunConfig(failure_config=train.FailureConfig(max_failures=1)),
)
result = trainer.fit()
import os
from ray import train
from ray.train import Checkpoint
from ray.train.torch import TorchTrainer
from ray.train.lightning import RayTrainReportCallback
def train_func():
model = MyLightningModule(...)
datamodule = MyLightningDataModule(...)
trainer = pl.Trainer(..., callbacks=[RayTrainReportCallback()])
checkpoint = train.get_checkpoint()
if checkpoint:
with checkpoint.as_directory() as ckpt_dir:
ckpt_path = os.path.join(ckpt_dir, RayTrainReportCallback.CHECKPOINT_NAME)
trainer.fit(model, datamodule=datamodule, ckpt_path=ckpt_path)
else:
trainer.fit(model, datamodule=datamodule)
ray_trainer = TorchTrainer(
train_func,
scaling_config=train.ScalingConfig(num_workers=2),
run_config=train.RunConfig(
checkpoint_config=train.CheckpointConfig(num_to_keep=2),
),
)
注意
在这些示例中,``Checkpoint.as_directory`` 用于将检查点内容视为本地目录。
如果检查点指向本地目录,则此方法仅返回本地目录路径而不进行复制。
如果检查点指向远程目录,则此方法会将检查点下载到本地临时目录并返回临时目录的路径。
如果同一节点上的多个进程同时调用此方法,则只有一个进程会执行下载,而其他进程则等待下载完成。下载完成后,所有进程都会收到相同的本地(临时)目录进行读取。
一旦所有进程完成处理检查点,临时目录将被清理。