使用 JAX 进行分布式训练入门#
本指南概述了 Ray Train 中的 JaxTrainer。
什么是 JAX?#
JAX 是一个用于加速器驱动的数组计算和程序转换的 Python 库,专为高性能数值计算和大规模机器学习而设计。
JAX 提供了一个可扩展的系统,用于转换数值函数,如 jax.grad、jax.jit 和 jax.vmap,利用 XLA 编译器创建高度优化的代码,这些代码可以在 GPU 和 TPU 等加速器上高效扩展。JAX 的核心优势在于其可组合性,允许将这些转换组合起来,构建用于分布式执行的复杂、高性能数值程序。
什么是 TPU?#
Tensor Processing Units (TPU),是 Google 为优化机器学习工作负载而创建的定制设计加速器。与通用 CPU 或并行处理 GPU 不同,TPU 高度专注于深度学习中涉及的海量矩阵和张量计算,因此效率极高。
TPU 的主要优势在于规模化性能,因为它们可以通过高速 ICI 互连连接成大型多主机配置,称为“PodSlices”,这使得它们非常适合训练无法容纳在单个节点上的大型模型。
要了解有关使用 KubeRay 配置 TPU 的更多信息,请参阅 使用 KubeRay 的 TPU。
JaxTrainer API#
在 Ray Train 中,JaxTrainer 是协调分布式 JAX 训练的核心组件。它遵循 SPMD(Single-Program, Multi-Data,单程序多数据)范例,您的训练代码会在多个工作节点上同时执行,每个节点运行在 TPU 切片内的独立 TPU 虚拟机上。Ray 会自动原子性地预留一个 TPU 多主机切片。
JaxTrainer 使用您在 train_loop_per_worker 函数中定义的训练逻辑以及指定分布式硬件布局的 ScalingConfig 进行初始化。JaxTrainer 目前仅支持 TPU 加速器类型。
配置规模和 TPU#
对于 TPU 训练,您可以在 ScalingConfig 中定义硬件切片的具体细节。关键字段包括:
use_tpu:这是 Ray 2.49.0 中为 V2ScalingConfig新增的字段。此布尔标志明确指示 Ray Train 初始化 JAX 后端以进行 TPU 执行。topology:这是 Ray 2.49.0 中为 V2ScalingConfig新增的字段。Topology 是一个字符串,定义了 TPU 芯片的物理布局(例如,“4x4”)。这对于多主机训练是必需的,并确保 Ray 正确地在切片中放置工作节点。有关按代支持的 TPU Topology 列表,请参阅 GKE 文档。num_workers:设置为 TPU 切片中的虚拟机数量。对于具有 2x2x4 Topology 的 v4-32 切片,这将是 4。resources_per_worker:一个字典,指定每个工作节点所需的资源。对于 TPU,您通常会请求每个虚拟机中的芯片数量(例如:{“TPU”:4})。accelerator_type:对于 TPU,accelerator_type指定您正在使用的 TPU 代(例如,“TPU-V6E”),确保您的工作负载调度到所需的 TPU 切片上。
这些配置共同提供了一个声明式 API,用于定义您的整个分布式 JAX 训练环境,使 Ray Train 能够处理在 TPU 切片上启动和协调工作节点的复杂任务。
快速入门#
供参考,最终代码如下:
from ray.train.v2.jax import JaxTrainer
from ray.train import ScalingConfig
def train_func():
# Your JAX training code here.
scaling_config = ScalingConfig(num_workers=4, use_tpu=True, topology="4x4", accelerator_type="TPU-V6E")
trainer = JaxTrainer(train_func, scaling_config=scaling_config)
result = trainer.fit()
train_func是在每个分布式训练工作节点上执行的 Python 代码。ScalingConfig定义了分布式训练工作节点的数量以及是否使用 TPU。JaxTrainer启动分布式训练作业。
比较使用 Ray Train 和不使用 Ray Train 的 JAX 训练脚本。
import jax
import jax.numpy as jnp
import optax
import ray.train
from ray.train.v2.jax import JaxTrainer
from ray.train import ScalingConfig
def train_func():
"""This function is run on each distributed worker."""
key = jax.random.PRNGKey(jax.process_index())
X = jax.random.normal(key, (100, 1))
noise = jax.random.normal(key, (100, 1)) * 0.1
y = 2 * X + 1 + noise
def linear_model(params, x):
return x @ params['w'] + params['b']
def loss_fn(params, x, y):
preds = linear_model(params, x)
return jnp.mean((preds - y) ** 2)
@jax.jit
def train_step(params, opt_state, x, y):
loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return params, opt_state, loss
# Initialize parameters and optimizer.
key, w_key, b_key = jax.random.split(key, 3)
params = {'w': jax.random.normal(w_key, (1, 1)), 'b': jax.random.normal(b_key, (1,))}
optimizer = optax.adam(learning_rate=0.01)
opt_state = optimizer.init(params)
# Training loop
epochs = 100
for epoch in range(epochs):
params, opt_state, loss = train_step(params, opt_state, X, y)
# Report metrics back to Ray Train.
ray.train.report({"loss": float(loss), "epoch": epoch})
# Define the hardware configuration for your distributed job.
scaling_config = ScalingConfig(
num_workers=4,
use_tpu=True,
topology="4x4",
accelerator_type="TPU-V6E",
placement_strategy="SPREAD"
)
# Define and run the JaxTrainer.
trainer = JaxTrainer(
train_loop_per_worker=train_func,
scaling_config=scaling_config,
)
result = trainer.fit()
print(f"Training finished. Final loss: {result.metrics['loss']:.4f}")
import jax
import jax.numpy as jnp
import optax
# In a non-Ray script, you would manually initialize the
# distributed environment for multi-host training.
# import jax.distributed
# jax.distributed.initialize()
# Generate synthetic data.
key = jax.random.PRNGKey(0)
X = jax.random.normal(key, (100, 1))
noise = jax.random.normal(key, (100, 1)) * 0.1
y = 2 * X + 1 + noise
# Model and loss function are standard JAX.
def linear_model(params, x):
return x @ params['w'] + params['b']
def loss_fn(params, x, y):
preds = linear_model(params, x)
return jnp.mean((preds - y) ** 2)
@jax.jit
def train_step(params, opt_state, x, y):
loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return params, opt_state, loss
# Initialize parameters and optimizer.
key, w_key, b_key = jax.random.split(key, 3)
params = {'w': jax.random.normal(w_key, (1, 1)), 'b': jax.random.normal(b_key, (1,))}
optimizer = optax.adam(learning_rate=0.01)
opt_state = optimizer.init(params)
# Training loop
epochs = 100
print("Starting training...")
for epoch in range(epochs):
params, opt_state, loss = train_step(params, opt_state, X, y)
if epoch % 10 == 0:
print(f"Epoch {epoch}, Loss: {loss:.4f}")
print("Training finished.")
print(f"Learned parameters: w={params['w'].item():.4f}, b={params['b'].item():.4f}")
设置训练函数#
Ray Train 会自动在每个 TPU 工作节点上初始化 JAX 分布式环境。要适应您现有的 JAX 代码,您只需将训练逻辑包装在一个可以传递给 JaxTrainer 的 Python 函数中。
此函数是 Ray 将在每个远程节点上执行的入口点。
+from ray.train.v2.jax import JaxTrainer
+from ray.train import ScalingConfig, report
-def main_logic()
+def train_func():
"""This function is run on each distributed worker."""
# ... (JAX model, data, and training step definitions) ...
# Training loop
for epoch in range(epochs):
params, opt_state, loss = train_step(params, opt_state, X, y)
- print(f"Epoch {epoch}, Loss: {loss:.4f}")
+ # In Ray Train, you can report metrics back to the trainer
+ report({"loss": float(loss), "epoch": epoch})
-if __name__ == "__main__":
- main_logic()
+# Define the hardware configuration for your distributed job.
+scaling_config = ScalingConfig(
+ num_workers=4,
+ use_tpu=True,
+ topology="4x4",
+ accelerator_type="TPU-V6E",
+ placement_strategy="SPREAD"
+)
+
+# Define and run the JaxTrainer, which executes `train_func`.
+trainer = JaxTrainer(
+ train_loop_per_worker=train_func,
+ scaling_config=scaling_config
+)
+result = trainer.fit()
配置持久存储#
创建一个 RunConfig 对象来指定结果(包括检查点和工件)将要保存的路径。
from ray.train import RunConfig
# Local path (/some/local/path/unique_run_name)
run_config = RunConfig(storage_path="/some/local/path", name="unique_run_name")
# Shared cloud storage URI (s3://bucket/unique_run_name)
run_config = RunConfig(storage_path="s3://bucket", name="unique_run_name")
# Shared NFS path (/mnt/nfs/unique_run_name)
run_config = RunConfig(storage_path="/mnt/nfs", name="unique_run_name")
警告
指定一个*共享存储位置*(如云存储或 NFS)对于单节点集群是*可选的*,但对于多节点集群是*必需的*。使用本地路径将在多节点集群的检查点过程中*引发错误*。
有关更多详细信息,请参阅 配置持久存储。
启动训练作业#
将所有内容联系起来,您现在可以使用 JaxTrainer 启动分布式训练作业。
from ray.train import ScalingConfig
train_func = lambda: None
scaling_config = ScalingConfig(num_workers=4, use_tpu=True, topology="4x4", accelerator_type="TPU-V6E")
run_config = None
from ray.train.v2.jax import JaxTrainer
trainer = JaxTrainer(
train_func, scaling_config=scaling_config, run_config=run_config
)
result = trainer.fit()
访问训练结果#
训练完成后,将返回一个 Result 对象,其中包含有关训练运行的信息,包括训练期间报告的指标和检查点。
result.metrics # The metrics reported during training.
result.checkpoint # The latest checkpoint reported during training.
result.path # The path where logs are stored.
result.error # The exception that was raised, if training failed.
有关更多用法示例,请参阅 检查训练结果。
下一步#
在您将 JAX 训练脚本转换为使用 Ray Train 后