def test_node_rank(tmp_path): environ = { "LSB_DJOB_RANKFILE": _make_rankfile(tmp_path), "LSB_JOBID": "1234", "JSM_NAMESPACE_SIZE": "4", "JSM_NAMESPACE_RANK": "3", "JSM_NAMESPACE_LOCAL_RANK": "1", } with mock.patch.dict(os.environ, environ), mock.patch("socket.gethostname", return_value="10.10.10.2"): env = LSFEnvironment() assert env.node_rank() == 2
def select_cluster_environment(self) -> ClusterEnvironment: if self._cluster_environment is not None: return self._cluster_environment if self.is_slurm_managing_tasks: env = SLURMEnvironment() elif TorchElasticEnvironment.is_using_torchelastic(): env = TorchElasticEnvironment() elif KubeflowEnvironment.is_using_kubeflow(): env = KubeflowEnvironment() elif LSFEnvironment.is_using_lsf(): env = LSFEnvironment() else: env = LightningEnvironment() return env
def select_cluster_environment(self) -> ClusterEnvironment: if self._cluster_environment is not None: return self._cluster_environment if self._is_slurm_managing_tasks(): env = SLURMEnvironment() rank_zero_info("Multiprocessing is handled by SLURM.") elif TorchElasticEnvironment.is_using_torchelastic(): env = TorchElasticEnvironment() elif KubeflowEnvironment.is_using_kubeflow(): env = KubeflowEnvironment() elif LSFEnvironment.is_using_lsf(): env = LSFEnvironment() else: env = LightningEnvironment() return env
def test_detect(): """Test the detection of a LSF environment configuration.""" with mock.patch.dict(os.environ, {}, clear=True): assert not LSFEnvironment.detect() with mock.patch.dict( os.environ, { "LSB_DJOB_RANKFILE": "", "LSB_JOBID": "", "JSM_NAMESPACE_SIZE": "", "JSM_NAMESPACE_LOCAL_RANK": "", }, ): assert LSFEnvironment.detect()
def test_missing_lsb_job_id(): """Test an error when the job id cannot be found.""" del os.environ["LSB_JOBID"] with pytest.raises( ValueError, match="Could not find job id in environment variable LSB_JOBID"): LSFEnvironment()
def test_missing_lsb_hosts(): """Test an error when the lsb hosts list cannot be found.""" del os.environ["LSB_HOSTS"] with pytest.raises( ValueError, match="Could not find hosts in environment variable LSB_HOSTS"): LSFEnvironment()
def test_missing_lsb_job_id(tmp_path): """Test an error when the job id cannot be found.""" with mock.patch.dict( os.environ, {"LSB_DJOB_RANKFILE": _make_rankfile(tmp_path)}), pytest.raises( ValueError, match="Could not find job id in environment variable LSB_JOBID"): LSFEnvironment()
def test_manual_main_port_and_address(tmp_path): """Test a user can set the port manually through the MASTER_PORT env variable.""" environ = { "LSB_DJOB_RANKFILE": _make_rankfile(tmp_path), "LSB_JOBID": "1234", "JSM_NAMESPACE_SIZE": "4", "JSM_NAMESPACE_RANK": "3", "JSM_NAMESPACE_LOCAL_RANK": "1", } with mock.patch.dict(os.environ, environ), mock.patch("socket.gethostname", return_value="10.10.10.2"): env = LSFEnvironment() assert env.main_port == 10234
def test_attributes_from_environment_variables(tmp_path): """Test that the LSF environment takes the attributes from the environment variables.""" environ = { "LSB_DJOB_RANKFILE": _make_rankfile(tmp_path), "LSB_JOBID": "1234", "JSM_NAMESPACE_SIZE": "4", "JSM_NAMESPACE_RANK": "3", "JSM_NAMESPACE_LOCAL_RANK": "1", } with mock.patch.dict(os.environ, environ), mock.patch("socket.gethostname", return_value="10.10.10.2"): env = LSFEnvironment() assert env.creates_processes_externally assert env.main_address == "10.10.10.0" assert env.main_port == 10234 assert env.world_size() == 4 assert env.global_rank() == 3 assert env.local_rank() == 1 env.set_global_rank(100) assert env.global_rank() == 3 env.set_world_size(100) assert env.world_size() == 4 assert LSFEnvironment.detect()
def test_empty_lsb_djob_rankfile(): """Test an error when the LSB_DJOB_RANKFILE is not populated.""" with pytest.raises( ValueError, match="The environment variable `LSB_DJOB_RANKFILE` is empty"): LSFEnvironment()
def test_missing_lsb_djob_rankfile(): """Test an error when the LSB_DJOB_RANKFILE cannot be found.""" with pytest.raises( ValueError, match="Did not find the environment variable `LSB_DJOB_RANKFILE`"): LSFEnvironment()
def test_node_rank(_): env = LSFEnvironment() assert env.node_rank() == 2
def test_attributes_from_environment_variables(): """Test that the LSF environment takes the attributes from the environment variables.""" env = LSFEnvironment() assert env.creates_children() assert env.master_address() == "10.10.10.0" assert env.master_port() == 10234 assert env.world_size() == 4 assert env.global_rank() == 3 assert env.local_rank() == 1 env.set_global_rank(100) assert env.global_rank() == 3 env.set_world_size(100) assert env.world_size() == 4 assert LSFEnvironment.is_using_lsf()
def test_manual_master_port_and_address(): """Test a user can set the port manually through the MASTER_PORT env variable.""" env = LSFEnvironment() assert env.master_port() == 4321
def test_attributes_from_environment_variables(): """Test that the LSF environment takes the attributes from the environment variables.""" env = LSFEnvironment() assert env.creates_processes_externally assert env.main_address == "10.10.10.0" assert env.main_port == 10234 assert env.world_size() == 4 assert env.global_rank() == 3 assert env.local_rank() == 1 env.set_global_rank(100) assert env.global_rank() == 3 env.set_world_size(100) assert env.world_size() == 4 assert LSFEnvironment.detect()
def process_args(args=None, return_io=False): """ Process arguments for running training """ if not isinstance(args, argparse.Namespace): args = parse_args(args) args.loader_kwargs = dict() targs = dict(max_epochs=args.epochs, ) targs['accumulate_grad_batches'] = args.accumulate env = None if args.ipu: targs['accelerator'] = 'ipu' targs['devices'] = process_gpus(args.gpus) else: targs['gpus'] = process_gpus(args.gpus) targs['num_nodes'] = args.num_nodes if args.lsf: ########################################################################################## # Currently coding against pytorch-lightning 1.4.3 ########################################################################################## if args.num_workers > 4: print0( "num_workers (-k) > 4 can lead to hanging on Summit -- setting to 4", file=sys.stderr) args.num_workers = 4 args.loader_kwargs[ 'num_workers'] = 1 # Set as a default. This will get overridden elsewhere args.loader_kwargs['multiprocessing_context'] = 'spawn' env = LSFEnvironment() elif args.slurm: env = SLURMEnvironment() if env is not None: global RANK global SIZE try: RANK = env.global_rank() SIZE = env.world_size() except: print( ">>> Could not get global rank -- setting RANK to 0 and SIZE to 1", file=sys.stderr) RANK = 0 SIZE = 1 if targs['gpus'] is not None: targs['accelerator'] = 'gpu' if targs['gpus'] == 1: targs['devices'] = 1 else: if env is None: raise ValueError( 'Please specify environment (--lsf or --slurm) if using more than one GPU' ) # parallel_devices = [torch.device(i) for i in range(torch.cuda.device_count()) if i < targs['gpus']] # precision_plugin = NativeMixedPrecisionPlugin(16, 'cuda') torch.cuda.set_device(env.local_rank()) targs['devices'] = targs['gpus'] targs['strategy'] = DDPStrategy( find_unused_parameters=False, cluster_environment=env, #accelerator=GPUAccelerator(), #parallel_devices=parallel_devices, #precision_plugin=precision_plugin, ) print( "---- Rank %s - Using GPUAccelerator with DDPStrategy" % env.global_rank(), file=sys.stderr) else: targs['accelerator'] = 'cpu' del args.gpus if args.sanity: if isinstance(args.sanity, str): args.sanity = int(args.sanity) else: args.sanity = 4000 targs['limit_train_batches'] = args.sanity targs['limit_val_batches'] = args.sanity // 4 if args.lr_find: targs['auto_lr_find'] = True del args.lr_find if args.checkpoint is not None: if os.path.exists(args.checkpoint): targs['resume_from_checkpoint'] = args.checkpoint else: warnings.warn( "Ignoring -c/--checkpoint argument because {args.checkpoint} does not exist." ) args.checkpoint = None if args.cuda_profile: targs['profiler'] = PyTorchProfiler( filename=f'pytorch_prof.{RANK:0{len(str(SIZE))}}', emit_nvtx=True) targs['replace_sampler_ddp'] = False args.loader_kwargs = dict() # make sure we are classifying if we are using adding classifier layers # to a resnet features model if args.features_checkpoint is not None: if args.manifold: raise ValueError( 'Cannot use manifold loss (i.e. -M) if adding classifier (i.e. -F)' ) args.classify = True data_mod = DeepIndexDataModule(args, keep_open=True, seed=args.seed + RANK, rank=RANK, size=SIZE) # if classification problem, use the number of taxa as the number of outputs if args.classify: args.n_outputs = data_mod.dataset.n_outputs args.input_nc = 136 if args.tnf else len(data_mod.dataset.vocab) model = process_model(args, taxa_table=data_mod.dataset.difile.taxa_table) if args.num_workers > 0: data_mod.dataset.close() ret = [model, args, targs] if return_io: ret.append(io) ret.append(data_mod) return tuple(ret)