本教程将引导您完成将现有 PyTorch Lightning 脚本转换为使用 Ray Train 的过程。

了解如何

配置 Lightning Trainer,使其能够通过 Ray 分布式运行,并在正确的 CPU 或 GPU 设备上运行。

  1. 配置训练函数以报告指标并保存检查点。

  2. 配置训练作业的规模以及 CPU 或 GPU 资源需求。

  3. 使用 TorchTrainer 启动分布式训练作业。

  4. 快速入门#

供参考,最终代码如下

train_func 是在每个分布式训练 worker 上执行的 Python 代码。

from ray.train.torch import TorchTrainer
from ray.train import ScalingConfig

def train_func():
    # Your PyTorch Lightning training code here.

scaling_config = ScalingConfig(num_workers=2, use_gpu=True)
trainer = TorchTrainer(train_func, scaling_config=scaling_config)
result = trainer.fit()
  1. ScalingConfig 定义了分布式训练 worker 的数量以及是否使用 GPU。

  2. TorchTrainer 启动分布式训练作业。

  3. 比较使用和不使用 Ray Train 的 PyTorch Lightning 训练脚本。

PyTorch Lightning

import torch
from torchvision.models import resnet18
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor, Normalize, Compose
from torch.utils.data import DataLoader
import lightning.pytorch as pl

# Model, Loss, Optimizer
class ImageClassifier(pl.LightningModule):
    def __init__(self):
        super(ImageClassifier, self).__init__()
        self.model = resnet18(num_classes=10)
        self.model.conv1 = torch.nn.Conv2d(
            1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
        )
        self.criterion = torch.nn.CrossEntropyLoss()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        outputs = self.forward(x)
        loss = self.criterion(outputs, y)
        self.log("loss", loss, on_step=True, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=0.001)

# Data
transform = Compose([ToTensor(), Normalize((0.28604,), (0.32025,))])
train_data = FashionMNIST(root='./data', train=True, download=True, transform=transform)
train_dataloader = DataLoader(train_data, batch_size=128, shuffle=True)

# Training
model = ImageClassifier()
trainer = pl.Trainer(max_epochs=10)
trainer.fit(model, train_dataloaders=train_dataloader)
import os
import tempfile

import torch
from torch.utils.data import DataLoader
from torchvision.models import resnet18
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor, Normalize, Compose
import lightning.pytorch as pl

import ray.train.lightning
from ray.train.torch import TorchTrainer

# Model, Loss, Optimizer
class ImageClassifier(pl.LightningModule):
    def __init__(self):
        super(ImageClassifier, self).__init__()
        self.model = resnet18(num_classes=10)
        self.model.conv1 = torch.nn.Conv2d(
            1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
        )
        self.criterion = torch.nn.CrossEntropyLoss()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        outputs = self.forward(x)
        loss = self.criterion(outputs, y)
        self.log("loss", loss, on_step=True, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=0.001)


def train_func():
    # Data
    transform = Compose([ToTensor(), Normalize((0.28604,), (0.32025,))])
    data_dir = os.path.join(tempfile.gettempdir(), "data")
    train_data = FashionMNIST(root=data_dir, train=True, download=True, transform=transform)
    train_dataloader = DataLoader(train_data, batch_size=128, shuffle=True)

    # Training
    model = ImageClassifier()
    # [1] Configure PyTorch Lightning Trainer.
    trainer = pl.Trainer(
        max_epochs=10,
        devices="auto",
        accelerator="auto",
        strategy=ray.train.lightning.RayDDPStrategy(),
        plugins=[ray.train.lightning.RayLightningEnvironment()],
        callbacks=[ray.train.lightning.RayTrainReportCallback()],
        # [1a] Optionally, disable the default checkpointing behavior
        # in favor of the `RayTrainReportCallback` above.
        enable_checkpointing=False,
    )
    trainer = ray.train.lightning.prepare_trainer(trainer)
    trainer.fit(model, train_dataloaders=train_dataloader)

# [2] Configure scaling and resource requirements.
scaling_config = ray.train.ScalingConfig(num_workers=2, use_gpu=True)

# [3] Launch distributed training job.
trainer = TorchTrainer(
    train_func,
    scaling_config=scaling_config,
    # [3a] If running in a multi-node cluster, this is where you
    # should configure the run's persistent storage that is accessible
    # across all worker nodes.
    # run_config=ray.train.RunConfig(storage_path="s3://..."),
)
result: ray.train.Result = trainer.fit()

# [4] Load the trained model.
with result.checkpoint.as_directory() as checkpoint_dir:
    model = ImageClassifier.load_from_checkpoint(
        os.path.join(
            checkpoint_dir,
            ray.train.lightning.RayTrainReportCallback.CHECKPOINT_NAME,
        ),
    )

首先,更新您的训练代码以支持分布式训练。首先将您的代码封装在一个训练函数

每个分布式训练 worker 都执行此函数。

def train_func():
    # Your model training code here.
    ...

您还可以通过 Trainer 的 train_loop_configtrain_func 的输入参数指定为一个字典。例如

警告

def train_func(config):
    lr = config["lr"]
    num_epochs = config["num_epochs"]

config = {"lr": 1e-4, "num_epochs": 10}
trainer = ray.train.torch.TorchTrainer(train_func, train_loop_config=config, ...)

避免通过 train_loop_config 传递大型数据对象,以减少序列化和反序列化开销。相反,建议直接在 train_func 中初始化大型对象(例如数据集、模型)。

Ray Train 会在每个 worker 上设置您的分布式进程组。您只需对您的 Lightning Trainer 定义进行少量更改即可。

 def load_dataset():
     # Return a large in-memory dataset
     ...

 def load_model():
     # Return a large in-memory model instance
     ...

-config = {"data": load_dataset(), "model": load_model()}

 def train_func(config):
-    data = config["data"]
-    model = config["model"]

+    data = load_dataset()
+    model = load_model()
     ...

 trainer = ray.train.torch.TorchTrainer(train_func, train_loop_config=config, ...)

以下章节讨论每个更改。

 import lightning.pytorch as pl
-from pl.strategies import DDPStrategy
-from pl.plugins.environments import LightningEnvironment
+import ray.train.lightning

 def train_func():
     ...
     model = MyLightningModule(...)
     datamodule = MyLightningDataModule(...)

     trainer = pl.Trainer(
-        devices=[0, 1, 2, 3],
-        strategy=DDPStrategy(),
-        plugins=[LightningEnvironment()],
+        devices="auto",
+        accelerator="auto",
+        strategy=ray.train.lightning.RayDDPStrategy(),
+        plugins=[ray.train.lightning.RayLightningEnvironment()]
     )
+    trainer = ray.train.lightning.prepare_trainer(trainer)

     trainer.fit(model, datamodule=datamodule)

配置分布式策略#

Ray Train 为 Lightning 提供了几种子类化的分布式策略。这些策略保留了与其基础策略类相同的参数列表。在内部,它们配置了根设备和分布式采样器参数。

RayDDPStrategy

 import lightning.pytorch as pl
-from pl.strategies import DDPStrategy
+import ray.train.lightning

 def train_func():
     ...
     trainer = pl.Trainer(
         ...
-        strategy=DDPStrategy(),
+        strategy=ray.train.lightning.RayDDPStrategy(),
         ...
     )
     ...

Ray Train 还提供了 RayLightningEnvironment 类作为 Ray 集群的规范。这个工具类配置了 worker 的本地、全局和节点 rank 以及 world size。

配置并行设备#

 import lightning.pytorch as pl
-from pl.plugins.environments import LightningEnvironment
+import ray.train.lightning

 def train_func():
     ...
     trainer = pl.Trainer(
         ...
-        plugins=[LightningEnvironment()],
+        plugins=[ray.train.lightning.RayLightningEnvironment()],
         ...
     )
     ...

此外,Ray TorchTrainer 已经为您配置了正确的 CUDA_VISIBLE_DEVICES。应始终通过设置 devices="auto"accelerator="auto" 来使用所有可用 GPU。

报告检查点和指标#

 import lightning.pytorch as pl

 def train_func():
     ...
     trainer = pl.Trainer(
         ...
-        devices=[0,1,2,3],
+        devices="auto",
+        accelerator="auto",
         ...
     )
     ...

为了持久化您的检查点并监控训练进度,请向您的 Trainer 添加一个 ray.train.lightning.RayTrainReportCallback 工具回调。

向 Ray Train 报告指标和检查点使您能够支持容错训练超参数优化。请注意,ray.train.lightning.RayTrainReportCallback 类仅提供了一个简单的实现,并且可以进一步定制

 import lightning.pytorch as pl
 from ray.train.lightning import RayTrainReportCallback

 def train_func():
     ...
     trainer = pl.Trainer(
         ...
-        callbacks=[...],
+        callbacks=[..., RayTrainReportCallback()],
     )
     ...

准备您的 Lightning Trainer#

最后,将您的 Lightning Trainer 传入 prepare_trainer() 以验证您的配置。

配置规模和 GPU#

 import lightning.pytorch as pl
 import ray.train.lightning

 def train_func():
     ...
     trainer = pl.Trainer(...)
+    trainer = ray.train.lightning.prepare_trainer(trainer)
     ...

在您的训练函数之外,创建一个 ScalingConfig 对象来配置

num_workers - 分布式训练 worker 进程的数量。

  1. use_gpu - 每个 worker 是否应使用 GPU(或 CPU)。

  2. 更多详细信息,请参阅配置规模和 GPU

from ray.train import ScalingConfig
scaling_config = ScalingConfig(num_workers=2, use_gpu=True)

配置持久化存储#

创建一个 RunConfig 对象来指定结果(包括检查点和 artifact)将保存到的路径。

指定一个共享存储位置(例如云存储或 NFS)对于单节点集群是可选的,但对于多节点集群是必需的。对于多节点集群,使用本地路径会在检查点过程中引发错误

from ray.train import RunConfig

# Local path (/some/local/path/unique_run_name)
run_config = RunConfig(storage_path="/some/local/path", name="unique_run_name")

# Shared cloud storage URI (s3://bucket/unique_run_name)
run_config = RunConfig(storage_path="s3://bucket", name="unique_run_name")

# Shared NFS path (/mnt/nfs/unique_run_name)
run_config = RunConfig(storage_path="/mnt/nfs", name="unique_run_name")

避免通过 train_loop_config 传递大型数据对象,以减少序列化和反序列化开销。相反,建议直接在 train_func 中初始化大型对象(例如数据集、模型)。

更多详细信息,请参阅配置持久化存储

启动训练作业#

将这些整合在一起,您现在可以使用 TorchTrainer 启动分布式训练作业。

访问训练结果#

from ray.train.torch import TorchTrainer

trainer = TorchTrainer(
    train_func, scaling_config=scaling_config, run_config=run_config
)
result = trainer.fit()

训练完成后,会返回一个 Result 对象,其中包含训练运行的信息,包括训练期间报告的指标和检查点。

更多用法示例,请参阅检查训练结果

result.metrics     # The metrics reported during training.
result.checkpoint  # The latest checkpoint reported during training.
result.path        # The path where logs are stored.
result.error       # The exception that was raised, if training failed.

下一步#

将您的 PyTorch Lightning 训练脚本转换为使用 Ray Train 后

请参阅用户指南,了解如何执行特定任务的更多信息。

  • 浏览示例,获取关于如何使用 Ray Train 的端到端示例。

  • 查阅API 参考,了解本教程中使用的类和方法的更多详细信息。

  • 版本兼容性#

Ray Train 已使用 pytorch_lightning 版本 1.6.52.1.2 进行了测试。为了获得完全兼容性,请使用 pytorch_lightning>=1.6.5 。早期版本并非禁用,但可能会导致意外问题。如果您遇到任何兼容性问题,请考虑升级您的 PyTorch Lightning 版本或提交一个问题

注意

如果您使用的是 Lightning 2.x,请使用导入路径 lightning.pytorch.xxx,而不是 pytorch_lightning.xxx

LightningTrainer 迁移指南#

Ray 2.4 引入了 LightningTrainer,并暴露了 LightningConfigBuilder 来定义 pl.LightningModulepl.Trainer 的配置。

然后,它实例化模型和 trainer 对象,并在一个黑箱中运行预定义的训练函数。

这个版本的 LightningTrainer API 具有约束性,限制了您管理训练功能的能力。

Ray 2.7 引入了全新的统一 TorchTrainer API,它提供了增强的透明度、灵活性和简洁性。这个 API 更符合标准的 PyTorch Lightning 脚本,确保用户对其原生的 Lightning 代码有更好的控制。

(已弃用) LightningTrainer

from ray.train.lightning import LightningConfigBuilder, LightningTrainer

config_builder = LightningConfigBuilder()
# [1] Collect model configs
config_builder.module(cls=MyLightningModule, lr=1e-3, feature_dim=128)

# [2] Collect checkpointing configs
config_builder.checkpointing(monitor="val_accuracy", mode="max", save_top_k=3)

# [3] Collect pl.Trainer configs
config_builder.trainer(
    max_epochs=10,
    accelerator="gpu",
    log_every_n_steps=100,
)

# [4] Build datasets on the head node
datamodule = MyLightningDataModule(batch_size=32)
config_builder.fit_params(datamodule=datamodule)

# [5] Execute the internal training function in a black box
ray_trainer = LightningTrainer(
    lightning_config=config_builder.build(),
    scaling_config=ScalingConfig(num_workers=4, use_gpu=True),
    run_config=RunConfig(
        checkpoint_config=CheckpointConfig(
            num_to_keep=3,
            checkpoint_score_attribute="val_accuracy",
            checkpoint_score_order="max",
        ),
    )
)
result = ray_trainer.fit()

# [6] Load the trained model from an opaque Lightning-specific checkpoint.
lightning_checkpoint = result.checkpoint
model = lightning_checkpoint.get_model(MyLightningModule)
import os

import lightning.pytorch as pl

import ray.train
from ray.train.torch import TorchTrainer
from ray.train.lightning import (
    RayDDPStrategy,
    RayLightningEnvironment,
    RayTrainReportCallback,
    prepare_trainer
)

def train_func():
    # [1] Create a Lightning model
    model = MyLightningModule(lr=1e-3, feature_dim=128)

    # [2] Report Checkpoint with callback
    ckpt_report_callback = RayTrainReportCallback()

    # [3] Create a Lighting Trainer
    trainer = pl.Trainer(
        max_epochs=10,
        log_every_n_steps=100,
        # New configurations below
        devices="auto",
        accelerator="auto",
        strategy=RayDDPStrategy(),
        plugins=[RayLightningEnvironment()],
        callbacks=[ckpt_report_callback],
    )

    # Validate your Lightning trainer configuration
    trainer = prepare_trainer(trainer)

    # [4] Build your datasets on each worker
    datamodule = MyLightningDataModule(batch_size=32)
    trainer.fit(model, datamodule=datamodule)

# [5] Explicitly define and run the training function
ray_trainer = TorchTrainer(
    train_func,
    scaling_config=ray.train.ScalingConfig(num_workers=4, use_gpu=True),
    run_config=ray.train.RunConfig(
        checkpoint_config=ray.train.CheckpointConfig(
            num_to_keep=3,
            checkpoint_score_attribute="val_accuracy",
            checkpoint_score_order="max",
        ),
    )
)
result = ray_trainer.fit()

# [6] Load the trained model from a simplified checkpoint interface.
checkpoint: ray.train.Checkpoint = result.checkpoint
with checkpoint.as_directory() as checkpoint_dir:
    print("Checkpoint contents:", os.listdir(checkpoint_dir))
    checkpoint_path = os.path.join(checkpoint_dir, "checkpoint.ckpt")
    model = MyLightningModule.load_from_checkpoint(checkpoint_path)