将 Weights & Biases 与 Tune 结合使用#
Weights & Biases (Wandb) 是一款用于实验跟踪、模型优化和数据集版本控制的工具。由于其出色的可视化工具,它在机器学习和数据科学社区非常受欢迎。

Ray Tune 目前为 Weights & Biases 提供了两种轻量级集成。其中一种是 WandbLoggerCallback,它会自动将报告给 Tune 的指标记录到 Wandb API。
另一种是 setup_wandb() 函数,可用于函数 API。它会自动使用 Tune 的训练信息初始化 Wandb API。您可以像往常一样使用 Wandb API,例如使用 wandb.log()
记录您的训练过程。
运行 Weights & Biases 示例#
在以下示例中,我们将使用上述两种方法,即 WandbLoggerCallback
和 setup_wandb
函数来记录指标。
第一步,请确保您在所有运行训练的机器上都已登录 wandb
wandb login
然后我们可以开始进行一些重要的导入
import numpy as np
import ray
from ray import tune
from ray.air.integrations.wandb import WandbLoggerCallback, setup_wandb
接下来,让我们定义一个简单的 train_function
函数(一个 Tune Trainable
),它向 Tune 报告一个随机损失。目标函数本身对于本示例并不重要,因为我们主要关注 Weights & Biases 集成。
def train_function(config):
for i in range(30):
loss = config["mean"] + config["sd"] * np.random.randn()
tune.report({"loss": loss})
您可以使用 WandbLoggerCallback
定义一个简单的网格搜索 Tune 运行,如下所示
def tune_with_callback():
"""Example for using a WandbLoggerCallback with the function API"""
tuner = tune.Tuner(
train_function,
tune_config=tune.TuneConfig(
metric="loss",
mode="min",
),
run_config=tune.RunConfig(
callbacks=[WandbLoggerCallback(project="Wandb_example")]
),
param_space={
"mean": tune.grid_search([1, 2, 3, 4, 5]),
"sd": tune.uniform(0.2, 0.8),
},
)
tuner.fit()
要使用 setup_wandb
实用程序,您只需在您的目标函数中调用此函数即可。请注意,我们还使用 wandb.log(...)
将 loss
作为字典记录到 Weights & Biases。除此之外,此版本的我们的目标函数与其原始版本相同。
def train_function_wandb(config):
wandb = setup_wandb(config, project="Wandb_example")
for i in range(30):
loss = config["mean"] + config["sd"] * np.random.randn()
tune.report({"loss": loss})
wandb.log(dict(loss=loss))
定义了 train_function_wandb
后,您的 Tune 实验将在每个 Trial 启动时设置 wandb!
def tune_with_setup():
"""Example for using the setup_wandb utility with the function API"""
tuner = tune.Tuner(
train_function_wandb,
tune_config=tune.TuneConfig(
metric="loss",
mode="min",
),
param_space={
"mean": tune.grid_search([1, 2, 3, 4, 5]),
"sd": tune.uniform(0.2, 0.8),
},
)
tuner.fit()
最后,您还可以通过在 setup()
方法中使用 setup_wandb
并将运行对象存储为属性来定义基于类的 Tune Trainable
。请注意,对于基于类的 Trainable,您必须单独传递 trial id、名称和组
class WandbTrainable(tune.Trainable):
def setup(self, config):
self.wandb = setup_wandb(
config,
trial_id=self.trial_id,
trial_name=self.trial_name,
group="Example",
project="Wandb_example",
)
def step(self):
for i in range(30):
loss = self.config["mean"] + self.config["sd"] * np.random.randn()
self.wandb.log({"loss": loss})
return {"loss": loss, "done": True}
def save_checkpoint(self, checkpoint_dir: str):
pass
def load_checkpoint(self, checkpoint_dir: str):
pass
使用此 WandbTrainable
运行 Tune 与使用函数 API 完全相同。下面的 tune_trainable
函数与上面的 tune_decorated
仅在传递给 Tuner()
的第一个参数上有所不同。
def tune_trainable():
"""Example for using a WandTrainableMixin with the class API"""
tuner = tune.Tuner(
WandbTrainable,
tune_config=tune.TuneConfig(
metric="loss",
mode="min",
),
param_space={
"mean": tune.grid_search([1, 2, 3, 4, 5]),
"sd": tune.uniform(0.2, 0.8),
},
)
results = tuner.fit()
return results.get_best_result().config
由于您可能没有 Wandb 的 API 密钥,我们可以模拟 Wandb 日志记录器并按如下方式测试我们的所有三个训练函数。如果您已登录 wandb,可以将 mock_api = False
设置为实际将结果上传到 Weights & Biases。
import os
mock_api = True
if mock_api:
os.environ.setdefault("WANDB_MODE", "disabled")
os.environ.setdefault("WANDB_API_KEY", "abcd")
ray.init(
runtime_env={"env_vars": {"WANDB_MODE": "disabled", "WANDB_API_KEY": "abcd"}}
)
tune_with_callback()
tune_with_setup()
tune_trainable()
2022-11-02 16:02:45,355 INFO worker.py:1534 -- Started a local Ray instance. View the dashboard at http://127.0.0.1:8266
2022-11-02 16:02:46,513 INFO wandb.py:282 -- Already logged into W&B.
Tune 状态
当前时间 | 2022-11-02 16:03:13 |
运行时间 | 00:00:27.28 |
内存 | 10.8/16.0 GiB |
系统信息
使用 FIFO 调度算法。请求的资源:0/16 CPU, 0/0 GPU, 0.0/3.44 GiB 堆内存, 0.0/1.72 GiB 对象
Trial 状态
Trial 名称 | 状态 | loc | 均值 | 标准差 | 迭代 | 总时间 (秒) | 损失 |
---|---|---|---|---|---|---|---|
train_function_7676d_00000 | 已终止 | 127.0.0.1:14578 | 1 | 0.411212 | 30 | 0.236137 | 0.828527 |
train_function_7676d_00001 | 已终止 | 127.0.0.1:14591 | 2 | 0.756339 | 30 | 5.57185 | 3.13156 |
train_function_7676d_00002 | 已终止 | 127.0.0.1:14593 | 3 | 0.436643 | 30 | 5.50237 | 3.26679 |
train_function_7676d_00003 | 已终止 | 127.0.0.1:14595 | 4 | 0.295929 | 30 | 5.60986 | 3.70388 |
train_function_7676d_00004 | 已终止 | 127.0.0.1:14596 | 5 | 0.335292 | 30 | 5.61385 | 4.74294 |
Trial 进度
Trial 名称 | 日期 | 完成 | 总 episodes | 实验 ID | 实验标签 | 主机名 | 恢复以来的迭代次数 | 损失 | 节点 IP | PID | 恢复以来的时间 | 本次迭代时间 (秒) | 总时间 (秒) | 时间戳 | 恢复以来的时间步数 | 总时间步数 | 训练迭代 | Trial ID | 预热时间 |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
train_function_7676d_00000 | 2022-11-02_16-02-53 | True | a9f242fa70184d9dadd8952b16fb0ecc | 0_mean=1,sd=0.4112 | Kais-MBP.local.meter | 30 | 0.828527 | 127.0.0.1 | 14578 | 0.236137 | 0.00381589 | 0.236137 | 1667430173 | 0 | 30 | 7676d_00000 | 0.00366998 | ||
train_function_7676d_00001 | 2022-11-02_16-03-03 | True | f57118365bcb4c229fe41c5911f05ad6 | 1_mean=2,sd=0.7563 | Kais-MBP.local.meter | 30 | 3.13156 | 127.0.0.1 | 14591 | 5.57185 | 0.00627518 | 5.57185 | 1667430183 | 0 | 30 | 7676d_00001 | 0.0027349 | ||
train_function_7676d_00002 | 2022-11-02_16-03-03 | True | 394021d4515d4616bae7126668f73b2b | 2_mean=3,sd=0.4366 | Kais-MBP.local.meter | 30 | 3.26679 | 127.0.0.1 | 14593 | 5.50237 | 0.00494576 | 5.50237 | 1667430183 | 0 | 30 | 7676d_00002 | 0.00286222 | ||
train_function_7676d_00003 | 2022-11-02_16-03-03 | True | a575e79c9d95485fa37deaa86267aea4 | 3_mean=4,sd=0.2959 | Kais-MBP.local.meter | 30 | 3.70388 | 127.0.0.1 | 14595 | 5.60986 | 0.00689816 | 5.60986 | 1667430183 | 0 | 30 | 7676d_00003 | 0.00299597 | ||
train_function_7676d_00004 | 2022-11-02_16-03-03 | True | 91ce57dcdbb54536b1874666b711350d | 4_mean=5,sd=0.3353 | Kais-MBP.local.meter | 30 | 4.74294 | 127.0.0.1 | 14596 | 5.61385 | 0.00672579 | 5.61385 | 1667430183 | 0 | 30 | 7676d_00004 | 0.00323987 |
2022-11-02 16:03:13,913 INFO tune.py:788 -- Total run time: 28.53 seconds (27.28 seconds for the tuning loop).
Tune 状态
当前时间 | 2022-11-02 16:03:22 |
运行时间 | 00:00:08.49 |
内存 | 9.9/16.0 GiB |
系统信息
使用 FIFO 调度算法。请求的资源:0/16 CPU, 0/0 GPU, 0.0/3.44 GiB 堆内存, 0.0/1.72 GiB 对象
Trial 状态
Trial 名称 | 状态 | loc | 均值 | 标准差 | 迭代 | 总时间 (秒) | 损失 |
---|---|---|---|---|---|---|---|
train_function_wandb_877eb_00000 | 已终止 | 127.0.0.1:14647 | 1 | 0.738281 | 30 | 1.61319 | 0.555153 |
train_function_wandb_877eb_00001 | 已终止 | 127.0.0.1:14660 | 2 | 0.321178 | 30 | 1.72447 | 2.52109 |
train_function_wandb_877eb_00002 | 已终止 | 127.0.0.1:14661 | 3 | 0.202487 | 30 | 1.8159 | 2.45412 |
train_function_wandb_877eb_00003 | 已终止 | 127.0.0.1:14662 | 4 | 0.515434 | 30 | 1.715 | 4.51413 |
train_function_wandb_877eb_00004 | 已终止 | 127.0.0.1:14663 | 5 | 0.216098 | 30 | 1.72827 | 5.2814 |
(train_function_wandb pid=14647) 2022-11-02 16:03:17,149 INFO wandb.py:282 -- Already logged into W&B.
Trial 进度
Trial 名称 | 日期 | 完成 | 总 episodes | 实验 ID | 实验标签 | 主机名 | 恢复以来的迭代次数 | 损失 | 节点 IP | PID | 恢复以来的时间 | 本次迭代时间 (秒) | 总时间 (秒) | 时间戳 | 恢复以来的时间步数 | 总时间步数 | 训练迭代 | Trial ID | 预热时间 |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
train_function_wandb_877eb_00000 | 2022-11-02_16-03-18 | True | 7b250c9f31ab484dad1a1fd29823afdf | 0_mean=1,sd=0.7383 | Kais-MBP.local.meter | 30 | 0.555153 | 127.0.0.1 | 14647 | 1.61319 | 0.00232315 | 1.61319 | 1667430198 | 0 | 30 | 877eb_00000 | 0.00391102 | ||
train_function_wandb_877eb_00001 | 2022-11-02_16-03-22 | True | 5172868368074557a3044ea3a9146673 | 1_mean=2,sd=0.3212 | Kais-MBP.local.meter | 30 | 2.52109 | 127.0.0.1 | 14660 | 1.72447 | 0.0152011 | 1.72447 | 1667430202 | 0 | 30 | 877eb_00001 | 0.00901699 | ||
train_function_wandb_877eb_00002 | 2022-11-02_16-03-22 | True | b13d9bccb1964b4b95e1a858a3ea64c7 | 2_mean=3,sd=0.2025 | Kais-MBP.local.meter | 30 | 2.45412 | 127.0.0.1 | 14661 | 1.8159 | 0.00437403 | 1.8159 | 1667430202 | 0 | 30 | 877eb_00002 | 0.00844812 | ||
train_function_wandb_877eb_00003 | 2022-11-02_16-03-22 | True | 869d7ec7a3544a8387985103e626818f | 3_mean=4,sd=0.5154 | Kais-MBP.local.meter | 30 | 4.51413 | 127.0.0.1 | 14662 | 1.715 | 0.00247812 | 1.715 | 1667430202 | 0 | 30 | 877eb_00003 | 0.00282907 | ||
train_function_wandb_877eb_00004 | 2022-11-02_16-03-22 | True | 84d3112d66f64325bc469e44b8447ef5 | 4_mean=5,sd=0.2161 | Kais-MBP.local.meter | 30 | 5.2814 | 127.0.0.1 | 14663 | 1.72827 | 0.00517201 | 1.72827 | 1667430202 | 0 | 30 | 877eb_00004 | 0.00272107 |
(train_function_wandb pid=14660) 2022-11-02 16:03:20,600 INFO wandb.py:282 -- Already logged into W&B.
(train_function_wandb pid=14661) 2022-11-02 16:03:20,600 INFO wandb.py:282 -- Already logged into W&B.
(train_function_wandb pid=14663) 2022-11-02 16:03:20,628 INFO wandb.py:282 -- Already logged into W&B.
(train_function_wandb pid=14662) 2022-11-02 16:03:20,723 INFO wandb.py:282 -- Already logged into W&B.
2022-11-02 16:03:22,565 INFO tune.py:788 -- Total run time: 8.60 seconds (8.48 seconds for the tuning loop).
Tune 状态
当前时间 | 2022-11-02 16:03:31 |
运行时间 | 00:00:09.28 |
内存 | 9.9/16.0 GiB |
系统信息
使用 FIFO 调度算法。请求的资源:0/16 CPU, 0/0 GPU, 0.0/3.44 GiB 堆内存, 0.0/1.72 GiB 对象
Trial 状态
Trial 名称 | 状态 | loc | 均值 | 标准差 | 迭代 | 总时间 (秒) | 损失 |
---|---|---|---|---|---|---|---|
WandbTrainable_8ca33_00000 | 已终止 | 127.0.0.1:14718 | 1 | 0.397894 | 1 | 0.000187159 | 0.742345 |
WandbTrainable_8ca33_00001 | 已终止 | 127.0.0.1:14737 | 2 | 0.386883 | 1 | 0.000151873 | 2.5709 |
WandbTrainable_8ca33_00002 | 已终止 | 127.0.0.1:14738 | 3 | 0.290693 | 1 | 0.00014019 | 2.99601 |
WandbTrainable_8ca33_00003 | 已终止 | 127.0.0.1:14739 | 4 | 0.33333 | 1 | 0.00015831 | 3.91276 |
WandbTrainable_8ca33_00004 | 已终止 | 127.0.0.1:14740 | 5 | 0.645479 | 1 | 0.000150919 | 5.47779 |
(WandbTrainable pid=14718) 2022-11-02 16:03:25,742 INFO wandb.py:282 -- Already logged into W&B.
Trial 进度
Trial 名称 | 日期 | 完成 | 总 episodes | 实验 ID | 主机名 | 恢复以来的迭代次数 | 损失 | 节点 IP | PID | 恢复以来的时间 | 本次迭代时间 (秒) | 总时间 (秒) | 时间戳 | 恢复以来的时间步数 | 总时间步数 | 训练迭代 | Trial ID | 预热时间 |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
WandbTrainable_8ca33_00000 | 2022-11-02_16-03-27 | True | 3adb4d0ae0d74d1c9ddd07924b5653b0 | Kais-MBP.local.meter | 1 | 0.742345 | 127.0.0.1 | 14718 | 0.000187159 | 0.000187159 | 0.000187159 | 1667430207 | 0 | 1 | 8ca33_00000 | 1.31382 | ||
WandbTrainable_8ca33_00001 | 2022-11-02_16-03-31 | True | f1511cfd51f94b3d9cf192181ccc08a9 | Kais-MBP.local.meter | 1 | 2.5709 | 127.0.0.1 | 14737 | 0.000151873 | 0.000151873 | 0.000151873 | 1667430211 | 0 | 1 | 8ca33_00001 | 1.31668 | ||
WandbTrainable_8ca33_00002 | 2022-11-02_16-03-31 | True | a7528ec6adf74de0b73aa98ebedab66d | Kais-MBP.local.meter | 1 | 2.99601 | 127.0.0.1 | 14738 | 0.00014019 | 0.00014019 | 0.00014019 | 1667430211 | 0 | 1 | 8ca33_00002 | 1.32008 | ||
WandbTrainable_8ca33_00003 | 2022-11-02_16-03-31 | True | b7af756ca586449ba2d4c44141b53b06 | Kais-MBP.local.meter | 1 | 3.91276 | 127.0.0.1 | 14739 | 0.00015831 | 0.00015831 | 0.00015831 | 1667430211 | 0 | 1 | 8ca33_00003 | 1.31879 | ||
WandbTrainable_8ca33_00004 | 2022-11-02_16-03-31 | True | 196624f42bcc45c18a26778573a43a2c | Kais-MBP.local.meter | 1 | 5.47779 | 127.0.0.1 | 14740 | 0.000150919 | 0.000150919 | 0.000150919 | 1667430211 | 0 | 1 | 8ca33_00004 | 1.31945 |
(WandbTrainable pid=14739) 2022-11-02 16:03:30,360 INFO wandb.py:282 -- Already logged into W&B.
(WandbTrainable pid=14740) 2022-11-02 16:03:30,393 INFO wandb.py:282 -- Already logged into W&B.
(WandbTrainable pid=14737) 2022-11-02 16:03:30,454 INFO wandb.py:282 -- Already logged into W&B.
(WandbTrainable pid=14738) 2022-11-02 16:03:30,510 INFO wandb.py:282 -- Already logged into W&B.
2022-11-02 16:03:31,985 INFO tune.py:788 -- Total run time: 9.40 seconds (9.27 seconds for the tuning loop).
{'mean': 1, 'sd': 0.3978937765393781, 'wandb': {'project': 'Wandb_example'}}
至此,我们的 Tune 和 Wandb 演练就完成了。在接下来的章节中,您可以找到有关 Tune-Wandb 集成 API 的更多详细信息。
Tune Wandb API 参考#
WandbLoggerCallback#
- class ray.air.integrations.wandb.WandbLoggerCallback(project: str | None = None, group: str | None = None, api_key_file: str | None = None, api_key: str | None = None, excludes: List[str] | None = None, log_config: bool = False, upload_checkpoints: bool = False, save_checkpoints: bool = False, upload_timeout: int = 1800, **kwargs)[源码]
Weights and biases (https://www.wandb.ai/) 是一款用于实验跟踪、模型优化和数据集版本控制的工具。这个 Ray Tune
LoggerCallback
将指标发送到 Wandb 进行自动跟踪和可视化。示例
import random from ray import tune from ray.air.integrations.wandb import WandbLoggerCallback def train_func(config): offset = random.random() / 5 for epoch in range(2, config["epochs"]): acc = 1 - (2 + config["lr"]) ** -epoch - random.random() / epoch - offset loss = (2 + config["lr"]) ** -epoch + random.random() / epoch + offset train.report({"acc": acc, "loss": loss}) tuner = tune.Tuner( train_func, param_space={ "lr": tune.grid_search([0.001, 0.01, 0.1, 1.0]), "epochs": 10, }, run_config=tune.RunConfig( callbacks=[WandbLoggerCallback(project="Optimization_Project")] ), ) results = tuner.fit()
- 参数:
project – Wandb 项目名称。必需参数。
group – Wandb 组名称。默认为 trainable 名称。
api_key_file – 包含 Wandb API KEY 的文件路径。如果使用 WandbLogger,此文件只需存在于运行 Tune 脚本的节点上。
api_key – Wandb API Key。api_key_file 的替代方案。
excludes – 应从日志中排除的指标和配置列表。
log_config – 布尔值,指示是否应记录
results
字典的config
参数。如果参数在训练期间会发生变化(例如使用 PopulationBasedTraining),这很有意义。默认为 False。upload_checkpoints – 如果
True
,模型检查点将作为 Artifact 上传到 Wandb。默认为False
。**kwargs – 关键字参数将传递给
wandb.init()
。
Wandb 的
group
、run_id
和run_name
会由 Tune 自动选择,但可以通过填写相应的配置值来覆盖。有关所有其他有效的配置设置,请参见此处: https://docs.wandb.ai/library/init
PublicAPI (alpha): 此 API 处于 alpha 阶段,在稳定之前可能会发生变化。
setup_wandb#
- ray.air.integrations.wandb.setup_wandb(config: Dict | None = None, api_key: str | None = None, api_key_file: str | None = None, rank_zero_only: bool = True, **kwargs) wandb.wandb_run.Run | wandb.sdk.lib.disabled.RunDisabled [源码]
设置 Weights & Biases session。
此函数可用于在(分布式)训练或调优运行中初始化 Weights & Biases session。
默认情况下,run ID 是 trial ID,run name 是 trial 名称,run group 是实验名称。这些设置可以通过将相应的参数作为
kwargs
传递来覆盖,kwargs
将传递给wandb.init()
。在 Ray Train 的分布式训练中,只有零级别 worker 会初始化 wandb。所有其他 worker 将返回一个禁用的 run 对象,以避免在分布式运行中重复记录日志。这可以通过传递
rank_zero_only=False
来禁用,此时将在每个训练 worker 中初始化 wandb。config 参数将传递给 Weights and Biases,并作为运行配置进行记录。
如果未传递 API 密钥或密钥文件,wandb 将尝试使用本地存储的凭据进行认证,例如通过运行
wandb login
创建的凭据。传递给
setup_wandb()
的关键字参数将传递给wandb.init()
,并优先于任何潜在的默认设置。- 参数:
config – 要记录到 Weights and Biases 的配置字典。可以包含
wandb.init()
的参数以及认证信息。api_key – 用于与 Weights and Biases 进行认证的 API 密钥。
api_key_file – 指向 Weights and Biases API 密钥的文件。
rank_zero_only – 如果为 True,在分布式训练中将仅为零级别 worker 返回一个已初始化的 session。如果为 False,将为所有 worker 初始化一个 session。
kwargs – 传递给
wandb.init()
。
示例
from ray.air.integrations.wandb import setup_wandb def training_loop(config): wandb = setup_wandb(config) # ... wandb.log({"loss": 0.123})
PublicAPI (alpha): 此 API 处于 alpha 阶段,在稳定之前可能会发生变化。