Example #1
0
def test_traceenum_elbo(length):
    hidden_dim = 10
    transition = pyro.param("transition",
                            0.3 / hidden_dim + 0.7 * torch.eye(hidden_dim),
                            constraint=constraints.positive)
    means = pyro.param("means", torch.arange(float(hidden_dim)))
    data = 1 + 2 * torch.randn(length)

    @ignore_jit_warnings()
    def model(data):
        transition = pyro.param("transition")
        means = pyro.param("means")
        states = [torch.tensor(0)]
        for t in pyro.markov(range(len(data))):
            states.append(pyro.sample("states_{}".format(t),
                                      dist.Categorical(transition[states[-1]]),
                                      infer={"enumerate": "parallel"}))
            pyro.sample("obs_{}".format(t),
                        dist.Normal(means[states[-1]], 1.),
                        obs=data[t])
        return tuple(states)

    def guide(data):
        pass

    expected_loss = TraceEnum_ELBO(max_plate_nesting=0).differentiable_loss(model, guide, data)
    actual_loss = JitTraceEnum_ELBO(max_plate_nesting=0).differentiable_loss(model, guide, data)
    assert_equal(expected_loss, actual_loss)

    expected_grads = grad(expected_loss, [transition, means], allow_unused=True)
    actual_grads = grad(actual_loss, [transition, means], allow_unused=True)
    for e, a, name in zip(expected_grads, actual_grads, ["transition", "means"]):
        assert_equal(e, a, msg="bad gradient for {}".format(name))
Example #2
0
def make_svi(model,
             guide,
             args=None,
             kwargs=None,
             steps=1000,
             lr=0.05,
             cut_time=slice(None, None),
             max_steps=2000,
             ensure_convergence=False,
             loss='ELBO'):
    adam_params = {
        "lr": lr,
        "betas": (0.90, 0.999),
        'weight_decay': 0.005,
        'clip_norm': 10
    }
    optimizer = ClippedAdam(adam_params)

    #     svi = SVI(model, guide, optimizer, loss=trace_mle(cut_time))
    if loss == 'ELBO':
        svi = SVI(model, guide, optimizer, loss=JitTraceEnum_ELBO())
    if loss == 'MLE':
        svi = SVI(model, guide, optimizer, loss=trace_mle())

    pbar = tqdm(range(1, steps + 1))
    time_start = 0

    loss_arr = []

    for i in pbar:
        loss, time_start = make_step(svi, pbar, time_start, args, kwargs)
        loss_arr.append(loss)

    while ensure_convergence:
        std_prev = np.std(loss_arr[-20:-1])
        mean_cur = np.mean(loss_arr[-100:])
        mean_prev = np.mean(loss_arr[-200:-100])
        prob = stat.norm(mean_prev, std_prev).cdf(mean_cur)
        #         print(prob, mean_cur, mean_prev, std_prev)
        if mean_cur < mean_prev and prob < 0.05 and len(loss_arr) < max_steps:
            pbar = tqdm(range(1, 100 + 1), leave=False)
            for j in pbar:
                loss, time_start = make_step(svi,
                                             pbar,
                                             time_start,
                                             args,
                                             kwargs,
                                             prefix='Extra: ')
                loss_arr.append(loss)
        else:
            break

    return loss
Example #3
0
def svi(data, ratings, model, guide, epoch, model_type, if_save=True,
        if_print=True, num_sample=200):
    elbo = JitTraceEnum_ELBO(max_plate_nesting=1)
    svi_model = SVI(model,
                    guide,
                    optim.Adam({"lr": .005}),
                    loss=elbo,
                    num_samples=num_sample)

    pyro.clear_param_store()
    loss_list = []
    for i in range(epoch):
        ELBO = svi_model.step(data, ratings)
        loss_list.append(ELBO)
        if i % 500 == 0 and if_print:
            print(ELBO)
    if if_save:
        to_pickle(loss_list, "data_pickle/{}_svi_loss".format(model_type))
    return svi_model, loss_list
Example #4
0
def main(_argv):
    transition_alphas = torch.tensor([[10., 90.],
                                      [90., 10.]])
    emission_alphas = torch.tensor([[[30., 20., 5.]],
                                    [[5., 10., 100.]]])
    lengths = torch.randint(10, 30, (10000,))
    trace = poutine.trace(model).get_trace(transition_alphas, emission_alphas, lengths)
    obs_sequences = [site['value'] for name, site in
                     trace.nodes.items() if name.startswith("element_")]
    obs_sequences = torch.stack(obs_sequences, dim=-2)
    guide = AutoDelta(poutine.block(model, hide_fn=lambda site: site['name'].startswith('state')),
                      init_loc_fn=init_to_sample)
    svi = SVI(model, guide, Adam(dict(lr=0.1)), JitTraceEnum_ELBO())
    total = 1000
    with tqdm.trange(total) as t:
        for i in t:
            loss = svi.step(0.5 * torch.ones((2, 2), dtype=torch.float),
                            0.3 * torch.ones((2, 1, 3), dtype=torch.float),
                            lengths, obs_sequences)
            t.set_description_str(f"SVI ({i}/{total}): {loss}")
    median = guide.median()
    print("Transition probs: ", median['transition_probs'].detach().numpy())
    print("Emission probs: ", median['emission_probs'].squeeze().detach().numpy())
Example #5
0
def run_inference(dataset_obj: SingleCellRNACountsDataset,
                  args) -> RemoveBackgroundPyroModel:
    """Run a full inference procedure, training a latent variable model.

    Args:
        dataset_obj: Input data in the form of a SingleCellRNACountsDataset
            object.
        args: Input command line parsed arguments.

    Returns:
         model: cellbender.model.RemoveBackgroundPyroModel that has had
            inference run.

    """

    # Get the trimmed count matrix (transformed if called for).
    count_matrix = dataset_obj.get_count_matrix()

    # Configure pyro options (skip validations to improve speed).
    pyro.enable_validation(False)
    pyro.distributions.enable_validation(False)
    pyro.set_rng_seed(0)
    pyro.clear_param_store()

    # Set up the variational autoencoder:

    # Encoder.
    encoder_z = EncodeZ(input_dim=count_matrix.shape[1],
                        hidden_dims=args.z_hidden_dims,
                        output_dim=args.z_dim,
                        input_transform='normalize')

    encoder_other = EncodeNonZLatents(
        n_genes=count_matrix.shape[1],
        z_dim=args.z_dim,
        hidden_dims=consts.ENC_HIDDEN_DIMS,
        log_count_crossover=dataset_obj.priors['log_counts_crossover'],
        prior_log_cell_counts=np.log1p(dataset_obj.priors['cell_counts']),
        input_transform='normalize')

    encoder = CompositeEncoder({'z': encoder_z, 'other': encoder_other})

    # Decoder.
    decoder = Decoder(input_dim=args.z_dim,
                      hidden_dims=args.z_hidden_dims[::-1],
                      output_dim=count_matrix.shape[1])

    # Set up the pyro model for variational inference.
    model = RemoveBackgroundPyroModel(model_type=args.model,
                                      encoder=encoder,
                                      decoder=decoder,
                                      dataset_obj=dataset_obj,
                                      use_cuda=args.use_cuda)

    # Load the dataset into DataLoaders.
    frac = args.training_fraction  # Fraction of barcodes to use for training
    batch_size = int(
        min(300, frac * dataset_obj.analyzed_barcode_inds.size / 2))
    train_loader, test_loader = \
        prep_data_for_training(dataset=count_matrix,
                               empty_drop_dataset=
                               dataset_obj.get_count_matrix_empties(),
                               random_state=dataset_obj.random,
                               batch_size=batch_size,
                               training_fraction=frac,
                               fraction_empties=args.fraction_empties,
                               shuffle=True,
                               use_cuda=args.use_cuda)

    # Set up the optimizer.
    optimizer = pyro.optim.clipped_adam.ClippedAdam
    optimizer_args = {'lr': args.learning_rate, 'clip_norm': 10.}

    # Set up a learning rate scheduler.
    minibatches_per_epoch = int(
        np.ceil(len(train_loader) / train_loader.batch_size).item())
    scheduler_args = {
        'optimizer': optimizer,
        'max_lr': args.learning_rate * 10,
        'steps_per_epoch': minibatches_per_epoch,
        'epochs': args.epochs,
        'optim_args': optimizer_args
    }
    scheduler = pyro.optim.OneCycleLR(scheduler_args)

    # Determine the loss function.
    if args.use_jit:

        # Call guide() once as a warm-up.
        model.guide(
            torch.zeros([10, dataset_obj.analyzed_gene_inds.size
                         ]).to(model.device))

        if args.model == "simple":
            loss_function = JitTrace_ELBO()
        else:
            loss_function = JitTraceEnum_ELBO(max_plate_nesting=1,
                                              strict_enumeration_warning=False)
    else:

        if args.model == "simple":
            loss_function = Trace_ELBO()
        else:
            loss_function = TraceEnum_ELBO(max_plate_nesting=1)

    # Set up the inference process.
    svi = SVI(model.model, model.guide, scheduler, loss=loss_function)

    # Run training.
    run_training(model,
                 svi,
                 train_loader,
                 test_loader,
                 epochs=args.epochs,
                 test_freq=5)

    return model
Example #6
0
def main(args):
    # setup logging
    log = get_logger(args.log)
    log(args)

    data = poly.load_data(poly.JSB_CHORALES)
    training_seq_lengths = data['train']['sequence_lengths']
    training_data_sequences = data['train']['sequences']
    test_seq_lengths = data['test']['sequence_lengths']
    test_data_sequences = data['test']['sequences']
    val_seq_lengths = data['valid']['sequence_lengths']
    val_data_sequences = data['valid']['sequences']
    N_train_data = len(training_seq_lengths)
    N_train_time_slices = float(torch.sum(training_seq_lengths))
    N_mini_batches = int(N_train_data / args.mini_batch_size +
                         int(N_train_data % args.mini_batch_size > 0))

    log("N_train_data: %d     avg. training seq. length: %.2f    N_mini_batches: %d" %
        (N_train_data, training_seq_lengths.float().mean(), N_mini_batches))

    # how often we do validation/test evaluation during training
    val_test_frequency = 50
    # the number of samples we use to do the evaluation
    n_eval_samples = 1

    # package repeated copies of val/test data for faster evaluation
    # (i.e. set us up for vectorization)
    def rep(x):
        rep_shape = torch.Size([x.size(0) * n_eval_samples]) + x.size()[1:]
        repeat_dims = [1] * len(x.size())
        repeat_dims[0] = n_eval_samples
        return x.repeat(repeat_dims).reshape(n_eval_samples, -1).transpose(1, 0).reshape(rep_shape)

    # get the validation/test data ready for the dmm: pack into sequences, etc.
    val_seq_lengths = rep(val_seq_lengths)
    test_seq_lengths = rep(test_seq_lengths)
    val_batch, val_batch_reversed, val_batch_mask, val_seq_lengths = poly.get_mini_batch(
        torch.arange(n_eval_samples * val_data_sequences.shape[0]), rep(val_data_sequences),
        val_seq_lengths, cuda=args.cuda)
    test_batch, test_batch_reversed, test_batch_mask, test_seq_lengths = poly.get_mini_batch(
        torch.arange(n_eval_samples * test_data_sequences.shape[0]), rep(test_data_sequences),
        test_seq_lengths, cuda=args.cuda)

    # instantiate the dmm
    dmm = DMM(rnn_dropout_rate=args.rnn_dropout_rate, num_iafs=args.num_iafs,
              iaf_dim=args.iaf_dim, use_cuda=args.cuda)

    # setup optimizer
    adam_params = {"lr": args.learning_rate, "betas": (args.beta1, args.beta2),
                   "clip_norm": args.clip_norm, "lrd": args.lr_decay,
                   "weight_decay": args.weight_decay}
    adam = ClippedAdam(adam_params)

    # setup inference algorithm
    if args.tmcelbo:
        elbo = JitTraceEnum_ELBO() if args.jit else TraceEnum_ELBO()
        dmm_guide = config_enumerate(dmm.guide, default="parallel", num_samples=args.tmc_num_samples, expand=False)
        svi = SVI(dmm.model, dmm_guide, adam, loss=elbo)
    else:
        elbo = JitTrace_ELBO() if args.jit else Trace_ELBO()
        svi = SVI(dmm.model, dmm.guide, adam, loss=elbo)

    # now we're going to define some functions we need to form the main training loop

    # saves the model and optimizer states to disk
    def save_checkpoint():
        log("saving model to %s..." % args.save_model)
        torch.save(dmm.state_dict(), args.save_model)
        log("saving optimizer states to %s..." % args.save_opt)
        adam.save(args.save_opt)
        log("done saving model and optimizer checkpoints to disk.")

    # loads the model and optimizer states from disk
    def load_checkpoint():
        assert exists(args.load_opt) and exists(args.load_model), \
            "--load-model and/or --load-opt misspecified"
        log("loading model from %s..." % args.load_model)
        dmm.load_state_dict(torch.load(args.load_model))
        log("loading optimizer states from %s..." % args.load_opt)
        adam.load(args.load_opt)
        log("done loading model and optimizer states.")

    # prepare a mini-batch and take a gradient step to minimize -elbo
    def process_minibatch(epoch, which_mini_batch, shuffled_indices):
        if args.annealing_epochs > 0 and epoch < args.annealing_epochs:
            # compute the KL annealing factor approriate for the current mini-batch in the current epoch
            min_af = args.minimum_annealing_factor
            annealing_factor = min_af + (1.0 - min_af) * \
                (float(which_mini_batch + epoch * N_mini_batches + 1) /
                 float(args.annealing_epochs * N_mini_batches))
        else:
            # by default the KL annealing factor is unity
            annealing_factor = 1.0

        # compute which sequences in the training set we should grab
        mini_batch_start = (which_mini_batch * args.mini_batch_size)
        mini_batch_end = np.min([(which_mini_batch + 1) * args.mini_batch_size, N_train_data])
        mini_batch_indices = shuffled_indices[mini_batch_start:mini_batch_end]
        # grab a fully prepped mini-batch using the helper function in the data loader
        mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths \
            = poly.get_mini_batch(mini_batch_indices, training_data_sequences,
                                  training_seq_lengths, cuda=args.cuda)
        # do an actual gradient step
        loss = svi.step(mini_batch, mini_batch_reversed, mini_batch_mask,
                        mini_batch_seq_lengths, annealing_factor)
        # keep track of the training loss
        return loss

    # helper function for doing evaluation
    def do_evaluation():
        # put the RNN into evaluation mode (i.e. turn off drop-out if applicable)
        dmm.rnn.eval()

        # compute the validation and test loss n_samples many times
        val_nll = svi.evaluate_loss(val_batch, val_batch_reversed, val_batch_mask,
                                    val_seq_lengths) / torch.sum(val_seq_lengths)
        test_nll = svi.evaluate_loss(test_batch, test_batch_reversed, test_batch_mask,
                                     test_seq_lengths) / torch.sum(test_seq_lengths)

        # put the RNN back into training mode (i.e. turn on drop-out if applicable)
        dmm.rnn.train()
        return val_nll, test_nll

    # if checkpoint files provided, load model and optimizer states from disk before we start training
    if args.load_opt != '' and args.load_model != '':
        load_checkpoint()

    #################
    # TRAINING LOOP #
    #################
    times = [time.time()]
    for epoch in range(args.num_epochs):
        # if specified, save model and optimizer states to disk every checkpoint_freq epochs
        if args.checkpoint_freq > 0 and epoch > 0 and epoch % args.checkpoint_freq == 0:
            save_checkpoint()

        # accumulator for our estimate of the negative log likelihood (or rather -elbo) for this epoch
        epoch_nll = 0.0
        # prepare mini-batch subsampling indices for this epoch
        shuffled_indices = torch.randperm(N_train_data)

        # process each mini-batch; this is where we take gradient steps
        for which_mini_batch in range(N_mini_batches):
            epoch_nll += process_minibatch(epoch, which_mini_batch, shuffled_indices)

        # report training diagnostics
        times.append(time.time())
        epoch_time = times[-1] - times[-2]
        log("[training epoch %04d]  %.4f \t\t\t\t(dt = %.3f sec)" %
            (epoch, epoch_nll / N_train_time_slices, epoch_time))

        # do evaluation on test and validation data and report results
        if val_test_frequency > 0 and epoch > 0 and epoch % val_test_frequency == 0:
            val_nll, test_nll = do_evaluation()
            log("[val/test epoch %04d]  %.4f  %.4f" % (epoch, val_nll, test_nll))
Example #7
0
def guide():
    # Model 'lmbd' as a LogNormal variate.
    a = pyro.param('a', torch.zeros(1))
    b = pyro.param('b',
                   torch.ones(1),
                   constraint=pyro.distributions.constraints.positive)
    lmbd = pyro.sample('lmbd', dist.LogNormal(a, b))


# Data from the Luria-Delbruck experiment.
obs = torch.tensor(
    [1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1]).float()
cnd_model = pyro.poutine.condition(model, data={'x': obs})

optimizer = Adam({'lr': 0.05})
ELBO = JitTraceEnum_ELBO(max_plate_nesting=1)
svi = SVI(cnd_model, guide, optimizer, ELBO)

l = 0
for step in range(1, 101):
    l += svi.step()
    if step % 5 == 0:
        print(float(l))
        l = 0.

print('===')

a = pyro.param('a')
b = pyro.param('b')

#print(torch.distributions.log_normal.LogNormal(a,b).sample([100]).view(-1))
Example #8
0
def run_inference(dataset_obj: SingleCellRNACountsDataset,
                  args) -> RemoveBackgroundPyroModel:
    """Run a full inference procedure, training a latent variable model.

    Args:
        dataset_obj: Input data in the form of a SingleCellRNACountsDataset
            object.
        args: Input command line parsed arguments.

    Returns:
         model: cellbender.model.RemoveBackgroundPyroModel that has had
            inference run.

    """

    # Get the trimmed count matrix (transformed if called for).
    count_matrix = dataset_obj.get_count_matrix()

    # Configure pyro options (skip validations to improve speed).
    pyro.enable_validation(False)
    pyro.distributions.enable_validation(False)
    pyro.set_rng_seed(0)
    pyro.clear_param_store()

    # Set up the variational autoencoder:

    # Encoder.
    encoder_z = EncodeZ(input_dim=count_matrix.shape[1],
                        hidden_dims=args.z_hidden_dims,
                        output_dim=args.z_dim,
                        input_transform='normalize')

    encoder_d = EncodeD(
        input_dim=count_matrix.shape[1],
        hidden_dims=args.d_hidden_dims,
        output_dim=1,
        log_count_crossover=dataset_obj.priors['log_counts_crossover'])

    if args.model == "simple":

        # If using the simple model, there is no need for p.
        encoder = CompositeEncoder({'z': encoder_z, 'd_loc': encoder_d})

    else:

        # Models that include empty droplets.
        encoder_p = EncodeNonEmptyDropletLogitProb(
            input_dim=count_matrix.shape[1],
            hidden_dims=args.p_hidden_dims,
            output_dim=1,
            input_transform='normalize',
            log_count_crossover=dataset_obj.priors['log_counts_crossover'])
        encoder = CompositeEncoder({
            'z': encoder_z,
            'd_loc': encoder_d,
            'p_y': encoder_p
        })

    # Decoder.
    decoder = Decoder(input_dim=args.z_dim,
                      hidden_dims=args.z_hidden_dims[::-1],
                      output_dim=count_matrix.shape[1])

    # Set up the pyro model for variational inference.
    model = RemoveBackgroundPyroModel(model_type=args.model,
                                      encoder=encoder,
                                      decoder=decoder,
                                      dataset_obj=dataset_obj,
                                      use_cuda=args.use_cuda)

    # Set up the optimizer.
    adam_args = {"lr": args.learning_rate}
    optimizer = ClippedAdam(adam_args)

    # Determine the loss function.
    if args.use_jit:
        loss_function = JitTraceEnum_ELBO(max_plate_nesting=1,
                                          strict_enumeration_warning=False)
    else:
        loss_function = TraceEnum_ELBO(max_plate_nesting=1)

    if args.model == "simple":
        if args.use_jit:
            loss_function = JitTrace_ELBO()
        else:
            loss_function = Trace_ELBO()

    # Set up the inference process.
    svi = SVI(model.model, model.guide, optimizer, loss=loss_function)

    # Load the dataset into DataLoaders.
    frac = args.training_fraction  # Fraction of barcodes to use for training
    batch_size = int(
        min(500, frac * dataset_obj.analyzed_barcode_inds.size / 2))
    train_loader, test_loader = \
        prep_data_for_training(dataset=count_matrix,
                               empty_drop_dataset=
                               dataset_obj.get_count_matrix_empties(),
                               random_state=dataset_obj.random,
                               batch_size=batch_size,
                               training_fraction=frac,
                               fraction_empties=args.fraction_empties,
                               shuffle=True,
                               use_cuda=args.use_cuda)

    # Run training.
    run_training(model,
                 svi,
                 train_loader,
                 test_loader,
                 epochs=args.epochs,
                 test_freq=10)

    return model
Example #9
0
def run_inference(dataset_obj: Dataset, args) -> VariationalInferenceModel:
    """Run a full inference procedure, training a latent variable model.

    Args:
        dataset_obj: Input data in the form of a Dataset object.
        args: Input command line parsed arguments.

    Returns:
         model: cellbender.model.VariationalInferenceModel that has had
         inference run.

    """

    # Get the trimmed count matrix (transformed if called for).
    count_matrix = dataset_obj.get_count_matrix()

    # Configure pyro options (skip validations to improve speed).
    pyro.enable_validation(False)
    pyro.distributions.enable_validation(False)
    pyro.set_rng_seed(0)
    pyro.clear_param_store()

    # Set up the variational autoencoder:

    # Encoder.
    encoder_z = EncodeZ(input_dim=count_matrix.shape[1],
                        hidden_dims=args.z_hidden_dims,
                        output_dim=args.z_dim,
                        input_transform='normalize')

    encoder_d = EncodeD(
        input_dim=count_matrix.shape[1],
        hidden_dims=args.d_hidden_dims,
        output_dim=1,
        log_count_crossover=dataset_obj.priors['log_counts_crossover'])

    if args.model[0] == "simple":

        # If using the simple model, there is no need for p.
        encoder = CompositeEncoder({'z': encoder_z, 'd_loc': encoder_d})

    else:

        # Models that include empty droplets.
        encoder_p = EncodePAmbient(
            input_dim=count_matrix.shape[1],
            hidden_dims=args.p_hidden_dims,
            output_dim=1,
            input_transform='normalize',
            log_count_crossover=dataset_obj.priors['log_counts_crossover'])
        encoder = CompositeEncoder({
            'z': encoder_z,
            'd_loc': encoder_d,
            'p_y': encoder_p
        })

    # Decoder.
    decoder = Decoder(input_dim=args.z_dim,
                      hidden_dims=args.z_hidden_dims[::-1],
                      output_dim=count_matrix.shape[1])

    # Set up the pyro model for variational inference.
    model = VariationalInferenceModel(
        model_type=args.model[0],
        encoder=encoder,
        decoder=decoder,
        dataset_obj=dataset_obj,
        use_decaying_avg_baseline=args.use_decaying_average_baseline,
        use_IAF=args.use_IAF,
        use_cuda=args.use_cuda)

    # Load the dataset into DataLoaders.
    frac = args.training_fraction
    batch_size = int(
        min(500, frac * dataset_obj.analyzed_barcode_inds.size / 2))
    train_loader, test_loader = \
        prep_data_for_training(dataset=count_matrix,
                               empty_drop_dataset=
                               dataset_obj.get_count_matrix_empties(),
                               batch_size=batch_size,
                               training_fraction=frac,
                               fraction_empties=args.fraction_empties,
                               shuffle=True,
                               use_cuda=args.use_cuda)

    # Run the guide once for Jit. (can hang on StopIteration if no test data!)
    # model.guide(test_loader.__iter__().__next__())  # This seems unnecessary

    # Set up the optimizer.
    adam_args = {"lr": args.learning_rate}
    optimizer = ClippedAdam(adam_args)

    # Determine the loss function.
    # loss_function = TraceEnum_ELBO(max_plate_nesting=1)
    loss_function = JitTraceEnum_ELBO(max_plate_nesting=1,
                                      strict_enumeration_warning=False)
    if args.model[0] == "simple":
        loss_function = JitTrace_ELBO()

    # Set up the inference process.
    svi = SVI(model.model, model.guide, optimizer, loss=loss_function)

    # Run training.
    run_training(model,
                 svi,
                 train_loader,
                 test_loader,
                 epochs=args.epochs,
                 test_freq=10)

    return model