部署文本分类模型#

本示例使用 DistilBERT 模型和 Ray Serve 构建一个 IMDB 评论分类应用。

要运行此示例,请安装以下依赖项

pip install "ray[serve]" requests torch transformers

本示例使用 distilbert-base-uncased 模型和 FastAPI。将以下代码保存到名为 distilbert_app.py 的文件中

使用以下 Serve 代码

from fastapi import FastAPI
import torch
from transformers import pipeline

from ray import serve
from ray.serve.handle import DeploymentHandle


app = FastAPI()


@serve.deployment(num_replicas=1)
@serve.ingress(app)
class APIIngress:
    def __init__(self, distilbert_model_handle: DeploymentHandle) -> None:
        self.handle = distilbert_model_handle

    @app.get("/classify")
    async def classify(self, sentence: str):
        return await self.handle.classify.remote(sentence)


@serve.deployment(
    ray_actor_options={"num_gpus": 1},
    autoscaling_config={"min_replicas": 0, "max_replicas": 2},
)
class DistilBertModel:
    def __init__(self):
        self.classifier = pipeline(
            "sentiment-analysis",
            model="distilbert-base-uncased",
            framework="pt",
            # Transformers requires you to pass device with index
            device=torch.device("cuda:0"),
        )

    def classify(self, sentence: str):
        return self.classifier(sentence)


entrypoint = APIIngress.bind(DistilBertModel.bind())

使用 serve run distilbert_app:entrypoint 命令启动 Serve 应用。

注意

自动扩展配置将 min_replicas 设置为 0,这意味着部署启动时没有 ObjectDetection 副本。这些副本仅在请求到达时才生成。当经过一段时间没有请求到达时,Serve 会将 ObjectDetection 缩减回 0 个副本以节省 GPU 资源。

您应该在日志中看到以下消息

(ServeController pid=362, ip=10.0.44.233) INFO 2023-03-08 16:44:57,579 controller 362 http_state.py:129 - Starting HTTP proxy with name 'SERVE_CONTROLLER_ACTOR:SERVE_PROXY_ACTOR-7396d5a9efdb59ee01b7befba448433f6c6fc734cfa5421d415da1b3' on node '7396d5a9efdb59ee01b7befba448433f6c6fc734cfa5421d415da1b3' listening on '127.0.0.1:8000'
(ServeController pid=362, ip=10.0.44.233) INFO 2023-03-08 16:44:57,588 controller 362 http_state.py:129 - Starting HTTP proxy with name 'SERVE_CONTROLLER_ACTOR:SERVE_PROXY_ACTOR-a30ea53938547e0bf88ce8672e578f0067be26a7e26d23465c46300b' on node 'a30ea53938547e0bf88ce8672e578f0067be26a7e26d23465c46300b' listening on '127.0.0.1:8000'
(ProxyActor pid=439, ip=10.0.44.233) INFO:     Started server process [439]
(ProxyActor pid=5779) INFO:     Started server process [5779]
(ServeController pid=362, ip=10.0.44.233) INFO 2023-03-08 16:44:59,362 controller 362 deployment_state.py:1333 - Adding 1 replica to deployment 'APIIngress'.
2023-03-08 16:45:01,316 SUCC <string>:93 -- Deployed Serve app successfully.

使用以下代码发送请求

import requests

prompt = "This was a masterpiece. Not completely faithful to the books, but enthralling from beginning to end. Might be my favorite of the three."
input = "%20".join(prompt.split(" "))
resp = requests.get(f"http://127.0.0.1:8000/classify?sentence={prompt}")
print(resp.status_code, resp.json())

客户端代码的输出是响应状态码、标签(在本例中为 positive)以及该标签的得分。

200 [{'label': 'LABEL_1', 'score': 0.9994940757751465}]