反模式:不必要地调用 ray.get 会损害性能#
简而言之: 避免在中间步骤中不必要地调用 ray.get()。直接使用对象引用进行操作,仅在最后调用 ray.get() 来获取最终结果。
调用 ray.get() 时,对象必须传输到调用 ray.get() 的 worker/节点。如果您不需要操作该对象,那么很可能不需要在它上面调用 ray.get()!
通常,最佳实践是尽可能延迟调用 ray.get(),或者甚至设计您的程序以完全避免调用 ray.get()。
代码示例#
反模式
import ray
import numpy as np
ray.init()
@ray.remote
def generate_rollout():
return np.ones((10000, 10000))
@ray.remote
def reduce(rollout):
return np.sum(rollout)
# `ray.get()` downloads the result here.
rollout = ray.get(generate_rollout.remote())
# Now we have to reupload `rollout`
reduced = ray.get(reduce.remote(rollout))
更好的方法
# Don't need ray.get here.
rollout_obj_ref = generate_rollout.remote()
# Rollout object is passed by reference.
reduced = ray.get(reduce.remote(rollout_obj_ref))
在反模式示例中,我们调用了 ray.get(),这强制我们将大型 rollout 传输到 driver,然后再传输到 _reduce_ worker。
在修复后的版本中,我们仅将对象引用传递给 _reduce_ 任务。_reduce_ worker 将隐式调用 ray.get(),直接从 generate_rollout worker 获取实际的 rollout 数据,避免了到 driver 的额外复制。
其他与 ray.get() 相关的反模式有