与 PyTorch 协同工作#

Ray Data 集成了 PyTorch 生态系统。

本指南介绍如何

为训练迭代 Torch 张量#

要以 Torch 格式迭代数据批次,请调用 Dataset.iter_torch_batches()。每个批次表示为 Dict[str, torch.Tensor],数据集中每列一个张量。

这对于使用数据集中的批次训练 Torch 模型非常有用。有关配置详细信息,例如提供 collate_fn 以自定义转换,请参阅 iter_torch_batches() 的 API 参考。

import ray
import torch

ds = ray.data.read_images("s3://anonymous@ray-example-data/image-datasets/simple")

for batch in ds.iter_torch_batches(batch_size=2):
    print(batch)
{'image': tensor([[[[...]]]], dtype=torch.uint8)}
...
{'image': tensor([[[[...]]]], dtype=torch.uint8)}

与 Ray Train 集成#

Ray Data 与 Ray Train 集成,以便轻松地为数据并行训练进行数据摄取,并支持 PyTorch、PyTorch Lightning 或 Hugging Face 训练。

import torch
from torch import nn
import ray
from ray import train
from ray.train import ScalingConfig
from ray.train.torch import TorchTrainer

def train_func():
    model = nn.Sequential(nn.Linear(30, 1), nn.Sigmoid())
    loss_fn = torch.nn.BCELoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

    # Datasets can be accessed in your train_func via ``get_dataset_shard``.
    train_data_shard = train.get_dataset_shard("train")

    for epoch_idx in range(2):
        for batch in train_data_shard.iter_torch_batches(batch_size=128, dtypes=torch.float32):
            features = torch.stack([batch[col_name] for col_name in batch.keys() if col_name != "target"], axis=1)
            predictions = model(features)
            train_loss = loss_fn(predictions, batch["target"].unsqueeze(1))
            train_loss.backward()
            optimizer.step()


train_dataset = ray.data.read_csv("s3://anonymous@air-example-data/breast_cancer.csv")

trainer = TorchTrainer(
    train_func,
    datasets={"train": train_dataset},
    scaling_config=ScalingConfig(num_workers=2)
)
trainer.fit()

有关更多详细信息,请参阅 Ray Train 用户指南

使用 Torch 张量进行转换#

使用 mapmap_batches 应用的转换可以返回 Torch 张量。

注意

在底层,Ray Data 会自动将 Torch 张量转换为 NumPy 数组。后续的转换接受 NumPy 数组作为输入,而不是 Torch 张量。

from typing import Dict
import numpy as np
import torch
import ray

ds = ray.data.read_images("s3://anonymous@ray-example-data/image-datasets/simple")

def convert_to_torch(row: Dict[str, np.ndarray]) -> Dict[str, torch.Tensor]:
    return {"tensor": torch.as_tensor(row["image"])}

# The tensor gets converted into a Numpy array under the hood
transformed_ds = ds.map(convert_to_torch)
print(transformed_ds.schema())

# Subsequent transformations take in Numpy array as input.
def check_numpy(row: Dict[str, np.ndarray]):
    assert isinstance(row["tensor"], np.ndarray)
    return row

transformed_ds.map(check_numpy).take_all()
Column  Type
------  ----
tensor  ArrowTensorTypeV2(shape=(32, 32, 3), dtype=uint8)
from typing import Dict
import numpy as np
import torch
import ray

ds = ray.data.read_images("s3://anonymous@ray-example-data/image-datasets/simple")

def convert_to_torch(batch: Dict[str, np.ndarray]) -> Dict[str, torch.Tensor]:
    return {"tensor": torch.as_tensor(batch["image"])}

# The tensor gets converted into a Numpy array under the hood
transformed_ds = ds.map_batches(convert_to_torch, batch_size=2)
print(transformed_ds.schema())

# Subsequent transformations take in Numpy array as input.
def check_numpy(batch: Dict[str, np.ndarray]):
    assert isinstance(batch["tensor"], np.ndarray)
    return batch

transformed_ds.map_batches(check_numpy, batch_size=2).take_all()
Column  Type
------  ----
tensor  ArrowTensorTypeV2(shape=(32, 32, 3), dtype=uint8)

有关转换数据的更多信息,请参阅 转换数据

内置 PyTorch 转换#

您可以使用来自 torchvisiontorchtexttorchaudio 的内置 Torch 转换。

from typing import Dict
import numpy as np
import torch
from torchvision import transforms
import ray

# Create the Dataset.
ds = ray.data.read_images("s3://anonymous@ray-example-data/image-datasets/simple")

# Define the torchvision transform.
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.CenterCrop(10)
    ]
)

# Define the map function
def transform_image(row: Dict[str, np.ndarray]) -> Dict[str, torch.Tensor]:
    row["transformed_image"] = transform(row["image"])
    return row

# Apply the transform over the dataset.
transformed_ds = ds.map(transform_image)
print(transformed_ds.schema())
Column             Type
------             ----
image              ArrowTensorTypeV2(shape=(32, 32, 3), dtype=uint8)
transformed_image  ArrowTensorTypeV2(shape=(3, 10, 10), dtype=float)
from typing import Dict, List
import numpy as np
from torchtext import transforms
import ray

# Create the Dataset.
ds = ray.data.read_text("s3://anonymous@ray-example-data/simple.txt")

# Define the torchtext transform.
VOCAB_FILE = "https://hugging-face.cn/bert-base-uncased/resolve/main/vocab.txt"
transform = transforms.BERTTokenizer(vocab_path=VOCAB_FILE, do_lower_case=True, return_tokens=True)

# Define the map_batches function.
def tokenize_text(batch: Dict[str, np.ndarray]) -> Dict[str, List[str]]:
    batch["tokenized_text"] = transform(list(batch["text"]))
    return batch

# Apply the transform over the dataset.
transformed_ds = ds.map_batches(tokenize_text, batch_size=2)
print(transformed_ds.schema())
Column          Type
------          ----
text            string
tokenized_text  list<item: string>

使用 PyTorch 进行批量推理#

使用 Ray Datasets,您可以通过将预训练模型映射到数据来执行可扩展的离线批量推理。

from typing import Dict
import numpy as np
import torch
import torch.nn as nn

import ray

# Step 1: Create a Ray Dataset from in-memory Numpy arrays.
# You can also create a Ray Dataset from many other sources and file
# formats.
ds = ray.data.from_numpy(np.ones((1, 100)))

# Step 2: Define a Predictor class for inference.
# Use a class to initialize the model just once in `__init__`
# and reuse it for inference across multiple batches.
class TorchPredictor:
    def __init__(self):
        # Load a dummy neural network.
        # Set `self.model` to your pre-trained PyTorch model.
        self.model = nn.Sequential(
            nn.Linear(in_features=100, out_features=1),
            nn.Sigmoid(),
        )
        self.model.eval()

    # Logic for inference on 1 batch of data.
    def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
        tensor = torch.as_tensor(batch["data"], dtype=torch.float32)
        with torch.inference_mode():
            # Get the predictions from the input batch.
            return {"output": self.model(tensor).numpy()}

# Step 3: Map the Predictor over the Dataset to get predictions.
# Use 2 parallel actors for inference. Each actor predicts on a
# different partition of data.
predictions = ds.map_batches(TorchPredictor, compute=ray.data.ActorPoolStrategy(size=2))
# Step 4: Show one prediction output.
predictions.show(limit=1)
{'output': array([0.5590901], dtype=float32)}

有关更多详细信息,请参阅 批量推理用户指南

保存包含 Torch 张量的数据集#

包含 Torch 张量的数据集可以保存到文件,如 parquet 或 NumPy。

有关保存数据的更多信息,请阅读 保存数据

注意

在 GPU 设备上的 Torch 张量无法序列化并写入磁盘。在保存数据之前,请将张量转换为 CPU(tensor.to("cpu"))。

import torch
import ray

tensor = torch.Tensor(1)
ds = ray.data.from_items([{"tensor": tensor}])

ds.write_parquet("local:///tmp/tensor")
import torch
import ray

tensor = torch.Tensor(1)
ds = ray.data.from_items([{"tensor": tensor}])

ds.write_numpy("local:///tmp/tensor", column="tensor")

从 PyTorch 数据集和 DataLoader 迁移#

如果您目前使用 PyTorch 数据集和 DataLoader,则可以迁移到 Ray Data 以处理分布式数据集。

PyTorch 数据集被 Dataset 抽象所取代,而 PyTorch DataLoader 被 Dataset.iter_torch_batches() 所取代。

内置 PyTorch 数据集#

如果您正在使用内置的 PyTorch 数据集,例如来自 torchvision 的数据集,可以使用 from_torch() API 将它们转换为 Ray Dataset。

import torchvision
import ray

mnist = torchvision.datasets.MNIST(root="/tmp/", download=True)
ds = ray.data.from_torch(mnist)

# The data for each item of the Torch dataset is under the "item" key.
print(ds.schema())
Column  Type
------  ----
item    <class 'object'>

自定义 PyTorch 数据集#

如果您有自定义的 PyTorch 数据集,可以通过将 __getitem__ 中的逻辑转换为 Ray Data 的读取和转换操作来迁移到 Ray Data。

任何从云存储和磁盘读取数据的逻辑都可以被 Ray Data 的 read_* API 之一取代,任何转换逻辑都可以作为 Dataset 上的 map 调用来应用。

以下示例展示了一个自定义的 PyTorch 数据集,以及与 Ray Data 对应的外观。

注意

与 PyTorch Map 风格数据集不同,Ray Datasets 不可索引。

import tempfile
import boto3
from botocore import UNSIGNED
from botocore.config import Config

from torchvision import transforms
from torch.utils.data import Dataset
from PIL import Image

class ImageDataset(Dataset):
    def __init__(self, bucket_name: str, dir_path: str):
        self.s3 = boto3.resource("s3", config=Config(signature_version=UNSIGNED))
        self.bucket = self.s3.Bucket(bucket_name)
        self.files = [obj.key for obj in self.bucket.objects.filter(Prefix=dir_path)]

        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize((128, 128)),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img_name = self.files[idx]

        # Infer the label from the file name.
        last_slash_idx = img_name.rfind("/")
        dot_idx = img_name.rfind(".")
        label = int(img_name[last_slash_idx+1:dot_idx])

        # Download the S3 file locally.
        obj = self.bucket.Object(img_name)
        tmp = tempfile.NamedTemporaryFile()
        tmp_name = "{}.jpg".format(tmp.name)

        with open(tmp_name, "wb") as f:
            obj.download_fileobj(f)
            f.flush()
            f.close()
            image = Image.open(tmp_name)

        # Preprocess the image.
        image = self.transform(image)

        return image, label

dataset = ImageDataset(bucket_name="ray-example-data", dir_path="batoidea/JPEGImages/")
import torchvision
import ray

ds = ray.data.read_images("s3://anonymous@ray-example-data/batoidea/JPEGImages", include_paths=True)

# Extract the label from the file path.
def extract_label(row: dict):
    filepath = row["path"]
    last_slash_idx = filepath.rfind("/")
    dot_idx = filepath.rfind('.')
    label = int(filepath[last_slash_idx+1:dot_idx])
    row["label"] = label
    return row

transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Resize((128, 128)),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ])

# Preprocess the images.
def transform_image(row: dict):
    row["transformed_image"] = transform(row["image"])
    return row

# Map the transformations over the dataset.
ds = ds.map(extract_label).map(transform_image)

PyTorch DataLoader#

可以通过调用 Dataset.iter_torch_batches() 来迭代数据集的批次,从而取代 PyTorch DataLoader。

下表显示了 PyTorch DataLoader 的参数如何映射到 Ray Data。请注意,行为不一定完全相同。有关确切的语义和用法,请参阅 iter_torch_batches() 的 API 参考。

PyTorch DataLoader 参数

Ray Data API

batch_size

ds.iter_torch_batches()batch_size 参数

shuffle

ds.iter_torch_batches()local_shuffle_buffer_size 参数

collate_fn

ds.iter_torch_batches()collate_fn 参数

sampler

不支持。可以通过迭代数据集后手动实现 ds.iter_torch_batches()

batch_sampler

不支持。可以通过迭代数据集后手动实现 ds.iter_torch_batches()

drop_last

ds.iter_torch_batches()drop_last 参数

num_workers

使用 ds.iter_torch_batches()prefetch_batches 参数来指示预取多少批次。预取线程的数量会根据 prefetch_batches 自动配置。

prefetch_factor

使用 ds.iter_torch_batches()prefetch_batches 参数来指示预取多少批次。预取线程的数量会根据 prefetch_batches 自动配置。

pin_memory

device 传递给 ds.iter_torch_batches() 以获取已移动到正确设备的张量。