服务 LLMs#
Ray Serve LLM API 允许用户使用熟悉的 Ray Serve API 部署多个 LLM 模型,同时提供与 OpenAI API 的兼容性。
特性#
⚡️ 自动扩缩容和负载均衡
🌐 统一的多节点多模型部署
🔌 与 OpenAI 兼容
🔄 支持多 LoRA 并共享基础模型
🚀 引擎无关架构(即 vLLM, SGLang 等)
要求#
pip install ray[serve,llm]>=2.43.0 vllm>=0.7.2
# Suggested dependencies when using vllm 0.7.2:
pip install xgrammar==0.1.11 pynvml==12.0.0
关键组件#
针对服务 LLMs,ray.serve.llm 模块提供了两种关键的部署类型
LLMServer#
LLMServer 设置并管理用于模型服务的 vLLM 引擎。它可以单独使用,也可以与您自己的自定义 Ray Serve 部署结合使用。
LLMRouter#
此部署提供了一个与 OpenAI 兼容的 FastAPI 入口,并将流量路由到多模型服务中适当的模型。支持以下端点:
/v1/chat/completions
: 聊天界面 (ChatGPT 风格)/v1/completions
: 文本续写/v1/models
: 列出可用模型/v1/models/{model}
: 模型信息
配置#
LLMConfig#
通过 LLMConfig
类指定模型详细信息,例如:
模型加载源 (HuggingFace 或云存储)
硬件要求 (加速器类型)
引擎参数 (例如 vLLM 引擎 kwargs)
LoRA 多路复用配置
Serve 自动扩缩容参数
快速入门示例#
通过 LLMRouter
部署#
from ray import serve
from ray.serve.llm import LLMConfig, build_openai_app
llm_config = LLMConfig(
model_loading_config=dict(
model_id="qwen-0.5b",
model_source="Qwen/Qwen2.5-0.5B-Instruct",
),
deployment_config=dict(
autoscaling_config=dict(
min_replicas=1, max_replicas=2,
)
),
# Pass the desired accelerator type (e.g. A10G, L4, etc.)
accelerator_type="A10G",
# You can customize the engine arguments (e.g. vLLM engine kwargs)
engine_kwargs=dict(
tensor_parallel_size=2,
),
)
app = build_openai_app({"llm_configs": [llm_config]})
serve.run(app, blocking=True)
from ray import serve
from ray.serve.llm import LLMConfig, LLMServer, LLMRouter
llm_config = LLMConfig(
model_loading_config=dict(
model_id="qwen-0.5b",
model_source="Qwen/Qwen2.5-0.5B-Instruct",
),
deployment_config=dict(
autoscaling_config=dict(
min_replicas=1, max_replicas=2,
)
),
# Pass the desired accelerator type (e.g. A10G, L4, etc.)
accelerator_type="A10G",
# You can customize the engine arguments (e.g. vLLM engine kwargs)
engine_kwargs=dict(
tensor_parallel_size=2,
),
)
# Deploy the application
deployment = LLMServer.as_deployment(llm_config.get_serve_options(name_prefix="vLLM:")).bind(llm_config)
llm_app = LLMRouter.as_deployment().bind([deployment])
serve.run(llm_app, blocking=True)
您可以使用 cURL 或 OpenAI Python 客户端查询已部署的模型
curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer fake-key" \
-d '{
"model": "qwen-0.5b",
"messages": [{"role": "user", "content": "Hello!"}]
}'
from openai import OpenAI
# Initialize client
client = OpenAI(base_url="http://localhost:8000/v1", api_key="fake-key")
# Basic chat completion with streaming
response = client.chat.completions.create(
model="qwen-0.5b",
messages=[{"role": "user", "content": "Hello!"}],
stream=True
)
for chunk in response:
if chunk.choices[0].delta.content is not None:
print(chunk.choices[0].delta.content, end="", flush=True)
要部署多个模型,您可以将 LLMConfig
对象列表传递给 LLMRouter
部署
from ray import serve
from ray.serve.llm import LLMConfig, build_openai_app
llm_config1 = LLMConfig(
model_loading_config=dict(
model_id="qwen-0.5b",
model_source="Qwen/Qwen2.5-0.5B-Instruct",
),
deployment_config=dict(
autoscaling_config=dict(
min_replicas=1, max_replicas=2,
)
),
accelerator_type="A10G",
)
llm_config2 = LLMConfig(
model_loading_config=dict(
model_id="qwen-1.5b",
model_source="Qwen/Qwen2.5-1.5B-Instruct",
),
deployment_config=dict(
autoscaling_config=dict(
min_replicas=1, max_replicas=2,
)
),
accelerator_type="A10G",
)
app = build_openai_app({"llm_configs": [llm_config1, llm_config2]})
serve.run(app, blocking=True)
from ray import serve
from ray.serve.llm import LLMConfig, LLMServer, LLMRouter
llm_config1 = LLMConfig(
model_loading_config=dict(
model_id="qwen-0.5b",
model_source="Qwen/Qwen2.5-0.5B-Instruct",
),
deployment_config=dict(
autoscaling_config=dict(
min_replicas=1, max_replicas=2,
)
),
accelerator_type="A10G",
)
llm_config2 = LLMConfig(
model_loading_config=dict(
model_id="qwen-1.5b",
model_source="Qwen/Qwen2.5-1.5B-Instruct",
),
deployment_config=dict(
autoscaling_config=dict(
min_replicas=1, max_replicas=2,
)
),
accelerator_type="A10G",
)
# Deploy the application
deployment1 = LLMServer.as_deployment(llm_config1.get_serve_options(name_prefix="vLLM:")).bind(llm_config1)
deployment2 = LLMServer.as_deployment(llm_config2.get_serve_options(name_prefix="vLLM:")).bind(llm_config2)
llm_app = LLMRouter.as_deployment().bind([deployment1, deployment2])
serve.run(llm_app, blocking=True)
另请参阅 Serve DeepSeek 查看部署 DeepSeek 模型的示例。
生产部署#
对于生产部署,Ray Serve LLM 提供了配置驱动部署的实用工具。您可以使用 YAML 文件指定部署配置
# config.yaml
applications:
- args:
llm_configs:
- model_loading_config:
model_id: qwen-0.5b
model_source: Qwen/Qwen2.5-0.5B-Instruct
accelerator_type: A10G
deployment_config:
autoscaling_config:
min_replicas: 1
max_replicas: 2
- model_loading_config:
model_id: qwen-1.5b
model_source: Qwen/Qwen2.5-1.5B-Instruct
accelerator_type: A10G
deployment_config:
autoscaling_config:
min_replicas: 1
max_replicas: 2
import_path: ray.serve.llm:build_openai_app
name: llm_app
route_prefix: "/"
# config.yaml
applications:
- args:
llm_configs:
- models/qwen-0.5b.yaml
- models/qwen-1.5b.yaml
import_path: ray.serve.llm:build_openai_app
name: llm_app
route_prefix: "/"
# models/qwen-0.5b.yaml
model_loading_config:
model_id: qwen-0.5b
model_source: Qwen/Qwen2.5-0.5B-Instruct
accelerator_type: A10G
deployment_config:
autoscaling_config:
min_replicas: 1
max_replicas: 2
# models/qwen-1.5b.yaml
model_loading_config:
model_id: qwen-1.5b
model_source: Qwen/Qwen2.5-1.5B-Instruct
accelerator_type: A10G
deployment_config:
autoscaling_config:
min_replicas: 1
max_replicas: 2
要使用任一配置文件进行部署
serve run config.yaml
生成配置文件#
Ray Serve LLM 提供了一个 CLI 用于生成部署配置文件
python -m ray.serve.llm.gen_config
注意:此命令需要交互式输入。您应该直接在终端中执行它。
此命令允许您从一组常见的 OSS LLMs 中选择并帮助您配置它们。您可以调整 GPU 类型、张量并行度和自动扩缩容参数等设置。
请注意,如果您配置的模型架构与提供的模型列表不同,您应该仔细检查生成的模型配置文件以提供正确的值。
此命令生成两个文件:一个 LLM 配置文件,保存在 model_config/
目录中;以及一个 Ray Serve 配置文件 serve_TIMESTAMP.yaml
,您可以在将来参考和重新运行。
阅读并检查生成的模型配置文件的外观。请参考 vLLMEngine Config 进行进一步定制。
高级使用模式#
对于每种使用模式,我们都提供了一个服务器和客户端代码片段。
多 LoRA 部署#
您可以使用 LoRA (低秩适应) 通过配置 LoraConfig
有效地微调模型。我们利用 Ray Serve 的多路复用功能从同一模型服务多个 LoRA 检查点。这允许权重在每个副本上即时加载,并通过 LRU 机制进行缓存。
from ray import serve
from ray.serve.llm import LLMConfig, build_openai_app
# Configure the model with LoRA
llm_config = LLMConfig(
model_loading_config=dict(
model_id="qwen-0.5b",
model_source="Qwen/Qwen2.5-0.5B-Instruct",
),
lora_config=dict(
# Let's pretend this is where LoRA weights are stored on S3.
# For example
# s3://my_dynamic_lora_path/lora_model_1_ckpt
# s3://my_dynamic_lora_path/lora_model_2_ckpt
# are two of the LoRA checkpoints
dynamic_lora_loading_path="s3://my_dynamic_lora_path",
max_num_adapters_per_replica=16,
),
engine_kwargs=dict(
enable_lora=True,
),
deployment_config=dict(
autoscaling_config=dict(
min_replicas=1,
max_replicas=2,
)
),
accelerator_type="A10G",
)
# Build and deploy the model
app = build_openai_app({"llm_configs": [llm_config]})
serve.run(app, blocking=True)
from openai import OpenAI
# Initialize client
client = OpenAI(base_url="http://localhost:8000/v1", api_key="fake-key")
# Make a request to the desired lora checkpoint
response = client.chat.completions.create(
model="qwen-0.5b:lora_model_1_ckpt",
messages=[{"role": "user", "content": "Hello!"}],
stream=True,
)
for chunk in response:
if chunk.choices[0].delta.content is not None:
print(chunk.choices[0].delta.content, end="", flush=True)
结构化输出#
对于结构化输出,您可以使用类似于 OpenAI API 的 JSON 模式
from ray import serve
from ray.serve.llm import LLMConfig, build_openai_app
llm_config = LLMConfig(
model_loading_config=dict(
model_id="qwen-0.5b",
model_source="Qwen/Qwen2.5-0.5B-Instruct",
),
deployment_config=dict(
autoscaling_config=dict(
min_replicas=1,
max_replicas=2,
)
),
accelerator_type="A10G",
)
# Build and deploy the model
app = build_openai_app({"llm_configs": [llm_config]})
serve.run(app, blocking=True)
from openai import OpenAI
# Initialize client
client = OpenAI(base_url="http://localhost:8000/v1", api_key="fake-key")
# Request structured JSON output
response = client.chat.completions.create(
model="qwen-0.5b",
response_format={"type": "json_object"},
messages=[
{
"role": "system",
"content": "You are a helpful assistant that outputs JSON."
},
{
"role": "user",
"content": "List three colors in JSON format"
}
],
stream=True,
)
for chunk in response:
if chunk.choices[0].delta.content is not None:
print(chunk.choices[0].delta.content, end="", flush=True)
# Example response:
# {
# "colors": [
# "red",
# "blue",
# "green"
# ]
# }
如果您愿意,还可以使用 pydantic 模型指定您想要的响应 schema
from openai import OpenAI
from typing import List, Literal
from pydantic import BaseModel
# Initialize client
client = OpenAI(base_url="http://localhost:8000/v1", api_key="fake-key")
# Define a pydantic model of a preset of allowed colors
class Color(BaseModel):
colors: List[Literal["cyan", "magenta", "yellow"]]
# Request structured JSON output
response = client.chat.completions.create(
model="qwen-0.5b",
response_format={
"type": "json_schema",
"json_schema": Color.model_json_schema()
},
messages=[
{
"role": "system",
"content": "You are a helpful assistant that outputs JSON."
},
{
"role": "user",
"content": "List three colors in JSON format"
}
],
stream=True,
)
for chunk in response:
if chunk.choices[0].delta.content is not None:
print(chunk.choices[0].delta.content, end="", flush=True)
# Example response:
# {
# "colors": [
# "cyan",
# "magenta",
# "yellow"
# ]
# }
视觉语言模型#
对于可以同时处理文本和图像的多模态模型
from ray import serve
from ray.serve.llm import LLMConfig, build_openai_app
# Configure a vision model
llm_config = LLMConfig(
model_loading_config=dict(
model_id="pixtral-12b",
model_source="mistral-community/pixtral-12b",
),
deployment_config=dict(
autoscaling_config=dict(
min_replicas=1,
max_replicas=2,
)
),
accelerator_type="L40S",
engine_kwargs=dict(
tensor_parallel_size=1,
max_model_len=8192,
),
)
# Build and deploy the model
app = build_openai_app({"llm_configs": [llm_config]})
serve.run(app, blocking=True)
from openai import OpenAI
# Initialize client
client = OpenAI(base_url="http://localhost:8000/v1", api_key="fake-key")
# Create and send a request with an image
response = client.chat.completions.create(
model="pixtral-12b",
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": "What's in this image?"
},
{
"type": "image_url",
"image_url": {
"url": "https://example.com/image.jpg"
}
}
]
}
],
stream=True,
)
for chunk in response:
if chunk.choices[0].delta.content is not None:
print(chunk.choices[0].delta.content, end="", flush=True)
使用远程存储存放模型权重#
您可以使用远程存储(S3 和 GCS)来存储您的模型权重,而不是从 Hugging Face 下载它们。
例如,如果您有一个存储在 S3 中的模型,其结构如下所示:
$ aws s3 ls air-example-data/rayllm-ossci/meta-Llama-3.2-1B-Instruct/
2025-03-25 11:37:48 1519 .gitattributes
2025-03-25 11:37:48 7712 LICENSE.txt
2025-03-25 11:37:48 41742 README.md
2025-03-25 11:37:48 6021 USE_POLICY.md
2025-03-25 11:37:48 877 config.json
2025-03-25 11:37:48 189 generation_config.json
2025-03-25 11:37:48 2471645608 model.safetensors
2025-03-25 11:37:53 296 special_tokens_map.json
2025-03-25 11:37:53 9085657 tokenizer.json
2025-03-25 11:37:53 54528 tokenizer_config.json
然后您可以在 model_loading_config
中指定 bucket_uri
指向您的 S3 存储桶。
# config.yaml
applications:
- args:
llm_configs:
- accelerator_type: A10G
engine_kwargs:
max_model_len: 8192
model_loading_config:
model_id: my_llama
model_source:
bucket_uri: s3://anonymous@air-example-data/rayllm-ossci/meta-Llama-3.2-1B-Instruct
import_path: ray.serve.llm:build_openai_app
name: llm_app
route_prefix: "/"
常见问题#
如何使用受限的 Huggingface 模型?#
您可以使用 runtime_env
指定访问模型所需的环境变量。要设置部署选项,您可以在 LLMConfig
对象上使用 get_serve_options
方法。
from ray import serve
from ray.serve.llm import LLMConfig, LLMServer, LLMRouter
import os
llm_config = LLMConfig(
model_loading_config=dict(
model_id="llama-3-8b-instruct",
model_source="meta-llama/Meta-Llama-3-8B-Instruct",
),
deployment_config=dict(
autoscaling_config=dict(
min_replicas=1, max_replicas=2,
)
),
# Pass the desired accelerator type (e.g. A10G, L4, etc.)
accelerator_type="A10G",
runtime_env=dict(
env_vars=dict(
HF_TOKEN=os.environ["HF_TOKEN"]
)
),
)
# Deploy the application
deployment = LLMServer.as_deployment(llm_config.get_serve_options(name_prefix="vLLM:")).bind(llm_config)
llm_app = LLMRouter.as_deployment().bind([deployment])
serve.run(llm_app, blocking=True)
为什么模型下载这么慢?#
如果您使用的是 huggingface 模型,可以通过设置 HF_HUB_ENABLE_HF_TRANSFER
并安装 pip install hf_transfer
来启用快速下载。
from ray import serve
from ray.serve.llm import LLMConfig, LLMServer, LLMRouter
import os
llm_config = LLMConfig(
model_loading_config=dict(
model_id="llama-3-8b-instruct",
model_source="meta-llama/Meta-Llama-3-8B-Instruct",
),
deployment_config=dict(
autoscaling_config=dict(
min_replicas=1, max_replicas=2,
)
),
# Pass the desired accelerator type (e.g. A10G, L4, etc.)
accelerator_type="A10G",
runtime_env=dict(
env_vars=dict(
HF_TOKEN=os.environ["HF_TOKEN"],
HF_HUB_ENABLE_HF_TRANSFER="1"
)
),
)
# Deploy the application
deployment = LLMServer.as_deployment(llm_config.get_serve_options(name_prefix="vLLM:")).bind(llm_config)
llm_app = LLMRouter.as_deployment().bind([deployment])
serve.run(llm_app, blocking=True)
如何配置 tokenizer 池大小以避免卡死?#
在使用 vLLM 的 engine_kwargs
中的 tokenizer_pool_size
时,还需要一起配置 tokenizer_pool_size
,以便正确调度 tokenizer 组。
示例如下所示
# config.yaml
applications:
- args:
llm_configs:
- engine_kwargs:
max_model_len: 1000
tokenizer_pool_size: 2
tokenizer_pool_extra_config: "{\"runtime_env\": {}}"
model_loading_config:
model_id: Qwen/Qwen2.5-7B-Instruct
import_path: ray.serve.llm:build_openai_app
name: llm_app
route_prefix: "/"
使用数据收集#
我们收集使用数据以改进 Ray Serve LLM。我们收集以下特性和属性的数据:
用于服务的模型架构
是否使用 JSON 模式
是否使用 LoRA 以及部署时初始加载了多少 LoRA 权重
是否使用自动扩缩容以及最小和最大副本设置
使用的张量并行大小
初始副本数量
使用的 GPU 类型和数量
如果您想选择退出使用数据收集,可以按照 Ray 使用统计 中的说明进行禁用。