加载数据#

Ray Data 可以从各种源加载数据。本指南将向你展示如何:

读取文件#

Ray Data 可以从本地磁盘或云存储读取各种文件格式的文件。要查看支持的完整文件格式列表,请参阅输入/输出参考

要读取 Parquet 文件,请调用 read_parquet()

import ray

ds = ray.data.read_parquet("s3://anonymous@ray-example-data/iris.parquet")

print(ds.schema())
Column        Type
------        ----
sepal.length  double
sepal.width   double
petal.length  double
petal.width   double
variety       string

提示

读取 Parquet 文件时,你可以利用列裁剪(column pruning)在文件扫描级别高效过滤列。有关投影下推(projection pushdown)功能的更多详细信息,请参阅Parquet 列裁剪

要读取原始图像,请调用 read_images()。Ray Data 将图像表示为 NumPy ndarray。

import ray

ds = ray.data.read_images("s3://anonymous@ray-example-data/batoidea/JPEGImages/")

print(ds.schema())
Column  Type
------  ----
image   numpy.ndarray(shape=(32, 32, 3), dtype=uint8)

要按行读取文本,请调用 read_text()

import ray

ds = ray.data.read_text("s3://anonymous@ray-example-data/this.txt")

print(ds.schema())
Column  Type
------  ----
text    string

要读取 CSV 文件,请调用 read_csv()

import ray

ds = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv")

print(ds.schema())
Column             Type
------             ----
sepal length (cm)  double
sepal width (cm)   double
petal length (cm)  double
petal width (cm)   double
target             int64

要读取原始二进制文件,请调用 read_binary_files()

import ray

ds = ray.data.read_binary_files("s3://anonymous@ray-example-data/documents")

print(ds.schema())
Column  Type
------  ----
bytes   binary

要读取 TFRecords 文件,请调用 read_tfrecords()

import ray

ds = ray.data.read_tfrecords("s3://anonymous@ray-example-data/iris.tfrecords")

print(ds.schema())
Column        Type
------        ----
label         binary
petal.length  float
sepal.width   float
petal.width   float
sepal.length  float

从本地磁盘读取文件#

要从本地磁盘读取文件,请调用诸如 read_parquet() 之类的函数,并使用 local:// 方案指定路径。路径可以指向文件或目录。

要读取 Parquet 以外的格式,请参阅输入/输出参考

提示

如果你的文件在每个节点上都可访问,则省略 local://,以便在整个集群中并行执行读取任务。

import ray

ds = ray.data.read_parquet("local:///tmp/iris.parquet")

print(ds.schema())
Column        Type
------        ----
sepal.length  double
sepal.width   double
petal.length  double
petal.width   double
variety       string

从云存储读取文件#

要读取云存储中的文件,请使用你的云服务提供商对所有节点进行身份验证。然后,调用诸如 read_parquet() 之类的方法,并使用适当的方案指定 URI。URI 可以指向存储桶、文件夹或对象。

要读取 Parquet 以外的格式,请参阅输入/输出参考

要从 Amazon S3 读取文件,请使用 s3:// 方案指定 URI。

import ray

ds = ray.data.read_parquet("s3://anonymous@ray-example-data/iris.parquet")

print(ds.schema())
Column        Type
------        ----
sepal.length  double
sepal.width   double
petal.length  double
petal.width   double
variety       string

Ray Data 依赖 PyArrow 进行 Amazon S3 身份验证。有关如何配置你的凭据以与 PyArrow 兼容的更多信息,请参阅其 S3 文件系统文档

要从 Google Cloud Storage 读取文件,请安装 Google Cloud Storage 文件系统接口

pip install gcsfs

然后,创建一个 GCSFileSystem,并使用 gs:// 方案指定 URI。

import ray

filesystem = gcsfs.GCSFileSystem(project="my-google-project")
ds = ray.data.read_parquet(
    "gs://...",
    filesystem=filesystem
)

print(ds.schema())
Column        Type
------        ----
sepal.length  double
sepal.width   double
petal.length  double
petal.width   double
variety       string

Ray Data 依赖 PyArrow 进行 Google Cloud Storage 身份验证。有关如何配置你的凭据以与 PyArrow 兼容的更多信息,请参阅其 GCS 文件系统文档

要从 Azure Blob Storage 读取文件,请安装 Azure-Datalake Gen1 和 Gen2 Storage 文件系统接口

pip install adlfs

然后,创建一个 AzureBlobFileSystem,并使用 az:// 方案指定 URI。

import adlfs
import ray

ds = ray.data.read_parquet(
    "az://ray-example-data/iris.parquet",
    adlfs.AzureBlobFileSystem(account_name="azureopendatastorage")
)

print(ds.schema())
Column        Type
------        ----
sepal.length  double
sepal.width   double
petal.length  double
petal.width   double
variety       string

Ray Data 依赖 PyArrow 进行 Azure Blob Storage 身份验证。有关如何配置你的凭据以与 PyArrow 兼容的更多信息,请参阅其 fsspec 兼容文件系统文档

从 NFS 读取文件#

要从 NFS 文件系统读取文件,请调用诸如 read_parquet() 之类的函数,并指定挂载文件系统上的文件。路径可以指向文件或目录。

要读取 Parquet 以外的格式,请参阅输入/输出参考

import ray

ds = ray.data.read_parquet("/mnt/cluster_storage/iris.parquet")

print(ds.schema())
Column        Type
------        ----
sepal.length  double
sepal.width   double
petal.length  double
petal.width   double
variety       string

处理压缩文件#

要读取压缩文件,请在 arrow_open_stream_args 中指定 compression。你可以使用 Arrow 支持的任何编解码器

import ray

ds = ray.data.read_csv(
    "s3://anonymous@ray-example-data/iris.csv.gz",
    arrow_open_stream_args={"compression": "gzip"},
)

从其他库加载数据#

从单节点数据库加载数据#

Ray Data 可以与 pandas、NumPy 和 Arrow 等库互操作。

要从 Python 对象创建 Dataset,请调用 from_items() 并传入一个 Dict 列表。Ray Data 将每个 Dict 视为一行。

import ray

ds = ray.data.from_items([
    {"food": "spam", "price": 9.34},
    {"food": "ham", "price": 5.37},
    {"food": "eggs", "price": 0.94}
])

print(ds)
MaterializedDataset(
   num_blocks=3,
   num_rows=3,
   schema={food: string, price: double}
)

你也可以从常规 Python 对象的列表创建 Dataset

import ray

ds = ray.data.from_items([1, 2, 3, 4, 5])

print(ds)
MaterializedDataset(num_blocks=5, num_rows=5, schema={item: int64})

要从 NumPy 数组创建 Dataset,请调用 from_numpy()。Ray Data 将最外层轴视为行维度。

import numpy as np
import ray

array = np.ones((3, 2, 2))
ds = ray.data.from_numpy(array)

print(ds)
MaterializedDataset(
   num_blocks=1,
   num_rows=3,
   schema={data: numpy.ndarray(shape=(2, 2), dtype=double)}
)

要从 pandas DataFrame 创建 Dataset,请调用 from_pandas()

import pandas as pd
import ray

df = pd.DataFrame({
    "food": ["spam", "ham", "eggs"],
    "price": [9.34, 5.37, 0.94]
})
ds = ray.data.from_pandas(df)

print(ds)
MaterializedDataset(
   num_blocks=1,
   num_rows=3,
   schema={food: object, price: float64}
)

要从 Arrow table 创建 Dataset,请调用 from_arrow()

import pyarrow as pa

table = pa.table({
    "food": ["spam", "ham", "eggs"],
    "price": [9.34, 5.37, 0.94]
})
ds = ray.data.from_arrow(table)

print(ds)
MaterializedDataset(
   num_blocks=1,
   num_rows=3,
   schema={food: string, price: double}
)

从分布式 DataFrame 库加载数据#

Ray Data 可以与 DaftDaskSparkModinMars 等分布式数据处理框架互操作。

注意

Ray 社区提供了这些操作,但可能不会积极维护它们。如果你遇到问题,请在此处创建 GitHub Issue。

要从 Daft DataFrame 创建 Dataset,请调用 from_daft()。此函数执行 Daft DataFrame,并基于 Daft 查询生成的 Arrow 数据构建一个 Dataset

import daft
import ray

ray.init()

df = daft.from_pydict({"int_col": [i for i in range(10000)], "str_col": [str(i) for i in range(10000)]})
ds = ray.data.from_daft(df)

ds.show(3)
{'int_col': 0, 'str_col': '0'}
{'int_col': 1, 'str_col': '1'}
{'int_col': 2, 'str_col': '2'}

要从 Dask DataFrame 创建 Dataset,请调用 from_dask()。此函数构建一个 Dataset,其底层是 Dask DataFrame 的分布式 Pandas DataFrame 分区。

import dask.dataframe as dd
import pandas as pd
import ray

df = pd.DataFrame({"col1": list(range(10000)), "col2": list(map(str, range(10000)))})
ddf = dd.from_pandas(df, npartitions=4)
# Create a Dataset from a Dask DataFrame.
ds = ray.data.from_dask(ddf)

ds.show(3)
{'col1': 0, 'col2': '0'}
{'col1': 1, 'col2': '1'}
{'col1': 2, 'col2': '2'}

要从 Spark DataFrame 创建 Dataset,请调用 from_spark()。此函数创建了一个 Dataset,其底层是 Spark DataFrame 的分布式 Spark DataFrame 分区。

import ray
import raydp

spark = raydp.init_spark(app_name="Spark -> Datasets Example",
                        num_executors=2,
                        executor_cores=2,
                        executor_memory="500MB")
df = spark.createDataFrame([(i, str(i)) for i in range(10000)], ["col1", "col2"])
ds = ray.data.from_spark(df)

ds.show(3)
{'col1': 0, 'col2': '0'}
{'col1': 1, 'col2': '1'}
{'col1': 2, 'col2': '2'}

要从 Iceberg Table 创建 Dataset,请调用 read_iceberg()。此函数创建了一个 Dataset,其底层是 Iceberg 表的分布式文件。

>>> import ray
>>> from pyiceberg.expressions import EqualTo
>>> ds = ray.data.read_iceberg(
...     table_identifier="db_name.table_name",
...     row_filter=EqualTo("column_name", "literal_value"),
...     catalog_kwargs={"name": "default", "type": "glue"}
... )
{'col1': 0, 'col2': '0'}
{'col1': 1, 'col2': '1'}
{'col1': 2, 'col2': '2'}

要从 Modin DataFrame 创建 Dataset,请调用 from_modin()。此函数构建一个 Dataset,其底层是 Modin DataFrame 的分布式 Pandas DataFrame 分区。

import modin.pandas as md
import pandas as pd
import ray

df = pd.DataFrame({"col1": list(range(10000)), "col2": list(map(str, range(10000)))})
mdf = md.DataFrame(df)
# Create a Dataset from a Modin DataFrame.
ds = ray.data.from_modin(mdf)

ds.show(3)
{'col1': 0, 'col2': '0'}
{'col1': 1, 'col2': '1'}
{'col1': 2, 'col2': '2'}

要从 Mars DataFrame 创建 Dataset,请调用 from_mars()。此函数构建一个 Dataset,其底层是 Mars DataFrame 的分布式 Pandas DataFrame 分区。

import mars
import mars.dataframe as md
import pandas as pd
import ray

cluster = mars.new_cluster_in_ray(worker_num=2, worker_cpu=1)

df = pd.DataFrame({"col1": list(range(10000)), "col2": list(map(str, range(10000)))})
mdf = md.DataFrame(df, num_partitions=8)
# Create a tabular Dataset from a Mars DataFrame.
ds = ray.data.from_mars(mdf)

ds.show(3)
{'col1': 0, 'col2': '0'}
{'col1': 1, 'col2': '1'}
{'col1': 2, 'col2': '2'}

从机器学习库加载数据#

Ray Data 可以与 HuggingFace、PyTorch 和 TensorFlow 数据集互操作。

要将 HuggingFace 数据集转换为 Ray 数据集,请调用 from_huggingface()。此函数访问底层 Arrow 表并将其直接转换为 Dataset。

警告

from_huggingface 仅在某些情况下支持并行读取,即对于未经转换的公共 HuggingFace 数据集。对于这些数据集,Ray Data 使用托管的 parquet 文件执行分布式读取;否则,Ray Data 使用单节点读取。对于内存中的 HuggingFace 数据集,这种行为应该不是问题,但对于大型内存映射 HuggingFace 数据集可能会导致失败。此外,不支持 HuggingFace DatasetDictIterableDatasetDict 对象。

import ray.data
from datasets import load_dataset

hf_ds = load_dataset("wikitext", "wikitext-2-raw-v1")
ray_ds = ray.data.from_huggingface(hf_ds["train"])
ray_ds.take(2)
[{'text': ''}, {'text': ' = Valkyria Chronicles III = \n'}]

要将 PyTorch 数据集转换为 Ray Dataset,请调用 from_torch()

import ray
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor

tds = datasets.CIFAR10(root="data", train=True, download=True, transform=ToTensor())
ds = ray.data.from_torch(tds)

print(ds)
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz
100%|███████████████████████| 170498071/170498071 [00:07<00:00, 23494838.54it/s]
Extracting data/cifar-10-python.tar.gz to data
Dataset(num_rows=50000, schema={item: object})

要将 TensorFlow 数据集转换为 Ray Dataset,请调用 from_tf()

警告

from_tf 不支持并行读取。仅将此函数用于 MNIST 或 CIFAR 等小型数据集。

import ray
import tensorflow_datasets as tfds

tf_ds, _ = tfds.load("cifar10", split=["train", "test"])
ds = ray.data.from_tf(tf_ds)

print(ds)
MaterializedDataset(
   num_blocks=...,
   num_rows=50000,
   schema={
      id: binary,
      image: numpy.ndarray(shape=(32, 32, 3), dtype=uint8),
      label: int64
   }
)

读取数据库#

Ray Data 可以读取 MySQL、PostgreSQL、MongoDB 和 BigQuery 等数据库中的数据。

读取 SQL 数据库#

调用 read_sql() 可从提供 Python DB API2 兼容连接器的数据库读取数据。

要从 MySQL 读取,请安装 MySQL Connector/Python。它是 MySQL 的第一方数据库连接器。

pip install mysql-connector-python

然后,定义你的连接逻辑并查询数据库。

import mysql.connector

import ray

def create_connection():
    return mysql.connector.connect(
        user="admin",
        password=...,
        host="example-mysql-database.c2c2k1yfll7o.us-west-2.rds.amazonaws.com",
        connection_timeout=30,
        database="example",
    )

# Get all movies
dataset = ray.data.read_sql("SELECT * FROM movie", create_connection)
# Get movies after the year 1980
dataset = ray.data.read_sql(
    "SELECT title, score FROM movie WHERE year >= 1980", create_connection
)
# Get the number of movies per year
dataset = ray.data.read_sql(
    "SELECT year, COUNT(*) FROM movie GROUP BY year", create_connection
)

要从 PostgreSQL 读取,请安装 Psycopg 2。它是最流行的 PostgreSQL 数据库连接器。

pip install psycopg2-binary

然后,定义你的连接逻辑并查询数据库。

import psycopg2

import ray

def create_connection():
    return psycopg2.connect(
        user="postgres",
        password=...,
        host="example-postgres-database.c2c2k1yfll7o.us-west-2.rds.amazonaws.com",
        dbname="example",
    )

# Get all movies
dataset = ray.data.read_sql("SELECT * FROM movie", create_connection)
# Get movies after the year 1980
dataset = ray.data.read_sql(
    "SELECT title, score FROM movie WHERE year >= 1980", create_connection
)
# Get the number of movies per year
dataset = ray.data.read_sql(
    "SELECT year, COUNT(*) FROM movie GROUP BY year", create_connection
)

要从 Snowflake 读取,请安装 Snowflake Connector for Python

pip install snowflake-connector-python

然后,定义你的连接逻辑并查询数据库。

import snowflake.connector

import ray

def create_connection():
    return snowflake.connector.connect(
        user=...,
        password=...
        account="ZZKXUVH-IPB52023",
        database="example",
    )

# Get all movies
dataset = ray.data.read_sql("SELECT * FROM movie", create_connection)
# Get movies after the year 1980
dataset = ray.data.read_sql(
    "SELECT title, score FROM movie WHERE year >= 1980", create_connection
)
# Get the number of movies per year
dataset = ray.data.read_sql(
    "SELECT year, COUNT(*) FROM movie GROUP BY year", create_connection
)

要从 Databricks 读取,请将 DATABRICKS_TOKEN 环境变量设置为你的 Databricks 数据仓库访问令牌。

export DATABRICKS_TOKEN=...

如果你的程序未在 Databricks 运行时上运行,还需要设置 DATABRICKS_HOST 环境变量。

export DATABRICKS_HOST=adb-<workspace-id>.<random-number>.azuredatabricks.net

然后,调用 ray.data.read_databricks_tables() 从 Databricks SQL 数据仓库读取。

import ray

dataset = ray.data.read_databricks_tables(
    warehouse_id='...',  # Databricks SQL warehouse ID
    catalog='catalog_1',  # Unity catalog name
    schema='db_1',  # Schema name
    query="SELECT title, score FROM movie WHERE year >= 1980",
)

要从 BigQuery 读取,请安装 Google BigQuery Python 客户端Google BigQueryStorage Python 客户端

pip install google-cloud-bigquery
pip install google-cloud-bigquery-storage

要从 BigQuery 读取数据,请调用 read_bigquery() 并指定项目 ID、数据集和查询(如果适用)。

import ray

# Read the entire dataset. Do not specify query.
ds = ray.data.read_bigquery(
    project_id="my_gcloud_project_id",
    dataset="bigquery-public-data.ml_datasets.iris",
)

# Read from a SQL query of the dataset. Do not specify dataset.
ds = ray.data.read_bigquery(
    project_id="my_gcloud_project_id",
    query = "SELECT * FROM `bigquery-public-data.ml_datasets.iris` LIMIT 50",
)

# Write back to BigQuery
ds.write_bigquery(
    project_id="my_gcloud_project_id",
    dataset="destination_dataset.destination_table",
    overwrite_table=True,
)

读取 MongoDB#

要从 MongoDB 读取数据,请调用 read_mongo() 并指定源 URI、数据库和集合。你还需要指定要对集合运行的管道。

import ray

# Read a local MongoDB.
ds = ray.data.read_mongo(
    uri="mongodb://localhost:27017",
    database="my_db",
    collection="my_collection",
    pipeline=[{"$match": {"col": {"$gte": 0, "$lt": 10}}}, {"$sort": "sort_col"}],
)

# Reading a remote MongoDB is the same.
ds = ray.data.read_mongo(
    uri="mongodb://username:[email protected]:27017/?authSource=admin",
    database="my_db",
    collection="my_collection",
    pipeline=[{"$match": {"col": {"$gte": 0, "$lt": 10}}}, {"$sort": "sort_col"}],
)

# Write back to MongoDB.
ds.write_mongo(
    MongoDatasource(),
    uri="mongodb://username:[email protected]:27017/?authSource=admin",
    database="my_db",
    collection="my_collection",
)

创建合成数据#

合成数据集对于测试和基准测试很有用。

要从整数范围创建合成 Dataset,请调用 range()。Ray Data 将整数范围存储在单个列中。

import ray

ds = ray.data.range(10000)

print(ds.schema())
Column  Type
------  ----
id      int64

要创建包含数组的合成 Dataset,请调用 range_tensor()。Ray Data 将整数范围打包到给定形状的 ndarrays 中。

import ray

ds = ray.data.range_tensor(10, shape=(64, 64))

print(ds.schema())
Column  Type
------  ----
data    numpy.ndarray(shape=(64, 64), dtype=int64)

加载其他数据源#

如果 Ray Data 无法加载你的数据,请继承 Datasource 类。然后,构建自定义数据源实例并将其传递给 read_datasource()。要写入结果,你可能还需要继承 ray.data.Datasink。然后,创建自定义数据接收器实例并将其传递给 write_datasink()。有关更多详细信息,请参阅高级:读写自定义文件类型

# Read from a custom datasource.
ds = ray.data.read_datasource(YourCustomDatasource(), **read_args)

# Write to a custom datasink.
ds.write_datasink(YourCustomDatasink())

性能注意事项#

默认情况下,所有读取任务的输出块数量是根据输入数据大小和可用资源动态确定的。这在大多数情况下应该运行良好。但是,你也可以通过设置 override_num_blocks 参数来覆盖默认值。Ray Data 内部会决定并行运行多少个读取任务以最佳利用集群,范围从 1...override_num_blocks 个任务。换句话说,override_num_blocks 的值越高,Dataset 中的数据块就越小,从而提供了更多的并行执行机会。

有关如何调整输出块数量以及其他优化读取性能的建议,请参阅优化读取