处理 Tensor / NumPy#
N 维数组(换句话说,Tensor)在机器学习工作负载中无处不在。本指南介绍处理此类数据的限制和最佳实践。
Tensor 数据表示#
Ray Data 将 Tensor 表示为 NumPy ndarrays。
import ray
ds = ray.data.read_images("s3://anonymous@air-example-data/digits")
print(ds)
Dataset(
num_rows=100,
schema={image: numpy.ndarray(shape=(28, 28), dtype=uint8)}
)
固定形状 Tensor 的批次#
如果您的 Tensor 具有固定形状,Ray Data 会将其批次表示为常规的 ndarrays。
>>> import ray
>>> ds = ray.data.read_images("s3://anonymous@air-example-data/digits")
>>> batch = ds.take_batch(batch_size=32)
>>> batch["image"].shape
(32, 28, 28)
>>> batch["image"].dtype
dtype('uint8')
可变形状 Tensor 的批次#
如果您的 Tensor 形状可变,Ray Data 会将其批次表示为对象 dtype 的数组。
>>> import ray
>>> ds = ray.data.read_images("s3://anonymous@air-example-data/AnimalDetection")
>>> batch = ds.take_batch(batch_size=32)
>>> batch["image"].shape
(32,)
>>> batch["image"].dtype
dtype('O')
这些对象数组的单个元素是常规的 ndarrays。
>>> batch["image"][0].dtype
dtype('uint8')
>>> batch["image"][0].shape
(375, 500, 3)
>>> batch["image"][3].shape
(333, 465, 3)
转换 Tensor 数据#
调用 map()
或 map_batches()
来转换 Tensor 数据。
from typing import Any, Dict
import ray
import numpy as np
ds = ray.data.read_images("s3://anonymous@air-example-data/AnimalDetection")
def increase_brightness(row: Dict[str, Any]) -> Dict[str, Any]:
row["image"] = np.clip(row["image"] + 4, 0, 255)
return row
# Increase the brightness, record at a time.
ds.map(increase_brightness)
def batch_increase_brightness(batch: Dict[str, np.ndarray]) -> Dict:
batch["image"] = np.clip(batch["image"] + 4, 0, 255)
return batch
# Increase the brightness, batch at a time.
ds.map_batches(batch_increase_brightness)
除了 NumPy ndarrays,Ray Data 还将返回的 NumPy ndarrays 列表以及实现了 __array__
的对象(例如 torch.Tensor
)视为 Tensor 数据。
有关转换数据的更多信息,请阅读转换数据。
保存 Tensor 数据#
使用 Parquet, NumPy 和 JSON 等格式保存 Tensor 数据。要查看所有支持的格式,请参阅输入/输出参考。
调用 write_parquet()
将数据保存到 Parquet 文件。
import ray
ds = ray.data.read_images("s3://anonymous@ray-example-data/image-datasets/simple")
ds.write_parquet("/tmp/simple")
调用 write_numpy()
将 ndarray 列保存到 NumPy 文件。
import ray
ds = ray.data.read_images("s3://anonymous@ray-example-data/image-datasets/simple")
ds.write_numpy("/tmp/simple", column="image")
要将图像保存到 JSON 文件,请调用 write_json()
。
import ray
ds = ray.data.read_images("s3://anonymous@ray-example-data/image-datasets/simple")
ds.write_json("/tmp/simple")
有关保存数据的更多信息,请阅读保存数据。