Ray Tune 基于群体的训练指南#

try-anyscale-quickstart

Tune 包含 基于群体的训练 (PBT) 的分布式实现,作为一种 调度器

Paper figure

PBT 通过并行训练许多具有随机超参数的神经网络开始,利用群体中其他成员的信息来优化这些超参数,并将资源分配给有前景的模型。让我们来学习如何使用这个算法。

使用基于群体的训练的功能 API#

PBT 从遗传算法中汲取灵感,在遗传算法中,群体中表现不佳的成员可以利用群体中表现最佳成员的信息。在我们的例子中,群体 是并行运行的一组 Tune trials,trial 的性能由用户指定的指标决定,例如 mean_accuracy

PBT 有两个主要步骤:利用 (exploitation)探索 (exploration)。利用的一个例子是一个 trial 从表现更好的 trial 复制模型参数。探索的一个例子是通过随机扰动当前值来生成新的超参数配置。

随着神经网络群体的训练进展,这种利用和探索的过程会定期进行,确保群体中的所有 worker 都具有良好的基本性能水平,并且持续探索新的超参数配置。这意味着 PBT 可以快速利用好的超参数,将更多的训练时间投入到有前景的模型中,并且至关重要的是,在整个训练过程中变异超参数值,从而学习到最佳的自适应超参数调度。

在这里,我们将通过一个 MNIST ConvNet 训练示例来学习如何使用 PBT。首先,我们定义一个使用 SGD 训练 ConvNet 模型的训练函数。

!pip install "ray[tune]"
import os
import tempfile

import torch
import torch.optim as optim

import ray
from ray import tune
from ray.tune.examples.mnist_pytorch import ConvNet, get_data_loaders, test_func
from ray.tune.schedulers import PopulationBasedTraining


def train_convnet(config):
    # Create our data loaders, model, and optmizer.
    step = 1
    train_loader, test_loader = get_data_loaders()
    model = ConvNet()
    optimizer = optim.SGD(
        model.parameters(),
        lr=config.get("lr", 0.01),
        momentum=config.get("momentum", 0.9),
    )

    # Myabe resume from a checkpoint.
    checkpoint = tune.get_checkpoint()
    if checkpoint:
        with checkpoint.as_directory() as checkpoint_dir:
            checkpoint_dict = torch.load(os.path.join(checkpoint_dir, "checkpoint.pt"))

        # Load model state and iteration step from checkpoint.
        model.load_state_dict(checkpoint_dict["model_state_dict"])
        # Load optimizer state (needed since we're using momentum),
        # then set the `lr` and `momentum` according to the config.
        optimizer.load_state_dict(checkpoint_dict["optimizer_state_dict"])
        for param_group in optimizer.param_groups:
            if "lr" in config:
                param_group["lr"] = config["lr"]
            if "momentum" in config:
                param_group["momentum"] = config["momentum"]

        # Note: Make sure to increment the checkpointed step by 1 to get the current step.
        last_step = checkpoint_dict["step"]
        step = last_step + 1

    while True:
        ray.tune.examples.mnist_pytorch.train_func(model, optimizer, train_loader)
        acc = test_func(model, test_loader)
        metrics = {"mean_accuracy": acc, "lr": config["lr"]}

        # Every `checkpoint_interval` steps, checkpoint our current state.
        if step % config["checkpoint_interval"] == 0:
            with tempfile.TemporaryDirectory() as tmpdir:
                torch.save(
                    {
                        "step": step,
                        "model_state_dict": model.state_dict(),
                        "optimizer_state_dict": optimizer.state_dict(),
                    },
                    os.path.join(tmpdir, "checkpoint.pt"),
                )
                tune.report(metrics, checkpoint=tune.Checkpoint.from_directory(tmpdir))
        else:
            tune.report(metrics)

        step += 1

该示例重用了 ray/tune/examples/mnist_pytorch.py 中的一些函数:这也是一个很好的演示,展示如何解耦调优逻辑和原始训练代码。

PBT 需要保存和加载检查点,因此我们必须在通过 train.get_checkpoint() 提供检查点时加载它,并定期通过 tune.report(...) 将模型状态保存到检查点中——在本例中,每 checkpoint_interval 次迭代保存一次,这是我们稍后设置的一个配置。

然后,我们定义一个 PBT 调度器

perturbation_interval = 5
scheduler = PopulationBasedTraining(
    time_attr="training_iteration",
    perturbation_interval=perturbation_interval,
    metric="mean_accuracy",
    mode="max",
    hyperparam_mutations={
        # distribution for resampling
        "lr": tune.uniform(0.0001, 1),
        # allow perturbations within this set of categorical values
        "momentum": [0.8, 0.9, 0.99],
    },
)

一些最重要的参数是

  • 使用 hyperparam_mutationscustom_explore_fn 来变异超参数。hyperparam_mutations 是一个字典,其中每个键/值对指定超参数的候选值或函数。custom_explore_fn 在应用 hyperparam_mutations 的内置扰动后应用,应返回根据需要更新的 config。

  • resample_probability:应用 hyperparam_mutations 时,从原始分布重新采样的概率。如果不重新采样,如果值是连续的,则会按 1.2 或 0.8 的因子进行扰动;如果值是离散的,则会更改为相邻值。请注意,resample_probability 默认为 0.25,因此带有分布的超参数可能会超出指定范围。

现在我们可以通过调用 Tuner.fit() 来启动调优过程

if ray.is_initialized():
    ray.shutdown()
ray.init()

tuner = tune.Tuner(
    train_convnet,
    run_config=tune.RunConfig(
        name="pbt_test",
        # Stop when we've reached a threshold accuracy, or a maximum
        # training_iteration, whichever comes first
        stop={"mean_accuracy": 0.96, "training_iteration": 50},
        checkpoint_config=tune.CheckpointConfig(
            checkpoint_score_attribute="mean_accuracy",
            num_to_keep=4,
        ),
        storage_path="/tmp/ray_results",
    ),
    tune_config=tune.TuneConfig(
        scheduler=scheduler,
        num_samples=4,
    ),
    param_space={
        "lr": tune.uniform(0.001, 1),
        "momentum": tune.uniform(0.001, 1),
        "checkpoint_interval": perturbation_interval,
    },
)

results_grid = tuner.fit()

注意

我们建议将 PBT 配置中的 checkpoint_intervalperturbation_interval 相匹配。这确保 PBT 算法能够利用最近一次迭代的 trial。

如果你的 perturbation_interval 很大,并且想更频繁地保存检查点,请将 perturbation_interval 设置为 checkpoint_interval 的倍数(例如,每 2 步保存检查点,每 4 步进行扰动)。

{LOG_DIR}/{MY_EXPERIMENT_NAME}/ 中,所有变异都记录在 pbt_global.txt 中,而单个策略扰动记录在 pbt_policy_{i}.txt 中。Tune 在每个扰动步骤记录以下信息:目标 trial 标签、克隆 trial 标签、目标 trial 迭代、克隆 trial 迭代、旧 config、新 config。

检查准确性

import matplotlib.pyplot as plt
import os

# Get the best trial result
best_result = results_grid.get_best_result(metric="mean_accuracy", mode="max")

# Print `path` where checkpoints are stored
print('Best result path:', best_result.path)

# Print the best trial `config` reported at the last iteration
# NOTE: This config is just what the trial ended up with at the last iteration.
# See the next section for replaying the entire history of configs.
print("Best final iteration hyperparameter config:\n", best_result.config)

# Plot the learning curve for the best trial
df = best_result.metrics_dataframe
# Deduplicate, since PBT might introduce duplicate data
df = df.drop_duplicates(subset="training_iteration", keep="last")
df.plot("training_iteration", "mean_accuracy")
plt.xlabel("Training Iterations")
plt.ylabel("Test Accuracy")
plt.show()
Best result logdir: /tmp/ray_results/pbt_test/train_convnet_69158_00000_0_lr=0.0701,momentum=0.1774_2022-10-20_11-31-32
Best final iteration hyperparameter config:
 {'lr': 0.07008752890101211, 'momentum': 0.17736213114751204, 'checkpoint_interval': 5}
../../_images/ce1c0fd33903bfaf581ee2245c0dbb68d5c7079ad64a4f35dd0e1278b34e1ea7.png

重放 PBT 运行#

基于群体的训练运行会以完全训练好的模型结束。但是,有时你可能想从头开始训练模型,但使用从 PBT 获得的相同超参数调度。Ray Tune 为此提供了重放工具。

你只需传递要重放的 trial 的策略日志文件。这通常存储在实验目录中,例如 ~/ray_results/pbt_test/pbt_policy_ba982_00000.txt

重放工具读取 trial 的原始配置,并在原始扰动发生时更新它。因此,你可以在重放运行中(并且应该)使用相同的 Trainable。请注意,最终结果不会完全相同,因为只重放超参数配置的更改,而不重放从其他样本加载的检查点。

import glob

from ray import tune
from ray.tune.schedulers import PopulationBasedTrainingReplay

# Get a random replay policy from the experiment we just ran
sample_pbt_trial_log = glob.glob(
    os.path.expanduser("/tmp/ray_results/pbt_test/pbt_policy*.txt")
)[0]
replay = PopulationBasedTrainingReplay(sample_pbt_trial_log)

tuner = tune.Tuner(
    train_convnet,
    tune_config=tune.TuneConfig(scheduler=replay),
    run_config=tune.RunConfig(stop={"training_iteration": 50}),
)
results_grid = tuner.fit()

Tune 状态

当前时间2022-10-20 11:32:49
运行时间00:00:30.39
内存3.8/62.0 GiB

系统信息

PopulationBasedTraining 重放:步骤 39,扰动 2
请求的资源:0/16 CPU,0/0 GPU,0.0/34.21 GiB 堆内存,0.0/17.1 GiB 对象

Trial 状态

Trial 名称状态位置准确率迭代总时间 (s)学习率
train_convnet_87836_00000已终止172.31.111.100:180210.93125 100 21.09940.00720379

Trial 进度

Trial 名称日期完成总回合数实验 ID主机名自恢复以来的迭代次数学习率平均准确率节点 IPpid自恢复以来的时间本次迭代时间 (s)总时间 (s)时间戳自恢复以来的时间步数总时间步数训练迭代trial ID预热时间
train_convnet_87836_000002022-10-20_11-32-49True 2a88b6f21b54451aa81c935c77ffbce5ip-172-31-111-100 610.00720379 0.93125172.31.111.10018021 12.787 0.196162 21.0994 1666290769 0 10087836_00000 0.00894547
2022-10-20 11:32:28,900	INFO pbt.py:1085 -- Population Based Training replay is now at step 32. Configuration will be changed to {'lr': 0.08410503468121452, 'momentum': 0.99, 'checkpoint_interval': 5}.
(train_convnet pid=17974) 2022-10-20 11:32:32,098	INFO trainable.py:772 -- Restored on 172.31.111.100 from checkpoint: /home/ray/ray_results/train_convnet_2022-10-20_11-32-19/train_convnet_87836_00000_0_2022-10-20_11-32-19/checkpoint_tmp4ab367
(train_convnet pid=17974) 2022-10-20 11:32:32,098	INFO trainable.py:781 -- Current state after restoring: {'_iteration': 32, '_timesteps_total': None, '_time_total': 6.83707332611084, '_episodes_total': None}
2022-10-20 11:32:33,575	INFO pbt.py:1085 -- Population Based Training replay is now at step 39. Configuration will be changed to {'lr': 0.007203792764253441, 'momentum': 0.9, 'checkpoint_interval': 5}.
(train_convnet pid=18021) 2022-10-20 11:32:36,764	INFO trainable.py:772 -- Restored on 172.31.111.100 from checkpoint: /home/ray/ray_results/train_convnet_2022-10-20_11-32-19/train_convnet_87836_00000_0_2022-10-20_11-32-19/checkpoint_tmpb82652
(train_convnet pid=18021) 2022-10-20 11:32:36,765	INFO trainable.py:781 -- Current state after restoring: {'_iteration': 39, '_timesteps_total': None, '_time_total': 8.312420129776001, '_episodes_total': None}
2022-10-20 11:32:49,668	INFO tune.py:787 -- Total run time: 30.50 seconds (30.38 seconds for the tuning loop).

示例:使用 PBT 的 DCGAN#

让我们看一个更复杂的例子:训练生成对抗网络 (GAN) (Goodfellow 等人,2014 年)。GAN 框架通过由两个相互竞争的模块(一个生成器和一个判别器)组成的训练范式来学习生成模型。面对次优的超参数选择,GAN 训练可能非常脆弱且不稳定,生成器通常会塌缩到单一模式或完全发散。

正如 基于群体的训练 (PBT) 中介绍的,PBT 可以帮助进行 DCGAN 训练。现在我们将讲解如何在 Tune 中实现。完整的代码示例在 Github 上。

我们使用标准 PyTorch API 定义生成器和判别器

# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm") != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)


# Generator Code
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(nz, ngf * 4, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh(),
        )

    def forward(self, input):
        return self.main(input)


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 4, 1, 4, 1, 0, bias=False),
            nn.Sigmoid(),
        )

    def forward(self, input):
        return self.main(input)


要使用 PBT 训练模型,我们需要为调度器定义一个指标来评估模型候选。对于 GAN 网络,inception score 可以说是最常用的指标。我们训练了一个 mnist 分类模型 (LeNet),并用它对生成的图像进行推理并评估图像质量。

提示

inception score 使用一个训练好的分类模型,我们将其保存在对象存储中,并作为对象引用传递给 inception_score 函数。

class Net(nn.Module):
    """
    LeNet for MNist classification, used for inception_score
    """

    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


def inception_score(imgs, mnist_model_ref, batch_size=32, splits=1):
    N = len(imgs)
    dtype = torch.FloatTensor
    dataloader = torch.utils.data.DataLoader(imgs, batch_size=batch_size)
    cm = ray.get(mnist_model_ref)  # Get the mnist model from Ray object store.
    up = nn.Upsample(size=(28, 28), mode="bilinear").type(dtype)

    def get_pred(x):
        x = up(x)
        x = cm(x)
        return F.softmax(x).data.cpu().numpy()

    preds = np.zeros((N, 10))
    for i, batch in enumerate(dataloader, 0):
        batch = batch.type(dtype)
        batchv = Variable(batch)
        batch_size_i = batch.size()[0]
        preds[i * batch_size : i * batch_size + batch_size_i] = get_pred(batchv)

    # Now compute the mean kl-div
    split_scores = []
    for k in range(splits):
        part = preds[k * (N // splits) : (k + 1) * (N // splits), :]
        py = np.mean(part, axis=0)
        scores = []
        for i in range(part.shape[0]):
            pyx = part[i, :]
            scores.append(entropy(pyx, py))
        split_scores.append(np.exp(np.mean(scores)))

    return np.mean(split_scores), np.std(split_scores)


我们定义了一个训练函数,其中包括一个生成器和一个判别器,每个都拥有独立的学习率和优化器。我们确保为训练实现检查点功能。特别要注意的是,从检查点加载后,我们需要设置优化器的学习率,因为我们希望使用传递给我们的 config 中的扰动配置,而不是与我们正在利用的 trial 完全相同的配置。

def dcgan_train(config):
    use_cuda = config.get("use_gpu") and torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    netD = Discriminator().to(device)
    netD.apply(weights_init)
    netG = Generator().to(device)
    netG.apply(weights_init)
    criterion = nn.BCELoss()
    optimizerD = optim.Adam(
        netD.parameters(), lr=config.get("lr", 0.01), betas=(beta1, 0.999)
    )
    optimizerG = optim.Adam(
        netG.parameters(), lr=config.get("lr", 0.01), betas=(beta1, 0.999)
    )
    with FileLock(os.path.expanduser("~/ray_results/.data.lock")):
        dataloader = get_data_loader()

    step = 1
    checkpoint = tune.get_checkpoint()
    if checkpoint:
        with checkpoint.as_directory() as checkpoint_dir:
            checkpoint_dict = torch.load(os.path.join(checkpoint_dir, "checkpoint.pt"))
        netD.load_state_dict(checkpoint_dict["netDmodel"])
        netG.load_state_dict(checkpoint_dict["netGmodel"])
        optimizerD.load_state_dict(checkpoint_dict["optimD"])
        optimizerG.load_state_dict(checkpoint_dict["optimG"])
        # Note: Make sure to increment the loaded step by 1 to get the
        # current step.
        last_step = checkpoint_dict["step"]
        step = last_step + 1

        # NOTE: It's important to set the optimizer learning rates
        # again, since we want to explore the parameters passed in by PBT.
        # Without this, we would continue using the exact same
        # configuration as the trial whose checkpoint we are exploiting.
        if "netD_lr" in config:
            for param_group in optimizerD.param_groups:
                param_group["lr"] = config["netD_lr"]
        if "netG_lr" in config:
            for param_group in optimizerG.param_groups:
                param_group["lr"] = config["netG_lr"]

    while True:
        lossG, lossD, is_score = train_func(
            netD,
            netG,
            optimizerG,
            optimizerD,
            criterion,
            dataloader,
            step,
            device,
            config["mnist_model_ref"],
        )
        metrics = {"lossg": lossG, "lossd": lossD, "is_score": is_score}

        if step % config["checkpoint_interval"] == 0:
            with tempfile.TemporaryDirectory() as tmpdir:
                torch.save(
                    {
                        "netDmodel": netD.state_dict(),
                        "netGmodel": netG.state_dict(),
                        "optimD": optimizerD.state_dict(),
                        "optimG": optimizerG.state_dict(),
                        "step": step,
                    },
                    os.path.join(tmpdir, "checkpoint.pt"),
                )
                tune.report(metrics, checkpoint=Checkpoint.from_directory(tmpdir))
        else:
            tune.report(metrics)

        step += 1


我们将 inception score 指定为指标并开始调优

import torch
import ray
from ray import tune
from ray.tune.schedulers import PopulationBasedTraining

from ray.tune.examples.pbt_dcgan_mnist.common import Net
from ray.tune.examples.pbt_dcgan_mnist.pbt_dcgan_mnist_func import (
    dcgan_train,
    download_mnist_cnn,
)

# Load the pretrained mnist classification model for inception_score
mnist_cnn = Net()
model_path = download_mnist_cnn()
mnist_cnn.load_state_dict(torch.load(model_path))
mnist_cnn.eval()
# Put the model in Ray object store.
mnist_model_ref = ray.put(mnist_cnn)

perturbation_interval = 5
scheduler = PopulationBasedTraining(
    perturbation_interval=perturbation_interval,
    hyperparam_mutations={
        # Distribution for resampling
        "netG_lr": tune.uniform(1e-2, 1e-5),
        "netD_lr": tune.uniform(1e-2, 1e-5),
    },
)

smoke_test = True  # For testing purposes: set this to False to run the full experiment
tuner = tune.Tuner(
    dcgan_train,
    run_config=tune.RunConfig(
        name="pbt_dcgan_mnist_tutorial",
        stop={"training_iteration": 5 if smoke_test else 150},
    ),
    tune_config=tune.TuneConfig(
        metric="is_score",
        mode="max",
        num_samples=2 if smoke_test else 8,
        scheduler=scheduler,
    ),
    param_space={
        # Define how initial values of the learning rates should be chosen.
        "netG_lr": tune.choice([0.0001, 0.0002, 0.0005]),
        "netD_lr": tune.choice([0.0001, 0.0002, 0.0005]),
        "mnist_model_ref": mnist_model_ref,
        "checkpoint_interval": perturbation_interval,
    },
)
results_grid = tuner.fit()

可以从检查点加载训练好的生成器模型,以从噪声信号生成数字图像。

可视化#

下面,我们可视化训练日志中不断增加的 inception score。

import matplotlib.pyplot as plt

# Uncomment to apply plotting styles
# !pip install seaborn
# import seaborn as sns
# sns.set_style("darkgrid")

result_dfs = [result.metrics_dataframe for result in results_grid]
best_result = results_grid.get_best_result(metric="is_score", mode="max")

plt.figure(figsize=(7, 4))
for i, df in enumerate(result_dfs):
    plt.plot(df["is_score"], label=i)
plt.legend()
plt.title("Inception Score During Training")
plt.xlabel("Training Iterations")
plt.ylabel("Inception Score")
plt.show()
../../_images/60e0d8726f400ee9d95b197738260f8989346ef4fe60d815b275d3d0b3c4b1f9.png

接下来,我们看看生成器和判别器的损失

fig, axs = plt.subplots(1, 2, figsize=(12, 4))

for i, df in enumerate(result_dfs):
    axs[0].plot(df["lossg"], label=i)
axs[0].legend()
axs[0].set_title("Generator Loss During Training")
axs[0].set_xlabel("Training Iterations")
axs[0].set_ylabel("Generator Loss")

for i, df in enumerate(result_dfs):
    axs[1].plot(df["lossd"], label=i)
axs[1].legend()
axs[1].set_title("Discriminator Loss During Training")
axs[1].set_xlabel("Training Iterations")
axs[1].set_ylabel("Discriminator Loss")

plt.show()
../../_images/6eeb09c68eafec5e2618a6c9485d8d03c0980af88139b9c100a0640f21c21def.png
from ray.tune.examples.pbt_dcgan_mnist.common import demo_gan

with best_result.checkpoint.as_directory() as best_checkpoint:
    demo_gan([best_checkpoint])
../../_images/dea1643da130b41b1e6280052aa99f0df3a03448fa7d8412a3c8e0d7a3091ae9.png

MNIST 生成器的训练应该需要几分钟。该示例可以轻松修改,以生成其他数据集(例如 cifar10 或 LSUN)的图像。

总结#

本教程涵盖了

  1. 使用基于群体的训练调优深度学习超参数的两个示例(CNN 和 GAN 训练)

  2. 保存和加载检查点并确保所有超参数都被使用(例如:优化器状态)

  3. 训练后可视化报告的指标

要了解更多信息,请查看下一篇教程 可视化基于群体的训练 (PBT) 超参数优化,这是一份理解 PBT 及其底层行为的可视化指南。

如果你有任何问题、建议或遇到任何问题,请在 DiscussGitHubRay Slack 上联系我们!