使用 Ray DAG API 的惰性计算图#

使用 ray.remote,您可以灵活地运行计算在运行时远程执行的应用。对于用 ray.remote 装饰的类或函数,您还可以在其主体上使用 .bind 来构建静态计算图。

注意

Ray DAG 被设计为一个面向开发者的 API,推荐的使用场景包括:

  1. 在本地迭代和测试由更高级别库编写的应用。

  2. 在 Ray DAG API 之上构建库。

注意

Ray 引入了一个实验性 API,用于处理高性能工作负载,尤其适用于使用多个 GPU 的应用。此 API 构建在 Ray DAG API 之上。

有关更多详细信息,请参见 Ray 编译图

当在用 ray.remote 装饰的类或函数上调用 .bind() 时,它会生成一个中间表示 (IR) 节点,该节点作为 DAG 的骨干和构建块,静态地将计算图连接在一起。每个 IR 节点在执行时会根据其拓扑顺序被解析为值。

IR 节点也可以赋值给变量,并作为参数传递给其他节点。

带有函数的 Ray DAG#

在用 ray.remote 装饰的函数上通过 .bind() 生成的 IR 节点在执行时会作为 Ray Task 运行,并解析为任务输出。

此示例展示了如何构建一个函数链,其中每个节点在迭代时可以作为根节点执行,或作为其他函数的输入参数或关键字参数来形成更复杂的 DAG。

任何 IR 节点都可以通过 dag_node.execute() 直接执行,这相当于将该节点作为 DAG 的根节点执行,所有从根节点无法到达的其他节点都将被忽略。

import ray

ray.init()

@ray.remote
def func(src, inc=1):
    return src + inc

a_ref = func.bind(1, inc=2)
assert ray.get(a_ref.execute()) == 3 # 1 + 2 = 3
b_ref = func.bind(a_ref, inc=3)
assert ray.get(b_ref.execute()) == 6 # (1 + 2) + 3 = 6
c_ref = func.bind(b_ref, inc=a_ref)
assert ray.get(c_ref.execute()) == 9 # ((1 + 2) + 3) + (1 + 2) = 9

带有类和类方法的 Ray DAG#

在用 ray.remote 装饰的类上通过 .bind() 生成的 IR 节点在执行时会作为 Ray Actor 运行。每次执行该节点时都会实例化一个 Actor,并且类方法的调用可以形成特定于父 Actor 实例的函数调用链。

从函数、类或类方法生成的 DAG IR 节点可以组合起来形成一个 DAG。

import ray

ray.init()

@ray.remote
class Actor:
    def __init__(self, init_value):
        self.i = init_value

    def inc(self, x):
        self.i += x

    def get(self):
        return self.i

a1 = Actor.bind(10)  # Instantiate Actor with init_value 10.
val = a1.get.bind()  # ClassMethod that returns value from get() from
                     # the actor created.
assert ray.get(val.execute()) == 10

@ray.remote
def combine(x, y):
    return x + y

a2 = Actor.bind(10) # Instantiate another Actor with init_value 10.
a1.inc.bind(2)  # Call inc() on the actor created with increment of 2.
a1.inc.bind(4)  # Call inc() on the actor created with increment of 4.
a2.inc.bind(6)  # Call inc() on the actor created with increment of 6.

# Combine outputs from a1.get() and a2.get()
dag = combine.bind(a1.get.bind(), a2.get.bind())

# a1 +  a2 + inc(2) + inc(4) + inc(6)
# 10 + (10 + ( 2   +    4    +   6)) = 32
assert ray.get(dag.execute()) == 32

带有自定义 InputNode 的 Ray DAG#

InputNode 是 DAG 中的单例节点,代表运行时的用户输入值。它应在没有参数的上下文管理器中使用,并作为 dag_node.execute() 的参数调用。

import ray

ray.init()

from ray.dag.input_node import InputNode

@ray.remote
def a(user_input):
    return user_input * 2

@ray.remote
def b(user_input):
    return user_input + 1

@ray.remote
def c(x, y):
    return x + y

with InputNode() as dag_input:
    a_ref = a.bind(dag_input)
    b_ref = b.bind(dag_input)
    dag = c.bind(a_ref, b_ref)

#   a(2)  +   b(2)  = c
# (2 * 2) + (2 + 1)
assert ray.get(dag.execute(2)) == 7

#   a(3)  +   b(3)  = c
# (3 * 2) + (3 + 1)
assert ray.get(dag.execute(3)) == 10

带有多个 MultiOutputNode 的 Ray DAG#

当您的 DAG 有多个输出时,MultiOutputNode 非常有用。dag_node.execute() 返回传递给 MultiOutputNode 的 Ray object reference 列表。下面的示例展示了一个有 2 个输出的多输出节点。

import ray

from ray.dag.input_node import InputNode
from ray.dag.output_node import MultiOutputNode

@ray.remote
def f(input):
    return input + 1

with InputNode() as input_data:
    dag = MultiOutputNode([f.bind(input_data["x"]), f.bind(input_data["y"])])

refs = dag.execute({"x": 1, "y": 2})
assert ray.get(refs) == [2, 3]

在 DAG 中重用 Ray Actor#

Actor 可以通过 Actor.bind() API 成为 DAG 定义的一部分。但是,当 DAG 执行完成后,Ray 会杀死通过 bind 创建的 Actor。

您可以通过使用 Actor.remote() 创建 Actor 来避免在 DAG 完成时杀死它们。

import ray
from ray.dag.input_node import InputNode
from ray.dag.output_node import MultiOutputNode

@ray.remote
class Worker:
    def __init__(self):
        self.forwarded = 0

    def forward(self, input_data: int):
        self.forwarded += 1
        return input_data + 1

    def num_forwarded(self):
        return self.forwarded

# Create an actor via ``remote`` API not ``bind`` API to avoid
# killing actors when a DAG is finished.
worker = Worker.remote()

with InputNode() as input_data:
    dag = MultiOutputNode([worker.forward.bind(input_data)])

# Actors are reused. The DAG definition doesn't include
# actor creation.
assert ray.get(dag.execute(1)) == [2]
assert ray.get(dag.execute(2)) == [3]
assert ray.get(dag.execute(3)) == [4]

# You can still use other actor methods via `remote` API.
assert ray.get(worker.num_forwarded.remote()) == 3

更多资源#

您可以在基于 Ray DAG API 构建的其他 Ray 库的以下资源中找到更多应用模式和示例,这些库使用了相同的机制。