def test_default_attributes(): """ Test the default attributes when no environment variables are set. """ env = SLURMEnvironment() assert env.creates_children() assert env.master_address() == "127.0.0.1" assert env.master_port() == 12910 assert env.world_size() is None with pytest.raises(KeyError): # local rank is required to be passed as env variable env.local_rank() with pytest.raises(KeyError): # node_rank is required to be passed as env variable env.node_rank()
def test_default_attributes(): """Test the default attributes when no environment variables are set.""" env = SLURMEnvironment() assert env.creates_processes_externally assert env.main_address == "127.0.0.1" assert env.main_port == 12910 with pytest.raises(KeyError): # world size is required to be passed as env variable env.world_size() with pytest.raises(KeyError): # local rank is required to be passed as env variable env.local_rank() with pytest.raises(KeyError): # node_rank is required to be passed as env variable env.node_rank()
def test_attributes_from_environment_variables(caplog): """Test that the SLURM cluster environment takes the attributes from the environment variables.""" env = SLURMEnvironment() assert env.auto_requeue is True assert env.main_address == "1.1.1.1" assert env.main_port == 15000 + 1234 assert env.job_id() == int("0001234") assert env.world_size() == 20 assert env.global_rank() == 1 assert env.local_rank() == 2 assert env.node_rank() == 3 assert env.job_name() == "JOB" # setter should be no-op with caplog.at_level(logging.DEBUG, logger="pytorch_lightning.plugins.environments"): env.set_global_rank(100) assert env.global_rank() == 1 assert "setting global rank is not allowed" in caplog.text caplog.clear() with caplog.at_level(logging.DEBUG, logger="pytorch_lightning.plugins.environments"): env.set_world_size(100) assert env.world_size() == 20 assert "setting world size is not allowed" in caplog.text
def test_attributes_from_environment_variables(): """ Test that the SLURM cluster environment takes the attributes from the environment variables. """ env = SLURMEnvironment() assert env.master_address() == "1.1.1.1" assert env.master_port() == 15000 + 1234 assert env.world_size() is None assert env.local_rank() == 2 assert env.node_rank() == 3
def process_args(args, comm=None): """ Process arguments for running inference """ conf_args = process_config(args.config) for k, v in vars(conf_args).items(): if not hasattr(args, k): setattr(args, k, v) logger = args.logger # set up logger if logger is None: logger = logging.getLogger() logger.setLevel(logging.INFO) if args.debug: logger.setLevel(logging.DEBUG) args.logger = logger rank = 0 size = 1 local_rank = 0 env = None if args.slurm: env = SLURMEnvironment() elif args.lsf: env = LSFEnvironment() if env is not None: local_rank = env.local_rank() rank = env.global_rank() size = env.world_size() # Figure out the checkpoint file to read from # and where to save outputs to if args.output is None: if os.path.isdir(args.checkpoint): ckpt = list(glob.glob(f"{args.checkpoint}/*.ckpt")) if len(ckpt) == 0: print(f'No checkpoint file found in {args.checkpoint}', file=sys.stderr) sys.exit(1) elif len(ckpt) > 1: print( f'More than one checkpoint file found in {args.checkpoint}. ' 'Please specify checkpoint with -c', file=sys.stderr) sys.exit(1) args.checkpoint = ckpt[0] outdir = args.checkpoint if outdir.endswith('.ckpt'): outdir = outdir[:-5] if not os.path.isdir(outdir): os.mkdir(outdir) args.output = os.path.join(outdir, 'outputs.h5') # setting classify to so that we can get labels when # we load data. We do this here because we assume that # network is going to output features, and we want to use the # labels for downstream analysis args.classify = True # load the model and override batch size model = process_model(args, inference=True) model.set_inference(True) if args.batch_size is not None: model.hparams.batch_size = args.batch_size args.n_outputs = model.hparams.n_outputs args.save_seq_ids = model.hparams.window is not None # remove ResNet features if args.resnet_features: if 'ResNet' not in model.__class__.__name__: raise ValueError("Cannot use -f without ResNet model - got %s" % model.__class__.__name__) from .models.resnet import ResNetFeatures args.n_outputs = model.fc.in_features if isinstance( model.fc, nn.Linear) else model.fc[0].in_features model = ResNetFeatures(model) args.features = True if size > 1: dataset = LazySeqDataset(path=args.input, hparams=argparse.Namespace(**model.hparams), keep_open=True, comm=comm, size=size, rank=rank) else: dataset = LazySeqDataset(path=args.input, hparams=argparse.Namespace(**model.hparams), keep_open=True) tot_bases = dataset.orig_difile.get_seq_lengths().sum() args.logger.info( f'rank {rank} - processing {tot_bases} bases across {len(dataset)} samples' ) tmp_dset = dataset kwargs = dict(batch_size=args.batch_size, shuffle=False) if args.num_workers > 0: kwargs['num_workers'] = args.num_workers kwargs['multiprocessing_context'] = 'spawn' kwargs['worker_init_fn'] = dataset.worker_init kwargs['persistent_workers'] = True loader = get_loader(tmp_dset, inference=True, **kwargs) args.difile = dataset.difile # return the model, any arguments, and Lighting Trainer args just in case # we want to use them down the line when we figure out how to use Lightning for # inference if not args.features: model = nn.Sequential(model, nn.Softmax(dim=1)) model.eval() ret = [model, dataset, loader, args, env] if size > 1: args.device = torch.device('cuda:%d' % local_rank) else: args.device = torch.device('cuda') return tuple(ret)