设置 gRPC 服务#

本节将帮助您了解如何

  • 构建用户定义的 gRPC 服务和 protobuf

  • 启动启用了 gRPC 的 Serve

  • 部署 gRPC 应用

  • 向 Serve 部署发送 gRPC 请求

  • 检查代理健康状况

  • 使用 gRPC 元数据

  • 使用流和模型组合

  • 处理错误

  • 使用 gRPC 上下文

定义 gRPC 服务#

运行 gRPC 服务器始于定义 gRPC 服务、RPC 方法和 protobuf,类似于下面的示例。

// user_defined_protos.proto

syntax = "proto3";

option java_multiple_files = true;
option java_package = "io.ray.examples.user_defined_protos";
option java_outer_classname = "UserDefinedProtos";

package userdefinedprotos;

message UserDefinedMessage {
  string name = 1;
  string origin = 2;
  int64 num = 3;
}

message UserDefinedResponse {
  string greeting = 1;
  int64 num = 2;
}

message UserDefinedMessage2 {}

message UserDefinedResponse2 {
  string greeting = 1;
}

message ImageData {
  string url = 1;
  string filename = 2;
}

message ImageClass {
  repeated string classes = 1;
  repeated float probabilities = 2;
}

service UserDefinedService {
  rpc __call__(UserDefinedMessage) returns (UserDefinedResponse);
  rpc Multiplexing(UserDefinedMessage2) returns (UserDefinedResponse2);
  rpc Streaming(UserDefinedMessage) returns (stream UserDefinedResponse);
}

service ImageClassificationService {
  rpc Predict(ImageData) returns (ImageClass);
}

此示例创建了一个名为 user_defined_protos.proto 的文件,其中包含两个 gRPC 服务:UserDefinedServiceImageClassificationServiceUserDefinedService 具有三个 RPC 方法:__call__MultiplexingStreamingImageClassificationService 有一个 RPC 方法:Predict。它们对应的输入和输出类型也针对每个 RPC 方法进行了特定定义。

定义好 .proto 服务后,使用 grpcio-tools 为这些服务编译 Python 代码。示例如下:

python -m grpc_tools.protoc -I=. --python_out=. --grpc_python_out=. ./user_defined_protos.proto

这将生成两个文件:user_defined_protos_pb2.pyuser_defined_protos_pb2_grpc.py

有关 grpcio-tools 的更多详细信息,请参阅 https://grpc.org.cn/docs/languages/python/basics/#generating-client-and-server-code

注意

确保生成的文件的目录与 Ray 集群运行的目录相同,以便 Serve 在启动代理时可以导入它们。

启动启用了 gRPC 的 Serve#

Serve start CLI、ray.serve.start API 和 Serve 配置文件都支持启动启用了 gRPC 代理的 Serve。有两个选项与 Serve 的 gRPC 代理相关:grpc_portgrpc_servicer_functionsgrpc_port 是 gRPC 代理监听的端口。默认为 9000。grpc_servicer_functions 是要添加到 gRPC 代理的 gRPC add_servicer_to_server 函数的导入路径列表。它还作为确定是否启动 gRPC 服务器的标志。默认值是空列表,表示不启动 gRPC 服务器。

ray start --head
serve start \
  --grpc-port 9000 \
  --grpc-servicer-functions user_defined_protos_pb2_grpc.add_UserDefinedServiceServicer_to_server \
  --grpc-servicer-functions user_defined_protos_pb2_grpc.add_ImageClassificationServiceServicer_to_server
from ray import serve
from ray.serve.config import gRPCOptions


grpc_port = 9000
grpc_servicer_functions = [
    "user_defined_protos_pb2_grpc.add_UserDefinedServiceServicer_to_server",
    "user_defined_protos_pb2_grpc.add_ImageClassificationServiceServicer_to_server",
]
serve.start(
    grpc_options=gRPCOptions(
        port=grpc_port,
        grpc_servicer_functions=grpc_servicer_functions,
    ),
)
# config.yaml
grpc_options:
  port: 9000
  grpc_servicer_functions:
    - user_defined_protos_pb2_grpc.add_UserDefinedServiceServicer_to_server
    - user_defined_protos_pb2_grpc.add_ImageClassificationServiceServicer_to_server

applications:
  - name: app1
    route_prefix: /app1
    import_path: test_deployment_v2:g
    runtime_env: {}

  - name: app2
    route_prefix: /app2
    import_path: test_deployment_v2:g2
    runtime_env: {}
# Start Serve with above config file.
serve run config.yaml

部署 gRPC 应用#

Serve 中的 gRPC 应用与 HTTP 应用的工作方式类似。唯一的区别是方法的输入和输出需要与 .proto 文件中定义的内容匹配,并且应用的 method 需要与预定义的 RPC 方法完全匹配(区分大小写)。例如,如果我们想部署带有 __call__ 方法的 UserDefinedService,则方法名需要是 __call__,输入类型需要是 UserDefinedMessage,输出类型需要是 UserDefinedResponse。Serve 将 protobuf 对象传递给方法,并期望从方法返回 protobuf 对象。

部署示例

import time

from typing import Generator
from user_defined_protos_pb2 import (
    UserDefinedMessage,
    UserDefinedMessage2,
    UserDefinedResponse,
    UserDefinedResponse2,
)

import ray
from ray import serve


@serve.deployment
class GrpcDeployment:
    def __call__(self, user_message: UserDefinedMessage) -> UserDefinedResponse:
        greeting = f"Hello {user_message.name} from {user_message.origin}"
        num = user_message.num * 2
        user_response = UserDefinedResponse(
            greeting=greeting,
            num=num,
        )
        return user_response

    @serve.multiplexed(max_num_models_per_replica=1)
    async def get_model(self, model_id: str) -> str:
        return f"loading model: {model_id}"

    async def Multiplexing(
        self, user_message: UserDefinedMessage2
    ) -> UserDefinedResponse2:
        model_id = serve.get_multiplexed_model_id()
        model = await self.get_model(model_id)
        user_response = UserDefinedResponse2(
            greeting=f"Method2 called model, {model}",
        )
        return user_response

    def Streaming(
        self, user_message: UserDefinedMessage
    ) -> Generator[UserDefinedResponse, None, None]:
        for i in range(10):
            greeting = f"{i}: Hello {user_message.name} from {user_message.origin}"
            num = user_message.num * 2 + i
            user_response = UserDefinedResponse(
                greeting=greeting,
                num=num,
            )
            yield user_response

            time.sleep(0.1)


g = GrpcDeployment.bind()

部署应用

app1 = "app1"
serve.run(target=g, name=app1, route_prefix=f"/{app1}")

注意

截至 Ray 2.7.0,route_prefix 仍然是必需字段,因为它与 HTTP 共享代码路径。未来的版本将使其对 gRPC 可选。

向 Serve 部署发送 gRPC 请求#

向 Serve 部署发送 gRPC 请求与向任何其他 gRPC 服务器发送请求类似。创建 gRPC 通道和存根,然后使用适当的输入在存根上调用 RPC 方法。输出是 Serve 应用返回的 protobuf 对象。

发送 gRPC 请求

import grpc
from user_defined_protos_pb2_grpc import UserDefinedServiceStub
from user_defined_protos_pb2 import UserDefinedMessage


channel = grpc.insecure_channel("localhost:9000")
stub = UserDefinedServiceStub(channel)
request = UserDefinedMessage(name="foo", num=30, origin="bar")

response, call = stub.__call__.with_call(request=request)
print(f"status code: {call.code()}")  # grpc.StatusCode.OK
print(f"greeting: {response.greeting}")  # "Hello foo from bar"
print(f"num: {response.num}")  # 60

详细了解 Python 中的 gRPC 客户端:https://grpc.org.cn/docs/languages/python/basics/#client

检查代理健康状况#

与 HTTP 的 /-/routes/-/healthz 端点类似,Serve 还提供了 gRPC 服务方法,可用于健康检查。

  • /ray.serve.RayServeAPIService/ListApplications 用于列出 Serve 中部署的所有应用程序。

  • /ray.serve.RayServeAPIService/Healthz 用于检查代理的健康状况。如果代理健康,它将返回 OK 状态和“success”消息。

服务方法和 protobuf 定义如下:

message ListApplicationsRequest {}

message ListApplicationsResponse {
  repeated string application_names = 1;
}

message HealthzRequest {}

message HealthzResponse {
  string message = 1;
}

service RayServeAPIService {
  rpc ListApplications(ListApplicationsRequest) returns (ListApplicationsResponse);
  rpc Healthz(HealthzRequest) returns (HealthzResponse);
}

可以使用以下代码调用服务方法:

import grpc
from ray.serve.generated.serve_pb2_grpc import RayServeAPIServiceStub
from ray.serve.generated.serve_pb2 import HealthzRequest, ListApplicationsRequest


channel = grpc.insecure_channel("localhost:9000")
stub = RayServeAPIServiceStub(channel)
request = ListApplicationsRequest()
response = stub.ListApplications(request=request)
print(f"Applications: {response.application_names}")  # ["app1"]

request = HealthzRequest()
response = stub.Healthz(request=request)
print(f"Health: {response.message}")  # "success"

注意

Serve 提供了 RayServeAPIServiceStub 存根,以及 HealthzRequestListApplicationsRequest protobuf,供您使用。您无需从 proto 文件生成它们。它们可供您参考。

使用 gRPC 元数据#

与 HTTP 头类似,gRPC 也支持元数据来传递与请求相关的信息。您可以将元数据传递给 Serve 的 gRPC 代理,Serve 知道如何解析和使用它们。Serve 还会将尾部元数据返回给客户端。

Serve 接受的元数据键列表

  • application:要路由到的 Serve 应用的名称。如果未传递且只部署了一个应用,Serve 会自动路由到唯一部署的应用。

  • request_id:用于跟踪请求的请求 ID。

  • multiplexed_model_id:用于模型多路复用的模型 ID。

Serve 返回的尾部元数据键列表

  • request_id:用于跟踪请求的请求 ID。

使用元数据示例

import grpc
from user_defined_protos_pb2_grpc import UserDefinedServiceStub
from user_defined_protos_pb2 import UserDefinedMessage2


channel = grpc.insecure_channel("localhost:9000")
stub = UserDefinedServiceStub(channel)
request = UserDefinedMessage2()
app_name = "app1"
request_id = "123"
multiplexed_model_id = "999"
metadata = (
    ("application", app_name),
    ("request_id", request_id),
    ("multiplexed_model_id", multiplexed_model_id),
)

response, call = stub.Multiplexing.with_call(request=request, metadata=metadata)
print(f"greeting: {response.greeting}")  # "Method2 called model, loading model: 999"
for key, value in call.trailing_metadata():
    print(f"trailing metadata key: {key}, value {value}")  # "request_id: 123"

使用流和模型组合#

gRPC 代理在功能上与 HTTP 代理保持一致。以下是使用 gRPC 代理获取流式响应以及进行模型组合的更多示例。

流式传输#

上面部署的名为“app1”的应用中部署了 Steaming 方法。以下代码获取流式响应。

import grpc
from user_defined_protos_pb2_grpc import UserDefinedServiceStub
from user_defined_protos_pb2 import UserDefinedMessage


channel = grpc.insecure_channel("localhost:9000")
stub = UserDefinedServiceStub(channel)
request = UserDefinedMessage(name="foo", num=30, origin="bar")
metadata = (("application", "app1"),)

responses = stub.Streaming(request=request, metadata=metadata)
for response in responses:
    print(f"greeting: {response.greeting}")  # greeting: n: Hello foo from bar
    print(f"num: {response.num}")  # num: 60 + n

模型组合#

假设我们有以下部署。ImageDownloaderDataPreprocessor 是在 PyTorch 运行推理之前下载和处理图像的两个独立步骤。ImageClassifier 部署初始化模型,调用 ImageDownloaderDataPreprocessor,然后将其馈送到 resnet 模型以获取给定图像的类别和概率。

import requests
import torch
from typing import List
from PIL import Image
from io import BytesIO
from torchvision import transforms
from user_defined_protos_pb2 import (
    ImageClass,
    ImageData,
)

from ray import serve
from ray.serve.handle import DeploymentHandle


@serve.deployment
class ImageClassifier:
    def __init__(
        self,
        _image_downloader: DeploymentHandle,
        _data_preprocessor: DeploymentHandle,
    ):
        self._image_downloader = _image_downloader
        self._data_preprocessor = _data_preprocessor
        self.model = torch.hub.load(
            "pytorch/vision:v0.10.0", "resnet18", pretrained=True
        )
        self.model.eval()
        self.categories = self._image_labels()

    def _image_labels(self) -> List[str]:
        categories = []
        url = (
            "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
        )
        labels = requests.get(url).text
        for label in labels.split("\n"):
            categories.append(label.strip())
        return categories

    async def Predict(self, image_data: ImageData) -> ImageClass:
        # Download image
        image = await self._image_downloader.remote(image_data.url)

        # Preprocess image
        input_batch = await self._data_preprocessor.remote(image)
        # Predict image
        with torch.no_grad():
            output = self.model(input_batch)

        probabilities = torch.nn.functional.softmax(output[0], dim=0)
        return self.process_model_outputs(probabilities)

    def process_model_outputs(self, probabilities: torch.Tensor) -> ImageClass:
        image_classes = []
        image_probabilities = []
        # Show top categories per image
        top5_prob, top5_catid = torch.topk(probabilities, 5)
        for i in range(top5_prob.size(0)):
            image_classes.append(self.categories[top5_catid[i]])
            image_probabilities.append(top5_prob[i].item())

        return ImageClass(
            classes=image_classes,
            probabilities=image_probabilities,
        )


@serve.deployment
class ImageDownloader:
    def __call__(self, image_url: str):
        image_bytes = requests.get(image_url).content
        return Image.open(BytesIO(image_bytes)).convert("RGB")


@serve.deployment
class DataPreprocessor:
    def __init__(self):
        self.preprocess = transforms.Compose(
            [
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                ),
            ]
        )

    def __call__(self, image: Image):
        input_tensor = self.preprocess(image)
        return input_tensor.unsqueeze(0)  # create a mini-batch as expected by the model


image_downloader = ImageDownloader.bind()
data_preprocessor = DataPreprocessor.bind()
g2 = ImageClassifier.options(name="grpc-image-classifier").bind(
    image_downloader, data_preprocessor
)

我们可以使用以下代码部署应用:

app2 = "app2"
serve.run(target=g2, name=app2, route_prefix=f"/{app2}")

调用应用的客户端代码如下:

import grpc
from user_defined_protos_pb2_grpc import ImageClassificationServiceStub
from user_defined_protos_pb2 import ImageData


channel = grpc.insecure_channel("localhost:9000")
stub = ImageClassificationServiceStub(channel)
request = ImageData(url="https://github.com/pytorch/hub/raw/master/images/dog.jpg")
metadata = (("application", "app2"),)  # Make sure application metadata is passed.

response, call = stub.Predict.with_call(request=request, metadata=metadata)
print(f"status code: {call.code()}")  # grpc.StatusCode.OK
print(f"Classes: {response.classes}")  # ['Samoyed', ...]
print(f"Probabilities: {response.probabilities}")  # [0.8846230506896973, ...]

注意

此时,Serve 上运行着两个应用,“app1”和“app2”。如果运行了多个应用,您需要在元数据中传递 application,以便 Serve 知道将请求路由到哪个应用。

处理错误#

与任何其他 gRPC 服务器一样,当响应代码不是“OK”时,请求会引发 grpc.RpcError。将您的请求代码放在 try-except 块中,并相应地处理错误。

import grpc
from user_defined_protos_pb2_grpc import UserDefinedServiceStub
from user_defined_protos_pb2 import UserDefinedMessage


channel = grpc.insecure_channel("localhost:9000")
stub = UserDefinedServiceStub(channel)
request = UserDefinedMessage(name="foo", num=30, origin="bar")

try:
    response = stub.__call__(request=request)
except grpc.RpcError as rpc_error:
    print(f"status code: {rpc_error.code()}")  # StatusCode.NOT_FOUND
    print(f"details: {rpc_error.details()}")  # Application metadata not set...

Serve 使用以下 gRPC 错误码

  • NOT_FOUND:当 Serve 部署了多个应用,但未在元数据中传递应用或传递的应用不匹配时。

  • UNAVAILABLE:仅在代理处于正在排空状态的健康检查方法上。当健康检查抛出 UNAVAILABLE 时,表示此节点的健康检查失败,您不应再将流量路由到此节点。

  • DEADLINE_EXCEEDED:请求花费的时间超过了超时设置并被取消。

  • INTERNAL:请求期间的其他未处理错误。

使用 gRPC 上下文#

Serve 向部署副本提供一个 gRPC context 对象,用于获取有关请求的信息以及设置响应元数据(例如代码和详细信息)。如果处理函数定义了 grpc_context 参数,Serve 将为每个请求传入一个 RayServegRPCContext 对象。以下是一个如何设置自定义状态码、详细信息和尾部元数据的示例。

from user_defined_protos_pb2 import UserDefinedMessage, UserDefinedResponse

from ray import serve
from ray.serve.grpc_util import RayServegRPCContext

import grpc
from typing import Tuple


@serve.deployment
class GrpcDeployment:
    def __init__(self):
        self.nums = {}

    def num_lookup(self, name: str) -> Tuple[int, grpc.StatusCode, str]:
        if name not in self.nums:
            self.nums[name] = len(self.nums)
            code = grpc.StatusCode.INVALID_ARGUMENT
            message = f"{name} not found, adding to nums."
        else:
            code = grpc.StatusCode.OK
            message = f"{name} found."
        return self.nums[name], code, message

    def __call__(
        self,
        user_message: UserDefinedMessage,
        grpc_context: RayServegRPCContext,  # to use grpc context, add this kwarg
    ) -> UserDefinedResponse:
        greeting = f"Hello {user_message.name} from {user_message.origin}"
        num, code, message = self.num_lookup(user_message.name)

        # Set custom code, details, and trailing metadata.
        grpc_context.set_code(code)
        grpc_context.set_details(message)
        grpc_context.set_trailing_metadata([("num", str(num))])

        user_response = UserDefinedResponse(
            greeting=greeting,
            num=num,
        )
        return user_response


g = GrpcDeployment.bind()
app1 = "app1"
serve.run(target=g, name=app1, route_prefix=f"/{app1}")

客户端代码定义如下,用于获取这些属性:

import grpc
from user_defined_protos_pb2_grpc import UserDefinedServiceStub
from user_defined_protos_pb2 import UserDefinedMessage


channel = grpc.insecure_channel("localhost:9000")
stub = UserDefinedServiceStub(channel)
request = UserDefinedMessage(name="foo", num=30, origin="bar")
metadata = (("application", "app1"),)

# First call is going to page miss and return INVALID_ARGUMENT status code.
try:
    response, call = stub.__call__.with_call(request=request, metadata=metadata)
except grpc.RpcError as rpc_error:
    assert rpc_error.code() == grpc.StatusCode.INVALID_ARGUMENT
    assert rpc_error.details() == "foo not found, adding to nums."
    assert any(
        [key == "num" and value == "0" for key, value in rpc_error.trailing_metadata()]
    )
    assert any([key == "request_id" for key, _ in rpc_error.trailing_metadata()])

# Second call is going to page hit and return OK status code.
response, call = stub.__call__.with_call(request=request, metadata=metadata)
assert call.code() == grpc.StatusCode.OK
assert call.details() == "foo found."
assert any([key == "num" and value == "0" for key, value in call.trailing_metadata()])
assert any([key == "request_id" for key, _ in call.trailing_metadata()])

注意

如果处理程序引发了未处理的异常,Serve 将返回一个 INTERNAL 错误码,并在详细信息中包含堆栈跟踪,无论 RayServegRPCContext 对象中设置了什么代码和详细信息。