如何为 Ray Tune 实验定义停止标准#

在运行 Tune 实验时,预先确定理想的训练时长可能具有挑战性。Tune 中的停止标准对于根据特定条件终止训练非常有用。

例如,您可能希望设置实验在以下情况下停止

  1. 设置实验在 N 个 epoch 后结束,或在报告的评估分数超过特定阈值时结束,以先发生者为准。

  2. T 秒后停止实验。

  3. 当试用遇到运行时错误时终止。

  4. 通过利用 Tune 的提前停止调度器,尽早停止表现不佳的试用。

本用户指南将演示如何在 Tune 实验中实现这些类型的停止标准。

对于所有代码示例,我们使用以下训练函数进行演示

from ray import tune
import time

def my_trainable(config):
    i = 1
    while True:
        # Do some training...
        time.sleep(1)

        # Report some metrics for demonstration...
        tune.report({"mean_accuracy": min(i / 10, 1.0)})
        i += 1

手动停止 Tune 实验#

如果您向运行 Tuner.fit() 的进程发送 SIGINT 信号(这通常是您在终端中按下 Ctrl+C 时发生的情况),Ray Tune 将优雅地关闭训练并保存最终实验状态。

注意

强制终止 Tune 实验,例如通过多次 Ctrl+C 命令,将不会给 Tune 最后一次快照实验状态的机会。如果您将来恢复实验,这可能导致恢复到过时的状态。

Ray Tune 也接受 SIGUSR1 信号来优雅地中断训练。当在远程 Ray 任务中运行 Ray Tune 时应该使用此信号,因为 Ray 默认会过滤掉 SIGINTSIGTERM 信号。

使用基于指标的标准停止#

除了手动停止,Tune 还提供了几种编程方式来停止实验。最简单的方法是使用基于指标的标准。这些是固定的阈值集合,用于确定实验何时应该停止。

您可以使用字典、函数或自定义 Stopper 来实现停止标准。

如果传入一个字典,其键可以是 Function API 中 session.report 返回结果的任何字段,或 Class API 中 step() 返回结果的任何字段。

注意

这包括 自动填充的指标,例如 training_iteration

在下面的示例中,每个试用将在完成 10 次迭代时或在平均准确率达到 0.8 或更高时停止。

这些指标假定是递增的,因此一旦报告的指标超过字典中指定的阈值,试用就会停止。

from ray import tune

tuner = tune.Tuner(
    my_trainable,
    run_config=tune.RunConfig(stop={"training_iteration": 10, "mean_accuracy": 0.8}),
)
result_grid = tuner.fit()

为了更灵活,您可以传入一个函数。如果传入一个函数,它必须接受 (trial_id: str, result: dict) 作为参数,并返回一个布尔值(如果试用应该停止则返回 True,否则返回 False)。

在下面的示例中,每个试用将在完成 10 次迭代时或在平均准确率达到 0.8 或更高时停止。

from ray import tune


def stop_fn(trial_id: str, result: dict) -> bool:
    return result["mean_accuracy"] >= 0.8 or result["training_iteration"] >= 10


tuner = tune.Tuner(my_trainable, run_config=tune.RunConfig(stop=stop_fn))
result_grid = tuner.fit()

最后,您可以实现 Stopper 接口,用于根据自定义停止标准停止单个试用或整个实验。例如,以下示例在任何单个试用达到标准后停止所有试用,并阻止新试用开始

from ray import tune
from ray.tune import Stopper


class CustomStopper(Stopper):
    def __init__(self):
        self.should_stop = False

    def __call__(self, trial_id: str, result: dict) -> bool:
        if not self.should_stop and result["mean_accuracy"] >= 0.8:
            self.should_stop = True
        return self.should_stop

    def stop_all(self) -> bool:
        """Returns whether to stop trials and prevent new ones from starting."""
        return self.should_stop


stopper = CustomStopper()
tuner = tune.Tuner(
    my_trainable,
    run_config=tune.RunConfig(stop=stopper),
    tune_config=tune.TuneConfig(num_samples=2),
)
result_grid = tuner.fit()

在该示例中,一旦任何试用达到 mean_accuracy 0.8 或更高的值,所有试用都将停止。

注意

当从 stop_all 返回 True 时,当前正在运行的试用不会立即停止。它们会在完成当前的训练迭代后停止(在 session.reportstep 之后)。

Ray Tune 自带一套开箱即用的 stopper 类。请参阅 Stopper 文档。

在一定时间后停止试用#

有两种基于时间停止 Tune 实验的选择:在指定超时后单独停止试用,或在一定时间后停止整个实验。

设置超时以单独停止试用#

您可以使用如上所述的字典停止标准,并使用 Tune 自动填充的 time_total_s 指标。

from ray import tune

tuner = tune.Tuner(
    my_trainable,
    # Stop a trial after it's run for more than 5 seconds.
    run_config=tune.RunConfig(stop={"time_total_s": 5}),
)
result_grid = tuner.fit()

注意

如果使用 Function Trainable API,您需要通过 tune.report 包含一些中间报告。每次报告都会自动记录试用的 time_total_s,这使得 Tune 可以基于时间作为指标进行停止。

如果训练循环在某个地方挂起,Tune 将无法拦截训练并为您停止试用。在这种情况下,您可以在训练循环中明确实现超时逻辑。

设置超时以停止实验#

使用 TuneConfig(time_budget_s) 配置告知 Tune 在 time_budget_s 秒后停止实验。

from ray import tune

# Stop the entire experiment after ANY trial has run for more than 5 seconds.
tuner = tune.Tuner(my_trainable, tune_config=tune.TuneConfig(time_budget_s=5.0))
result_grid = tuner.fit()

注意

如果使用 Function Trainable API,您需要通过 tune.report 包含一些中间报告,原因同上。

试用失败时停止#

除了根据试用性能停止,如果任何试用遇到运行时错误,您还可以停止整个实验。为此,您可以使用 ray.tune.FailureConfig 类。

通过此配置,如果任何试用遇到错误,整个实验将立即停止。

from ray import tune
import time


def my_failing_trainable(config):
    if config["should_fail"]:
        raise RuntimeError("Failing (on purpose)!")
    # Do some training...
    time.sleep(10)
    tune.report({"mean_accuracy": 0.9})


tuner = tune.Tuner(
    my_failing_trainable,
    param_space={"should_fail": tune.grid_search([True, False])},
    run_config=tune.RunConfig(failure_config=tune.FailureConfig(fail_fast=True)),
)
result_grid = tuner.fit()

这在调试包含许多试用的 Tune 实验时非常有用。

使用 Tune 调度器提前停止#

另一种停止 Tune 实验的方法是使用提前停止调度器。这些调度器监控试用的性能,如果它们没有取得足够的进展,就会提前停止它们。

AsyncHyperBandSchedulerHyperBandForBOHB 是 Tune 内置的提前停止调度器示例。有关完整列表以及更实际的示例,请参阅 Tune 调度器 API 参考

在下面的示例中,我们同时使用了字典停止标准和提前停止标准

from ray import tune
from ray.tune.schedulers import AsyncHyperBandScheduler


scheduler = AsyncHyperBandScheduler(time_attr="training_iteration")

tuner = tune.Tuner(
    my_trainable,
    run_config=tune.RunConfig(stop={"training_iteration": 10}),
    tune_config=tune.TuneConfig(
        scheduler=scheduler, num_samples=2, metric="mean_accuracy", mode="max"
    ),
)
result_grid = tuner.fit()

总结#

在本用户指南中,我们学习了如何使用指标、试用错误和提前停止调度器来停止 Tune 实验。

请参阅以下资源以获取更多信息