如何保存和加载实验检查点#

实验检查点是 Tune 存储的三种数据类型之一。这些是用户定义的,用于快照您的训练进度!

实验级别的检查点通过 Tune Trainable API 保存:这是您定义自定义训练逻辑的方式,也是您定义要检查的实验状态的地方。在本指南中,我们将展示如何为 Tune 的 Function Trainable 和 Class Trainable API 保存和加载检查点,并介绍配置选项。

函数 API 检查点#

如果使用 Ray Tune 的函数 API,可以按以下方式保存和加载检查点。要创建检查点,请使用 `from_directory()` API。

import os
import tempfile

from ray import tune
from ray.tune import Checkpoint


def train_func(config):
    start = 1
    my_model = MyModel()

    checkpoint = tune.get_checkpoint()
    if checkpoint:
        with checkpoint.as_directory() as checkpoint_dir:
            checkpoint_dict = torch.load(os.path.join(checkpoint_dir, "checkpoint.pt"))
            start = checkpoint_dict["epoch"] + 1
            my_model.load_state_dict(checkpoint_dict["model_state"])

    for epoch in range(start, config["epochs"] + 1):
        # Model training here
        # ...

        metrics = {"metric": 1}
        with tempfile.TemporaryDirectory() as tempdir:
            torch.save(
                {"epoch": epoch, "model_state": my_model.state_dict()},
                os.path.join(tempdir, "checkpoint.pt"),
            )
            tune.report(metrics=metrics, checkpoint=Checkpoint.from_directory(tempdir))


tuner = tune.Tuner(train_func, param_space={"epochs": 5})
result_grid = tuner.fit()

在上面的代码片段中

  • 我们通过 `tune.report(..., checkpoint=checkpoint)` 实现检查点保存。请注意,每个检查点都必须与一组指标一起报告——这样,检查点就可以根据指定的指标进行排序。

  • 在节点上训练时,训练迭代 `epoch` 期间保存的检查点会保存到路径 `///checkpoint_`,并且可以根据存储配置进一步同步到统一的存储位置。

  • 我们通过 `tune.get_checkpoint()` 实现检查点加载。当 Tune 恢复一个实验时,这个会用实验的最新检查点填充。当(1)实验配置为在遇到失败后重试,(2)实验正在恢复,以及(3)实验在暂停后恢复(例如 PBT)时,就会发生这种情况。

注意

`checkpoint_frequency` 和 `checkpoint_at_end` 将不适用于函数 API 检查点。这些是通过 Function Trainable 手动配置的。例如,如果您想每三个 epoch 进行一次检查点,可以通过以下方式实现:

NUM_EPOCHS = 12
# checkpoint every three epochs.
CHECKPOINT_FREQ = 3


def train_func(config):
    for epoch in range(1, config["epochs"] + 1):
        # Model training here
        # ...

        # Report metrics and save a checkpoint
        metrics = {"metric": "my_metric"}
        if epoch % CHECKPOINT_FREQ == 0:
            with tempfile.TemporaryDirectory() as tempdir:
                # Save a checkpoint in tempdir.
                tune.report(metrics, checkpoint=Checkpoint.from_directory(tempdir))
        else:
            tune.report(metrics)


tuner = tune.Tuner(train_func, param_space={"epochs": NUM_EPOCHS})
result_grid = tuner.fit()

有关创建检查点的更多信息,请参阅 此处

类 API 检查点#

您还可以使用 Trainable Class API 实现检查点/恢复

import os
import torch
from torch import nn

from ray import tune


class MyTrainableClass(tune.Trainable):
    def setup(self, config):
        self.model = nn.Sequential(
            nn.Linear(config.get("input_size", 32), 32), nn.ReLU(), nn.Linear(32, 10)
        )

    def step(self):
        return {}

    def save_checkpoint(self, tmp_checkpoint_dir):
        checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth")
        torch.save(self.model.state_dict(), checkpoint_path)
        return tmp_checkpoint_dir

    def load_checkpoint(self, tmp_checkpoint_dir):
        checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth")
        self.model.load_state_dict(torch.load(checkpoint_path))


tuner = tune.Tuner(
    MyTrainableClass,
    param_space={"input_size": 64},
    run_config=tune.RunConfig(
        stop={"training_iteration": 2},
        checkpoint_config=tune.CheckpointConfig(checkpoint_frequency=2),
    ),
)
tuner.fit()

您可以通过三种不同的机制进行检查点:手动、定期和在终止时。

通过 Trainable 手动检查点#

自定义 Trainable 可以通过在 `step` 的结果字典中返回 `should_checkpoint: True`(或 `tune.result.SHOULD_CHECKPOINT: True`)来手动触发检查点。这在抢占式实例上尤其有用

import random


# to be implemented by user.
def detect_instance_preemption():
    choice = random.randint(1, 100)
    # simulating a 1% chance of preemption.
    return choice <= 1


def train_func(self):
    # training code
    result = {"mean_accuracy": "my_accuracy"}
    if detect_instance_preemption():
        result.update(should_checkpoint=True)
    return result


在上面的示例中,如果 `detect_instance_preemption` 返回 True,则可以触发手动检查点。

通过 Tuner Callback 手动检查点#

与[通过 Trainable 手动检查点](#tune-class-trainable-checkpointing-manual-checkpointing)类似,您还可以通过 `Tuner` Callback 方法触发检查点,方法是在自定义回调的 `on_trial_result()` 方法中设置 `result["should_checkpoint"] = True`(或 `result[tune.result.SHOULD_CHECKPOINT] = True`)标志。与在 Trainable Class API 中进行检查点相比,此方法将检查点逻辑与训练逻辑分离,并提供了对所有 `Trial` 实例的访问,从而可以实现更复杂的检查点策略。

from ray import tune
from ray.tune.experiment import Trial
from ray.tune.result import SHOULD_CHECKPOINT, TRAINING_ITERATION


class CheckpointByStepsTaken(tune.Callback):
    def __init__(self, iterations_per_checkpoint: int):
        self.steps_per_checkpoint = iterations_per_checkpoint
        self._trials_last_checkpoint = {}

    def on_trial_result(
        self, iteration: int, trials: list[Trial], trial: Trial, result: dict, **info
    ):
        current_iteration = result[TRAINING_ITERATION]
        if (
            current_iteration - self._trials_last_checkpoint.get(trial, -1)
            >= self.steps_per_checkpoint
        ):
            result[SHOULD_CHECKPOINT] = True
            self._trials_last_checkpoint[trial] = current_iteration


定期检查点#

可以通过设置 `checkpoint_frequency=N` 来实现,每 N 次迭代检查点一次,例如:


tuner = tune.Tuner(
    MyTrainableClass,
    run_config=tune.RunConfig(
        stop={"training_iteration": 2},
        checkpoint_config=tune.CheckpointConfig(checkpoint_frequency=10),
    ),
)
tuner.fit()

终止时检查点#

`checkpoint_frequency` 可能不与实验的确切结束点重合。如果您希望在实验结束时创建一个检查点,您可以额外设置 `checkpoint_at_end=True`

tuner = tune.Tuner(
    MyTrainableClass,
    run_config=tune.RunConfig(
        stop={"training_iteration": 2},
        checkpoint_config=tune.CheckpointConfig(
            checkpoint_frequency=10, checkpoint_at_end=True
        ),
    ),
)
tuner.fit()

配置#

可以通过 `CheckpointConfig` 配置检查点。由于检查点频率是在用户定义的训练循环中手动确定的,因此其中一些配置不适用于 Function Trainable API。请参阅下面的兼容性矩阵。

类 API

函数 API

num_to_keep

checkpoint_score_attribute

checkpoint_score_order

checkpoint_frequency

checkpoint_at_end

总结#

在本用户指南中,我们介绍了如何在 Tune 中保存和加载实验检查点。启用检查点后,请继续阅读以下指南,了解如何

附录:Tune 存储的数据类型#

实验检查点#

实验级别的检查点保存实验状态。这包括搜索器的状态、实验列表及其状态(例如,PENDING、RUNNING、TERMINATED、ERROR),以及每个实验相关的元数据(例如,超参数配置、一些派生的实验结果(min、max、last)等)。

实验级别的检查点由驱动程序定期在头节点上保存。默认情况下,其保存频率会自动调整,以便最多 5% 的时间用于保存实验检查点,剩余时间用于处理训练结果和调度。此时间也可以通过 `TUNE_GLOBAL_CHECKPOINT_S` 环境变量进行调整。

实验检查点#

实验级别的检查点捕获每个实验的状态。这通常包括模型和优化器状态。以下是实验检查点的几个用途:

  • 如果实验因某种原因中断(例如,在抢占式实例上),可以从最后的状态恢复。不会丢失训练时间。

  • 某些搜索器或调度程序会暂停实验,以便为其他实验腾出资源进行训练。这只有在实验可以从最新状态继续训练时才有意义。

  • 之后,检查点可用于其他下游任务,如批量推理。

在此处了解如何保存和加载实验检查点。

实验结果#

实验报告的指标会被保存并记录到各自的实验目录中。这是以 CSV、JSON 或 Tensorboard (events.out.tfevents.*) 格式存储的数据,可以由 Tensorboard 检查并用于实验后分析。