使用 PyTorch Lightning 和 Tune#
PyTorch Lightning 是一个为 PyTorch 模型训练带来结构化的框架。它旨在避免样板代码,这样您在构建新模型时就不必一遍又一遍地编写相同的训练循环。
PyTorch Lightning 的主要抽象是 LightningModule 类,您的应用程序应该继承它。有一篇关于如何将模型从 vanilla PyTorch 迁移到 Lightning 的优秀文章。
PyTorch Lightning 的类结构使得定义和调整模型参数变得非常容易。本教程将向您展示如何使用 Tune 和 PyTorch Lightning。值得注意的是,LightningModule 完全不需要修改——因此,您可以即插即用,用于您现有的模型,前提是它们的参数是可配置的!
注意
要运行此示例,您需要安装以下库
$ pip install -q "ray[tune]" torch torchvision pytorch_lightning
MNIST 的 PyTorch Lightning 分类器#
首先,让我们从 MNIST 分类器的基本 PyTorch Lightning 实现开始。此时,此分类器不包含任何调优代码。
首先,我们进行一些导入
import os
import torch
import tempfile
import pytorch_lightning as pl
import torch.nn.functional as F
from filelock import FileLock
from torchmetrics import Accuracy
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision import transforms
/home/ray/anaconda3/lib/python3.11/site-packages/lightning_utilities/core/imports.py:14: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
import pkg_resources
/home/ray/anaconda3/lib/python3.11/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
_torch_pytree._register_pytree_node(
/home/ray/anaconda3/lib/python3.11/site-packages/transformers/utils/generic.py:309: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
_torch_pytree._register_pytree_node(
我们的示例基于我们之前提到的博客文章中的 MNIST 示例。我们将原始模型和数据集定义改编为 MNISTClassifier 和 MNISTDataModule。
class MNISTClassifier(pl.LightningModule):
def __init__(self, config):
super(MNISTClassifier, self).__init__()
self.accuracy = Accuracy(task="multiclass", num_classes=10, top_k=1)
self.layer_1_size = config["layer_1_size"]
self.layer_2_size = config["layer_2_size"]
self.lr = config["lr"]
# mnist images are (1, 28, 28) (channels, width, height)
self.layer_1 = torch.nn.Linear(28 * 28, self.layer_1_size)
self.layer_2 = torch.nn.Linear(self.layer_1_size, self.layer_2_size)
self.layer_3 = torch.nn.Linear(self.layer_2_size, 10)
self.eval_loss = []
self.eval_accuracy = []
def cross_entropy_loss(self, logits, labels):
return F.nll_loss(logits, labels)
def forward(self, x):
batch_size, channels, width, height = x.size()
x = x.view(batch_size, -1)
x = self.layer_1(x)
x = torch.relu(x)
x = self.layer_2(x)
x = torch.relu(x)
x = self.layer_3(x)
x = torch.log_softmax(x, dim=1)
return x
def training_step(self, train_batch, batch_idx):
x, y = train_batch
logits = self.forward(x)
loss = self.cross_entropy_loss(logits, y)
accuracy = self.accuracy(logits, y)
self.log("ptl/train_loss", loss)
self.log("ptl/train_accuracy", accuracy)
return loss
def validation_step(self, val_batch, batch_idx):
x, y = val_batch
logits = self.forward(x)
loss = self.cross_entropy_loss(logits, y)
accuracy = self.accuracy(logits, y)
self.eval_loss.append(loss)
self.eval_accuracy.append(accuracy)
return {"val_loss": loss, "val_accuracy": accuracy}
def on_validation_epoch_end(self):
avg_loss = torch.stack(self.eval_loss).mean()
avg_acc = torch.stack(self.eval_accuracy).mean()
self.log("ptl/val_loss", avg_loss, sync_dist=True)
self.log("ptl/val_accuracy", avg_acc, sync_dist=True)
self.eval_loss.clear()
self.eval_accuracy.clear()
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
return optimizer
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, batch_size=128):
super().__init__()
self.data_dir = tempfile.mkdtemp()
self.batch_size = batch_size
self.transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
def setup(self, stage=None):
with FileLock(f"{self.data_dir}.lock"):
mnist = MNIST(
self.data_dir, train=True, download=True, transform=self.transform
)
self.mnist_train, self.mnist_val = random_split(mnist, [55000, 5000])
self.mnist_test = MNIST(
self.data_dir, train=False, download=True, transform=self.transform
)
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=self.batch_size, num_workers=4)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=4)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=4)
default_config = {
"layer_1_size": 128,
"layer_2_size": 256,
"lr": 1e-3,
}
定义一个训练函数,该函数创建模型、DataModule 和 PyTorch Lightning Trainer。
from ray.tune.integration.pytorch_lightning import TuneReportCheckpointCallback
def train_func(config):
dm = MNISTDataModule(batch_size=config["batch_size"])
model = MNISTClassifier(config)
trainer = pl.Trainer(
devices="auto",
accelerator="auto",
callbacks=[TuneReportCheckpointCallback()],
enable_progress_bar=False,
)
trainer.fit(model, datamodule=dm)
调整模型参数#
上述参数应该已经能为您提供超过 90% 的良好准确率。但是,我们也许可以通过更改一些超参数来进一步提高它。例如,如果我们使用更小的学习率和更大的中间层大小,我们可能会获得更高的准确率。
与其手动遍历所有参数组合,不如使用 Tune 来系统地尝试参数组合并找到表现最佳的集合。
首先,我们需要一些额外的导入
from ray import tune
from ray.tune.schedulers import ASHAScheduler
配置搜索空间#
现在我们配置参数搜索空间。我们想在不同的层维度、学习率和批次大小之间进行选择。学习率应该在 0.0001 和 0.1 之间均匀采样。tune.loguniform() 函数是语法糖,用于更轻松地在这些不同的数量级之间采样,特别是我们还可以采样小值。对于 tune.choice() 也是如此,它从所有提供的选项中进行采样。
search_space = {
"layer_1_size": tune.choice([32, 64, 128]),
"layer_2_size": tune.choice([64, 128, 256]),
"lr": tune.loguniform(1e-4, 1e-1),
"batch_size": tune.choice([32, 64]),
}
选择调度器#
在此示例中,我们使用异步 Hyperband 调度器。此调度器在每次迭代时决定哪些试验可能表现不佳,并停止这些试验。这样,我们就不会在糟糕的超参数配置上浪费任何资源。
# The maximum training epochs
num_epochs = 5
# Number of samples from parameter space
num_samples = 10
如果您有更多可用资源,可以相应地修改上述参数。例如,更多的 epoch,更多的参数样本。
scheduler = ASHAScheduler(max_t=num_epochs, grace_period=1, reduction_factor=2)
使用 GPU 训练#
我们可以指定 Tune 为每个试验请求的资源数量,包括 GPU。
train_fn_with_resources = tune.with_resources(train_func, resources={"CPU": 1, "GPU": 1})
整合#
最后,我们需要创建一个 Tuner() 对象,并使用 tuner.fit() 启动 Ray Tune。完整的代码如下
def tune_mnist_asha(num_samples=10):
scheduler = ASHAScheduler(max_t=num_epochs, grace_period=1, reduction_factor=2)
tuner = tune.Tuner(
train_fn_with_resources,
param_space=search_space,
tune_config=tune.TuneConfig(
metric="ptl/val_accuracy",
mode="max",
num_samples=num_samples,
scheduler=scheduler,
),
run_config=tune.RunConfig(
checkpoint_config=tune.CheckpointConfig(
num_to_keep=2,
checkpoint_score_attribute="ptl/val_accuracy",
checkpoint_score_order="max",
),
),
)
return tuner.fit()
results = tune_mnist_asha(num_samples=num_samples)
results.get_best_result(metric="ptl/val_accuracy", mode="max")
Result(
metrics={'ptl/train_loss': 0.001267582061700523, 'ptl/train_accuracy': 1.0, 'ptl/val_loss': 0.1036270260810852, 'ptl/val_accuracy': 0.9721123576164246},
path='/home/ray/ray_results/train_func_2025-09-23_13-37-55/train_func_2f534_00006_6_batch_size=64,layer_1_size=64,layer_2_size=64,lr=0.0020_2025-09-23_13-37-55',
filesystem='local',
checkpoint=Checkpoint(filesystem=local, path=/home/ray/ray_results/train_func_2025-09-23_13-37-55/train_func_2f534_00006_6_batch_size=64,layer_1_size=64,layer_2_size=64,lr=0.0020_2025-09-23_13-37-55/checkpoint_000004)
)
在上面的示例中,Tune 使用不同的超参数配置运行了 10 个试验。
正如您在 training_iteration 列中看到的,损失高(准确率低)的试验已被提前终止。表现最佳的试验使用了 batch_size=64、layer_1_size=128、layer_2_size=256 和 lr=0.0037。
更多 PyTorch Lightning 示例#
要使用 Ray Train 运行分布式 PyTorch Lightning 训练,请参阅入门指南。
MLflow PyTorch Lightning 示例:使用 MLflow 和 Pytorch Lightning 与 Ray Tune 的示例。