Ray Direct Transport (RDT)#

Ray 对象通常存储在 Ray 的基于 CPU 的对象存储中,并在 Ray 任务或 actor 访问时进行复制和反序列化。对于 GPU 数据,这可能导致不必要且昂贵的数据传输。例如,将一个 CUDA torch.Tensor 从一个 Ray 任务传递到另一个任务,需要先从 GPU 复制到 CPU 内存,然后再复制回 GPU 内存。

Ray Direct Transport (RDT) 是一项新功能,它允许 Ray 直接在 Ray actor 之间存储和传递对象。此功能通过以下方式增强了熟悉的 Ray ObjectRef API:

  • 在必要传输之前将 GPU 数据保留在 GPU 内存中

  • 避免将数据复制到 Ray 对象存储并从中反序列化(这会产生高昂的成本)

  • 使用高效的数据传输方式,例如集体通信库(GlooNCCL)或点对点 RDMA(通过 NVIDIA 的 NIXL),直接在包括 CPU 和 GPU 在内的设备之间传输数据。

注意

RDT 目前处于 **alpha** 阶段,尚未支持所有 Ray Core API。未来版本可能会引入破坏性的 API 更改。有关更多详细信息,请参阅 局限性 部分。

入门#

提示

RDT 目前支持由 Ray actor 任务创建的 torch.Tensor 对象。未来的版本可能会支持其他数据类型和 Ray 非 actor 任务。

本教程将展示如何使用不同的张量传输(即,用于在 actor 之间传输张量的机制)来创建和使用 RDT。目前,RDT 支持以下张量传输:

  1. Gloo:一个用于 PyTorch 和 CPU 的集体通信库。

  2. NVIDIA NCCL:一个用于 NVIDIA GPU 的集体通信库。

  3. NVIDIA NIXL(由 UCX 支持):一个用于加速点对点传输(通过 RDMA)的库,尤其是在各种类型的内存和 NVIDIA GPU 之间。

为了便于跟随,我们将从 Gloo 传输开始,该传输无需任何物理 GPU 即可使用。

与 Gloo 的用法(仅限 CPU)#

安装#

注意

正在建设中。

教程#

首先,定义一个 actor 类和一个返回 torch.Tensor 的任务。

import torch
import ray


@ray.remote
class MyActor:
    def random_tensor(self):
        return torch.randn(1000, 1000)


按照当前写法,当 torch.Tensor 返回时,它将被复制到 Ray 的基于 CPU 的对象存储中。对于基于 CPU 的张量,这可能需要一个昂贵的步骤来复制和序列化对象,而基于 GPU 的张量还需要复制到 CPU 内存和从中复制出来。

要启用 RDT,请在 @ray.method 装饰器中使用 tensor_transport 选项。

@ray.remote
class MyActor:
    @ray.method(tensor_transport="gloo")
    def random_tensor(self):
        return torch.randn(1000, 1000)


此装饰器可以添加到任何返回 torch.Tensor 的 actor 任务,或者返回嵌套在其他 Python 对象中的 torch.Tensors 的 actor 任务。添加此装饰器将改变 Ray 的行为,如下所示:

  1. 返回张量时,Ray 将存储张量的引用,而不是将其复制到 CPU 内存。

  2. ray.ObjectRef 传递给另一个任务时,Ray 将使用 Gloo 将张量传输到目标任务。

请注意,对于 (2) 的工作,@ray.method(tensor_transport) 装饰器只需要添加到返回张量的 actor 任务上。不应将其添加到消耗张量的 actor 任务(除非这些任务也返回张量)。

另外,为了让 (2) 工作,我们必须先创建一个 actor 的集体组

创建集体组#

要创建用于 RDT 的集体组:

  1. 创建多个 Ray actor。

  2. 使用 ray.experimental.collective.create_collective_group 函数在 actor 上创建集体组。指定的 backend 必须与 @ray.method 装饰器中使用的 tensor_transport 匹配。

以下是一个示例:

import torch
import ray
from ray.experimental.collective import create_collective_group


@ray.remote
class MyActor:
    @ray.method(tensor_transport="gloo")
    def random_tensor(self):
        return torch.randn(1000, 1000)

    def sum(self, tensor: torch.Tensor):
        return torch.sum(tensor)


sender, receiver = MyActor.remote(), MyActor.remote()
# The tensor_transport specified here must match the one used in the @ray.method
# decorator.
group = create_collective_group([sender, receiver], backend="torch_gloo")

现在 actor 可以通过 gloo 直接通信。还可以使用 ray.experimental.collective.destroy_collective_group 函数销毁该组。调用此函数后,可以在同一 actor 上创建新的集体组。

将对象传递给其他 actor#

既然我们有了一个集体组,就可以创建 RDT 对象并在 actor 之间传递。这是一个完整的示例:

import torch
import ray
from ray.experimental.collective import create_collective_group


@ray.remote
class MyActor:
    @ray.method(tensor_transport="gloo")
    def random_tensor(self):
        return torch.randn(1000, 1000)

    def sum(self, tensor: torch.Tensor):
        return torch.sum(tensor)


sender, receiver = MyActor.remote(), MyActor.remote()
group = create_collective_group([sender, receiver], backend="torch_gloo")

# The tensor will be stored by the `sender` actor instead of in Ray's object
# store.
tensor = sender.random_tensor.remote()
result = receiver.sum.remote(tensor)
print(ray.get(result))

ray.ObjectRef 传递给另一个任务时,Ray 将使用 Gloo 将张量直接从源 actor 传输到目标 actor,而不是默认的对象存储。请注意,@ray.method(tensor_transport) 装饰器仅添加到返回张量的 actor 任务上;添加此提示后,接收 actor 任务 receiver.sum 将自动使用 Gloo 接收张量。在此示例中,由于 MyActor.sum 没有 @ray.method(tensor_transport) 装饰器,它将使用默认的 Ray 对象存储传输来返回 torch.sum(tensor)

RDT 还支持在 Python 数据结构嵌套中传递张量,以及 actor 任务返回多个张量,如下例所示:

import torch
import ray
from ray.experimental.collective import create_collective_group


@ray.remote
class MyActor:
    @ray.method(tensor_transport="gloo")
    def random_tensor_dict(self):
        return {"tensor1": torch.randn(1000, 1000), "tensor2": torch.randn(1000, 1000)}

    def sum(self, tensor_dict: dict):
        return torch.sum(tensor_dict["tensor1"]) + torch.sum(tensor_dict["tensor2"])


sender, receiver = MyActor.remote(), MyActor.remote()
group = create_collective_group([sender, receiver], backend="torch_gloo")

# Both tensor values in the dictionary will be stored by the `sender` actor
# instead of in Ray's object store.
tensor_dict = sender.random_tensor_dict.remote()
result = receiver.sum.remote(tensor_dict)
print(ray.get(result))

将 RDT 对象传递给生成它们的 actor#

RDT ray.ObjectRefs 也可以传递给生成它们的 actor。这避免了任何复制,只是提供了对之前创建的 torch.Tensor 的引用。例如:

import torch
import ray
import pytest
from ray.experimental.collective import create_collective_group


@ray.remote
class MyActor:
    @ray.method(tensor_transport="gloo")
    def random_tensor(self):
        return torch.randn(1000, 1000)

    def sum(self, tensor: torch.Tensor):
        return torch.sum(tensor)


sender, receiver = MyActor.remote(), MyActor.remote()
group = create_collective_group([sender, receiver], backend="torch_gloo")

tensor = sender.random_tensor.remote()
# Pass the ObjectRef back to the actor that produced it. The tensor will be
# passed back to the same actor without copying.
sum1 = sender.sum.remote(tensor)
sum2 = receiver.sum.remote(tensor)
assert torch.allclose(*ray.get([sum1, sum2]))

注意

Ray 只保留用户创建的张量的引用,因此张量对象是可变的。如果上面的示例中 sender.sum 修改了张量,那么 receiver.sum 也会看到更改。这与标准的 Ray Core API 不同,后者总是为 actor 返回的数据创建一个不可变的副本。

ray.get#

与往常一样,也可以使用 ray.get 函数来检索 RDT 对象的结果。但是,ray.get 默认将使用与 @ray.method 装饰器中指定的相同的张量传输。对于基于集体的传输,如果调用者不是集体组成员,这将不起作用。

因此,用户需要通过在 ray.get 中设置 _tensor_transport 来显式指定 Ray 对象存储作为张量传输。


# Wrong example of ray.get(). Since the tensor transport in the @ray.method decorator is Gloo,
# ray.get() will try to use Gloo to fetch the tensor, which is not supported
# because the caller is not part of the collective group.
with pytest.raises(ValueError) as e:
    ray.get(tensor)

assert (
    "Trying to use two-sided tensor transport: GLOO for ray.get. This is only supported for one-sided transports such as NIXL or the OBJECT_STORE."
    in str(e.value)
)

# Correct example of ray.get(), explicitly setting the tensor transport to use the Ray object store.
print(ray.get(tensor, _tensor_transport="object_store"))
# torch.Tensor(...)

对象可变性#

与 Ray 对象存储中的对象不同,RDT 对象是可变的,这意味着 Ray 只保留对张量的引用,直到请求传输才会复制。因此,如果返回张量的 actor 也保留对张量的引用,并在 Ray 仍存储张量引用时对其进行原地修改,则接收 actor 可能会看到部分或全部更改。

以下是可能出错的一个例子:

import torch
import ray
from ray.experimental.collective import create_collective_group


@ray.remote
class MyActor:
    @ray.method(tensor_transport="gloo")
    def random_tensor(self):
        self.tensor = torch.randn(1000, 1000)
        # After this function returns, Ray and this actor will both hold a
        # reference to the same tensor.
        return self.tensor

    def increment_and_sum_stored_tensor(self):
        # NOTE: In-place update, while Ray still holds a reference to the same tensor.
        self.tensor += 1
        return torch.sum(self.tensor)

    def increment_and_sum(self, tensor: torch.Tensor):
        return torch.sum(tensor + 1)


sender, receiver = MyActor.remote(), MyActor.remote()
group = create_collective_group([sender, receiver], backend="torch_gloo")

tensor = sender.random_tensor.remote()
tensor1 = sender.increment_and_sum_stored_tensor.remote()
# Wait for sender.increment_and_sum_stored_tensor task to finish.
tensor1 = ray.get(tensor1)
# Receiver will now receive the updated value instead of the original.
tensor2 = receiver.increment_and_sum.remote(tensor)

try:
    # This assertion will fail because sender.increment_and_sum_stored_tensor
    # modified the tensor in place before sending it to
    # receiver.increment_and_sum.
    assert torch.allclose(tensor1, ray.get(tensor2))
except AssertionError:
    print("AssertionError: sender and receiver returned different sums.")

在此示例中,发送 actor 将一个张量返回给 Ray,但它也在其本地状态中保留了对该张量的引用。然后,在 sender.increment_and_sum_stored_tensor 中,发送 actor 在 Ray 仍持有张量引用时,原地修改了张量。然后,receiver.increment_and_sum 任务接收到的是修改后的张量,而不是原始张量,因此断言失败。

为了修复此类错误,请使用 ray.experimental.wait_tensor_freed 函数等待 Ray 释放对张量的所有引用,以便 actor 可以安全地再次写入张量。wait_tensor_freed 将在所有依赖于该张量的任务执行完毕并且所有相应的 ObjectRefs 都超出范围后解除阻塞。Ray 通过跟踪哪些任务将与张量对应的 ObjectRef 作为参数来跟踪依赖于该张量的任务。

这是前面示例的修复版本。

import torch
import ray
from ray.experimental.collective import create_collective_group


@ray.remote
class MyActor:
    @ray.method(tensor_transport="gloo")
    def random_tensor(self):
        self.tensor = torch.randn(1000, 1000)
        return self.tensor

    def increment_and_sum_stored_tensor(self):
        # 1. Sender actor waits for Ray to release all references to the tensor
        # before modifying the tensor in place.
        ray.experimental.wait_tensor_freed(self.tensor)
        # NOTE: In-place update, but Ray guarantees that it has already released
        # its references to this tensor.
        self.tensor += 1
        return torch.sum(self.tensor)

    def increment_and_sum(self, tensor: torch.Tensor):
        # Receiver task remains the same.
        return torch.sum(tensor + 1)


sender, receiver = MyActor.remote(), MyActor.remote()
group = create_collective_group([sender, receiver], backend="torch_gloo")

tensor = sender.random_tensor.remote()
tensor1 = sender.increment_and_sum_stored_tensor.remote()
# 2. Skip `ray.get`` because `wait_tensor_freed`` will block until all
# references to `tensor` are freed, so calling `ray.get` here would cause a
# deadlock.
# tensor1 = ray.get(tensor1)
tensor2 = receiver.increment_and_sum.remote(tensor)

# 3. Delete all references to `tensor`, to unblock wait_tensor_freed.
del tensor

# This assertion will now pass.
assert torch.allclose(ray.get(tensor1), ray.get(tensor2))

主要更改包括:1. sender 在原地修改张量之前调用 wait_tensor_freed。2. Driver 跳过 ray.get,因为 wait_tensor_freed 会阻塞直到所有指向张量的 ObjectRefs 被释放,因此在此处调用 ray.get 会导致死锁。3. Driver 调用 del tensor 来释放其对张量的引用。同样,这是必需的,因为 wait_tensor_freed 会阻塞直到所有指向张量的 ObjectRefs 被释放。

当 RDT ObjectRef 被传递回生成它的 actor 时,Ray 会传递张量的引用而不是副本。因此,可能会发生相同的 bug。为了帮助捕获这种情况,如果 RDT 对象被传递给生成它的 actor 和另一个 actor,Ray 会打印一个警告,如下所示:

import torch
import ray
from ray.experimental.collective import create_collective_group


@ray.remote
class MyActor:
    @ray.method(tensor_transport="gloo")
    def random_tensor(self):
        return torch.randn(1000, 1000)

    def increment_and_sum(self, tensor: torch.Tensor):
        # In-place update.
        tensor += 1
        return torch.sum(tensor)


sender, receiver = MyActor.remote(), MyActor.remote()
group = create_collective_group([sender, receiver], backend="torch_gloo")

tensor = sender.random_tensor.remote()
tensor1 = sender.increment_and_sum.remote(tensor)
tensor2 = receiver.increment_and_sum.remote(tensor)
# A warning will be printed:
# UserWarning: GPU ObjectRef(...) is being passed back to the actor that created it Actor(MyActor, ...). Note that GPU objects are mutable. If the tensor is modified, Ray's internal copy will also be updated, and subsequent passes to other actors will receive the updated version instead of the original.

try:
    # This assertion may fail because the tensor returned by sender.random_tensor
    # is modified in-place by sender.increment_and_sum while being sent to
    # receiver.increment_and_sum.
    assert torch.allclose(ray.get(tensor1), ray.get(tensor2))
except AssertionError:
    print("AssertionError: sender and receiver returned different sums.")

与 NCCL 的用法(仅限 NVIDIA GPU)#

RDT 只需要少量代码更改即可切换张量传输。这是 Gloo 示例,已修改为使用 NVIDIA GPU 和 NCCL 库进行集体 GPU 通信。

import torch
import ray
from ray.experimental.collective import create_collective_group


@ray.remote(num_gpus=1)
class MyActor:
    @ray.method(tensor_transport="nccl")
    def random_tensor(self):
        return torch.randn(1000, 1000).cuda()

    def sum(self, tensor: torch.Tensor):
        return torch.sum(tensor)


sender, receiver = MyActor.remote(), MyActor.remote()
group = create_collective_group([sender, receiver], backend="nccl")

# The tensor will be stored by the `sender` actor instead of in Ray's object
# store.
tensor = sender.random_tensor.remote()
result = receiver.sum.remote(tensor)
ray.get(result)

主要的代码差异是:

  1. @ray.method 使用 tensor_transport="nccl" 而不是 tensor_transport="gloo"

  2. 使用 ray.experimental.collective.create_collective_group 函数来创建集体组。

  3. 使用 .cuda() 方法在 GPU 上创建张量。

与 NIXL 的用法(CPU 或 NVIDIA GPU)#

安装#

为获得最佳性能,请运行 install_gdrcopy.sh 脚本(例如,install_gdrcopy.sh "${GDRCOPY_OS_VERSION}" "12.8" "x64")。您可以在 此处找到可用的操作系统版本。如果未安装 gdrcopy,情况仍将与普通 pip install nixl 一样,只是性能较低。nixlucx 是通过 pip 作为依赖项安装的。

教程#

NIXL 可以在不同设备之间传输数据,包括 CPU 和 NVIDIA GPU,但不需要提前创建集体组。这意味着任何在其环境中安装了 NIXL 的 actor 都可以用于创建和传递 RDT 对象。

否则,用法与 Gloo 示例 中的相同。

这是一个关于如何使用 NIXL 在两个 actor 之间传输 RDT 对象的示例:

import torch
import ray


@ray.remote(num_gpus=1)
class MyActor:
    @ray.method(tensor_transport="nixl")
    def random_tensor(self):
        return torch.randn(1000, 1000).cuda()

    def sum(self, tensor: torch.Tensor):
        return torch.sum(tensor)

    def produce(self, tensors):
        refs = []
        for t in tensors:
            refs.append(ray.put(t, _tensor_transport="nixl"))
        return refs

    def consume_with_nixl(self, refs):
        # ray.get will also use NIXL to retrieve the
        # result.
        tensors = [ray.get(ref) for ref in refs]
        sum = 0
        for t in tensors:
            assert t.device.type == "cuda"
            sum += t.sum().item()
        return sum


# No collective group is needed. The two actors just need to have NIXL
# installed.
sender, receiver = MyActor.remote(), MyActor.remote()

# The tensor will be stored by the `sender` actor instead of in Ray's object
# store.
tensor = sender.random_tensor.remote()
result = receiver.sum.remote(tensor)
ray.get(result)

Gloo 示例相比,主要的代码差异是:

  1. @ray.method 使用 tensor_transport="nixl" 而不是 tensor_transport="gloo"

  2. 不需要集体组。

ray.put 和 ray.get 与 NIXL 一起使用#

与基于集体的张量传输(Gloo 和 NCCL)不同,ray.get 函数可以使用 NIXL 来检索结果的副本。默认情况下,ray.get 的张量传输将是 @ray.method 装饰器中指定的。

# ray.get will also use NIXL to retrieve the
# result.
print(ray.get(tensor))
# torch.Tensor(...)

您也可以使用 NIXL 从 ray.put 创建的引用中检索结果。

tensor1 = torch.randn(1000, 1000).cuda()
tensor2 = torch.randn(1000, 1000).cuda()
refs = sender.produce.remote([tensor1, tensor2])
ref1 = receiver.consume_with_nixl.remote(refs)
print(ray.get(ref1))

总结#

RDT 允许 Ray 直接在 Ray actor 之间存储和传递对象,使用 Gloo、NCCL 和 NIXL 等加速传输。以下是需要注意的主要几点:

  • 如果使用基于集体的张量传输(Gloo 或 NCCL),则必须提前创建集体组。NIXL 只要求所有涉及的 actor 都安装了 NIXL。

  • 与 Ray 对象存储中的对象不同,RDT 对象是可变的,这意味着 Ray 只持有对存储的张量的引用,而不是副本。

  • 否则,actor 可以像往常一样使用。

有关完整的局限性列表,请参阅 局限性 部分。

微基准测试#

注意

正在建设中。

限制#

RDT 目前处于 alpha 阶段,目前存在以下局限性,未来版本可能会解决这些问题:

  • 仅支持 torch.Tensor 对象。

  • 仅支持 Ray actor,不支持 Ray 任务。

  • 尚未与 asyncio 兼容。请关注 跟踪问题以获取更新。

  • 支持以下传输:Gloo、NCCL 和 NIXL。

  • 仅支持 CPU 和 NVIDIA GPU。

  • RDT 对象是可变的。这意味着 Ray 只持有对张量的引用,并且在请求传输之前不会复制它。因此,如果应用程序代码在返回张量之前也保留了对张量的引用,并在原地修改了张量,那么接收 actor 可能会看到部分或全部更改。

对于基于集体的张量传输(Gloo 和 NCCL):

  • 只有创建集体组的进程才能提交返回和传递 RDT 对象的 actor 任务。如果创建进程将 actor 句柄传递给其他进程,则这些进程可以像往常一样提交 actor 任务,但无法使用 RDT 对象。

  • 同样,创建集体组的进程无法序列化并将其 RDT ray.ObjectRefs 传递给其他 Ray 任务或 actor。相反,ray.ObjectRefs 只能作为直接参数传递给其他 actor 任务,并且这些 actor 必须在同一个集体组中。

  • 每个 actor 在同一时间只能属于一个集体组(每个张量传输)。

  • 不支持 ray.put

由于一个已知问题,对于 NIXL,我们目前不支持在同一 actor 中存储不同的 GPU 对象,这些对象包含重叠但不完全相同的张量集。要支持此模式,请确保第一个 ObjectRef 超出范围,然后再将相同的张量存储在第二个对象中。

from ray.exceptions import ActorDiedError

@ray.remote(num_gpus=1)
class Actor:
    def __init__(self):
        self.tensor1 = torch.tensor([1, 2, 3])
        self.tensor2 = torch.tensor([4, 5, 6])
        self.tensor3 = torch.tensor([7, 8, 9])

    @ray.method(tensor_transport="nixl")
    def send_dict1(self):
        return {"round1-1": self.tensor1, "round1-2": self.tensor2}

    @ray.method(tensor_transport="nixl")
    def send_dict2(self):
        return {"round2-1": self.tensor1, "round2-3": self.tensor3}

    def sum_dict(self, dict):
        return sum(v.sum().item() for v in dict.values())


sender, receiver = Actor.remote(), Actor.remote()
ref1 = sender.send_dict1.remote()
result1 = receiver.sum_dict.remote(ref1)
print(ray.get(result1))
ref2 = sender.send_dict2.remote()
result2 = receiver.sum_dict.remote(ref2)
try:
    print(ray.get(result2))
except ValueError as e:
    print("Error caught:", e)

错误处理#

  • 应用程序级别的错误,即用户代码引发的异常,不会销毁集体组,而是会像非 RDT Ray 对象一样传播到任何依赖的任务。

  • 如果 GLOO 或 NCCL 集体操作期间发生系统级别错误,集体组将被销毁,actor 将被杀死以防止挂起。

  • 如果 NIXL 传输期间发生系统级别错误,Ray 或 NIXL 将中止传输并抛出异常,Ray 将在依赖任务中或在 NIXL ref 的 ray.get 上抛出该异常。

  • 系统级别错误包括:
    • 第三方传输内部错误,例如 NCCL 网络错误。

    • Actor 或节点故障。

    • 由于张量设备/传输不匹配导致的传输错误,例如,在使用 NCCL 时是 CPU 张量。

    • Ray RDT 对象获取超时(可以通过设置 RAY_rdt_fetch_fail_timeout_milliseconds 环境变量来覆盖)。

    • 任何意外的系统 bug。

高级:RDT 内部机制#

注意

正在建设中。