如何保存和加载 Trial Checkpoint#

Trial checkpoint 是 Tune 存储的三种数据类型之一。它们是用户定义的,用于快照你的训练进度!

Trial 级别的 checkpoint 通过 Tune Trainable API 保存:这是你定义自定义训练逻辑的方式,也是你定义要 checkpoint 的 Trial 状态的地方。本指南将展示如何为 Tune 的 Function Trainable 和 Class Trainable API 保存和加载 checkpoint,并引导你了解配置选项。

Function API Checkpointing#

如果使用 Ray Tune 的 Function API,可以按照以下方式保存和加载 checkpoint。要创建 checkpoint,请使用 from_directory() API。

import os
import tempfile

from ray import tune
from ray.tune import Checkpoint


def train_func(config):
    start = 1
    my_model = MyModel()

    checkpoint = tune.get_checkpoint()
    if checkpoint:
        with checkpoint.as_directory() as checkpoint_dir:
            checkpoint_dict = torch.load(os.path.join(checkpoint_dir, "checkpoint.pt"))
            start = checkpoint_dict["epoch"] + 1
            my_model.load_state_dict(checkpoint_dict["model_state"])

    for epoch in range(start, config["epochs"] + 1):
        # Model training here
        # ...

        metrics = {"metric": 1}
        with tempfile.TemporaryDirectory() as tempdir:
            torch.save(
                {"epoch": epoch, "model_state": my_model.state_dict()},
                os.path.join(tempdir, "checkpoint.pt"),
            )
            tune.report(metrics=metrics, checkpoint=Checkpoint.from_directory(tempdir))


tuner = tune.Tuner(train_func, param_space={"epochs": 5})
result_grid = tuner.fit()

在上面的代码片段中

  • 我们使用 tune.report(..., checkpoint=checkpoint) 实现checkpoint 保存。请注意,每个 checkpoint 必须与一组指标一起报告——这样,checkpoint 就可以根据指定的指标排序。

  • 训练迭代 epoch 期间保存的 checkpoint 将保存到训练发生节点上的路径 <storage_path>/<exp_name>/<trial_name>/checkpoint_<epoch>,并可以根据存储配置进一步同步到统一的存储位置。

  • 我们使用 tune.get_checkpoint() 实现checkpoint 加载。每当 Tune 恢复 Trial 时,此函数将填充 Trial 的最新 checkpoint。这发生在以下情况:(1)Trial 配置为在遇到故障后重试,(2)实验正在恢复,以及(3)Trial 在暂停后恢复(例如:PBT)。

注意

checkpoint_frequencycheckpoint_at_end 不适用于 Function API checkpointing。这些需要通过 Function Trainable 手动配置。例如,如果你想每三个 epoch checkpoint 一次,可以通过以下方式实现

NUM_EPOCHS = 12
# checkpoint every three epochs.
CHECKPOINT_FREQ = 3


def train_func(config):
    for epoch in range(1, config["epochs"] + 1):
        # Model training here
        # ...

        # Report metrics and save a checkpoint
        metrics = {"metric": "my_metric"}
        if epoch % CHECKPOINT_FREQ == 0:
            with tempfile.TemporaryDirectory() as tempdir:
                # Save a checkpoint in tempdir.
                tune.report(metrics, checkpoint=Checkpoint.from_directory(tempdir))
        else:
            tune.report(metrics)


tuner = tune.Tuner(train_func, param_space={"epochs": NUM_EPOCHS})
result_grid = tuner.fit()

请参阅此处 了解 有关 创建 checkpoint 的更多信息。

Class API Checkpointing#

你还可以使用 Trainable Class API 实现 checkpoint/恢复

import os
import torch
from torch import nn

from ray import tune


class MyTrainableClass(tune.Trainable):
    def setup(self, config):
        self.model = nn.Sequential(
            nn.Linear(config.get("input_size", 32), 32), nn.ReLU(), nn.Linear(32, 10)
        )

    def step(self):
        return {}

    def save_checkpoint(self, tmp_checkpoint_dir):
        checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth")
        torch.save(self.model.state_dict(), checkpoint_path)
        return tmp_checkpoint_dir

    def load_checkpoint(self, tmp_checkpoint_dir):
        checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth")
        self.model.load_state_dict(torch.load(checkpoint_path))


tuner = tune.Tuner(
    MyTrainableClass,
    param_space={"input_size": 64},
    run_config=tune.RunConfig(
        stop={"training_iteration": 2},
        checkpoint_config=tune.CheckpointConfig(checkpoint_frequency=2),
    ),
)
tuner.fit()

你可以通过三种不同的机制进行 checkpoint:手动、周期性和终止时。

手动 Checkpointing#

自定义 Trainable 可以通过在 step 的结果字典中返回 should_checkpoint: True (或 tune.result.SHOULD_CHECKPOINT: True) 来手动触发 checkpointing。这在 Spot 实例中特别有用

import random


# to be implemented by user.
def detect_instance_preemption():
    choice = random.randint(1, 100)
    # simulating a 1% chance of preemption.
    return choice <= 1


def train_func(self):
    # training code
    result = {"mean_accuracy": "my_accuracy"}
    if detect_instance_preemption():
        result.update(should_checkpoint=True)
    return result


在上面的例子中,如果 detect_instance_preemption 返回 True,则可以触发手动 checkpointing。

周期性 Checkpointing#

这可以通过设置 checkpoint_frequency=N 来启用,以便每 N 次迭代 checkpoint Trial,例如


tuner = tune.Tuner(
    MyTrainableClass,
    run_config=tune.RunConfig(
        stop={"training_iteration": 2},
        checkpoint_config=tune.CheckpointConfig(checkpoint_frequency=10),
    ),
)
tuner.fit()

终止时 Checkpointing#

checkpoint_frequency 可能与实验的精确结束时间不一致。如果你希望在 Trial 结束时创建一个 checkpoint,你可以额外设置 checkpoint_at_end=True

tuner = tune.Tuner(
    MyTrainableClass,
    run_config=tune.RunConfig(
        stop={"training_iteration": 2},
        checkpoint_config=tune.CheckpointConfig(
            checkpoint_frequency=10, checkpoint_at_end=True
        ),
    ),
)
tuner.fit()

配置#

可以通过 CheckpointConfig 配置 Checkpointing。某些配置不适用于 Function Trainable API,因为 checkpointing 频率是在用户定义的训练循环中手动确定的。请参阅下面的兼容性矩阵。

Class API

Function API

num_to_keep

checkpoint_score_attribute

checkpoint_score_order

checkpoint_frequency

checkpoint_at_end

总结#

在本用户指南中,我们介绍了如何在 Tune 中保存和加载 Trial checkpoint。启用 checkpointing 后,请继续阅读以下指南之一,了解如何

附录:Tune 存储的数据类型#

实验 Checkpoint#

实验级 checkpoint 保存实验状态。这包括搜索器的状态、Trial 列表及其状态(例如 PENDING、RUNNING、TERMINATED、ERROR),以及与每个 Trial 相关的元数据(例如超参数配置、一些派生的 Trial 结果(最小、最大、最后)等)。

实验级 checkpoint 由驱动程序在 head 节点上周期性保存。默认情况下,保存频率会自动调整,以便最多 5% 的时间用于保存实验 checkpoint,其余时间用于处理训练结果和调度。此时间也可以通过 TUNE_GLOBAL_CHECKPOINT_S 环境变量进行调整。

Trial Checkpoint#

Trial 级别的 checkpoint 捕获每个 Trial 的状态。这通常包括模型和优化器状态。Trial checkpoint 的一些用途如下

  • 如果 Trial 因某些原因中断(例如在 Spot 实例上),可以从上次状态恢复。不会丢失训练时间。

  • 一些搜索器或调度器会暂停 Trial,以便为其他 Trial 腾出资源进行训练。这只有在 Trial 能够从最新状态继续训练时才有意义。

  • 该 checkpoint 稍后可用于其他下游任务,例如批量推理。

在此了解如何保存和加载 Trial checkpoint。

Trial 结果#

Trial 报告的指标会保存并记录到各自的 Trial 目录中。这是以 CSV、JSON 或 Tensorboard (events.out.tfevents.*) 格式存储的数据,可以通过 Tensorboard 进行检查并用于实验后分析。