def initialize_optimizer(self, lr): optimizer = Adam({'lr': lr}) elbo = JitTrace_ELBO() if self.args.jit else Trace_ELBO() return SVI(self.model, self.guide, optimizer, loss=elbo)
def main(args): # clear param store pyro.clear_param_store() # setup MNIST data loaders # train_loader, test_loader train_loader, test_loader = setup_data_loaders(MNIST, use_cuda=args.cuda, batch_size=256) # setup the VAE vae = VAE(use_cuda=args.cuda) # setup the optimizer adam_args = {"lr": args.learning_rate} optimizer = Adam(adam_args) # setup the inference algorithm elbo = JitTrace_ELBO() if args.jit else Trace_ELBO() svi = SVI(vae.model, vae.guide, optimizer, loss=elbo) # setup visdom for visualization if args.visdom_flag: vis = visdom.Visdom() train_elbo = [] test_elbo = [] # training loop for epoch in range(args.num_epochs): # initialize loss accumulator epoch_loss = 0. # do a training epoch over each mini-batch x returned # by the data loader for x, _ in train_loader: # if on GPU put mini-batch into CUDA memory if args.cuda: x = x.cuda() # do ELBO gradient and accumulate loss epoch_loss += svi.step(x) # report training diagnostics normalizer_train = len(train_loader.dataset) total_epoch_loss_train = epoch_loss / normalizer_train train_elbo.append(total_epoch_loss_train) print("[epoch %03d] average training loss: %.4f" % (epoch, total_epoch_loss_train)) if epoch % args.test_frequency == 0: # initialize loss accumulator test_loss = 0. # compute the loss over the entire test set for i, (x, _) in enumerate(test_loader): # if on GPU put mini-batch into CUDA memory if args.cuda: x = x.cuda() # compute ELBO estimate and accumulate loss test_loss += svi.evaluate_loss(x) # pick three random test images from the first mini-batch and # visualize how well we're reconstructing them if i == 0: if args.visdom_flag: plot_vae_samples(vae, vis) reco_indices = np.random.randint(0, x.shape[0], 3) for index in reco_indices: test_img = x[index, :] reco_img = vae.reconstruct_img(test_img) vis.image(test_img.reshape( 28, 28).detach().cpu().numpy(), opts={'caption': 'test image'}) vis.image(reco_img.reshape( 28, 28).detach().cpu().numpy(), opts={'caption': 'reconstructed image'}) # report test diagnostics normalizer_test = len(test_loader.dataset) total_epoch_loss_test = test_loss / normalizer_test test_elbo.append(total_epoch_loss_test) print("[epoch %03d] average test loss: %.4f" % (epoch, total_epoch_loss_test)) if epoch == args.tsne_iter: mnist_test_tsne(vae=vae, test_loader=test_loader) plot_llk(np.array(train_elbo), np.array(test_elbo)) return vae
def main(args): """ run inference for SS-VAE :param args: arguments for SS-VAE :return: None """ if args.seed is not None: pyro.set_rng_seed(args.seed) viz = None if args.visualize: viz = Visdom() mkdir_p("./vae_results") # batch_size: number of images (and labels) to be considered in a batch ss_vae = SSVAE(z_dim=args.z_dim, hidden_layers=args.hidden_layers, use_cuda=args.cuda, config_enum=args.enum_discrete, aux_loss_multiplier=args.aux_loss_multiplier) # setup the optimizer adam_params = {"lr": args.learning_rate, "betas": (args.beta_1, 0.999)} optimizer = Adam(adam_params) # set up the loss(es) for inference. wrapping the guide in config_enumerate builds the loss as a sum # by enumerating each class label for the sampled discrete categorical distribution in the model guide = config_enumerate(ss_vae.guide, args.enum_discrete, expand=True) Elbo = JitTraceEnum_ELBO if args.jit else TraceEnum_ELBO elbo = Elbo(max_plate_nesting=1, strict_enumeration_warning=False) loss_basic = SVI(ss_vae.model, guide, optimizer, loss=elbo) # build a list of all losses considered losses = [loss_basic] # aux_loss: whether to use the auxiliary loss from NIPS 14 paper (Kingma et al) if args.aux_loss: elbo = JitTrace_ELBO() if args.jit else Trace_ELBO() loss_aux = SVI(ss_vae.model_classify, ss_vae.guide_classify, optimizer, loss=elbo) losses.append(loss_aux) try: # setup the logger if a filename is provided logger = open(args.logfile, "w") if args.logfile else None data_loaders = setup_data_loaders(MNISTCached, args.cuda, args.batch_size, sup_num=args.sup_num) # how often would a supervised batch be encountered during inference # e.g. if sup_num is 3000, we would have every 16th = int(50000/3000) batch supervised # until we have traversed through the all supervised batches periodic_interval_batches = int(MNISTCached.train_data_size / (1.0 * args.sup_num)) # number of unsupervised examples unsup_num = MNISTCached.train_data_size - args.sup_num # initializing local variables to maintain the best validation accuracy # seen across epochs over the supervised training set # and the corresponding testing set and the state of the networks best_valid_acc, corresponding_test_acc = 0.0, 0.0 # run inference for a certain number of epochs for i in range(0, args.num_epochs): # get the losses for an epoch epoch_losses_sup, epoch_losses_unsup = \ run_inference_for_epoch(data_loaders, losses, periodic_interval_batches) # compute average epoch losses i.e. losses per example avg_epoch_losses_sup = map(lambda v: v / args.sup_num, epoch_losses_sup) avg_epoch_losses_unsup = map(lambda v: v / unsup_num, epoch_losses_unsup) # store the loss and validation/testing accuracies in the logfile str_loss_sup = " ".join(map(str, avg_epoch_losses_sup)) str_loss_unsup = " ".join(map(str, avg_epoch_losses_unsup)) str_print = "{} epoch: avg losses {}".format( i, "{} {}".format(str_loss_sup, str_loss_unsup)) validation_accuracy = get_accuracy(data_loaders["valid"], ss_vae.classifier, args.batch_size) str_print += " validation accuracy {}".format(validation_accuracy) # this test accuracy is only for logging, this is not used # to make any decisions during training test_accuracy = get_accuracy(data_loaders["test"], ss_vae.classifier, args.batch_size) str_print += " test accuracy {}".format(test_accuracy) # update the best validation accuracy and the corresponding # testing accuracy and the state of the parent module (including the networks) if best_valid_acc < validation_accuracy: best_valid_acc = validation_accuracy corresponding_test_acc = test_accuracy print_and_log(logger, str_print) final_test_accuracy = get_accuracy(data_loaders["test"], ss_vae.classifier, args.batch_size) print_and_log( logger, "best validation accuracy {} corresponding testing accuracy {} " "last testing accuracy {}".format(best_valid_acc, corresponding_test_acc, final_test_accuracy)) # visualize the conditional samples visualize(ss_vae, viz, data_loaders["test"]) finally: # close the logger file object if we opened it earlier if args.logfile: logger.close()
def main(args): # setup logging log = get_logger(args.log) log(args) data = poly.load_data(poly.JSB_CHORALES) training_seq_lengths = data['train']['sequence_lengths'] training_data_sequences = data['train']['sequences'] test_seq_lengths = data['test']['sequence_lengths'] test_data_sequences = data['test']['sequences'] val_seq_lengths = data['valid']['sequence_lengths'] val_data_sequences = data['valid']['sequences'] N_train_data = len(training_seq_lengths) N_train_time_slices = float(torch.sum(training_seq_lengths)) N_mini_batches = int(N_train_data / args.mini_batch_size + int(N_train_data % args.mini_batch_size > 0)) log("N_train_data: %d avg. training seq. length: %.2f N_mini_batches: %d" % (N_train_data, training_seq_lengths.float().mean(), N_mini_batches)) ## instantiate the dmm dmm = dmm_model.DMM(rnn_dropout_rate=args.rnn_dropout_rate, num_iafs=args.num_iafs, iaf_dim=args.iaf_dim, use_cuda=args.cuda) dmm.eval() # setup optimizer adam_params = { "lr": args.learning_rate, "betas": (args.beta1, args.beta2), "clip_norm": args.clip_norm, "lrd": args.lr_decay, "weight_decay": args.weight_decay } adam = ClippedAdam(adam_params) # setup inference algorithm elbo = JitTrace_ELBO() if args.jit else Trace_ELBO() svi = SVI(dmm.model, dmm.guide, adam, loss=elbo) # loads the model and optimizer states from disk def load_checkpoint(): assert exists(args.load_opt) and exists(args.load_model), \ "--load-model and/or --load-opt misspecified" log("loading model from %s..." % args.load_model) dmm.load_state_dict(torch.load(args.load_model)) log("loading optimizer states from %s..." % args.load_opt) adam.load(args.load_opt) log("done loading model and optimizer states.") if args.load_opt != '' and args.load_model != '': load_checkpoint() ####################################### # LOAD TRAINED MODEL AND SAMPLE FROM IT ####################################### ## Basic parameters fs_aud = 12000 # sampling rate for audio rendering fig = plt.figure() ax_gt = fig.add_subplot(1, 2, 1) ax_estimated = fig.add_subplot(1, 2, 2) # ax_estimated_c = fig.add_subplot(2,2,4) # ax_dist = fig.add_subplot(2,2,2) MIDI_lo = 21 MIDI_hi = 21 + 87 # MIDI_lo_p = # MIDI_hi_p = condense = False num_notes = MIDI_hi - MIDI_lo + 1 ## Load the front-end model net = dnn.Net(ac_length=1024) net.eval() save_prefix = "dnn_frontend_poly_test" save_path = project_directory + "dnn_front_end/saved_models/" net.load_state_dict(torch.load(save_path + save_prefix + ".pt")) ## Select a testing set and collect data and initial distribution num_test_seqs = test_seq_lengths.shape[0] # idx = np.random.randint(num_test_seqs) idx = 22 print("Using test sequence number {}".format(idx)) seq_len = test_seq_lengths[idx].item() piano_roll_gt = test_data_sequences[idx, :seq_len, :].data.numpy() piano_roll_gt_rev = np.ascontiguousarray(np.flip(piano_roll_gt, axis=1)) z1_dist = dmm.get_z1_dist(torch.tensor(piano_roll_gt_rev)) ## Plot ground truth and render as audio piano_roll_gt_full = np.zeros((128, seq_len)) piano_roll_gt_full[MIDI_lo:MIDI_hi + 1, :] += piano_roll_gt.transpose().astype(int) piano_roll_gt_full *= 64 print("Sythensizing audio for input...", end="") with suppress_stdout(): pm = piano_roll_to_pretty_midi(piano_roll_gt_full.astype(float), fs=2, program=52) audio = pm.fluidsynth(fs=fs_aud, sf2_path='/usr/share/soundfonts/FluidR3_GM.sf2') print("done.") wav.write("test_ground_truth_out.wav", fs_aud, audio) ax_gt.imshow(np.flip(np.transpose(piano_roll_gt), axis=0)) ax_gt.set_yticks(np.arange(1, 89)[::10]) ax_gt.set_yticklabels(np.arange(88, 0, -1)[::10]) # plt.show() ## Generate Neural Activity Patterns # Periphery Parameters x_lo = 0.05 # 45 Hz x_hi = 0.75 # 6050 Hz num_channels = 72 print("Generating NAP...") nap, channel_cfs = pyc.carfac_nap(audio, float(fs_aud), num_sections=num_channels, x_lo=x_lo, x_hi=x_hi, b=0.5) print("Finished.") ## Generate auto-correlated frames len_sig_n = len(audio) len_frame_n = len_sig_n / (seq_len + 2) # to account of PM's padding # num_frames = int(len_sig_n/len_frame_n) num_frames = seq_len c_times_n = np.arange(0, num_frames) * len_frame_n + int(len_frame_n // 2) c_times_t = c_times_n / fs_aud win_size_n = 2048 # analysis window size win_size_t = win_size_n / fs_aud print("Calculating frame data...") sac_frames = datagen.gen_nap_sac_frames(nap, float(fs_aud), c_times_t, win_size_t, normalize=True) frames = gen_nap_frames(nap, float(fs_aud), c_times_t, win_size_t) print("Finished.") assert len(sac_frames) == seq_len ## Plot some sample frames # fig = plt.figure() # idcs = np.random.randint(seq_len,size=10) # for k in range(10): # ax = fig.add_subplot(2,5,k+1) # ax.plot(frames[k]) # plt.show() ## Generate the observation probabilities win_size = int(len(audio) / piano_roll_gt.shape[0]) sig_len = len(audio) num_hops = piano_roll_gt.shape[0] assert num_hops == sac_frames.shape[0] obs_probs = torch.zeros((num_hops, 2, num_notes), requires_grad=False) obs_probs_dnn = torch.zeros((num_hops, 2, num_notes), requires_grad=False) dnn_ests = torch.zeros((num_hops, num_notes), requires_grad=False) for k in range(num_hops): # obs_probs[k,:,:] = midi_probs_nap_klap(frames[k], fs_aud, 1024, b=0.01) obs_probs[k, :, :] = midi_probs_from_signal_dnn( sac_frames[k], fs_aud, MIDI_lo, MIDI_hi, net, ac_size=1024, compression_factor=0.825, offset=0.05) # dnn_ests[k] = torch.where(obs_probs[k,1,:] > 0.25, # torch.ones_like(obs_probs[k,1,:]), # torch.zeros_like(obs_probs[k,1,:])) ## Plot some sample front-end outputs # fig = plt.figure() # idcs = np.random.randint(seq_len, size=10) # for k in range(10): # ax = fig.add_subplot(2,5,k+1) # ax.plot(obs_probs[idcs[k],1,:].numpy()) # ax.plot(obs_probs_dnn[idcs[k],1,:].numpy()) # ax.set_title("{}".format(idcs[k])) # for q in range(piano_roll_gt.shape[1]): # if piano_roll_gt[idcs[k],q] == 1: # ax.plot([q, q], [0,1.0], color='C3') # plt.show() # TEST: calculate the probability of the first notes piano_roll_gt = torch.from_numpy(piano_roll_gt).type(torch.long) # Particle Filtering Parameters # num_particles = 10000 # worked! num_particles = 500 z_dim = 100 x_dim = 88 z = torch.ones((num_particles, z_dim), requires_grad=False) x = torch.ones((num_hops, num_particles, x_dim), requires_grad=False) w = torch.ones((num_hops, num_particles), requires_grad=False) w_naive = torch.ones((num_hops, num_particles), requires_grad=False) # Generate initial particles count = 0 for p in range(num_particles): z_prev = pyro.sample("init_z", z1_dist) z[p, :], x[0, p, :] = dmm.get_sample(z_prev, p) num_same = torch.sum( x[0, p, :].type(torch.long) == piano_roll_gt[0]).item() if num_same == 88: count += 1 print("Got {} correct samples in step {}".format(count, 0)) count = 0 ## Calculate initial weights ## Uniform # w[0,:] = 1.0/w.shape[1] ## Use calculates obs_probs (from DNN or elsewhere) for p in range(num_particles): prob = calc_obs_probs(obs_probs[0, :, :], x[0, p, :]) w[0, p] *= prob w_naive[0, p] *= prob # Normalize weights w[0, :] = normalize_weights(w[0, :]) w_naive[0, :] = normalize_weights(w_naive[0, :]) ## Main Particle Filtering Loop good_samples = np.zeros(num_hops) for f in range(1, num_hops): ## Sample new particles z_vals, z_probs = particles_to_dist(z, w[f - 1, :]) for p in range(num_particles): idx = discrete_sample(z_probs) z[p, :], x[f, p, :] = dmm.get_sample(z_vals[idx], p + f * num_particles) num_same = torch.sum( x[f, p, :].type(torch.long) == piano_roll_gt[f]).item() if num_same == 88: count += 1 ## Calculate Weights -- probably bring this into loop above #- Primitive observation-free model of observations (more notes correct, higher prob) sx_prob = ((num_same - 78) / 10)**2 w_naive[f, p] *= sx_prob #*xz_prob #- Use Observation probabilities sx_prob = calc_obs_probs(obs_probs[f, :, :], x[f, p, :]) w[f, p] *= sx_prob #*xz_prob # Calculate probability of x given z # xz_dist = dist.Bernoulli(dmm.emitter(z[p,:])) # xz_prob = torch.exp(torch.sum(xz_dist.log_prob(x[f,p,:]))).item() # Report number of samples that corresponded with ground truth print("Got {} correct samples in step {} \t".format(count, f), end='') good_samples[f] = count count = 0 ## Normalize w[f, :] = w[f, :].pow(0.25) w[f, :] = normalize_weights(w[f, :]) w_naive[f, :] = normalize_weights(w_naive[f, :]) # plt.plot(w[f,:].numpy()) # plt.plot(w_naive[f,:].numpy()) # plt.show() print("\tDone!") # Now pull out the most probable path piano_roll_dist = np.zeros((num_hops, 88)) if condense: print("\"Condensing\" final distribution") w_c = [] x_c = [] for f in range(num_hops): w_condensed, x_condensed = make_final_dist(w[f, :], x[f, :, :]) w_c.append(w_condensed) x_c.append(x_condensed) print("{} unique samples in step {}/{}".format( len(w_condensed), f + 1, num_hops)) piano_roll_dist[f, :] = np.sum(x_condensed * w_condensed[:, np.newaxis], axis=0) ax_dist.imshow(np.flip(np.transpose(piano_roll_dist), axis=0)) ## Most probable path by picking highest weighted particle piano_roll_estimated = np.zeros((num_hops, 88)) piano_roll_estimated_c = np.zeros((num_hops, 88)) for f in range(num_hops): ## just picking highest weight idx = np.argmax(w[f, :].numpy()) piano_roll_estimated[f, :] = x[f, idx, :].numpy() ## picking from highest condensed weight if condense: idx = np.argmax(w_c[f]) piano_roll_estimated_c[f, :] = x_c[f][idx, :] ax_estimated.imshow(np.flip(np.transpose(piano_roll_estimated), axis=0)) ax_estimated.set_yticks(np.arange(1, 89)[::10]) ax_estimated.set_yticklabels(np.arange(88, 0, -1)[::10]) if condense: ax_estimated_c.imshow( np.flip(np.transpose(piano_roll_estimated_c), axis=0)) ## Calculate and report precision and recall p, r, f = precision_recall_f(piano_roll_gt.numpy(), piano_roll_estimated) print("Precision (regular): \t", p) print("Recall (regular): \t", r) print("F-metric (regular): \t", f) ## Check how often the correct sample was chosen when available gt = piano_roll_gt.numpy() num_available = 0 num_chosen = 0 num_available = 0 num_chosen_c = 0 for f in range(num_hops): if good_samples[f] > 0: num_available += 1 if np.array_equal(gt[f, :], piano_roll_estimated[f, :]): num_chosen += 1 if condense: if np.array_equal(gt[f, :], piano_roll_estimated_c[f, :]): num_chosen_c += 1 print("Correct select rate (normal) : {}".format(num_chosen / num_available)) if condense: print("Correct select rate (condensed): {}".format(num_chosen_c / num_available)) ## Make audio for estimate piano roll piano_roll_estimated_full = np.zeros((128, seq_len), dtype=int) piano_roll_estimated_full[ 20:108, :] += piano_roll_estimated.transpose().astype(int) piano_roll_estimated_full *= 64 print("Sythensizing audio for input...", end="") with suppress_stdout(): pm = piano_roll_to_pretty_midi(piano_roll_estimated_full.astype(float), fs=2, program=52) audio = pm.fluidsynth(fs=fs_aud, sf2_path='/usr/share/soundfonts/FluidR3_GM.sf2') print("done.") wav.write("test_estimated_out.wav", fs_aud, audio) # ax_estimated.imshow(np.flip(piano_roll_estimated_full, axis=0)) ## Sample a random sequence starting at the same initial latent state # x_vals = [] # z_vals = [] # piano_roll_sampled = np.zeros((seq_len, 88)) # for k in range(seq_len): # z_new, x_new = dmm.get_sample(z_prev, k) # x_vals.append(x_new) # z_vals.append(z_new) # piano_roll_sampled[k,:] = x_new.data.numpy() # z_prev = z_new # # # Get MIDI from piano from sampled roll # piano_roll_sampled_full = np.zeros((128, seq_len), dtype=int) # piano_roll_sampled_full[20:108,:] += piano_roll_sampled.transpose().astype(int) # piano_roll_sampled_full *= 64 # print("Sythensizing audio for input...", end="") # with suppress_stdout(): # pm = piano_roll_to_pretty_midi(piano_roll_sampled_full.astype(float), fs=1, program=52) # audio = pm.fluidsynth(fs=fs_aud, # sf2_path='/usr/share/soundfonts/FluidR3_GM.sf2') # print('done.') # wav.write("test_sampled_out.wav", fs_aud, audio) # ax_sampled.imshow(np.flip(np.transpose(piano_roll_sampled), axis=0)) # plt.tight_layout() plt.show()
def main(args): # setup logging log = get_logger(args.log) log(args) data = poly.load_data(poly.JSB_CHORALES) training_seq_lengths = data['train']['sequence_lengths'] training_data_sequences = data['train']['sequences'] test_seq_lengths = data['test']['sequence_lengths'] test_data_sequences = data['test']['sequences'] val_seq_lengths = data['valid']['sequence_lengths'] val_data_sequences = data['valid']['sequences'] N_train_data = len(training_seq_lengths) N_train_time_slices = float(torch.sum(training_seq_lengths)) N_mini_batches = int(N_train_data / args.mini_batch_size + int(N_train_data % args.mini_batch_size > 0)) log("N_train_data: %d avg. training seq. length: %.2f N_mini_batches: %d" % (N_train_data, training_seq_lengths.float().mean(), N_mini_batches)) # how often we do validation/test evaluation during training val_test_frequency = 50 # the number of samples we use to do the evaluation n_eval_samples = 1 # package repeated copies of val/test data for faster evaluation # (i.e. set us up for vectorization) def rep(x): rep_shape = torch.Size([x.size(0) * n_eval_samples]) + x.size()[1:] repeat_dims = [1] * len(x.size()) repeat_dims[0] = n_eval_samples return x.repeat(repeat_dims).reshape(n_eval_samples, -1).transpose( 1, 0).reshape(rep_shape) # get the validation/test data ready for the dmm: pack into sequences, etc. val_seq_lengths = rep(val_seq_lengths) test_seq_lengths = rep(test_seq_lengths) val_batch, val_batch_reversed, val_batch_mask, val_seq_lengths = poly.get_mini_batch( torch.arange(n_eval_samples * val_data_sequences.shape[0]), rep(val_data_sequences), val_seq_lengths, cuda=args.cuda) test_batch, test_batch_reversed, test_batch_mask, test_seq_lengths = poly.get_mini_batch( torch.arange(n_eval_samples * test_data_sequences.shape[0]), rep(test_data_sequences), test_seq_lengths, cuda=args.cuda) # instantiate the dmm dmm = DMM(rnn_dropout_rate=args.rnn_dropout_rate, num_iafs=args.num_iafs, iaf_dim=args.iaf_dim, use_cuda=args.cuda) # setup optimizer adam_params = { "lr": args.learning_rate, "betas": (args.beta1, args.beta2), "clip_norm": args.clip_norm, "lrd": args.lr_decay, "weight_decay": args.weight_decay } adam = ClippedAdam(adam_params) # setup inference algorithm if args.tmc: if args.jit: raise NotImplementedError("no JIT support yet for TMC") tmc_loss = TraceTMC_ELBO() dmm_guide = config_enumerate(dmm.guide, default="parallel", num_samples=args.tmc_num_samples, expand=False) svi = SVI(dmm.model, dmm_guide, adam, loss=tmc_loss) elif args.tmcelbo: if args.jit: raise NotImplementedError("no JIT support yet for TMC ELBO") elbo = TraceEnum_ELBO() dmm_guide = config_enumerate(dmm.guide, default="parallel", num_samples=args.tmc_num_samples, expand=False) svi = SVI(dmm.model, dmm_guide, adam, loss=elbo) else: elbo = JitTrace_ELBO() if args.jit else Trace_ELBO() svi = SVI(dmm.model, dmm.guide, adam, loss=elbo) # now we're going to define some functions we need to form the main training loop # saves the model and optimizer states to disk def save_checkpoint(): log("saving model to %s..." % args.save_model) torch.save(dmm.state_dict(), args.save_model) log("saving optimizer states to %s..." % args.save_opt) adam.save(args.save_opt) log("done saving model and optimizer checkpoints to disk.") # loads the model and optimizer states from disk def load_checkpoint(): assert exists(args.load_opt) and exists(args.load_model), \ "--load-model and/or --load-opt misspecified" log("loading model from %s..." % args.load_model) dmm.load_state_dict(torch.load(args.load_model)) log("loading optimizer states from %s..." % args.load_opt) adam.load(args.load_opt) log("done loading model and optimizer states.") # prepare a mini-batch and take a gradient step to minimize -elbo def process_minibatch(epoch, which_mini_batch, shuffled_indices): if args.annealing_epochs > 0 and epoch < args.annealing_epochs: # compute the KL annealing factor approriate for the current mini-batch in the current epoch min_af = args.minimum_annealing_factor annealing_factor = min_af + (1.0 - min_af) * \ (float(which_mini_batch + epoch * N_mini_batches + 1) / float(args.annealing_epochs * N_mini_batches)) else: # by default the KL annealing factor is unity annealing_factor = 1.0 # compute which sequences in the training set we should grab mini_batch_start = (which_mini_batch * args.mini_batch_size) mini_batch_end = np.min([(which_mini_batch + 1) * args.mini_batch_size, N_train_data]) mini_batch_indices = shuffled_indices[mini_batch_start:mini_batch_end] # grab a fully prepped mini-batch using the helper function in the data loader mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths \ = poly.get_mini_batch(mini_batch_indices, training_data_sequences, training_seq_lengths, cuda=args.cuda) # do an actual gradient step loss = svi.step(mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths, annealing_factor) # keep track of the training loss return loss # helper function for doing evaluation def do_evaluation(): # put the RNN into evaluation mode (i.e. turn off drop-out if applicable) dmm.rnn.eval() # compute the validation and test loss n_samples many times val_nll = svi.evaluate_loss(val_batch, val_batch_reversed, val_batch_mask, val_seq_lengths) / float( torch.sum(val_seq_lengths)) test_nll = svi.evaluate_loss( test_batch, test_batch_reversed, test_batch_mask, test_seq_lengths) / float(torch.sum(test_seq_lengths)) # put the RNN back into training mode (i.e. turn on drop-out if applicable) dmm.rnn.train() return val_nll, test_nll # if checkpoint files provided, load model and optimizer states from disk before we start training if args.load_opt != '' and args.load_model != '': load_checkpoint() ################# # TRAINING LOOP # ################# times = [time.time()] for epoch in range(args.num_epochs): # if specified, save model and optimizer states to disk every checkpoint_freq epochs if args.checkpoint_freq > 0 and epoch > 0 and epoch % args.checkpoint_freq == 0: save_checkpoint() # accumulator for our estimate of the negative log likelihood (or rather -elbo) for this epoch epoch_nll = 0.0 # prepare mini-batch subsampling indices for this epoch shuffled_indices = torch.randperm(N_train_data) # process each mini-batch; this is where we take gradient steps for which_mini_batch in range(N_mini_batches): epoch_nll += process_minibatch(epoch, which_mini_batch, shuffled_indices) # report training diagnostics times.append(time.time()) epoch_time = times[-1] - times[-2] log("[training epoch %04d] %.4f \t\t\t\t(dt = %.3f sec)" % (epoch, epoch_nll / N_train_time_slices, epoch_time)) # do evaluation on test and validation data and report results if val_test_frequency > 0 and epoch > 0 and epoch % val_test_frequency == 0: val_nll, test_nll = do_evaluation() log("[val/test epoch %04d] %.4f %.4f" % (epoch, val_nll, test_nll))
def main(args): # clear param store pyro.clear_param_store() ### SETUP train_loader, test_loader = get_data() # setup the VAE vae = VAE(use_cuda=args.cuda) # setup the optimizer adam_args = {"lr": args.learning_rate} optimizer = Adam(adam_args) # setup the inference algorithm elbo = JitTrace_ELBO() if args.jit else Trace_ELBO() svi = SVI(vae.model, vae.guide, optimizer, loss=elbo) inputSize = 0 # setup visdom for visualization if args.visdom_flag: vis = visdom.Visdom() train_elbo = [] test_elbo = [] for epoch in range(args.num_epochs): # initialize loss accumulator epoch_loss = 0. # do a training epoch over each mini-batch x returned # by the data loader for step, batch in enumerate(train_loader): print("batch", batch) x, adj = 0, 0 # if on GPU put mini-batch into CUDA memory if args.cuda: x = batch['x'].cuda() adj = batch['edge_index'].cuda() else: x = batch['x'] adj = batch['edge_index'] print("x_shape", x.shape) print("adj_shape", adj.shape) inputSize = x.shape[0] * x.shape[1] epoch_loss += svi.step(x, adj) # report training diagnostics normalizer_train = len(train_loader.dataset) total_epoch_loss_train = epoch_loss / normalizer_train train_elbo.append(total_epoch_loss_train) print("[epoch %03d] average training loss: %.4f" % (epoch, total_epoch_loss_train)) if True: # if epoch % args.test_frequency == 0: # initialize loss accumulator test_loss = 0. # compute the loss over the entire test set for step, batch in enumerate(test_loader): x, adj = 0, 0 # if on GPU put mini-batch into CUDA memory if args.cuda: x = batch['x'].cuda() adj = batch['edge_index'].cuda() else: x = batch['x'] adj = batch['edge_index'] # compute ELBO estimate and accumulate loss # print('before evaluating test loss') test_loss += svi.evaluate_loss(x, adj) # print('after evaluating test loss') # pick three random test images from the first mini-batch and # visualize how well we're reconstructing them # if i == 0: # if args.visdom_flag: # plot_vae_samples(vae, vis) # reco_indices = np.random.randint(0, x.shape[0], 3) # for index in reco_indices: # test_img = x[index, :] # reco_img = vae.reconstruct_img(test_img) # vis.image(test_img.reshape(28, 28).detach().cpu().numpy(), # opts={'caption': 'test image'}) # vis.image(reco_img.reshape(28, 28).detach().cpu().numpy(), # opts={'caption': 'reconstructed image'}) if args.visdom_flag: plot_vae_samples(vae, vis) reco_indices = np.random.randint(0, x.shape[0], 3) for index in reco_indices: test_img = x[index, :] reco_img = vae.reconstruct_graph(test_img) vis.image(test_img.reshape(28, 28).detach().cpu().numpy(), opts={'caption': 'test image'}) vis.image(reco_img.reshape(28, 28).detach().cpu().numpy(), opts={'caption': 'reconstructed image'}) # report test diagnostics normalizer_test = len(test_loader.dataset) total_epoch_loss_test = test_loss / normalizer_test test_elbo.append(total_epoch_loss_test) print("[epoch %03d] average test loss: %.4f" % (epoch, total_epoch_loss_test)) # if epoch == args.tsne_iter: # mnist_test_tsne(vae=vae, test_loader=test_loader) # plot_llk(np.array(train_elbo), np.array(test_elbo)) if args.save: torch.save( { 'epoch': epoch, 'model_state_dict': vae.state_dict(), 'optimzier_state_dict': optimizer.get_state(), 'train_loss': total_epoch_loss_train, 'test_loss': total_epoch_loss_test }, 'vae_' + args.name + str(args.time) + '.pt') return vae
def run_inference(dataset_obj: SingleCellRNACountsDataset, args) -> RemoveBackgroundPyroModel: """Run a full inference procedure, training a latent variable model. Args: dataset_obj: Input data in the form of a SingleCellRNACountsDataset object. args: Input command line parsed arguments. Returns: model: cellbender.model.RemoveBackgroundPyroModel that has had inference run. """ # Get the trimmed count matrix (transformed if called for). count_matrix = dataset_obj.get_count_matrix() # Configure pyro options (skip validations to improve speed). pyro.enable_validation(False) pyro.distributions.enable_validation(False) pyro.set_rng_seed(0) pyro.clear_param_store() # Set up the variational autoencoder: # Encoder. encoder_z = EncodeZ(input_dim=count_matrix.shape[1], hidden_dims=args.z_hidden_dims, output_dim=args.z_dim, input_transform='normalize') encoder_d = EncodeD( input_dim=count_matrix.shape[1], hidden_dims=args.d_hidden_dims, output_dim=1, log_count_crossover=dataset_obj.priors['log_counts_crossover']) if args.model == "simple": # If using the simple model, there is no need for p. encoder = CompositeEncoder({'z': encoder_z, 'd_loc': encoder_d}) else: # Models that include empty droplets. encoder_p = EncodeNonEmptyDropletLogitProb( input_dim=count_matrix.shape[1], hidden_dims=args.p_hidden_dims, output_dim=1, input_transform='normalize', log_count_crossover=dataset_obj.priors['log_counts_crossover']) encoder = CompositeEncoder({ 'z': encoder_z, 'd_loc': encoder_d, 'p_y': encoder_p }) # Decoder. decoder = Decoder(input_dim=args.z_dim, hidden_dims=args.z_hidden_dims[::-1], output_dim=count_matrix.shape[1]) # Set up the pyro model for variational inference. model = RemoveBackgroundPyroModel(model_type=args.model, encoder=encoder, decoder=decoder, dataset_obj=dataset_obj, use_cuda=args.use_cuda) # Set up the optimizer. adam_args = {"lr": args.learning_rate} optimizer = ClippedAdam(adam_args) # Determine the loss function. if args.use_jit: loss_function = JitTraceEnum_ELBO(max_plate_nesting=1, strict_enumeration_warning=False) else: loss_function = TraceEnum_ELBO(max_plate_nesting=1) if args.model == "simple": if args.use_jit: loss_function = JitTrace_ELBO() else: loss_function = Trace_ELBO() # Set up the inference process. svi = SVI(model.model, model.guide, optimizer, loss=loss_function) # Load the dataset into DataLoaders. frac = args.training_fraction # Fraction of barcodes to use for training batch_size = int( min(500, frac * dataset_obj.analyzed_barcode_inds.size / 2)) train_loader, test_loader = \ prep_data_for_training(dataset=count_matrix, empty_drop_dataset= dataset_obj.get_count_matrix_empties(), random_state=dataset_obj.random, batch_size=batch_size, training_fraction=frac, fraction_empties=args.fraction_empties, shuffle=True, use_cuda=args.use_cuda) # Run training. run_training(model, svi, train_loader, test_loader, epochs=args.epochs, test_freq=10) return model
def training(args, rel_embeddings, word_embeddings): if args.seed is not None: pyro.set_rng_seed(args.seed) # CUDA for PyTorch cuda_available = torch.cuda.is_available() if (cuda_available and args.cuda): device = torch.device("cuda") torch.cuda.set_device(0) print("using gpu acceleration") print("Generating Config") config = Config( word_embeddings=torch.FloatTensor(word_embeddings), decoder_hidden_dim=args.decoder_hidden_dim, num_relations=7, encoder_hidden_dim=args.encoder_hidden_dim, num_predicates=1000, batch_size=args.batch_size ) # initialize the generator model generator = SimpleGenerator(config) # setup the optimizer adam_params = {"lr": args.learning_rate, "betas": (args.beta_1, 0.999)} optimizer = ClippedAdam(adam_params) # set up the loss(es) for inference. wrapping the guide in config_enumerate builds the loss as a sum # by enumerating each class label for the sampled discrete categorical distribution in the model if args.enumerate: guide = config_enumerate(generator.guide, args.enum_discrete, expand=True) else: guide = generator.guide elbo = (JitTraceEnum_ELBO if args.jit else TraceEnum_ELBO)(max_plate_nesting=1) loss_basic = SVI(generator.model, guide, optimizer, loss=elbo) # build a list of all losses considered losses = [loss_basic] # aux_loss: whether to use the auxiliary loss from NIPS 14 paper (Kingma et al) if args.aux_loss: elbo = JitTrace_ELBO() if args.jit else Trace_ELBO() loss_aux = SVI(generator.model_identify, generator.guide_identify, optimizer, loss=elbo) losses.append(loss_aux) # prepare data real_model = RealModel(rel_embeddings, word_embeddings) data = real_model.generate_data() sup_train_set = data[:100] unsup_train_set = data[100:700] eval_set = data[700:900] test_set = data[900:] data_loaders = setup_data_loaders(sup_train_set, unsup_train_set, eval_set, test_set, batch_size=args.batch_size) num_train = len(sup_train_set) + len(unsup_train_set) num_eval = len(eval_set) num_test = len(test_set) # how often would a supervised batch be encountered during inference # e.g. if sup_num is 3000, we would have every 16th = int(50000/3000) batch supervised # until we have traversed through the all supervised batches periodic_interval_batches = int(1.0 * num_train / len(sup_train_set)) # setup the logger if a filename is provided log_fn = "./logs/" + args.experiment_type + '/' + args.experiment_name + '.log' logger = open(log_fn, "w") # run inference for a certain number of epochs for i in tqdm(range(0, args.num_epochs)): # get the losses for an epoch epoch_losses_sup, epoch_losses_unsup = \ train_epoch(data_loaders=data_loaders, models=losses, periodic_interval_batches=periodic_interval_batches) # compute average epoch losses i.e. losses per example avg_epoch_losses_sup = map(lambda v: v / len(sup_train_set), epoch_losses_sup) avg_epoch_losses_unsup = map(lambda v: v / len(unsup_train_set), epoch_losses_unsup) # store the loss and validation/testing accuracies in the logfile str_loss_sup = " ".join(map(str, avg_epoch_losses_sup)) str_loss_unsup = " ".join(map(str, avg_epoch_losses_unsup)) str_print = "{} epoch: avg losses {}".format(i, "{} {}".format(str_loss_sup, str_loss_unsup)) print_and_log(logger, str_print) # save trained models torch.save(generator.state_dict(), './models/test_generator_state_dict.pth') return generator
def main(args): # clear param store pyro.clear_param_store() # setup Protein data loaders # train_loader, test_loader train_loader, test_loader = setup_data_loaders(batch_size=256, use_cuda=args.cuda) ramaPlot.plot_loader(train_loader) # setup the VAE vae = VAE(use_cuda=args.cuda) #set_trace() # setup optimizer adam_args = {"lr": args.learning_rate} optimizer = Adam(adam_args) # setup the inference algorithm elbo = JitTrace_ELBO() if args.jit else Trace_ELBO() svi = SVI(vae.model, vae.guide, optimizer, loss=elbo) """ for x, y in train_loader: print("x in train_loader:", x) set_trace() """ train_elbo = [] test_elbo = [] for epoch in range(args.num_epochs): # initialize loss accumulator epoch_loss = 0. # do a training epoch over each mini-batch x returned # by the data loader for x in train_loader: # if on GPU put mini-batch into CUDA memory # float() needed because it is currently double on that gives error # plus double datatype is considerably slower on GPU x = x.float() if args.cuda: x = x.float() x = x.cuda() # do ELBO gradient and accumulate los #set_trace() epoch_loss += svi.step(x) # report training diagnostics normalizer_train = len(train_loader.dataset) total_epoch_loss_train = epoch_loss / normalizer_train train_elbo.append(total_epoch_loss_train) print("[epoch %03d] average training loss: %.4f" % (epoch, total_epoch_loss_train)) if epoch % args.test_frequency == 0: # initialize loss accumulator test_loss = 0. for i, x in enumerate(test_loader): #set_trace() # float() needed because it is currently double on that gives error # plus double datatype is considerably slower on GPU x = x.float() # if on GPU put mini-batch into CUDA memory if args.cuda: x = x.cuda() # compute ELBO estimate and accumulate loss test_loss += svi.evaluate_loss(x) # pick three random test angles from the first mini-batch and # visualize how well we're reconstructing them if i == 0: #print("i = 0") #ramaPlot.plot_vae_samples(vae,str(epoch)) #ramaPlot.make_plots(train_elbo, train_loader, elbo, vae, x) ramaPlot.make_plots2(train_elbo, train_loader, elbo, vae) reco_indices = np.random.randint(0, x.shape[0], 3) for index in reco_indices: test_angle = x[index, :] reco_angle = vae.reconstruct_plot(test_angle) #set_trace() #set_trace() #ramaPlot.plot(test_angle, str(epoch)+"test_angle") #ramaPlot.plot(reco_angle, str(epoch)+"reconstructed_angle") # report test diagnostics normalizer_test = len(test_loader.dataset) total_epoch_loss_test = test_loss / normalizer_test test_elbo.append(total_epoch_loss_test) print("[epoch %03d] average test loss: %.4f" % (epoch, total_epoch_loss_test)) """ if epoch == args.tsne_iter: mnist_test_tsne(vae=vae, test_loader=test_loader) plot_llk(np.array(train_elbo), np.array(test_elbo)) """ return vae
def run_inference(dataset_obj: Dataset, args) -> VariationalInferenceModel: """Run a full inference procedure, training a latent variable model. Args: dataset_obj: Input data in the form of a Dataset object. args: Input command line parsed arguments. Returns: model: cellbender.model.VariationalInferenceModel that has had inference run. """ # Get the trimmed count matrix (transformed if called for). count_matrix = dataset_obj.get_count_matrix() # Configure pyro options (skip validations to improve speed). pyro.enable_validation(False) pyro.distributions.enable_validation(False) pyro.set_rng_seed(0) pyro.clear_param_store() # Set up the variational autoencoder: # Encoder. encoder_z = EncodeZ(input_dim=count_matrix.shape[1], hidden_dims=args.z_hidden_dims, output_dim=args.z_dim, input_transform='normalize') encoder_d = EncodeD( input_dim=count_matrix.shape[1], hidden_dims=args.d_hidden_dims, output_dim=1, log_count_crossover=dataset_obj.priors['log_counts_crossover']) if args.model[0] == "simple": # If using the simple model, there is no need for p. encoder = CompositeEncoder({'z': encoder_z, 'd_loc': encoder_d}) else: # Models that include empty droplets. encoder_p = EncodePAmbient( input_dim=count_matrix.shape[1], hidden_dims=args.p_hidden_dims, output_dim=1, input_transform='normalize', log_count_crossover=dataset_obj.priors['log_counts_crossover']) encoder = CompositeEncoder({ 'z': encoder_z, 'd_loc': encoder_d, 'p_y': encoder_p }) # Decoder. decoder = Decoder(input_dim=args.z_dim, hidden_dims=args.z_hidden_dims[::-1], output_dim=count_matrix.shape[1]) # Set up the pyro model for variational inference. model = VariationalInferenceModel( model_type=args.model[0], encoder=encoder, decoder=decoder, dataset_obj=dataset_obj, use_decaying_avg_baseline=args.use_decaying_average_baseline, use_IAF=args.use_IAF, use_cuda=args.use_cuda) # Load the dataset into DataLoaders. frac = args.training_fraction batch_size = int( min(500, frac * dataset_obj.analyzed_barcode_inds.size / 2)) train_loader, test_loader = \ prep_data_for_training(dataset=count_matrix, empty_drop_dataset= dataset_obj.get_count_matrix_empties(), batch_size=batch_size, training_fraction=frac, fraction_empties=args.fraction_empties, shuffle=True, use_cuda=args.use_cuda) # Run the guide once for Jit. (can hang on StopIteration if no test data!) # model.guide(test_loader.__iter__().__next__()) # This seems unnecessary # Set up the optimizer. adam_args = {"lr": args.learning_rate} optimizer = ClippedAdam(adam_args) # Determine the loss function. # loss_function = TraceEnum_ELBO(max_plate_nesting=1) loss_function = JitTraceEnum_ELBO(max_plate_nesting=1, strict_enumeration_warning=False) if args.model[0] == "simple": loss_function = JitTrace_ELBO() # Set up the inference process. svi = SVI(model.model, model.guide, optimizer, loss=loss_function) # Run training. run_training(model, svi, train_loader, test_loader, epochs=args.epochs, test_freq=10) return model
else: raise Exception(f'model {args.irt_model} not recognized') model = model_class( args.ability_dim, num_item, hidden_dim=args.hidden_dim, ability_merge=args.ability_merge, conditional_posterior=args.conditional_posterior, generative_model=args.generative_model, num_iafs=args.num_iafs, iaf_dim=args.iaf_dim, ).to(device) optimizer = Adam({"lr": args.lr}) elbo = (JitTrace_ELBO(num_particles=args.num_particles) if args.jit else Trace_ELBO(num_particles=args.num_particles)) svi = SVI(model.model, model.guide, optimizer, loss=elbo) def get_annealing_factor(epoch, which_mini_batch): if args.anneal_kl: annealing_factor = \ (float(which_mini_batch + epoch * N_mini_batches + 1) / float(args.epochs // 2 * N_mini_batches)) else: annealing_factor = args.beta_kl return annealing_factor def train(epoch): model.train() train_loss = AverageMeter()
def main(args): # clear param store pyro.clear_param_store() #pyro.enable_validation(True) # train_loader, test_loader transform = {} transform["train"] = transforms.Compose([ transforms.Resize((400, 400)), transforms.ToTensor(), ]) transform["test"] = transforms.Compose( [transforms.Resize((400, 400)), transforms.ToTensor()]) train_loader, test_loader = setup_data_loaders( dataset=GameCharacterFullData, root_path=DATA_PATH, batch_size=32, transforms=transform) # setup the VAE vae = VAE(use_cuda=args.cuda, num_labels=17) # setup the optimizer adam_args = {"lr": args.learning_rate} optimizer = Adam(adam_args) # setup the inference algorithm elbo = JitTrace_ELBO() if args.jit else Trace_ELBO() svi = SVI(vae.model, vae.guide, optimizer, loss=elbo) # setup visdom for visualization if args.visdom_flag: vis = visdom.Visdom(port='8097') train_elbo = [] test_elbo = [] # training loop for epoch in range(args.num_epochs): # initialize loss accumulator epoch_loss = 0. # do a training epoch over each mini-batch x returned # by the data loader for x, y, actor, reactor, actor_type, reactor_type, action, reaction in train_loader: # if on GPU put mini-batch into CUDA memory if args.cuda: x = x.cuda() y = y.cuda() actor = actor.cuda() reactor = reactor.cuda() actor_type = actor_type.cuda() reactor_type = reactor_type.cuda() action = action.cuda() reaction = reaction.cuda() # do ELBO gradient and accumulate loss epoch_loss += svi.step(x, y, actor, reactor, actor_type, reactor_type, action, reaction) # report training diagnostics normalizer_train = len(train_loader.dataset) total_epoch_loss_train = epoch_loss / normalizer_train train_elbo.append(total_epoch_loss_train) print("[epoch %03d] average training loss: %.4f" % (epoch, total_epoch_loss_train)) if epoch % args.test_frequency == 0: # initialize loss accumulator test_loss = 0. # compute the loss over the entire test set for i, (x, y, actor, reactor, actor_type, reactor_type, action, reaction) in enumerate(test_loader): # if on GPU put mini-batch into CUDA memory if args.cuda: x = x.cuda() y = y.cuda() actor = actor.cuda() reactor = reactor.cuda() actor_type = actor_type.cuda() reactor_type = reactor_type.cuda() action = action.cuda() reaction = reaction.cuda() # compute ELBO estimate and accumulate loss test_loss += svi.evaluate_loss(x, y, actor, reactor, actor_type, reactor_type, action, reaction) # pick three random test images from the first mini-batch and # visualize how well we're reconstructing them if i == 0: if args.visdom_flag: plot_vae_samples(vae, vis) reco_indices = np.random.randint(0, x.shape[0], 3) for index in reco_indices: test_img = x[index, :] reco_img = vae.reconstruct_img(test_img) vis.image(test_img.reshape( 400, 400).detach().cpu().numpy(), opts={'caption': 'test image'}) vis.image(reco_img.reshape( 400, 400).detach().cpu().numpy(), opts={'caption': 'reconstructed image'}) # report test diagnostics normalizer_test = len(test_loader.dataset) total_epoch_loss_test = test_loss / normalizer_test test_elbo.append(total_epoch_loss_test) print("[epoch %03d] average test loss: %.4f" % (epoch, total_epoch_loss_test)) return vae, optimizer