Esempio n. 1
0
def bench_mpi(args):
    guess_rank = int(os.environ["OMPI_COMM_WORLD_RANK"])
    os.environ["UCX_NET_DEVICES"] = best_device_map[guess_rank]

    torch.distributed.init_process_group(backend="mpi")
    os.environ["MASTER_ADDR"] = args.host
    os.environ["MASTER_PORT"] = "10639"
    if args.socket_name:
        os.environ["GLOO_SOCKET_IFNAME"] = args.socket_name
        os.environ["TP_SOCKET_IFNAME"] = args.socket_name
    init_method = f"tcp://{os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}"
    rank = torch.distributed.get_rank()
    world_size = torch.distributed.get_world_size()
    torch.cuda.set_device(rank % torch.cuda.device_count())

    rpc.init_rpc(
        f"Test{rank}",
        rank=rank,
        world_size=world_size,
        backend=rpc.BackendType.PROCESS_GROUP,
        rpc_backend_options=rpc.ProcessGroupRpcBackendOptions(
            rpc_timeout=20, init_method=init_method),
    )

    initialize_model_parallel(1, world_size)
    init_random_seed(0)

    run_mp_worker(args, world_size)

    rpc.shutdown()
    torch.distributed.destroy_process_group()
Esempio n. 2
0
def worker_process(rank: int, world_size: int, filename: str,
                   filename_rpc: str, func: Callable, args: Any,
                   error_queue: Any) -> None:
    """Main function for unit tests launced with torch_spawn"""

    if not dist_init(rank, world_size, filename, filename_rpc):
        logging.warning("failed initializing torch distributed")
        return

    kwargs = {}
    if "OMPI_COMM_WORLD_RANK" not in os.environ:
        kwargs["pipeline_backend"] = "gloo"

    initialize_model_parallel(1, world_size, **kwargs)

    try:
        func(*args)
        teardown()
    except BaseException as e:
        logging.warning(f" Rank {rank}: {e}")

        # Make sure that the group is properly destroyed, even for tests which check for exceptions being raised
        teardown()

        # If the function raises 'Skipped', this indicates pytest.skip(), so
        # forward it to parent so we can call pytest.skip() there
        if e.__class__.__name__ == "Skipped":
            error_queue.put(str(e))
            return

        raise e
Esempio n. 3
0
def benchmark_multiprocess(rank, world_size, args):

    init_method_pgroup = "tcp://localhost:{}".format(MPI_PORT)
    # TODO(anj-s): Add regression benchmarks for nccl as well.
    torch.distributed.init_process_group(
        backend="gloo", rank=rank, world_size=world_size, init_method=init_method_pgroup
    )

    torch.cuda.set_device(rank % torch.cuda.device_count())
    # TODO(anj-s): Move to TensorPipeRpcBackendOptions.
    rpc.init_rpc(
        f"Test{rank}",
        rank=rank,
        world_size=world_size,
        backend=rpc.BackendType.PROCESS_GROUP,
        rpc_backend_options=rpc.ProcessGroupRpcBackendOptions(
            rpc_timeout=20, init_method="tcp://localhost:{}".format(RPC_PORT)
        ),
    )
    initialize_model_parallel(1, world_size)
    init_random_seed(0)
    run_mp_worker(args, world_size)

    rpc.shutdown()
    torch.distributed.destroy_process_group()
Esempio n. 4
0
def run_worker(rank, world_size, args):
    if args.world_size != 0:
        world_size = args.world_size
    dist_init(rank + args.rank_base, world_size, hostname=args.host)
    initialize_model_parallel(1, world_size)
    init_random_seed(0)
    run_mp_worker(args, world_size)

    rpc.shutdown()
    torch.distributed.destroy_process_group()
Esempio n. 5
0
def run(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "10638"
    torch.distributed.init_process_group(backend="nccl",
                                         rank=rank,
                                         world_size=world_size)
    os.environ["MASTER_PORT"] = "10639"
    torch.distributed.rpc.init_rpc(f"worker{rank}",
                                   rank=rank,
                                   world_size=world_size)
    initialize_model_parallel(1, world_size)

    model = nn.Sequential(torch.nn.Linear(10, 10), torch.nn.ReLU(),
                          torch.nn.Linear(10, 5))
    target = torch.randint(0, 2, size=(20, 1)).squeeze()
    data = torch.randn(20, 10)
    loss_fn = F.nll_loss

    device = torch.device("cuda", rank)

    model = fairscale.nn.Pipe(
        model,
        balance=[2, 1],
        style=fairscale.nn.Pipe.MultiProcess,
        worker_map={
            0: "worker0",
            1: "worker1"
        },  # Needed to convert ranks to RPC worker names
        input_device=device,
    ).to(device)

    # define optimizer and loss function
    optimizer = optim.SGD(model.parameters(), lr=0.001)

    # zero the parameter gradients
    optimizer.zero_grad()

    # outputs and target need to be on the same device
    # forward step
    outputs = model(data.to(device))
    # compute loss
    if rank == 1:
        loss = loss_fn(outputs.to(device), target.to(device))

        # backward + optimize
        loss.backward()
        optimizer.step()
    else:
        model.back_helper(outputs)

    print(f"Finished Training Step on {rank}")

    del model
Esempio n. 6
0
def run(rank, world_size):
    torch_pg.init_mpi()
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "10638"
    torch.distributed.init_process_group(backend="nccl",
                                         rank=rank,
                                         world_size=world_size)
    os.environ["MASTER_PORT"] = "10639"
    torch.distributed.rpc.init_rpc(f"worker{rank}",
                                   rank=rank,
                                   world_size=world_size)
    initialize_model_parallel(1, world_size, pipeline_backend="mpi")

    if rank == 1:
        # For RPC, all ranks other than 0 just need to call rpc.shutdown()
        torch.distributed.rpc.shutdown()
        return

    model = nn.Sequential(torch.nn.Linear(10, 10), torch.nn.ReLU(),
                          torch.nn.Linear(10, 5))
    target = torch.randint(0, 2, size=(20, 1)).squeeze()
    data = torch.randn(20, 10)
    loss_fn = F.nll_loss

    device = torch.device("cuda", rank)

    model = fairscale.nn.PipeRPCWrapper(
        model,
        balance=[2, 1],
        worker_map={
            0: "worker0",
            1: "worker1"
        },  # Needed to convert ranks to RPC worker names
        input_device=device,
    ).to(device)

    # We can't directly access the model on each worker, so we need to call
    # foreach_worker with a callback to setup the optimizer
    model.foreach_worker(register_optimizer, {"lr": 0.001}, include_self=True)

    outputs = model(data.to(device))
    loss = loss_fn(outputs.to(device), target.to(device))
    loss.backward()

    # Same as earlier, use foreach_worker to step the optimizer on each rank
    model.foreach_worker(run_optimizer, include_self=True)

    print(f"Finished Training Step on {rank}")

    torch.distributed.rpc.shutdown()

    del model
Esempio n. 7
0
        def replacement(*args: Any, **kwargs: Any) -> None:
            assert args == tuple()
            assert world_sizes is not None  # mypy crutch

            args = tuple(
                kwargs[p] for p in parameters if p != "rank"
            )  # converting named parameters to positional parameters to pass to `spawn`

            error_queue = multiprocessing.get_context("spawn").SimpleQueue()
            if "OMPI_COMM_WORLD_RANK" in os.environ:
                # TODO (Min): this global used to be assigned every time this file is imported.
                #     I changed it to be assigned on first use. Should be the same, but I am not
                #     sure this is used or is correct since different processes would have different
                #     file names to init_process_group below. By initing, here, we don't leave
                #     a temp file behind on importing time.
                global filename_mpi
                if filename_mpi is None:
                    filename_mpi = tempfile.mkstemp()[1]

                os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
                os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
                torch.distributed.init_process_group(
                    "mpi", init_method=f"file://{filename_mpi}")

                world_size = torch.distributed.get_world_size()
                destroy_model_parallel()
                initialize_model_parallel(1, world_size)
                torch.cuda.set_device(torch.distributed.get_rank() %
                                      torch.cuda.device_count())
                if world_size in world_sizes:
                    try:
                        func(*args)
                        teardown()
                    except BaseException as e:
                        teardown()
                        import traceback

                        print(f"{traceback.format_exc()}")
                        raise e
                else:
                    pytest.skip(
                        "Requested world size doesn't match current world size"
                    )
            else:
                spawn_for_all_world_sizes(worker_process, world_sizes,
                                          (func, args, error_queue))

            if not error_queue.empty():
                msg = error_queue.get()
                pytest.skip(msg)
def run(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "10638"
    dist_init(rank, world_size)
    os.environ["MASTER_PORT"] = "10639"
    dist.rpc.init_rpc(f"worker{rank}", rank=rank, world_size=world_size)
    initialize_model_parallel(1, world_size)

    model = get_model()
    data, target = get_data()[0]
    loss_fn = get_loss_fun()

    device = torch.device("cuda",
                          rank) if DEVICE == "cuda" else torch.device("cpu")

    model = MultiProcessPipe(
        model,
        balance=[2, 1],
        style=MultiProcessPipe.MultiProcess,
        worker_map={
            0: "worker0",
            1: "worker1"
        },  # Needed to convert ranks to RPC worker names
        input_device=device,
    ).to(device)

    # define optimizer and loss function
    optimizer = optim.SGD(model.parameters(), lr=0.001)

    # zero the parameter gradients
    optimizer.zero_grad()

    # outputs and target need to be on the same device
    # forward step
    outputs = model(data.to(device))
    # compute loss
    if rank == 1:
        loss = loss_fn(outputs.to(device), target.to(device))

        # backward + optimize
        loss.backward()
        optimizer.step()
    else:
        model.back_helper(outputs)

    print(f"Finished Training Step on {rank}")
    dist.rpc.shutdown()

    del model
Esempio n. 9
0
def run(rank, world_size):
    torch_pg.init_mpi()
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "10638"
    dist_init(rank, world_size)  # FIXME (supports gloo)
    os.environ["MASTER_PORT"] = "10639"
    torch.distributed.rpc.init_rpc(f"worker{rank}",
                                   rank=rank,
                                   world_size=world_size)
    initialize_model_parallel(1, world_size, pipeline_backend="mpi")

    if rank == 1:
        # For RPC, all ranks other than 0 just need to call rpc.shutdown()
        torch.distributed.rpc.shutdown()
        return

    model = getModel()
    data, target = getData()[0]
    loss_fn = getLossFun()

    device = torch.device("cuda", rank)

    model = fairscale.nn.PipeRPCWrapper(
        model,
        balance=[2, 1],
        worker_map={
            0: "worker0",
            1: "worker1"
        },  # Needed to convert ranks to RPC worker names
        input_device=device,
    ).to(device)

    # We can't directly access the model on each worker, so we need to call
    # foreach_worker with a callback to setup the optimizer
    model.foreach_worker(register_optimizer, {"lr": 0.001}, include_self=True)

    outputs = model(data.to(device))
    loss = loss_fn(outputs.to(device), target.to(device))
    loss.backward()

    # Same as earlier, use foreach_worker to step the optimizer on each rank
    model.foreach_worker(run_optimizer, include_self=True)

    print(f"Finished Training Step on {rank}")

    torch.distributed.rpc.shutdown()

    del model
Esempio n. 10
0
def worker_process(rank, world_size, func, args, error_queue):
    """Main function for unit tests launced with torch_spawn"""

    dist_init(rank, world_size)
    kwargs = {}
    if "OMPI_COMM_WORLD_RANK" not in os.environ:
        kwargs["pipeline_backend"] = "gloo"
    initialize_model_parallel(1, world_size, **kwargs)
    try:
        func(*args)
    except BaseException as e:
        # If the function raises 'Skipped', this indicates pytest.skip(), so
        # forward it to parent so we can call pytest.skip() there
        if e.__class__.__name__ == "Skipped":
            error_queue.put(str(e))
            return
        raise e
Esempio n. 11
0
def worker_process(rank: int, world_size: int, filename: str,
                   filename_rpc: str, func: Callable, args: Any,
                   error_queue: Any) -> None:
    """Main function for unit tests launched with torch_spawn"""

    if not dist_init(rank, world_size, filename, filename_rpc):
        logging.warning("failed initializing torch distributed")
        teardown()
        return

    kwargs = {}
    if "OMPI_COMM_WORLD_RANK" not in os.environ:
        kwargs["pipeline_backend"] = "gloo"

    initialize_model_parallel(1, world_size, **kwargs)

    # Make sure that CUDA operations are repeatable
    context = (
        torch.backends.cudnn.flags(benchmark=False,
                                   deterministic=True)  # type: ignore
        if torch.cuda.is_available() and hasattr(torch.backends.cudnn, "flags")
        else contextlib.suppress())

    if torch.cuda.is_available() and not hasattr(torch.backends.cudnn,
                                                 "flags"):
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    try:
        with context:
            func(*args)
        teardown()
    except BaseException as e:
        logging.warning(f" Rank {rank}: {e}")

        # Make sure that the group is properly destroyed, even for tests which check for exceptions being raised
        teardown()

        # If the function raises 'Skipped', this indicates pytest.skip(), so
        # forward it to parent so we can call pytest.skip() there
        if e.__class__.__name__ == "Skipped":
            error_queue.put(str(e))
            return

        raise e
Esempio n. 12
0
def bench_mpi(args):
    guess_rank = int(os.environ["OMPI_COMM_WORLD_RANK"])
    world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"])
    local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"])
    os.environ["UCX_NET_DEVICES"] = best_device_map[local_rank]

    os.environ["MASTER_ADDR"] = args.host
    os.environ["MASTER_PORT"] = "10638"
    if args.socket_name:
        os.environ["GLOO_SOCKET_IFNAME"] = args.socket_name
        os.environ["TP_SOCKET_IFNAME"] = args.socket_name

    torch.distributed.init_process_group(backend="gloo",
                                         rank=guess_rank,
                                         world_size=world_size)

    os.environ["MASTER_ADDR"] = args.host
    os.environ["MASTER_PORT"] = "10639"
    init_method = f"tcp://{os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}"
    rank = torch.distributed.get_rank()
    world_size = torch.distributed.get_world_size()

    rpc.init_rpc(
        f"Test{rank}",
        rank=rank,
        world_size=world_size,
        backend=rpc.BackendType.PROCESS_GROUP,
        rpc_backend_options=rpc.ProcessGroupRpcBackendOptions(
            rpc_timeout=20, init_method=init_method),
    )

    backends = {
        "model_parallel_backend": "nccl",
        "pipeline_backend": "mpi",
        "ddp_backend": "nccl"
    }

    initialize_model_parallel(1, world_size, **backends)
    init_random_seed(0)

    run_mp_worker(args, world_size)

    rpc.shutdown()
    torch.distributed.destroy_process_group()
Esempio n. 13
0
        def replacement(*args: Any, **kwargs: Any) -> None:
            assert args == tuple()
            assert world_sizes is not None  # mypy crutch

            args = tuple(
                kwargs[p] for p in parameters if p != "rank"
            )  # converting named parameters to positional parameters to pass to `spawn`

            error_queue = multiprocessing.get_context("spawn").SimpleQueue()
            if "OMPI_COMM_WORLD_RANK" in os.environ:
                global filename_mpi

                os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
                os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
                torch.distributed.init_process_group(
                    "mpi", init_method=f"file://{filename_mpi}")

                world_size = torch.distributed.get_world_size()
                destroy_model_parallel()
                initialize_model_parallel(1, world_size)
                torch.cuda.set_device(torch.distributed.get_rank() %
                                      torch.cuda.device_count())
                if world_size in world_sizes:
                    try:
                        func(*args)
                        teardown()
                    except BaseException as e:
                        teardown()
                        import traceback

                        print(f"{traceback.format_exc()}")
                        raise e
                else:
                    pytest.skip(
                        "Requested world size doesn't match current world size"
                    )
            else:
                spawn_for_all_world_sizes(worker_process, world_sizes,
                                          (func, args, error_queue))

            if not error_queue.empty():
                msg = error_queue.get()
                pytest.skip(msg)
Esempio n. 14
0
        def replacement(*args, **kwargs):
            assert args == tuple()
            args = tuple(
                kwargs[p] for p in parameters if p != "rank"
            )  # converting named parameters to positional parameters to pass to `spawn`

            if "OMPI_COMM_WORLD_RANK" in os.environ:
                torch.distributed.init_process_group("mpi")
                world_size = torch.distributed.get_world_size()
                initialize_model_parallel(1, world_size)
                torch.cuda.set_device(torch.distributed.get_rank() %
                                      torch.cuda.device_count())
                if world_size in world_sizes:
                    func(*args)
                else:
                    pytest.skip(
                        f"requested world size doesn't match current world size"
                    )
            else:
                spawn_for_all_world_sizes(helper, world_sizes, (func, args))
Esempio n. 15
0
        def replacement(*args, **kwargs):
            assert args == tuple()
            args = tuple(
                kwargs[p] for p in parameters if p != "rank"
            )  # converting named parameters to positional parameters to pass to `spawn`

            error_queue = multiprocessing.get_context("spawn").SimpleQueue()
            if "OMPI_COMM_WORLD_RANK" in os.environ:
                os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
                os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
                os.environ["MASTER_ADDR"] = "localhost"
                os.environ["MASTER_PORT"] = "10638"
                torch.distributed.init_process_group("mpi")
                world_size = torch.distributed.get_world_size()
                initialize_model_parallel(1, world_size)
                torch.cuda.set_device(torch.distributed.get_rank() %
                                      torch.cuda.device_count())
                if world_size in world_sizes:
                    try:
                        func(*args)
                    except BaseException as e:
                        print(f"got exception {e} from test")
                        import traceback

                        print(f"{traceback.format_exc()}")
                        raise e
                else:
                    pytest.skip(
                        f"requested world size doesn't match current world size"
                    )
            else:
                spawn_for_all_world_sizes(worker_process, world_sizes,
                                          (func, args, error_queue))

            if not error_queue.empty():
                msg = error_queue.get()
                pytest.skip(msg)
Esempio n. 16
0
def helper(rank, world_size, func, args):
    dist_init(rank, world_size)
    initialize_model_parallel(1, world_size)
    func(*args)
Esempio n. 17
0
 def init_model_parallel_groups(self):
     num_model_parallel = 1  # TODO currently no support for vertical model parallel
     mpu.initialize_model_parallel(model_parallel_size_=num_model_parallel, pipeline_length=self.gpus_per_model)