예제 #1
0
파일: testing.py 프로젝트: wns823/fairscale
def teardown() -> None:
    destroy_model_parallel()

    if torch.distributed.is_initialized():
        torch.distributed.destroy_process_group()
    try:
        torch.distributed.rpc.shutdown()
    except Exception:
        pass
예제 #2
0
def pytest_runtest_teardown(item):
    if "OMPI_COMM_WORLD_RANK" in os.environ:
        destroy_model_parallel()
        if torch.distributed.is_initialized():
            torch.distributed.destroy_process_group()
        try:
            torch.distributed.rpc.shutdown()
        except Exception:
            pass
예제 #3
0
def teardown() -> None:
    destroy_model_parallel()

    if torch.distributed.is_initialized():
        torch.distributed.destroy_process_group()
    try:
        # torch 1.5 hangs on shutdown if waiting for all processes
        torch.distributed.rpc.shutdown(graceful=False)
    except Exception:
        pass
예제 #4
0
        def replacement(*args: Any, **kwargs: Any) -> None:
            assert args == tuple()
            assert world_sizes is not None  # mypy crutch

            args = tuple(
                kwargs[p] for p in parameters if p != "rank"
            )  # converting named parameters to positional parameters to pass to `spawn`

            error_queue = multiprocessing.get_context("spawn").SimpleQueue()
            if "OMPI_COMM_WORLD_RANK" in os.environ:
                # TODO (Min): this global used to be assigned every time this file is imported.
                #     I changed it to be assigned on first use. Should be the same, but I am not
                #     sure this is used or is correct since different processes would have different
                #     file names to init_process_group below. By initing, here, we don't leave
                #     a temp file behind on importing time.
                global filename_mpi
                if filename_mpi is None:
                    filename_mpi = tempfile.mkstemp()[1]

                os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
                os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
                torch.distributed.init_process_group(
                    "mpi", init_method=f"file://{filename_mpi}")

                world_size = torch.distributed.get_world_size()
                destroy_model_parallel()
                initialize_model_parallel(1, world_size)
                torch.cuda.set_device(torch.distributed.get_rank() %
                                      torch.cuda.device_count())
                if world_size in world_sizes:
                    try:
                        func(*args)
                        teardown()
                    except BaseException as e:
                        teardown()
                        import traceback

                        print(f"{traceback.format_exc()}")
                        raise e
                else:
                    pytest.skip(
                        "Requested world size doesn't match current world size"
                    )
            else:
                spawn_for_all_world_sizes(worker_process, world_sizes,
                                          (func, args, error_queue))

            if not error_queue.empty():
                msg = error_queue.get()
                pytest.skip(msg)
예제 #5
0
        def replacement(*args: Any, **kwargs: Any) -> None:
            assert args == tuple()
            assert world_sizes is not None  # mypy crutch

            args = tuple(
                kwargs[p] for p in parameters if p != "rank"
            )  # converting named parameters to positional parameters to pass to `spawn`

            error_queue = multiprocessing.get_context("spawn").SimpleQueue()
            if "OMPI_COMM_WORLD_RANK" in os.environ:
                global filename_mpi

                os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
                os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
                torch.distributed.init_process_group(
                    "mpi", init_method=f"file://{filename_mpi}")

                world_size = torch.distributed.get_world_size()
                destroy_model_parallel()
                initialize_model_parallel(1, world_size)
                torch.cuda.set_device(torch.distributed.get_rank() %
                                      torch.cuda.device_count())
                if world_size in world_sizes:
                    try:
                        func(*args)
                        teardown()
                    except BaseException as e:
                        teardown()
                        import traceback

                        print(f"{traceback.format_exc()}")
                        raise e
                else:
                    pytest.skip(
                        "Requested world size doesn't match current world size"
                    )
            else:
                spawn_for_all_world_sizes(worker_process, world_sizes,
                                          (func, args, error_queue))

            if not error_queue.empty():
                msg = error_queue.get()
                pytest.skip(msg)