使用 PyTorch Lightning 和 Tune#

try-anyscale-quickstart

PyTorch Lightning 是一个为 PyTorch 模型训练带来结构化的框架。它旨在避免样板代码,这样您在构建新模型时就不必一遍又一遍地编写相同的训练循环。

../../_images/pytorch_lightning_full.png

PyTorch Lightning 的主要抽象是 LightningModule 类,您的应用程序应该继承它。有一篇关于如何将模型从 vanilla PyTorch 迁移到 Lightning 的优秀文章

PyTorch Lightning 的类结构使得定义和调整模型参数变得非常容易。本教程将向您展示如何使用 Tune 和 PyTorch Lightning。值得注意的是,LightningModule 完全不需要修改——因此,您可以即插即用,用于您现有的模型,前提是它们的参数是可配置的!

注意

要运行此示例,您需要安装以下库

$ pip install -q "ray[tune]" torch torchvision pytorch_lightning

MNIST 的 PyTorch Lightning 分类器#

首先,让我们从 MNIST 分类器的基本 PyTorch Lightning 实现开始。此时,此分类器不包含任何调优代码。

首先,我们进行一些导入

import os
import torch
import tempfile
import pytorch_lightning as pl
import torch.nn.functional as F
from filelock import FileLock
from torchmetrics import Accuracy
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision import transforms
/home/ray/anaconda3/lib/python3.11/site-packages/lightning_utilities/core/imports.py:14: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
  import pkg_resources
/home/ray/anaconda3/lib/python3.11/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _torch_pytree._register_pytree_node(
/home/ray/anaconda3/lib/python3.11/site-packages/transformers/utils/generic.py:309: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _torch_pytree._register_pytree_node(

我们的示例基于我们之前提到的博客文章中的 MNIST 示例。我们将原始模型和数据集定义改编为 MNISTClassifierMNISTDataModule

class MNISTClassifier(pl.LightningModule):
    def __init__(self, config):
        super(MNISTClassifier, self).__init__()
        self.accuracy = Accuracy(task="multiclass", num_classes=10, top_k=1)
        self.layer_1_size = config["layer_1_size"]
        self.layer_2_size = config["layer_2_size"]
        self.lr = config["lr"]

        # mnist images are (1, 28, 28) (channels, width, height)
        self.layer_1 = torch.nn.Linear(28 * 28, self.layer_1_size)
        self.layer_2 = torch.nn.Linear(self.layer_1_size, self.layer_2_size)
        self.layer_3 = torch.nn.Linear(self.layer_2_size, 10)
        self.eval_loss = []
        self.eval_accuracy = []

    def cross_entropy_loss(self, logits, labels):
        return F.nll_loss(logits, labels)

    def forward(self, x):
        batch_size, channels, width, height = x.size()
        x = x.view(batch_size, -1)

        x = self.layer_1(x)
        x = torch.relu(x)

        x = self.layer_2(x)
        x = torch.relu(x)

        x = self.layer_3(x)
        x = torch.log_softmax(x, dim=1)

        return x

    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        logits = self.forward(x)
        loss = self.cross_entropy_loss(logits, y)
        accuracy = self.accuracy(logits, y)

        self.log("ptl/train_loss", loss)
        self.log("ptl/train_accuracy", accuracy)
        return loss

    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        logits = self.forward(x)
        loss = self.cross_entropy_loss(logits, y)
        accuracy = self.accuracy(logits, y)
        self.eval_loss.append(loss)
        self.eval_accuracy.append(accuracy)
        return {"val_loss": loss, "val_accuracy": accuracy}

    def on_validation_epoch_end(self):
        avg_loss = torch.stack(self.eval_loss).mean()
        avg_acc = torch.stack(self.eval_accuracy).mean()
        self.log("ptl/val_loss", avg_loss, sync_dist=True)
        self.log("ptl/val_accuracy", avg_acc, sync_dist=True)
        self.eval_loss.clear()
        self.eval_accuracy.clear()

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer


class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, batch_size=128):
        super().__init__()
        self.data_dir = tempfile.mkdtemp()
        self.batch_size = batch_size
        self.transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        )

    def setup(self, stage=None):
        with FileLock(f"{self.data_dir}.lock"):
            mnist = MNIST(
                self.data_dir, train=True, download=True, transform=self.transform
            )
            self.mnist_train, self.mnist_val = random_split(mnist, [55000, 5000])

            self.mnist_test = MNIST(
                self.data_dir, train=False, download=True, transform=self.transform
            )

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size, num_workers=4)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=4)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=4)
default_config = {
    "layer_1_size": 128,
    "layer_2_size": 256,
    "lr": 1e-3,
}

定义一个训练函数,该函数创建模型、DataModule 和 PyTorch Lightning Trainer

from ray.tune.integration.pytorch_lightning import TuneReportCheckpointCallback

def train_func(config):
    dm = MNISTDataModule(batch_size=config["batch_size"])
    model = MNISTClassifier(config)

    trainer = pl.Trainer(
        devices="auto",
        accelerator="auto",
        callbacks=[TuneReportCheckpointCallback()],
        enable_progress_bar=False,
    )
    trainer.fit(model, datamodule=dm)

调整模型参数#

上述参数应该已经能为您提供超过 90% 的良好准确率。但是,我们也许可以通过更改一些超参数来进一步提高它。例如,如果我们使用更小的学习率和更大的中间层大小,我们可能会获得更高的准确率。

与其手动遍历所有参数组合,不如使用 Tune 来系统地尝试参数组合并找到表现最佳的集合。

首先,我们需要一些额外的导入

from ray import tune
from ray.tune.schedulers import ASHAScheduler

配置搜索空间#

现在我们配置参数搜索空间。我们想在不同的层维度、学习率和批次大小之间进行选择。学习率应该在 0.00010.1 之间均匀采样。tune.loguniform() 函数是语法糖,用于更轻松地在这些不同的数量级之间采样,特别是我们还可以采样小值。对于 tune.choice() 也是如此,它从所有提供的选项中进行采样。

search_space = {
    "layer_1_size": tune.choice([32, 64, 128]),
    "layer_2_size": tune.choice([64, 128, 256]),
    "lr": tune.loguniform(1e-4, 1e-1),
    "batch_size": tune.choice([32, 64]),
}

选择调度器#

在此示例中,我们使用异步 Hyperband 调度器。此调度器在每次迭代时决定哪些试验可能表现不佳,并停止这些试验。这样,我们就不会在糟糕的超参数配置上浪费任何资源。

# The maximum training epochs
num_epochs = 5

# Number of samples from parameter space
num_samples = 10

如果您有更多可用资源,可以相应地修改上述参数。例如,更多的 epoch,更多的参数样本。

scheduler = ASHAScheduler(max_t=num_epochs, grace_period=1, reduction_factor=2)

使用 GPU 训练#

我们可以指定 Tune 为每个试验请求的资源数量,包括 GPU。

train_fn_with_resources = tune.with_resources(train_func, resources={"CPU": 1, "GPU": 1})

整合#

最后,我们需要创建一个 Tuner() 对象,并使用 tuner.fit() 启动 Ray Tune。完整的代码如下

def tune_mnist_asha(num_samples=10):
    scheduler = ASHAScheduler(max_t=num_epochs, grace_period=1, reduction_factor=2)

    tuner = tune.Tuner(
        train_fn_with_resources,
        param_space=search_space,
        tune_config=tune.TuneConfig(
            metric="ptl/val_accuracy",
            mode="max",
            num_samples=num_samples,
            scheduler=scheduler,
        ),
        run_config=tune.RunConfig(
            checkpoint_config=tune.CheckpointConfig(
                num_to_keep=2,
                checkpoint_score_attribute="ptl/val_accuracy",
                checkpoint_score_order="max",
            ),
        ),
    )
    return tuner.fit()


results = tune_mnist_asha(num_samples=num_samples)
results.get_best_result(metric="ptl/val_accuracy", mode="max")
Result(
  metrics={'ptl/train_loss': 0.001267582061700523, 'ptl/train_accuracy': 1.0, 'ptl/val_loss': 0.1036270260810852, 'ptl/val_accuracy': 0.9721123576164246},
  path='/home/ray/ray_results/train_func_2025-09-23_13-37-55/train_func_2f534_00006_6_batch_size=64,layer_1_size=64,layer_2_size=64,lr=0.0020_2025-09-23_13-37-55',
  filesystem='local',
  checkpoint=Checkpoint(filesystem=local, path=/home/ray/ray_results/train_func_2025-09-23_13-37-55/train_func_2f534_00006_6_batch_size=64,layer_1_size=64,layer_2_size=64,lr=0.0020_2025-09-23_13-37-55/checkpoint_000004)
)

在上面的示例中,Tune 使用不同的超参数配置运行了 10 个试验。

正如您在 training_iteration 列中看到的,损失高(准确率低)的试验已被提前终止。表现最佳的试验使用了 batch_size=64layer_1_size=128layer_2_size=256lr=0.0037

更多 PyTorch Lightning 示例#