使用 PyTorch Lightning 进行分布式训练入门#
本教程将引导您完成将现有 PyTorch Lightning 脚本转换为使用 Ray Train 的过程。
了解如何
配置 Lightning Trainer,使其能够在 Ray 上以分布式方式并在正确的 CPU 或 GPU 设备上运行。
配置 训练函数 来报告指标和保存检查点。
为训练作业配置 缩放 以及 CPU 或 GPU 资源需求。
使用
TorchTrainer启动分布式训练作业。
快速入门#
供参考,最终代码如下
from ray.train.torch import TorchTrainer
from ray.train import ScalingConfig
def train_func():
# Your PyTorch Lightning training code here.
scaling_config = ScalingConfig(num_workers=2, use_gpu=True)
trainer = TorchTrainer(train_func, scaling_config=scaling_config)
result = trainer.fit()
train_func是在每个分布式训练工作节点上执行的 Python 代码。ScalingConfig定义了分布式训练工作节点的数量以及是否使用 GPU。TorchTrainer启动分布式训练作业。
比较使用 Ray Train 和不使用 Ray Train 的 PyTorch Lightning 训练脚本。
import os
import tempfile
import torch
from torch.utils.data import DataLoader
from torchvision.models import resnet18
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor, Normalize, Compose
import lightning.pytorch as pl
import ray.train.lightning
from ray.train.torch import TorchTrainer
# Model, Loss, Optimizer
class ImageClassifier(pl.LightningModule):
def __init__(self):
super(ImageClassifier, self).__init__()
self.model = resnet18(num_classes=10)
self.model.conv1 = torch.nn.Conv2d(
1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
)
self.criterion = torch.nn.CrossEntropyLoss()
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
x, y = batch
outputs = self.forward(x)
loss = self.criterion(outputs, y)
self.log("loss", loss, on_step=True, prog_bar=True)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.model.parameters(), lr=0.001)
def train_func():
# Data
transform = Compose([ToTensor(), Normalize((0.28604,), (0.32025,))])
data_dir = os.path.join(tempfile.gettempdir(), "data")
train_data = FashionMNIST(root=data_dir, train=True, download=True, transform=transform)
train_dataloader = DataLoader(train_data, batch_size=128, shuffle=True)
# Training
model = ImageClassifier()
# [1] Configure PyTorch Lightning Trainer.
trainer = pl.Trainer(
max_epochs=10,
devices="auto",
accelerator="auto",
strategy=ray.train.lightning.RayDDPStrategy(),
plugins=[ray.train.lightning.RayLightningEnvironment()],
callbacks=[ray.train.lightning.RayTrainReportCallback()],
# [1a] Optionally, disable the default checkpointing behavior
# in favor of the `RayTrainReportCallback` above.
enable_checkpointing=False,
)
trainer = ray.train.lightning.prepare_trainer(trainer)
trainer.fit(model, train_dataloaders=train_dataloader)
# [2] Configure scaling and resource requirements.
scaling_config = ray.train.ScalingConfig(num_workers=2, use_gpu=True)
# [3] Launch distributed training job.
trainer = TorchTrainer(
train_func,
scaling_config=scaling_config,
# [3a] If running in a multi-node cluster, this is where you
# should configure the run's persistent storage that is accessible
# across all worker nodes.
# run_config=ray.train.RunConfig(storage_path="s3://..."),
)
result: ray.train.Result = trainer.fit()
# [4] Load the trained model.
with result.checkpoint.as_directory() as checkpoint_dir:
model = ImageClassifier.load_from_checkpoint(
os.path.join(
checkpoint_dir,
ray.train.lightning.RayTrainReportCallback.CHECKPOINT_NAME,
),
)
import torch
from torchvision.models import resnet18
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor, Normalize, Compose
from torch.utils.data import DataLoader
import lightning.pytorch as pl
# Model, Loss, Optimizer
class ImageClassifier(pl.LightningModule):
def __init__(self):
super(ImageClassifier, self).__init__()
self.model = resnet18(num_classes=10)
self.model.conv1 = torch.nn.Conv2d(
1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
)
self.criterion = torch.nn.CrossEntropyLoss()
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
x, y = batch
outputs = self.forward(x)
loss = self.criterion(outputs, y)
self.log("loss", loss, on_step=True, prog_bar=True)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.model.parameters(), lr=0.001)
# Data
transform = Compose([ToTensor(), Normalize((0.28604,), (0.32025,))])
train_data = FashionMNIST(root='./data', train=True, download=True, transform=transform)
train_dataloader = DataLoader(train_data, batch_size=128, shuffle=True)
# Training
model = ImageClassifier()
trainer = pl.Trainer(max_epochs=10)
trainer.fit(model, train_dataloaders=train_dataloader)
设置训练函数#
首先,更新您的训练代码以支持分布式训练。首先将您的代码封装在一个 训练函数 中
def train_func():
# Your model training code here.
...
每个分布式训练工作节点都会执行此函数。
您也可以通过 Trainer 的 train_loop_config 将 train_func 的输入参数指定为字典。例如
def train_func(config):
lr = config["lr"]
num_epochs = config["num_epochs"]
config = {"lr": 1e-4, "num_epochs": 10}
trainer = ray.train.torch.TorchTrainer(train_func, train_loop_config=config, ...)
警告
为避免通过 train_loop_config 传递大型数据对象以减少序列化和反序列化开销。相反,更推荐在 train_func 中直接初始化大型对象(例如数据集、模型)。
def load_dataset():
# Return a large in-memory dataset
...
def load_model():
# Return a large in-memory model instance
...
-config = {"data": load_dataset(), "model": load_model()}
def train_func(config):
- data = config["data"]
- model = config["model"]
+ data = load_dataset()
+ model = load_model()
...
trainer = ray.train.torch.TorchTrainer(train_func, train_loop_config=config, ...)
Ray Train 在每个 worker 上设置您的分布式进程组。您只需要对 Lightning Trainer 定义进行少量更改。
import lightning.pytorch as pl
-from pl.strategies import DDPStrategy
-from pl.plugins.environments import LightningEnvironment
+import ray.train.lightning
def train_func():
...
model = MyLightningModule(...)
datamodule = MyLightningDataModule(...)
trainer = pl.Trainer(
- devices=[0, 1, 2, 3],
- strategy=DDPStrategy(),
- plugins=[LightningEnvironment()],
+ devices="auto",
+ accelerator="auto",
+ strategy=ray.train.lightning.RayDDPStrategy(),
+ plugins=[ray.train.lightning.RayLightningEnvironment()]
)
+ trainer = ray.train.lightning.prepare_trainer(trainer)
trainer.fit(model, datamodule=datamodule)
以下各节将讨论每一项更改。
配置分布式策略#
Ray Train 为 Lightning 提供了几种子类化的分布式策略。这些策略保留了与其基类相同的参数列表。在内部,它们配置了根设备和分布式采样器参数。
import lightning.pytorch as pl
-from pl.strategies import DDPStrategy
+import ray.train.lightning
def train_func():
...
trainer = pl.Trainer(
...
- strategy=DDPStrategy(),
+ strategy=ray.train.lightning.RayDDPStrategy(),
...
)
...
配置 Ray 集群环境插件#
Ray Train 还提供了一个 RayLightningEnvironment 类作为 Ray 集群的规范。这个实用类配置了 worker 的本地、全局和节点排名以及世界大小。
import lightning.pytorch as pl
-from pl.plugins.environments import LightningEnvironment
+import ray.train.lightning
def train_func():
...
trainer = pl.Trainer(
...
- plugins=[LightningEnvironment()],
+ plugins=[ray.train.lightning.RayLightningEnvironment()],
...
)
...
配置并行设备#
此外,Ray TorchTrainer 已经为您配置了正确的 CUDA_VISIBLE_DEVICES。应该始终通过设置 devices="auto" 和 acelerator="auto" 来使用所有可用的 GPU。
import lightning.pytorch as pl
def train_func():
...
trainer = pl.Trainer(
...
- devices=[0,1,2,3],
+ devices="auto",
+ accelerator="auto",
...
)
...
报告检查点和指标#
要持久化您的检查点并监控训练进度,请将 ray.train.lightning.RayTrainReportCallback 实用回调添加到您的 Trainer 中。
import lightning.pytorch as pl
from ray.train.lightning import RayTrainReportCallback
def train_func():
...
trainer = pl.Trainer(
...
- callbacks=[...],
+ callbacks=[..., RayTrainReportCallback()],
)
...
将指标和检查点报告给 Ray Train 使您能够支持 容错训练 和 超参数优化。请注意,ray.train.lightning.RayTrainReportCallback 类只提供了一个简单的实现,并且可以 进一步自定义。
准备您的 Lightning Trainer#
最后,将您的 Lightning Trainer 传递给 prepare_trainer() 来验证您的配置。
import lightning.pytorch as pl
import ray.train.lightning
def train_func():
...
trainer = pl.Trainer(...)
+ trainer = ray.train.lightning.prepare_trainer(trainer)
...
配置缩放和 GPU#
在训练函数之外,创建一个 ScalingConfig 对象来配置
num_workers- 分布式训练工作节点的数量。use_gpu- 每个工作节点是否应使用 GPU(或 CPU)。
from ray.train import ScalingConfig
scaling_config = ScalingConfig(num_workers=2, use_gpu=True)
有关更多详细信息,请参阅 配置缩放和 GPU。
配置持久存储#
创建一个 RunConfig 对象来指定结果(包括检查点和工件)将要保存的路径。
from ray.train import RunConfig
# Local path (/some/local/path/unique_run_name)
run_config = RunConfig(storage_path="/some/local/path", name="unique_run_name")
# Shared cloud storage URI (s3://bucket/unique_run_name)
run_config = RunConfig(storage_path="s3://bucket", name="unique_run_name")
# Shared NFS path (/mnt/nfs/unique_run_name)
run_config = RunConfig(storage_path="/mnt/nfs", name="unique_run_name")
警告
指定一个*共享存储位置*(如云存储或 NFS)对于单节点集群是*可选的*,但对于多节点集群是*必需的*。使用本地路径将在多节点集群的检查点过程中*引发错误*。
有关更多详细信息,请参阅 配置持久存储。
启动训练作业#
将所有这些结合起来,您现在可以使用 TorchTrainer 启动分布式训练作业。
from ray.train.torch import TorchTrainer
trainer = TorchTrainer(
train_func, scaling_config=scaling_config, run_config=run_config
)
result = trainer.fit()
访问训练结果#
训练完成后,将返回一个 Result 对象,其中包含有关训练运行的信息,包括训练期间报告的指标和检查点。
result.metrics # The metrics reported during training.
result.checkpoint # The latest checkpoint reported during training.
result.path # The path where logs are stored.
result.error # The exception that was raised, if training failed.
有关更多用法示例,请参阅 检查训练结果。
下一步#
转换您的 PyTorch Lightning 训练脚本以使用 Ray Train 后
版本兼容性#
Ray Train 已与 pytorch_lightning 版本 1.6.5 和 2.1.2 进行了测试。为获得完整兼容性,请使用 pytorch_lightning>=1.6.5。不禁止更早的版本,但可能会导致意外问题。如果您遇到任何兼容性问题,请考虑升级您的 PyTorch Lightning 版本或 提交 issue。
注意
如果您使用的是 Lightning 2.x,请使用导入路径 lightning.pytorch.xxx 而不是 pytorch_lightning.xxx。
LightningTrainer 迁移指南#
Ray 2.4 引入了 LightningTrainer,并公开了 LightningConfigBuilder 来定义 pl.LightningModule 和 pl.Trainer 的配置。
然后,它在黑盒中实例化模型和 Trainer 对象并运行预定义的训练函数。
此版本的 LightningTrainer API 具有限制性,限制了您管理训练功能的能力。
Ray 2.7 引入了新统一的 TorchTrainer API,它提供了增强的透明度、灵活性和简单性。此 API 更符合标准的 PyTorch Lightning 脚本,确保用户可以更好地控制其原生 Lightning 代码。
from ray.train.lightning import LightningConfigBuilder, LightningTrainer
config_builder = LightningConfigBuilder()
# [1] Collect model configs
config_builder.module(cls=MyLightningModule, lr=1e-3, feature_dim=128)
# [2] Collect checkpointing configs
config_builder.checkpointing(monitor="val_accuracy", mode="max", save_top_k=3)
# [3] Collect pl.Trainer configs
config_builder.trainer(
max_epochs=10,
accelerator="gpu",
log_every_n_steps=100,
)
# [4] Build datasets on the head node
datamodule = MyLightningDataModule(batch_size=32)
config_builder.fit_params(datamodule=datamodule)
# [5] Execute the internal training function in a black box
ray_trainer = LightningTrainer(
lightning_config=config_builder.build(),
scaling_config=ScalingConfig(num_workers=4, use_gpu=True),
run_config=RunConfig(
checkpoint_config=CheckpointConfig(
num_to_keep=3,
checkpoint_score_attribute="val_accuracy",
checkpoint_score_order="max",
),
)
)
result = ray_trainer.fit()
# [6] Load the trained model from an opaque Lightning-specific checkpoint.
lightning_checkpoint = result.checkpoint
model = lightning_checkpoint.get_model(MyLightningModule)
import os
import lightning.pytorch as pl
import ray.train
from ray.train.torch import TorchTrainer
from ray.train.lightning import (
RayDDPStrategy,
RayLightningEnvironment,
RayTrainReportCallback,
prepare_trainer
)
def train_func():
# [1] Create a Lightning model
model = MyLightningModule(lr=1e-3, feature_dim=128)
# [2] Report Checkpoint with callback
ckpt_report_callback = RayTrainReportCallback()
# [3] Create a Lighting Trainer
trainer = pl.Trainer(
max_epochs=10,
log_every_n_steps=100,
# New configurations below
devices="auto",
accelerator="auto",
strategy=RayDDPStrategy(),
plugins=[RayLightningEnvironment()],
callbacks=[ckpt_report_callback],
)
# Validate your Lightning trainer configuration
trainer = prepare_trainer(trainer)
# [4] Build your datasets on each worker
datamodule = MyLightningDataModule(batch_size=32)
trainer.fit(model, datamodule=datamodule)
# [5] Explicitly define and run the training function
ray_trainer = TorchTrainer(
train_func,
scaling_config=ray.train.ScalingConfig(num_workers=4, use_gpu=True),
run_config=ray.train.RunConfig(
checkpoint_config=ray.train.CheckpointConfig(
num_to_keep=3,
checkpoint_score_attribute="val_accuracy",
checkpoint_score_order="max",
),
)
)
result = ray_trainer.fit()
# [6] Load the trained model from a simplified checkpoint interface.
checkpoint: ray.train.Checkpoint = result.checkpoint
with checkpoint.as_directory() as checkpoint_dir:
print("Checkpoint contents:", os.listdir(checkpoint_dir))
checkpoint_path = os.path.join(checkpoint_dir, "checkpoint.ckpt")
model = MyLightningModule.load_from_checkpoint(checkpoint_path)