如何保存和加载 Trial Checkpoint#
Trial checkpoint 是 Tune 存储的三种数据类型之一。它们是用户定义的,用于快照你的训练进度!
Trial 级别的 checkpoint 通过 Tune Trainable API 保存:这是你定义自定义训练逻辑的方式,也是你定义要 checkpoint 的 Trial 状态的地方。本指南将展示如何为 Tune 的 Function Trainable 和 Class Trainable API 保存和加载 checkpoint,并引导你了解配置选项。
Function API Checkpointing#
如果使用 Ray Tune 的 Function API,可以按照以下方式保存和加载 checkpoint。要创建 checkpoint,请使用 from_directory()
API。
import os
import tempfile
from ray import tune
from ray.tune import Checkpoint
def train_func(config):
start = 1
my_model = MyModel()
checkpoint = tune.get_checkpoint()
if checkpoint:
with checkpoint.as_directory() as checkpoint_dir:
checkpoint_dict = torch.load(os.path.join(checkpoint_dir, "checkpoint.pt"))
start = checkpoint_dict["epoch"] + 1
my_model.load_state_dict(checkpoint_dict["model_state"])
for epoch in range(start, config["epochs"] + 1):
# Model training here
# ...
metrics = {"metric": 1}
with tempfile.TemporaryDirectory() as tempdir:
torch.save(
{"epoch": epoch, "model_state": my_model.state_dict()},
os.path.join(tempdir, "checkpoint.pt"),
)
tune.report(metrics=metrics, checkpoint=Checkpoint.from_directory(tempdir))
tuner = tune.Tuner(train_func, param_space={"epochs": 5})
result_grid = tuner.fit()
在上面的代码片段中
我们使用
tune.report(..., checkpoint=checkpoint)
实现checkpoint 保存。请注意,每个 checkpoint 必须与一组指标一起报告——这样,checkpoint 就可以根据指定的指标排序。训练迭代
epoch
期间保存的 checkpoint 将保存到训练发生节点上的路径<storage_path>/<exp_name>/<trial_name>/checkpoint_<epoch>
,并可以根据存储配置进一步同步到统一的存储位置。我们使用
tune.get_checkpoint()
实现checkpoint 加载。每当 Tune 恢复 Trial 时,此函数将填充 Trial 的最新 checkpoint。这发生在以下情况:(1)Trial 配置为在遇到故障后重试,(2)实验正在恢复,以及(3)Trial 在暂停后恢复(例如:PBT)。
注意
checkpoint_frequency
和 checkpoint_at_end
不适用于 Function API checkpointing。这些需要通过 Function Trainable 手动配置。例如,如果你想每三个 epoch checkpoint 一次,可以通过以下方式实现
NUM_EPOCHS = 12
# checkpoint every three epochs.
CHECKPOINT_FREQ = 3
def train_func(config):
for epoch in range(1, config["epochs"] + 1):
# Model training here
# ...
# Report metrics and save a checkpoint
metrics = {"metric": "my_metric"}
if epoch % CHECKPOINT_FREQ == 0:
with tempfile.TemporaryDirectory() as tempdir:
# Save a checkpoint in tempdir.
tune.report(metrics, checkpoint=Checkpoint.from_directory(tempdir))
else:
tune.report(metrics)
tuner = tune.Tuner(train_func, param_space={"epochs": NUM_EPOCHS})
result_grid = tuner.fit()
请参阅此处 了解 有关 创建 checkpoint
的更多信息。
Class API Checkpointing#
你还可以使用 Trainable Class API 实现 checkpoint/恢复
import os
import torch
from torch import nn
from ray import tune
class MyTrainableClass(tune.Trainable):
def setup(self, config):
self.model = nn.Sequential(
nn.Linear(config.get("input_size", 32), 32), nn.ReLU(), nn.Linear(32, 10)
)
def step(self):
return {}
def save_checkpoint(self, tmp_checkpoint_dir):
checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth")
torch.save(self.model.state_dict(), checkpoint_path)
return tmp_checkpoint_dir
def load_checkpoint(self, tmp_checkpoint_dir):
checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth")
self.model.load_state_dict(torch.load(checkpoint_path))
tuner = tune.Tuner(
MyTrainableClass,
param_space={"input_size": 64},
run_config=tune.RunConfig(
stop={"training_iteration": 2},
checkpoint_config=tune.CheckpointConfig(checkpoint_frequency=2),
),
)
tuner.fit()
你可以通过三种不同的机制进行 checkpoint:手动、周期性和终止时。
手动 Checkpointing#
自定义 Trainable 可以通过在 step
的结果字典中返回 should_checkpoint: True
(或 tune.result.SHOULD_CHECKPOINT: True
) 来手动触发 checkpointing。这在 Spot 实例中特别有用
import random
# to be implemented by user.
def detect_instance_preemption():
choice = random.randint(1, 100)
# simulating a 1% chance of preemption.
return choice <= 1
def train_func(self):
# training code
result = {"mean_accuracy": "my_accuracy"}
if detect_instance_preemption():
result.update(should_checkpoint=True)
return result
在上面的例子中,如果 detect_instance_preemption
返回 True,则可以触发手动 checkpointing。
周期性 Checkpointing#
这可以通过设置 checkpoint_frequency=N
来启用,以便每 N 次迭代 checkpoint Trial,例如
tuner = tune.Tuner(
MyTrainableClass,
run_config=tune.RunConfig(
stop={"training_iteration": 2},
checkpoint_config=tune.CheckpointConfig(checkpoint_frequency=10),
),
)
tuner.fit()
终止时 Checkpointing#
checkpoint_frequency
可能与实验的精确结束时间不一致。如果你希望在 Trial 结束时创建一个 checkpoint,你可以额外设置 checkpoint_at_end=True
tuner = tune.Tuner(
MyTrainableClass,
run_config=tune.RunConfig(
stop={"training_iteration": 2},
checkpoint_config=tune.CheckpointConfig(
checkpoint_frequency=10, checkpoint_at_end=True
),
),
)
tuner.fit()
配置#
可以通过 CheckpointConfig
配置 Checkpointing。某些配置不适用于 Function Trainable API,因为 checkpointing 频率是在用户定义的训练循环中手动确定的。请参阅下面的兼容性矩阵。
Class API |
Function API |
|
---|---|---|
|
✅ |
✅ |
|
✅ |
✅ |
|
✅ |
✅ |
|
✅ |
❌ |
|
✅ |
❌ |
总结#
在本用户指南中,我们介绍了如何在 Tune 中保存和加载 Trial checkpoint。启用 checkpointing 后,请继续阅读以下指南之一,了解如何
附录:Tune 存储的数据类型#
实验 Checkpoint#
实验级 checkpoint 保存实验状态。这包括搜索器的状态、Trial 列表及其状态(例如 PENDING、RUNNING、TERMINATED、ERROR),以及与每个 Trial 相关的元数据(例如超参数配置、一些派生的 Trial 结果(最小、最大、最后)等)。
实验级 checkpoint 由驱动程序在 head 节点上周期性保存。默认情况下,保存频率会自动调整,以便最多 5% 的时间用于保存实验 checkpoint,其余时间用于处理训练结果和调度。此时间也可以通过 TUNE_GLOBAL_CHECKPOINT_S 环境变量进行调整。
Trial Checkpoint#
Trial 级别的 checkpoint 捕获每个 Trial 的状态。这通常包括模型和优化器状态。Trial checkpoint 的一些用途如下
如果 Trial 因某些原因中断(例如在 Spot 实例上),可以从上次状态恢复。不会丢失训练时间。
一些搜索器或调度器会暂停 Trial,以便为其他 Trial 腾出资源进行训练。这只有在 Trial 能够从最新状态继续训练时才有意义。
该 checkpoint 稍后可用于其他下游任务,例如批量推理。
在此了解如何保存和加载 Trial checkpoint。
Trial 结果#
Trial 报告的指标会保存并记录到各自的 Trial 目录中。这是以 CSV、JSON 或 Tensorboard (events.out.tfevents.*) 格式存储的数据,可以通过 Tensorboard 进行检查并用于实验后分析。