服务 LLMs#

Ray Serve LLM API 允许用户使用熟悉的 Ray Serve API 部署多个 LLM 模型,同时提供与 OpenAI API 的兼容性。

特性#

  • ⚡️ 自动扩缩容和负载均衡

  • 🌐 统一的多节点多模型部署

  • 🔌 与 OpenAI 兼容

  • 🔄 支持多 LoRA 并共享基础模型

  • 🚀 引擎无关架构(即 vLLM, SGLang 等)

要求#

pip install ray[serve,llm]>=2.43.0 vllm>=0.7.2

# Suggested dependencies when using vllm 0.7.2:
pip install xgrammar==0.1.11 pynvml==12.0.0

关键组件#

针对服务 LLMs,ray.serve.llm 模块提供了两种关键的部署类型

LLMServer#

LLMServer 设置并管理用于模型服务的 vLLM 引擎。它可以单独使用,也可以与您自己的自定义 Ray Serve 部署结合使用。

LLMRouter#

此部署提供了一个与 OpenAI 兼容的 FastAPI 入口,并将流量路由到多模型服务中适当的模型。支持以下端点:

  • /v1/chat/completions: 聊天界面 (ChatGPT 风格)

  • /v1/completions: 文本续写

  • /v1/models: 列出可用模型

  • /v1/models/{model}: 模型信息

配置#

LLMConfig#

通过 LLMConfig 类指定模型详细信息,例如:

  • 模型加载源 (HuggingFace 或云存储)

  • 硬件要求 (加速器类型)

  • 引擎参数 (例如 vLLM 引擎 kwargs)

  • LoRA 多路复用配置

  • Serve 自动扩缩容参数

快速入门示例#

通过 LLMRouter 部署#

from ray import serve
from ray.serve.llm import LLMConfig, build_openai_app

llm_config = LLMConfig(
    model_loading_config=dict(
        model_id="qwen-0.5b",
        model_source="Qwen/Qwen2.5-0.5B-Instruct",
    ),
    deployment_config=dict(
        autoscaling_config=dict(
            min_replicas=1, max_replicas=2,
        )
    ),
    # Pass the desired accelerator type (e.g. A10G, L4, etc.)
    accelerator_type="A10G",
    # You can customize the engine arguments (e.g. vLLM engine kwargs)
    engine_kwargs=dict(
        tensor_parallel_size=2,
    ),
)

app = build_openai_app({"llm_configs": [llm_config]})
serve.run(app, blocking=True)
from ray import serve
from ray.serve.llm import LLMConfig, LLMServer, LLMRouter

llm_config = LLMConfig(
    model_loading_config=dict(
        model_id="qwen-0.5b",
        model_source="Qwen/Qwen2.5-0.5B-Instruct",
    ),
    deployment_config=dict(
        autoscaling_config=dict(
            min_replicas=1, max_replicas=2,
        )
    ),
    # Pass the desired accelerator type (e.g. A10G, L4, etc.)
    accelerator_type="A10G",
    # You can customize the engine arguments (e.g. vLLM engine kwargs)
    engine_kwargs=dict(
        tensor_parallel_size=2,
    ),
)

# Deploy the application
deployment = LLMServer.as_deployment(llm_config.get_serve_options(name_prefix="vLLM:")).bind(llm_config)
llm_app = LLMRouter.as_deployment().bind([deployment])
serve.run(llm_app, blocking=True)

您可以使用 cURL 或 OpenAI Python 客户端查询已部署的模型

curl -X POST http://localhost:8000/v1/chat/completions \
     -H "Content-Type: application/json" \
     -H "Authorization: Bearer fake-key" \
     -d '{
           "model": "qwen-0.5b",
           "messages": [{"role": "user", "content": "Hello!"}]
         }'
from openai import OpenAI

# Initialize client
client = OpenAI(base_url="http://localhost:8000/v1", api_key="fake-key")

# Basic chat completion with streaming
response = client.chat.completions.create(
    model="qwen-0.5b",
    messages=[{"role": "user", "content": "Hello!"}],
    stream=True
)

for chunk in response:
    if chunk.choices[0].delta.content is not None:
        print(chunk.choices[0].delta.content, end="", flush=True)

要部署多个模型,您可以将 LLMConfig 对象列表传递给 LLMRouter 部署

from ray import serve
from ray.serve.llm import LLMConfig, build_openai_app


llm_config1 = LLMConfig(
    model_loading_config=dict(
        model_id="qwen-0.5b",
        model_source="Qwen/Qwen2.5-0.5B-Instruct",
    ),
    deployment_config=dict(
        autoscaling_config=dict(
            min_replicas=1, max_replicas=2,
        )
    ),
    accelerator_type="A10G",
)

llm_config2 = LLMConfig(
    model_loading_config=dict(
        model_id="qwen-1.5b",
        model_source="Qwen/Qwen2.5-1.5B-Instruct",
    ),
    deployment_config=dict(
        autoscaling_config=dict(
            min_replicas=1, max_replicas=2,
        )
    ),
    accelerator_type="A10G",
)

app = build_openai_app({"llm_configs": [llm_config1, llm_config2]})
serve.run(app, blocking=True)
from ray import serve
from ray.serve.llm import LLMConfig, LLMServer, LLMRouter

llm_config1 = LLMConfig(
    model_loading_config=dict(
        model_id="qwen-0.5b",
        model_source="Qwen/Qwen2.5-0.5B-Instruct",
    ),
    deployment_config=dict(
        autoscaling_config=dict(
            min_replicas=1, max_replicas=2,
        )
    ),
    accelerator_type="A10G",
)

llm_config2 = LLMConfig(
    model_loading_config=dict(
        model_id="qwen-1.5b",
        model_source="Qwen/Qwen2.5-1.5B-Instruct",
    ),
    deployment_config=dict(
        autoscaling_config=dict(
            min_replicas=1, max_replicas=2,
        )
    ),
    accelerator_type="A10G",
)

# Deploy the application
deployment1 = LLMServer.as_deployment(llm_config1.get_serve_options(name_prefix="vLLM:")).bind(llm_config1)
deployment2 = LLMServer.as_deployment(llm_config2.get_serve_options(name_prefix="vLLM:")).bind(llm_config2)
llm_app = LLMRouter.as_deployment().bind([deployment1, deployment2])
serve.run(llm_app, blocking=True)

另请参阅 Serve DeepSeek 查看部署 DeepSeek 模型的示例。

生产部署#

对于生产部署,Ray Serve LLM 提供了配置驱动部署的实用工具。您可以使用 YAML 文件指定部署配置

# config.yaml
applications:
- args:
    llm_configs:
        - model_loading_config:
            model_id: qwen-0.5b
            model_source: Qwen/Qwen2.5-0.5B-Instruct
          accelerator_type: A10G
          deployment_config:
            autoscaling_config:
                min_replicas: 1
                max_replicas: 2
        - model_loading_config:
            model_id: qwen-1.5b
            model_source: Qwen/Qwen2.5-1.5B-Instruct
          accelerator_type: A10G
          deployment_config:
            autoscaling_config:
                min_replicas: 1
                max_replicas: 2
  import_path: ray.serve.llm:build_openai_app
  name: llm_app
  route_prefix: "/"
# config.yaml
applications:
- args:
    llm_configs:
        - models/qwen-0.5b.yaml
        - models/qwen-1.5b.yaml
  import_path: ray.serve.llm:build_openai_app
  name: llm_app
  route_prefix: "/"
# models/qwen-0.5b.yaml
model_loading_config:
  model_id: qwen-0.5b
  model_source: Qwen/Qwen2.5-0.5B-Instruct
accelerator_type: A10G
deployment_config:
  autoscaling_config:
    min_replicas: 1
    max_replicas: 2
# models/qwen-1.5b.yaml
model_loading_config:
  model_id: qwen-1.5b
  model_source: Qwen/Qwen2.5-1.5B-Instruct
accelerator_type: A10G
deployment_config:
  autoscaling_config:
    min_replicas: 1
    max_replicas: 2

要使用任一配置文件进行部署

serve run config.yaml

生成配置文件#

Ray Serve LLM 提供了一个 CLI 用于生成部署配置文件

python -m ray.serve.llm.gen_config

注意:此命令需要交互式输入。您应该直接在终端中执行它。

此命令允许您从一组常见的 OSS LLMs 中选择并帮助您配置它们。您可以调整 GPU 类型、张量并行度和自动扩缩容参数等设置。

请注意,如果您配置的模型架构与提供的模型列表不同,您应该仔细检查生成的模型配置文件以提供正确的值。

此命令生成两个文件:一个 LLM 配置文件,保存在 model_config/ 目录中;以及一个 Ray Serve 配置文件 serve_TIMESTAMP.yaml,您可以在将来参考和重新运行。

阅读并检查生成的模型配置文件的外观。请参考 vLLMEngine Config 进行进一步定制。

高级使用模式#

对于每种使用模式,我们都提供了一个服务器和客户端代码片段。

多 LoRA 部署#

您可以使用 LoRA (低秩适应) 通过配置 LoraConfig 有效地微调模型。我们利用 Ray Serve 的多路复用功能从同一模型服务多个 LoRA 检查点。这允许权重在每个副本上即时加载,并通过 LRU 机制进行缓存。

from ray import serve
from ray.serve.llm import LLMConfig, build_openai_app

# Configure the model with LoRA
llm_config = LLMConfig(
    model_loading_config=dict(
        model_id="qwen-0.5b",
        model_source="Qwen/Qwen2.5-0.5B-Instruct",
    ),
    lora_config=dict(
        # Let's pretend this is where LoRA weights are stored on S3.
        # For example
        # s3://my_dynamic_lora_path/lora_model_1_ckpt
        # s3://my_dynamic_lora_path/lora_model_2_ckpt
        # are two of the LoRA checkpoints
        dynamic_lora_loading_path="s3://my_dynamic_lora_path",
        max_num_adapters_per_replica=16,
    ),
    engine_kwargs=dict(
        enable_lora=True,
    ),
    deployment_config=dict(
        autoscaling_config=dict(
            min_replicas=1,
            max_replicas=2,
        )
    ),
    accelerator_type="A10G",
)

# Build and deploy the model
app = build_openai_app({"llm_configs": [llm_config]})
serve.run(app, blocking=True)
from openai import OpenAI

# Initialize client
client = OpenAI(base_url="http://localhost:8000/v1", api_key="fake-key")

# Make a request to the desired lora checkpoint
response = client.chat.completions.create(
    model="qwen-0.5b:lora_model_1_ckpt",
    messages=[{"role": "user", "content": "Hello!"}],
    stream=True,
)

for chunk in response:
    if chunk.choices[0].delta.content is not None:
        print(chunk.choices[0].delta.content, end="", flush=True)

结构化输出#

对于结构化输出,您可以使用类似于 OpenAI API 的 JSON 模式

from ray import serve
from ray.serve.llm import LLMConfig, build_openai_app

llm_config = LLMConfig(
    model_loading_config=dict(
        model_id="qwen-0.5b",
        model_source="Qwen/Qwen2.5-0.5B-Instruct",
    ),
    deployment_config=dict(
        autoscaling_config=dict(
            min_replicas=1,
            max_replicas=2,
        )
    ),
    accelerator_type="A10G",
)

# Build and deploy the model
app = build_openai_app({"llm_configs": [llm_config]})
serve.run(app, blocking=True)
from openai import OpenAI

# Initialize client
client = OpenAI(base_url="http://localhost:8000/v1", api_key="fake-key")

# Request structured JSON output
response = client.chat.completions.create(
    model="qwen-0.5b",
    response_format={"type": "json_object"},
    messages=[
        {
            "role": "system",
            "content": "You are a helpful assistant that outputs JSON."
        },
        {
            "role": "user",
            "content": "List three colors in JSON format"
        }
    ],
    stream=True,
)

for chunk in response:
    if chunk.choices[0].delta.content is not None:
        print(chunk.choices[0].delta.content, end="", flush=True)
# Example response:
# {
#   "colors": [
#     "red",
#     "blue",
#     "green"
#   ]
# }

如果您愿意,还可以使用 pydantic 模型指定您想要的响应 schema

from openai import OpenAI
from typing import List, Literal
from pydantic import BaseModel

# Initialize client
client = OpenAI(base_url="http://localhost:8000/v1", api_key="fake-key")

# Define a pydantic model of a preset of allowed colors
class Color(BaseModel):
    colors: List[Literal["cyan", "magenta", "yellow"]]

# Request structured JSON output
response = client.chat.completions.create(
    model="qwen-0.5b",
    response_format={
        "type": "json_schema",
        "json_schema": Color.model_json_schema()

    },
    messages=[
        {
            "role": "system",
            "content": "You are a helpful assistant that outputs JSON."
        },
        {
            "role": "user",
            "content": "List three colors in JSON format"
        }
    ],
    stream=True,
)

for chunk in response:
    if chunk.choices[0].delta.content is not None:
        print(chunk.choices[0].delta.content, end="", flush=True)
# Example response:
# {
#   "colors": [
#     "cyan",
#     "magenta",
#     "yellow"
#   ]
# }

视觉语言模型#

对于可以同时处理文本和图像的多模态模型

from ray import serve
from ray.serve.llm import LLMConfig, build_openai_app


# Configure a vision model
llm_config = LLMConfig(
    model_loading_config=dict(
        model_id="pixtral-12b",
        model_source="mistral-community/pixtral-12b",
    ),
    deployment_config=dict(
        autoscaling_config=dict(
            min_replicas=1,
            max_replicas=2,
        )
    ),
    accelerator_type="L40S",
    engine_kwargs=dict(
        tensor_parallel_size=1,
        max_model_len=8192,
    ),
)

# Build and deploy the model
app = build_openai_app({"llm_configs": [llm_config]})
serve.run(app, blocking=True)
from openai import OpenAI

# Initialize client
client = OpenAI(base_url="http://localhost:8000/v1", api_key="fake-key")

# Create and send a request with an image
response = client.chat.completions.create(
    model="pixtral-12b",
    messages=[
        {
            "role": "user",
            "content": [
                {
                    "type": "text",
                    "text": "What's in this image?"
                },
                {
                    "type": "image_url",
                    "image_url": {
                        "url": "https://example.com/image.jpg"
                    }
                }
            ]
        }
    ],
    stream=True,
)

for chunk in response:
    if chunk.choices[0].delta.content is not None:
        print(chunk.choices[0].delta.content, end="", flush=True)

使用远程存储存放模型权重#

您可以使用远程存储(S3 和 GCS)来存储您的模型权重,而不是从 Hugging Face 下载它们。

例如,如果您有一个存储在 S3 中的模型,其结构如下所示:

$ aws s3 ls air-example-data/rayllm-ossci/meta-Llama-3.2-1B-Instruct/
2025-03-25 11:37:48       1519 .gitattributes
2025-03-25 11:37:48       7712 LICENSE.txt
2025-03-25 11:37:48      41742 README.md
2025-03-25 11:37:48       6021 USE_POLICY.md
2025-03-25 11:37:48        877 config.json
2025-03-25 11:37:48        189 generation_config.json
2025-03-25 11:37:48 2471645608 model.safetensors
2025-03-25 11:37:53        296 special_tokens_map.json
2025-03-25 11:37:53    9085657 tokenizer.json
2025-03-25 11:37:53      54528 tokenizer_config.json

然后您可以在 model_loading_config 中指定 bucket_uri 指向您的 S3 存储桶。

# config.yaml
applications:
- args:
    llm_configs:
        - accelerator_type: A10G
          engine_kwargs:
            max_model_len: 8192
          model_loading_config:
            model_id: my_llama
            model_source:
              bucket_uri: s3://anonymous@air-example-data/rayllm-ossci/meta-Llama-3.2-1B-Instruct
  import_path: ray.serve.llm:build_openai_app
  name: llm_app
  route_prefix: "/"

常见问题#

如何使用受限的 Huggingface 模型?#

您可以使用 runtime_env 指定访问模型所需的环境变量。要设置部署选项,您可以在 LLMConfig 对象上使用 get_serve_options 方法。

from ray import serve
from ray.serve.llm import LLMConfig, LLMServer, LLMRouter
import os

llm_config = LLMConfig(
    model_loading_config=dict(
        model_id="llama-3-8b-instruct",
        model_source="meta-llama/Meta-Llama-3-8B-Instruct",
    ),
    deployment_config=dict(
        autoscaling_config=dict(
            min_replicas=1, max_replicas=2,
        )
    ),
    # Pass the desired accelerator type (e.g. A10G, L4, etc.)
    accelerator_type="A10G",
    runtime_env=dict(
        env_vars=dict(
            HF_TOKEN=os.environ["HF_TOKEN"]
        )
    ),
)

# Deploy the application
deployment = LLMServer.as_deployment(llm_config.get_serve_options(name_prefix="vLLM:")).bind(llm_config)
llm_app = LLMRouter.as_deployment().bind([deployment])
serve.run(llm_app, blocking=True)

为什么模型下载这么慢?#

如果您使用的是 huggingface 模型,可以通过设置 HF_HUB_ENABLE_HF_TRANSFER 并安装 pip install hf_transfer 来启用快速下载。

from ray import serve
from ray.serve.llm import LLMConfig, LLMServer, LLMRouter
import os

llm_config = LLMConfig(
    model_loading_config=dict(
        model_id="llama-3-8b-instruct",
        model_source="meta-llama/Meta-Llama-3-8B-Instruct",
    ),
    deployment_config=dict(
        autoscaling_config=dict(
            min_replicas=1, max_replicas=2,
        )
    ),
    # Pass the desired accelerator type (e.g. A10G, L4, etc.)
    accelerator_type="A10G",
    runtime_env=dict(
        env_vars=dict(
            HF_TOKEN=os.environ["HF_TOKEN"],
            HF_HUB_ENABLE_HF_TRANSFER="1"
        )
    ),
)

# Deploy the application
deployment = LLMServer.as_deployment(llm_config.get_serve_options(name_prefix="vLLM:")).bind(llm_config)
llm_app = LLMRouter.as_deployment().bind([deployment])
serve.run(llm_app, blocking=True)

如何配置 tokenizer 池大小以避免卡死?#

在使用 vLLM 的 engine_kwargs 中的 tokenizer_pool_size 时,还需要一起配置 tokenizer_pool_size,以便正确调度 tokenizer 组。

示例如下所示

# config.yaml
applications:
- args:
    llm_configs:
        - engine_kwargs:
            max_model_len: 1000
            tokenizer_pool_size: 2
            tokenizer_pool_extra_config: "{\"runtime_env\": {}}"
          model_loading_config:
            model_id: Qwen/Qwen2.5-7B-Instruct
  import_path: ray.serve.llm:build_openai_app
  name: llm_app
  route_prefix: "/"

使用数据收集#

我们收集使用数据以改进 Ray Serve LLM。我们收集以下特性和属性的数据:

  • 用于服务的模型架构

  • 是否使用 JSON 模式

  • 是否使用 LoRA 以及部署时初始加载了多少 LoRA 权重

  • 是否使用自动扩缩容以及最小和最大副本设置

  • 使用的张量并行大小

  • 初始副本数量

  • 使用的 GPU 类型和数量

如果您想选择退出使用数据收集,可以按照 Ray 使用统计 中的说明进行禁用。