def test_node_rank(tmp_path):
    environ = {
        "LSB_DJOB_RANKFILE": _make_rankfile(tmp_path),
        "LSB_JOBID": "1234",
        "JSM_NAMESPACE_SIZE": "4",
        "JSM_NAMESPACE_RANK": "3",
        "JSM_NAMESPACE_LOCAL_RANK": "1",
    }
    with mock.patch.dict(os.environ,
                         environ), mock.patch("socket.gethostname",
                                              return_value="10.10.10.2"):
        env = LSFEnvironment()
        assert env.node_rank() == 2
 def select_cluster_environment(self) -> ClusterEnvironment:
     if self._cluster_environment is not None:
         return self._cluster_environment
     if self.is_slurm_managing_tasks:
         env = SLURMEnvironment()
     elif TorchElasticEnvironment.is_using_torchelastic():
         env = TorchElasticEnvironment()
     elif KubeflowEnvironment.is_using_kubeflow():
         env = KubeflowEnvironment()
     elif LSFEnvironment.is_using_lsf():
         env = LSFEnvironment()
     else:
         env = LightningEnvironment()
     return env
示例#3
0
 def select_cluster_environment(self) -> ClusterEnvironment:
     if self._cluster_environment is not None:
         return self._cluster_environment
     if self._is_slurm_managing_tasks():
         env = SLURMEnvironment()
         rank_zero_info("Multiprocessing is handled by SLURM.")
     elif TorchElasticEnvironment.is_using_torchelastic():
         env = TorchElasticEnvironment()
     elif KubeflowEnvironment.is_using_kubeflow():
         env = KubeflowEnvironment()
     elif LSFEnvironment.is_using_lsf():
         env = LSFEnvironment()
     else:
         env = LightningEnvironment()
     return env
def test_detect():
    """Test the detection of a LSF environment configuration."""
    with mock.patch.dict(os.environ, {}, clear=True):
        assert not LSFEnvironment.detect()

    with mock.patch.dict(
            os.environ,
        {
            "LSB_DJOB_RANKFILE": "",
            "LSB_JOBID": "",
            "JSM_NAMESPACE_SIZE": "",
            "JSM_NAMESPACE_LOCAL_RANK": "",
        },
    ):
        assert LSFEnvironment.detect()
def test_missing_lsb_job_id():
    """Test an error when the job id cannot be found."""
    del os.environ["LSB_JOBID"]
    with pytest.raises(
            ValueError,
            match="Could not find job id in environment variable LSB_JOBID"):
        LSFEnvironment()
def test_missing_lsb_hosts():
    """Test an error when the lsb hosts list cannot be found."""
    del os.environ["LSB_HOSTS"]
    with pytest.raises(
            ValueError,
            match="Could not find hosts in environment variable LSB_HOSTS"):
        LSFEnvironment()
def test_missing_lsb_job_id(tmp_path):
    """Test an error when the job id cannot be found."""
    with mock.patch.dict(
            os.environ,
        {"LSB_DJOB_RANKFILE": _make_rankfile(tmp_path)}), pytest.raises(
            ValueError,
            match="Could not find job id in environment variable LSB_JOBID"):
        LSFEnvironment()
def test_manual_main_port_and_address(tmp_path):
    """Test a user can set the port manually through the MASTER_PORT env variable."""
    environ = {
        "LSB_DJOB_RANKFILE": _make_rankfile(tmp_path),
        "LSB_JOBID": "1234",
        "JSM_NAMESPACE_SIZE": "4",
        "JSM_NAMESPACE_RANK": "3",
        "JSM_NAMESPACE_LOCAL_RANK": "1",
    }
    with mock.patch.dict(os.environ,
                         environ), mock.patch("socket.gethostname",
                                              return_value="10.10.10.2"):
        env = LSFEnvironment()
        assert env.main_port == 10234
def test_attributes_from_environment_variables(tmp_path):
    """Test that the LSF environment takes the attributes from the environment variables."""
    environ = {
        "LSB_DJOB_RANKFILE": _make_rankfile(tmp_path),
        "LSB_JOBID": "1234",
        "JSM_NAMESPACE_SIZE": "4",
        "JSM_NAMESPACE_RANK": "3",
        "JSM_NAMESPACE_LOCAL_RANK": "1",
    }
    with mock.patch.dict(os.environ,
                         environ), mock.patch("socket.gethostname",
                                              return_value="10.10.10.2"):
        env = LSFEnvironment()
        assert env.creates_processes_externally
        assert env.main_address == "10.10.10.0"
        assert env.main_port == 10234
        assert env.world_size() == 4
        assert env.global_rank() == 3
        assert env.local_rank() == 1
        env.set_global_rank(100)
        assert env.global_rank() == 3
        env.set_world_size(100)
        assert env.world_size() == 4
        assert LSFEnvironment.detect()
def test_empty_lsb_djob_rankfile():
    """Test an error when the LSB_DJOB_RANKFILE is not populated."""
    with pytest.raises(
            ValueError,
            match="The environment variable `LSB_DJOB_RANKFILE` is empty"):
        LSFEnvironment()
def test_missing_lsb_djob_rankfile():
    """Test an error when the LSB_DJOB_RANKFILE cannot be found."""
    with pytest.raises(
            ValueError,
            match="Did not find the environment variable `LSB_DJOB_RANKFILE`"):
        LSFEnvironment()
def test_node_rank(_):
    env = LSFEnvironment()
    assert env.node_rank() == 2
def test_attributes_from_environment_variables():
    """Test that the LSF environment takes the attributes from the environment variables."""
    env = LSFEnvironment()
    assert env.creates_children()
    assert env.master_address() == "10.10.10.0"
    assert env.master_port() == 10234
    assert env.world_size() == 4
    assert env.global_rank() == 3
    assert env.local_rank() == 1
    env.set_global_rank(100)
    assert env.global_rank() == 3
    env.set_world_size(100)
    assert env.world_size() == 4
    assert LSFEnvironment.is_using_lsf()
def test_manual_master_port_and_address():
    """Test a user can set the port manually through the MASTER_PORT env variable."""
    env = LSFEnvironment()
    assert env.master_port() == 4321
示例#15
0
def test_attributes_from_environment_variables():
    """Test that the LSF environment takes the attributes from the environment variables."""
    env = LSFEnvironment()
    assert env.creates_processes_externally
    assert env.main_address == "10.10.10.0"
    assert env.main_port == 10234
    assert env.world_size() == 4
    assert env.global_rank() == 3
    assert env.local_rank() == 1
    env.set_global_rank(100)
    assert env.global_rank() == 3
    env.set_world_size(100)
    assert env.world_size() == 4
    assert LSFEnvironment.detect()
示例#16
0
def process_args(args=None, return_io=False):
    """
    Process arguments for running training
    """
    if not isinstance(args, argparse.Namespace):
        args = parse_args(args)

    args.loader_kwargs = dict()

    targs = dict(max_epochs=args.epochs, )

    targs['accumulate_grad_batches'] = args.accumulate

    env = None

    if args.ipu:
        targs['accelerator'] = 'ipu'
        targs['devices'] = process_gpus(args.gpus)
    else:
        targs['gpus'] = process_gpus(args.gpus)
        targs['num_nodes'] = args.num_nodes
        if args.lsf:
            ##########################################################################################
            # Currently coding against pytorch-lightning 1.4.3
            ##########################################################################################
            if args.num_workers > 4:
                print0(
                    "num_workers (-k) > 4 can lead to hanging on Summit -- setting to 4",
                    file=sys.stderr)
                args.num_workers = 4
            args.loader_kwargs[
                'num_workers'] = 1  # Set as a default. This will get overridden elsewhere
            args.loader_kwargs['multiprocessing_context'] = 'spawn'
            env = LSFEnvironment()
        elif args.slurm:
            env = SLURMEnvironment()

        if env is not None:
            global RANK
            global SIZE
            try:
                RANK = env.global_rank()
                SIZE = env.world_size()
            except:
                print(
                    ">>> Could not get global rank -- setting RANK to 0 and SIZE to 1",
                    file=sys.stderr)
                RANK = 0
                SIZE = 1

        if targs['gpus'] is not None:
            targs['accelerator'] = 'gpu'
            if targs['gpus'] == 1:
                targs['devices'] = 1
            else:
                if env is None:
                    raise ValueError(
                        'Please specify environment (--lsf or --slurm) if using more than one GPU'
                    )
                # parallel_devices = [torch.device(i) for i in range(torch.cuda.device_count()) if i < targs['gpus']]
                # precision_plugin = NativeMixedPrecisionPlugin(16, 'cuda')
                torch.cuda.set_device(env.local_rank())
                targs['devices'] = targs['gpus']
                targs['strategy'] = DDPStrategy(
                    find_unused_parameters=False,
                    cluster_environment=env,
                    #accelerator=GPUAccelerator(),
                    #parallel_devices=parallel_devices,
                    #precision_plugin=precision_plugin,
                )

                print(
                    "---- Rank %s  -  Using GPUAccelerator with DDPStrategy" %
                    env.global_rank(),
                    file=sys.stderr)
        else:
            targs['accelerator'] = 'cpu'

    del args.gpus

    if args.sanity:
        if isinstance(args.sanity, str):
            args.sanity = int(args.sanity)
        else:
            args.sanity = 4000
        targs['limit_train_batches'] = args.sanity
        targs['limit_val_batches'] = args.sanity // 4

    if args.lr_find:
        targs['auto_lr_find'] = True
    del args.lr_find

    if args.checkpoint is not None:
        if os.path.exists(args.checkpoint):
            targs['resume_from_checkpoint'] = args.checkpoint
        else:
            warnings.warn(
                "Ignoring -c/--checkpoint argument because {args.checkpoint} does not exist."
            )
            args.checkpoint = None

    if args.cuda_profile:
        targs['profiler'] = PyTorchProfiler(
            filename=f'pytorch_prof.{RANK:0{len(str(SIZE))}}', emit_nvtx=True)

    targs['replace_sampler_ddp'] = False

    args.loader_kwargs = dict()

    # make sure we are classifying if we are using adding classifier layers
    # to a resnet features model
    if args.features_checkpoint is not None:
        if args.manifold:
            raise ValueError(
                'Cannot use manifold loss (i.e. -M) if adding classifier (i.e. -F)'
            )
        args.classify = True

    data_mod = DeepIndexDataModule(args,
                                   keep_open=True,
                                   seed=args.seed + RANK,
                                   rank=RANK,
                                   size=SIZE)

    # if classification problem, use the number of taxa as the number of outputs
    if args.classify:
        args.n_outputs = data_mod.dataset.n_outputs

    args.input_nc = 136 if args.tnf else len(data_mod.dataset.vocab)

    model = process_model(args, taxa_table=data_mod.dataset.difile.taxa_table)

    if args.num_workers > 0:
        data_mod.dataset.close()

    ret = [model, args, targs]
    if return_io:
        ret.append(io)

    ret.append(data_mod)

    return tuple(ret)