使用 PyTorch 进行对象检测批量推理#

try-anyscale-quickstart

此示例演示了如何使用预训练的 PyTorch 模型和 Ray Data 大规模执行对象检测批量推理。

你将执行以下操作

  1. 使用预训练的 PyTorch 模型对单个图像执行对象检测。

  2. 使用 Ray Data 对 PyTorch 模型进行扩展,并对大量图像执行对象检测批量推理。

  3. 验证推理结果并将其保存到外部存储。

  4. 学习如何将 Ray Data 与 GPU 结合使用。

开始之前#

如果尚未安装,请安装以下依赖项。

!pip install -q "ray[data]" torchvision

使用 PyTorch 对单个图像进行对象检测#

在深入了解 Ray Data 之前,让我们看看 PyTorch 官方文档中的这个对象检测示例。该示例使用了预训练模型 (FasterRCNN_ResNet50) 对单个图像进行对象检测推理。

首先,从互联网下载一张图片。

import requests
from PIL import Image

url = "https://s3-us-west-2.amazonaws.com/air-example-data/AnimalDetection/JPEGImages/2007_000063.jpg"
img = Image.open(requests.get(url, stream=True).raw)
display(img)
../../_images/612f82906f846ed344b779b0175f581ceb56c3f5283b42fc83e00b8f80f3904b.png

其次,加载并初始化一个预训练的 PyTorch 模型。

from torchvision import transforms
from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2, FasterRCNN_ResNet50_FPN_V2_Weights

weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
model = fasterrcnn_resnet50_fpn_v2(weights=weights, box_score_thresh=0.9)
model.eval();

然后应用预处理转换。

img = transforms.Compose([transforms.PILToTensor()])(img)
preprocess = weights.transforms()
batch = [preprocess(img)]

然后使用模型进行推理。

prediction = model(batch)[0]

最后,可视化结果。

from torchvision.utils import draw_bounding_boxes
from torchvision.transforms.functional import to_pil_image

labels = [weights.meta["categories"][i] for i in prediction["labels"]]
box = draw_bounding_boxes(img,
                          boxes=prediction["boxes"],
                          labels=labels,
                          colors="red",
                          width=4)
im = to_pil_image(box.detach())
display(im)
../../_images/ac3f24c4e3d83b1b0ff3c0e8f2d4cc3dd1d0d1ee1c62f94e936da56d9b469883.png

使用 Ray Data 进行扩展#

接下来,让我们看看如何将前面的示例扩展到大量图像。我们将使用 Ray Data 以流式和分布式方式进行批量推理,充分利用集群中的所有 CPU 和 GPU 资源。

加载图像数据集#

我们将使用的数据集是 Pascal VOC 的一个子集,包含猫和狗(完整数据集有 20 个类别)。该数据集中共有 2434 张图片。

首先,我们使用 ray.data.read_images API 从 S3 加载准备好的图像数据集。我们可以使用 schema API 检查数据集的模式。正如我们所见,它有一列名为“image”,值是表示为 np.ndarray 格式的图像数据。

import ray

ds = ray.data.read_images("s3://anonymous@air-example-data/AnimalDetection/JPEGImages")
display(ds.schema())
2025-02-05 14:22:50,021	INFO worker.py:1841 -- Started a local Ray instance.
2025-02-05 14:22:50,698	INFO streaming_executor.py:108 -- Starting execution of Dataset. Full logs are in /tmp/ray/session_2025-02-05_14-22-49_425292_37149/logs/ray-data
2025-02-05 14:22:50,698	INFO streaming_executor.py:109 -- Execution plan of Dataset: InputDataBuffer[Input] -> TaskPoolMapOperator[ReadImage]
[dataset]: Run `pip install tqdm` to enable progress reporting.
Column  Type
------  ----
image   numpy.ndarray(ndim=3, dtype=uint8)

使用 Ray Data 进行批量推理#

正如我们在 PyTorch 示例中看到的,模型推理包括 2 个步骤:图像预处理和模型推理。

预处理#

首先,让我们将预处理代码转换为 Ray Data。我们将预处理代码打包到一个 preprocess_image 函数中。该函数应该只接受一个参数,即一个字典,其中包含数据集中表示为 numpy 数组的单张图像。

import numpy as np
import torch
from torchvision import transforms
from torchvision.models.detection import (FasterRCNN_ResNet50_FPN_V2_Weights,
                                          fasterrcnn_resnet50_fpn_v2)
from typing import Dict


def preprocess_image(data: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
    weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
    preprocessor = transforms.Compose(
        [transforms.ToTensor(), weights.transforms()]
    )
    return {
        "image": data["image"],
        "transformed": preprocessor(data["image"]),
    }

然后我们使用 map API 将该函数应用于整个数据集。通过使用 Ray Data 的 map,我们可以将预处理扩展到 Ray 集群中的所有资源。请注意,map 方法是惰性的,直到我们开始消费结果才会执行。

ds = ds.map(preprocess_image)

模型推理#

接下来,让我们转换模型推理部分。与预处理相比,模型推理有 2 个不同之处:

  1. 模型加载和初始化通常很耗时。

  2. 如果我们以批量方式处理数据,可以使用硬件加速来优化模型推理。使用更大的批量可以提高 GPU 利用率和推理作业的整体运行时长。

因此,我们将模型推理代码转换为以下 ObjectDetectionModel 类。在该类中,我们将耗时的模型加载和初始化代码放在 __init__ 构造函数中,它只会运行一次。我们将模型推理代码放在 __call__ 方法中,该方法将为每个批量调用。

__call__ 方法接收一批数据项,而不是单个数据项。在这种情况下,批量也是一个字典,其中包含一个名为“image”的键,其值是表示为 np.ndarray 格式的图像数组。我们还可以使用 take_batch API 来获取单个批量,并检查其内部数据结构。

single_batch = ds.take_batch(batch_size=3)
display(single_batch)
2025-02-05 14:22:51,757	INFO streaming_executor.py:108 -- Starting execution of Dataset. Full logs are in /tmp/ray/session_2025-02-05_14-22-49_425292_37149/logs/ray-data
2025-02-05 14:22:51,757	INFO streaming_executor.py:109 -- Execution plan of Dataset: InputDataBuffer[Input] -> TaskPoolMapOperator[ReadImage->Map(preprocess_image)] -> LimitOperator[limit=3]
{'image': array([array([[[137,  59,   0],
                [139,  61,   0],
                [145,  65,   2],
                ...,
                [141,  71,   2],
                [140,  69,   7],
                [138,  68,   8]],
 
               [[135,  55,   0],
                [138,  58,   0],
                [143,  63,   2],
                ...,
                [142,  69,   1],
                [140,  69,   5],
                [138,  68,   6]],
 
               [[141,  59,   1],
                [145,  63,   3],
                [146,  64,   6],
                ...,
                [143,  70,   1],
                [141,  70,   4],
                [139,  68,   4]],
 
               ...,
 
               [[223, 193, 157],
                [219, 189, 153],
                [188, 156, 118],
                ...,
                [151,  51,  15],
                [147,  47,  11],
                [142,  42,   6]],
 
               [[224, 194, 158],
                [225, 195, 159],
                [224, 192, 154],
                ...,
                [148,  48,  12],
                [145,  45,   9],
                [139,  39,   3]],
 
               [[227, 195, 157],
                [236, 204, 166],
                [215, 181, 144],
                ...,
                [148,  50,  13],
                [145,  47,  10],
                [138,  40,   3]]], shape=(375, 500, 3), dtype=uint8),
        array([[[ 78, 111, 104],
                [ 80, 113, 104],
                [ 83, 116, 105],
                ...,
                [153, 179, 192],
                [177, 200, 216],
                [192, 215, 231]],
 
               [[ 70, 105, 101],
                [ 76, 111, 105],
                [ 73, 106,  99],
                ...,
                [154, 180, 193],
                [141, 167, 182],
                [127, 153, 168]],
 
               [[ 83, 122, 127],
                [ 54,  92,  95],
                [ 72, 103, 105],
                ...,
                [157, 185, 197],
                [157, 185, 199],
                [154, 181, 198]],
 
               ...,
 
               [[  1, 103, 107],
                [  1, 103, 107],
                [  5, 102, 108],
                ...,
                [127,  40,  46],
                [145,  54,  61],
                [144,  53,  60]],
 
               [[  1, 103, 107],
                [  0, 102, 106],
                [  2, 101, 106],
                ...,
                [139,  54,  61],
                [134,  47,  53],
                [134,  47,  53]],
 
               [[  0, 102, 105],
                [  0, 102, 106],
                [  4, 103, 109],
                ...,
                [121,  47,  48],
                [148,  68,  71],
                [137,  52,  57]]], shape=(375, 500, 3), dtype=uint8),
        array([[[19,  1,  1],
                [23,  5,  5],
                [22,  2,  3],
                ...,
                [56, 29, 10],
                [62, 34, 13],
                [69, 41, 20]],
 
               [[25,  7,  7],
                [22,  4,  4],
                [21,  1,  2],
                ...,
                [55, 27,  6],
                [73, 42, 22],
                [67, 36, 15]],
 
               [[19,  3,  3],
                [18,  2,  2],
                [19,  1,  1],
                ...,
                [59, 28,  7],
                [75, 43, 22],
                [69, 37, 14]],
 
               ...,
 
               [[10,  2, 17],
                [14, 11, 22],
                [ 9, 12, 17],
                ...,
                [14, 18, 30],
                [10, 12, 27],
                [ 8, 10, 23]],
 
               [[12,  4, 19],
                [ 9,  6, 17],
                [ 7, 12, 16],
                ...,
                [ 8, 12, 24],
                [ 3,  7, 19],
                [ 5,  9, 21]],
 
               [[ 9,  2, 18],
                [ 4,  2, 15],
                [ 5, 10, 14],
                ...,
                [ 8, 10, 22],
                [ 2,  4, 16],
                [ 7,  9, 21]]], shape=(375, 500, 3), dtype=uint8)],
       dtype=object),
 'transformed': array([array([[[0.5372549 , 0.54509807, 0.5686275 , ..., 0.5529412 ,
                 0.54901963, 0.5411765 ],
                [0.5294118 , 0.5411765 , 0.56078434, ..., 0.5568628 ,
                 0.54901963, 0.5411765 ],
                [0.5529412 , 0.5686275 , 0.57254905, ..., 0.56078434,
                 0.5529412 , 0.54509807],
                ...,
                [0.8745098 , 0.85882354, 0.7372549 , ..., 0.5921569 ,
                 0.5764706 , 0.5568628 ],
                [0.8784314 , 0.88235295, 0.8784314 , ..., 0.5803922 ,
                 0.5686275 , 0.54509807],
                [0.8901961 , 0.9254902 , 0.84313726, ..., 0.5803922 ,
                 0.5686275 , 0.5411765 ]],
 
               [[0.23137255, 0.23921569, 0.25490198, ..., 0.2784314 ,
                 0.27058825, 0.26666668],
                [0.21568628, 0.22745098, 0.24705882, ..., 0.27058825,
                 0.27058825, 0.26666668],
                [0.23137255, 0.24705882, 0.2509804 , ..., 0.27450982,
                 0.27450982, 0.26666668],
                ...,
                [0.75686276, 0.7411765 , 0.6117647 , ..., 0.2       ,
                 0.18431373, 0.16470589],
                [0.7607843 , 0.7647059 , 0.7529412 , ..., 0.1882353 ,
                 0.1764706 , 0.15294118],
                [0.7647059 , 0.8       , 0.70980394, ..., 0.19607843,
                 0.18431373, 0.15686275]],
 
               [[0.        , 0.        , 0.00784314, ..., 0.00784314,
                 0.02745098, 0.03137255],
                [0.        , 0.        , 0.00784314, ..., 0.00392157,
                 0.01960784, 0.02352941],
                [0.00392157, 0.01176471, 0.02352941, ..., 0.00392157,
                 0.01568628, 0.01568628],
                ...,
                [0.6156863 , 0.6       , 0.4627451 , ..., 0.05882353,
                 0.04313726, 0.02352941],
                [0.61960787, 0.62352943, 0.6039216 , ..., 0.04705882,
                 0.03529412, 0.01176471],
                [0.6156863 , 0.6509804 , 0.5647059 , ..., 0.05098039,
                 0.03921569, 0.01176471]]], shape=(3, 375, 500), dtype=float32),
        array([[[0.30588236, 0.3137255 , 0.3254902 , ..., 0.6       ,
                 0.69411767, 0.7529412 ],
                [0.27450982, 0.29803923, 0.28627452, ..., 0.6039216 ,
                 0.5529412 , 0.49803922],
                [0.3254902 , 0.21176471, 0.28235295, ..., 0.6156863 ,
                 0.6156863 , 0.6039216 ],
                ...,
                [0.00392157, 0.00392157, 0.01960784, ..., 0.49803922,
                 0.5686275 , 0.5647059 ],
                [0.00392157, 0.        , 0.00784314, ..., 0.54509807,
                 0.5254902 , 0.5254902 ],
                [0.        , 0.        , 0.01568628, ..., 0.4745098 ,
                 0.5803922 , 0.5372549 ]],
 
               [[0.43529412, 0.44313726, 0.45490196, ..., 0.7019608 ,
                 0.78431374, 0.84313726],
                [0.4117647 , 0.43529412, 0.41568628, ..., 0.7058824 ,
                 0.654902  , 0.6       ],
                [0.47843137, 0.36078432, 0.40392157, ..., 0.7254902 ,
                 0.7254902 , 0.70980394],
                ...,
                [0.40392157, 0.40392157, 0.4       , ..., 0.15686275,
                 0.21176471, 0.20784314],
                [0.40392157, 0.4       , 0.39607844, ..., 0.21176471,
                 0.18431373, 0.18431373],
                [0.4       , 0.4       , 0.40392157, ..., 0.18431373,
                 0.26666668, 0.20392157]],
 
               [[0.40784314, 0.40784314, 0.4117647 , ..., 0.7529412 ,
                 0.84705883, 0.90588236],
                [0.39607844, 0.4117647 , 0.3882353 , ..., 0.75686276,
                 0.7137255 , 0.65882355],
                [0.49803922, 0.37254903, 0.4117647 , ..., 0.77254903,
                 0.78039217, 0.7764706 ],
                ...,
                [0.41960785, 0.41960785, 0.42352942, ..., 0.18039216,
                 0.23921569, 0.23529412],
                [0.41960785, 0.41568628, 0.41568628, ..., 0.23921569,
                 0.20784314, 0.20784314],
                [0.4117647 , 0.41568628, 0.42745098, ..., 0.1882353 ,
                 0.2784314 , 0.22352941]]], shape=(3, 375, 500), dtype=float32),
        array([[[0.07450981, 0.09019608, 0.08627451, ..., 0.21960784,
                 0.24313726, 0.27058825],
                [0.09803922, 0.08627451, 0.08235294, ..., 0.21568628,
                 0.28627452, 0.2627451 ],
                [0.07450981, 0.07058824, 0.07450981, ..., 0.23137255,
                 0.29411766, 0.27058825],
                ...,
                [0.03921569, 0.05490196, 0.03529412, ..., 0.05490196,
                 0.03921569, 0.03137255],
                [0.04705882, 0.03529412, 0.02745098, ..., 0.03137255,
                 0.01176471, 0.01960784],
                [0.03529412, 0.01568628, 0.01960784, ..., 0.03137255,
                 0.00784314, 0.02745098]],
 
               [[0.00392157, 0.01960784, 0.00784314, ..., 0.11372549,
                 0.13333334, 0.16078432],
                [0.02745098, 0.01568628, 0.00392157, ..., 0.10588235,
                 0.16470589, 0.14117648],
                [0.01176471, 0.00784314, 0.00392157, ..., 0.10980392,
                 0.16862746, 0.14509805],
                ...,
                [0.00784314, 0.04313726, 0.04705882, ..., 0.07058824,
                 0.04705882, 0.03921569],
                [0.01568628, 0.02352941, 0.04705882, ..., 0.04705882,
                 0.02745098, 0.03529412],
                [0.00784314, 0.00784314, 0.03921569, ..., 0.03921569,
                 0.01568628, 0.03529412]],
 
               [[0.00392157, 0.01960784, 0.01176471, ..., 0.03921569,
                 0.05098039, 0.07843138],
                [0.02745098, 0.01568628, 0.00784314, ..., 0.02352941,
                 0.08627451, 0.05882353],
                [0.01176471, 0.00784314, 0.00392157, ..., 0.02745098,
                 0.08627451, 0.05490196],
                ...,
                [0.06666667, 0.08627451, 0.06666667, ..., 0.11764706,
                 0.10588235, 0.09019608],
                [0.07450981, 0.06666667, 0.0627451 , ..., 0.09411765,
                 0.07450981, 0.08235294],
                [0.07058824, 0.05882353, 0.05490196, ..., 0.08627451,
                 0.0627451 , 0.08235294]]], shape=(3, 375, 500), dtype=float32)],
       dtype=object)}
class ObjectDetectionModel:
    def __init__(self):
        # Define the model loading and initialization code in `__init__`.
        self.weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
        self.model = fasterrcnn_resnet50_fpn_v2(
            weights=self.weights,
            box_score_thresh=0.9,
        )
        if torch.cuda.is_available():
            # Move the model to GPU if it's available.
            self.model = self.model.cuda()
        self.model.eval()

    def __call__(self, input_batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
        # Define the per-batch inference code in `__call__`.
        batch = [torch.from_numpy(image) for image in input_batch["transformed"]]
        if torch.cuda.is_available():
            # Move the data to GPU if it's available.
            batch = [image.cuda() for image in batch]
        predictions = self.model(batch)
        # keep the original image for visualization purposes
        return {
            "image": input_batch["image"],
            "labels": [pred["labels"].detach().cpu().numpy() for pred in predictions],
            "boxes": [pred["boxes"].detach().cpu().numpy() for pred in predictions],
        }

然后我们使用 map_batches API 将模型应用于整个数据集。

mapmap_batches 的第一个参数是用户定义函数 (UDF),它可以是函数或类。基于函数的 UDF 作为短时运行的 Ray 任务运行,而基于类的 UDF 作为长时间运行的 Ray Actor 运行。对于基于类的 UDF,使用 concurrency 参数指定并行 Actor 的数量。batch_size 参数表示每个批量中的图像数量。

num_gpus 参数指定每个 ObjectDetectionModel 实例所需的 GPU 数量。Ray 调度器可以处理异构资源需求,以便最大化资源利用率。在这种情况下,ObjectDetectionModel 实例将在 GPU 上运行,而 preprocess_image 实例将在 CPU 上运行。

ds = ds.map_batches(
    ObjectDetectionModel,
    # Use 4 model replicas. Change this number based on the number of GPUs in your cluster.
    concurrency=4,
    batch_size=4,  # Use the largest batch size that can fit in GPU memory.
    # Specify 1 GPU per model replica. Set to 0 if you are doing CPU inference.
    num_gpus=1,
)

验证并保存结果#

接下来,让我们获取一小部分批量,并通过可视化验证推理结果。

from torchvision.transforms.functional import convert_image_dtype, to_tensor

batch = ds.take_batch(batch_size=2)
for image, labels, boxes in zip(batch["image"], batch["labels"], batch["boxes"]):
    image = convert_image_dtype(to_tensor(image), torch.uint8)
    labels = [weights.meta["categories"][i] for i in labels]
    boxes = torch.from_numpy(boxes)
    img = to_pil_image(draw_bounding_boxes(
        image,
        boxes,
        labels=labels,
        colors="red",
        width=4,
    ))
    display(img)
2025-02-05 14:22:53,627	INFO streaming_executor.py:108 -- Starting execution of Dataset. Full logs are in /tmp/ray/session_2025-02-05_14-22-49_425292_37149/logs/ray-data
2025-02-05 14:22:53,628	INFO streaming_executor.py:109 -- Execution plan of Dataset: InputDataBuffer[Input] -> ActorPoolMapOperator[ReadImage->Map(preprocess_image)->MapBatches(ObjectDetectionModel)] -> LimitOperator[limit=2]
2025-02-05 14:22:55,891	WARNING progress_bar.py:120 -- Truncating long operator name to 100 characters. To disable this behavior, set `ray.data.DataContext.get_current().DEFAULT_ENABLE_PROGRESS_BAR_NAME_TRUNCATION = False`.
../../_images/f6eabde66ffa1c2720d15d8f3b5c7b814f29c5617bf5676bf17dcd731fe73b9f.png ../../_images/0b3ebf95963769a84edfe37dd3ee262a6922c54d89fe21b3c3ed13b1533e1c2e.png

如果样本看起来不错,我们可以继续将结果保存到外部存储,例如 S3 或本地磁盘。有关所有支持的存储和文件格式,请参阅Ray Data 输入/输出

ds.write_parquet("local://tmp/inference_results")