反模式:不必要地调用 ray.get 会损害性能#

TLDR: 避免对中间步骤不必要地调用 ray.get()。直接使用对象引用,仅在最后调用 ray.get() 获取最终结果。

调用 ray.get() 时,对象必须传输到调用 ray.get() 的 worker/节点。如果你不需要操作该对象,你可能不需要对其调用 ray.get()

通常,最佳实践是尽可能晚地调用 ray.get(),甚至设计程序来完全避免调用 ray.get()

代码示例#

反模式

import ray
import numpy as np

ray.init()


@ray.remote
def generate_rollout():
    return np.ones((10000, 10000))


@ray.remote
def reduce(rollout):
    return np.sum(rollout)


# `ray.get()` downloads the result here.
rollout = ray.get(generate_rollout.remote())
# Now we have to reupload `rollout`
reduced = ray.get(reduce.remote(rollout))
../../_images/unnecessary-ray-get-anti.svg

更好的方法

# Don't need ray.get here.
rollout_obj_ref = generate_rollout.remote()
# Rollout object is passed by reference.
reduced = ray.get(reduce.remote(rollout_obj_ref))
../../_images/unnecessary-ray-get-better.svg

注意在反模式示例中,我们调用了 ray.get(),这迫使我们将大型 rollout 传输到驱动器,然后再传输到 reduce worker。

在改进版本中,我们只将对象引用传递给 reduce 任务。reduce worker 将隐式调用 ray.get() 直接从 generate_rollout worker 获取实际的 rollout 数据,避免了额外复制到驱动器。

其他与 ray.get() 相关的反模式包括