def test_attributes_from_environment_variables(caplog):
    """Test that the SLURM cluster environment takes the attributes from the environment variables."""
    env = SLURMEnvironment()
    assert env.auto_requeue is True
    assert env.main_address == "1.1.1.1"
    assert env.main_port == 15000 + 1234
    assert env.job_id() == int("0001234")
    assert env.world_size() == 20
    assert env.global_rank() == 1
    assert env.local_rank() == 2
    assert env.node_rank() == 3
    assert env.job_name() == "JOB"
    # setter should be no-op
    with caplog.at_level(logging.DEBUG,
                         logger="pytorch_lightning.plugins.environments"):
        env.set_global_rank(100)
    assert env.global_rank() == 1
    assert "setting global rank is not allowed" in caplog.text

    caplog.clear()

    with caplog.at_level(logging.DEBUG,
                         logger="pytorch_lightning.plugins.environments"):
        env.set_world_size(100)
    assert env.world_size() == 20
    assert "setting world size is not allowed" in caplog.text
Exemple #2
0
def test_default_attributes():
    """Test the default attributes when no environment variables are set."""
    env = SLURMEnvironment()
    assert env.creates_processes_externally
    assert env.main_address == "127.0.0.1"
    assert env.main_port == 12910
    with pytest.raises(KeyError):
        # world size is required to be passed as env variable
        env.world_size()
    with pytest.raises(KeyError):
        # local rank is required to be passed as env variable
        env.local_rank()
    with pytest.raises(KeyError):
        # node_rank is required to be passed as env variable
        env.node_rank()
Exemple #3
0
def test_attributes_from_environment_variables():
    """ Test that the SLURM cluster environment takes the attributes from the environment variables. """
    env = SLURMEnvironment()
    assert env.master_address() == "1.1.1.1"
    assert env.master_port() == 15000 + 1234
    assert env.world_size() is None
    assert env.local_rank() == 2
    assert env.node_rank() == 3
Exemple #4
0
def test_default_attributes():
    """ Test the default attributes when no environment variables are set. """
    env = SLURMEnvironment()
    assert env.creates_children()
    assert env.master_address() == "127.0.0.1"
    assert env.master_port() == 12910
    assert env.world_size() is None
    with pytest.raises(KeyError):
        # local rank is required to be passed as env variable
        env.local_rank()
    with pytest.raises(KeyError):
        # node_rank is required to be passed as env variable
        env.node_rank()
Exemple #5
0
def process_args(args, comm=None):
    """
    Process arguments for running inference
    """
    conf_args = process_config(args.config)
    for k, v in vars(conf_args).items():
        if not hasattr(args, k):
            setattr(args, k, v)

    logger = args.logger
    # set up logger
    if logger is None:
        logger = logging.getLogger()
        logger.setLevel(logging.INFO)
    if args.debug:
        logger.setLevel(logging.DEBUG)
    args.logger = logger

    rank = 0
    size = 1
    local_rank = 0
    env = None

    if args.slurm:
        env = SLURMEnvironment()
    elif args.lsf:
        env = LSFEnvironment()

    if env is not None:
        local_rank = env.local_rank()
        rank = env.global_rank()
        size = env.world_size()

    # Figure out the checkpoint file to read from
    # and where to save outputs to
    if args.output is None:
        if os.path.isdir(args.checkpoint):
            ckpt = list(glob.glob(f"{args.checkpoint}/*.ckpt"))
            if len(ckpt) == 0:
                print(f'No checkpoint file found in {args.checkpoint}',
                      file=sys.stderr)
                sys.exit(1)
            elif len(ckpt) > 1:
                print(
                    f'More than one checkpoint file found in {args.checkpoint}. '
                    'Please specify checkpoint with -c',
                    file=sys.stderr)
                sys.exit(1)
            args.checkpoint = ckpt[0]
        outdir = args.checkpoint
        if outdir.endswith('.ckpt'):
            outdir = outdir[:-5]
        if not os.path.isdir(outdir):
            os.mkdir(outdir)
        args.output = os.path.join(outdir, 'outputs.h5')

    # setting classify to so that we can get labels when
    # we load data. We do this here because we assume that
    # network is going to output features, and we want to use the
    # labels for downstream analysis
    args.classify = True

    # load the model and override batch size
    model = process_model(args, inference=True)
    model.set_inference(True)
    if args.batch_size is not None:
        model.hparams.batch_size = args.batch_size

    args.n_outputs = model.hparams.n_outputs
    args.save_seq_ids = model.hparams.window is not None

    # remove ResNet features
    if args.resnet_features:
        if 'ResNet' not in model.__class__.__name__:
            raise ValueError("Cannot use -f without ResNet model - got %s" %
                             model.__class__.__name__)
        from .models.resnet import ResNetFeatures
        args.n_outputs = model.fc.in_features if isinstance(
            model.fc, nn.Linear) else model.fc[0].in_features
        model = ResNetFeatures(model)
        args.features = True

    if size > 1:
        dataset = LazySeqDataset(path=args.input,
                                 hparams=argparse.Namespace(**model.hparams),
                                 keep_open=True,
                                 comm=comm,
                                 size=size,
                                 rank=rank)
    else:
        dataset = LazySeqDataset(path=args.input,
                                 hparams=argparse.Namespace(**model.hparams),
                                 keep_open=True)

    tot_bases = dataset.orig_difile.get_seq_lengths().sum()
    args.logger.info(
        f'rank {rank} - processing {tot_bases} bases across {len(dataset)} samples'
    )

    tmp_dset = dataset

    kwargs = dict(batch_size=args.batch_size, shuffle=False)
    if args.num_workers > 0:
        kwargs['num_workers'] = args.num_workers
        kwargs['multiprocessing_context'] = 'spawn'
        kwargs['worker_init_fn'] = dataset.worker_init
        kwargs['persistent_workers'] = True
    loader = get_loader(tmp_dset, inference=True, **kwargs)

    args.difile = dataset.difile

    # return the model, any arguments, and Lighting Trainer args just in case
    # we want to use them down the line when we figure out how to use Lightning for
    # inference

    if not args.features:
        model = nn.Sequential(model, nn.Softmax(dim=1))

    model.eval()

    ret = [model, dataset, loader, args, env]

    if size > 1:
        args.device = torch.device('cuda:%d' % local_rank)
    else:
        args.device = torch.device('cuda')

    return tuple(ret)