Exemple #1
0
def test_xla_distrib_single_node_no_spawn():
    idist.initialize("xla-tpu")
    _test_distrib_config(local_rank=0,
                         backend="xla-tpu",
                         ws=1,
                         true_device="xla")
    idist.finalize()
Exemple #2
0
def test_native_distrib_single_node_launch_tool_nccl(local_rank, world_size):
    import os

    rank = local_rank
    os.environ["RANK"] = "{}".format(rank)

    idist.initialize("nccl")
    _test_distrib_config(local_rank, "nccl", world_size, "cuda", rank)
    idist.finalize()
Exemple #3
0
def test_native_distrib_single_node_launch_tool_gloo(local_rank, world_size):
    import os
    from datetime import timedelta

    timeout = timedelta(seconds=20)
    rank = local_rank
    os.environ["RANK"] = "{}".format(rank)

    idist.initialize("gloo", timeout=timeout)
    _test_distrib_config(local_rank, "gloo", world_size, "cpu", rank)
    idist.finalize()
Exemple #4
0
def test_hvd_distrib_single_node_single_device():
    import horovod.torch as hvd

    idist.initialize("horovod")

    device = "cpu" if torch.cuda.device_count() < 1 else "cuda"
    local_rank = hvd.local_rank()
    world_size = hvd.size()
    rank = hvd.rank()
    _test_distrib_config(local_rank, "horovod", world_size, device, rank)
    idist.finalize()
Exemple #5
0
def _test_native_distrib_single_node_launch_tool(backend,
                                                 device,
                                                 local_rank,
                                                 world_size,
                                                 init_method=None,
                                                 **kwargs):
    import os

    rank = local_rank
    os.environ["RANK"] = f"{rank}"

    idist.initialize(backend, init_method=init_method, **kwargs)
    _test_distrib_config(local_rank,
                         backend,
                         world_size,
                         device,
                         rank,
                         true_init_method=init_method)
    idist.finalize()