设置 gRPC 服务#

本节将帮助您了解如何

  • 构建用户定义的 gRPC 服务和 protobuf

  • 启用 gRPC 启动 Serve

  • 部署 gRPC 应用

  • 向 Serve 部署发送 gRPC 请求

  • 检查代理健康状况

  • 使用 gRPC 元数据

  • 使用流式传输和模型组合

  • 处理错误

  • 使用 gRPC 上下文

定义 gRPC 服务#

运行 gRPC 服务首先需要定义 gRPC 服务、RPC 方法和 protobuf,类似于下面所示。

// user_defined_protos.proto

syntax = "proto3";

option java_multiple_files = true;
option java_package = "io.ray.examples.user_defined_protos";
option java_outer_classname = "UserDefinedProtos";

package userdefinedprotos;

message UserDefinedMessage {
  string name = 1;
  string origin = 2;
  int64 num = 3;
}

message UserDefinedResponse {
  string greeting = 1;
  int64 num = 2;
}

message UserDefinedMessage2 {}

message UserDefinedResponse2 {
  string greeting = 1;
}

message ImageData {
  string url = 1;
  string filename = 2;
}

message ImageClass {
  repeated string classes = 1;
  repeated float probabilities = 2;
}

service UserDefinedService {
  rpc __call__(UserDefinedMessage) returns (UserDefinedResponse);
  rpc Multiplexing(UserDefinedMessage2) returns (UserDefinedResponse2);
  rpc Streaming(UserDefinedMessage) returns (stream UserDefinedResponse);
}

service ImageClassificationService {
  rpc Predict(ImageData) returns (ImageClass);
}

本示例创建一个名为 user_defined_protos.proto 的文件,其中包含两个 gRPC 服务:UserDefinedServiceImageClassificationServiceUserDefinedService 包含三个 RPC 方法:__call__MultiplexingStreamingImageClassificationService 包含一个 RPC 方法:Predict。每种 RPC 方法对应的输入和输出类型也已明确定义。

定义 .proto 服务后,使用 grpcio-tools 编译这些服务的 Python 代码。例如,命令如下所示

python -m grpc_tools.protoc -I=. --python_out=. --grpc_python_out=. ./user_defined_protos.proto

它会生成两个文件:user_defined_protos_pb2.pyuser_defined_protos_pb2_grpc.py

有关 grpcio-tools 的更多详细信息,请参阅 https://grpc.org.cn/docs/languages/python/basics/#generating-client-and-server-code

注意

确保生成的文件与 Ray 集群运行所在的目录相同,以便 Serve 在启动代理时可以导入这些文件。

启用 gRPC 启动 Serve#

Serve start 命令行界面、ray.serve.start API 和Serve 配置文件都支持通过 gRPC 代理启动 Serve。有两个选项与 Serve 的 gRPC 代理相关:grpc_portgrpc_servicer_functionsgrpc_port 是 gRPC 代理监听的端口,默认为 9000。grpc_servicer_functions 是 gRPC add_servicer_to_server 函数的导入路径列表,用于添加到 gRPC 代理。它也作为是否启动 gRPC 服务器的标志。默认为空列表,表示不启动 gRPC 服务器。

ray start --head
serve start \
  --grpc-port 9000 \
  --grpc-servicer-functions user_defined_protos_pb2_grpc.add_UserDefinedServiceServicer_to_server \
  --grpc-servicer-functions user_defined_protos_pb2_grpc.add_ImageClassificationServiceServicer_to_server
from ray import serve
from ray.serve.config import gRPCOptions


grpc_port = 9000
grpc_servicer_functions = [
    "user_defined_protos_pb2_grpc.add_UserDefinedServiceServicer_to_server",
    "user_defined_protos_pb2_grpc.add_ImageClassificationServiceServicer_to_server",
]
serve.start(
    grpc_options=gRPCOptions(
        port=grpc_port,
        grpc_servicer_functions=grpc_servicer_functions,
    ),
)
# config.yaml
grpc_options:
  port: 9000
  grpc_servicer_functions:
    - user_defined_protos_pb2_grpc.add_UserDefinedServiceServicer_to_server
    - user_defined_protos_pb2_grpc.add_ImageClassificationServiceServicer_to_server

applications:
  - name: app1
    route_prefix: /app1
    import_path: test_deployment_v2:g
    runtime_env: {}

  - name: app2
    route_prefix: /app2
    import_path: test_deployment_v2:g2
    runtime_env: {}
# Start Serve with above config file.
serve run config.yaml

部署 gRPC 应用#

Serve 中的 gRPC 应用工作方式与 HTTP 应用类似。唯一的区别在于,方法的输入和输出需要与 .proto 文件中定义的相匹配,并且应用的方法名称必须与预定义的 RPC 方法名称完全一致(区分大小写)。例如,如果我们要使用 __call__ 方法部署 UserDefinedService,则方法名称必须是 __call__,输入类型必须是 UserDefinedMessage,输出类型必须是 UserDefinedResponse。Serve 会将 protobuf 对象传递给方法,并期望方法返回 protobuf 对象。

部署示例

import time

from typing import Generator
from user_defined_protos_pb2 import (
    UserDefinedMessage,
    UserDefinedMessage2,
    UserDefinedResponse,
    UserDefinedResponse2,
)

import ray
from ray import serve


@serve.deployment
class GrpcDeployment:
    def __call__(self, user_message: UserDefinedMessage) -> UserDefinedResponse:
        greeting = f"Hello {user_message.name} from {user_message.origin}"
        num = user_message.num * 2
        user_response = UserDefinedResponse(
            greeting=greeting,
            num=num,
        )
        return user_response

    @serve.multiplexed(max_num_models_per_replica=1)
    async def get_model(self, model_id: str) -> str:
        return f"loading model: {model_id}"

    async def Multiplexing(
        self, user_message: UserDefinedMessage2
    ) -> UserDefinedResponse2:
        model_id = serve.get_multiplexed_model_id()
        model = await self.get_model(model_id)
        user_response = UserDefinedResponse2(
            greeting=f"Method2 called model, {model}",
        )
        return user_response

    def Streaming(
        self, user_message: UserDefinedMessage
    ) -> Generator[UserDefinedResponse, None, None]:
        for i in range(10):
            greeting = f"{i}: Hello {user_message.name} from {user_message.origin}"
            num = user_message.num * 2 + i
            user_response = UserDefinedResponse(
                greeting=greeting,
                num=num,
            )
            yield user_response

            time.sleep(0.1)


g = GrpcDeployment.bind()

部署应用

app1 = "app1"
serve.run(target=g, name=app1, route_prefix=f"/{app1}")

注意

由于与 HTTP 共享代码路径,截至 Ray 2.7.0,route_prefix 仍然是必需字段。未来版本将使其对 gRPC 可选。

向 Serve 部署发送 gRPC 请求#

向 Serve 部署发送 gRPC 请求类似于向任何其他 gRPC 服务器发送 gRPC 请求。创建一个 gRPC 通道和存根,然后调用存根上的 RPC 方法并传入适当的输入。输出是您的 Serve 应用返回的 protobuf 对象。

发送 gRPC 请求

import grpc
from user_defined_protos_pb2_grpc import UserDefinedServiceStub
from user_defined_protos_pb2 import UserDefinedMessage


channel = grpc.insecure_channel("localhost:9000")
stub = UserDefinedServiceStub(channel)
request = UserDefinedMessage(name="foo", num=30, origin="bar")

response, call = stub.__call__.with_call(request=request)
print(f"status code: {call.code()}")  # grpc.StatusCode.OK
print(f"greeting: {response.greeting}")  # "Hello foo from bar"
print(f"num: {response.num}")  # 60

阅读更多关于 Python gRPC 客户端的信息:https://grpc.org.cn/docs/languages/python/basics/#client

检查代理健康状况#

与 HTTP 的 /-/routes/-/healthz 端点类似,Serve 也提供了用于健康检查的 gRPC 服务方法。

  • /ray.serve.RayServeAPIService/ListApplications 用于列出 Serve 中部署的所有应用。

  • /ray.serve.RayServeAPIService/Healthz 用于检查代理的健康状况。如果代理健康,它将返回 OK 状态和“success”消息。

服务方法和 protobuf 定义如下

message ListApplicationsRequest {}

message ListApplicationsResponse {
  repeated string application_names = 1;
}

message HealthzRequest {}

message HealthzResponse {
  string message = 1;
}

service RayServeAPIService {
  rpc ListApplications(ListApplicationsRequest) returns (ListApplicationsResponse);
  rpc Healthz(HealthzRequest) returns (HealthzResponse);
}

您可以使用以下代码调用服务方法

import grpc
from ray.serve.generated.serve_pb2_grpc import RayServeAPIServiceStub
from ray.serve.generated.serve_pb2 import HealthzRequest, ListApplicationsRequest


channel = grpc.insecure_channel("localhost:9000")
stub = RayServeAPIServiceStub(channel)
request = ListApplicationsRequest()
response = stub.ListApplications(request=request)
print(f"Applications: {response.application_names}")  # ["app1"]

request = HealthzRequest()
response = stub.Healthz(request=request)
print(f"Health: {response.message}")  # "success"

注意

Serve 提供了 RayServeAPIServiceStub 存根以及 HealthzRequestListApplicationsRequest protobuf 供您使用。您无需从 proto 文件生成它们。它们可供您参考。

使用 gRPC 元数据#

就像 HTTP 头部一样,gRPC 也支持使用元数据传递请求相关信息。您可以将元数据传递给 Serve 的 gRPC 代理,Serve 知道如何解析和使用它们。Serve 还会将尾随元数据传递回客户端。

Serve 接受的元数据键列表

  • application:要路由到的 Serve 应用名称。如果未传递且仅部署了一个应用,Serve 会自动路由到该应用。

  • request_id:用于跟踪请求的请求 ID。

  • multiplexed_model_id:用于模型多路复用的模型 ID。

Serve 返回的尾随元数据键列表

  • request_id:用于跟踪请求的请求 ID。

使用元数据的示例

import grpc
from user_defined_protos_pb2_grpc import UserDefinedServiceStub
from user_defined_protos_pb2 import UserDefinedMessage2


channel = grpc.insecure_channel("localhost:9000")
stub = UserDefinedServiceStub(channel)
request = UserDefinedMessage2()
app_name = "app1"
request_id = "123"
multiplexed_model_id = "999"
metadata = (
    ("application", app_name),
    ("request_id", request_id),
    ("multiplexed_model_id", multiplexed_model_id),
)

response, call = stub.Multiplexing.with_call(request=request, metadata=metadata)
print(f"greeting: {response.greeting}")  # "Method2 called model, loading model: 999"
for key, value in call.trailing_metadata():
    print(f"trailing metadata key: {key}, value {value}")  # "request_id: 123"

使用流式传输和模型组合#

gRPC 代理的功能与 HTTP 代理保持一致。以下是使用 gRPC 代理获取流式响应以及进行模型组合的更多示例。

流式传输#

上面的应用“app1”部署了 Streaming 方法。以下代码获取流式响应。

import grpc
from user_defined_protos_pb2_grpc import UserDefinedServiceStub
from user_defined_protos_pb2 import UserDefinedMessage


channel = grpc.insecure_channel("localhost:9000")
stub = UserDefinedServiceStub(channel)
request = UserDefinedMessage(name="foo", num=30, origin="bar")
metadata = (("application", "app1"),)

responses = stub.Streaming(request=request, metadata=metadata)
for response in responses:
    print(f"greeting: {response.greeting}")  # greeting: n: Hello foo from bar
    print(f"num: {response.num}")  # num: 60 + n

模型组合#

假设我们有以下部署。ImageDownloaderDataPreprocessor 是在 PyTorch 运行推理之前下载和处理图像的两个独立步骤。ImageClassifier 部署初始化模型,调用 ImageDownloaderDataPreprocessor,并将结果输入 resnet 模型以获取给定图像的类别和概率。

import requests
import torch
from typing import List
from PIL import Image
from io import BytesIO
from torchvision import transforms
from user_defined_protos_pb2 import (
    ImageClass,
    ImageData,
)

from ray import serve
from ray.serve.handle import DeploymentHandle


@serve.deployment
class ImageClassifier:
    def __init__(
        self,
        _image_downloader: DeploymentHandle,
        _data_preprocessor: DeploymentHandle,
    ):
        self._image_downloader = _image_downloader
        self._data_preprocessor = _data_preprocessor
        self.model = torch.hub.load(
            "pytorch/vision:v0.10.0", "resnet18", pretrained=True
        )
        self.model.eval()
        self.categories = self._image_labels()

    def _image_labels(self) -> List[str]:
        categories = []
        url = (
            "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
        )
        labels = requests.get(url).text
        for label in labels.split("\n"):
            categories.append(label.strip())
        return categories

    async def Predict(self, image_data: ImageData) -> ImageClass:
        # Download image
        image = await self._image_downloader.remote(image_data.url)

        # Preprocess image
        input_batch = await self._data_preprocessor.remote(image)
        # Predict image
        with torch.no_grad():
            output = self.model(input_batch)

        probabilities = torch.nn.functional.softmax(output[0], dim=0)
        return self.process_model_outputs(probabilities)

    def process_model_outputs(self, probabilities: torch.Tensor) -> ImageClass:
        image_classes = []
        image_probabilities = []
        # Show top categories per image
        top5_prob, top5_catid = torch.topk(probabilities, 5)
        for i in range(top5_prob.size(0)):
            image_classes.append(self.categories[top5_catid[i]])
            image_probabilities.append(top5_prob[i].item())

        return ImageClass(
            classes=image_classes,
            probabilities=image_probabilities,
        )


@serve.deployment
class ImageDownloader:
    def __call__(self, image_url: str):
        image_bytes = requests.get(image_url).content
        return Image.open(BytesIO(image_bytes)).convert("RGB")


@serve.deployment
class DataPreprocessor:
    def __init__(self):
        self.preprocess = transforms.Compose(
            [
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                ),
            ]
        )

    def __call__(self, image: Image):
        input_tensor = self.preprocess(image)
        return input_tensor.unsqueeze(0)  # create a mini-batch as expected by the model


image_downloader = ImageDownloader.bind()
data_preprocessor = DataPreprocessor.bind()
g2 = ImageClassifier.options(name="grpc-image-classifier").bind(
    image_downloader, data_preprocessor
)

我们可以使用以下代码部署应用

app2 = "app2"
serve.run(target=g2, name=app2, route_prefix=f"/{app2}")

调用该应用的客户端代码如下所示

import grpc
from user_defined_protos_pb2_grpc import ImageClassificationServiceStub
from user_defined_protos_pb2 import ImageData


channel = grpc.insecure_channel("localhost:9000")
stub = ImageClassificationServiceStub(channel)
request = ImageData(url="https://github.com/pytorch/hub/raw/master/images/dog.jpg")
metadata = (("application", "app2"),)  # Make sure application metadata is passed.

response, call = stub.Predict.with_call(request=request, metadata=metadata)
print(f"status code: {call.code()}")  # grpc.StatusCode.OK
print(f"Classes: {response.classes}")  # ['Samoyed', ...]
print(f"Probabilities: {response.probabilities}")  # [0.8846230506896973, ...]

注意

此时,Serve 上正在运行两个应用,“app1”和“app2”。如果运行了多个应用,您需要将 application 传递给元数据,以便 Serve 知道要路由到哪个应用。

处理错误#

与任何其他 gRPC 服务器类似,当响应代码不是“OK”时,请求会抛出 grpc.RpcError。将您的请求代码放入 try-except 块中,并相应地处理错误。

import grpc
from user_defined_protos_pb2_grpc import UserDefinedServiceStub
from user_defined_protos_pb2 import UserDefinedMessage


channel = grpc.insecure_channel("localhost:9000")
stub = UserDefinedServiceStub(channel)
request = UserDefinedMessage(name="foo", num=30, origin="bar")

try:
    response = stub.__call__(request=request)
except grpc.RpcError as rpc_error:
    print(f"status code: {rpc_error.code()}")  # StatusCode.NOT_FOUND
    print(f"details: {rpc_error.details()}")  # Application metadata not set...

Serve 使用以下 gRPC 错误码

  • NOT_FOUND:当 Serve 部署了多个应用,且元数据中未传递应用名称或传递的应用名称不匹配时。

  • UNAVAILABLE:仅在代理处于 draining 状态时用于健康检查方法。当健康检查抛出 UNAVAILABLE 时,表示该节点上的健康检查失败,您不应再路由到该节点。

  • DEADLINE_EXCEEDED:请求花费的时间超过超时设置并被取消。

  • INTERNAL:请求期间发生其他未处理的错误。

使用 gRPC 上下文#

Serve 为部署副本提供了一个 gRPC 上下文对象,用于获取请求信息以及设置响应元数据,例如代码和详细信息。如果处理函数定义了一个 grpc_context 参数,Serve 将为每个请求传入一个 RayServegRPCContext 对象。下面是如何设置自定义状态码、详细信息和尾随元数据的示例。

from user_defined_protos_pb2 import UserDefinedMessage, UserDefinedResponse

from ray import serve
from ray.serve.grpc_util import RayServegRPCContext

import grpc
from typing import Tuple


@serve.deployment
class GrpcDeployment:
    def __init__(self):
        self.nums = {}

    def num_lookup(self, name: str) -> Tuple[int, grpc.StatusCode, str]:
        if name not in self.nums:
            self.nums[name] = len(self.nums)
            code = grpc.StatusCode.INVALID_ARGUMENT
            message = f"{name} not found, adding to nums."
        else:
            code = grpc.StatusCode.OK
            message = f"{name} found."
        return self.nums[name], code, message

    def __call__(
        self,
        user_message: UserDefinedMessage,
        grpc_context: RayServegRPCContext,  # to use grpc context, add this kwarg
    ) -> UserDefinedResponse:
        greeting = f"Hello {user_message.name} from {user_message.origin}"
        num, code, message = self.num_lookup(user_message.name)

        # Set custom code, details, and trailing metadata.
        grpc_context.set_code(code)
        grpc_context.set_details(message)
        grpc_context.set_trailing_metadata([("num", str(num))])

        user_response = UserDefinedResponse(
            greeting=greeting,
            num=num,
        )
        return user_response


g = GrpcDeployment.bind()
app1 = "app1"
serve.run(target=g, name=app1, route_prefix=f"/{app1}")

客户端代码定义如下,用于获取这些属性。

import grpc
from user_defined_protos_pb2_grpc import UserDefinedServiceStub
from user_defined_protos_pb2 import UserDefinedMessage


channel = grpc.insecure_channel("localhost:9000")
stub = UserDefinedServiceStub(channel)
request = UserDefinedMessage(name="foo", num=30, origin="bar")
metadata = (("application", "app1"),)

# First call is going to page miss and return INVALID_ARGUMENT status code.
try:
    response, call = stub.__call__.with_call(request=request, metadata=metadata)
except grpc.RpcError as rpc_error:
    assert rpc_error.code() == grpc.StatusCode.INVALID_ARGUMENT
    assert rpc_error.details() == "foo not found, adding to nums."
    assert any(
        [key == "num" and value == "0" for key, value in rpc_error.trailing_metadata()]
    )
    assert any([key == "request_id" for key, _ in rpc_error.trailing_metadata()])

# Second call is going to page hit and return OK status code.
response, call = stub.__call__.with_call(request=request, metadata=metadata)
assert call.code() == grpc.StatusCode.OK
assert call.details() == "foo found."
assert any([key == "num" and value == "0" for key, value in call.trailing_metadata()])
assert any([key == "request_id" for key, _ in call.trailing_metadata()])

注意

如果处理程序抛出未处理的异常,Serve 将返回 INTERNAL 错误码,并在详细信息中包含堆栈跟踪,无论在 RayServegRPCContext 对象中设置了什么代码和详细信息。