使用 FastAPI 在 AWS NeuronCores 上为 Stable Diffusion 模型提供推理服务#

本示例使用预编译的 Stable Diffusion XL 模型,并通过 Ray Serve 和 FastAPI 部署到 AWS Inferentia2 (Inf2) 实例上。

注意

在开始本示例之前

  • 设置 PyTorch Neuron

  • 根据实例类型安装 AWS NeuronCore 驱动程序和工具,以及 torch-neuronx

pip install "optimum-neuron==0.0.13" "diffusers==0.21.4"
pip install "ray[serve]" requests transformers

本示例使用 Stable Diffusion-XL 模型和 FastAPI。此模型已使用 AWS Neuron 编译,可直接运行推理。但是,您可以选择其他 Stable Diffusion 模型并将其编译以兼容在 AWS Inferentia2 实例上运行推理。

本示例中的模型已准备好部署。将以下代码保存到名为 aws_neuron_core_inference_serve_stable_diffusion.py 的文件中。

使用 serve run aws_neuron_core_inference_serve_stable_diffusion:entrypoint 启动 Serve 应用。

from io import BytesIO
from fastapi import FastAPI
from fastapi.responses import Response
from ray import serve


app = FastAPI()

neuron_cores = 2


@serve.deployment(num_replicas=1, route_prefix="/")
@serve.ingress(app)
class APIIngress:
    def __init__(self, diffusion_model_handle) -> None:
        self.handle = diffusion_model_handle

    @app.get(
        "/imagine",
        responses={200: {"content": {"image/png": {}}}},
        response_class=Response,
    )
    async def generate(self, prompt: str):

        image_ref = await self.handle.generate.remote(prompt)
        image = image_ref
        file_stream = BytesIO()
        image.save(file_stream, "PNG")
        return Response(content=file_stream.getvalue(), media_type="image/png")


@serve.deployment(
    ray_actor_options={"resources": {"neuron_cores": neuron_cores}},
    autoscaling_config={"min_replicas": 1, "max_replicas": 2},
)
class StableDiffusionV2:
    def __init__(self):
        from optimum.neuron import NeuronStableDiffusionXLPipeline

        compiled_model_id = "aws-neuron/stable-diffusion-xl-base-1-0-1024x1024"
        self.pipe = NeuronStableDiffusionXLPipeline.from_pretrained(
            compiled_model_id, device_ids=[0, 1]
        )

    async def generate(self, prompt: str):

        assert len(prompt), "prompt parameter cannot be empty"
        image = self.pipe(prompt).images[0]
        return image


entrypoint = APIIngress.bind(StableDiffusionV2.bind())

当使用 RayServe 的部署成功时,您应该看到以下日志消息

2024-02-07 17:53:28,299	INFO worker.py:1715 -- Started a local Ray instance. View the dashboard at http://127.0.0.1:8265 
(ProxyActor pid=25282) INFO 2024-02-07 17:53:31,751 proxy 172.31.10.188 proxy.py:1128 - Proxy actor fd464602af1e456162edf6f901000000 starting on node 5a8e0c24b22976f1f7672cc54f13ace25af3664a51429d8e332c0679.
(ProxyActor pid=25282) INFO 2024-02-07 17:53:31,755 proxy 172.31.10.188 proxy.py:1333 - Starting HTTP server on node: 5a8e0c24b22976f1f7672cc54f13ace25af3664a51429d8e332c0679 listening on port 8000
(ProxyActor pid=25282) INFO:     Started server process [25282]
(ServeController pid=25233) INFO 2024-02-07 17:53:31,921 controller 25233 deployment_state.py:1545 - Deploying new version of deployment StableDiffusionV2 in application 'default'. Setting initial target number of replicas to 1.
(ServeController pid=25233) INFO 2024-02-07 17:53:31,922 controller 25233 deployment_state.py:1545 - Deploying new version of deployment APIIngress in application 'default'. Setting initial target number of replicas to 1.
(ServeController pid=25233) INFO 2024-02-07 17:53:32,024 controller 25233 deployment_state.py:1829 - Adding 1 replica to deployment StableDiffusionV2 in application 'default'.
(ServeController pid=25233) INFO 2024-02-07 17:53:32,029 controller 25233 deployment_state.py:1829 - Adding 1 replica to deployment APIIngress in application 'default'.
Fetching 20 files: 100%|██████████| 20/20 [00:00<00:00, 195538.65it/s]
(ServeController pid=25233) WARNING 2024-02-07 17:54:02,114 controller 25233 deployment_state.py:2171 - Deployment 'StableDiffusionV2' in application 'default' has 1 replicas that have taken more than 30s to initialize. This may be caused by a slow __init__ or reconfigure method.
(ServeController pid=25233) WARNING 2024-02-07 17:54:32,170 controller 25233 deployment_state.py:2171 - Deployment 'StableDiffusionV2' in application 'default' has 1 replicas that have taken more than 30s to initialize. This may be caused by a slow __init__ or reconfigure method.
(ServeController pid=25233) WARNING 2024-02-07 17:55:02,344 controller 25233 deployment_state.py:2171 - Deployment 'StableDiffusionV2' in application 'default' has 1 replicas that have taken more than 30s to initialize. This may be caused by a slow __init__ or reconfigure method.
(ServeController pid=25233) WARNING 2024-02-07 17:55:32,418 controller 25233 deployment_state.py:2171 - Deployment 'StableDiffusionV2' in application 'default' has 1 replicas that have taken more than 30s to initialize. This may be caused by a slow __init__ or reconfigure method.
2024-02-07 17:55:46,263	SUCC scripts.py:483 -- Deployed Serve app successfully.

使用以下代码发送请求

import requests

prompt = "a zebra is dancing in the grass, river, sunlit"
input = "%20".join(prompt.split(" "))
resp = requests.get(f"http://127.0.0.1:8000/imagine?prompt={input}")
print("Write the response to `output.png`.")
with open("output.png", "wb") as f:
    f.write(resp.content)

向端点发送请求时,您应该看到以下日志消息

(ServeReplica:default:StableDiffusionV2 pid=25320) Prompt:  a zebra is dancing in the grass, river, sunlit
  0%|          | 0/50 [00:00<?, ?it/s]2 pid=25320) 
  2%|▏         | 1/50 [00:00<00:14,  3.43it/s]320) 
  4%|▍         | 2/50 [00:00<00:13,  3.62it/s]320) 
  6%|▌         | 3/50 [00:00<00:12,  3.73it/s]320) 
  8%|▊         | 4/50 [00:01<00:12,  3.78it/s]320) 
 10%|█         | 5/50 [00:01<00:11,  3.81it/s]320) 
 12%|█▏        | 6/50 [00:01<00:11,  3.82it/s]320) 
 14%|█▍        | 7/50 [00:01<00:11,  3.83it/s]320) 
 16%|█▌        | 8/50 [00:02<00:10,  3.84it/s]320) 
 18%|█▊        | 9/50 [00:02<00:10,  3.85it/s]320) 
 20%|██        | 10/50 [00:02<00:10,  3.85it/s]20) 
 22%|██▏       | 11/50 [00:02<00:10,  3.85it/s]20) 
 24%|██▍       | 12/50 [00:03<00:09,  3.86it/s]20) 
 26%|██▌       | 13/50 [00:03<00:09,  3.86it/s]20) 
 28%|██▊       | 14/50 [00:03<00:09,  3.85it/s]20) 
 30%|███       | 15/50 [00:03<00:09,  3.85it/s]20) 
 32%|███▏      | 16/50 [00:04<00:08,  3.85it/s]20) 
 34%|███▍      | 17/50 [00:04<00:08,  3.85it/s]20) 
 36%|███▌      | 18/50 [00:04<00:08,  3.85it/s]20) 
 38%|███▊      | 19/50 [00:04<00:08,  3.86it/s]20) 
 40%|████      | 20/50 [00:05<00:07,  3.85it/s]20) 
 42%|████▏     | 21/50 [00:05<00:07,  3.85it/s]20) 
 44%|████▍     | 22/50 [00:05<00:07,  3.85it/s]20) 
 46%|████▌     | 23/50 [00:06<00:07,  3.81it/s]20) 
 48%|████▊     | 24/50 [00:06<00:06,  3.81it/s]20) 
 50%|█████     | 25/50 [00:06<00:06,  3.82it/s]20) 
 52%|█████▏    | 26/50 [00:06<00:06,  3.83it/s]20) 
 54%|█████▍    | 27/50 [00:07<00:05,  3.84it/s]20) 
 56%|█████▌    | 28/50 [00:07<00:05,  3.84it/s]20) 
 58%|█████▊    | 29/50 [00:07<00:05,  3.84it/s]20) 
 60%|██████    | 30/50 [00:07<00:05,  3.84it/s]20) 
 62%|██████▏   | 31/50 [00:08<00:04,  3.84it/s]20) 
 64%|██████▍   | 32/50 [00:08<00:04,  3.84it/s]20) 
 66%|██████▌   | 33/50 [00:08<00:04,  3.85it/s]20) 
 68%|██████▊   | 34/50 [00:08<00:04,  3.85it/s]20) 
 70%|███████   | 35/50 [00:09<00:03,  3.84it/s]20) 
 72%|███████▏  | 36/50 [00:09<00:03,  3.84it/s]20) 
 74%|███████▍  | 37/50 [00:09<00:03,  3.84it/s]20) 
 76%|███████▌  | 38/50 [00:09<00:03,  3.84it/s]20) 
 78%|███████▊  | 39/50 [00:10<00:02,  3.84it/s]20) 
 80%|████████  | 40/50 [00:10<00:02,  3.84it/s]20) 
 82%|████████▏ | 41/50 [00:10<00:02,  3.84it/s]20) 
 84%|████████▍ | 42/50 [00:10<00:02,  3.84it/s]20) 
 86%|████████▌ | 43/50 [00:11<00:01,  3.84it/s]20) 
 88%|████████▊ | 44/50 [00:11<00:01,  3.84it/s]20) 
 90%|█████████ | 45/50 [00:11<00:01,  3.84it/s]20) 
 92%|█████████▏| 46/50 [00:11<00:01,  3.85it/s]20) 
 94%|█████████▍| 47/50 [00:12<00:00,  3.85it/s]20) 
 96%|█████████▌| 48/50 [00:12<00:00,  3.84it/s]20) 
 98%|█████████▊| 49/50 [00:12<00:00,  3.84it/s]20) 
100%|██████████| 50/50 [00:13<00:00,  3.83it/s]20) 
(ServeReplica:default:StableDiffusionV2 pid=25320) INFO 2024-02-07 17:58:36,604 default_StableDiffusionV2 OXPzZm 33133be7-246f-4492-9ab6-6a4c2666b306 /imagine replica.py:772 - GENERATE OK 14167.2ms

应用会在本地保存 output.png 文件。以下是输出图像示例。 图像