反模式:不必要地调用 ray.get 会损害性能#
TLDR: 避免对中间步骤不必要地调用 ray.get()
。直接使用对象引用,仅在最后调用 ray.get()
获取最终结果。
调用 ray.get()
时,对象必须传输到调用 ray.get()
的 worker/节点。如果你不需要操作该对象,你可能不需要对其调用 ray.get()
!
通常,最佳实践是尽可能晚地调用 ray.get()
,甚至设计程序来完全避免调用 ray.get()
。
代码示例#
反模式
import ray
import numpy as np
ray.init()
@ray.remote
def generate_rollout():
return np.ones((10000, 10000))
@ray.remote
def reduce(rollout):
return np.sum(rollout)
# `ray.get()` downloads the result here.
rollout = ray.get(generate_rollout.remote())
# Now we have to reupload `rollout`
reduced = ray.get(reduce.remote(rollout))
更好的方法
# Don't need ray.get here.
rollout_obj_ref = generate_rollout.remote()
# Rollout object is passed by reference.
reduced = ray.get(reduce.remote(rollout_obj_ref))
注意在反模式示例中,我们调用了 ray.get()
,这迫使我们将大型 rollout 传输到驱动器,然后再传输到 reduce worker。
在改进版本中,我们只将对象引用传递给 reduce 任务。reduce
worker 将隐式调用 ray.get()
直接从 generate_rollout
worker 获取实际的 rollout 数据,避免了额外复制到驱动器。
其他与 ray.get()
相关的反模式包括