本教程将引导您完成将现有 PyTorch Lightning 脚本转换为使用 Ray Train 的过程。
了解如何
配置 Lightning Trainer,使其能够通过 Ray 分布式运行,并在正确的 CPU 或 GPU 设备上运行。
配置训练函数以报告指标并保存检查点。
配置训练作业的规模以及 CPU 或 GPU 资源需求。
使用
TorchTrainer
启动分布式训练作业。快速入门#
供参考,最终代码如下
train_func
是在每个分布式训练 worker 上执行的 Python 代码。
from ray.train.torch import TorchTrainer
from ray.train import ScalingConfig
def train_func():
# Your PyTorch Lightning training code here.
scaling_config = ScalingConfig(num_workers=2, use_gpu=True)
trainer = TorchTrainer(train_func, scaling_config=scaling_config)
result = trainer.fit()
ScalingConfig
定义了分布式训练 worker 的数量以及是否使用 GPU。TorchTrainer
启动分布式训练作业。比较使用和不使用 Ray Train 的 PyTorch Lightning 训练脚本。
PyTorch Lightning
import torch
from torchvision.models import resnet18
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor, Normalize, Compose
from torch.utils.data import DataLoader
import lightning.pytorch as pl
# Model, Loss, Optimizer
class ImageClassifier(pl.LightningModule):
def __init__(self):
super(ImageClassifier, self).__init__()
self.model = resnet18(num_classes=10)
self.model.conv1 = torch.nn.Conv2d(
1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
)
self.criterion = torch.nn.CrossEntropyLoss()
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
x, y = batch
outputs = self.forward(x)
loss = self.criterion(outputs, y)
self.log("loss", loss, on_step=True, prog_bar=True)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.model.parameters(), lr=0.001)
# Data
transform = Compose([ToTensor(), Normalize((0.28604,), (0.32025,))])
train_data = FashionMNIST(root='./data', train=True, download=True, transform=transform)
train_dataloader = DataLoader(train_data, batch_size=128, shuffle=True)
# Training
model = ImageClassifier()
trainer = pl.Trainer(max_epochs=10)
trainer.fit(model, train_dataloaders=train_dataloader)
import os
import tempfile
import torch
from torch.utils.data import DataLoader
from torchvision.models import resnet18
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor, Normalize, Compose
import lightning.pytorch as pl
import ray.train.lightning
from ray.train.torch import TorchTrainer
# Model, Loss, Optimizer
class ImageClassifier(pl.LightningModule):
def __init__(self):
super(ImageClassifier, self).__init__()
self.model = resnet18(num_classes=10)
self.model.conv1 = torch.nn.Conv2d(
1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
)
self.criterion = torch.nn.CrossEntropyLoss()
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
x, y = batch
outputs = self.forward(x)
loss = self.criterion(outputs, y)
self.log("loss", loss, on_step=True, prog_bar=True)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.model.parameters(), lr=0.001)
def train_func():
# Data
transform = Compose([ToTensor(), Normalize((0.28604,), (0.32025,))])
data_dir = os.path.join(tempfile.gettempdir(), "data")
train_data = FashionMNIST(root=data_dir, train=True, download=True, transform=transform)
train_dataloader = DataLoader(train_data, batch_size=128, shuffle=True)
# Training
model = ImageClassifier()
# [1] Configure PyTorch Lightning Trainer.
trainer = pl.Trainer(
max_epochs=10,
devices="auto",
accelerator="auto",
strategy=ray.train.lightning.RayDDPStrategy(),
plugins=[ray.train.lightning.RayLightningEnvironment()],
callbacks=[ray.train.lightning.RayTrainReportCallback()],
# [1a] Optionally, disable the default checkpointing behavior
# in favor of the `RayTrainReportCallback` above.
enable_checkpointing=False,
)
trainer = ray.train.lightning.prepare_trainer(trainer)
trainer.fit(model, train_dataloaders=train_dataloader)
# [2] Configure scaling and resource requirements.
scaling_config = ray.train.ScalingConfig(num_workers=2, use_gpu=True)
# [3] Launch distributed training job.
trainer = TorchTrainer(
train_func,
scaling_config=scaling_config,
# [3a] If running in a multi-node cluster, this is where you
# should configure the run's persistent storage that is accessible
# across all worker nodes.
# run_config=ray.train.RunConfig(storage_path="s3://..."),
)
result: ray.train.Result = trainer.fit()
# [4] Load the trained model.
with result.checkpoint.as_directory() as checkpoint_dir:
model = ImageClassifier.load_from_checkpoint(
os.path.join(
checkpoint_dir,
ray.train.lightning.RayTrainReportCallback.CHECKPOINT_NAME,
),
)
首先,更新您的训练代码以支持分布式训练。首先将您的代码封装在一个训练函数中
每个分布式训练 worker 都执行此函数。
def train_func():
# Your model training code here.
...
您还可以通过 Trainer 的 train_loop_config
将 train_func
的输入参数指定为一个字典。例如
警告
def train_func(config):
lr = config["lr"]
num_epochs = config["num_epochs"]
config = {"lr": 1e-4, "num_epochs": 10}
trainer = ray.train.torch.TorchTrainer(train_func, train_loop_config=config, ...)
避免通过 train_loop_config
传递大型数据对象,以减少序列化和反序列化开销。相反,建议直接在 train_func
中初始化大型对象(例如数据集、模型)。
Ray Train 会在每个 worker 上设置您的分布式进程组。您只需对您的 Lightning Trainer 定义进行少量更改即可。
def load_dataset():
# Return a large in-memory dataset
...
def load_model():
# Return a large in-memory model instance
...
-config = {"data": load_dataset(), "model": load_model()}
def train_func(config):
- data = config["data"]
- model = config["model"]
+ data = load_dataset()
+ model = load_model()
...
trainer = ray.train.torch.TorchTrainer(train_func, train_loop_config=config, ...)
以下章节讨论每个更改。
import lightning.pytorch as pl
-from pl.strategies import DDPStrategy
-from pl.plugins.environments import LightningEnvironment
+import ray.train.lightning
def train_func():
...
model = MyLightningModule(...)
datamodule = MyLightningDataModule(...)
trainer = pl.Trainer(
- devices=[0, 1, 2, 3],
- strategy=DDPStrategy(),
- plugins=[LightningEnvironment()],
+ devices="auto",
+ accelerator="auto",
+ strategy=ray.train.lightning.RayDDPStrategy(),
+ plugins=[ray.train.lightning.RayLightningEnvironment()]
)
+ trainer = ray.train.lightning.prepare_trainer(trainer)
trainer.fit(model, datamodule=datamodule)
配置分布式策略#
Ray Train 为 Lightning 提供了几种子类化的分布式策略。这些策略保留了与其基础策略类相同的参数列表。在内部,它们配置了根设备和分布式采样器参数。
RayDDPStrategy
import lightning.pytorch as pl
-from pl.strategies import DDPStrategy
+import ray.train.lightning
def train_func():
...
trainer = pl.Trainer(
...
- strategy=DDPStrategy(),
+ strategy=ray.train.lightning.RayDDPStrategy(),
...
)
...
Ray Train 还提供了 RayLightningEnvironment
类作为 Ray 集群的规范。这个工具类配置了 worker 的本地、全局和节点 rank 以及 world size。
配置并行设备#
import lightning.pytorch as pl
-from pl.plugins.environments import LightningEnvironment
+import ray.train.lightning
def train_func():
...
trainer = pl.Trainer(
...
- plugins=[LightningEnvironment()],
+ plugins=[ray.train.lightning.RayLightningEnvironment()],
...
)
...
此外,Ray TorchTrainer 已经为您配置了正确的 CUDA_VISIBLE_DEVICES
。应始终通过设置 devices="auto"
和 accelerator="auto"
来使用所有可用 GPU。
报告检查点和指标#
import lightning.pytorch as pl
def train_func():
...
trainer = pl.Trainer(
...
- devices=[0,1,2,3],
+ devices="auto",
+ accelerator="auto",
...
)
...
为了持久化您的检查点并监控训练进度,请向您的 Trainer 添加一个 ray.train.lightning.RayTrainReportCallback
工具回调。
向 Ray Train 报告指标和检查点使您能够支持容错训练和超参数优化。请注意,ray.train.lightning.RayTrainReportCallback
类仅提供了一个简单的实现,并且可以进一步定制。
import lightning.pytorch as pl
from ray.train.lightning import RayTrainReportCallback
def train_func():
...
trainer = pl.Trainer(
...
- callbacks=[...],
+ callbacks=[..., RayTrainReportCallback()],
)
...
准备您的 Lightning Trainer#
最后,将您的 Lightning Trainer 传入 prepare_trainer()
以验证您的配置。
配置规模和 GPU#
import lightning.pytorch as pl
import ray.train.lightning
def train_func():
...
trainer = pl.Trainer(...)
+ trainer = ray.train.lightning.prepare_trainer(trainer)
...
在您的训练函数之外,创建一个 ScalingConfig
对象来配置
num_workers
- 分布式训练 worker 进程的数量。
from ray.train import ScalingConfig
scaling_config = ScalingConfig(num_workers=2, use_gpu=True)
配置持久化存储#
创建一个 RunConfig
对象来指定结果(包括检查点和 artifact)将保存到的路径。
指定一个共享存储位置(例如云存储或 NFS)对于单节点集群是可选的,但对于多节点集群是必需的。对于多节点集群,使用本地路径会在检查点过程中引发错误。
from ray.train import RunConfig
# Local path (/some/local/path/unique_run_name)
run_config = RunConfig(storage_path="/some/local/path", name="unique_run_name")
# Shared cloud storage URI (s3://bucket/unique_run_name)
run_config = RunConfig(storage_path="s3://bucket", name="unique_run_name")
# Shared NFS path (/mnt/nfs/unique_run_name)
run_config = RunConfig(storage_path="/mnt/nfs", name="unique_run_name")
避免通过 train_loop_config
传递大型数据对象,以减少序列化和反序列化开销。相反,建议直接在 train_func
中初始化大型对象(例如数据集、模型)。
更多详细信息,请参阅配置持久化存储。
启动训练作业#
将这些整合在一起,您现在可以使用 TorchTrainer
启动分布式训练作业。
访问训练结果#
from ray.train.torch import TorchTrainer
trainer = TorchTrainer(
train_func, scaling_config=scaling_config, run_config=run_config
)
result = trainer.fit()
训练完成后,会返回一个 Result
对象,其中包含训练运行的信息,包括训练期间报告的指标和检查点。
更多用法示例,请参阅检查训练结果。
result.metrics # The metrics reported during training.
result.checkpoint # The latest checkpoint reported during training.
result.path # The path where logs are stored.
result.error # The exception that was raised, if training failed.
下一步#
将您的 PyTorch Lightning 训练脚本转换为使用 Ray Train 后
请参阅用户指南,了解如何执行特定任务的更多信息。
Ray Train 已使用 pytorch_lightning
版本 1.6.5
和 2.1.2
进行了测试。为了获得完全兼容性,请使用 pytorch_lightning>=1.6.5
。早期版本并非禁用,但可能会导致意外问题。如果您遇到任何兼容性问题,请考虑升级您的 PyTorch Lightning 版本或提交一个问题。
注意
如果您使用的是 Lightning 2.x,请使用导入路径 lightning.pytorch.xxx
,而不是 pytorch_lightning.xxx
。
LightningTrainer 迁移指南#
Ray 2.4 引入了 LightningTrainer
,并暴露了 LightningConfigBuilder
来定义 pl.LightningModule
和 pl.Trainer
的配置。
然后,它实例化模型和 trainer 对象,并在一个黑箱中运行预定义的训练函数。
这个版本的 LightningTrainer API 具有约束性,限制了您管理训练功能的能力。
Ray 2.7 引入了全新的统一 TorchTrainer
API,它提供了增强的透明度、灵活性和简洁性。这个 API 更符合标准的 PyTorch Lightning 脚本,确保用户对其原生的 Lightning 代码有更好的控制。
(已弃用) LightningTrainer
from ray.train.lightning import LightningConfigBuilder, LightningTrainer
config_builder = LightningConfigBuilder()
# [1] Collect model configs
config_builder.module(cls=MyLightningModule, lr=1e-3, feature_dim=128)
# [2] Collect checkpointing configs
config_builder.checkpointing(monitor="val_accuracy", mode="max", save_top_k=3)
# [3] Collect pl.Trainer configs
config_builder.trainer(
max_epochs=10,
accelerator="gpu",
log_every_n_steps=100,
)
# [4] Build datasets on the head node
datamodule = MyLightningDataModule(batch_size=32)
config_builder.fit_params(datamodule=datamodule)
# [5] Execute the internal training function in a black box
ray_trainer = LightningTrainer(
lightning_config=config_builder.build(),
scaling_config=ScalingConfig(num_workers=4, use_gpu=True),
run_config=RunConfig(
checkpoint_config=CheckpointConfig(
num_to_keep=3,
checkpoint_score_attribute="val_accuracy",
checkpoint_score_order="max",
),
)
)
result = ray_trainer.fit()
# [6] Load the trained model from an opaque Lightning-specific checkpoint.
lightning_checkpoint = result.checkpoint
model = lightning_checkpoint.get_model(MyLightningModule)
import os
import lightning.pytorch as pl
import ray.train
from ray.train.torch import TorchTrainer
from ray.train.lightning import (
RayDDPStrategy,
RayLightningEnvironment,
RayTrainReportCallback,
prepare_trainer
)
def train_func():
# [1] Create a Lightning model
model = MyLightningModule(lr=1e-3, feature_dim=128)
# [2] Report Checkpoint with callback
ckpt_report_callback = RayTrainReportCallback()
# [3] Create a Lighting Trainer
trainer = pl.Trainer(
max_epochs=10,
log_every_n_steps=100,
# New configurations below
devices="auto",
accelerator="auto",
strategy=RayDDPStrategy(),
plugins=[RayLightningEnvironment()],
callbacks=[ckpt_report_callback],
)
# Validate your Lightning trainer configuration
trainer = prepare_trainer(trainer)
# [4] Build your datasets on each worker
datamodule = MyLightningDataModule(batch_size=32)
trainer.fit(model, datamodule=datamodule)
# [5] Explicitly define and run the training function
ray_trainer = TorchTrainer(
train_func,
scaling_config=ray.train.ScalingConfig(num_workers=4, use_gpu=True),
run_config=ray.train.RunConfig(
checkpoint_config=ray.train.CheckpointConfig(
num_to_keep=3,
checkpoint_score_attribute="val_accuracy",
checkpoint_score_order="max",
),
)
)
result = ray_trainer.fit()
# [6] Load the trained model from a simplified checkpoint interface.
checkpoint: ray.train.Checkpoint = result.checkpoint
with checkpoint.as_directory() as checkpoint_dir:
print("Checkpoint contents:", os.listdir(checkpoint_dir))
checkpoint_path = os.path.join(checkpoint_dir, "checkpoint.ckpt")
model = MyLightningModule.load_from_checkpoint(checkpoint_path)