在单个或多个 Intel Gaudi 加速器上服务 Llama2-7b/70b#

Intel Gaudi AI 处理器 (HPU) 是由 Intel Habana Labs 设计的 AI 硬件加速器。有关更多详细信息,请参阅 Gaudi 架构Gaudi 开发者文档

本教程包含两个示例

  1. 使用单个 HPU 部署 Llama2-7b

    • 将模型加载到 HPU 上。

    • 在 HPU 上执行生成。

    • 启用 HPU Graph 优化。

  2. 在单个节点上使用多个 HPU 部署 Llama2-70b

    • 初始化分布式后端。

    • 将分片模型加载到 DeepSpeed workers 上。

    • 从 DeepSpeed workers 流式传输响应。

本教程在 HPU 上服务大型语言模型 (LLM)。

环境设置#

使用预构建的容器来运行这些示例。要运行容器,你需要 Docker。有关安装说明,请参阅 安装 Docker Engine

接下来,按照 使用容器运行 安装 Gaudi 驱动程序和容器运行时。要验证你的安装,启动一个 shell 并运行 hl-smi。它应该打印关于机器上 HPU 的状态信息。

+-----------------------------------------------------------------------------+
| HL-SMI Version:                              hl-1.20.0-fw-58.1.1.1          |
| Driver Version:                                     1.19.1-6f47ddd          |
| Nic Driver Version:                                 1.19.1-f071c23          |
|-------------------------------+----------------------+----------------------+
| AIP  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncor-Events|
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | AIP-Util  Compute M. |
|===============================+======================+======================|
|   0  HL-225              N/A  | 0000:9a:00.0     N/A |                   0  |
| N/A   22C   N/A  96W /  600W  |   768MiB /  98304MiB |     0%            0% |
|-------------------------------+----------------------+----------------------+
|   1  HL-225              N/A  | 0000:9b:00.0     N/A |                   0  |
| N/A   24C   N/A  78W /  600W  |   768MiB /  98304MiB |     0%            0% |
|-------------------------------+----------------------+----------------------+
|   2  HL-225              N/A  | 0000:b3:00.0     N/A |                   0  |
| N/A   25C   N/A  81W /  600W  |   768MiB /  98304MiB |     0%            0% |
|-------------------------------+----------------------+----------------------+
|   3  HL-225              N/A  | 0000:b4:00.0     N/A |                   0  |
| N/A   22C   N/A  92W /  600W  | 96565MiB /  98304MiB |     0%           98% |
|-------------------------------+----------------------+----------------------+
|   4  HL-225              N/A  | 0000:33:00.0     N/A |                   0  |
| N/A   22C   N/A  83W /  600W  |   768MiB /  98304MiB |     0%            0% |
|-------------------------------+----------------------+----------------------+
|   5  HL-225              N/A  | 0000:4e:00.0     N/A |                   0  |
| N/A   21C   N/A  80W /  600W  | 96564MiB /  98304MiB |     0%           98% |
|-------------------------------+----------------------+----------------------+
|   6  HL-225              N/A  | 0000:34:00.0     N/A |                   0  |
| N/A   25C   N/A  86W /  600W  |   768MiB /  98304MiB |     0%            0% |
|-------------------------------+----------------------+----------------------+
|   7  HL-225              N/A  | 0000:4d:00.0     N/A |                   0  |
| N/A   30C   N/A 100W /  600W  | 17538MiB /  98304MiB |     0%           17% |
|-------------------------------+----------------------+----------------------+
| Compute Processes:                                               AIP Memory |
|  AIP       PID   Type   Process name                             Usage      |
|=============================================================================|
|   0        N/A   N/A    N/A                                      N/A        |
|   1        N/A   N/A    N/A                                      N/A        |
|   2        N/A   N/A    N/A                                      N/A        |
|   3        N/A   N/A    N/A                                      N/A        |
|   4        N/A   N/A    N/A                                      N/A        |
|   5        N/A   N/A    N/A                                      N/A        |
|   6        N/A   N/A    N/A                                      N/A        |
|   7       107684     C   ray::_RayTrainW                         16770MiB    
+=============================================================================+

接下来,启动 Gaudi 容器

docker pull vault.habana.ai/gaudi-docker/1.20.0/ubuntu22.04/habanalabs/pytorch-installer-2.6.0:latest
docker run -it --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --net=host --ipc=host vault.habana.ai/gaudi-docker/1.20.0/ubuntu22.04/habanalabs/pytorch-installer-2.6.0:latest

要按照本教程中的示例操作,将包含示例和模型的目录挂载到容器中。在容器内运行

pip install ray[tune,serve]
pip install git+https://github.com/huggingface/optimum-habana.git
# Replace 1.20.0 with the driver version of the container.
pip install git+https://github.com/HabanaAI/[email protected]
# Only needed by the DeepSpeed example.
export RAY_EXPERIMENTAL_NOSET_HABANA_VISIBLE_MODULES=1

在容器中运行 ray start --head 启动 Ray。你现在就可以运行示例了。

在单个 HPU 上运行模型#

此示例展示了如何在 HPU 上部署 Llama2-7b 模型进行推理。

首先,定义一个使用 HPU 服务 Llama2-7b 模型的部署。注意,我们启用了 HPU Graph 优化以获得更好的性能。

import asyncio
from functools import partial
from queue import Empty
from typing import Dict, Any

from starlette.requests import Request
from starlette.responses import StreamingResponse
import torch

from ray import serve


# Define the Ray Serve deployment
@serve.deployment(ray_actor_options={"num_cpus": 10, "resources": {"HPU": 1}})
class LlamaModel:
    def __init__(self, model_id_or_path: str):
        from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
        from optimum.habana.transformers.modeling_utils import (
            adapt_transformers_to_gaudi,
        )

        # Tweak transformers to optimize performance
        adapt_transformers_to_gaudi()

        self.device = torch.device("hpu")

        self.tokenizer = AutoTokenizer.from_pretrained(
            model_id_or_path, use_fast=False, use_auth_token=""
        )
        hf_config = AutoConfig.from_pretrained(
            model_id_or_path,
            torchscript=True,
            use_auth_token="",
            trust_remote_code=False,
        )
        # Load the model in Gaudi
        model = AutoModelForCausalLM.from_pretrained(
            model_id_or_path,
            config=hf_config,
            torch_dtype=torch.float32,
            low_cpu_mem_usage=True,
            use_auth_token="",
        )
        model = model.eval().to(self.device)

        from habana_frameworks.torch.hpu import wrap_in_hpu_graph

        # Enable hpu graph runtime
        self.model = wrap_in_hpu_graph(model)

        # Set pad token, etc.
        self.tokenizer.pad_token_id = self.model.generation_config.pad_token_id
        self.tokenizer.padding_side = "left"

        # Use async loop in streaming
        self.loop = asyncio.get_running_loop()

    def tokenize(self, prompt: str):
        """Tokenize the input and move to HPU."""

        input_tokens = self.tokenizer(prompt, return_tensors="pt", padding=True)
        return input_tokens.input_ids.to(device=self.device)

    def generate(self, prompt: str, **config: Dict[str, Any]):
        """Take a prompt and generate a response."""

        input_ids = self.tokenize(prompt)
        gen_tokens = self.model.generate(input_ids, **config)
        return self.tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)[0]

    async def consume_streamer_async(self, streamer):
        """Consume the streamer asynchronously."""

        while True:
            try:
                for token in streamer:
                    yield token
                break
            except Empty:
                await asyncio.sleep(0.001)

    def streaming_generate(self, prompt: str, streamer, **config: Dict[str, Any]):
        """Generate a streamed response given an input."""

        input_ids = self.tokenize(prompt)
        self.model.generate(input_ids, streamer=streamer, **config)

    async def __call__(self, http_request: Request):
        """Handle HTTP requests."""

        # Load fields from the request
        json_request: str = await http_request.json()
        text = json_request["text"]
        # Config used in generation
        config = json_request.get("config", {})
        streaming_response = json_request["stream"]

        # Prepare prompts
        prompts = []
        if isinstance(text, list):
            prompts.extend(text)
        else:
            prompts.append(text)

        # Process config
        config.setdefault("max_new_tokens", 128)

        # Enable HPU graph runtime
        config["hpu_graphs"] = True
        # Lazy mode should be True when using HPU graphs
        config["lazy_mode"] = True

        # Non-streaming case
        if not streaming_response:
            return self.generate(prompts, **config)

        # Streaming case
        from transformers import TextIteratorStreamer

        streamer = TextIteratorStreamer(
            self.tokenizer, skip_prompt=True, timeout=0, skip_special_tokens=True
        )
        # Convert the streamer into a generator
        self.loop.run_in_executor(
            None, partial(self.streaming_generate, prompts, streamer, **config)
        )
        return StreamingResponse(
            self.consume_streamer_async(streamer),
            status_code=200,
            media_type="text/plain",
        )


# Replace the model ID with path if necessary
entrypoint = LlamaModel.bind("meta-llama/Llama-2-7b-chat-hf")

复制上述代码并将其保存为 intel_gaudi_inference_serve.py。如下所示启动部署

serve run intel_gaudi_inference_serve:entrypoint

部署启动时终端应打印日志

2025-03-03 06:07:08,106 INFO scripts.py:494 -- Running import path: 'infer:entrypoint'.
2025-03-03 06:07:09,295 INFO worker.py:1654 -- Connecting to existing Ray cluster at address: 100.83.111.228:6379...
2025-03-03 06:07:09,304 INFO worker.py:1832 -- Connected to Ray cluster. View the dashboard at 127.0.0.1:8265 
(ProxyActor pid=147082) INFO 2025-03-03 06:07:11,096 proxy 100.83.111.228 -- Proxy starting on node b4d028b67678bfdd190b503b44780bc319c07b1df13ac5c577873861 (HTTP port: 8000).
INFO 2025-03-03 06:07:11,202 serve 162730 -- Started Serve in namespace "serve".
INFO 2025-03-03 06:07:11,203 serve 162730 -- Connecting to existing Serve app in namespace "serve". New http options will not be applied.
(ProxyActor pid=147082) INFO 2025-03-03 06:07:11,184 proxy 100.83.111.228 -- Got updated endpoints: {}.
(ServeController pid=147087) INFO 2025-03-03 06:07:11,278 controller 147087 -- Deploying new version of Deployment(name='LlamaModel', app='default') (initial target replicas: 1).
(ProxyActor pid=147082) INFO 2025-03-03 06:07:11,280 proxy 100.83.111.228 -- Got updated endpoints: {Deployment(name='LlamaModel', app='default'): EndpointInfo(route='/', app_is_cross_language=False)}.
(ProxyActor pid=147082) INFO 2025-03-03 06:07:11,286 proxy 100.83.111.228 -- Started <ray.serve._private.router.SharedRouterLongPollClient object at 0x7f74804e90c0>.
(ServeController pid=147087) INFO 2025-03-03 06:07:11,381 controller 147087 -- Adding 1 replica to Deployment(name='LlamaModel', app='default').
(ServeReplica:default:LlamaModel pid=147085) [WARNING|utils.py:212] 2025-03-03 06:07:15,251 >> optimum-habana v1.15.0 has been validated for SynapseAI v1.19.0 but habana-frameworks v1.20.0.543 was found, this could lead to undefined behavior!
(ServeReplica:default:LlamaModel pid=147085) /usr/local/lib/python3.10/dist-packages/transformers/deepspeed.py:24: FutureWarning: transformers.deepspeed module is deprecated and will be removed in a future version. Please import deepspeed modules directly from transformers.integrations
(ServeReplica:default:LlamaModel pid=147085)   warnings.warn(
(ServeReplica:default:LlamaModel pid=147085) /usr/local/lib/python3.10/dist-packages/transformers/models/auto/tokenization_auto.py:796: FutureWarning: The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.
(ServeReplica:default:LlamaModel pid=147085)   warnings.warn(
(ServeReplica:default:LlamaModel pid=147085) /usr/local/lib/python3.10/dist-packages/transformers/models/auto/configuration_auto.py:991: FutureWarning: The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.
(ServeReplica:default:LlamaModel pid=147085)   warnings.warn(
(ServeReplica:default:LlamaModel pid=147085) /usr/local/lib/python3.10/dist-packages/transformers/models/auto/auto_factory.py:471: FutureWarning: The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.
(ServeReplica:default:LlamaModel pid=147085)   warnings.warn(
Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]
Loading checkpoint shards:  50%|█████     | 1/2 [00:01<00:01,  1.72s/it]
Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.45s/it]
(ServeReplica:default:LlamaModel pid=147085) ============================= HABANA PT BRIDGE CONFIGURATION =========================== 
(ServeReplica:default:LlamaModel pid=147085)  PT_HPU_LAZY_MODE = 1
(ServeReplica:default:LlamaModel pid=147085)  PT_HPU_RECIPE_CACHE_CONFIG = ,false,1024
(ServeReplica:default:LlamaModel pid=147085)  PT_HPU_MAX_COMPOUND_OP_SIZE = 9223372036854775807
(ServeReplica:default:LlamaModel pid=147085)  PT_HPU_LAZY_ACC_PAR_MODE = 1
(ServeReplica:default:LlamaModel pid=147085)  PT_HPU_ENABLE_REFINE_DYNAMIC_SHAPES = 0
(ServeReplica:default:LlamaModel pid=147085)  PT_HPU_EAGER_PIPELINE_ENABLE = 1
(ServeReplica:default:LlamaModel pid=147085)  PT_HPU_EAGER_COLLECTIVE_PIPELINE_ENABLE = 1
(ServeReplica:default:LlamaModel pid=147085)  PT_HPU_ENABLE_LAZY_COLLECTIVES = 0
(ServeReplica:default:LlamaModel pid=147085) ---------------------------: System Configuration :---------------------------
(ServeReplica:default:LlamaModel pid=147085) Num CPU Cores : 160
(ServeReplica:default:LlamaModel pid=147085) CPU RAM       : 1056374420 KB
(ServeReplica:default:LlamaModel pid=147085) ------------------------------------------------------------------------------
INFO 2025-03-03 06:07:30,359 serve 162730 -- Application 'default' is ready at http://127.0.0.1:8000/.
INFO 2025-03-03 06:07:30,359 serve 162730 -- Deployed app 'default' successfully.

在另一个 shell 中,使用以下代码向部署发送请求以执行生成任务。

import requests

# Prompt for the model
prompt = "Once upon a time,"

# Add generation config here
config = {}

# Non-streaming response
sample_input = {"text": prompt, "config": config, "stream": False}
outputs = requests.post("http://127.0.0.1:8000/", json=sample_input, stream=False)
print(outputs.text, flush=True)

# Streaming response
sample_input["stream"] = True
outputs = requests.post("http://127.0.0.1:8000/", json=sample_input, stream=True)
outputs.raise_for_status()
for output in outputs.iter_content(chunk_size=None, decode_unicode=True):
    print(output, end="", flush=True)
print()

以下是示例输出

Once upon a time, in a small village nestled in the rolling hills of Tuscany, there lived a young girl named Sophia.

Sophia was a curious and adventurous child, always eager to explore the world around her. She spent her days playing in the fields and forests, chasing after butterflies and watching the clouds drift lazily across the sky.

One day, as Sophia was wandering through the village, she stumbled upon a beautiful old book hidden away in a dusty corner of the local library. The book was bound in worn leather and adorned with intr
in a small village nestled in the rolling hills of Tuscany, there lived a young girl named Luna.
Luna was a curious and adventurous child, always eager to explore the world around her. She spent her days wandering through the village, discovering new sights and sounds at every turn.

One day, as she was wandering through the village, Luna stumbled upon a hidden path she had never seen before. The path was overgrown with weeds and vines, and it seemed to disappear into the distance.

Luna's curiosity was piqued,

在多个 HPU 上运行分片模型#

此示例部署一个使用 8 个 HPU 并由 DeepSpeed 编排的 Llama2-70b 模型。

此示例需要缓存 Llama2-70b 模型。在 Gaudi 容器中运行以下 Python 代码来缓存模型。

from huggingface_hub import snapshot_download
snapshot_download(
    "meta-llama/Llama-2-70b-chat-hf",
    # Replace the path if necessary.
    cache_dir=os.getenv("TRANSFORMERS_CACHE", None),
    # Specify your Hugging Face token.
    token=""
)

在此示例中,部署副本将提示发送到 DeepSpeed workers,后者运行在 Ray Actor 中

import tempfile
from typing import Dict, Any
from starlette.requests import Request
from starlette.responses import StreamingResponse

import torch
from transformers import TextStreamer

import ray
from ray import serve
from ray.util.queue import Queue
from ray.runtime_env import RuntimeEnv


@ray.remote(resources={"HPU": 1})
class DeepSpeedInferenceWorker:
    def __init__(self, model_id_or_path: str, world_size: int, local_rank: int):
        """An actor that runs a DeepSpeed inference engine.

        Arguments:
            model_id_or_path: Either a Hugging Face model ID
                or a path to a cached model.
            world_size: Total number of worker processes.
            local_rank: Rank of this worker process.
                The rank 0 worker is the head worker.
        """
        from transformers import AutoTokenizer, AutoConfig
        from optimum.habana.transformers.modeling_utils import (
            adapt_transformers_to_gaudi,
        )

        # Tweak transformers for better performance on Gaudi.
        adapt_transformers_to_gaudi()

        self.model_id_or_path = model_id_or_path
        self._world_size = world_size
        self._local_rank = local_rank
        self.device = torch.device("hpu")

        self.model_config = AutoConfig.from_pretrained(
            model_id_or_path,
            torch_dtype=torch.bfloat16,
            token="",
            trust_remote_code=False,
        )

        # Load and configure the tokenizer.
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_id_or_path, use_fast=False, token=""
        )
        self.tokenizer.padding_side = "left"
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        import habana_frameworks.torch.distributed.hccl as hccl

        # Initialize the distributed backend.
        hccl.initialize_distributed_hpu(
            world_size=world_size, rank=local_rank, local_rank=local_rank
        )
        torch.distributed.init_process_group(backend="hccl")

    def load_model(self):
        """Load the model to HPU and initialize the DeepSpeed inference engine."""

        import deepspeed
        from transformers import AutoModelForCausalLM
        from optimum.habana.checkpoint_utils import (
            get_ds_injection_policy,
            write_checkpoints_json,
        )

        # Construct the model with fake meta Tensors.
        # Loads the model weights from the checkpoint later.
        with deepspeed.OnDevice(dtype=torch.bfloat16, device="meta"):
            model = AutoModelForCausalLM.from_config(
                self.model_config, torch_dtype=torch.bfloat16
            )
        model = model.eval()

        # Create a file to indicate where the checkpoint is.
        checkpoints_json = tempfile.NamedTemporaryFile(suffix=".json", mode="w+")
        write_checkpoints_json(
            self.model_id_or_path, self._local_rank, checkpoints_json, token=""
        )

        # Prepare the DeepSpeed inference configuration.
        kwargs = {"dtype": torch.bfloat16}
        kwargs["checkpoint"] = checkpoints_json.name
        kwargs["tensor_parallel"] = {"tp_size": self._world_size}
        # Enable the HPU graph, similar to the cuda graph.
        kwargs["enable_cuda_graph"] = True
        # Specify the injection policy, required by DeepSpeed Tensor parallelism.
        kwargs["injection_policy"] = get_ds_injection_policy(self.model_config)

        # Initialize the inference engine.
        self.model = deepspeed.init_inference(model, **kwargs).module

    def tokenize(self, prompt: str):
        """Tokenize the input and move it to HPU."""

        input_tokens = self.tokenizer(prompt, return_tensors="pt", padding=True)
        return input_tokens.input_ids.to(device=self.device)

    def generate(self, prompt: str, **config: Dict[str, Any]):
        """Take in a prompt and generate a response."""

        input_ids = self.tokenize(prompt)
        gen_tokens = self.model.generate(input_ids, **config)
        return self.tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)[0]

    def streaming_generate(self, prompt: str, streamer, **config: Dict[str, Any]):
        """Generate a streamed response given an input."""

        input_ids = self.tokenize(prompt)
        self.model.generate(input_ids, streamer=streamer, **config)

    def get_streamer(self):
        """Return a streamer.

        We only need the rank 0 worker's result.
        Other workers return a fake streamer.
        """

        if self._local_rank == 0:
            return RayTextIteratorStreamer(self.tokenizer, skip_special_tokens=True)
        else:

            class FakeStreamer:
                def put(self, value):
                    pass

                def end(self):
                    pass

            return FakeStreamer()


class RayTextIteratorStreamer(TextStreamer):
    def __init__(
        self,
        tokenizer,
        skip_prompt: bool = False,
        timeout: int = None,
        **decode_kwargs: Dict[str, Any],
    ):
        super().__init__(tokenizer, skip_prompt, **decode_kwargs)
        self.text_queue = Queue()
        self.stop_signal = None
        self.timeout = timeout

    def on_finalized_text(self, text: str, stream_end: bool = False):
        self.text_queue.put(text, timeout=self.timeout)
        if stream_end:
            self.text_queue.put(self.stop_signal, timeout=self.timeout)

    def __iter__(self):
        return self

    def __next__(self):
        value = self.text_queue.get(timeout=self.timeout)
        if value == self.stop_signal:
            raise StopIteration()
        else:
            return value


接下来,定义一个部署

# We need to set these variables for this example.
HABANA_ENVS = {
    "PT_HPU_LAZY_ACC_PAR_MODE": "0",
    "PT_HPU_ENABLE_REFINE_DYNAMIC_SHAPES": "0",
    "PT_HPU_ENABLE_WEIGHT_CPU_PERMUTE": "0",
    "PT_HPU_ENABLE_LAZY_COLLECTIVES": "true",
    "HABANA_VISIBLE_MODULES": "0,1,2,3,4,5,6,7",
}


# Define the Ray Serve deployment.
@serve.deployment
class DeepSpeedLlamaModel:
    def __init__(self, world_size: int, model_id_or_path: str):
        self._world_size = world_size

        # Create the DeepSpeed workers
        self.deepspeed_workers = []
        for i in range(world_size):
            self.deepspeed_workers.append(
                DeepSpeedInferenceWorker.options(
                    runtime_env=RuntimeEnv(env_vars=HABANA_ENVS)
                ).remote(model_id_or_path, world_size, i)
            )

        # Load the model to all workers.
        for worker in self.deepspeed_workers:
            worker.load_model.remote()

        # Get the workers' streamers.
        self.streamers = ray.get(
            [worker.get_streamer.remote() for worker in self.deepspeed_workers]
        )

    def generate(self, prompt: str, **config: Dict[str, Any]):
        """Send the prompt to workers for generation.

        Return after all workers finish the generation.
        Only return the rank 0 worker's result.
        """

        futures = [
            worker.generate.remote(prompt, **config)
            for worker in self.deepspeed_workers
        ]
        return ray.get(futures)[0]

    def streaming_generate(self, prompt: str, **config: Dict[str, Any]):
        """Send the prompt to workers for streaming generation.

        Only use the rank 0 worker's result.
        """

        for worker, streamer in zip(self.deepspeed_workers, self.streamers):
            worker.streaming_generate.remote(prompt, streamer, **config)

    def consume_streamer(self, streamer):
        """Consume the streamer and return a generator."""
        for token in streamer:
            yield token

    async def __call__(self, http_request: Request):
        """Handle received HTTP requests."""

        # Load fields from the request
        json_request: str = await http_request.json()
        text = json_request["text"]
        # Config used in generation
        config = json_request.get("config", {})
        streaming_response = json_request["stream"]

        # Prepare prompts
        prompts = []
        if isinstance(text, list):
            prompts.extend(text)
        else:
            prompts.append(text)

        # Process the configuration.
        config.setdefault("max_new_tokens", 128)

        # Enable HPU graph runtime.
        config["hpu_graphs"] = True
        # Lazy mode should be True when using HPU graphs.
        config["lazy_mode"] = True

        # Non-streaming case
        if not streaming_response:
            return self.generate(prompts, **config)

        # Streaming case
        self.streaming_generate(prompts, **config)
        return StreamingResponse(
            self.consume_streamer(self.streamers[0]),
            status_code=200,
            media_type="text/plain",
        )


# Replace the model ID with a path if necessary.
entrypoint = DeepSpeedLlamaModel.bind(8, "meta-llama/Llama-2-70b-chat-hf")

复制前面两个代码块并将其保存到 intel_gaudi_inference_serve_deepspeed.py 文件中。使用 serve run intel_gaudi_inference_serve_deepspeed:entrypoint 运行此示例。

注意!!!请谨慎设置环境变量 HABANA_VISIBLE_MODULES

部署启动时终端应打印日志

2025-03-03 06:21:57,692 INFO scripts.py:494 -- Running import path: 'infer-ds:entrypoint'.
2025-03-03 06:22:03,064 INFO worker.py:1832 -- Started a local Ray instance. View the dashboard at 127.0.0.1:8265 
INFO 2025-03-03 06:22:07,343 serve 170212 -- Started Serve in namespace "serve".
INFO 2025-03-03 06:22:07,343 serve 170212 -- Connecting to existing Serve app in namespace "serve". New http options will not be applied.
(ServeController pid=170719) INFO 2025-03-03 06:22:07,377 controller 170719 -- Deploying new version of Deployment(name='DeepSpeedLlamaModel', app='default') (initial target replicas: 1).
(ProxyActor pid=170723) INFO 2025-03-03 06:22:07,290 proxy 100.83.111.228 -- Proxy starting on node 47721c925467a877497e66104328bb72dc7bd7f900a63b2f1fdb48b2 (HTTP port: 8000).
(ProxyActor pid=170723) INFO 2025-03-03 06:22:07,325 proxy 100.83.111.228 -- Got updated endpoints: {}.
(ProxyActor pid=170723) INFO 2025-03-03 06:22:07,379 proxy 100.83.111.228 -- Got updated endpoints: {Deployment(name='DeepSpeedLlamaModel', app='default'): EndpointInfo(route='/', app_is_cross_language=False)}.
(ServeController pid=170719) INFO 2025-03-03 06:22:07,478 controller 170719 -- Adding 1 replica to Deployment(name='DeepSpeedLlamaModel', app='default').
(ProxyActor pid=170723) INFO 2025-03-03 06:22:07,422 proxy 100.83.111.228 -- Started <ray.serve._private.router.SharedRouterLongPollClient object at 0x7fa557945210>.
(DeepSpeedInferenceWorker pid=179962) [WARNING|utils.py:212] 2025-03-03 06:22:14,611 >> optimum-habana v1.15.0 has been validated for SynapseAI v1.19.0 but habana-frameworks v1.20.0.543 was found, this could lead to undefined behavior!
(DeepSpeedInferenceWorker pid=179963) /usr/local/lib/python3.10/dist-packages/transformers/deepspeed.py:24: FutureWarning: transformers.deepspeed module is deprecated and will be removed in a future version. Please import deepspeed modules directly from transformers.integrations
(DeepSpeedInferenceWorker pid=179963)   warnings.warn(
(DeepSpeedInferenceWorker pid=179964) [WARNING|utils.py:212] 2025-03-03 06:22:14,613 >> optimum-habana v1.15.0 has been validated for SynapseAI v1.19.0 but habana-frameworks v1.20.0.543 was found, this could lead to undefined behavior! [repeated 3x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.rayai.org.cn/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication for more options.)
(DeepSpeedInferenceWorker pid=179962) [2025-03-03 06:22:23,502] [INFO] [real_accelerator.py:219:get_accelerator] Setting ds_accelerator to hpu (auto detect)
Loading 2 checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]
(DeepSpeedInferenceWorker pid=179962) [2025-03-03 06:22:24,032] [INFO] [logging.py:105:log_dist] [Rank -1] DeepSpeed info: version=0.16.1+hpu.synapse.v1.20.0, git-hash=61543a96, git-branch=1.20.0
(DeepSpeedInferenceWorker pid=179962) [2025-03-03 06:22:24,035] [INFO] [logging.py:105:log_dist] [Rank -1] quantize_bits = 8 mlp_extra_grouping = False, quantize_groups = 1
(DeepSpeedInferenceWorker pid=179962) [2025-03-03 06:22:24,048] [INFO] [comm.py:652:init_distributed] cdb=None
(DeepSpeedInferenceWorker pid=179963) ============================= HABANA PT BRIDGE CONFIGURATION =========================== 
(DeepSpeedInferenceWorker pid=179963)  PT_HPU_LAZY_MODE = 1
(DeepSpeedInferenceWorker pid=179963)  PT_HPU_RECIPE_CACHE_CONFIG = ,false,1024
(DeepSpeedInferenceWorker pid=179963)  PT_HPU_MAX_COMPOUND_OP_SIZE = 9223372036854775807
(DeepSpeedInferenceWorker pid=179963)  PT_HPU_LAZY_ACC_PAR_MODE = 0
(DeepSpeedInferenceWorker pid=179963)  PT_HPU_ENABLE_REFINE_DYNAMIC_SHAPES = 0
(DeepSpeedInferenceWorker pid=179963)  PT_HPU_EAGER_PIPELINE_ENABLE = 1
(DeepSpeedInferenceWorker pid=179963)  PT_HPU_EAGER_COLLECTIVE_PIPELINE_ENABLE = 1
(DeepSpeedInferenceWorker pid=179963)  PT_HPU_ENABLE_LAZY_COLLECTIVES = 1
(DeepSpeedInferenceWorker pid=179963) ---------------------------: System Configuration :---------------------------
(DeepSpeedInferenceWorker pid=179963) Num CPU Cores : 160
(DeepSpeedInferenceWorker pid=179963) CPU RAM       : 1056374420 KB
(DeepSpeedInferenceWorker pid=179963) ------------------------------------------------------------------------------
(DeepSpeedInferenceWorker pid=179964) /usr/local/lib/python3.10/dist-packages/transformers/deepspeed.py:24: FutureWarning: transformers.deepspeed module is deprecated and will be removed in a future version. Please import deepspeed modules directly from transformers.integrations [repeated 3x across cluster]
(DeepSpeedInferenceWorker pid=179964)   warnings.warn( [repeated 3x across cluster]
Loading 2 checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s] [repeated 3x across cluster]
(ServeController pid=170719) WARNING 2025-03-03 06:22:37,562 controller 170719 -- Deployment 'DeepSpeedLlamaModel' in application 'default' has 1 replicas that have taken more than 30s to initialize.
(ServeController pid=170719) This may be caused by a slow __init__ or reconfigure method.
Loading 2 checkpoint shards:  50%|█████     | 1/2 [00:17<00:17, 17.51s/it]
Loading 2 checkpoint shards: 100%|██████████| 2/2 [00:21<00:00,  9.57s/it]
Loading 2 checkpoint shards: 100%|██████████| 2/2 [00:21<00:00, 10.88s/it]
Loading 2 checkpoint shards:  50%|█████     | 1/2 [00:18<00:18, 18.70s/it] [repeated 3x across cluster]
INFO 2025-03-03 06:22:48,569 serve 170212 -- Application 'default' is ready at http://127.0.0.1:8000/.
INFO 2025-03-03 06:22:48,569 serve 170212 -- Deployed app 'default' successfully.

使用单个 HPU 示例中介绍的相同代码片段发送生成请求。以下是一个示例输出

Once upon a time, in a far-off land, there was a magical kingdom called "Happily Ever Laughter." It was a place where laughter was the key to unlocking all the joys of life, and where everyone lived in perfect harmony.

In this kingdom, there was a beautiful princess named Lily. She was kind, gentle, and had a heart full of laughter. Every day, she would wake up with a big smile on her face, ready to face whatever adventures the day might bring.

One day, a wicked sorcerer cast a spell on the kingdom
Once upon a time, in a far-off land, there was a magical kingdom called "Happily Ever Laughter." It was a place where laughter was the key to unlocking all the joys of life, and where everyone lived in perfect harmony.

In this kingdom, there was a beautiful princess named Lily. She was kind, gentle, and had a heart full of laughter. Every day, she would wake up with a big smile on her face, ready to face whatever adventures the day might bring.

One day, a wicked sorcerer cast a spell on the kingdom

下一步#

请参阅 llm-on-ray 了解更多定制和大规模部署 LLM 的方法。