使用 Ray Train 微调 PyTorch 图像分类器#

try-anyscale-quickstart

本示例使用 Ray Train 微调预训练的 ResNet 模型。

在此示例中,网络架构由预训练的 ResNet 模型的中间层输出组成,该输出馈送到一个随机初始化的线性层,该层为我们的新任务输出分类 logit。

加载和预处理微调数据集#

本示例改编自 PyTorch 的 Finetuning Torchvision Models 教程。我们将使用 _hymenoptera_data_ 作为微调数据集,它包含两个类别(蜜蜂和蚂蚁)以及总共 397 张图像(训练和验证集)。这是一个非常小的数据集,仅用于演示目的。

该数据集可在 此处 公开获取。请注意,它的结构是目录名作为标签。使用 torchvision.datasets.ImageFolder() 加载图像及其对应的标签。

import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, models, transforms
import numpy as np

# Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {
    "train": transforms.Compose(
        [
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    ),
    "val": transforms.Compose(
        [
            transforms.Resize(224),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    ),
}

def download_datasets():
    os.system(
        "wget https://download.pytorch.org/tutorial/hymenoptera_data.zip >/dev/null 2>&1"
    )
    os.system("unzip hymenoptera_data.zip >/dev/null 2>&1")

# Download and build torch datasets
def build_datasets():
    torch_datasets = {}
    for split in ["train", "val"]:
        torch_datasets[split] = datasets.ImageFolder(
            os.path.join("./hymenoptera_data", split), data_transforms[split]
        )
    return torch_datasets

初始化模型和微调配置#

接下来,我们定义将在稍后传递到训练循环函数中的训练配置。

train_loop_config = {
    "input_size": 224,  # Input image size (224 x 224)
    "batch_size": 32,  # Batch size for training
    "num_epochs": 10,  # Number of epochs to train for
    "lr": 0.001,  # Learning Rate
    "momentum": 0.9,  # SGD optimizer momentum
}

接下来,我们定义我们的模型。您可以从预训练权重创建模型,也可以重新加载先前运行的模型检查点。

import os
import torch
from ray.train import Checkpoint

# Option 1: Initialize model with pretrained weights
def initialize_model():
    # Load pretrained model params
    model = models.resnet50(pretrained=True)

    # Replace the original classifier with a new Linear layer
    num_features = model.fc.in_features
    model.fc = nn.Linear(num_features, 2)

    # Ensure all params get updated during finetuning
    for param in model.parameters():
        param.requires_grad = True
    return model


# Option 2: Initialize model with an Train checkpoint
# Replace this with your own uri
CHECKPOINT_FROM_S3 = Checkpoint(
    path="s3://air-example-data/finetune-resnet-checkpoint/TorchTrainer_4f69f_00000_0_2023-02-14_14-04-09/checkpoint_000001/"
)


def initialize_model_from_checkpoint(checkpoint: Checkpoint):
    with checkpoint.as_directory() as tmpdir:
        state_dict = torch.load(os.path.join(tmpdir, "checkpoint.pt"))
    resnet50 = initialize_model()
    resnet50.load_state_dict(state_dict["model"])
    return resnet50

定义训练循环#

train_loop_per_worker 函数定义了每个工作程序的微调过程。

1. 为每个工作程序准备数据加载器:

  • 本教程假定您正在使用 PyTorch 的原生 torch.utils.data.Dataset 进行数据输入。 train.torch.prepare_data_loader() 为分布式执行准备了您的 DataLoader。您也可以使用 Ray Data 进行更高效的预处理。有关使用 Ray Data 处理图像的更多详细信息,请参阅 处理图像 Ray Data 用户指南。

2. 准备您的模型:

  • train.torch.prepare_model() 为分布式训练准备模型。在底层,它将您的 torch 模型转换为 DistributedDataParallel 模型,该模型会同步所有工作程序之间的权重。

3. 报告指标和检查点:

import os
from tempfile import TemporaryDirectory

import ray.train as train
from ray.train import Checkpoint



def evaluate(logits, labels):
    _, preds = torch.max(logits, 1)
    corrects = torch.sum(preds == labels).item()
    return corrects


def train_loop_per_worker(configs):
    import warnings

    warnings.filterwarnings("ignore")

    # Calculate the batch size for a single worker
    worker_batch_size = configs["batch_size"] // train.get_context().get_world_size()

    # Download dataset once on local rank 0 worker
    if train.get_context().get_local_rank() == 0:
        download_datasets()
    torch.distributed.barrier()

    # Build datasets on each worker
    torch_datasets = build_datasets()

    # Prepare dataloader for each worker
    dataloaders = dict()
    dataloaders["train"] = DataLoader(
        torch_datasets["train"], batch_size=worker_batch_size, shuffle=True
    )
    dataloaders["val"] = DataLoader(
        torch_datasets["val"], batch_size=worker_batch_size, shuffle=False
    )

    # Distribute
    dataloaders["train"] = train.torch.prepare_data_loader(dataloaders["train"])
    dataloaders["val"] = train.torch.prepare_data_loader(dataloaders["val"])

    device = train.torch.get_device()

    # Prepare DDP Model, optimizer, and loss function
    model = initialize_model()
    model = train.torch.prepare_model(model)

    optimizer = optim.SGD(
        model.parameters(), lr=configs["lr"], momentum=configs["momentum"]
    )
    criterion = nn.CrossEntropyLoss()

    # Start training loops
    for epoch in range(configs["num_epochs"]):
        # Each epoch has a training and validation phase
        for phase in ["train", "val"]:
            if phase == "train":
                model.train()  # Set model to training mode
            else:
                model.eval()  # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            if train.get_context().get_world_size() > 1:
                dataloaders[phase].sampler.set_epoch(epoch)

            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                with torch.set_grad_enabled(phase == "train"):
                    # Get model outputs and calculate loss
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == "train":
                        loss.backward()
                        optimizer.step()

                # calculate statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += evaluate(outputs, labels)

            size = len(torch_datasets[phase]) // train.get_context().get_world_size()
            epoch_loss = running_loss / size
            epoch_acc = running_corrects / size

            if train.get_context().get_world_rank() == 0:
                print(
                    "Epoch {}-{} Loss: {:.4f} Acc: {:.4f}".format(
                        epoch, phase, epoch_loss, epoch_acc
                    )
                )

            # Report metrics and checkpoint every epoch
            if phase == "val":
                with TemporaryDirectory() as tmpdir:
                    state_dict = {
                        "epoch": epoch,
                        "model": model.module.state_dict(),
                        "optimizer_state_dict": optimizer.state_dict(),
                    }
                    torch.save(state_dict, os.path.join(tmpdir, "checkpoint.pt"))
                    train.report(
                        metrics={"loss": epoch_loss, "acc": epoch_acc},
                        checkpoint=Checkpoint.from_directory(tmpdir),
                    )

接下来,设置 TorchTrainer

from ray.train.torch import TorchTrainer
from ray.train import ScalingConfig, RunConfig, CheckpointConfig

# Scale out model training across 4 GPUs.
scaling_config = ScalingConfig(
    num_workers=4, use_gpu=True, resources_per_worker={"CPU": 1, "GPU": 1}
)

# Save the latest checkpoint
checkpoint_config = CheckpointConfig(num_to_keep=1)

# Set experiment name and checkpoint configs
run_config = RunConfig(
    name="finetune-resnet",
    storage_path="/tmp/ray_results",
    checkpoint_config=checkpoint_config,
)
trainer = TorchTrainer(
    train_loop_per_worker=train_loop_per_worker,
    train_loop_config=train_loop_config,
    scaling_config=scaling_config,
    run_config=run_config,
)

result = trainer.fit()
print(result)

加载检查点进行预测:#

元数据和检查点已在 TorchTrainer 中指定的 storage_path 中保存。

现在我们需要加载训练好的模型并在测试数据上对其进行评估。最佳模型参数已保存在 log_dir 中。我们可以使用先前定义的 initialize_model_from_checkpoint() 函数从我们的微调运行中加载结果检查点。

model = initialize_model_from_checkpoint(result.checkpoint)
device = torch.device("cuda")

最后,定义一个简单的评估循环并检查检查点模型的性能。

model = model.to(device)
model.eval()

download_datasets()
torch_datasets = build_datasets()
dataloader = DataLoader(torch_datasets["val"], batch_size=32, num_workers=4)
corrects = 0
for inputs, labels in dataloader:
    inputs = inputs.to(device)
    labels = labels.to(device)
    preds = model(inputs)
    corrects += evaluate(preds, labels)

print("Accuracy: ", corrects / len(dataloader.dataset))
Accuracy:  0.934640522875817