如何在 PyTorch 中使用 Tune#

try-anyscale-quickstart

在此教程中,我们将向你展示如何将 Tune 集成到你的 PyTorch 训练工作流程中。我们将遵循 PyTorch 文档中的此教程,用于训练 CIFAR10 图像分类器。

../../_images/pytorch_logo.png

超参数调优可以使模型的性能从一般提升到高度准确。通常,选择不同的学习率或改变网络层大小等简单的事情就能对你的模型性能产生巨大影响。幸运的是,Tune 可以轻松探索这些最佳参数组合——并且与 PyTorch 很好地协作。

如你所见,我们只需要做一些微小的修改。特别是,我们需要

  1. 将数据加载和训练包装在函数中,

  2. 使一些网络参数可配置,

  3. 添加检查点(可选),

  4. 并定义模型调优的搜索空间

设置 / 导入#

首先,所需条件(取消注释即可安装)

#!pip install -Uq "ray[tune]" torch torchvision pandas ipywidgets

接下来,导入

import os
import tempfile

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from filelock import FileLock
from torch.utils.data import random_split

from ray import train, tune
from ray.tune.schedulers import ASHAScheduler

大多数导入是构建 PyTorch 模型所需的。只有最后三个导入用于 Ray Tune。

数据加载器#

我们将数据加载器包装在它们自己的函数中,并传递一个全局数据目录。这样我们就可以在不同的 trial 之间共享一个数据目录。

def load_data(data_dir="./data"):
    """Create dataloaders for normalized CIFAR10 training/test subsets."""
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    # We add FileLock here because multiple workers will want to
    # download data, and this may cause overwrites since
    # DataLoader is not threadsafe.
    with FileLock(os.path.expanduser("~/.data.lock")):
        trainset = torchvision.datasets.CIFAR10(
            root=data_dir, train=True, download=True, transform=transform)

        testset = torchvision.datasets.CIFAR10(
            root=data_dir, train=False, download=True, transform=transform)

    return trainset, testset

def create_dataloaders(trainset, batch_size, num_workers=8):
    """Create train/val splits and dataloaders."""
    train_size = int(len(trainset) * 0.8)
    train_subset, val_subset = random_split(
        trainset, [train_size, len(trainset) - train_size])

    train_loader = torch.utils.data.DataLoader(
        train_subset,
        batch_size=batch_size, 
        shuffle=True,
        num_workers=num_workers
    )
    val_loader = torch.utils.data.DataLoader(
        val_subset,
        batch_size=batch_size,
        shuffle=False, 
        num_workers=num_workers
    )
    return train_loader, val_loader
def load_test_data():
    # Load fake data for running a quick smoke-test.
    trainset = torchvision.datasets.FakeData(
        128, (3, 32, 32), num_classes=10, transform=transforms.ToTensor()
    )
    testset = torchvision.datasets.FakeData(
        16, (3, 32, 32), num_classes=10, transform=transforms.ToTensor()
    )
    return trainset, testset

可配置的神经网络#

我们只能调优那些可配置的参数。在此示例中,我们可以指定全连接层的层大小

class Net(nn.Module):
    def __init__(self, l1=120, l2=84):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, l1)
        self.fc2 = nn.Linear(l1, l2)
        self.fc3 = nn.Linear(l2, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

训练函数#

现在变得有趣了,因为我们对 PyTorch 文档中的示例做了一些修改。

完整的代码示例如下

def train_cifar(config):
    net = Net(config["l1"], config["l2"])
    device = config["device"]
    if device == "cuda":
        net = nn.DataParallel(net)
    net.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=config["lr"], momentum=0.9, weight_decay=5e-5)

    # Load existing checkpoint through `get_checkpoint()` API.
    if tune.get_checkpoint():
        loaded_checkpoint = tune.get_checkpoint()
        with loaded_checkpoint.as_directory() as loaded_checkpoint_dir:
            model_state, optimizer_state = torch.load(
                os.path.join(loaded_checkpoint_dir, "checkpoint.pt")
            )
            net.load_state_dict(model_state)
            optimizer.load_state_dict(optimizer_state)

    # Data setup
    if config["smoke_test"]:
        trainset, _ = load_test_data()
    else:
        trainset, _ = load_data()
    train_loader, val_loader = create_dataloaders(
        trainset, 
        config["batch_size"],
        num_workers=0 if config["smoke_test"] else 8
    )

    for epoch in range(config["max_num_epochs"]):  # loop over the dataset multiple times
        net.train()
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            # forward + backward + optimize
            optimizer.zero_grad()  # reset gradients
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()

        # Validation
        net.eval()
        val_loss = 0.0
        correct = total = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = net(inputs)
                val_loss += criterion(outputs, labels).item()
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()

        # Report metrics
        metrics = {
            "loss": val_loss / len(val_loader),
            "accuracy": correct / total,
        }

        # Here we save a checkpoint. It is automatically registered with
        # Ray Tune and will potentially be accessed through in ``get_checkpoint()``
        # in future iterations.
        # Note to save a file-like checkpoint, you still need to put it under a directory
        # to construct a checkpoint.
        with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
            path = os.path.join(temp_checkpoint_dir, "checkpoint.pt")
            torch.save(
                (net.state_dict(), optimizer.state_dict()), path
            )
            checkpoint = tune.Checkpoint.from_directory(temp_checkpoint_dir)
            tune.report(metrics, checkpoint=checkpoint)
    print("Finished Training!")

如你所见,大部分代码都是直接改编自示例。

测试集准确率#

通常,机器学习模型的性能在从未用于训练模型的保留测试集上进行测试。我们也将此包装在一个函数中

def test_best_model(best_result, smoke_test=False):
    best_trained_model = Net(best_result.config["l1"], best_result.config["l2"])
    device = best_result.config["device"]
    if device == "cuda":
        best_trained_model = nn.DataParallel(best_trained_model)
    best_trained_model.to(device)

    checkpoint_path = os.path.join(best_result.checkpoint.to_directory(), "checkpoint.pt")

    model_state, _optimizer_state = torch.load(checkpoint_path)
    best_trained_model.load_state_dict(model_state)

    if smoke_test:
        _trainset, testset = load_test_data()
    else:
        _trainset, testset = load_data()

    testloader = torch.utils.data.DataLoader(
        testset, batch_size=4, shuffle=False, num_workers=2
    )

    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = best_trained_model(images)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    print(f"Best trial test set accuracy: {correct / total}")

如你所见,该函数还期望一个 device 参数,这样我们就可以在 GPU 上进行测试集验证。

配置搜索空间#

最后,我们需要定义 Tune 的搜索空间。示例如下

# Set this to True for a smoke test that runs with a small synthetic dataset.
SMOKE_TEST = False
隐藏代码单元格内容
# For CI testing:
SMOKE_TEST = True
config = {
    "l1": tune.sample_from(lambda _: 2**np.random.randint(2, 9)),
    "l2": tune.sample_from(lambda _: 2**np.random.randint(2, 9)),
    "lr": tune.loguniform(1e-4, 1e-1),
    "batch_size": tune.choice([2, 4, 8, 16]),
    "smoke_test": SMOKE_TEST,
    "num_trials": 10 if not SMOKE_TEST else 2,
    "max_num_epochs": 10 if not SMOKE_TEST else 2,
    "device": "cuda" if torch.cuda.is_available() else "cpu",
}

函数 tune.sample_from() 使得可以定义自己的采样方法来获取超参数。在此示例中,层大小 l1l2 应该是 4 到 256 之间的 2 的幂,即 4、8、16、32、64、128 或 256。 lr(学习率)应该在 0.0001 和 0.1 之间均匀采样。最后,批量大小在 2、4、8 和 16 中选择。

在每个 trial 中,Tune 现在将从这些搜索空间中随机采样参数组合。然后它将并行训练多个模型,并在其中找到性能最好的模型。我们还使用了 ASHAScheduler,它将提前终止表现不佳的 trial。

你可以指定 CPU 的数量,例如可用于增加 PyTorch DataLoader 实例的 num_workers。在每个 trial 中,所选的 GPU 数量对 PyTorch 可见。Trial 无法访问未为其请求的 GPU——因此你不必担心两个 trial 使用同一组资源。

在这里,我们还可以指定小数 GPU,例如 gpus_per_trial=0.5 是完全有效的。然后 trial 将互相共享 GPU。你只需确保模型仍然适合 GPU 内存即可。

训练模型后,我们将找到性能最好的模型,并从检查点文件加载训练好的网络。然后我们获取测试集准确率并通过打印报告所有信息。

完整的 main 函数如下所示

def main(config, gpus_per_trial=1):
    scheduler = ASHAScheduler(
        time_attr="training_iteration",
        max_t=config["max_num_epochs"],
        grace_period=1,
        reduction_factor=2)
    
    tuner = tune.Tuner(
        tune.with_resources(
            tune.with_parameters(train_cifar),
            resources={"cpu": 2, "gpu": gpus_per_trial}
        ),
        tune_config=tune.TuneConfig(
            metric="loss",
            mode="min",
            scheduler=scheduler,
            num_samples=config["num_trials"],
        ),
        param_space=config,
    )
    results = tuner.fit()
    
    best_result = results.get_best_result("loss", "min")

    print(f"Best trial config: {best_result.config}")
    print(f"Best trial final validation loss: {best_result.metrics['loss']}")
    print(f"Best trial final validation accuracy: {best_result.metrics['accuracy']}")

    test_best_model(best_result, smoke_test=config["smoke_test"])

main(config, gpus_per_trial=1 if torch.cuda.is_available() else 0)

如果你运行代码,输出示例如下所示

  Number of trials: 10 (10 TERMINATED)
  +-------------------------+------------+-------+------+------+-------------+--------------+---------+------------+----------------------+
  | Trial name              | status     | loc   |   l1 |   l2 |          lr |   batch_size |    loss |   accuracy |   training_iteration |
  |-------------------------+------------+-------+------+------+-------------+--------------+---------+------------+----------------------|
  | train_cifar_87d1f_00000 | TERMINATED |       |   64 |    4 | 0.00011629  |            2 | 1.87273 |     0.244  |                    2 |
  | train_cifar_87d1f_00001 | TERMINATED |       |   32 |   64 | 0.000339763 |            8 | 1.23603 |     0.567  |                    8 |
  | train_cifar_87d1f_00002 | TERMINATED |       |    8 |   16 | 0.00276249  |           16 | 1.1815  |     0.5836 |                   10 |
  | train_cifar_87d1f_00003 | TERMINATED |       |    4 |   64 | 0.000648721 |            4 | 1.31131 |     0.5224 |                    8 |
  | train_cifar_87d1f_00004 | TERMINATED |       |   32 |   16 | 0.000340753 |            8 | 1.26454 |     0.5444 |                    8 |
  | train_cifar_87d1f_00005 | TERMINATED |       |    8 |    4 | 0.000699775 |            8 | 1.99594 |     0.1983 |                    2 |
  | train_cifar_87d1f_00006 | TERMINATED |       |  256 |    8 | 0.0839654   |           16 | 2.3119  |     0.0993 |                    1 |
  | train_cifar_87d1f_00007 | TERMINATED |       |   16 |  128 | 0.0758154   |           16 | 2.33575 |     0.1327 |                    1 |
  | train_cifar_87d1f_00008 | TERMINATED |       |   16 |    8 | 0.0763312   |           16 | 2.31129 |     0.1042 |                    4 |
  | train_cifar_87d1f_00009 | TERMINATED |       |  128 |   16 | 0.000124903 |            4 | 2.26917 |     0.1945 |                    1 |
  +-------------------------+------------+-------+------+------+-------------+--------------+---------+------------+----------------------+


  Best trial config: {'l1': 8, 'l2': 16, 'lr': 0.0027624906698231976, 'batch_size': 16, 'data_dir': '...'}
  Best trial final validation loss: 1.1815014744281769
  Best trial final validation accuracy: 0.5836
  Best trial test set accuracy: 0.5806

如你所见,为了避免浪费资源,大多数 trial 都被提前停止了。性能最好的 trial 实现了约 58% 的验证准确率,这可以在测试集上得到确认。

就是这样!你现在可以调优你的 PyTorch 模型的参数了。

查看更多 PyTorch 示例#

  • MNIST PyTorch 示例:将 PyTorch MNIST 示例转换为使用基于函数的 API 的 Tune。还展示了如何轻松地将依赖 argparse 的内容转换为使用 Tune。

  • PBT ConvNet 示例:使用函数 API 训练带有检查点的 ConvNet 示例。

  • MNIST PyTorch Trainable 示例:将 PyTorch MNIST 示例转换为使用 Trainable API 的 Tune。还使用了 HyperBandScheduler 并在最后保存了模型检查点。