Ejemplo n.º 1
0
 def initialize_optimizer(self, lr):
     optimizer = Adam({'lr': lr})
     elbo = JitTrace_ELBO() if self.args.jit else Trace_ELBO()
     return SVI(self.model, self.guide, optimizer, loss=elbo)
Ejemplo n.º 2
0
Archivo: vae.py Proyecto: ucals/pyro
def main(args):
    # clear param store
    pyro.clear_param_store()

    # setup MNIST data loaders
    # train_loader, test_loader
    train_loader, test_loader = setup_data_loaders(MNIST,
                                                   use_cuda=args.cuda,
                                                   batch_size=256)

    # setup the VAE
    vae = VAE(use_cuda=args.cuda)

    # setup the optimizer
    adam_args = {"lr": args.learning_rate}
    optimizer = Adam(adam_args)

    # setup the inference algorithm
    elbo = JitTrace_ELBO() if args.jit else Trace_ELBO()
    svi = SVI(vae.model, vae.guide, optimizer, loss=elbo)

    # setup visdom for visualization
    if args.visdom_flag:
        vis = visdom.Visdom()

    train_elbo = []
    test_elbo = []
    # training loop
    for epoch in range(args.num_epochs):
        # initialize loss accumulator
        epoch_loss = 0.
        # do a training epoch over each mini-batch x returned
        # by the data loader
        for x, _ in train_loader:
            # if on GPU put mini-batch into CUDA memory
            if args.cuda:
                x = x.cuda()
            # do ELBO gradient and accumulate loss
            epoch_loss += svi.step(x)

        # report training diagnostics
        normalizer_train = len(train_loader.dataset)
        total_epoch_loss_train = epoch_loss / normalizer_train
        train_elbo.append(total_epoch_loss_train)
        print("[epoch %03d]  average training loss: %.4f" %
              (epoch, total_epoch_loss_train))

        if epoch % args.test_frequency == 0:
            # initialize loss accumulator
            test_loss = 0.
            # compute the loss over the entire test set
            for i, (x, _) in enumerate(test_loader):
                # if on GPU put mini-batch into CUDA memory
                if args.cuda:
                    x = x.cuda()
                # compute ELBO estimate and accumulate loss
                test_loss += svi.evaluate_loss(x)

                # pick three random test images from the first mini-batch and
                # visualize how well we're reconstructing them
                if i == 0:
                    if args.visdom_flag:
                        plot_vae_samples(vae, vis)
                        reco_indices = np.random.randint(0, x.shape[0], 3)
                        for index in reco_indices:
                            test_img = x[index, :]
                            reco_img = vae.reconstruct_img(test_img)
                            vis.image(test_img.reshape(
                                28, 28).detach().cpu().numpy(),
                                      opts={'caption': 'test image'})
                            vis.image(reco_img.reshape(
                                28, 28).detach().cpu().numpy(),
                                      opts={'caption': 'reconstructed image'})

            # report test diagnostics
            normalizer_test = len(test_loader.dataset)
            total_epoch_loss_test = test_loss / normalizer_test
            test_elbo.append(total_epoch_loss_test)
            print("[epoch %03d]  average test loss: %.4f" %
                  (epoch, total_epoch_loss_test))

        if epoch == args.tsne_iter:
            mnist_test_tsne(vae=vae, test_loader=test_loader)
            plot_llk(np.array(train_elbo), np.array(test_elbo))

    return vae
Ejemplo n.º 3
0
def main(args):
    """
    run inference for SS-VAE
    :param args: arguments for SS-VAE
    :return: None
    """
    if args.seed is not None:
        pyro.set_rng_seed(args.seed)

    viz = None
    if args.visualize:
        viz = Visdom()
        mkdir_p("./vae_results")

    # batch_size: number of images (and labels) to be considered in a batch
    ss_vae = SSVAE(z_dim=args.z_dim,
                   hidden_layers=args.hidden_layers,
                   use_cuda=args.cuda,
                   config_enum=args.enum_discrete,
                   aux_loss_multiplier=args.aux_loss_multiplier)

    # setup the optimizer
    adam_params = {"lr": args.learning_rate, "betas": (args.beta_1, 0.999)}
    optimizer = Adam(adam_params)

    # set up the loss(es) for inference. wrapping the guide in config_enumerate builds the loss as a sum
    # by enumerating each class label for the sampled discrete categorical distribution in the model
    guide = config_enumerate(ss_vae.guide, args.enum_discrete, expand=True)
    Elbo = JitTraceEnum_ELBO if args.jit else TraceEnum_ELBO
    elbo = Elbo(max_plate_nesting=1, strict_enumeration_warning=False)
    loss_basic = SVI(ss_vae.model, guide, optimizer, loss=elbo)

    # build a list of all losses considered
    losses = [loss_basic]

    # aux_loss: whether to use the auxiliary loss from NIPS 14 paper (Kingma et al)
    if args.aux_loss:
        elbo = JitTrace_ELBO() if args.jit else Trace_ELBO()
        loss_aux = SVI(ss_vae.model_classify,
                       ss_vae.guide_classify,
                       optimizer,
                       loss=elbo)
        losses.append(loss_aux)

    try:
        # setup the logger if a filename is provided
        logger = open(args.logfile, "w") if args.logfile else None

        data_loaders = setup_data_loaders(MNISTCached,
                                          args.cuda,
                                          args.batch_size,
                                          sup_num=args.sup_num)

        # how often would a supervised batch be encountered during inference
        # e.g. if sup_num is 3000, we would have every 16th = int(50000/3000) batch supervised
        # until we have traversed through the all supervised batches
        periodic_interval_batches = int(MNISTCached.train_data_size /
                                        (1.0 * args.sup_num))

        # number of unsupervised examples
        unsup_num = MNISTCached.train_data_size - args.sup_num

        # initializing local variables to maintain the best validation accuracy
        # seen across epochs over the supervised training set
        # and the corresponding testing set and the state of the networks
        best_valid_acc, corresponding_test_acc = 0.0, 0.0

        # run inference for a certain number of epochs
        for i in range(0, args.num_epochs):

            # get the losses for an epoch
            epoch_losses_sup, epoch_losses_unsup = \
                run_inference_for_epoch(data_loaders, losses, periodic_interval_batches)

            # compute average epoch losses i.e. losses per example
            avg_epoch_losses_sup = map(lambda v: v / args.sup_num,
                                       epoch_losses_sup)
            avg_epoch_losses_unsup = map(lambda v: v / unsup_num,
                                         epoch_losses_unsup)

            # store the loss and validation/testing accuracies in the logfile
            str_loss_sup = " ".join(map(str, avg_epoch_losses_sup))
            str_loss_unsup = " ".join(map(str, avg_epoch_losses_unsup))

            str_print = "{} epoch: avg losses {}".format(
                i, "{} {}".format(str_loss_sup, str_loss_unsup))

            validation_accuracy = get_accuracy(data_loaders["valid"],
                                               ss_vae.classifier,
                                               args.batch_size)
            str_print += " validation accuracy {}".format(validation_accuracy)

            # this test accuracy is only for logging, this is not used
            # to make any decisions during training
            test_accuracy = get_accuracy(data_loaders["test"],
                                         ss_vae.classifier, args.batch_size)
            str_print += " test accuracy {}".format(test_accuracy)

            # update the best validation accuracy and the corresponding
            # testing accuracy and the state of the parent module (including the networks)
            if best_valid_acc < validation_accuracy:
                best_valid_acc = validation_accuracy
                corresponding_test_acc = test_accuracy

            print_and_log(logger, str_print)

        final_test_accuracy = get_accuracy(data_loaders["test"],
                                           ss_vae.classifier, args.batch_size)
        print_and_log(
            logger,
            "best validation accuracy {} corresponding testing accuracy {} "
            "last testing accuracy {}".format(best_valid_acc,
                                              corresponding_test_acc,
                                              final_test_accuracy))

        # visualize the conditional samples
        visualize(ss_vae, viz, data_loaders["test"])
    finally:
        # close the logger file object if we opened it earlier
        if args.logfile:
            logger.close()
Ejemplo n.º 4
0
def main(args):
    # setup logging
    log = get_logger(args.log)
    log(args)

    data = poly.load_data(poly.JSB_CHORALES)
    training_seq_lengths = data['train']['sequence_lengths']
    training_data_sequences = data['train']['sequences']
    test_seq_lengths = data['test']['sequence_lengths']
    test_data_sequences = data['test']['sequences']
    val_seq_lengths = data['valid']['sequence_lengths']
    val_data_sequences = data['valid']['sequences']
    N_train_data = len(training_seq_lengths)
    N_train_time_slices = float(torch.sum(training_seq_lengths))
    N_mini_batches = int(N_train_data / args.mini_batch_size +
                         int(N_train_data % args.mini_batch_size > 0))

    log("N_train_data: %d     avg. training seq. length: %.2f    N_mini_batches: %d"
        % (N_train_data, training_seq_lengths.float().mean(), N_mini_batches))

    ## instantiate the dmm
    dmm = dmm_model.DMM(rnn_dropout_rate=args.rnn_dropout_rate,
                        num_iafs=args.num_iafs,
                        iaf_dim=args.iaf_dim,
                        use_cuda=args.cuda)
    dmm.eval()

    # setup optimizer
    adam_params = {
        "lr": args.learning_rate,
        "betas": (args.beta1, args.beta2),
        "clip_norm": args.clip_norm,
        "lrd": args.lr_decay,
        "weight_decay": args.weight_decay
    }
    adam = ClippedAdam(adam_params)

    # setup inference algorithm
    elbo = JitTrace_ELBO() if args.jit else Trace_ELBO()
    svi = SVI(dmm.model, dmm.guide, adam, loss=elbo)

    # loads the model and optimizer states from disk
    def load_checkpoint():
        assert exists(args.load_opt) and exists(args.load_model), \
            "--load-model and/or --load-opt misspecified"
        log("loading model from %s..." % args.load_model)
        dmm.load_state_dict(torch.load(args.load_model))
        log("loading optimizer states from %s..." % args.load_opt)
        adam.load(args.load_opt)
        log("done loading model and optimizer states.")

    if args.load_opt != '' and args.load_model != '':
        load_checkpoint()

    #######################################
    # LOAD TRAINED MODEL AND SAMPLE FROM IT
    #######################################
    ## Basic parameters
    fs_aud = 12000  # sampling rate for audio rendering
    fig = plt.figure()
    ax_gt = fig.add_subplot(1, 2, 1)
    ax_estimated = fig.add_subplot(1, 2, 2)
    # ax_estimated_c = fig.add_subplot(2,2,4)
    # ax_dist        = fig.add_subplot(2,2,2)
    MIDI_lo = 21
    MIDI_hi = 21 + 87
    # MIDI_lo_p =
    # MIDI_hi_p =
    condense = False
    num_notes = MIDI_hi - MIDI_lo + 1

    ## Load the front-end model
    net = dnn.Net(ac_length=1024)
    net.eval()
    save_prefix = "dnn_frontend_poly_test"
    save_path = project_directory + "dnn_front_end/saved_models/"
    net.load_state_dict(torch.load(save_path + save_prefix + ".pt"))

    ## Select a testing set and collect data and initial distribution
    num_test_seqs = test_seq_lengths.shape[0]
    # idx = np.random.randint(num_test_seqs)
    idx = 22
    print("Using test sequence number {}".format(idx))
    seq_len = test_seq_lengths[idx].item()
    piano_roll_gt = test_data_sequences[idx, :seq_len, :].data.numpy()
    piano_roll_gt_rev = np.ascontiguousarray(np.flip(piano_roll_gt, axis=1))
    z1_dist = dmm.get_z1_dist(torch.tensor(piano_roll_gt_rev))

    ## Plot ground truth and render as audio
    piano_roll_gt_full = np.zeros((128, seq_len))
    piano_roll_gt_full[MIDI_lo:MIDI_hi +
                       1, :] += piano_roll_gt.transpose().astype(int)
    piano_roll_gt_full *= 64
    print("Sythensizing audio for input...", end="")
    with suppress_stdout():
        pm = piano_roll_to_pretty_midi(piano_roll_gt_full.astype(float),
                                       fs=2,
                                       program=52)
        audio = pm.fluidsynth(fs=fs_aud,
                              sf2_path='/usr/share/soundfonts/FluidR3_GM.sf2')
    print("done.")
    wav.write("test_ground_truth_out.wav", fs_aud, audio)
    ax_gt.imshow(np.flip(np.transpose(piano_roll_gt), axis=0))
    ax_gt.set_yticks(np.arange(1, 89)[::10])
    ax_gt.set_yticklabels(np.arange(88, 0, -1)[::10])
    # plt.show()

    ## Generate Neural Activity Patterns
    # Periphery Parameters
    x_lo = 0.05  # 45 Hz
    x_hi = 0.75  # 6050 Hz
    num_channels = 72
    print("Generating NAP...")
    nap, channel_cfs = pyc.carfac_nap(audio,
                                      float(fs_aud),
                                      num_sections=num_channels,
                                      x_lo=x_lo,
                                      x_hi=x_hi,
                                      b=0.5)
    print("Finished.")

    ## Generate auto-correlated frames
    len_sig_n = len(audio)
    len_frame_n = len_sig_n / (seq_len + 2)  # to account of PM's padding
    # num_frames = int(len_sig_n/len_frame_n)
    num_frames = seq_len
    c_times_n = np.arange(0, num_frames) * len_frame_n + int(len_frame_n // 2)
    c_times_t = c_times_n / fs_aud
    win_size_n = 2048  # analysis window size
    win_size_t = win_size_n / fs_aud
    print("Calculating frame data...")
    sac_frames = datagen.gen_nap_sac_frames(nap,
                                            float(fs_aud),
                                            c_times_t,
                                            win_size_t,
                                            normalize=True)
    frames = gen_nap_frames(nap, float(fs_aud), c_times_t, win_size_t)
    print("Finished.")
    assert len(sac_frames) == seq_len
    ## Plot some sample frames
    # fig = plt.figure()
    # idcs = np.random.randint(seq_len,size=10)
    # for k in range(10):
    #     ax = fig.add_subplot(2,5,k+1)
    #     ax.plot(frames[k])
    # plt.show()

    ## Generate the observation probabilities
    win_size = int(len(audio) / piano_roll_gt.shape[0])
    sig_len = len(audio)
    num_hops = piano_roll_gt.shape[0]
    assert num_hops == sac_frames.shape[0]
    obs_probs = torch.zeros((num_hops, 2, num_notes), requires_grad=False)
    obs_probs_dnn = torch.zeros((num_hops, 2, num_notes), requires_grad=False)
    dnn_ests = torch.zeros((num_hops, num_notes), requires_grad=False)
    for k in range(num_hops):
        # obs_probs[k,:,:] = midi_probs_nap_klap(frames[k], fs_aud, 1024, b=0.01)
        obs_probs[k, :, :] = midi_probs_from_signal_dnn(
            sac_frames[k],
            fs_aud,
            MIDI_lo,
            MIDI_hi,
            net,
            ac_size=1024,
            compression_factor=0.825,
            offset=0.05)
        # dnn_ests[k] = torch.where(obs_probs[k,1,:] > 0.25,
        #                        torch.ones_like(obs_probs[k,1,:]),
        #                        torch.zeros_like(obs_probs[k,1,:]))

    ## Plot some sample front-end outputs
    # fig = plt.figure()
    # idcs = np.random.randint(seq_len, size=10)
    # for k in range(10):
    #     ax = fig.add_subplot(2,5,k+1)
    #     ax.plot(obs_probs[idcs[k],1,:].numpy())
    #     ax.plot(obs_probs_dnn[idcs[k],1,:].numpy())
    #     ax.set_title("{}".format(idcs[k]))
    #     for q in range(piano_roll_gt.shape[1]):
    #         if piano_roll_gt[idcs[k],q] == 1:
    #             ax.plot([q, q], [0,1.0], color='C3')
    # plt.show()

    # TEST: calculate the probability of the first notes
    piano_roll_gt = torch.from_numpy(piano_roll_gt).type(torch.long)

    # Particle Filtering Parameters
    # num_particles = 10000     # worked!
    num_particles = 500
    z_dim = 100
    x_dim = 88
    z = torch.ones((num_particles, z_dim), requires_grad=False)
    x = torch.ones((num_hops, num_particles, x_dim), requires_grad=False)
    w = torch.ones((num_hops, num_particles), requires_grad=False)
    w_naive = torch.ones((num_hops, num_particles), requires_grad=False)

    # Generate initial particles
    count = 0
    for p in range(num_particles):
        z_prev = pyro.sample("init_z", z1_dist)
        z[p, :], x[0, p, :] = dmm.get_sample(z_prev, p)
        num_same = torch.sum(
            x[0, p, :].type(torch.long) == piano_roll_gt[0]).item()
        if num_same == 88:
            count += 1
    print("Got {} correct samples in step {}".format(count, 0))
    count = 0

    ## Calculate initial weights
    ## Uniform
    # w[0,:] = 1.0/w.shape[1]

    ## Use calculates obs_probs (from DNN or elsewhere)
    for p in range(num_particles):
        prob = calc_obs_probs(obs_probs[0, :, :], x[0, p, :])
        w[0, p] *= prob
        w_naive[0, p] *= prob

    # Normalize weights
    w[0, :] = normalize_weights(w[0, :])
    w_naive[0, :] = normalize_weights(w_naive[0, :])

    ## Main Particle Filtering Loop
    good_samples = np.zeros(num_hops)
    for f in range(1, num_hops):
        ## Sample new particles
        z_vals, z_probs = particles_to_dist(z, w[f - 1, :])
        for p in range(num_particles):
            idx = discrete_sample(z_probs)
            z[p, :], x[f, p, :] = dmm.get_sample(z_vals[idx],
                                                 p + f * num_particles)
            num_same = torch.sum(
                x[f, p, :].type(torch.long) == piano_roll_gt[f]).item()
            if num_same == 88:
                count += 1

            ## Calculate Weights -- probably bring this into loop above
            #- Primitive observation-free model of observations (more notes correct, higher prob)
            sx_prob = ((num_same - 78) / 10)**2
            w_naive[f, p] *= sx_prob  #*xz_prob
            #- Use Observation probabilities
            sx_prob = calc_obs_probs(obs_probs[f, :, :], x[f, p, :])
            w[f, p] *= sx_prob  #*xz_prob
            # Calculate probability of x given z
            # xz_dist = dist.Bernoulli(dmm.emitter(z[p,:]))
            # xz_prob = torch.exp(torch.sum(xz_dist.log_prob(x[f,p,:]))).item()

        # Report number of samples that corresponded with ground truth
        print("Got {} correct samples in step {} \t".format(count, f), end='')
        good_samples[f] = count
        count = 0

        ## Normalize
        w[f, :] = w[f, :].pow(0.25)
        w[f, :] = normalize_weights(w[f, :])
        w_naive[f, :] = normalize_weights(w_naive[f, :])
        # plt.plot(w[f,:].numpy())
        # plt.plot(w_naive[f,:].numpy())
        # plt.show()
        print("\tDone!")

    # Now pull out the most probable path
    piano_roll_dist = np.zeros((num_hops, 88))

    if condense:
        print("\"Condensing\" final distribution")
        w_c = []
        x_c = []
        for f in range(num_hops):
            w_condensed, x_condensed = make_final_dist(w[f, :], x[f, :, :])
            w_c.append(w_condensed)
            x_c.append(x_condensed)
            print("{} unique samples in step {}/{}".format(
                len(w_condensed), f + 1, num_hops))
            piano_roll_dist[f, :] = np.sum(x_condensed *
                                           w_condensed[:, np.newaxis],
                                           axis=0)
        ax_dist.imshow(np.flip(np.transpose(piano_roll_dist), axis=0))

    ## Most probable path by picking highest weighted particle
    piano_roll_estimated = np.zeros((num_hops, 88))
    piano_roll_estimated_c = np.zeros((num_hops, 88))
    for f in range(num_hops):
        ## just picking highest weight
        idx = np.argmax(w[f, :].numpy())
        piano_roll_estimated[f, :] = x[f, idx, :].numpy()
        ## picking from highest condensed weight
        if condense:
            idx = np.argmax(w_c[f])
            piano_roll_estimated_c[f, :] = x_c[f][idx, :]
    ax_estimated.imshow(np.flip(np.transpose(piano_roll_estimated), axis=0))
    ax_estimated.set_yticks(np.arange(1, 89)[::10])
    ax_estimated.set_yticklabels(np.arange(88, 0, -1)[::10])
    if condense:
        ax_estimated_c.imshow(
            np.flip(np.transpose(piano_roll_estimated_c), axis=0))

    ## Calculate and report precision and recall
    p, r, f = precision_recall_f(piano_roll_gt.numpy(), piano_roll_estimated)
    print("Precision (regular): \t", p)
    print("Recall (regular):    \t", r)
    print("F-metric (regular):  \t", f)

    ## Check how often the correct sample was chosen when available
    gt = piano_roll_gt.numpy()
    num_available = 0
    num_chosen = 0
    num_available = 0
    num_chosen_c = 0
    for f in range(num_hops):
        if good_samples[f] > 0:
            num_available += 1
            if np.array_equal(gt[f, :], piano_roll_estimated[f, :]):
                num_chosen += 1
            if condense:
                if np.array_equal(gt[f, :], piano_roll_estimated_c[f, :]):
                    num_chosen_c += 1
    print("Correct select rate (normal)   : {}".format(num_chosen /
                                                       num_available))
    if condense:
        print("Correct select rate (condensed): {}".format(num_chosen_c /
                                                           num_available))

    ## Make audio for estimate piano roll
    piano_roll_estimated_full = np.zeros((128, seq_len), dtype=int)
    piano_roll_estimated_full[
        20:108, :] += piano_roll_estimated.transpose().astype(int)
    piano_roll_estimated_full *= 64
    print("Sythensizing audio for input...", end="")
    with suppress_stdout():
        pm = piano_roll_to_pretty_midi(piano_roll_estimated_full.astype(float),
                                       fs=2,
                                       program=52)
        audio = pm.fluidsynth(fs=fs_aud,
                              sf2_path='/usr/share/soundfonts/FluidR3_GM.sf2')
    print("done.")
    wav.write("test_estimated_out.wav", fs_aud, audio)
    # ax_estimated.imshow(np.flip(piano_roll_estimated_full, axis=0))

    ## Sample a random sequence starting at the same initial latent state
    # x_vals = []
    # z_vals = []
    # piano_roll_sampled = np.zeros((seq_len, 88))

    # for k in range(seq_len):
    #     z_new, x_new = dmm.get_sample(z_prev, k)
    #     x_vals.append(x_new)
    #     z_vals.append(z_new)
    #     piano_roll_sampled[k,:] = x_new.data.numpy()
    #     z_prev = z_new
    #
    # # Get MIDI from piano from sampled roll
    # piano_roll_sampled_full = np.zeros((128, seq_len), dtype=int)
    # piano_roll_sampled_full[20:108,:] += piano_roll_sampled.transpose().astype(int)
    # piano_roll_sampled_full *= 64
    # print("Sythensizing audio for input...", end="")
    # with suppress_stdout():
    #     pm = piano_roll_to_pretty_midi(piano_roll_sampled_full.astype(float), fs=1, program=52)
    #     audio = pm.fluidsynth(fs=fs_aud,
    #                           sf2_path='/usr/share/soundfonts/FluidR3_GM.sf2')
    # print('done.')
    # wav.write("test_sampled_out.wav", fs_aud, audio)
    # ax_sampled.imshow(np.flip(np.transpose(piano_roll_sampled), axis=0))
    # plt.tight_layout()
    plt.show()
Ejemplo n.º 5
0
Archivo: dmm.py Proyecto: ucals/pyro
def main(args):
    # setup logging
    log = get_logger(args.log)
    log(args)

    data = poly.load_data(poly.JSB_CHORALES)
    training_seq_lengths = data['train']['sequence_lengths']
    training_data_sequences = data['train']['sequences']
    test_seq_lengths = data['test']['sequence_lengths']
    test_data_sequences = data['test']['sequences']
    val_seq_lengths = data['valid']['sequence_lengths']
    val_data_sequences = data['valid']['sequences']
    N_train_data = len(training_seq_lengths)
    N_train_time_slices = float(torch.sum(training_seq_lengths))
    N_mini_batches = int(N_train_data / args.mini_batch_size +
                         int(N_train_data % args.mini_batch_size > 0))

    log("N_train_data: %d     avg. training seq. length: %.2f    N_mini_batches: %d"
        % (N_train_data, training_seq_lengths.float().mean(), N_mini_batches))

    # how often we do validation/test evaluation during training
    val_test_frequency = 50
    # the number of samples we use to do the evaluation
    n_eval_samples = 1

    # package repeated copies of val/test data for faster evaluation
    # (i.e. set us up for vectorization)
    def rep(x):
        rep_shape = torch.Size([x.size(0) * n_eval_samples]) + x.size()[1:]
        repeat_dims = [1] * len(x.size())
        repeat_dims[0] = n_eval_samples
        return x.repeat(repeat_dims).reshape(n_eval_samples, -1).transpose(
            1, 0).reshape(rep_shape)

    # get the validation/test data ready for the dmm: pack into sequences, etc.
    val_seq_lengths = rep(val_seq_lengths)
    test_seq_lengths = rep(test_seq_lengths)
    val_batch, val_batch_reversed, val_batch_mask, val_seq_lengths = poly.get_mini_batch(
        torch.arange(n_eval_samples * val_data_sequences.shape[0]),
        rep(val_data_sequences),
        val_seq_lengths,
        cuda=args.cuda)
    test_batch, test_batch_reversed, test_batch_mask, test_seq_lengths = poly.get_mini_batch(
        torch.arange(n_eval_samples * test_data_sequences.shape[0]),
        rep(test_data_sequences),
        test_seq_lengths,
        cuda=args.cuda)

    # instantiate the dmm
    dmm = DMM(rnn_dropout_rate=args.rnn_dropout_rate,
              num_iafs=args.num_iafs,
              iaf_dim=args.iaf_dim,
              use_cuda=args.cuda)

    # setup optimizer
    adam_params = {
        "lr": args.learning_rate,
        "betas": (args.beta1, args.beta2),
        "clip_norm": args.clip_norm,
        "lrd": args.lr_decay,
        "weight_decay": args.weight_decay
    }
    adam = ClippedAdam(adam_params)

    # setup inference algorithm
    if args.tmc:
        if args.jit:
            raise NotImplementedError("no JIT support yet for TMC")
        tmc_loss = TraceTMC_ELBO()
        dmm_guide = config_enumerate(dmm.guide,
                                     default="parallel",
                                     num_samples=args.tmc_num_samples,
                                     expand=False)
        svi = SVI(dmm.model, dmm_guide, adam, loss=tmc_loss)
    elif args.tmcelbo:
        if args.jit:
            raise NotImplementedError("no JIT support yet for TMC ELBO")
        elbo = TraceEnum_ELBO()
        dmm_guide = config_enumerate(dmm.guide,
                                     default="parallel",
                                     num_samples=args.tmc_num_samples,
                                     expand=False)
        svi = SVI(dmm.model, dmm_guide, adam, loss=elbo)
    else:
        elbo = JitTrace_ELBO() if args.jit else Trace_ELBO()
        svi = SVI(dmm.model, dmm.guide, adam, loss=elbo)

    # now we're going to define some functions we need to form the main training loop

    # saves the model and optimizer states to disk
    def save_checkpoint():
        log("saving model to %s..." % args.save_model)
        torch.save(dmm.state_dict(), args.save_model)
        log("saving optimizer states to %s..." % args.save_opt)
        adam.save(args.save_opt)
        log("done saving model and optimizer checkpoints to disk.")

    # loads the model and optimizer states from disk
    def load_checkpoint():
        assert exists(args.load_opt) and exists(args.load_model), \
            "--load-model and/or --load-opt misspecified"
        log("loading model from %s..." % args.load_model)
        dmm.load_state_dict(torch.load(args.load_model))
        log("loading optimizer states from %s..." % args.load_opt)
        adam.load(args.load_opt)
        log("done loading model and optimizer states.")

    # prepare a mini-batch and take a gradient step to minimize -elbo
    def process_minibatch(epoch, which_mini_batch, shuffled_indices):
        if args.annealing_epochs > 0 and epoch < args.annealing_epochs:
            # compute the KL annealing factor approriate for the current mini-batch in the current epoch
            min_af = args.minimum_annealing_factor
            annealing_factor = min_af + (1.0 - min_af) * \
                (float(which_mini_batch + epoch * N_mini_batches + 1) /
                 float(args.annealing_epochs * N_mini_batches))
        else:
            # by default the KL annealing factor is unity
            annealing_factor = 1.0

        # compute which sequences in the training set we should grab
        mini_batch_start = (which_mini_batch * args.mini_batch_size)
        mini_batch_end = np.min([(which_mini_batch + 1) * args.mini_batch_size,
                                 N_train_data])
        mini_batch_indices = shuffled_indices[mini_batch_start:mini_batch_end]
        # grab a fully prepped mini-batch using the helper function in the data loader
        mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths \
            = poly.get_mini_batch(mini_batch_indices, training_data_sequences,
                                  training_seq_lengths, cuda=args.cuda)
        # do an actual gradient step
        loss = svi.step(mini_batch, mini_batch_reversed, mini_batch_mask,
                        mini_batch_seq_lengths, annealing_factor)
        # keep track of the training loss
        return loss

    # helper function for doing evaluation
    def do_evaluation():
        # put the RNN into evaluation mode (i.e. turn off drop-out if applicable)
        dmm.rnn.eval()

        # compute the validation and test loss n_samples many times
        val_nll = svi.evaluate_loss(val_batch, val_batch_reversed,
                                    val_batch_mask, val_seq_lengths) / float(
                                        torch.sum(val_seq_lengths))
        test_nll = svi.evaluate_loss(
            test_batch, test_batch_reversed, test_batch_mask,
            test_seq_lengths) / float(torch.sum(test_seq_lengths))

        # put the RNN back into training mode (i.e. turn on drop-out if applicable)
        dmm.rnn.train()
        return val_nll, test_nll

    # if checkpoint files provided, load model and optimizer states from disk before we start training
    if args.load_opt != '' and args.load_model != '':
        load_checkpoint()

    #################
    # TRAINING LOOP #
    #################
    times = [time.time()]
    for epoch in range(args.num_epochs):
        # if specified, save model and optimizer states to disk every checkpoint_freq epochs
        if args.checkpoint_freq > 0 and epoch > 0 and epoch % args.checkpoint_freq == 0:
            save_checkpoint()

        # accumulator for our estimate of the negative log likelihood (or rather -elbo) for this epoch
        epoch_nll = 0.0
        # prepare mini-batch subsampling indices for this epoch
        shuffled_indices = torch.randperm(N_train_data)

        # process each mini-batch; this is where we take gradient steps
        for which_mini_batch in range(N_mini_batches):
            epoch_nll += process_minibatch(epoch, which_mini_batch,
                                           shuffled_indices)

        # report training diagnostics
        times.append(time.time())
        epoch_time = times[-1] - times[-2]
        log("[training epoch %04d]  %.4f \t\t\t\t(dt = %.3f sec)" %
            (epoch, epoch_nll / N_train_time_slices, epoch_time))

        # do evaluation on test and validation data and report results
        if val_test_frequency > 0 and epoch > 0 and epoch % val_test_frequency == 0:
            val_nll, test_nll = do_evaluation()
            log("[val/test epoch %04d]  %.4f  %.4f" %
                (epoch, val_nll, test_nll))
Ejemplo n.º 6
0
def main(args):
    # clear param store
    pyro.clear_param_store()

    ### SETUP
    train_loader, test_loader = get_data()

    # setup the VAE
    vae = VAE(use_cuda=args.cuda)

    # setup the optimizer
    adam_args = {"lr": args.learning_rate}
    optimizer = Adam(adam_args)

    # setup the inference algorithm
    elbo = JitTrace_ELBO() if args.jit else Trace_ELBO()
    svi = SVI(vae.model, vae.guide, optimizer, loss=elbo)

    inputSize = 0

    # setup visdom for visualization
    if args.visdom_flag:
        vis = visdom.Visdom()

    train_elbo = []
    test_elbo = []

    for epoch in range(args.num_epochs):
        # initialize loss accumulator
        epoch_loss = 0.
        # do a training epoch over each mini-batch x returned
        # by the data loader

        for step, batch in enumerate(train_loader):
            print("batch", batch)
            x, adj = 0, 0
            # if on GPU put mini-batch into CUDA memory
            if args.cuda:
                x = batch['x'].cuda()
                adj = batch['edge_index'].cuda()
            else:
                x = batch['x']
                adj = batch['edge_index']
            print("x_shape", x.shape)
            print("adj_shape", adj.shape)

            inputSize = x.shape[0] * x.shape[1]
            epoch_loss += svi.step(x, adj)

        # report training diagnostics
        normalizer_train = len(train_loader.dataset)
        total_epoch_loss_train = epoch_loss / normalizer_train
        train_elbo.append(total_epoch_loss_train)
        print("[epoch %03d]  average training loss: %.4f" %
              (epoch, total_epoch_loss_train))

        if True:
            # if epoch % args.test_frequency == 0:
            # initialize loss accumulator
            test_loss = 0.
            # compute the loss over the entire test set
            for step, batch in enumerate(test_loader):
                x, adj = 0, 0
                # if on GPU put mini-batch into CUDA memory
                if args.cuda:
                    x = batch['x'].cuda()
                    adj = batch['edge_index'].cuda()
                else:
                    x = batch['x']
                    adj = batch['edge_index']

                # compute ELBO estimate and accumulate loss
                # print('before evaluating test loss')
                test_loss += svi.evaluate_loss(x, adj)
                # print('after evaluating test loss')

                # pick three random test images from the first mini-batch and
                # visualize how well we're reconstructing them
                # if i == 0:
                #     if args.visdom_flag:
                #         plot_vae_samples(vae, vis)
                #         reco_indices = np.random.randint(0, x.shape[0], 3)
                #         for index in reco_indices:
                #             test_img = x[index, :]
                #             reco_img = vae.reconstruct_img(test_img)
                #             vis.image(test_img.reshape(28, 28).detach().cpu().numpy(),
                #                       opts={'caption': 'test image'})
                #             vis.image(reco_img.reshape(28, 28).detach().cpu().numpy(),
                #                       opts={'caption': 'reconstructed image'})

                if args.visdom_flag:
                    plot_vae_samples(vae, vis)
                    reco_indices = np.random.randint(0, x.shape[0], 3)
                    for index in reco_indices:
                        test_img = x[index, :]
                        reco_img = vae.reconstruct_graph(test_img)
                        vis.image(test_img.reshape(28,
                                                   28).detach().cpu().numpy(),
                                  opts={'caption': 'test image'})
                        vis.image(reco_img.reshape(28,
                                                   28).detach().cpu().numpy(),
                                  opts={'caption': 'reconstructed image'})

            # report test diagnostics
            normalizer_test = len(test_loader.dataset)
            total_epoch_loss_test = test_loss / normalizer_test
            test_elbo.append(total_epoch_loss_test)
            print("[epoch %03d]  average test loss: %.4f" %
                  (epoch, total_epoch_loss_test))

        # if epoch == args.tsne_iter:
        #     mnist_test_tsne(vae=vae, test_loader=test_loader)
        #     plot_llk(np.array(train_elbo), np.array(test_elbo))

    if args.save:
        torch.save(
            {
                'epoch': epoch,
                'model_state_dict': vae.state_dict(),
                'optimzier_state_dict': optimizer.get_state(),
                'train_loss': total_epoch_loss_train,
                'test_loss': total_epoch_loss_test
            }, 'vae_' + args.name + str(args.time) + '.pt')

    return vae
Ejemplo n.º 7
0
def run_inference(dataset_obj: SingleCellRNACountsDataset,
                  args) -> RemoveBackgroundPyroModel:
    """Run a full inference procedure, training a latent variable model.

    Args:
        dataset_obj: Input data in the form of a SingleCellRNACountsDataset
            object.
        args: Input command line parsed arguments.

    Returns:
         model: cellbender.model.RemoveBackgroundPyroModel that has had
            inference run.

    """

    # Get the trimmed count matrix (transformed if called for).
    count_matrix = dataset_obj.get_count_matrix()

    # Configure pyro options (skip validations to improve speed).
    pyro.enable_validation(False)
    pyro.distributions.enable_validation(False)
    pyro.set_rng_seed(0)
    pyro.clear_param_store()

    # Set up the variational autoencoder:

    # Encoder.
    encoder_z = EncodeZ(input_dim=count_matrix.shape[1],
                        hidden_dims=args.z_hidden_dims,
                        output_dim=args.z_dim,
                        input_transform='normalize')

    encoder_d = EncodeD(
        input_dim=count_matrix.shape[1],
        hidden_dims=args.d_hidden_dims,
        output_dim=1,
        log_count_crossover=dataset_obj.priors['log_counts_crossover'])

    if args.model == "simple":

        # If using the simple model, there is no need for p.
        encoder = CompositeEncoder({'z': encoder_z, 'd_loc': encoder_d})

    else:

        # Models that include empty droplets.
        encoder_p = EncodeNonEmptyDropletLogitProb(
            input_dim=count_matrix.shape[1],
            hidden_dims=args.p_hidden_dims,
            output_dim=1,
            input_transform='normalize',
            log_count_crossover=dataset_obj.priors['log_counts_crossover'])
        encoder = CompositeEncoder({
            'z': encoder_z,
            'd_loc': encoder_d,
            'p_y': encoder_p
        })

    # Decoder.
    decoder = Decoder(input_dim=args.z_dim,
                      hidden_dims=args.z_hidden_dims[::-1],
                      output_dim=count_matrix.shape[1])

    # Set up the pyro model for variational inference.
    model = RemoveBackgroundPyroModel(model_type=args.model,
                                      encoder=encoder,
                                      decoder=decoder,
                                      dataset_obj=dataset_obj,
                                      use_cuda=args.use_cuda)

    # Set up the optimizer.
    adam_args = {"lr": args.learning_rate}
    optimizer = ClippedAdam(adam_args)

    # Determine the loss function.
    if args.use_jit:
        loss_function = JitTraceEnum_ELBO(max_plate_nesting=1,
                                          strict_enumeration_warning=False)
    else:
        loss_function = TraceEnum_ELBO(max_plate_nesting=1)

    if args.model == "simple":
        if args.use_jit:
            loss_function = JitTrace_ELBO()
        else:
            loss_function = Trace_ELBO()

    # Set up the inference process.
    svi = SVI(model.model, model.guide, optimizer, loss=loss_function)

    # Load the dataset into DataLoaders.
    frac = args.training_fraction  # Fraction of barcodes to use for training
    batch_size = int(
        min(500, frac * dataset_obj.analyzed_barcode_inds.size / 2))
    train_loader, test_loader = \
        prep_data_for_training(dataset=count_matrix,
                               empty_drop_dataset=
                               dataset_obj.get_count_matrix_empties(),
                               random_state=dataset_obj.random,
                               batch_size=batch_size,
                               training_fraction=frac,
                               fraction_empties=args.fraction_empties,
                               shuffle=True,
                               use_cuda=args.use_cuda)

    # Run training.
    run_training(model,
                 svi,
                 train_loader,
                 test_loader,
                 epochs=args.epochs,
                 test_freq=10)

    return model
Ejemplo n.º 8
0
def training(args, rel_embeddings, word_embeddings):
    if args.seed is not None:
        pyro.set_rng_seed(args.seed)
    # CUDA for PyTorch
    cuda_available = torch.cuda.is_available()
    if (cuda_available and args.cuda):
        device = torch.device("cuda")
        torch.cuda.set_device(0)
        print("using gpu acceleration")

    print("Generating Config")
    config = Config(
        word_embeddings=torch.FloatTensor(word_embeddings),
        decoder_hidden_dim=args.decoder_hidden_dim,
        num_relations=7,
        encoder_hidden_dim=args.encoder_hidden_dim,
        num_predicates=1000,
        batch_size=args.batch_size
    )

    # initialize the generator model
    generator = SimpleGenerator(config)

    # setup the optimizer
    adam_params = {"lr": args.learning_rate, "betas": (args.beta_1, 0.999)}
    optimizer = ClippedAdam(adam_params)

    # set up the loss(es) for inference. wrapping the guide in config_enumerate builds the loss as a sum
    # by enumerating each class label for the sampled discrete categorical distribution in the model
    if args.enumerate:
        guide = config_enumerate(generator.guide, args.enum_discrete, expand=True)
    else:
        guide = generator.guide
    elbo = (JitTraceEnum_ELBO if args.jit else TraceEnum_ELBO)(max_plate_nesting=1)
    loss_basic = SVI(generator.model, guide, optimizer, loss=elbo)

    # build a list of all losses considered
    losses = [loss_basic]

    # aux_loss: whether to use the auxiliary loss from NIPS 14 paper (Kingma et al)
    if args.aux_loss:
        elbo = JitTrace_ELBO() if args.jit else Trace_ELBO()
        loss_aux = SVI(generator.model_identify, generator.guide_identify, optimizer, loss=elbo)
        losses.append(loss_aux)

    # prepare data
    real_model = RealModel(rel_embeddings, word_embeddings)
    data = real_model.generate_data()
    sup_train_set = data[:100]
    unsup_train_set = data[100:700]
    eval_set = data[700:900]
    test_set = data[900:]

    data_loaders = setup_data_loaders(sup_train_set,
                                      unsup_train_set,
                                      eval_set,
                                      test_set,
                                      batch_size=args.batch_size)

    num_train = len(sup_train_set) + len(unsup_train_set)
    num_eval = len(eval_set)
    num_test = len(test_set)

    # how often would a supervised batch be encountered during inference
    # e.g. if sup_num is 3000, we would have every 16th = int(50000/3000) batch supervised
    # until we have traversed through the all supervised batches
    periodic_interval_batches = int(1.0 * num_train / len(sup_train_set))

    # setup the logger if a filename is provided
    log_fn = "./logs/" + args.experiment_type + '/' + args.experiment_name + '.log'
    logger = open(log_fn, "w")

    # run inference for a certain number of epochs
    for i in tqdm(range(0, args.num_epochs)):
        # get the losses for an epoch
        epoch_losses_sup, epoch_losses_unsup = \
            train_epoch(data_loaders=data_loaders,
                        models=losses,
                        periodic_interval_batches=periodic_interval_batches)

        # compute average epoch losses i.e. losses per example
        avg_epoch_losses_sup = map(lambda v: v / len(sup_train_set), epoch_losses_sup)
        avg_epoch_losses_unsup = map(lambda v: v / len(unsup_train_set), epoch_losses_unsup)

        # store the loss and validation/testing accuracies in the logfile
        str_loss_sup = " ".join(map(str, avg_epoch_losses_sup))
        str_loss_unsup = " ".join(map(str, avg_epoch_losses_unsup))
        str_print = "{} epoch: avg losses {}".format(i, "{} {}".format(str_loss_sup, str_loss_unsup))
        print_and_log(logger, str_print)

    # save trained models
    torch.save(generator.state_dict(), './models/test_generator_state_dict.pth')
    return generator
Ejemplo n.º 9
0
def main(args):
    # clear param store
    pyro.clear_param_store()

    # setup Protein data loaders
    # train_loader, test_loader
    train_loader, test_loader = setup_data_loaders(batch_size=256, use_cuda=args.cuda)
    ramaPlot.plot_loader(train_loader)

    # setup the VAE
    vae = VAE(use_cuda=args.cuda)
    #set_trace()
    # setup optimizer
    adam_args = {"lr": args.learning_rate}
    optimizer = Adam(adam_args)

    # setup the inference algorithm
    elbo = JitTrace_ELBO() if args.jit else Trace_ELBO()
    svi = SVI(vae.model, vae.guide, optimizer, loss=elbo)

    """
    for x, y in train_loader:
        print("x in train_loader:", x)
        set_trace()
        """

    train_elbo = []
    test_elbo = []

    for epoch in range(args.num_epochs):
        # initialize loss accumulator
        epoch_loss = 0.
        # do a training epoch over each mini-batch x returned
        # by the data loader
        for x in train_loader:
            # if on GPU put mini-batch into CUDA memory
            # float() needed because it is currently double on that gives error
            # plus double datatype is considerably slower on GPU
            x = x.float()
            if args.cuda:
                x = x.float()
                x = x.cuda()
            # do ELBO gradient and accumulate los
            #set_trace()
            epoch_loss += svi.step(x)

        # report training diagnostics
        normalizer_train = len(train_loader.dataset)
        total_epoch_loss_train = epoch_loss / normalizer_train
        train_elbo.append(total_epoch_loss_train)
        print("[epoch %03d]  average training loss: %.4f" % (epoch, total_epoch_loss_train))

        if epoch % args.test_frequency == 0:
            # initialize loss accumulator
            test_loss = 0.
            for i, x in enumerate(test_loader):
                #set_trace()
                # float() needed because it is currently double on that gives error
                # plus double datatype is considerably slower on GPU
                x = x.float()
                # if on GPU put mini-batch into CUDA memory
                if args.cuda:
                    x = x.cuda()
                # compute ELBO estimate and accumulate loss
                test_loss += svi.evaluate_loss(x)

                # pick three random test angles from the first mini-batch and
                # visualize how well we're reconstructing them
                if i == 0:
                    #print("i = 0")
                    #ramaPlot.plot_vae_samples(vae,str(epoch))
                    #ramaPlot.make_plots(train_elbo, train_loader, elbo, vae, x)
                    ramaPlot.make_plots2(train_elbo, train_loader, elbo, vae)
                    reco_indices = np.random.randint(0, x.shape[0], 3)
                    for index in reco_indices:
                        test_angle = x[index, :]
                        reco_angle = vae.reconstruct_plot(test_angle)
                        #set_trace()
                    #set_trace()
                    #ramaPlot.plot(test_angle, str(epoch)+"test_angle")
                    #ramaPlot.plot(reco_angle, str(epoch)+"reconstructed_angle")


            # report test diagnostics
            normalizer_test = len(test_loader.dataset)
            total_epoch_loss_test = test_loss / normalizer_test
            test_elbo.append(total_epoch_loss_test)
            print("[epoch %03d]  average test loss: %.4f" % (epoch, total_epoch_loss_test))
        """
        if epoch == args.tsne_iter:
            mnist_test_tsne(vae=vae, test_loader=test_loader)
            plot_llk(np.array(train_elbo), np.array(test_elbo))
        """
    return vae
Ejemplo n.º 10
0
def run_inference(dataset_obj: Dataset, args) -> VariationalInferenceModel:
    """Run a full inference procedure, training a latent variable model.

    Args:
        dataset_obj: Input data in the form of a Dataset object.
        args: Input command line parsed arguments.

    Returns:
         model: cellbender.model.VariationalInferenceModel that has had
         inference run.

    """

    # Get the trimmed count matrix (transformed if called for).
    count_matrix = dataset_obj.get_count_matrix()

    # Configure pyro options (skip validations to improve speed).
    pyro.enable_validation(False)
    pyro.distributions.enable_validation(False)
    pyro.set_rng_seed(0)
    pyro.clear_param_store()

    # Set up the variational autoencoder:

    # Encoder.
    encoder_z = EncodeZ(input_dim=count_matrix.shape[1],
                        hidden_dims=args.z_hidden_dims,
                        output_dim=args.z_dim,
                        input_transform='normalize')

    encoder_d = EncodeD(
        input_dim=count_matrix.shape[1],
        hidden_dims=args.d_hidden_dims,
        output_dim=1,
        log_count_crossover=dataset_obj.priors['log_counts_crossover'])

    if args.model[0] == "simple":

        # If using the simple model, there is no need for p.
        encoder = CompositeEncoder({'z': encoder_z, 'd_loc': encoder_d})

    else:

        # Models that include empty droplets.
        encoder_p = EncodePAmbient(
            input_dim=count_matrix.shape[1],
            hidden_dims=args.p_hidden_dims,
            output_dim=1,
            input_transform='normalize',
            log_count_crossover=dataset_obj.priors['log_counts_crossover'])
        encoder = CompositeEncoder({
            'z': encoder_z,
            'd_loc': encoder_d,
            'p_y': encoder_p
        })

    # Decoder.
    decoder = Decoder(input_dim=args.z_dim,
                      hidden_dims=args.z_hidden_dims[::-1],
                      output_dim=count_matrix.shape[1])

    # Set up the pyro model for variational inference.
    model = VariationalInferenceModel(
        model_type=args.model[0],
        encoder=encoder,
        decoder=decoder,
        dataset_obj=dataset_obj,
        use_decaying_avg_baseline=args.use_decaying_average_baseline,
        use_IAF=args.use_IAF,
        use_cuda=args.use_cuda)

    # Load the dataset into DataLoaders.
    frac = args.training_fraction
    batch_size = int(
        min(500, frac * dataset_obj.analyzed_barcode_inds.size / 2))
    train_loader, test_loader = \
        prep_data_for_training(dataset=count_matrix,
                               empty_drop_dataset=
                               dataset_obj.get_count_matrix_empties(),
                               batch_size=batch_size,
                               training_fraction=frac,
                               fraction_empties=args.fraction_empties,
                               shuffle=True,
                               use_cuda=args.use_cuda)

    # Run the guide once for Jit. (can hang on StopIteration if no test data!)
    # model.guide(test_loader.__iter__().__next__())  # This seems unnecessary

    # Set up the optimizer.
    adam_args = {"lr": args.learning_rate}
    optimizer = ClippedAdam(adam_args)

    # Determine the loss function.
    # loss_function = TraceEnum_ELBO(max_plate_nesting=1)
    loss_function = JitTraceEnum_ELBO(max_plate_nesting=1,
                                      strict_enumeration_warning=False)
    if args.model[0] == "simple":
        loss_function = JitTrace_ELBO()

    # Set up the inference process.
    svi = SVI(model.model, model.guide, optimizer, loss=loss_function)

    # Run training.
    run_training(model,
                 svi,
                 train_loader,
                 test_loader,
                 epochs=args.epochs,
                 test_freq=10)

    return model
Ejemplo n.º 11
0
    else:
        raise Exception(f'model {args.irt_model} not recognized')

    model = model_class(
        args.ability_dim,
        num_item,
        hidden_dim=args.hidden_dim,
        ability_merge=args.ability_merge,
        conditional_posterior=args.conditional_posterior,
        generative_model=args.generative_model,
        num_iafs=args.num_iafs,
        iaf_dim=args.iaf_dim,
    ).to(device)

    optimizer = Adam({"lr": args.lr})
    elbo = (JitTrace_ELBO(num_particles=args.num_particles)
            if args.jit else Trace_ELBO(num_particles=args.num_particles))
    svi = SVI(model.model, model.guide, optimizer, loss=elbo)

    def get_annealing_factor(epoch, which_mini_batch):
        if args.anneal_kl:
            annealing_factor = \
                (float(which_mini_batch + epoch * N_mini_batches + 1) /
                 float(args.epochs // 2 * N_mini_batches))
        else:
            annealing_factor = args.beta_kl
        return annealing_factor

    def train(epoch):
        model.train()
        train_loss = AverageMeter()
Ejemplo n.º 12
0
def main(args):
    # clear param store
    pyro.clear_param_store()
    #pyro.enable_validation(True)

    # train_loader, test_loader
    transform = {}
    transform["train"] = transforms.Compose([
        transforms.Resize((400, 400)),
        transforms.ToTensor(),
    ])
    transform["test"] = transforms.Compose(
        [transforms.Resize((400, 400)),
         transforms.ToTensor()])

    train_loader, test_loader = setup_data_loaders(
        dataset=GameCharacterFullData,
        root_path=DATA_PATH,
        batch_size=32,
        transforms=transform)

    # setup the VAE
    vae = VAE(use_cuda=args.cuda, num_labels=17)

    # setup the optimizer
    adam_args = {"lr": args.learning_rate}
    optimizer = Adam(adam_args)

    # setup the inference algorithm
    elbo = JitTrace_ELBO() if args.jit else Trace_ELBO()
    svi = SVI(vae.model, vae.guide, optimizer, loss=elbo)

    # setup visdom for visualization
    if args.visdom_flag:
        vis = visdom.Visdom(port='8097')

    train_elbo = []
    test_elbo = []
    # training loop
    for epoch in range(args.num_epochs):
        # initialize loss accumulator
        epoch_loss = 0.
        # do a training epoch over each mini-batch x returned
        # by the data loader
        for x, y, actor, reactor, actor_type, reactor_type, action, reaction in train_loader:
            # if on GPU put mini-batch into CUDA memory
            if args.cuda:
                x = x.cuda()
                y = y.cuda()
                actor = actor.cuda()
                reactor = reactor.cuda()
                actor_type = actor_type.cuda()
                reactor_type = reactor_type.cuda()
                action = action.cuda()
                reaction = reaction.cuda()
            # do ELBO gradient and accumulate loss
            epoch_loss += svi.step(x, y, actor, reactor, actor_type,
                                   reactor_type, action, reaction)

        # report training diagnostics
        normalizer_train = len(train_loader.dataset)
        total_epoch_loss_train = epoch_loss / normalizer_train
        train_elbo.append(total_epoch_loss_train)
        print("[epoch %03d]  average training loss: %.4f" %
              (epoch, total_epoch_loss_train))

        if epoch % args.test_frequency == 0:
            # initialize loss accumulator
            test_loss = 0.
            # compute the loss over the entire test set
            for i, (x, y, actor, reactor, actor_type, reactor_type, action,
                    reaction) in enumerate(test_loader):
                # if on GPU put mini-batch into CUDA memory
                if args.cuda:
                    x = x.cuda()
                    y = y.cuda()
                    actor = actor.cuda()
                    reactor = reactor.cuda()
                    actor_type = actor_type.cuda()
                    reactor_type = reactor_type.cuda()
                    action = action.cuda()
                    reaction = reaction.cuda()
                # compute ELBO estimate and accumulate loss
                test_loss += svi.evaluate_loss(x, y, actor, reactor,
                                               actor_type, reactor_type,
                                               action, reaction)
                # pick three random test images from the first mini-batch and
                # visualize how well we're reconstructing them
                if i == 0:
                    if args.visdom_flag:
                        plot_vae_samples(vae, vis)
                        reco_indices = np.random.randint(0, x.shape[0], 3)
                        for index in reco_indices:
                            test_img = x[index, :]
                            reco_img = vae.reconstruct_img(test_img)
                            vis.image(test_img.reshape(
                                400, 400).detach().cpu().numpy(),
                                      opts={'caption': 'test image'})
                            vis.image(reco_img.reshape(
                                400, 400).detach().cpu().numpy(),
                                      opts={'caption': 'reconstructed image'})
            # report test diagnostics
            normalizer_test = len(test_loader.dataset)
            total_epoch_loss_test = test_loss / normalizer_test
            test_elbo.append(total_epoch_loss_test)
            print("[epoch %03d]  average test loss: %.4f" %
                  (epoch, total_epoch_loss_test))

    return vae, optimizer