def test_traceenum_elbo(length): hidden_dim = 10 transition = pyro.param("transition", 0.3 / hidden_dim + 0.7 * torch.eye(hidden_dim), constraint=constraints.positive) means = pyro.param("means", torch.arange(float(hidden_dim))) data = 1 + 2 * torch.randn(length) @ignore_jit_warnings() def model(data): transition = pyro.param("transition") means = pyro.param("means") states = [torch.tensor(0)] for t in pyro.markov(range(len(data))): states.append(pyro.sample("states_{}".format(t), dist.Categorical(transition[states[-1]]), infer={"enumerate": "parallel"})) pyro.sample("obs_{}".format(t), dist.Normal(means[states[-1]], 1.), obs=data[t]) return tuple(states) def guide(data): pass expected_loss = TraceEnum_ELBO(max_plate_nesting=0).differentiable_loss(model, guide, data) actual_loss = JitTraceEnum_ELBO(max_plate_nesting=0).differentiable_loss(model, guide, data) assert_equal(expected_loss, actual_loss) expected_grads = grad(expected_loss, [transition, means], allow_unused=True) actual_grads = grad(actual_loss, [transition, means], allow_unused=True) for e, a, name in zip(expected_grads, actual_grads, ["transition", "means"]): assert_equal(e, a, msg="bad gradient for {}".format(name))
def make_svi(model, guide, args=None, kwargs=None, steps=1000, lr=0.05, cut_time=slice(None, None), max_steps=2000, ensure_convergence=False, loss='ELBO'): adam_params = { "lr": lr, "betas": (0.90, 0.999), 'weight_decay': 0.005, 'clip_norm': 10 } optimizer = ClippedAdam(adam_params) # svi = SVI(model, guide, optimizer, loss=trace_mle(cut_time)) if loss == 'ELBO': svi = SVI(model, guide, optimizer, loss=JitTraceEnum_ELBO()) if loss == 'MLE': svi = SVI(model, guide, optimizer, loss=trace_mle()) pbar = tqdm(range(1, steps + 1)) time_start = 0 loss_arr = [] for i in pbar: loss, time_start = make_step(svi, pbar, time_start, args, kwargs) loss_arr.append(loss) while ensure_convergence: std_prev = np.std(loss_arr[-20:-1]) mean_cur = np.mean(loss_arr[-100:]) mean_prev = np.mean(loss_arr[-200:-100]) prob = stat.norm(mean_prev, std_prev).cdf(mean_cur) # print(prob, mean_cur, mean_prev, std_prev) if mean_cur < mean_prev and prob < 0.05 and len(loss_arr) < max_steps: pbar = tqdm(range(1, 100 + 1), leave=False) for j in pbar: loss, time_start = make_step(svi, pbar, time_start, args, kwargs, prefix='Extra: ') loss_arr.append(loss) else: break return loss
def svi(data, ratings, model, guide, epoch, model_type, if_save=True, if_print=True, num_sample=200): elbo = JitTraceEnum_ELBO(max_plate_nesting=1) svi_model = SVI(model, guide, optim.Adam({"lr": .005}), loss=elbo, num_samples=num_sample) pyro.clear_param_store() loss_list = [] for i in range(epoch): ELBO = svi_model.step(data, ratings) loss_list.append(ELBO) if i % 500 == 0 and if_print: print(ELBO) if if_save: to_pickle(loss_list, "data_pickle/{}_svi_loss".format(model_type)) return svi_model, loss_list
def main(_argv): transition_alphas = torch.tensor([[10., 90.], [90., 10.]]) emission_alphas = torch.tensor([[[30., 20., 5.]], [[5., 10., 100.]]]) lengths = torch.randint(10, 30, (10000,)) trace = poutine.trace(model).get_trace(transition_alphas, emission_alphas, lengths) obs_sequences = [site['value'] for name, site in trace.nodes.items() if name.startswith("element_")] obs_sequences = torch.stack(obs_sequences, dim=-2) guide = AutoDelta(poutine.block(model, hide_fn=lambda site: site['name'].startswith('state')), init_loc_fn=init_to_sample) svi = SVI(model, guide, Adam(dict(lr=0.1)), JitTraceEnum_ELBO()) total = 1000 with tqdm.trange(total) as t: for i in t: loss = svi.step(0.5 * torch.ones((2, 2), dtype=torch.float), 0.3 * torch.ones((2, 1, 3), dtype=torch.float), lengths, obs_sequences) t.set_description_str(f"SVI ({i}/{total}): {loss}") median = guide.median() print("Transition probs: ", median['transition_probs'].detach().numpy()) print("Emission probs: ", median['emission_probs'].squeeze().detach().numpy())
def run_inference(dataset_obj: SingleCellRNACountsDataset, args) -> RemoveBackgroundPyroModel: """Run a full inference procedure, training a latent variable model. Args: dataset_obj: Input data in the form of a SingleCellRNACountsDataset object. args: Input command line parsed arguments. Returns: model: cellbender.model.RemoveBackgroundPyroModel that has had inference run. """ # Get the trimmed count matrix (transformed if called for). count_matrix = dataset_obj.get_count_matrix() # Configure pyro options (skip validations to improve speed). pyro.enable_validation(False) pyro.distributions.enable_validation(False) pyro.set_rng_seed(0) pyro.clear_param_store() # Set up the variational autoencoder: # Encoder. encoder_z = EncodeZ(input_dim=count_matrix.shape[1], hidden_dims=args.z_hidden_dims, output_dim=args.z_dim, input_transform='normalize') encoder_other = EncodeNonZLatents( n_genes=count_matrix.shape[1], z_dim=args.z_dim, hidden_dims=consts.ENC_HIDDEN_DIMS, log_count_crossover=dataset_obj.priors['log_counts_crossover'], prior_log_cell_counts=np.log1p(dataset_obj.priors['cell_counts']), input_transform='normalize') encoder = CompositeEncoder({'z': encoder_z, 'other': encoder_other}) # Decoder. decoder = Decoder(input_dim=args.z_dim, hidden_dims=args.z_hidden_dims[::-1], output_dim=count_matrix.shape[1]) # Set up the pyro model for variational inference. model = RemoveBackgroundPyroModel(model_type=args.model, encoder=encoder, decoder=decoder, dataset_obj=dataset_obj, use_cuda=args.use_cuda) # Load the dataset into DataLoaders. frac = args.training_fraction # Fraction of barcodes to use for training batch_size = int( min(300, frac * dataset_obj.analyzed_barcode_inds.size / 2)) train_loader, test_loader = \ prep_data_for_training(dataset=count_matrix, empty_drop_dataset= dataset_obj.get_count_matrix_empties(), random_state=dataset_obj.random, batch_size=batch_size, training_fraction=frac, fraction_empties=args.fraction_empties, shuffle=True, use_cuda=args.use_cuda) # Set up the optimizer. optimizer = pyro.optim.clipped_adam.ClippedAdam optimizer_args = {'lr': args.learning_rate, 'clip_norm': 10.} # Set up a learning rate scheduler. minibatches_per_epoch = int( np.ceil(len(train_loader) / train_loader.batch_size).item()) scheduler_args = { 'optimizer': optimizer, 'max_lr': args.learning_rate * 10, 'steps_per_epoch': minibatches_per_epoch, 'epochs': args.epochs, 'optim_args': optimizer_args } scheduler = pyro.optim.OneCycleLR(scheduler_args) # Determine the loss function. if args.use_jit: # Call guide() once as a warm-up. model.guide( torch.zeros([10, dataset_obj.analyzed_gene_inds.size ]).to(model.device)) if args.model == "simple": loss_function = JitTrace_ELBO() else: loss_function = JitTraceEnum_ELBO(max_plate_nesting=1, strict_enumeration_warning=False) else: if args.model == "simple": loss_function = Trace_ELBO() else: loss_function = TraceEnum_ELBO(max_plate_nesting=1) # Set up the inference process. svi = SVI(model.model, model.guide, scheduler, loss=loss_function) # Run training. run_training(model, svi, train_loader, test_loader, epochs=args.epochs, test_freq=5) return model
def main(args): # setup logging log = get_logger(args.log) log(args) data = poly.load_data(poly.JSB_CHORALES) training_seq_lengths = data['train']['sequence_lengths'] training_data_sequences = data['train']['sequences'] test_seq_lengths = data['test']['sequence_lengths'] test_data_sequences = data['test']['sequences'] val_seq_lengths = data['valid']['sequence_lengths'] val_data_sequences = data['valid']['sequences'] N_train_data = len(training_seq_lengths) N_train_time_slices = float(torch.sum(training_seq_lengths)) N_mini_batches = int(N_train_data / args.mini_batch_size + int(N_train_data % args.mini_batch_size > 0)) log("N_train_data: %d avg. training seq. length: %.2f N_mini_batches: %d" % (N_train_data, training_seq_lengths.float().mean(), N_mini_batches)) # how often we do validation/test evaluation during training val_test_frequency = 50 # the number of samples we use to do the evaluation n_eval_samples = 1 # package repeated copies of val/test data for faster evaluation # (i.e. set us up for vectorization) def rep(x): rep_shape = torch.Size([x.size(0) * n_eval_samples]) + x.size()[1:] repeat_dims = [1] * len(x.size()) repeat_dims[0] = n_eval_samples return x.repeat(repeat_dims).reshape(n_eval_samples, -1).transpose(1, 0).reshape(rep_shape) # get the validation/test data ready for the dmm: pack into sequences, etc. val_seq_lengths = rep(val_seq_lengths) test_seq_lengths = rep(test_seq_lengths) val_batch, val_batch_reversed, val_batch_mask, val_seq_lengths = poly.get_mini_batch( torch.arange(n_eval_samples * val_data_sequences.shape[0]), rep(val_data_sequences), val_seq_lengths, cuda=args.cuda) test_batch, test_batch_reversed, test_batch_mask, test_seq_lengths = poly.get_mini_batch( torch.arange(n_eval_samples * test_data_sequences.shape[0]), rep(test_data_sequences), test_seq_lengths, cuda=args.cuda) # instantiate the dmm dmm = DMM(rnn_dropout_rate=args.rnn_dropout_rate, num_iafs=args.num_iafs, iaf_dim=args.iaf_dim, use_cuda=args.cuda) # setup optimizer adam_params = {"lr": args.learning_rate, "betas": (args.beta1, args.beta2), "clip_norm": args.clip_norm, "lrd": args.lr_decay, "weight_decay": args.weight_decay} adam = ClippedAdam(adam_params) # setup inference algorithm if args.tmcelbo: elbo = JitTraceEnum_ELBO() if args.jit else TraceEnum_ELBO() dmm_guide = config_enumerate(dmm.guide, default="parallel", num_samples=args.tmc_num_samples, expand=False) svi = SVI(dmm.model, dmm_guide, adam, loss=elbo) else: elbo = JitTrace_ELBO() if args.jit else Trace_ELBO() svi = SVI(dmm.model, dmm.guide, adam, loss=elbo) # now we're going to define some functions we need to form the main training loop # saves the model and optimizer states to disk def save_checkpoint(): log("saving model to %s..." % args.save_model) torch.save(dmm.state_dict(), args.save_model) log("saving optimizer states to %s..." % args.save_opt) adam.save(args.save_opt) log("done saving model and optimizer checkpoints to disk.") # loads the model and optimizer states from disk def load_checkpoint(): assert exists(args.load_opt) and exists(args.load_model), \ "--load-model and/or --load-opt misspecified" log("loading model from %s..." % args.load_model) dmm.load_state_dict(torch.load(args.load_model)) log("loading optimizer states from %s..." % args.load_opt) adam.load(args.load_opt) log("done loading model and optimizer states.") # prepare a mini-batch and take a gradient step to minimize -elbo def process_minibatch(epoch, which_mini_batch, shuffled_indices): if args.annealing_epochs > 0 and epoch < args.annealing_epochs: # compute the KL annealing factor approriate for the current mini-batch in the current epoch min_af = args.minimum_annealing_factor annealing_factor = min_af + (1.0 - min_af) * \ (float(which_mini_batch + epoch * N_mini_batches + 1) / float(args.annealing_epochs * N_mini_batches)) else: # by default the KL annealing factor is unity annealing_factor = 1.0 # compute which sequences in the training set we should grab mini_batch_start = (which_mini_batch * args.mini_batch_size) mini_batch_end = np.min([(which_mini_batch + 1) * args.mini_batch_size, N_train_data]) mini_batch_indices = shuffled_indices[mini_batch_start:mini_batch_end] # grab a fully prepped mini-batch using the helper function in the data loader mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths \ = poly.get_mini_batch(mini_batch_indices, training_data_sequences, training_seq_lengths, cuda=args.cuda) # do an actual gradient step loss = svi.step(mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths, annealing_factor) # keep track of the training loss return loss # helper function for doing evaluation def do_evaluation(): # put the RNN into evaluation mode (i.e. turn off drop-out if applicable) dmm.rnn.eval() # compute the validation and test loss n_samples many times val_nll = svi.evaluate_loss(val_batch, val_batch_reversed, val_batch_mask, val_seq_lengths) / torch.sum(val_seq_lengths) test_nll = svi.evaluate_loss(test_batch, test_batch_reversed, test_batch_mask, test_seq_lengths) / torch.sum(test_seq_lengths) # put the RNN back into training mode (i.e. turn on drop-out if applicable) dmm.rnn.train() return val_nll, test_nll # if checkpoint files provided, load model and optimizer states from disk before we start training if args.load_opt != '' and args.load_model != '': load_checkpoint() ################# # TRAINING LOOP # ################# times = [time.time()] for epoch in range(args.num_epochs): # if specified, save model and optimizer states to disk every checkpoint_freq epochs if args.checkpoint_freq > 0 and epoch > 0 and epoch % args.checkpoint_freq == 0: save_checkpoint() # accumulator for our estimate of the negative log likelihood (or rather -elbo) for this epoch epoch_nll = 0.0 # prepare mini-batch subsampling indices for this epoch shuffled_indices = torch.randperm(N_train_data) # process each mini-batch; this is where we take gradient steps for which_mini_batch in range(N_mini_batches): epoch_nll += process_minibatch(epoch, which_mini_batch, shuffled_indices) # report training diagnostics times.append(time.time()) epoch_time = times[-1] - times[-2] log("[training epoch %04d] %.4f \t\t\t\t(dt = %.3f sec)" % (epoch, epoch_nll / N_train_time_slices, epoch_time)) # do evaluation on test and validation data and report results if val_test_frequency > 0 and epoch > 0 and epoch % val_test_frequency == 0: val_nll, test_nll = do_evaluation() log("[val/test epoch %04d] %.4f %.4f" % (epoch, val_nll, test_nll))
def guide(): # Model 'lmbd' as a LogNormal variate. a = pyro.param('a', torch.zeros(1)) b = pyro.param('b', torch.ones(1), constraint=pyro.distributions.constraints.positive) lmbd = pyro.sample('lmbd', dist.LogNormal(a, b)) # Data from the Luria-Delbruck experiment. obs = torch.tensor( [1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1]).float() cnd_model = pyro.poutine.condition(model, data={'x': obs}) optimizer = Adam({'lr': 0.05}) ELBO = JitTraceEnum_ELBO(max_plate_nesting=1) svi = SVI(cnd_model, guide, optimizer, ELBO) l = 0 for step in range(1, 101): l += svi.step() if step % 5 == 0: print(float(l)) l = 0. print('===') a = pyro.param('a') b = pyro.param('b') #print(torch.distributions.log_normal.LogNormal(a,b).sample([100]).view(-1))
def run_inference(dataset_obj: SingleCellRNACountsDataset, args) -> RemoveBackgroundPyroModel: """Run a full inference procedure, training a latent variable model. Args: dataset_obj: Input data in the form of a SingleCellRNACountsDataset object. args: Input command line parsed arguments. Returns: model: cellbender.model.RemoveBackgroundPyroModel that has had inference run. """ # Get the trimmed count matrix (transformed if called for). count_matrix = dataset_obj.get_count_matrix() # Configure pyro options (skip validations to improve speed). pyro.enable_validation(False) pyro.distributions.enable_validation(False) pyro.set_rng_seed(0) pyro.clear_param_store() # Set up the variational autoencoder: # Encoder. encoder_z = EncodeZ(input_dim=count_matrix.shape[1], hidden_dims=args.z_hidden_dims, output_dim=args.z_dim, input_transform='normalize') encoder_d = EncodeD( input_dim=count_matrix.shape[1], hidden_dims=args.d_hidden_dims, output_dim=1, log_count_crossover=dataset_obj.priors['log_counts_crossover']) if args.model == "simple": # If using the simple model, there is no need for p. encoder = CompositeEncoder({'z': encoder_z, 'd_loc': encoder_d}) else: # Models that include empty droplets. encoder_p = EncodeNonEmptyDropletLogitProb( input_dim=count_matrix.shape[1], hidden_dims=args.p_hidden_dims, output_dim=1, input_transform='normalize', log_count_crossover=dataset_obj.priors['log_counts_crossover']) encoder = CompositeEncoder({ 'z': encoder_z, 'd_loc': encoder_d, 'p_y': encoder_p }) # Decoder. decoder = Decoder(input_dim=args.z_dim, hidden_dims=args.z_hidden_dims[::-1], output_dim=count_matrix.shape[1]) # Set up the pyro model for variational inference. model = RemoveBackgroundPyroModel(model_type=args.model, encoder=encoder, decoder=decoder, dataset_obj=dataset_obj, use_cuda=args.use_cuda) # Set up the optimizer. adam_args = {"lr": args.learning_rate} optimizer = ClippedAdam(adam_args) # Determine the loss function. if args.use_jit: loss_function = JitTraceEnum_ELBO(max_plate_nesting=1, strict_enumeration_warning=False) else: loss_function = TraceEnum_ELBO(max_plate_nesting=1) if args.model == "simple": if args.use_jit: loss_function = JitTrace_ELBO() else: loss_function = Trace_ELBO() # Set up the inference process. svi = SVI(model.model, model.guide, optimizer, loss=loss_function) # Load the dataset into DataLoaders. frac = args.training_fraction # Fraction of barcodes to use for training batch_size = int( min(500, frac * dataset_obj.analyzed_barcode_inds.size / 2)) train_loader, test_loader = \ prep_data_for_training(dataset=count_matrix, empty_drop_dataset= dataset_obj.get_count_matrix_empties(), random_state=dataset_obj.random, batch_size=batch_size, training_fraction=frac, fraction_empties=args.fraction_empties, shuffle=True, use_cuda=args.use_cuda) # Run training. run_training(model, svi, train_loader, test_loader, epochs=args.epochs, test_freq=10) return model
def run_inference(dataset_obj: Dataset, args) -> VariationalInferenceModel: """Run a full inference procedure, training a latent variable model. Args: dataset_obj: Input data in the form of a Dataset object. args: Input command line parsed arguments. Returns: model: cellbender.model.VariationalInferenceModel that has had inference run. """ # Get the trimmed count matrix (transformed if called for). count_matrix = dataset_obj.get_count_matrix() # Configure pyro options (skip validations to improve speed). pyro.enable_validation(False) pyro.distributions.enable_validation(False) pyro.set_rng_seed(0) pyro.clear_param_store() # Set up the variational autoencoder: # Encoder. encoder_z = EncodeZ(input_dim=count_matrix.shape[1], hidden_dims=args.z_hidden_dims, output_dim=args.z_dim, input_transform='normalize') encoder_d = EncodeD( input_dim=count_matrix.shape[1], hidden_dims=args.d_hidden_dims, output_dim=1, log_count_crossover=dataset_obj.priors['log_counts_crossover']) if args.model[0] == "simple": # If using the simple model, there is no need for p. encoder = CompositeEncoder({'z': encoder_z, 'd_loc': encoder_d}) else: # Models that include empty droplets. encoder_p = EncodePAmbient( input_dim=count_matrix.shape[1], hidden_dims=args.p_hidden_dims, output_dim=1, input_transform='normalize', log_count_crossover=dataset_obj.priors['log_counts_crossover']) encoder = CompositeEncoder({ 'z': encoder_z, 'd_loc': encoder_d, 'p_y': encoder_p }) # Decoder. decoder = Decoder(input_dim=args.z_dim, hidden_dims=args.z_hidden_dims[::-1], output_dim=count_matrix.shape[1]) # Set up the pyro model for variational inference. model = VariationalInferenceModel( model_type=args.model[0], encoder=encoder, decoder=decoder, dataset_obj=dataset_obj, use_decaying_avg_baseline=args.use_decaying_average_baseline, use_IAF=args.use_IAF, use_cuda=args.use_cuda) # Load the dataset into DataLoaders. frac = args.training_fraction batch_size = int( min(500, frac * dataset_obj.analyzed_barcode_inds.size / 2)) train_loader, test_loader = \ prep_data_for_training(dataset=count_matrix, empty_drop_dataset= dataset_obj.get_count_matrix_empties(), batch_size=batch_size, training_fraction=frac, fraction_empties=args.fraction_empties, shuffle=True, use_cuda=args.use_cuda) # Run the guide once for Jit. (can hang on StopIteration if no test data!) # model.guide(test_loader.__iter__().__next__()) # This seems unnecessary # Set up the optimizer. adam_args = {"lr": args.learning_rate} optimizer = ClippedAdam(adam_args) # Determine the loss function. # loss_function = TraceEnum_ELBO(max_plate_nesting=1) loss_function = JitTraceEnum_ELBO(max_plate_nesting=1, strict_enumeration_warning=False) if args.model[0] == "simple": loss_function = JitTrace_ELBO() # Set up the inference process. svi = SVI(model.model, model.guide, optimizer, loss=loss_function) # Run training. run_training(model, svi, train_loader, test_loader, epochs=args.epochs, test_freq=10) return model