使用 JAX 进行分布式训练入门#

本指南概述了 Ray Train 中的 JaxTrainer

什么是 JAX?#

JAX 是一个用于加速器驱动的数组计算和程序转换的 Python 库,专为高性能数值计算和大规模机器学习而设计。

JAX 提供了一个可扩展的系统,用于转换数值函数,如 jax.gradjax.jitjax.vmap,利用 XLA 编译器创建高度优化的代码,这些代码可以在 GPU 和 TPU 等加速器上高效扩展。JAX 的核心优势在于其可组合性,允许将这些转换组合起来,构建用于分布式执行的复杂、高性能数值程序。

什么是 TPU?#

Tensor Processing Units (TPU),是 Google 为优化机器学习工作负载而创建的定制设计加速器。与通用 CPU 或并行处理 GPU 不同,TPU 高度专注于深度学习中涉及的海量矩阵和张量计算,因此效率极高。

TPU 的主要优势在于规模化性能,因为它们可以通过高速 ICI 互连连接成大型多主机配置,称为“PodSlices”,这使得它们非常适合训练无法容纳在单个节点上的大型模型。

要了解有关使用 KubeRay 配置 TPU 的更多信息,请参阅 使用 KubeRay 的 TPU

JaxTrainer API#

在 Ray Train 中,JaxTrainer 是协调分布式 JAX 训练的核心组件。它遵循 SPMD(Single-Program, Multi-Data,单程序多数据)范例,您的训练代码会在多个工作节点上同时执行,每个节点运行在 TPU 切片内的独立 TPU 虚拟机上。Ray 会自动原子性地预留一个 TPU 多主机切片。

JaxTrainer 使用您在 train_loop_per_worker 函数中定义的训练逻辑以及指定分布式硬件布局的 ScalingConfig 进行初始化。JaxTrainer 目前仅支持 TPU 加速器类型。

配置规模和 TPU#

对于 TPU 训练,您可以在 ScalingConfig 中定义硬件切片的具体细节。关键字段包括:

  • use_tpu:这是 Ray 2.49.0 中为 V2 ScalingConfig 新增的字段。此布尔标志明确指示 Ray Train 初始化 JAX 后端以进行 TPU 执行。

  • topology:这是 Ray 2.49.0 中为 V2 ScalingConfig 新增的字段。Topology 是一个字符串,定义了 TPU 芯片的物理布局(例如,“4x4”)。这对于多主机训练是必需的,并确保 Ray 正确地在切片中放置工作节点。有关按代支持的 TPU Topology 列表,请参阅 GKE 文档

  • num_workers:设置为 TPU 切片中的虚拟机数量。对于具有 2x2x4 Topology 的 v4-32 切片,这将是 4。

  • resources_per_worker:一个字典,指定每个工作节点所需的资源。对于 TPU,您通常会请求每个虚拟机中的芯片数量(例如:{“TPU”:4})。

  • accelerator_type:对于 TPU,accelerator_type 指定您正在使用的 TPU 代(例如,“TPU-V6E”),确保您的工作负载调度到所需的 TPU 切片上。

这些配置共同提供了一个声明式 API,用于定义您的整个分布式 JAX 训练环境,使 Ray Train 能够处理在 TPU 切片上启动和协调工作节点的复杂任务。

快速入门#

供参考,最终代码如下:

from ray.train.v2.jax import JaxTrainer
from ray.train import ScalingConfig

def train_func():
    # Your JAX training code here.

scaling_config = ScalingConfig(num_workers=4, use_tpu=True, topology="4x4", accelerator_type="TPU-V6E")
trainer = JaxTrainer(train_func, scaling_config=scaling_config)
result = trainer.fit()
  1. train_func 是在每个分布式训练工作节点上执行的 Python 代码。

  2. ScalingConfig 定义了分布式训练工作节点的数量以及是否使用 TPU。

  3. JaxTrainer 启动分布式训练作业。

比较使用 Ray Train 和不使用 Ray Train 的 JAX 训练脚本。

import jax
import jax.numpy as jnp
import optax
import ray.train

from ray.train.v2.jax import JaxTrainer
from ray.train import ScalingConfig

def train_func():
    """This function is run on each distributed worker."""
    key = jax.random.PRNGKey(jax.process_index())
    X = jax.random.normal(key, (100, 1))
    noise = jax.random.normal(key, (100, 1)) * 0.1
    y = 2 * X + 1 + noise

    def linear_model(params, x):
        return x @ params['w'] + params['b']

    def loss_fn(params, x, y):
        preds = linear_model(params, x)
        return jnp.mean((preds - y) ** 2)

    @jax.jit
    def train_step(params, opt_state, x, y):
        loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
        updates, opt_state = optimizer.update(grads, opt_state)
        params = optax.apply_updates(params, updates)
        return params, opt_state, loss

    # Initialize parameters and optimizer.
    key, w_key, b_key = jax.random.split(key, 3)
    params = {'w': jax.random.normal(w_key, (1, 1)), 'b': jax.random.normal(b_key, (1,))}
    optimizer = optax.adam(learning_rate=0.01)
    opt_state = optimizer.init(params)

    # Training loop
    epochs = 100
    for epoch in range(epochs):
        params, opt_state, loss = train_step(params, opt_state, X, y)
        # Report metrics back to Ray Train.
        ray.train.report({"loss": float(loss), "epoch": epoch})

# Define the hardware configuration for your distributed job.
scaling_config = ScalingConfig(
    num_workers=4,
    use_tpu=True,
    topology="4x4",
    accelerator_type="TPU-V6E",
    placement_strategy="SPREAD"
)

# Define and run the JaxTrainer.
trainer = JaxTrainer(
    train_loop_per_worker=train_func,
    scaling_config=scaling_config,
)
result = trainer.fit()
print(f"Training finished. Final loss: {result.metrics['loss']:.4f}")
import jax
import jax.numpy as jnp
import optax

# In a non-Ray script, you would manually initialize the
# distributed environment for multi-host training.
# import jax.distributed
# jax.distributed.initialize()

# Generate synthetic data.
key = jax.random.PRNGKey(0)
X = jax.random.normal(key, (100, 1))
noise = jax.random.normal(key, (100, 1)) * 0.1
y = 2 * X + 1 + noise

# Model and loss function are standard JAX.
def linear_model(params, x):
    return x @ params['w'] + params['b']

def loss_fn(params, x, y):
    preds = linear_model(params, x)
    return jnp.mean((preds - y) ** 2)

@jax.jit
def train_step(params, opt_state, x, y):
    loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss

# Initialize parameters and optimizer.
key, w_key, b_key = jax.random.split(key, 3)
params = {'w': jax.random.normal(w_key, (1, 1)), 'b': jax.random.normal(b_key, (1,))}
optimizer = optax.adam(learning_rate=0.01)
opt_state = optimizer.init(params)

# Training loop
epochs = 100
print("Starting training...")
for epoch in range(epochs):
    params, opt_state, loss = train_step(params, opt_state, X, y)
    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {loss:.4f}")

print("Training finished.")
print(f"Learned parameters: w={params['w'].item():.4f}, b={params['b'].item():.4f}")

设置训练函数#

Ray Train 会自动在每个 TPU 工作节点上初始化 JAX 分布式环境。要适应您现有的 JAX 代码,您只需将训练逻辑包装在一个可以传递给 JaxTrainer 的 Python 函数中。

此函数是 Ray 将在每个远程节点上执行的入口点。

+from ray.train.v2.jax import JaxTrainer
+from ray.train import ScalingConfig, report

-def main_logic()
+def train_func():
    """This function is run on each distributed worker."""
    # ... (JAX model, data, and training step definitions) ...

    # Training loop
    for epoch in range(epochs):
        params, opt_state, loss = train_step(params, opt_state, X, y)
-       print(f"Epoch {epoch}, Loss: {loss:.4f}")
+       # In Ray Train, you can report metrics back to the trainer
+       report({"loss": float(loss), "epoch": epoch})

-if __name__ == "__main__":
-    main_logic()
+# Define the hardware configuration for your distributed job.
+scaling_config = ScalingConfig(
+    num_workers=4,
+    use_tpu=True,
+    topology="4x4",
+    accelerator_type="TPU-V6E",
+    placement_strategy="SPREAD"
+)
+
+# Define and run the JaxTrainer, which executes `train_func`.
+trainer = JaxTrainer(
+    train_loop_per_worker=train_func,
+    scaling_config=scaling_config
+)
+result = trainer.fit()

配置持久存储#

创建一个 RunConfig 对象来指定结果(包括检查点和工件)将要保存的路径。

from ray.train import RunConfig

# Local path (/some/local/path/unique_run_name)
run_config = RunConfig(storage_path="/some/local/path", name="unique_run_name")

# Shared cloud storage URI (s3://bucket/unique_run_name)
run_config = RunConfig(storage_path="s3://bucket", name="unique_run_name")

# Shared NFS path (/mnt/nfs/unique_run_name)
run_config = RunConfig(storage_path="/mnt/nfs", name="unique_run_name")

警告

指定一个*共享存储位置*(如云存储或 NFS)对于单节点集群是*可选的*,但对于多节点集群是*必需的*。使用本地路径将在多节点集群的检查点过程中*引发错误*。

有关更多详细信息,请参阅 配置持久存储

启动训练作业#

将所有内容联系起来,您现在可以使用 JaxTrainer 启动分布式训练作业。

from ray.train import ScalingConfig

train_func = lambda: None
scaling_config = ScalingConfig(num_workers=4, use_tpu=True, topology="4x4", accelerator_type="TPU-V6E")
run_config = None
from ray.train.v2.jax import JaxTrainer

trainer = JaxTrainer(
    train_func, scaling_config=scaling_config, run_config=run_config
)
result = trainer.fit()

访问训练结果#

训练完成后,将返回一个 Result 对象,其中包含有关训练运行的信息,包括训练期间报告的指标和检查点。

result.metrics     # The metrics reported during training.
result.checkpoint  # The latest checkpoint reported during training.
result.path        # The path where logs are stored.
result.error       # The exception that was raised, if training failed.

有关更多用法示例,请参阅 检查训练结果

下一步#

在您将 JAX 训练脚本转换为使用 Ray Train 后

  • 请参阅 用户指南 以了解有关执行特定任务的更多信息。

  • 浏览 示例 以获取有关如何使用 Ray Train 的端到端示例。

  • 请参阅 API 参考,了解本教程中的类和方法的更多详细信息。