def test_resolve_distributed_mode_slurm3():
    args = argparse.Namespace(
        multiprocessing_distributed=True,
        dist_world_size=None,
        dist_rank=None,
        ngpu=1,
        local_rank=None,
        dist_launcher="slurm",
        dist_backend="nccl",
        dist_init_method="env://",
        dist_master_addr=None,
        dist_master_port=10000,
    )
    env = dict(
        SLURM_PROCID="0",
        SLURM_NTASKS="1",
        SLURM_STEP_NUM_NODES="1",
        SLURM_STEP_NODELIST="localhost",
        SLURM_NODEID="0",
        CUDA_VISIBLE_DEVICES="0,1",
    )

    e = ProcessPoolExecutor(max_workers=2)
    with unittest.mock.patch.dict("os.environ", dict(env, SLURM_LOCALID="0")):
        resolve_distributed_mode(args)
        option = build_dataclass(DistributedOption, args)
        fn = e.submit(option.init)

    with unittest.mock.patch.dict("os.environ", dict(env, SLURM_LOCALID="0")):
        option2 = build_dataclass(DistributedOption, args)
        fn2 = e.submit(option2.init)

    fn.result()
    fn2.result()
def test_resolve_distributed_mode_slurm2(dist_init_method):
    args = argparse.Namespace(
        multiprocessing_distributed=False,
        dist_world_size=None,
        dist_rank=None,
        ngpu=2,
        local_rank=None,
        dist_launcher="slurm",
        dist_backend="nccl",
        dist_init_method=dist_init_method,
        dist_master_addr=None,
        dist_master_port=None,
    )
    with unittest.mock.patch.dict(
            "os.environ",
            dict(
                SLURM_PROCID="0",
                SLURM_NTASKS="2",
                SLURM_STEP_NUM_NODES="1",
                SLURM_STEP_NODELIST="host1",
                SLURM_NODEID="0",
                SLURM_LOCALID="0",
                CUDA_VISIBLE_DEVICES="0,1",
            ),
    ):
        with pytest.raises(RuntimeError):
            resolve_distributed_mode(args)
def test_default_work():
    parser = AbsTask.get_parser()
    args = parser.parse_args([])
    resolve_distributed_mode(args)
    option = build_dataclass(DistributedOption, args)
    option.init_options()
    option.init_torch_distributed()
def test_resolve_distributed_mode3(dist_init_method):
    args = argparse.Namespace(
        multiprocessing_distributed=False,
        dist_world_size=None,
        dist_rank=None,
        ngpu=2,
        local_rank=None,
        dist_launcher=None,
        dist_backend="nccl",
        dist_init_method=dist_init_method,
        dist_master_addr=None,
        dist_master_port=None,
    )
    resolve_distributed_mode(args)
def test_resolve_distributed_mode5(dist_init_method):
    args = argparse.Namespace(
        multiprocessing_distributed=False,
        dist_world_size=2,
        dist_rank=0,
        ngpu=2,
        local_rank=1,
        dist_launcher="slurm",
        dist_backend="nccl",
        dist_init_method=dist_init_method,
        dist_master_addr=None,
        dist_master_port=None,
    )
    with pytest.raises(RuntimeError):
        resolve_distributed_mode(args)
def test_resolve_distributed_mode7(dist_init_method):
    args = argparse.Namespace(
        multiprocessing_distributed=True,
        dist_world_size=2,
        dist_rank=0,
        ngpu=1,
        local_rank=None,
        dist_launcher=None,
        dist_backend="nccl",
        dist_init_method=dist_init_method,
        dist_master_addr=None,
        dist_master_port=None,
    )
    resolve_distributed_mode(args)
    assert args.distributed
    assert not args.multiprocessing_distributed