加载数据#

Ray Data 可以加载各种来源的数据。本指南将向您展示如何

读取文件#

Ray Data 可以从本地磁盘或云存储读取各种文件格式的文件。要查看支持的文件格式的完整列表,请参阅 输入/输出参考

要读取 Parquet 文件,请调用 read_parquet()

import ray

ds = ray.data.read_parquet("s3://anonymous@ray-example-data/iris.parquet")

print(ds.schema())
Column        Type
------        ----
sepal.length  double
sepal.width   double
petal.length  double
petal.width   double
variety       string

提示

在读取 Parquet 文件时,您可以利用列裁剪功能,在文件扫描级别高效地过滤列。有关投影下推功能的更多详细信息,请参阅 Parquet 列裁剪

要读取原始图片,请调用 read_images()。Ray Data 将图片表示为 NumPy ndarrays。

import ray

ds = ray.data.read_images("s3://anonymous@ray-example-data/batoidea/JPEGImages/")

print(ds.schema())
Column  Type
------  ----
image   ArrowTensorTypeV2(shape=(32, 32, 3), dtype=uint8)

要读取文本行,请调用 read_text()

import ray

ds = ray.data.read_text("s3://anonymous@ray-example-data/this.txt")

print(ds.schema())
Column  Type
------  ----
text    string

要读取 CSV 文件,请调用 read_csv()

import ray

ds = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv")

print(ds.schema())
Column             Type
------             ----
sepal length (cm)  double
sepal width (cm)   double
petal length (cm)  double
petal width (cm)   double
target             int64

要读取原始二进制文件,请调用 read_binary_files()

import ray

ds = ray.data.read_binary_files("s3://anonymous@ray-example-data/documents")

print(ds.schema())
Column  Type
------  ----
bytes   binary

要读取 TFRecords 文件,请调用 read_tfrecords()

import ray

ds = ray.data.read_tfrecords("s3://anonymous@ray-example-data/iris.tfrecords")

print(ds.schema())
Column        Type
------        ----
label         binary
petal.length  float
sepal.width   float
petal.width   float
sepal.length  float

从本地磁盘读取文件#

要从本地磁盘读取文件,请调用类似 read_parquet() 的函数,并使用 local:// schema 指定路径。路径可以指向文件或目录。

要读取 Parquet 以外的其他格式,请参阅 输入/输出参考

提示

如果您的文件在所有节点上都可访问,请排除 local://,以便在集群中并行化读取任务。

import ray

ds = ray.data.read_parquet("local:///tmp/iris.parquet")

print(ds.schema())
Column        Type
------        ----
sepal.length  double
sepal.width   double
petal.length  double
petal.width   double
variety       string

从云存储读取文件#

要从云存储读取文件,请对所有节点进行身份验证,以匹配您的云服务提供商。然后,调用类似 read_parquet() 的方法,并使用适当的 schema 指定 URI。URI 可以指向存储桶、文件夹或对象。

要读取 Parquet 以外的其他格式,请参阅 输入/输出参考

要从 Amazon S3 读取文件,请使用 s3:// scheme 指定 URI。

import ray

ds = ray.data.read_parquet("s3://anonymous@ray-example-data/iris.parquet")

print(ds.schema())
Column        Type
------        ----
sepal.length  double
sepal.width   double
petal.length  double
petal.width   double
variety       string

Ray Data 依赖 PyArrow 进行 Amazon S3 身份验证。有关如何配置与 PyArrow 兼容的凭据的更多信息,请参阅其 S3 文件系统文档

要从 Google Cloud Storage 读取文件,请安装 Google Cloud Storage 的文件系统接口

pip install gcsfs

然后,创建一个 GCSFileSystem 并使用 gs:// scheme 指定 URI。

import ray

filesystem = gcsfs.GCSFileSystem(project="my-google-project")
ds = ray.data.read_parquet(
    "gs://...",
    filesystem=filesystem
)

print(ds.schema())
Column        Type
------        ----
sepal.length  double
sepal.width   double
petal.length  double
petal.width   double
variety       string

Ray Data 依赖 PyArrow 进行 Google Cloud Storage 身份验证。有关如何配置与 PyArrow 兼容的凭据的更多信息,请参阅其 GCS 文件系统文档

要从 Azure Blob Storage 读取文件,请安装 Azure-Datalake Gen1 和 Gen2 Storage 的文件系统接口

pip install adlfs

然后,创建一个 AzureBlobFileSystem 并使用 az:// scheme 指定 URI。

import adlfs
import ray

ds = ray.data.read_parquet(
    "az://ray-example-data/iris.parquet",
    adlfs.AzureBlobFileSystem(account_name="azureopendatastorage")
)

print(ds.schema())
Column        Type
------        ----
sepal.length  double
sepal.width   double
petal.length  double
petal.width   double
variety       string

Ray Data 依赖 PyArrow 进行 Azure Blob Storage 身份验证。有关如何配置与 PyArrow 兼容的凭据的更多信息,请参阅其 fsspec 兼容文件系统文档

从 NFS 读取文件#

要从 NFS 文件系统读取文件,请调用类似 read_parquet() 的函数,并指定挂载文件系统上的文件。路径可以指向文件或目录。

要读取 Parquet 以外的其他格式,请参阅 输入/输出参考

import ray

ds = ray.data.read_parquet("/mnt/cluster_storage/iris.parquet")

print(ds.schema())
Column        Type
------        ----
sepal.length  double
sepal.width   double
petal.length  double
petal.width   double
variety       string

处理压缩文件#

要读取压缩文件,请在 arrow_open_stream_args 中指定 compression。您可以使用 Arrow 支持的任何 编解码器

import ray

ds = ray.data.read_csv(
    "s3://anonymous@ray-example-data/iris.csv.gz",
    arrow_open_stream_args={"compression": "gzip"},
)

从 URI 下载文件#

有时您可能有一个包含 URI 列的元数据表,并且您想下载 URI 引用的文件。

您可以通过利用 with_column() 方法以及 download() 表达式来批量下载数据。这种方法允许系统处理数据集中 URI 引用的文件的并行下载,而无需在自己的转换中管理异步代码。

以下示例展示了如何从 Parquet 文件中列出的 URL 下载一批图片

import ray
from ray.data.expressions import download

# Read a Parquet file containing a column of image URLs
ds = ray.data.read_parquet("s3://anonymous@ray-example-data/imagenet/metadata_file.parquet")

# Use `with_column` and `download` to download the images in parallel.
# This creates a new column 'bytes' with the downloaded file contents.
ds = ds.with_column(
    "bytes",
    download("image_url"),
)

ds.take(1)

从其他库加载数据#

从单节点数据库加载数据#

Ray Data 与 pandas、NumPy 和 Arrow 等库互操作。

要从 Python 对象创建 Dataset,请调用 from_items() 并传入一个 Dict 列表。Ray Data 将每个 Dict 视为一行。

import ray

ds = ray.data.from_items([
    {"food": "spam", "price": 9.34},
    {"food": "ham", "price": 5.37},
    {"food": "eggs", "price": 0.94}
])

print(ds)
MaterializedDataset(
   num_blocks=3,
   num_rows=3,
   schema={food: string, price: double}
)

您还可以从普通 Python 对象列表中创建 Dataset。在 schema 中,列名默认为“item”。

import ray

ds = ray.data.from_items([1, 2, 3, 4, 5])

print(ds)
MaterializedDataset(num_blocks=5, num_rows=5, schema={item: int64})

要从 NumPy 数组创建 Dataset,请调用 from_numpy()。Ray Data 将外层轴视为行维度。

import numpy as np
import ray

array = np.ones((3, 2, 2))
ds = ray.data.from_numpy(array)

print(ds)
MaterializedDataset(
   num_blocks=1,
   num_rows=3,
   schema={data: ArrowTensorTypeV2(shape=(2, 2), dtype=double)}
)

要从 pandas DataFrame 创建 Dataset,请调用 from_pandas()

import pandas as pd
import ray

df = pd.DataFrame({
    "food": ["spam", "ham", "eggs"],
    "price": [9.34, 5.37, 0.94]
})
ds = ray.data.from_pandas(df)

print(ds)
MaterializedDataset(
   num_blocks=1,
   num_rows=3,
   schema={food: object, price: float64}
)

要从 Arrow 表创建 Dataset,请调用 from_arrow()

import pyarrow as pa

table = pa.table({
    "food": ["spam", "ham", "eggs"],
    "price": [9.34, 5.37, 0.94]
})
ds = ray.data.from_arrow(table)

print(ds)
MaterializedDataset(
   num_blocks=1,
   num_rows=3,
   schema={food: string, price: double}
)

从分布式 DataFrame 库加载数据#

Ray Data 与分布式数据处理框架(如 DaftDaskSparkModinMars)互操作。

注意

Ray Community 提供这些操作,但可能不会积极维护它们。如果您遇到问题,请在此处 创建 GitHub issue

要从 Daft DataFrame 创建 Dataset,请调用 from_daft()。此函数执行 Daft DataFrame 并构建一个由 Daft 查询生成的 Arrow 数据支持的 Dataset

警告

from_daft() 不支持 PyArrow 14 及更高版本。有关更多信息,请参阅 此 issue

import daft
import ray

df = daft.from_pydict({"int_col": [i for i in range(10000)], "str_col": [str(i) for i in range(10000)]})
ds = ray.data.from_daft(df)

ds.show(3)
{'int_col': 0, 'str_col': '0'}
{'int_col': 1, 'str_col': '1'}
{'int_col': 2, 'str_col': '2'}

要从 Dask DataFrame 创建 Dataset,请调用 from_dask()。此函数构建一个由 Dask DataFrame 底层的分布式 Pandas DataFrame 分区支持的 Dataset

import dask.dataframe as dd
import pandas as pd
import ray

df = pd.DataFrame({"col1": list(range(10000)), "col2": list(map(str, range(10000)))})
ddf = dd.from_pandas(df, npartitions=4)
# Create a Dataset from a Dask DataFrame.
ds = ray.data.from_dask(ddf)

ds.show(3)
{'col1': 0, 'col2': '0'}
{'col1': 1, 'col2': '1'}
{'col1': 2, 'col2': '2'}

要从 Spark DataFrame 创建 Dataset,请调用 from_spark()。此函数创建一个由 Spark DataFrame 底层的分布式 Spark DataFrame 分区支持的 Dataset

import ray
import raydp

spark = raydp.init_spark(app_name="Spark -> Datasets Example",
                        num_executors=2,
                        executor_cores=2,
                        executor_memory="500MB")
df = spark.createDataFrame([(i, str(i)) for i in range(10000)], ["col1", "col2"])
ds = ray.data.from_spark(df)

ds.show(3)
{'col1': 0, 'col2': '0'}
{'col1': 1, 'col2': '1'}
{'col1': 2, 'col2': '2'}

要从 Iceberg Table 创建 Dataset,请调用 read_iceberg()。此函数创建一个由 Iceberg 表底层分布式文件支持的 Dataset

import ray
from pyiceberg.expressions import EqualTo

ds = ray.data.read_iceberg(
    table_identifier="db_name.table_name",
    row_filter=EqualTo("column_name", "literal_value"),
    catalog_kwargs={"name": "default", "type": "glue"}
)
ds.show(3)
{'col1': 0, 'col2': '0'}
{'col1': 1, 'col2': '1'}
{'col1': 2, 'col2': '2'}

要从 Modin DataFrame 创建 Dataset,请调用 from_modin()。此函数构建一个由 Modin DataFrame 底层的分布式 Pandas DataFrame 分区支持的 Dataset

import modin.pandas as md
import pandas as pd
import ray

df = pd.DataFrame({"col1": list(range(10000)), "col2": list(map(str, range(10000)))})
mdf = md.DataFrame(df)
# Create a Dataset from a Modin DataFrame.
ds = ray.data.from_modin(mdf)

ds.show(3)
{'col1': 0, 'col2': '0'}
{'col1': 1, 'col2': '1'}
{'col1': 2, 'col2': '2'}

要从 Mars DataFrame 创建 Dataset,请调用 from_mars()。此函数构建一个由 Mars DataFrame 底层的分布式 Pandas DataFrame 分区支持的 Dataset

import mars
import mars.dataframe as md
import pandas as pd
import ray

cluster = mars.new_cluster_in_ray(worker_num=2, worker_cpu=1)

df = pd.DataFrame({"col1": list(range(10000)), "col2": list(map(str, range(10000)))})
mdf = md.DataFrame(df, num_partitions=8)
# Create a tabular Dataset from a Mars DataFrame.
ds = ray.data.from_mars(mdf)

ds.show(3)
{'col1': 0, 'col2': '0'}
{'col1': 1, 'col2': '1'}
{'col1': 2, 'col2': '2'}

加载 Hugging Face 数据集#

要从 Hugging Face Hub 读取数据集,请使用 HfFileSystem 文件系统,配合 read_parquet()(或其他读取函数)。与先将数据集加载到内存中相比,这种方法提供了更好的性能和可扩展性。

首先,安装所需的依赖项

pip install huggingface_hub

设置您的 Hugging Face 令牌以进行身份验证。虽然公共数据集无需令牌即可读取,但在没有令牌的情况下,Hugging Face 的速率限制会更严格。要读取没有令牌的 Hugging Face 数据集,只需将文件系统参数设置为 HfFileSystem()

export HF_TOKEN=<YOUR HUGGING FACE TOKEN>

对于大多数 Hugging Face 数据集,数据都存储在 Parquet 文件中。您可以直接从数据集路径读取

import os
import ray
from huggingface_hub import HfFileSystem

ds = ray.data.read_parquet(
    "hf://datasets/wikimedia/wikipedia",
    file_extensions=["parquet"],
    filesystem=HfFileSystem(token=os.environ["HF_TOKEN"]),
)

print(f"Dataset count: {ds.count()}")
print(ds.schema())
Dataset count: 61614907
Column  Type
------  ----
id      string
url     string
title   string
text    string

提示

如果您在从 Hugging Face 文件系统读取时遇到序列化错误,请尝试将 huggingface_hub 升级到 1.1.6 或更高版本。有关更多详细信息,请参阅此 issue:ray-project/ray#59029

从 ML 库加载数据#

Ray Data 与 PyTorch 和 TensorFlow 数据集互操作。

要将 HuggingFace 数据集加载到 Ray Data 中,请使用 HuggingFace Hub 的 HfFileSystem,配合 read_parquet()read_csv()read_json()。由于 HuggingFace 数据集通常以这些文件格式为后端,因此这种方法可以直接从 Hub 实现高效的分布式读取。

import ray.data
from huggingface_hub import HfFileSystem

path = "hf://datasets/Salesforce/wikitext/wikitext-2-raw-v1/"
fs = HfFileSystem()
ds = ray.data.read_parquet(path, filesystem=fs)
print(ds.take(5))
[{'text': '...'}, {'text': '...'}]

要将 PyTorch 数据集转换为 Ray Dataset,请调用 from_torch()

import ray
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor

tds = datasets.CIFAR10(root="data", train=True, download=True, transform=ToTensor())
ds = ray.data.from_torch(tds)

print(ds)
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz
100%|███████████████████████| 170498071/170498071 [00:07<00:00, 23494838.54it/s]
Extracting data/cifar-10-python.tar.gz to data
Dataset(num_rows=50000, schema={item: object})

要将 TensorFlow 数据集转换为 Ray Dataset,请调用 from_tf()

警告

from_tf 不支持并行读取。仅将此函数用于小型数据集,如 MNIST 或 CIFAR。

import ray
import tensorflow_datasets as tfds

tf_ds, _ = tfds.load("cifar10", split=["train", "test"])
ds = ray.data.from_tf(tf_ds)

print(ds)
MaterializedDataset(
   num_blocks=...,
   num_rows=50000,
   schema={
      id: binary,
      image: ArrowTensorTypeV2(shape=(32, 32, 3), dtype=uint8),
      label: int64
   }
)

读取数据库#

Ray Data 可以读取 MySQL、PostgreSQL、MongoDB 和 BigQuery 等数据库。

读取 SQL 数据库#

调用 read_sql() 来从提供 Python DB API2 兼容连接器的数据库读取数据。

要从 MySQL 读取,请安装 MySQL Connector/Python。这是官方的 MySQL 数据库连接器。

pip install mysql-connector-python

然后,定义您的连接逻辑并查询数据库。

import mysql.connector

import ray

def create_connection():
    return mysql.connector.connect(
        user="admin",
        password=...,
        host="example-mysql-database.c2c2k1yfll7o.us-west-2.rds.amazonaws.com",
        connection_timeout=30,
        database="example",
    )

# Get all movies
dataset = ray.data.read_sql("SELECT * FROM movie", create_connection)
# Get movies after the year 1980
dataset = ray.data.read_sql(
    "SELECT title, score FROM movie WHERE year >= 1980", create_connection
)
# Get the number of movies per year
dataset = ray.data.read_sql(
    "SELECT year, COUNT(*) FROM movie GROUP BY year", create_connection
)

要从 PostgreSQL 读取,请安装 Psycopg 2。这是最流行的 PostgreSQL 数据库连接器。

pip install psycopg2-binary

然后,定义您的连接逻辑并查询数据库。

import psycopg2

import ray

def create_connection():
    return psycopg2.connect(
        user="postgres",
        password=...,
        host="example-postgres-database.c2c2k1yfll7o.us-west-2.rds.amazonaws.com",
        dbname="example",
    )

# Get all movies
dataset = ray.data.read_sql("SELECT * FROM movie", create_connection)
# Get movies after the year 1980
dataset = ray.data.read_sql(
    "SELECT title, score FROM movie WHERE year >= 1980", create_connection
)
# Get the number of movies per year
dataset = ray.data.read_sql(
    "SELECT year, COUNT(*) FROM movie GROUP BY year", create_connection
)

要从 Snowflake 读取,请安装 Snowflake Connector for Python

pip install snowflake-connector-python

然后,定义您的连接逻辑并查询数据库。

import snowflake.connector

import ray

def create_connection():
    return snowflake.connector.connect(
        user=...,
        password=...
        account="ZZKXUVH-IPB52023",
        database="example",
    )

# Get all movies
dataset = ray.data.read_sql("SELECT * FROM movie", create_connection)
# Get movies after the year 1980
dataset = ray.data.read_sql(
    "SELECT title, score FROM movie WHERE year >= 1980", create_connection
)
# Get the number of movies per year
dataset = ray.data.read_sql(
    "SELECT year, COUNT(*) FROM movie GROUP BY year", create_connection
)

要从 Databricks 读取,请将 DATABRICKS_TOKEN 环境变量设置为您的 Databricks SQL 仓库访问令牌。

export DATABRICKS_TOKEN=...

如果您不是在 Databricks 运行时上运行程序,请同时设置 DATABRICKS_HOST 环境变量。

export DATABRICKS_HOST=adb-<workspace-id>.<random-number>.azuredatabricks.net

然后,调用 ray.data.read_databricks_tables() 从 Databricks SQL 仓库读取。

import ray

dataset = ray.data.read_databricks_tables(
    warehouse_id='...',  # Databricks SQL warehouse ID
    catalog='catalog_1',  # Unity catalog name
    schema='db_1',  # Schema name
    query="SELECT title, score FROM movie WHERE year >= 1980",
)

要从 BigQuery 读取,请安装 Google BigQuery 的 Python 客户端Google BigQueryStorage 的 Python 客户端

pip install google-cloud-bigquery
pip install google-cloud-bigquery-storage

要从 BigQuery 读取数据,请调用 read_bigquery() 并指定项目 ID、数据集和查询(如果适用)。

import ray

# Read the entire dataset. Do not specify query.
ds = ray.data.read_bigquery(
    project_id="my_gcloud_project_id",
    dataset="bigquery-public-data.ml_datasets.iris",
)

# Read from a SQL query of the dataset. Do not specify dataset.
ds = ray.data.read_bigquery(
    project_id="my_gcloud_project_id",
    query = "SELECT * FROM `bigquery-public-data.ml_datasets.iris` LIMIT 50",
)

# Write back to BigQuery
ds.write_bigquery(
    project_id="my_gcloud_project_id",
    dataset="destination_dataset.destination_table",
    overwrite_table=True,
)

读取 MongoDB#

要从 MongoDB 读取数据,请调用 read_mongo() 并指定源 URI、数据库和集合。您还需要指定一个将在集合上运行的管道。

import ray

# Read a local MongoDB.
ds = ray.data.read_mongo(
    uri="mongodb://:27017",
    database="my_db",
    collection="my_collection",
    pipeline=[{"$match": {"col": {"$gte": 0, "$lt": 10}}}, {"$sort": "sort_col"}],
)

# Reading a remote MongoDB is the same.
ds = ray.data.read_mongo(
    uri="mongodb://username:password@mongodb0.example.com:27017/?authSource=admin",
    database="my_db",
    collection="my_collection",
    pipeline=[{"$match": {"col": {"$gte": 0, "$lt": 10}}}, {"$sort": "sort_col"}],
)

# Write back to MongoDB.
ds.write_mongo(
    MongoDatasource(),
    uri="mongodb://username:password@mongodb0.example.com:27017/?authSource=admin",
    database="my_db",
    collection="my_collection",
)

从 Kafka 读取#

Ray Data 可以从 Kafka 等消息队列读取。

要从 Kafka 主题读取数据,请调用 read_kafka() 并指定主题名称和代理地址。Ray Data 在起始偏移量和结束偏移量之间执行有界读取。

首先,安装所需的依赖项

pip install kafka-python

然后,指定您的 Kafka 配置并从主题读取。

import ray

# Read from a single topic with offset range
ds = ray.data.read_kafka(
    topics="my-topic",
    bootstrap_servers="localhost:9092",
    start_offset=0,
    end_offset=1000,
)

# Read from multiple topics
ds = ray.data.read_kafka(
    topics=["topic1", "topic2"],
    bootstrap_servers="localhost:9092",
    start_offset="earliest",
    end_offset="latest",
)

# Read with authentication
from ray.data import KafkaAuthConfig

auth_config = KafkaAuthConfig(
    security_protocol="SASL_SSL",
    sasl_mechanism="PLAIN",
    sasl_plain_username="your-username",
    sasl_plain_password="your-password",
)

ds = ray.data.read_kafka(
    topics="secure-topic",
    bootstrap_servers="localhost:9092",
    kafka_auth_config=auth_config,
)

print(ds.schema())
Column          Type
------          ----
offset          int64
key             binary
value           binary
topic           string
partition       int32
timestamp       int64
timestamp_type  int32
headers         map<string, binary>

创建合成数据#

合成数据集对于测试和基准测试非常有用。

要从整数范围创建合成 Dataset,请调用 range()。Ray Data 将整数范围存储在一个名为“id”的列中。

import ray

ds = ray.data.range(10000)

print(ds.schema())
Column  Type
------  ----
id      int64

要创建包含数组的合成 Dataset,请调用 range_tensor()。Ray Data 将整数范围打包到具有指定形状的 ndarrays 中。在 schema 中,列名默认为“data”。

import ray

ds = ray.data.range_tensor(10, shape=(64, 64))

print(ds.schema())
Column  Type
------  ----
data    ArrowTensorTypeV2(shape=(64, 64), dtype=int64)

加载其他数据源#

如果 Ray Data 无法加载您的数据,请继承 Datasource。然后,构建您的自定义数据源的实例,并将其传递给 read_datasource()。要写入结果,您可能还需要继承 ray.data.Datasink。然后,创建您的自定义数据汇的实例,并将其传递给 write_datasink()。有关更多详细信息,请参阅 高级:读取和写入自定义文件类型

# Read from a custom datasource.
ds = ray.data.read_datasource(YourCustomDatasource(), **read_args)

# Write to a custom datasink.
ds.write_datasink(YourCustomDatasink())

性能注意事项#

默认情况下,所有读取任务的输出块数量是根据输入数据大小和可用资源动态决定的。在大多数情况下,这应该效果不错。但是,您也可以通过设置 override_num_blocks 参数来覆盖默认值。Ray Data 会在内部决定同时运行多少个读取任务以最佳地利用集群,范围从 1...override_num_blocks 个任务。换句话说,override_num_blocks 的值越高,Dataset 中的数据块就越小,从而提供更多的并行执行机会。

有关如何调整输出块数量以及优化读取性能的其他建议,请参阅 优化读取