使用图像#
借助 Ray Data,您可以轻松读取和转换大型图像数据集。
本指南将向您展示如何
读取图像#
Ray Data 可以读取多种格式的图像。
要查看支持的文件格式的完整列表,请参阅 输入/输出参考。
要加载 JPEG 文件等原始图像,请调用 read_images()。在 schema 中,“image” 列名是默认值。
注意
read_images() 使用 PIL。支持的文件格式列表,请参阅 图像文件格式。
import ray
ds = ray.data.read_images("s3://anonymous@ray-example-data/batoidea/JPEGImages")
print(ds.schema())
Column Type
------ ----
image ArrowTensorTypeV2(shape=(32, 32, 3), dtype=uint8)
要从 URI 数据集中加载图像,请将 with_column() 方法与 download() 表达式结合使用。
import ray
from ray.data.expressions import download
ds = ray.data.read_parquet("s3://anonymous@ray-example-data/imagenet/metadata_file.parquet")
ds = ds.with_column("bytes", download("image_url"))
print(ds.schema())
Column Type
------ ----
image_url string
bytes null
要加载以 NumPy 格式存储的图像,请调用 read_numpy()。
import ray
ds = ray.data.read_numpy("s3://anonymous@air-example-data/cifar-10/images.npy")
print(ds.schema())
Column Type
------ ----
data ArrowTensorTypeV2(shape=(32, 32, 3), dtype=uint8)
图像数据集通常包含如下所示的 tf.train.Example 消息
features {
feature {
key: "image"
value {
bytes_list {
value: ... # Raw image bytes
}
}
}
feature {
key: "label"
value {
int64_list {
value: 3
}
}
}
}
要加载此格式存储的示例,请调用 read_tfrecords()。然后,调用 map() 来解码原始图像字节。
import io
from typing import Any, Dict
import numpy as np
from PIL import Image
import ray
def decode_bytes(row: Dict[str, Any]) -> Dict[str, Any]:
data = row["image"]
image = Image.open(io.BytesIO(data))
row["image"] = np.asarray(image)
return row
ds = (
ray.data.read_tfrecords(
"s3://anonymous@air-example-data/cifar-10/tfrecords"
)
.map(decode_bytes)
)
print(ds.schema())
Column Type
------ ----
image ArrowTensorTypeV2(shape=(32, 32, 3), dtype=uint8)
label int64
要加载以 Parquet 文件存储的图像数据,请调用 ray.data.read_parquet()。
import ray
ds = ray.data.read_parquet("s3://anonymous@air-example-data/cifar-10/parquet")
print(ds.schema())
Column Type
------ ----
img struct<bytes: binary, path: string>
label int64
有关创建数据集的更多信息,请参阅 加载数据。
转换图像#
要转换图像,请调用 map() 或 map_batches()。
from typing import Any, Dict
import numpy as np
import ray
def increase_brightness(batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
batch["image"] = np.clip(batch["image"] + 4, 0, 255)
return batch
ds = (
ray.data.read_images("s3://anonymous@ray-example-data/batoidea/JPEGImages")
.map_batches(increase_brightness)
)
有关转换数据的更多信息,请参阅 转换数据。
对图像进行推理#
要使用预训练模型进行推理,请先加载和转换您的数据。
from typing import Any, Dict
from torchvision import transforms
import ray
def transform_image(row: Dict[str, Any]) -> Dict[str, Any]:
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((32, 32))
])
row["image"] = transform(row["image"])
return row
ds = (
ray.data.read_images("s3://anonymous@ray-example-data/batoidea/JPEGImages")
.map(transform_image)
)
接下来,实现一个可调用的类来设置和调用您的模型。
import torch
from torchvision import models
class ImageClassifier:
def __init__(self):
weights = models.ResNet18_Weights.DEFAULT
self.model = models.resnet18(weights=weights)
self.model.eval()
def __call__(self, batch):
inputs = torch.from_numpy(batch["image"])
with torch.inference_mode():
outputs = self.model(inputs)
return {"class": outputs.argmax(dim=1)}
最后,调用 Dataset.map_batches()。
predictions = ds.map_batches(
ImageClassifier,
compute=ray.data.ActorPoolStrategy(size=2),
batch_size=4
)
predictions.show(3)
{'class': 118}
{'class': 153}
{'class': 296}
有关执行推理的更多信息,请参阅 端到端:离线批量推理 和 有状态转换。
保存图像#
以 PNG、Parquet 和 NumPy 等格式保存图像。要查看所有支持的格式,请参阅 输入/输出参考。
要将图像保存为图像文件,请调用 write_images()。
import ray
ds = ray.data.read_images("s3://anonymous@ray-example-data/image-datasets/simple")
ds.write_images("/tmp/simple", column="image", file_format="png")
要将图像保存到 Parquet 文件,请调用 write_parquet()。
import ray
ds = ray.data.read_images("s3://anonymous@ray-example-data/image-datasets/simple")
ds.write_parquet("/tmp/simple")
要将图像保存到 NumPy 文件,请调用 write_numpy()。
import ray
ds = ray.data.read_images("s3://anonymous@ray-example-data/image-datasets/simple")
ds.write_numpy("/tmp/simple", column="image")
有关保存数据的更多信息,请参阅 保存数据。