如何保存和加载实验检查点#
实验检查点是 Tune 存储的三种数据类型之一。这些是用户定义的,用于快照您的训练进度!
实验级别的检查点通过 Tune Trainable API 保存:这是您定义自定义训练逻辑的方式,也是您定义要检查的实验状态的地方。在本指南中,我们将展示如何为 Tune 的 Function Trainable 和 Class Trainable API 保存和加载检查点,并介绍配置选项。
函数 API 检查点#
如果使用 Ray Tune 的函数 API,可以按以下方式保存和加载检查点。要创建检查点,请使用 `from_directory()` API。
import os
import tempfile
from ray import tune
from ray.tune import Checkpoint
def train_func(config):
start = 1
my_model = MyModel()
checkpoint = tune.get_checkpoint()
if checkpoint:
with checkpoint.as_directory() as checkpoint_dir:
checkpoint_dict = torch.load(os.path.join(checkpoint_dir, "checkpoint.pt"))
start = checkpoint_dict["epoch"] + 1
my_model.load_state_dict(checkpoint_dict["model_state"])
for epoch in range(start, config["epochs"] + 1):
# Model training here
# ...
metrics = {"metric": 1}
with tempfile.TemporaryDirectory() as tempdir:
torch.save(
{"epoch": epoch, "model_state": my_model.state_dict()},
os.path.join(tempdir, "checkpoint.pt"),
)
tune.report(metrics=metrics, checkpoint=Checkpoint.from_directory(tempdir))
tuner = tune.Tuner(train_func, param_space={"epochs": 5})
result_grid = tuner.fit()
在上面的代码片段中
我们通过 `tune.report(..., checkpoint=checkpoint)` 实现检查点保存。请注意,每个检查点都必须与一组指标一起报告——这样,检查点就可以根据指定的指标进行排序。
在节点上训练时,训练迭代 `epoch` 期间保存的检查点会保存到路径 `
/ / /checkpoint_ `,并且可以根据存储配置进一步同步到统一的存储位置。 我们通过 `tune.get_checkpoint()` 实现检查点加载。当 Tune 恢复一个实验时,这个会用实验的最新检查点填充。当(1)实验配置为在遇到失败后重试,(2)实验正在恢复,以及(3)实验在暂停后恢复(例如 PBT)时,就会发生这种情况。
注意
`checkpoint_frequency` 和 `checkpoint_at_end` 将不适用于函数 API 检查点。这些是通过 Function Trainable 手动配置的。例如,如果您想每三个 epoch 进行一次检查点,可以通过以下方式实现:
NUM_EPOCHS = 12
# checkpoint every three epochs.
CHECKPOINT_FREQ = 3
def train_func(config):
for epoch in range(1, config["epochs"] + 1):
# Model training here
# ...
# Report metrics and save a checkpoint
metrics = {"metric": "my_metric"}
if epoch % CHECKPOINT_FREQ == 0:
with tempfile.TemporaryDirectory() as tempdir:
# Save a checkpoint in tempdir.
tune.report(metrics, checkpoint=Checkpoint.from_directory(tempdir))
else:
tune.report(metrics)
tuner = tune.Tuner(train_func, param_space={"epochs": NUM_EPOCHS})
result_grid = tuner.fit()
有关创建检查点的更多信息,请参阅 此处。
类 API 检查点#
您还可以使用 Trainable Class API 实现检查点/恢复
import os
import torch
from torch import nn
from ray import tune
class MyTrainableClass(tune.Trainable):
def setup(self, config):
self.model = nn.Sequential(
nn.Linear(config.get("input_size", 32), 32), nn.ReLU(), nn.Linear(32, 10)
)
def step(self):
return {}
def save_checkpoint(self, tmp_checkpoint_dir):
checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth")
torch.save(self.model.state_dict(), checkpoint_path)
return tmp_checkpoint_dir
def load_checkpoint(self, tmp_checkpoint_dir):
checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth")
self.model.load_state_dict(torch.load(checkpoint_path))
tuner = tune.Tuner(
MyTrainableClass,
param_space={"input_size": 64},
run_config=tune.RunConfig(
stop={"training_iteration": 2},
checkpoint_config=tune.CheckpointConfig(checkpoint_frequency=2),
),
)
tuner.fit()
您可以通过三种不同的机制进行检查点:手动、定期和在终止时。
通过 Trainable 手动检查点#
自定义 Trainable 可以通过在 `step` 的结果字典中返回 `should_checkpoint: True`(或 `tune.result.SHOULD_CHECKPOINT: True`)来手动触发检查点。这在抢占式实例上尤其有用
import random
# to be implemented by user.
def detect_instance_preemption():
choice = random.randint(1, 100)
# simulating a 1% chance of preemption.
return choice <= 1
def train_func(self):
# training code
result = {"mean_accuracy": "my_accuracy"}
if detect_instance_preemption():
result.update(should_checkpoint=True)
return result
在上面的示例中,如果 `detect_instance_preemption` 返回 True,则可以触发手动检查点。
通过 Tuner Callback 手动检查点#
与[通过 Trainable 手动检查点](#tune-class-trainable-checkpointing-manual-checkpointing)类似,您还可以通过 `Tuner` Callback 方法触发检查点,方法是在自定义回调的 `on_trial_result()` 方法中设置 `result["should_checkpoint"] = True`(或 `result[tune.result.SHOULD_CHECKPOINT] = True`)标志。与在 Trainable Class API 中进行检查点相比,此方法将检查点逻辑与训练逻辑分离,并提供了对所有 `Trial` 实例的访问,从而可以实现更复杂的检查点策略。
from ray import tune
from ray.tune.experiment import Trial
from ray.tune.result import SHOULD_CHECKPOINT, TRAINING_ITERATION
class CheckpointByStepsTaken(tune.Callback):
def __init__(self, iterations_per_checkpoint: int):
self.steps_per_checkpoint = iterations_per_checkpoint
self._trials_last_checkpoint = {}
def on_trial_result(
self, iteration: int, trials: list[Trial], trial: Trial, result: dict, **info
):
current_iteration = result[TRAINING_ITERATION]
if (
current_iteration - self._trials_last_checkpoint.get(trial, -1)
>= self.steps_per_checkpoint
):
result[SHOULD_CHECKPOINT] = True
self._trials_last_checkpoint[trial] = current_iteration
定期检查点#
可以通过设置 `checkpoint_frequency=N` 来实现,每 N 次迭代检查点一次,例如:
tuner = tune.Tuner(
MyTrainableClass,
run_config=tune.RunConfig(
stop={"training_iteration": 2},
checkpoint_config=tune.CheckpointConfig(checkpoint_frequency=10),
),
)
tuner.fit()
终止时检查点#
`checkpoint_frequency` 可能不与实验的确切结束点重合。如果您希望在实验结束时创建一个检查点,您可以额外设置 `checkpoint_at_end=True`
tuner = tune.Tuner(
MyTrainableClass,
run_config=tune.RunConfig(
stop={"training_iteration": 2},
checkpoint_config=tune.CheckpointConfig(
checkpoint_frequency=10, checkpoint_at_end=True
),
),
)
tuner.fit()
配置#
可以通过 `CheckpointConfig` 配置检查点。由于检查点频率是在用户定义的训练循环中手动确定的,因此其中一些配置不适用于 Function Trainable API。请参阅下面的兼容性矩阵。
类 API |
函数 API |
|
|---|---|---|
|
✅ |
✅ |
|
✅ |
✅ |
|
✅ |
✅ |
|
✅ |
❌ |
|
✅ |
❌ |
总结#
在本用户指南中,我们介绍了如何在 Tune 中保存和加载实验检查点。启用检查点后,请继续阅读以下指南,了解如何
附录:Tune 存储的数据类型#
实验检查点#
实验级别的检查点保存实验状态。这包括搜索器的状态、实验列表及其状态(例如,PENDING、RUNNING、TERMINATED、ERROR),以及每个实验相关的元数据(例如,超参数配置、一些派生的实验结果(min、max、last)等)。
实验级别的检查点由驱动程序定期在头节点上保存。默认情况下,其保存频率会自动调整,以便最多 5% 的时间用于保存实验检查点,剩余时间用于处理训练结果和调度。此时间也可以通过 `TUNE_GLOBAL_CHECKPOINT_S` 环境变量进行调整。
实验检查点#
实验级别的检查点捕获每个实验的状态。这通常包括模型和优化器状态。以下是实验检查点的几个用途:
如果实验因某种原因中断(例如,在抢占式实例上),可以从最后的状态恢复。不会丢失训练时间。
某些搜索器或调度程序会暂停实验,以便为其他实验腾出资源进行训练。这只有在实验可以从最新状态继续训练时才有意义。
之后,检查点可用于其他下游任务,如批量推理。
在此处了解如何保存和加载实验检查点。
实验结果#
实验报告的指标会被保存并记录到各自的实验目录中。这是以 CSV、JSON 或 Tensorboard (events.out.tfevents.*) 格式存储的数据,可以由 Tensorboard 检查并用于实验后分析。