Ray Tune 回调和指标指南#

如何在 Ray Tune 中使用回调?#

Ray Tune 支持在训练过程的各种时间点被调用的回调。回调可以作为参数传递给 RunConfig,由 Tuner 接收,您提供的子方法将自动调用。

这个简单的回调会在每次收到结果时打印一个指标

from ray import tune
from ray.tune import Callback


class MyCallback(Callback):
    def on_trial_result(self, iteration, trials, trial, result, **info):
        print(f"Got result: {result['metric']}")


def train_fn(config):
    for i in range(10):
        tune.report({"metric": i})


tuner = tune.Tuner(
    train_fn,
    run_config=tune.RunConfig(callbacks=[MyCallback()]))
tuner.fit()

有关更多详细信息和可用钩子,请查看 Ray Tune 回调的 API 文档

如何在 Tune 中使用记录指标?#

您可以在 Function 和 Class 训练 API 中记录任意值和指标

def trainable(config):
    for i in range(num_epochs):
        ...
        tune.report({"acc": accuracy, "metric_foo": random_metric_1, "bar": metric_2})

class Trainable(tune.Trainable):
    def step(self):
        ...
        # don't call report here!
        return dict(acc=accuracy, metric_foo=random_metric_1, bar=metric_2)

提示

请注意,tune.report() 并非用于传输大量数据,例如模型或数据集。这样做可能会产生大量开销,并显着减慢您的 Tune 运行速度。

哪些 Tune 指标会被自动填充?#

Tune 具有自动填充指标的概念。在训练期间,Tune 将自动记录以下指标,以及任何用户提供的数值。所有这些都可以用作停止条件或作为 Trial Schedulers/Search Algorithms 的参数。

  • config:超参数配置

  • date:处理结果的日期和时间(字符串格式)

  • done:如果 Trial 已完成,则为 True,否则为 False

  • episodes_total:总回合数(适用于 RLlib 可训练对象)

  • experiment_id:唯一的实验 ID

  • experiment_tag:唯一的实验标签(包含参数值)

  • hostname:工作节点的 hostname

  • iterations_since_restore:从检查点恢复 worker 后调用 tune.report 的次数

  • node_ip:工作节点的 IP 地址

  • pid:工作节点进程的进程 ID (PID)

  • time_since_restore:从检查点恢复以来的秒数。

  • time_this_iter_s:当前训练迭代的运行时间(秒),即一次调用可训练函数或类 API 中的 _train()

  • time_total_s:总运行时间(秒)。

  • timestamp:处理结果的时间戳

  • timesteps_since_restore:从检查点恢复以来的时间步数

  • timesteps_total:总时间步数

  • training_iteration:调用 tune.report() 的次数

  • trial_id:唯一的 Trial ID

所有这些指标都可以在 Trial.last_result 字典中看到。