Ray Train概述#

要有效地使用Ray Train,您需要理解四个主要概念

  1. 训练函数 (Training function):一个包含模型训练逻辑的Python函数。

  2. 工作进程 (Worker):一个运行训练函数的进程。

  3. 扩展配置 (Scaling configuration): 关于工作进程数量和计算资源(例如CPU或GPU)的配置。

  4. 训练器 (Trainer):一个Python类,将训练函数、工作进程和扩展配置结合起来执行分布式训练任务。

../_images/overview.png

训练函数#

训练函数是一个用户定义的Python函数,包含端到端的模型训练循环逻辑。启动分布式训练任务时,每个工作进程都会执行此训练函数。

Ray Train文档使用以下约定

  1. train_func 是一个用户定义的函数,包含训练代码。

  2. train_func 作为参数传递给训练器的 train_loop_per_worker

def train_func():
    """User-defined training function that runs on each distributed worker process.

    This function typically contains logic for loading the model,
    loading the dataset, training the model, saving checkpoints,
    and logging metrics.
    """
    ...

工作进程#

Ray Train将模型训练计算分发到集群中的各个工作进程。每个工作进程都是一个执行 train_func 的进程。工作进程的数量决定了训练任务的并行度,并在 ScalingConfig 中配置。

扩展配置#

ScalingConfig 是定义训练任务规模的机制。为工作进程并行度和计算资源指定两个基本参数

  • num_workers:为分布式训练任务启动的工作进程数量。

  • use_gpu:每个工作进程是否应使用GPU或CPU。

from ray.train import ScalingConfig

# Single worker with a CPU
scaling_config = ScalingConfig(num_workers=1, use_gpu=False)

# Single worker with a GPU
scaling_config = ScalingConfig(num_workers=1, use_gpu=True)

# Multiple workers, each with a GPU
scaling_config = ScalingConfig(num_workers=4, use_gpu=True)

训练器#

训练器将前三个概念结合起来,用于启动分布式训练任务。Ray Train为不同框架提供了训练器类 (Trainer classes)。调用 fit() 方法执行训练任务,具体方式为

  1. 根据扩展配置 (scaling_config) 定义的方式启动工作进程。

  2. 在所有工作进程上设置框架的分布式环境。

  3. 在所有工作进程上运行 train_func

from ray.train.torch import TorchTrainer

trainer = TorchTrainer(train_func, scaling_config=scaling_config)
trainer.fit()