使用 Ray DAG API 进行延迟计算图#

使用 ray.remote,您可以灵活地在运行时远程执行计算的应用程序。对于 ray.remote 装饰的类或函数,您还可以使用正文上的 .bind 来构建静态计算图。

注意

Ray DAG 被设计为一个面向开发者的 API,推荐用例包括:

  1. 在本地迭代和测试由更高级别库编写的应用程序。

  2. 在 Ray DAG API 之上构建库。

注意

Ray 引入了一个实验性 API,用于高性能工作负载,尤其适用于使用多个 GPU 的应用程序。该 API 构建在 Ray DAG API 之上。

有关更多详细信息,请参阅 Ray Compiled Graph

当在 ray.remote 装饰的类或函数上调用 .bind() 时,它会生成一个中间表示 (IR) 节点,该节点充当 DAG 的骨干和构建块,静态地将计算图结合在一起,其中每个 IR 节点将在执行时根据其拓扑顺序解析为值。

IR 节点也可以赋值给变量并作为参数传递给其他节点。

使用函数进行 Ray DAG#

ray.remote 装饰的函数上调用 .bind() 生成的 IR 节点将在执行时作为 Ray 任务执行,并解析为任务输出。

本示例展示了如何构建一个函数链,其中每个节点都可以作为根节点进行迭代执行,或者用作其他函数的输入参数或关键字参数,以形成更复杂的 DAG。

任何 IR 节点都可以直接执行 dag_node.execute(),充当 DAG 的根节点,所有其他从根节点无法访问的节点都将被忽略。

import ray

ray.init()

@ray.remote
def func(src, inc=1):
    return src + inc

a_ref = func.bind(1, inc=2)
assert ray.get(a_ref.execute()) == 3 # 1 + 2 = 3
b_ref = func.bind(a_ref, inc=3)
assert ray.get(b_ref.execute()) == 6 # (1 + 2) + 3 = 6
c_ref = func.bind(b_ref, inc=a_ref)
assert ray.get(c_ref.execute()) == 9 # ((1 + 2) + 3) + (1 + 2) = 9

使用类和类方法进行 Ray DAG#

ray.remote 装饰的类上调用 .bind() 生成的 IR 节点将在执行时作为 Ray Actor 执行。Actor 将在每次执行节点时实例化,并且类方法调用可以形成特定于父 Actor 实例的函数调用链。

从函数、类或类方法生成的 DAG IR 节点可以组合在一起形成一个 DAG。

import ray

ray.init()

@ray.remote
class Actor:
    def __init__(self, init_value):
        self.i = init_value

    def inc(self, x):
        self.i += x

    def get(self):
        return self.i

a1 = Actor.bind(10)  # Instantiate Actor with init_value 10.
val = a1.get.bind()  # ClassMethod that returns value from get() from
                     # the actor created.
assert ray.get(val.execute()) == 10

@ray.remote
def combine(x, y):
    return x + y

a2 = Actor.bind(10) # Instantiate another Actor with init_value 10.
a1.inc.bind(2)  # Call inc() on the actor created with increment of 2.
a1.inc.bind(4)  # Call inc() on the actor created with increment of 4.
a2.inc.bind(6)  # Call inc() on the actor created with increment of 6.

# Combine outputs from a1.get() and a2.get()
dag = combine.bind(a1.get.bind(), a2.get.bind())

# a1 +  a2 + inc(2) + inc(4) + inc(6)
# 10 + (10 + ( 2   +    4    +   6)) = 32
assert ray.get(dag.execute()) == 32

使用自定义 InputNode 进行 Ray DAG#

InputNode 是 DAG 的单例节点,代表运行时用户输入值。它应该在不带参数的上下文管理器中使用,并作为 dag_node.execute() 的参数调用。

import ray

ray.init()

from ray.dag.input_node import InputNode

@ray.remote
def a(user_input):
    return user_input * 2

@ray.remote
def b(user_input):
    return user_input + 1

@ray.remote
def c(x, y):
    return x + y

with InputNode() as dag_input:
    a_ref = a.bind(dag_input)
    b_ref = b.bind(dag_input)
    dag = c.bind(a_ref, b_ref)

#   a(2)  +   b(2)  = c
# (2 * 2) + (2 + 1)
assert ray.get(dag.execute(2)) == 7

#   a(3)  +   b(3)  = c
# (3 * 2) + (3 + 1)
assert ray.get(dag.execute(3)) == 10

使用多个 MultiOutputNode 进行 Ray DAG#

MultiOutputNode 在 DAG 有多个输出时非常有用。dag_node.execute() 返回传递给 MultiOutputNode 的 Ray 对象引用的列表。下面的示例展示了具有 2 个输出的多输出节点。

import ray

from ray.dag.input_node import InputNode
from ray.dag.output_node import MultiOutputNode

@ray.remote
def f(input):
    return input + 1

with InputNode() as input_data:
    dag = MultiOutputNode([f.bind(input_data["x"]), f.bind(input_data["y"])])

refs = dag.execute({"x": 1, "y": 2})
assert ray.get(refs) == [2, 3]

在 DAG 中重用 Ray Actor#

Actor 可以通过 Actor.bind() API 成为 DAG 定义的一部分。但是,当 DAG 执行完成时,Ray 会销毁使用 bind 创建的 Actor。

您可以通过使用 Actor.remote() 创建 Actor 来避免在 DAG 完成时销毁 Actor。

import ray
from ray.dag.input_node import InputNode
from ray.dag.output_node import MultiOutputNode

@ray.remote
class Worker:
    def __init__(self):
        self.forwarded = 0

    def forward(self, input_data: int):
        self.forwarded += 1
        return input_data + 1

    def num_forwarded(self):
        return self.forwarded

# Create an actor via ``remote`` API not ``bind`` API to avoid
# killing actors when a DAG is finished.
worker = Worker.remote()

with InputNode() as input_data:
    dag = MultiOutputNode([worker.forward.bind(input_data)])

# Actors are reused. The DAG definition doesn't include
# actor creation.
assert ray.get(dag.execute(1)) == [2]
assert ray.get(dag.execute(2)) == [3]
assert ray.get(dag.execute(3)) == [4]

# You can still use other actor methods via `remote` API.
assert ray.get(worker.num_forwarded.remote()) == 3

更多资源#

您可以在以下资源中找到更多应用程序模式和示例,这些资源来自基于 Ray DAG API 构建的其他 Ray 库,采用相同的机制。