在 Ray Serve 中使用 Triton Server 服务模型#

本指南展示了如何在 Ray Serve 中使用 NVIDIA Triton Server 构建一个带有 stable diffusion 模型服务应用。

准备工作#

安装#

建议使用已经安装了 Triton Server Python API 库的 nvcr.io/nvidia/tritonserver:23.12-py3 镜像,并在镜像内通过 pip install "ray[serve]" 安装 ray serve 库。

构建并导出模型#

对于此应用,编码器被导出为 ONNX 格式,而 stable diffusion 模型被导出为 TensorRT 引擎格式,该格式与 Triton Server 兼容。以下是导出模型为 ONNX 格式的示例。

import torch
from diffusers import AutoencoderKL
from transformers import CLIPTextModel, CLIPTokenizer

prompt = "Draw a dog"
vae = AutoencoderKL.from_pretrained(
    "CompVis/stable-diffusion-v1-4", subfolder="vae", use_auth_token=True
)

tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")

vae.forward = vae.decode
torch.onnx.export(
    vae,
    (torch.randn(1, 4, 64, 64), False),
    "vae.onnx",
    input_names=["latent_sample", "return_dict"],
    output_names=["sample"],
    dynamic_axes={
        "latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
    },
    do_constant_folding=True,
    opset_version=14,
)

text_input = tokenizer(
    prompt,
    padding="max_length",
    max_length=tokenizer.model_max_length,
    truncation=True,
    return_tensors="pt",
)

torch.onnx.export(
    text_encoder,
    (text_input.input_ids.to(torch.int32)),
    "encoder.onnx",
    input_names=["input_ids"],
    output_names=["last_hidden_state", "pooler_output"],
    dynamic_axes={
        "input_ids": {0: "batch", 1: "sequence"},
    },
    opset_version=14,
    do_constant_folding=True,
)

从脚本中,输出文件为 vae.onnxencoder.onnx

导出 ONNX 模型后,将 ONNX 模型转换为 TensorRT 引擎序列化文件。(关于 trtexec CLI 的详细信息

trtexec --onnx=vae.onnx --saveEngine=vae.plan --minShapes=latent_sample:1x4x64x64 --optShapes=latent_sample:4x4x64x64 --maxShapes=latent_sample:8x4x64x64 --fp16

准备模型仓库#

Triton Server 需要一个模型仓库来存储模型,它是一个包含模型配置和模型文件的本地目录或远程 blob 存储(例如 AWS S3)。在我们的示例中,我们将使用本地目录作为模型仓库来保存所有模型文件。

model_repo/
├── stable_diffusion
│   ├── 1      └── model.py
│   └── config.pbtxt
├── text_encoder
│   ├── 1      └── model.onnx
│   └── config.pbtxt
└── vae
    ├── 1
       └── model.plan
    └── config.pbtxt

模型仓库包含三个模型:stable_diffusiontext_encodervae。每个模型都有一个 config.pbtxt 文件和一个模型文件。config.pbtxt 文件包含模型配置,用于描述模型类型和输入/输出格式。(你可以在此处了解更多关于模型配置文件的信息)。要获取我们示例的配置文件,你可以从此处下载。我们使用 1 作为每个模型的版本。模型文件保存在版本目录下。

在 Ray Serve 应用内启动 Triton Server#

在每个 Serve 副本中,都运行着一个 Triton Server 实例。API 接收模型仓库路径作为参数,并在副本初始化期间启动 Triton Serve 实例。模型可以在推理请求期间加载,加载的模型会在 Triton Server 实例中缓存。

以下是使用 Triton Server 服务模型的推理代码示例。(来源

import numpy
import requests
import tritonserver
from fastapi import FastAPI
from PIL import Image
from ray import serve


app = FastAPI()

@serve.deployment(ray_actor_options={"num_gpus": 1})
@serve.ingress(app)
class TritonDeployment:
    def __init__(self):
        self._triton_server = tritonserver

        model_repository = ["/workspace/models"]

        self._triton_server = tritonserver.Server(
            model_repository=model_repository,
            model_control_mode=tritonserver.ModelControlMode.EXPLICIT,
            log_info=False,
        )
        self._triton_server.start(wait_until_ready=True)

    @app.get("/generate")
    def generate(self, prompt: str, filename: str = "generated_image.jpg") -> None:
        if not self._triton_server.model("stable_diffusion").ready():
            try:
                self._triton_server.load("text_encoder")
                self._triton_server.load("vae")
                self._stable_diffusion = self._triton_server.load("stable_diffusion")
                if not self._stable_diffusion.ready():
                    raise Exception("Model not ready")
            except Exception as error:
                print(f"Error can't load stable diffusion model, {error}")
                return

        for response in self._stable_diffusion.infer(inputs={"prompt": [[prompt]]}):
            generated_image = (
                numpy.from_dlpack(response.outputs["generated_image"])
                .squeeze()
                .astype(numpy.uint8)
            )

            image_ = Image.fromarray(generated_image)
            image_.save(filename)


if __name__ == "__main__":
    # Deploy the deployment.
    serve.run(TritonDeployment.bind())

    # Query the deployment.
    requests.get(
        "http://localhost:8000/generate",
        params={"prompt": "dogs in new york, realistic, 4k, photograph"},
    )

将上述代码保存到名为 triton_serve.py 的文件中,然后运行 python triton_serve.py 启动服务器并发送分类请求。运行上述代码后,你应该会看到生成的图片 generated_image.jpg。快看看吧! image

注意

你也可以使用远程模型仓库,例如 AWS S3,来存储模型文件。要使用远程模型仓库,你需要将 model_repository 变量设置为远程模型仓库路径。例如 model_repository = s3://<bucket_name>/<model_repository_path>

如果你发现任何 bug 或有任何建议,请在 GitHub 上提交 Issue 告诉我们。