部署 Stable Diffusion 模型#
本示例使用 Ray Serve 运行 Stable Diffusion 应用。
要运行本示例,请安装以下内容
pip install "ray[serve]" requests torch diffusers==0.12.1 transformers
本示例使用 stabilityai/stable-diffusion-2 模型和 FastAPI 来构建。将以下代码保存到名为 stable_diffusion.py 的文件中。
Serve 代码如下所示
from io import BytesIO
from fastapi import FastAPI
from fastapi.responses import Response
import torch
from ray import serve
from ray.serve.handle import DeploymentHandle
app = FastAPI()
@serve.deployment(num_replicas=1)
@serve.ingress(app)
class APIIngress:
def __init__(self, diffusion_model_handle: DeploymentHandle) -> None:
self.handle = diffusion_model_handle
@app.get(
"/imagine",
responses={200: {"content": {"image/png": {}}}},
response_class=Response,
)
async def generate(self, prompt: str, img_size: int = 512):
assert len(prompt), "prompt parameter cannot be empty"
image = await self.handle.generate.remote(prompt, img_size=img_size)
file_stream = BytesIO()
image.save(file_stream, "PNG")
return Response(content=file_stream.getvalue(), media_type="image/png")
@serve.deployment(
ray_actor_options={"num_gpus": 1},
autoscaling_config={"min_replicas": 0, "max_replicas": 2},
)
class StableDiffusionV2:
def __init__(self):
from diffusers import EulerDiscreteScheduler, StableDiffusionPipeline
model_id = "stabilityai/stable-diffusion-2"
scheduler = EulerDiscreteScheduler.from_pretrained(
model_id, subfolder="scheduler"
)
self.pipe = StableDiffusionPipeline.from_pretrained(
model_id, scheduler=scheduler, revision="fp16", torch_dtype=torch.float16
)
self.pipe = self.pipe.to("cuda")
def generate(self, prompt: str, img_size: int = 512):
assert len(prompt), "prompt parameter cannot be empty"
with torch.autocast("cuda"):
image = self.pipe(prompt, height=img_size, width=img_size).images[0]
return image
entrypoint = APIIngress.bind(StableDiffusionV2.bind())
使用 serve run stable_diffusion:entrypoint
命令来启动 Serve 应用。
注意
自动伸缩配置将 min_replicas
设置为 0,这意味着部署启动时没有 ObjectDetection
的副本。这些副本只在请求到达时才会生成。当一段时间内没有请求到达时,Serve 会将 ObjectDetection
副本数量缩减回 0,以节省 GPU 资源。
您应该在输出中看到以下消息
(ServeController pid=362, ip=10.0.44.233) INFO 2023-03-08 16:44:57,579 controller 362 http_state.py:129 - Starting HTTP proxy with name 'SERVE_CONTROLLER_ACTOR:SERVE_PROXY_ACTOR-7396d5a9efdb59ee01b7befba448433f6c6fc734cfa5421d415da1b3' on node '7396d5a9efdb59ee01b7befba448433f6c6fc734cfa5421d415da1b3' listening on '127.0.0.1:8000'
(ServeController pid=362, ip=10.0.44.233) INFO 2023-03-08 16:44:57,588 controller 362 http_state.py:129 - Starting HTTP proxy with name 'SERVE_CONTROLLER_ACTOR:SERVE_PROXY_ACTOR-a30ea53938547e0bf88ce8672e578f0067be26a7e26d23465c46300b' on node 'a30ea53938547e0bf88ce8672e578f0067be26a7e26d23465c46300b' listening on '127.0.0.1:8000'
(ProxyActor pid=439, ip=10.0.44.233) INFO: Started server process [439]
(ProxyActor pid=5779) INFO: Started server process [5779]
(ServeController pid=362, ip=10.0.44.233) INFO 2023-03-08 16:44:59,362 controller 362 deployment_state.py:1333 - Adding 1 replica to deployment 'APIIngress'.
2023-03-08 16:45:01,316 SUCC <string>:93 -- Deployed Serve app successfully.
使用以下代码发送请求
import requests
prompt = "a cute cat is dancing on the grass."
input = "%20".join(prompt.split(" "))
resp = requests.get(f"http://127.0.0.1:8000/imagine?prompt={input}")
with open("output.png", 'wb') as f:
f.write(resp.content)
该应用会将 output.png
文件保存在本地。下面是一个输出图像的示例。