使用 PyTorch 进行对象检测批量推理#
此示例演示了如何使用预训练的 PyTorch 模型和 Ray Data 大规模执行对象检测批量推理。
你将执行以下操作
使用预训练的 PyTorch 模型对单个图像执行对象检测。
使用 Ray Data 对 PyTorch 模型进行扩展,并对大量图像执行对象检测批量推理。
验证推理结果并将其保存到外部存储。
学习如何将 Ray Data 与 GPU 结合使用。
开始之前#
如果尚未安装,请安装以下依赖项。
!pip install -q "ray[data]" torchvision
使用 PyTorch 对单个图像进行对象检测#
在深入了解 Ray Data 之前,让我们看看 PyTorch 官方文档中的这个对象检测示例。该示例使用了预训练模型 (FasterRCNN_ResNet50) 对单个图像进行对象检测推理。
首先,从互联网下载一张图片。
import requests
from PIL import Image
url = "https://s3-us-west-2.amazonaws.com/air-example-data/AnimalDetection/JPEGImages/2007_000063.jpg"
img = Image.open(requests.get(url, stream=True).raw)
display(img)

其次,加载并初始化一个预训练的 PyTorch 模型。
from torchvision import transforms
from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2, FasterRCNN_ResNet50_FPN_V2_Weights
weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
model = fasterrcnn_resnet50_fpn_v2(weights=weights, box_score_thresh=0.9)
model.eval();
然后应用预处理转换。
img = transforms.Compose([transforms.PILToTensor()])(img)
preprocess = weights.transforms()
batch = [preprocess(img)]
然后使用模型进行推理。
prediction = model(batch)[0]
最后,可视化结果。
from torchvision.utils import draw_bounding_boxes
from torchvision.transforms.functional import to_pil_image
labels = [weights.meta["categories"][i] for i in prediction["labels"]]
box = draw_bounding_boxes(img,
boxes=prediction["boxes"],
labels=labels,
colors="red",
width=4)
im = to_pil_image(box.detach())
display(im)

使用 Ray Data 进行扩展#
接下来,让我们看看如何将前面的示例扩展到大量图像。我们将使用 Ray Data 以流式和分布式方式进行批量推理,充分利用集群中的所有 CPU 和 GPU 资源。
加载图像数据集#
我们将使用的数据集是 Pascal VOC 的一个子集,包含猫和狗(完整数据集有 20 个类别)。该数据集中共有 2434 张图片。
首先,我们使用 ray.data.read_images
API 从 S3 加载准备好的图像数据集。我们可以使用 schema
API 检查数据集的模式。正如我们所见,它有一列名为“image”,值是表示为 np.ndarray
格式的图像数据。
import ray
ds = ray.data.read_images("s3://anonymous@air-example-data/AnimalDetection/JPEGImages")
display(ds.schema())
2025-02-05 14:22:50,021 INFO worker.py:1841 -- Started a local Ray instance.
2025-02-05 14:22:50,698 INFO streaming_executor.py:108 -- Starting execution of Dataset. Full logs are in /tmp/ray/session_2025-02-05_14-22-49_425292_37149/logs/ray-data
2025-02-05 14:22:50,698 INFO streaming_executor.py:109 -- Execution plan of Dataset: InputDataBuffer[Input] -> TaskPoolMapOperator[ReadImage]
[dataset]: Run `pip install tqdm` to enable progress reporting.
Column Type
------ ----
image numpy.ndarray(ndim=3, dtype=uint8)
使用 Ray Data 进行批量推理#
正如我们在 PyTorch 示例中看到的,模型推理包括 2 个步骤:图像预处理和模型推理。
预处理#
首先,让我们将预处理代码转换为 Ray Data。我们将预处理代码打包到一个 preprocess_image
函数中。该函数应该只接受一个参数,即一个字典,其中包含数据集中表示为 numpy 数组的单张图像。
import numpy as np
import torch
from torchvision import transforms
from torchvision.models.detection import (FasterRCNN_ResNet50_FPN_V2_Weights,
fasterrcnn_resnet50_fpn_v2)
from typing import Dict
def preprocess_image(data: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
preprocessor = transforms.Compose(
[transforms.ToTensor(), weights.transforms()]
)
return {
"image": data["image"],
"transformed": preprocessor(data["image"]),
}
然后我们使用 map
API 将该函数应用于整个数据集。通过使用 Ray Data 的 map,我们可以将预处理扩展到 Ray 集群中的所有资源。请注意,map
方法是惰性的,直到我们开始消费结果才会执行。
ds = ds.map(preprocess_image)
模型推理#
接下来,让我们转换模型推理部分。与预处理相比,模型推理有 2 个不同之处:
模型加载和初始化通常很耗时。
如果我们以批量方式处理数据,可以使用硬件加速来优化模型推理。使用更大的批量可以提高 GPU 利用率和推理作业的整体运行时长。
因此,我们将模型推理代码转换为以下 ObjectDetectionModel
类。在该类中,我们将耗时的模型加载和初始化代码放在 __init__
构造函数中,它只会运行一次。我们将模型推理代码放在 __call__
方法中,该方法将为每个批量调用。
__call__ 方法接收一批数据项,而不是单个数据项。在这种情况下,批量也是一个字典,其中包含一个名为“image”的键,其值是表示为 np.ndarray
格式的图像数组。我们还可以使用 take_batch
API 来获取单个批量,并检查其内部数据结构。
single_batch = ds.take_batch(batch_size=3)
display(single_batch)
2025-02-05 14:22:51,757 INFO streaming_executor.py:108 -- Starting execution of Dataset. Full logs are in /tmp/ray/session_2025-02-05_14-22-49_425292_37149/logs/ray-data
2025-02-05 14:22:51,757 INFO streaming_executor.py:109 -- Execution plan of Dataset: InputDataBuffer[Input] -> TaskPoolMapOperator[ReadImage->Map(preprocess_image)] -> LimitOperator[limit=3]
{'image': array([array([[[137, 59, 0],
[139, 61, 0],
[145, 65, 2],
...,
[141, 71, 2],
[140, 69, 7],
[138, 68, 8]],
[[135, 55, 0],
[138, 58, 0],
[143, 63, 2],
...,
[142, 69, 1],
[140, 69, 5],
[138, 68, 6]],
[[141, 59, 1],
[145, 63, 3],
[146, 64, 6],
...,
[143, 70, 1],
[141, 70, 4],
[139, 68, 4]],
...,
[[223, 193, 157],
[219, 189, 153],
[188, 156, 118],
...,
[151, 51, 15],
[147, 47, 11],
[142, 42, 6]],
[[224, 194, 158],
[225, 195, 159],
[224, 192, 154],
...,
[148, 48, 12],
[145, 45, 9],
[139, 39, 3]],
[[227, 195, 157],
[236, 204, 166],
[215, 181, 144],
...,
[148, 50, 13],
[145, 47, 10],
[138, 40, 3]]], shape=(375, 500, 3), dtype=uint8),
array([[[ 78, 111, 104],
[ 80, 113, 104],
[ 83, 116, 105],
...,
[153, 179, 192],
[177, 200, 216],
[192, 215, 231]],
[[ 70, 105, 101],
[ 76, 111, 105],
[ 73, 106, 99],
...,
[154, 180, 193],
[141, 167, 182],
[127, 153, 168]],
[[ 83, 122, 127],
[ 54, 92, 95],
[ 72, 103, 105],
...,
[157, 185, 197],
[157, 185, 199],
[154, 181, 198]],
...,
[[ 1, 103, 107],
[ 1, 103, 107],
[ 5, 102, 108],
...,
[127, 40, 46],
[145, 54, 61],
[144, 53, 60]],
[[ 1, 103, 107],
[ 0, 102, 106],
[ 2, 101, 106],
...,
[139, 54, 61],
[134, 47, 53],
[134, 47, 53]],
[[ 0, 102, 105],
[ 0, 102, 106],
[ 4, 103, 109],
...,
[121, 47, 48],
[148, 68, 71],
[137, 52, 57]]], shape=(375, 500, 3), dtype=uint8),
array([[[19, 1, 1],
[23, 5, 5],
[22, 2, 3],
...,
[56, 29, 10],
[62, 34, 13],
[69, 41, 20]],
[[25, 7, 7],
[22, 4, 4],
[21, 1, 2],
...,
[55, 27, 6],
[73, 42, 22],
[67, 36, 15]],
[[19, 3, 3],
[18, 2, 2],
[19, 1, 1],
...,
[59, 28, 7],
[75, 43, 22],
[69, 37, 14]],
...,
[[10, 2, 17],
[14, 11, 22],
[ 9, 12, 17],
...,
[14, 18, 30],
[10, 12, 27],
[ 8, 10, 23]],
[[12, 4, 19],
[ 9, 6, 17],
[ 7, 12, 16],
...,
[ 8, 12, 24],
[ 3, 7, 19],
[ 5, 9, 21]],
[[ 9, 2, 18],
[ 4, 2, 15],
[ 5, 10, 14],
...,
[ 8, 10, 22],
[ 2, 4, 16],
[ 7, 9, 21]]], shape=(375, 500, 3), dtype=uint8)],
dtype=object),
'transformed': array([array([[[0.5372549 , 0.54509807, 0.5686275 , ..., 0.5529412 ,
0.54901963, 0.5411765 ],
[0.5294118 , 0.5411765 , 0.56078434, ..., 0.5568628 ,
0.54901963, 0.5411765 ],
[0.5529412 , 0.5686275 , 0.57254905, ..., 0.56078434,
0.5529412 , 0.54509807],
...,
[0.8745098 , 0.85882354, 0.7372549 , ..., 0.5921569 ,
0.5764706 , 0.5568628 ],
[0.8784314 , 0.88235295, 0.8784314 , ..., 0.5803922 ,
0.5686275 , 0.54509807],
[0.8901961 , 0.9254902 , 0.84313726, ..., 0.5803922 ,
0.5686275 , 0.5411765 ]],
[[0.23137255, 0.23921569, 0.25490198, ..., 0.2784314 ,
0.27058825, 0.26666668],
[0.21568628, 0.22745098, 0.24705882, ..., 0.27058825,
0.27058825, 0.26666668],
[0.23137255, 0.24705882, 0.2509804 , ..., 0.27450982,
0.27450982, 0.26666668],
...,
[0.75686276, 0.7411765 , 0.6117647 , ..., 0.2 ,
0.18431373, 0.16470589],
[0.7607843 , 0.7647059 , 0.7529412 , ..., 0.1882353 ,
0.1764706 , 0.15294118],
[0.7647059 , 0.8 , 0.70980394, ..., 0.19607843,
0.18431373, 0.15686275]],
[[0. , 0. , 0.00784314, ..., 0.00784314,
0.02745098, 0.03137255],
[0. , 0. , 0.00784314, ..., 0.00392157,
0.01960784, 0.02352941],
[0.00392157, 0.01176471, 0.02352941, ..., 0.00392157,
0.01568628, 0.01568628],
...,
[0.6156863 , 0.6 , 0.4627451 , ..., 0.05882353,
0.04313726, 0.02352941],
[0.61960787, 0.62352943, 0.6039216 , ..., 0.04705882,
0.03529412, 0.01176471],
[0.6156863 , 0.6509804 , 0.5647059 , ..., 0.05098039,
0.03921569, 0.01176471]]], shape=(3, 375, 500), dtype=float32),
array([[[0.30588236, 0.3137255 , 0.3254902 , ..., 0.6 ,
0.69411767, 0.7529412 ],
[0.27450982, 0.29803923, 0.28627452, ..., 0.6039216 ,
0.5529412 , 0.49803922],
[0.3254902 , 0.21176471, 0.28235295, ..., 0.6156863 ,
0.6156863 , 0.6039216 ],
...,
[0.00392157, 0.00392157, 0.01960784, ..., 0.49803922,
0.5686275 , 0.5647059 ],
[0.00392157, 0. , 0.00784314, ..., 0.54509807,
0.5254902 , 0.5254902 ],
[0. , 0. , 0.01568628, ..., 0.4745098 ,
0.5803922 , 0.5372549 ]],
[[0.43529412, 0.44313726, 0.45490196, ..., 0.7019608 ,
0.78431374, 0.84313726],
[0.4117647 , 0.43529412, 0.41568628, ..., 0.7058824 ,
0.654902 , 0.6 ],
[0.47843137, 0.36078432, 0.40392157, ..., 0.7254902 ,
0.7254902 , 0.70980394],
...,
[0.40392157, 0.40392157, 0.4 , ..., 0.15686275,
0.21176471, 0.20784314],
[0.40392157, 0.4 , 0.39607844, ..., 0.21176471,
0.18431373, 0.18431373],
[0.4 , 0.4 , 0.40392157, ..., 0.18431373,
0.26666668, 0.20392157]],
[[0.40784314, 0.40784314, 0.4117647 , ..., 0.7529412 ,
0.84705883, 0.90588236],
[0.39607844, 0.4117647 , 0.3882353 , ..., 0.75686276,
0.7137255 , 0.65882355],
[0.49803922, 0.37254903, 0.4117647 , ..., 0.77254903,
0.78039217, 0.7764706 ],
...,
[0.41960785, 0.41960785, 0.42352942, ..., 0.18039216,
0.23921569, 0.23529412],
[0.41960785, 0.41568628, 0.41568628, ..., 0.23921569,
0.20784314, 0.20784314],
[0.4117647 , 0.41568628, 0.42745098, ..., 0.1882353 ,
0.2784314 , 0.22352941]]], shape=(3, 375, 500), dtype=float32),
array([[[0.07450981, 0.09019608, 0.08627451, ..., 0.21960784,
0.24313726, 0.27058825],
[0.09803922, 0.08627451, 0.08235294, ..., 0.21568628,
0.28627452, 0.2627451 ],
[0.07450981, 0.07058824, 0.07450981, ..., 0.23137255,
0.29411766, 0.27058825],
...,
[0.03921569, 0.05490196, 0.03529412, ..., 0.05490196,
0.03921569, 0.03137255],
[0.04705882, 0.03529412, 0.02745098, ..., 0.03137255,
0.01176471, 0.01960784],
[0.03529412, 0.01568628, 0.01960784, ..., 0.03137255,
0.00784314, 0.02745098]],
[[0.00392157, 0.01960784, 0.00784314, ..., 0.11372549,
0.13333334, 0.16078432],
[0.02745098, 0.01568628, 0.00392157, ..., 0.10588235,
0.16470589, 0.14117648],
[0.01176471, 0.00784314, 0.00392157, ..., 0.10980392,
0.16862746, 0.14509805],
...,
[0.00784314, 0.04313726, 0.04705882, ..., 0.07058824,
0.04705882, 0.03921569],
[0.01568628, 0.02352941, 0.04705882, ..., 0.04705882,
0.02745098, 0.03529412],
[0.00784314, 0.00784314, 0.03921569, ..., 0.03921569,
0.01568628, 0.03529412]],
[[0.00392157, 0.01960784, 0.01176471, ..., 0.03921569,
0.05098039, 0.07843138],
[0.02745098, 0.01568628, 0.00784314, ..., 0.02352941,
0.08627451, 0.05882353],
[0.01176471, 0.00784314, 0.00392157, ..., 0.02745098,
0.08627451, 0.05490196],
...,
[0.06666667, 0.08627451, 0.06666667, ..., 0.11764706,
0.10588235, 0.09019608],
[0.07450981, 0.06666667, 0.0627451 , ..., 0.09411765,
0.07450981, 0.08235294],
[0.07058824, 0.05882353, 0.05490196, ..., 0.08627451,
0.0627451 , 0.08235294]]], shape=(3, 375, 500), dtype=float32)],
dtype=object)}
class ObjectDetectionModel:
def __init__(self):
# Define the model loading and initialization code in `__init__`.
self.weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
self.model = fasterrcnn_resnet50_fpn_v2(
weights=self.weights,
box_score_thresh=0.9,
)
if torch.cuda.is_available():
# Move the model to GPU if it's available.
self.model = self.model.cuda()
self.model.eval()
def __call__(self, input_batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
# Define the per-batch inference code in `__call__`.
batch = [torch.from_numpy(image) for image in input_batch["transformed"]]
if torch.cuda.is_available():
# Move the data to GPU if it's available.
batch = [image.cuda() for image in batch]
predictions = self.model(batch)
# keep the original image for visualization purposes
return {
"image": input_batch["image"],
"labels": [pred["labels"].detach().cpu().numpy() for pred in predictions],
"boxes": [pred["boxes"].detach().cpu().numpy() for pred in predictions],
}
然后我们使用 map_batches
API 将模型应用于整个数据集。
map
和 map_batches
的第一个参数是用户定义函数 (UDF),它可以是函数或类。基于函数的 UDF 作为短时运行的 Ray 任务运行,而基于类的 UDF 作为长时间运行的 Ray Actor 运行。对于基于类的 UDF,使用 concurrency
参数指定并行 Actor 的数量。batch_size
参数表示每个批量中的图像数量。
num_gpus
参数指定每个 ObjectDetectionModel
实例所需的 GPU 数量。Ray 调度器可以处理异构资源需求,以便最大化资源利用率。在这种情况下,ObjectDetectionModel
实例将在 GPU 上运行,而 preprocess_image
实例将在 CPU 上运行。
ds = ds.map_batches(
ObjectDetectionModel,
# Use 4 model replicas. Change this number based on the number of GPUs in your cluster.
concurrency=4,
batch_size=4, # Use the largest batch size that can fit in GPU memory.
# Specify 1 GPU per model replica. Set to 0 if you are doing CPU inference.
num_gpus=1,
)
验证并保存结果#
接下来,让我们获取一小部分批量,并通过可视化验证推理结果。
from torchvision.transforms.functional import convert_image_dtype, to_tensor
batch = ds.take_batch(batch_size=2)
for image, labels, boxes in zip(batch["image"], batch["labels"], batch["boxes"]):
image = convert_image_dtype(to_tensor(image), torch.uint8)
labels = [weights.meta["categories"][i] for i in labels]
boxes = torch.from_numpy(boxes)
img = to_pil_image(draw_bounding_boxes(
image,
boxes,
labels=labels,
colors="red",
width=4,
))
display(img)
2025-02-05 14:22:53,627 INFO streaming_executor.py:108 -- Starting execution of Dataset. Full logs are in /tmp/ray/session_2025-02-05_14-22-49_425292_37149/logs/ray-data
2025-02-05 14:22:53,628 INFO streaming_executor.py:109 -- Execution plan of Dataset: InputDataBuffer[Input] -> ActorPoolMapOperator[ReadImage->Map(preprocess_image)->MapBatches(ObjectDetectionModel)] -> LimitOperator[limit=2]
2025-02-05 14:22:55,891 WARNING progress_bar.py:120 -- Truncating long operator name to 100 characters. To disable this behavior, set `ray.data.DataContext.get_current().DEFAULT_ENABLE_PROGRESS_BAR_NAME_TRUNCATION = False`.


如果样本看起来不错,我们可以继续将结果保存到外部存储,例如 S3 或本地磁盘。有关所有支持的存储和文件格式,请参阅Ray Data 输入/输出。
ds.write_parquet("local://tmp/inference_results")