迭代数据#
Ray Data 允许您按行或按批次迭代数据。
本指南将向您展示如何
按行迭代#
要按行迭代数据集,请调用 Dataset.iter_rows()
。Ray Data 将每行表示为一个字典。
import ray
ds = ray.data.read_csv("s3://anonymous@air-example-data/iris.csv")
for row in ds.iter_rows():
print(row)
{'sepal length (cm)': 5.1, 'sepal width (cm)': 3.5, 'petal length (cm)': 1.4, 'petal width (cm)': 0.2, 'target': 0}
{'sepal length (cm)': 4.9, 'sepal width (cm)': 3.0, 'petal length (cm)': 1.4, 'petal width (cm)': 0.2, 'target': 0}
...
{'sepal length (cm)': 5.9, 'sepal width (cm)': 3.0, 'petal length (cm)': 5.1, 'petal width (cm)': 1.8, 'target': 2}
按批次迭代#
批次包含多行数据。通过调用以下方法之一,可以以不同格式迭代数据集批次
Dataset.iter_batches() <ray.data.Dataset.iter_batches>
Dataset.iter_torch_batches() <ray.data.Dataset.iter_torch_batches>
Dataset.to_tf() <ray.data.Dataset.to_tf>
import ray
ds = ray.data.read_images("s3://anonymous@ray-example-data/image-datasets/simple")
for batch in ds.iter_batches(batch_size=2, batch_format="numpy"):
print(batch)
{'image': array([[[[...]]]], dtype=uint8)}
...
{'image': array([[[[...]]]], dtype=uint8)}
import ray
ds = ray.data.read_csv("s3://anonymous@air-example-data/iris.csv")
for batch in ds.iter_batches(batch_size=2, batch_format="pandas"):
print(batch)
sepal length (cm) sepal width (cm) petal length (cm) petal width (cm) target
0 5.1 3.5 1.4 0.2 0
1 4.9 3.0 1.4 0.2 0
...
sepal length (cm) sepal width (cm) petal length (cm) petal width (cm) target
0 6.2 3.4 5.4 2.3 2
1 5.9 3.0 5.1 1.8 2
import ray
ds = ray.data.read_images("s3://anonymous@ray-example-data/image-datasets/simple")
for batch in ds.iter_torch_batches(batch_size=2):
print(batch)
{'image': tensor([[[[...]]]], dtype=torch.uint8)}
...
{'image': tensor([[[[...]]]], dtype=torch.uint8)}
import ray
ds = ray.data.read_csv("s3://anonymous@air-example-data/iris.csv")
tf_dataset = ds.to_tf(
feature_columns="sepal length (cm)",
label_columns="target",
batch_size=2
)
for features, labels in tf_dataset:
print(features, labels)
tf.Tensor([5.1 4.9], shape=(2,), dtype=float64) tf.Tensor([0 0], shape=(2,), dtype=int64)
...
tf.Tensor([6.2 5.9], shape=(2,), dtype=float64) tf.Tensor([2 2], shape=(2,), dtype=int64)
打乱后按批次迭代#
Dataset.random_shuffle
速度较慢,因为它会打乱所有行。如果不需要完全全局打乱,您可以通过指定 local_shuffle_buffer_size
在迭代期间打乱指定缓冲区大小内的部分行。虽然这不是像 random_shuffle
那样的真正全局打乱,但其性能更好,因为它不需要过多的数据移动。有关这些选项的更多详细信息,请参阅打乱数据。
提示
要配置 local_shuffle_buffer_size
,请选择能够实现足够随机性的最小值。值越高,随机性越大,但迭代速度越慢。关于如何诊断慢速问题,请参阅按批次迭代时进行局部打乱。
import ray
ds = ray.data.read_images("s3://anonymous@ray-example-data/image-datasets/simple")
for batch in ds.iter_batches(
batch_size=2,
batch_format="numpy",
local_shuffle_buffer_size=250,
):
print(batch)
{'image': array([[[[...]]]], dtype=uint8)}
...
{'image': array([[[[...]]]], dtype=uint8)}
import ray
ds = ray.data.read_csv("s3://anonymous@air-example-data/iris.csv")
for batch in ds.iter_batches(
batch_size=2,
batch_format="pandas",
local_shuffle_buffer_size=250,
):
print(batch)
sepal length (cm) sepal width (cm) petal length (cm) petal width (cm) target
0 6.3 2.9 5.6 1.8 2
1 5.7 4.4 1.5 0.4 0
...
sepal length (cm) sepal width (cm) petal length (cm) petal width (cm) target
0 5.6 2.7 4.2 1.3 1
1 4.8 3.0 1.4 0.1 0
import ray
ds = ray.data.read_images("s3://anonymous@ray-example-data/image-datasets/simple")
for batch in ds.iter_torch_batches(
batch_size=2,
local_shuffle_buffer_size=250,
):
print(batch)
{'image': tensor([[[[...]]]], dtype=torch.uint8)}
...
{'image': tensor([[[[...]]]], dtype=torch.uint8)}
import ray
ds = ray.data.read_csv("s3://anonymous@air-example-data/iris.csv")
tf_dataset = ds.to_tf(
feature_columns="sepal length (cm)",
label_columns="target",
batch_size=2,
local_shuffle_buffer_size=250,
)
for features, labels in tf_dataset:
print(features, labels)
tf.Tensor([5.2 6.3], shape=(2,), dtype=float64) tf.Tensor([1 2], shape=(2,), dtype=int64)
...
tf.Tensor([5. 5.8], shape=(2,), dtype=float64) tf.Tensor([0 0], shape=(2,), dtype=int64)
拆分数据集用于分布式并行训练#
如果您正在进行分布式数据并行训练,请调用 Dataset.streaming_split
将数据集拆分成不相交的分片。
注意
如果您正在使用Ray Train,则无需拆分数据集。Ray Train 会自动为您拆分数据集。要了解更多信息,请参阅ML 训练数据加载指南。
import ray
@ray.remote
class Worker:
def train(self, data_iterator):
for batch in data_iterator.iter_batches(batch_size=8):
pass
ds = ray.data.read_csv("s3://anonymous@air-example-data/iris.csv")
workers = [Worker.remote() for _ in range(4)]
shards = ds.streaming_split(n=4, equal=True)
ray.get([w.train.remote(s) for w, s in zip(workers, shards)])