使用 Intel Gaudi 训练 ResNet 模型#

try-anyscale-quickstart

在本 Jupyter notebook 中,我们将使用 HPU 训练 ResNet-50 模型来对蚂蚁和蜜蜂的图像进行分类。我们将使用 PyTorch 进行模型训练,使用 Ray 进行分布式训练。数据集将使用 torchvision 的 datasets 和 transforms 进行下载和处理。

Intel Gaudi AI 处理器 (HPUs) 是由 Intel Habana Labs 设计的 AI 硬件加速器。更多信息请参见 Gaudi 架构Gaudi 开发者文档

配置#

运行此示例需要安装了 Gaudi/Gaudi2 的节点。Gaudi 和 Gaudi2 都有 8 个 HPU。我们将使用 2 个 worker 来训练模型,每个 worker 使用 1 个 HPU。

我们建议使用预构建容器来运行这些示例。要运行容器,您需要 Docker。安装说明请参见 安装 Docker Engine

接下来,按照 使用容器运行 来安装 Gaudi 驱动程序和容器运行时。

接下来,启动 Gaudi 容器

docker pull vault.habana.ai/gaudi-docker/1.20.0/ubuntu22.04/habanalabs/pytorch-installer-2.6.0:latest
docker run -it --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --net=host --ipc=host vault.habana.ai/gaudi-docker/1.20.0/ubuntu22.04/habanalabs/pytorch-installer-2.6.0:latest

在容器内部,安装 Ray 和 Jupyter 以运行此 notebook。

pip install ray[train] notebook
import os
from typing import Dict
from tempfile import TemporaryDirectory

import torch
from filelock import FileLock
from torch import nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
from tqdm import tqdm

import ray
import ray.train as train
from ray.train import ScalingConfig, Checkpoint
from ray.train.torch import TorchTrainer
from ray.train.torch import TorchConfig
from ray.runtime_env import RuntimeEnv

import habana_frameworks.torch.core as htcore

定义数据转换#

我们将设置数据转换,用于预处理训练集和验证集的图像。这包括训练集的随机裁剪、翻转和归一化,以及验证集的 resizing 和归一化。

# Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {
    "train": transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]),
    "val": transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]),
}

数据集下载函数#

我们将定义一个函数来下载膜翅目数据集(Hymenoptera dataset)。该数据集包含蚂蚁和蜜蜂的图像,用于二分类问题。

def download_datasets():
    os.system("wget https://download.pytorch.org/tutorial/hymenoptera_data.zip >/dev/null 2>&1")
    os.system("unzip hymenoptera_data.zip >/dev/null 2>&1")

数据集准备函数#

下载数据集后,我们需要构建用于训练和验证的 PyTorch 数据集。build_datasets 函数将应用之前定义的数据转换并创建数据集。

def build_datasets():
    torch_datasets = {}
    for split in ["train", "val"]:
        torch_datasets[split] = datasets.ImageFolder(
            os.path.join("./hymenoptera_data", split), data_transforms[split]
        )
    return torch_datasets

模型初始化函数#

我们将定义两个函数来初始化我们的模型。initialize_model 函数将加载一个预训练的 ResNet-50 模型,并替换最终分类层以用于我们的二分类任务。initialize_model_from_checkpoint 函数将在可用时从保存的检查点加载模型。

def initialize_model():
    # Load pretrained model params
    model = models.resnet50(pretrained=True)

    # Replace the original classifier with a new Linear layer
    num_features = model.fc.in_features
    model.fc = nn.Linear(num_features, 2)

    # Ensure all params get updated during finetuning
    for param in model.parameters():
        param.requires_grad = True
    return model

评估函数#

为了评估训练期间模型的性能,我们定义了一个 evaluate 函数。该函数通过比较预测标签和真实标签来计算正确预测的数量。

def evaluate(logits, labels):
    _, preds = torch.max(logits, 1)
    corrects = torch.sum(preds == labels).item()
    return corrects

训练循环函数#

此函数定义了每个 worker 将执行的训练循环。它包括下载数据集、准备数据加载器、初始化模型以及运行训练和验证阶段。与 GPU 的训练函数相比,移植到 HPU 无需进行任何更改。Ray Train 内部会执行这些操作:

  • 检测 HPU 并设置设备。

  • 初始化 habana PyTorch 后端。

  • 初始化 habana 分布式后端。

def train_loop_per_worker(configs):
    import warnings

    warnings.filterwarnings("ignore")

    # Calculate the batch size for a single worker
    worker_batch_size = configs["batch_size"] // train.get_context().get_world_size()

    # Download dataset once on local rank 0 worker
    if train.get_context().get_local_rank() == 0:
        download_datasets()
    torch.distributed.barrier()

    # Build datasets on each worker
    torch_datasets = build_datasets()

    # Prepare dataloader for each worker
    dataloaders = dict()
    dataloaders["train"] = DataLoader(
        torch_datasets["train"], batch_size=worker_batch_size, shuffle=True
    )
    dataloaders["val"] = DataLoader(
        torch_datasets["val"], batch_size=worker_batch_size, shuffle=False
    )

    # Distribute
    dataloaders["train"] = train.torch.prepare_data_loader(dataloaders["train"])
    dataloaders["val"] = train.torch.prepare_data_loader(dataloaders["val"])

    # Obtain HPU device automatically
    device = train.torch.get_device()

    # Prepare DDP Model, optimizer, and loss function
    model = initialize_model()
    model = model.to(device)

    optimizer = optim.SGD(
        model.parameters(), lr=configs["lr"], momentum=configs["momentum"]
    )
    criterion = nn.CrossEntropyLoss()

    # Start training loops
    for epoch in range(configs["num_epochs"]):
        # Each epoch has a training and validation phase
        for phase in ["train", "val"]:
            if phase == "train":
                model.train()  # Set model to training mode
            else:
                model.eval()  # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                with torch.set_grad_enabled(phase == "train"):
                    # Get model outputs and calculate loss
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == "train":
                        loss.backward()
                        optimizer.step()

                # calculate statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += evaluate(outputs, labels)

            size = len(torch_datasets[phase]) // train.get_context().get_world_size()
            epoch_loss = running_loss / size
            epoch_acc = running_corrects / size

            if train.get_context().get_world_rank() == 0:
                print(
                    "Epoch {}-{} Loss: {:.4f} Acc: {:.4f}".format(
                        epoch, phase, epoch_loss, epoch_acc
                    )
                )

            # Report metrics and checkpoint every epoch
            if phase == "val":
                train.report(
                    metrics={"loss": epoch_loss, "acc": epoch_acc},
                )

主训练函数#

train_resnet 函数使用 Ray 设置分布式训练环境并启动训练过程。它指定了 SGD 优化器的批量大小、epoch 数、学习率和动量。要启用使用 HPU 进行训练,我们只需进行以下更改:

  • 在 ScalingConfig 中为每个 worker 要求一个 HPU

  • 在 TorchConfig 中将后端设置为“hccl”

def train_resnet(num_workers=2):
    global_batch_size = 16

    train_loop_config = {
        "input_size": 224,  # Input image size (224 x 224)
        "batch_size": 32,  # Batch size for training
        "num_epochs": 10,  # Number of epochs to train for
        "lr": 0.001,  # Learning Rate
        "momentum": 0.9,  # SGD optimizer momentum
    }
    # Configure computation resources
    # In ScalingConfig, require an HPU for each worker
    scaling_config = ScalingConfig(num_workers=num_workers, resources_per_worker={"CPU": 1, "HPU": 1})
    # Set backend to hccl in TorchConfig
    torch_config = TorchConfig(backend = "hccl")
    
    ray.init()
    
    # Initialize a Ray TorchTrainer
    trainer = TorchTrainer(
        train_loop_per_worker=train_loop_per_worker,
        train_loop_config=train_loop_config,
        torch_config=torch_config,
        scaling_config=scaling_config,
    )

    result = trainer.fit()
    print(f"Training result: {result}")

开始训练#

最后,我们调用 train_resnet 函数来开始训练过程。您可以调整要使用的 worker 数量。在运行此单元格之前,请确保您的环境中已正确设置 Ray 以处理分布式训练。

注意:以下警告是正常的,并在 SynapseAI 1.14.0+ 版本中得到解决

/usr/local/lib/python3.10/dist-packages/torch/distributed/distributed_c10d.py:252: UserWarning: Device capability of hccl unspecified, assuming `cpu` and `cuda`. Please specify it via the `devices` argument of `register_backend`.
train_resnet(num_workers=2) 

可能的输出#

2025-03-03 03:32:12,620 INFO worker.py:1841 -- Started a local Ray instance.
/usr/local/lib/python3.10/dist-packages/ray/tune/impl/tuner_internal.py:125: RayDeprecationWarning: The `RunConfig` class should be imported from `ray.tune` when passing it to the Tuner. Please update your imports. See this issue for more context and migration options: https://github.com/ray-project/ray/issues/49454. Disable these warnings by setting the environment variable: RAY_TRAIN_ENABLE_V2_MIGRATION_WARNINGS=0
  _log_deprecation_warning(
(RayTrainWorker pid=63669) Setting up process group for: env:// [rank=0, world_size=2]
(TorchTrainer pid=63280) Started distributed worker processes: 
(TorchTrainer pid=63280) - (node_id=9f2c34ea47fe405f3227e9168aa857f81655a83e95fd6be359fd76db, ip=100.83.111.228, pid=63669) world_rank=0, local_rank=0, node_rank=0
(TorchTrainer pid=63280) - (node_id=9f2c34ea47fe405f3227e9168aa857f81655a83e95fd6be359fd76db, ip=100.83.111.228, pid=63668) world_rank=1, local_rank=1, node_rank=0
(RayTrainWorker pid=63669) ============================= HABANA PT BRIDGE CONFIGURATION =========================== 
(RayTrainWorker pid=63669)  PT_HPU_LAZY_MODE = 1
(RayTrainWorker pid=63669)  PT_HPU_RECIPE_CACHE_CONFIG = ,false,1024
(RayTrainWorker pid=63669)  PT_HPU_MAX_COMPOUND_OP_SIZE = 9223372036854775807
(RayTrainWorker pid=63669)  PT_HPU_LAZY_ACC_PAR_MODE = 1
(RayTrainWorker pid=63669)  PT_HPU_ENABLE_REFINE_DYNAMIC_SHAPES = 0
(RayTrainWorker pid=63669)  PT_HPU_EAGER_PIPELINE_ENABLE = 1
(RayTrainWorker pid=63669)  PT_HPU_EAGER_COLLECTIVE_PIPELINE_ENABLE = 1
(RayTrainWorker pid=63669)  PT_HPU_ENABLE_LAZY_COLLECTIVES = 0
(RayTrainWorker pid=63669) ---------------------------: System Configuration :---------------------------
(RayTrainWorker pid=63669) Num CPU Cores : 160
(RayTrainWorker pid=63669) CPU RAM       : 1056374420 KB
(RayTrainWorker pid=63669) ------------------------------------------------------------------------------
(RayTrainWorker pid=63668) Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
  0%|          | 0.00/97.8M [00:00<?, ?B/s]
  9%|▊         | 8.38M/97.8M [00:00<00:01, 87.7MB/s]
100%|██████████| 97.8M/97.8M [00:00<00:00, 193MB/s]
100%|██████████| 97.8M/97.8M [00:00<00:00, 203MB/s]

View detailed results here: /root/ray_results/TorchTrainer_2025-03-03_03-32-15
To visualize your results with TensorBoard, run: `tensorboard --logdir /tmp/ray/session_2025-03-03_03-32-10_695011_53838/artifacts/2025-03-03_03-32-15/TorchTrainer_2025-03-03_03-32-15/driver_artifacts`

Training started with configuration:
╭──────────────────────────────────────╮
│ Training config                      │
├──────────────────────────────────────┤
│ train_loop_config/batch_size      32 │
│ train_loop_config/input_size     224 │
│ train_loop_config/lr           0.001 │
│ train_loop_config/momentum       0.9 │
│ train_loop_config/num_epochs      10 │
╰──────────────────────────────────────╯
(RayTrainWorker pid=63669) Epoch 0-train Loss: 0.6574 Acc: 0.6066

Training finished iteration 1 at 2025-03-03 03:32:45. Total running time: 29s
╭───────────────────────────────╮
│ Training result               │
├───────────────────────────────┤
│ checkpoint_dir_name           │
│ time_this_iter_s       24.684 │
│ time_total_s           24.684 │
│ training_iteration          1 │
│ acc                   0.71053 │
│ loss                  0.51455 │
╰───────────────────────────────╯
(RayTrainWorker pid=63669) Epoch 0-val Loss: 0.5146 Acc: 0.7105
(RayTrainWorker pid=63669) Epoch 1-train Loss: 0.5016 Acc: 0.7541

Training finished iteration 2 at 2025-03-03 03:32:46. Total running time: 31s
╭───────────────────────────────╮
│ Training result               │
├───────────────────────────────┤
│ checkpoint_dir_name           │
│ time_this_iter_s      1.39649 │
│ time_total_s          26.0805 │
│ training_iteration          2 │
│ acc                   0.93421 │
│ loss                  0.30218 │
╰───────────────────────────────╯
(RayTrainWorker pid=63669) Epoch 1-val Loss: 0.3022 Acc: 0.9342
(RayTrainWorker pid=63669) Epoch 2-train Loss: 0.3130 Acc: 0.9180

Training finished iteration 3 at 2025-03-03 03:32:47. Total running time: 32s
╭───────────────────────────────╮
│ Training result               │
├───────────────────────────────┤
│ checkpoint_dir_name           │
│ time_this_iter_s      1.37042 │
│ time_total_s          27.4509 │
│ training_iteration          3 │
│ acc                   0.93421 │
│ loss                  0.22201 │
╰───────────────────────────────╯
(RayTrainWorker pid=63669) Epoch 2-val Loss: 0.2220 Acc: 0.9342
(RayTrainWorker pid=63669) Epoch 3-train Loss: 0.2416 Acc: 0.9262

Training finished iteration 4 at 2025-03-03 03:32:49. Total running time: 34s
╭───────────────────────────────╮
│ Training result               │
├───────────────────────────────┤
│ checkpoint_dir_name           │
│ time_this_iter_s      1.38353 │
│ time_total_s          28.8345 │
│ training_iteration          4 │
│ acc                   0.96053 │
│ loss                  0.17815 │
╰───────────────────────────────╯
(RayTrainWorker pid=63669) Epoch 3-val Loss: 0.1782 Acc: 0.9605
(RayTrainWorker pid=63669) Epoch 4-train Loss: 0.1900 Acc: 0.9508

Training finished iteration 5 at 2025-03-03 03:32:50. Total running time: 35s
╭───────────────────────────────╮
│ Training result               │
├───────────────────────────────┤
│ checkpoint_dir_name           │
│ time_this_iter_s      1.37318 │
│ time_total_s          30.2077 │
│ training_iteration          5 │
│ acc                   0.93421 │
│ loss                  0.17063 │
╰───────────────────────────────╯
(RayTrainWorker pid=63669) Epoch 4-val Loss: 0.1706 Acc: 0.9342
(RayTrainWorker pid=63669) Epoch 5-train Loss: 0.1346 Acc: 0.9672

Training finished iteration 6 at 2025-03-03 03:32:52. Total running time: 36s
╭───────────────────────────────╮
│ Training result               │
├───────────────────────────────┤
│ checkpoint_dir_name           │
│ time_this_iter_s      1.37999 │
│ time_total_s          31.5876 │
│ training_iteration          6 │
│ acc                   0.96053 │
│ loss                   0.1552 │
╰───────────────────────────────╯
(RayTrainWorker pid=63669) Epoch 5-val Loss: 0.1552 Acc: 0.9605
(RayTrainWorker pid=63669) Epoch 6-train Loss: 0.1184 Acc: 0.9672

Training finished iteration 7 at 2025-03-03 03:32:53. Total running time: 38s
╭───────────────────────────────╮
│ Training result               │
├───────────────────────────────┤
│ checkpoint_dir_name           │
│ time_this_iter_s      1.39198 │
│ time_total_s          32.9796 │
│ training_iteration          7 │
│ acc                   0.94737 │
│ loss                  0.14702 │
╰───────────────────────────────╯
(RayTrainWorker pid=63669) Epoch 6-val Loss: 0.1470 Acc: 0.9474
(RayTrainWorker pid=63669) Epoch 7-train Loss: 0.0864 Acc: 0.9836

Training finished iteration 8 at 2025-03-03 03:32:54. Total running time: 39s
╭───────────────────────────────╮
│ Training result               │
├───────────────────────────────┤
│ checkpoint_dir_name           │
│ time_this_iter_s       1.3736 │
│ time_total_s          34.3532 │
│ training_iteration          8 │
│ acc                   0.94737 │
│ loss                  0.14443 │
╰───────────────────────────────╯
(RayTrainWorker pid=63669) Epoch 7-val Loss: 0.1444 Acc: 0.9474
(RayTrainWorker pid=63669) Epoch 8-train Loss: 0.1085 Acc: 0.9590

Training finished iteration 9 at 2025-03-03 03:32:56. Total running time: 40s
╭───────────────────────────────╮
│ Training result               │
├───────────────────────────────┤
│ checkpoint_dir_name           │
│ time_this_iter_s      1.37868 │
│ time_total_s          35.7319 │
│ training_iteration          9 │
│ acc                   0.94737 │
│ loss                  0.14194 │
╰───────────────────────────────╯
(RayTrainWorker pid=63669) Epoch 8-val Loss: 0.1419 Acc: 0.9474
(RayTrainWorker pid=63669) Epoch 9-train Loss: 0.0829 Acc: 0.9754

2025-03-03 03:32:58,628 INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/root/ray_results/TorchTrainer_2025-03-03_03-32-15' in 0.0028s.
Training finished iteration 10 at 2025-03-03 03:32:57. Total running time: 42s
╭───────────────────────────────╮
│ Training result               │
├───────────────────────────────┤
│ checkpoint_dir_name           │
│ time_this_iter_s      1.36497 │
│ time_total_s          37.0969 │
│ training_iteration         10 │
│ acc                   0.96053 │
│ loss                  0.14297 │
╰───────────────────────────────╯
(RayTrainWorker pid=63669) Epoch 9-val Loss: 0.1430 Acc: 0.9605

Training completed after 10 iterations at 2025-03-03 03:32:58. Total running time: 43s

Training result: Result(
  metrics={'loss': 0.1429688463869848, 'acc': 0.9605263157894737},
  path='/root/ray_results/TorchTrainer_2025-03-03_03-32-15/TorchTrainer_19fd8_00000_0_2025-03-03_03-32-15',
  filesystem='local',
  checkpoint=None
)
(RayTrainWorker pid=63669) Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
  0%|          | 0.00/97.8M [00:00<?, ?B/s]
 68%|██████▊   | 66.1M/97.8M [00:00<00:00, 160MB/s] [repeated 6x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.rayai.org.cn/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication for more options.)