高级:扩展昂贵的 colate 函数#

默认情况下,当您调用 ray.data.DataIterator.iter_torch_batches() 时,colate 函数会在训练 worker 上执行。这种方法有两个主要缺点:

  • 低可扩展性:colate 函数在每个训练 worker 上顺序运行,限制了并行性。

  • 资源竞争:colate 函数消耗训练 worker 的 CPU 和内存资源,可能导致模型训练速度变慢。

将 colate 函数扩展到 Ray Data 可以让您在多个 CPU 节点上独立于训练 worker 来扩展 collation,从而提高整体流水线吞吐量,尤其是在 colate 函数负载较高时。

此优化在 colate 函数计算成本高昂(例如,标记化、图像增强或复杂的特征工程)且您有额外的 CPU 资源可用于数据预处理时特别有效。

将 colate 函数移至 Ray Data#

以下示例显示了一个在训练 worker 上运行的典型 colate 函数。

train_dataset = read_parquet().map(...)

def train_func():
    for batch in ray.train.get_dataset_shard("train").iter_torch_batches(
        collate_fn=collate_fn,
        batch_size=BATCH_SIZE
    ):
        # Training logic here
        pass

trainer = TorchTrainer(
    train_func,
    datasets={"train": train_dataset},
    scaling_config=ScalingConfig(num_workers=4, use_gpu=True)
)

result = trainer.fit()

如果 colate 函数耗时/计算密集,并且您想将其扩展,您应该:

创建在 Ray Data 中运行的自定义 colate 函数#

为了扩展,您需要将 collate_fn 移至 Ray Data 的 map_batches 操作。

def collate_fn(batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
    return batch

train_dataset = train_dataset.map_batches(collate_fn, batch_size=BATCH_SIZE)

def train_func():
    for batch in ray.train.get_dataset_shard("train").iter_torch_batches(
        collate_fn=None,
        batch_size=BATCH_SIZE,
    ):
        # Training logic here
        pass

trainer = TorchTrainer(
    train_func,
    datasets={"train": train_dataset},
    scaling_config=ScalingConfig(num_workers=4, use_gpu=True)
)

result = trainer.fit()

有几点需要注意:

  • collate_fn 返回一个 NumPy 数组字典,这是 Ray Data 的标准批次格式。

  • iter_torch_batches 方法使用 collate_fn=None,这减少了训练 worker 进程上完成的工作量。

确保批次大小对齐#

通常,colate 函数用于创建具有目标批次大小的完整数据批次。但是,如果您使用 ray.data.Dataset.map_batches() 将 colate 函数移至 Ray Data,默认情况下,它不会保证每次函数调用的批次大小。

您可能会遇到两个常见问题:

  1. colate 函数需要提供一定数量的行作为输入才能正常工作。

  2. 您想避免在训练 worker 进程上进行任何数据重格式化/重批次处理。

为了解决这些问题,您可以使用 ray.data.Dataset.repartition()target_num_rows_per_block 来确保批次大小对齐。

通过在 map_batches 之前调用 repartition,可以确保输入块包含所需的行数。

# Note: If you only use map_batches(batch_size=BATCH_SIZE), you are not guaranteed to get the desired number of rows as an input.
dataset = dataset.repartition(target_num_rows_per_block=BATCH_SIZE).map_batches(collate_fn, batch_size=BATCH_SIZE)

通过在 map_batches 之后调用 repartition,可以确保输出块包含所需的行数。这可以避免在训练 worker 进程上进行任何数据重格式化/重批次处理。

dataset = dataset.map_batches(collate_fn, batch_size=BATCH_SIZE).repartition(target_num_rows_per_block=BATCH_SIZE)

def train_func():
    for batch in ray.train.get_dataset_shard("train").iter_torch_batches(
        collate_fn=None,
        batch_size=BATCH_SIZE,
    ):
        # Training logic here
        pass

trainer = TorchTrainer(
    train_func,
    datasets={"train": train_dataset},
    scaling_config=ScalingConfig(num_workers=4, use_gpu=True)
)

result = trainer.fit()

整合#

在本指南中,我们使用一个模拟文本数据集来演示此优化。您可以在 随机文本生成器 中找到模拟数据集的实现。

以下示例显示了一个在训练 worker 上运行的典型 colate 函数。

from transformers import AutoTokenizer
import torch
import numpy as np
from typing import Dict
from ray.train.torch import TorchTrainer
from ray.train import ScalingConfig
from mock_dataset import create_mock_ray_text_dataset

BATCH_SIZE = 10000

def vanilla_collate_fn(tokenizer: AutoTokenizer, batch: Dict[str, np.ndarray]) -> Dict[str, torch.Tensor]:
    outputs = tokenizer(
        list(batch["text"]),
        truncation=True,
        padding="longest",
        return_tensors="pt",
    )
    outputs["labels"] = torch.LongTensor(batch["label"])
    return outputs

def train_func():
    tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
    collate_fn = lambda x: vanilla_collate_fn(tokenizer, x)

    # Collate function runs on the training worker
    for batch in ray.train.get_dataset_shard("train").iter_torch_batches(
        collate_fn=collate_fn,
        batch_size=BATCH_SIZE
    ):
        # Training logic here
        pass

train_dataset = create_mock_ray_text_dataset(
    dataset_size=1000000,
    min_len=1000,
    max_len=3000
)

trainer = TorchTrainer(
    train_func,
    datasets={"train": train_dataset},
    scaling_config=ScalingConfig(num_workers=4, use_gpu=True)
)

result = trainer.fit()

以下示例将 colate 函数移至 Ray Data 预处理。

from transformers import AutoTokenizer
import numpy as np
from typing import Dict
from ray.train.torch import TorchTrainer
from ray.train import ScalingConfig
from mock_dataset import create_mock_ray_text_dataset
import pyarrow as pa

BATCH_SIZE = 10000

class CollateFnRayData:
    def __init__(self):
        self.tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")

    def __call__(self, batch: pa.Table) -> Dict[str, np.ndarray]:
        results = self.tokenizer(
            batch["text"].to_pylist(),
            truncation=True,
            padding="longest",
            return_tensors="np",
        )
        results["labels"] = np.array(batch["label"])
        return results

def train_func():
    # Collate function already ran in Ray Data
    for batch in ray.train.get_dataset_shard("train").iter_torch_batches(
        collate_fn=None,
        batch_size=BATCH_SIZE,
    ):
        # Training logic here
        pass

# Apply preprocessing in Ray Data
train_dataset = (
    create_mock_ray_text_dataset(
        dataset_size=1000000,
        min_len=1000,
        max_len=3000
    )
    .map_batches(
        CollateFnRayData,
        batch_size=BATCH_SIZE,
        batch_format="pyarrow",
    )
    .repartition(target_num_rows_per_block=BATCH_SIZE)  # Ensure batch size alignment
)

trainer = TorchTrainer(
    train_func,
    datasets={"train": train_dataset},
    scaling_config=ScalingConfig(num_workers=4, use_gpu=True)
)

result = trainer.fit()

优化实现做了以下更改:

  • 在 Ray Data 中进行预处理:标记化逻辑从 train_func 移至 CollateFnRayData,后者在 map_batches 中运行。

  • NumPy 输出:colate 函数返回 Dict[str, np.ndarray] 而不是 PyTorch 张量,Ray Data 原生支持此格式。

  • 批次对齐map_batches 之后的 repartition(target_num_rows_per_block=BATCH_SIZE) 可确保 colate 函数接收确切的批次大小,并且输出块与批次大小对齐。

  • 迭代器中无 collate_fniter_torch_batches 使用 collate_fn=None,因为预处理已在 Ray Data 中完成。

基准测试结果#

以下基准测试展示了扩展 colate 函数所带来的性能提升。测试在 100 万行、文本长度在 1000 到 3000 个字符之间的数据集上,使用批次大小为 10,000 的文本标记化。

单节点(g4dn.12xlarge:48 vCPU,4 NVIDIA T4 GPU,192 GiB 内存)

配置

吞吐量

迭代器中的 Collate(基线)

1,588 行/秒

Ray Data 中的 Collate

3,437 行/秒

添加 2 个额外的 CPU 节点(m5.8xlarge:每个 32 vCPU,128 GiB 内存)

配置

吞吐量

迭代器中的 Collate(基线)

1,659 行/秒

Ray Data 中的 Collate

10,717 行/秒

结果表明,将 colate 函数扩展到 Ray Data 在单节点上可提供 2 倍的加速,在添加仅 CPU 的节点进行预处理时可提供 6 倍的加速。

高级:处理自定义数据类型#

上述优化实现返回 Dict[str, np.ndarray],Ray Data 原生支持此格式。但是,如果您的 colate 函数需要返回 PyTorch 张量或其他 ray.data.Dataset.map_batches() 不直接支持的自定义数据类型,您需要对其进行序列化。

张量序列化工具#

以下工具将 PyTorch 张量序列化为 PyArrow 格式。它将批次中的所有张量展平为单个二进制缓冲区,存储有关张量形状和数据类型的元数据,并将所有内容打包到一个单行 PyArrow 表中。在训练端,它将表反序列化回原始张量结构。

与实际的 colate 函数工作(例如标记化或图像处理)相比,序列化和反序列化操作通常很轻量,因此开销相对于扩展 colate 函数的性能提升而言很小。

您可以参考 Collate Utilities 作为参考实现,并根据您的需求进行调整。

带张量序列化的示例#

当您的 colate 函数必须返回 PyTorch 张量时,以下示例演示了使用张量序列化。此方法需要在 map_batches 之前使用 repartition,因为 colate 函数会更改输出行的数量(每个批次变成一个序列化行)。

from transformers import AutoTokenizer
import torch
from typing import Dict
from ray.data.collate_fn import ArrowBatchCollateFn
import pyarrow as pa
from collate_utils import serialize_tensors_to_table, deserialize_table_to_tensors
from ray.train.torch import TorchTrainer
from ray.train import ScalingConfig
from mock_dataset import create_mock_ray_text_dataset

BATCH_SIZE = 10000

class TextTokenizerCollateFn:
    """Collate function that runs in Ray Data preprocessing."""
    def __init__(self):
        self.tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")

    def __call__(self, batch: pa.Table) -> pa.Table:
        # Tokenize the batch
        outputs = self.tokenizer(
            batch["text"].to_pylist(),
            truncation=True,
            padding="longest",
            return_tensors="pt",
        )
        outputs["labels"] = torch.LongTensor(batch["label"].to_numpy())

        # Serialize to single-row table using the utility
        return serialize_tensors_to_table(outputs)

class IteratorCollateFn(ArrowBatchCollateFn):
    """Collate function for iter_torch_batches that deserializes the batch."""
    def __init__(self, pin_memory=False):
        self._pin_memory = pin_memory

    def __call__(self, batch: pa.Table) -> Dict[str, torch.Tensor]:
        # Deserialize from single-row table using the utility
        return deserialize_table_to_tensors(batch, pin_memory=self._pin_memory)

def train_func():
    collate_fn = IteratorCollateFn()

    # Collate function only deserializes on the training worker
    for batch in ray.train.get_dataset_shard("train").iter_torch_batches(
        collate_fn=collate_fn,
        batch_size=1  # Each "row" is actually a full batch
    ):
        # Training logic here
        pass

# Apply preprocessing in Ray Data
# Use repartition BEFORE map_batches because output row count changes
train_dataset = (
    create_mock_ray_text_dataset(
        dataset_size=1000000,
        min_len=1000,
        max_len=3000
    )
    .repartition(target_num_rows_per_block=BATCH_SIZE)
    .map_batches(
        TextTokenizerCollateFn,
        batch_size=BATCH_SIZE,
        batch_format="pyarrow",
    )
)

trainer = TorchTrainer(
    train_func,
    datasets={"train": train_dataset},
    scaling_config=ScalingConfig(num_workers=4, use_gpu=True)
)

result = trainer.fit()