具有结构化输出的批量推理(引导解码)#

结构化输出(或称为引导解码、JSON 模式)是一个有用的功能,可确保 LLM 响应遵循给定的 JSON 或上下文无关文法输出模式。

在此示例中,我们将演示如何使用 Ray Data LLM 对 JSON 格式的结构化输出执行批量推理。要运行此示例,我们需要安装以下依赖项

pip install -qU "ray[data]" "vllm==0.7.2" "xgrammar==0.1.11"
from pydantic import BaseModel

import ray
from ray.data.llm import build_llm_processor, vLLMEngineProcessorConfig

# 1. Construct a guided decoding schema. It can be:
# choice: List[str]
# json: str
# grammar: str
# See https://docs.vllm.com.cn/en/latest/getting_started/examples/structured_outputs.html
# for more details about how to construct the schema. Here we use JSON as an example.
class AnswerWithExplain(BaseModel):
    problem: str
    answer: int
    explain: str

json_schema = AnswerWithExplain.model_json_schema()

# 2. construct a vLLM processor config.
processor_config = vLLMEngineProcessorConfig(
    # The base model.
    model_source="unsloth/Llama-3.2-1B-Instruct",
    # vLLM engine config.
    engine_kwargs=dict(
        # Specify the guided decoding library to use. The default is "xgrammar".
        # See https://docs.vllm.com.cn/en/latest/serving/engine_args.html
        # for other available libraries.
        guided_decoding_backend="xgrammar",
        # Older GPUs (e.g. T4) don't support bfloat16. You should remove
        # this line if you're using later GPUs.
        dtype="half",
        # Reduce the model length to fit small GPUs. You should remove
        # this line if you're using large GPUs.
        max_model_len=1024,
    ),
    # The batch size used in Ray Data.
    batch_size=16,
    # Use one GPU in this example.
    concurrency=1,
)

# 3. construct a processor using the processor config.
processor = build_llm_processor(
    processor_config,
    # Convert the input data to the OpenAI chat form.
    preprocess=lambda row: dict(
        messages=[
            {
                "role": "system",
                "content": "You are a math teacher. Give the answer to "
                "the equation and explain it. Output the problem, answer and "
                "explanation in JSON",
            },
            {
                "role": "user",
                "content": f"3 * {row['id']} + 5 = ?",
            },
        ],
        sampling_params=dict(
            temperature=0.3,
            max_tokens=150,
            detokenize=False,
            # Specify the guided decoding schema.
            guided_decoding=dict(json=json_schema),
        ),
    ),
    # Only keep the generated text in the output dataset.
    postprocess=lambda row: {
        "resp": row["generated_text"],
    },
)

# 4. Synthesize a dataset with 30 rows.
# Each row has a single column "id" ranging from 0 to 29.
ds = ray.data.range(30)
# 5. Apply the processor to the dataset. Note that this line won't kick off
# anything because processor is execution lazily.
ds = processor(ds)
# Materialization kicks off the pipeline execution.
ds = ds.materialize()

# 6. Print all outputs.
# Example output:
# {
#     "problem": "3 * 6 + 5 = ?",
#     "answer": 23,
#     "explain": "To solve this equation, we need to follow the order of
#       operations (PEMDAS): Parentheses, Exponents, Multiplication and Division,
#       and Addition and Subtraction. In this case, we first multiply 3 and 6,
#       which equals 18. Then we add 5 to 18, which equals 23."
# }
for out in ds.take_all():
    print(out["resp"])
    print("==========")

# 7. Shutdown Ray to release resources.
ray.shutdown()