实验性:通信和计算重叠#

编译图目前为 GPU 通信和计算重叠提供实验性支持。开启此功能后,它会自动将 GPU 通信与计算操作重叠,从而隐藏通信开销并提高性能。

要启用此功能,请在调用 dag.experimental_compile() 时指定 _overlap_gpu_communication=True

以下代码包含受益于重叠的 GPU 通信和计算操作。

import ray
import time
import torch
from ray.dag import InputNode, MultiOutputNode


@ray.remote(num_cpus=0, num_gpus=1)
class TorchTensorWorker:
    def send(self, shape, dtype, value: int, send_tensor=True):
        if not send_tensor:
            return 1
        return torch.ones(shape, dtype=dtype, device="cuda") * value

    def recv_and_matmul(self, two_d_tensor):
        """
        Receive the tensor and do some expensive computation (matmul).

        Args:
            two_d_tensor: a 2D tensor that has the same size for its dimensions
        """
        # Check that tensor got loaded to the correct device.
        assert two_d_tensor.dim() == 2
        assert two_d_tensor.size(0) == two_d_tensor.size(1)
        torch.matmul(two_d_tensor, two_d_tensor)
        return (two_d_tensor[0][0].item(), two_d_tensor.shape, two_d_tensor.dtype)


def test(overlap_gpu_communication):
    num_senders = 3
    senders = [TorchTensorWorker.remote() for _ in range(num_senders)]
    receiver = TorchTensorWorker.remote()

    shape = (10000, 10000)
    dtype = torch.float16

    with InputNode() as inp:
        branches = [sender.send.bind(shape, dtype, inp) for sender in senders]
        branches = [
            branch.with_tensor_transport(
                transport="nccl", _static_shape=True, _direct_return=True
            )
            # For a ray version before 2.42, use `with_type_hint()` instead.
            # branch.with_type_hint(
            #     TorchTensorType(
            #         transport="nccl", _static_shape=True, _direct_return=True
            #     )
            # )
            for branch in branches
        ]
        branches = [receiver.recv_and_matmul.bind(branch) for branch in branches]
        dag = MultiOutputNode(branches)

    compiled_dag = dag.experimental_compile(
        _overlap_gpu_communication=overlap_gpu_communication
    )

    start = time.monotonic()
    for i in range(5):
        ref = compiled_dag.execute(i)
        result = ray.get(ref)
        assert result == [(i, shape, dtype)] * num_senders
    duration = time.monotonic() - start
    print(f"{overlap_gpu_communication=}, {duration=}")
    compiled_dag.teardown(kill_actors=True)


for overlap_gpu_communication in [False, True]:
    test(overlap_gpu_communication)

上述代码的输出包括以下两行

overlap_gpu_communication=False, duration=1.0670117866247892
overlap_gpu_communication=True, duration=0.9211348341777921

实际性能数据可能因硬件而异,但对于此示例,启用 _overlap_gpu_communication 可将延迟降低约 14%。