Ray Train 概述#

为了有效使用 Ray Train,您需要了解四个主要概念

  1. 训练函数:一个包含模型训练逻辑的 Python 函数。

  2. Worker:一个运行训练函数的进程。

  3. 扩展配置:一个关于 worker 数量和计算资源(例如,CPU 或 GPU)的配置。

  4. Trainer:一个 Python 类,它将训练函数、worker 和扩展配置结合起来执行分布式训练作业。

../_images/overview.png

训练函数#

训练函数是用户定义的 Python 函数,其中包含端到端的模型训练循环逻辑。在启动分布式训练作业时,每个 worker 都会执行此训练函数。

Ray Train 文档使用以下约定

  1. train_func 是一个用户定义的函数,其中包含训练代码。

  2. train_func 被传递到 Trainer 的 train_loop_per_worker 参数。

def train_func():
    """User-defined training function that runs on each distributed worker process.

    This function typically contains logic for loading the model,
    loading the dataset, training the model, saving checkpoints,
    and logging metrics.
    """
    ...

Worker#

Ray Train 将模型训练计算分配给集群中的各个 worker 进程。每个 worker 都是执行 train_func 的进程。worker 的数量决定了训练作业的并行度,并在 ScalingConfig 中配置。

扩展配置#

ScalingConfig 是定义训练作业规模的机制。为 worker 并行度和计算资源指定两个基本参数

  • num_workers:为分布式训练作业启动的 worker 数量。

  • use_gpu:每个 worker 是否应使用 GPU 或 CPU。

from ray.train import ScalingConfig

# Single worker with a CPU
scaling_config = ScalingConfig(num_workers=1, use_gpu=False)

# Single worker with a GPU
scaling_config = ScalingConfig(num_workers=1, use_gpu=True)

# Multiple workers, each with a GPU
scaling_config = ScalingConfig(num_workers=4, use_gpu=True)

Trainer#

Trainer 将前面三个概念结合起来以启动分布式训练作业。Ray Train 为不同的框架提供了 Trainer 类。调用 fit() 方法通过以下方式执行训练作业

  1. 根据 scaling_config 定义启动 worker。

  2. 在所有 worker 上设置框架的分布式环境。

  3. 在所有 worker 上运行 train_func

from ray.train.torch import TorchTrainer

trainer = TorchTrainer(train_func, scaling_config=scaling_config)
trainer.fit()