Esempio n. 1
0
def run_variational_network(train_images,N_weights_dec,decoder,decoder_log_like,trained_weights,all_mean):
    start_time = time.time()

    # Create aevb function
    # Training parameters

    D = train_images.shape[1]

    enc_layers = [D, hidden_units,hidden_units, 2*latent_dimensions]

    N_weights_enc, encoder, encoder_log_like = make_gaussian_nn(enc_layers)

    # Optimize aevb
    batch_size = 100
    num_training_iters = 1600
    rs = npr.RandomState(0)

    parser = WeightsParser()
    parser.add_shape('encoding weights', (N_weights_enc,))
    initial_enc_w = rs.randn(len(parser)) * param_scale

    batch_idxs = make_batches(train_images.shape[0], batch_size)
    banded_cov = create_banded_cov(all_cov.shape[0],10)
    log_prior = create_data_L()
    def batch_value_and_grad(enc_w, iter):
        iter = iter % len(batch_idxs)
        cur_data = train_images[batch_idxs[iter]]
        return enc_lower_bound(enc_w,trained_weights,encoder,decoder_log_like,log_prior,N_weights_enc,cur_data,samples_per_image,latent_dimensions,rs)
    lb_grad = grad(batch_value_and_grad)

    def callback(params, i, grad):
        ml = batch_value_and_grad(params,i)
        print "----- log marginal likelihood:", ml
        #Generate samples
        num_samples = 100
        images_per_row = 10
        # zs = train_images[0:100,:]
        zs = np.zeros((100,10))
        zs[:,1] = 1
        (mus,log_sigs) = encoder(params, zs)
        # sigs = np.exp(log_sigs)
        # noise = rs.randn(1,100,784)
        # samples = mus + sigs*noise
        # samples = np.reshape(samples,(100*1,784),order = 'F')
        samples = mus

        fig = plt.figure(1)
        fig.clf()
        ax = fig.add_subplot(111)
        plot_images(samples, ax, ims_per_row=images_per_row)
        plt.savefig('samples.png')

    final_params = adam(lb_grad, initial_enc_w, num_training_iters, callback=callback)

    def decoder_with_weights(zs):
        return decoder(parser.get(final_params, 'decoding weights'), zs)
    return decoder_with_weights

    finish_time = time.time()
    print "total runtime", finish_time - start_time
Esempio n. 2
0
def run_aevb(train_images):
    start_time = time.time()

    # Optimize aevb
    batch_size = 100
    num_training_iters = 2 * 640
    rs = npr.RandomState(0)

    parser = WeightsParser()
    parser.add_shape('encoding weights', (N_weights_enc, ))
    parser.add_shape('decoding weights', (N_weights_dec, ))
    initial_combined_weights = rs.randn(len(parser)) * param_scale

    batch_idxs = make_batches(train_images.shape[0], batch_size)

    def batch_value_and_grad(weights, iter):
        iter = iter % len(batch_idxs)
        cur_data = train_images[batch_idxs[iter]]
        return lower_bound(weights, encoder, decoder_log_like, N_weights_enc,
                           cur_data, samples_per_image, latent_dimensions, rs)

    lb_grad = grad(batch_value_and_grad)

    def callback(params, i, grad):
        ml = batch_value_and_grad(params, i)
        print "log marginal likelihood:", ml

        #Generate samples
        num_samples = 100
        images_per_row = 10
        zs = rs.randn(num_samples, latent_dimensions)
        samples = decoder(parser.get(params, 'decoding weights'), zs)
        fig = plt.figure(1)
        fig.clf()
        ax = fig.add_subplot(111)
        plot_images(samples, ax, ims_per_row=images_per_row)
        plt.savefig('samples.png')

    final_params = adam(lb_grad,
                        initial_combined_weights,
                        num_training_iters,
                        callback=callback)

    #Validation loss:
    print '--- test loss:', lower_bound(final_params, encoder,
                                        decoder_log_like, N_weights_enc,
                                        test_images[0:100, :],
                                        samples_per_image, latent_dimensions,
                                        rs)

    parameters = final_params, N_weights_enc, samples_per_image, latent_dimensions, rs
    save_string = 'parameters50.pkl'
    print 'SAVING AS: ', save_string
    print 'LATENTS DIMS', latent_dimensions
    with open(save_string, 'w') as f:
        pickle.dump(parameters, f, 1)

    finish_time = time.time()
    print "total runtime", finish_time - start_time
Esempio n. 3
0
def build_grad_sampler(D, num_steps, approx):

    # Build parser
    parser = WeightsParser()
    parser.add_shape('mean', D)
    parser.add_shape('log_stddev', D)
    parser.add_shape('log_stepsize', 1)

    def sample_and_run_grad(params,
                            loglik_func,
                            rs,
                            num_images,
                            samples_per_image,
                            callback=None):
        gradfun = elementwise_grad(loglik_func)
        mean = parser.get(params, 'mean')
        stddevs = np.exp(parser.get(params, 'log_stddev'))
        stepsizes = np.exp(parser.get(params, 'log_stepsize'))

        initial_entropies = np.full(num_images * samples_per_image,
                                    entropy_of_a_diagonal_gaussian(stddevs))
        init_xs = mean + rs.randn(num_images * samples_per_image, D) * stddevs
        samples, entropy_estimates = \
            gradient_ascent_entropic(gradfun, entropies=initial_entropies, xs=init_xs,
                                     stepsizes=stepsizes,num_steps=num_steps,
                                     rs=rs, callback=callback, approx=approx)

        loglik_estimates = loglik_func(samples)
        return samples, loglik_estimates, entropy_estimates

    return sample_and_run_grad, parser
Esempio n. 4
0
def run_aevb(train_images):
    start_time = time.time()

    # Create aevb function
    # Training parameters

    D = train_images.shape[1]
    rs = np.random.npr.RandomState(0)
    sample_and_run_grad = build_mult_grad_sampler(latent_dimensions,
                                                  1,
                                                  approx=True)

    enc_layers = [D, hidden_units, 2 * latent_dimensions]
    dec_layers = [latent_dimensions, hidden_units, D]
    N_weights_NN, encoder_NN, encoder_log_like_NN = make_gaussian_nn(
        enc_layers)
    N_weights_dec, decoder, decoder_log_like = make_binary_nn(dec_layers)
    encoder = sample_and_run_grad

    #Create parser
    parser = WeightsParser()
    parser.add_shape('encoding network weights', (N_weights_NN, ))
    N_weights_enc = len(parser)
    parser.add_shape('decoding weights', (N_weights_dec, ))
    params = rs.randn(len(parser)) * param_scale

    def two_part_encode(params, data, log_lik_func, rs, num_samples):
        network_weights = parser.get(params, 'encoding network weights')
        (mus, log_sigs) = encoder_NN(network_weights, data)
        sigs = np.exp(log_sigs)
        return sample_and_run_grad(mus, sigs, .01, log_lik_func, rs,
                                   num_samples)

    # Optimize aevb
    batch_size = 100
    num_training_iters = 1600
    rs = npr.RandomState(0)

    batch_idxs = make_batches(train_images.shape[0], batch_size)

    def batch_value_and_grad(weights, iter):
        iter = iter % len(batch_idxs)
        cur_data = train_images[batch_idxs[iter]]
        return lower_bound(weights, two_part_encode, decoder_log_like,
                           N_weights_enc, cur_data, samples_per_image,
                           latent_dimensions, rs)

    lb_grad = grad(batch_value_and_grad)

    def callback(params, i, grad):
        ml = batch_value_and_grad(params, i)
        print "log marginal likelihood:", ml

        #Generate samples
        num_samples = 100
        images_per_row = 10
        zs = rs.randn(num_samples, latent_dimensions)
        # samples = np.random.binomial(1,decoder(parser.get(params, 'decoding weights'), zs))
        samples = decoder(parser.get(params, 'decoding weights'), zs)
        fig = plt.figure(1)
        fig.clf()
        ax = fig.add_subplot(111)
        plot_images(samples, ax, ims_per_row=images_per_row)
        plt.savefig('samples.png')

    final_params = adam(lb_grad, params, num_training_iters, callback=callback)

    def decoder_with_weights(zs):
        return decoder(parser.get(final_params, 'decoding weights'), zs)

    return decoder_with_weights

    finish_time = time.time()
    print "total runtime", finish_time - start_time
Esempio n. 5
0
def run_aevb(train_images):

    # run_aevb(train_images)

    start_time = time.time()

    # Create aevb function
    # Training parameters

    D = train_images.shape[1]

    dec_layers = [latent_dimensions, hidden_units, hidden_units, D]

    mean = np.zeros(latent_dimensions)
    log_stddevs = np.log(1.0 * np.ones(latent_dimensions))
    log_stepsize = np.log(.005)

    rs = np.random.npr.RandomState(0)
    sample_and_run_es = build_early_stop_fixed_params(
        latent_dimensions,
        approx=True,
        mean=mean,
        log_stddevs=log_stddevs,
        log_stepsize=log_stepsize)
    N_weights_dec, decoder, decoder_log_like = make_binary_nn(dec_layers)
    N_weights_enc = 0
    encoder = sample_and_run_es
    # Build parser
    parser = WeightsParser()
    parser.add_shape('decoding weights', (N_weights_dec, ))

    params = np.zeros(len(parser))
    parser.put(params, 'decoding weights',
               rs.randn(N_weights_dec) * param_scale)
    assert len(parser) == N_weights_dec

    # Optimize aevb
    batch_size = 1
    num_training_iters = 1600
    rs = npr.RandomState(0)

    batch_idxs = make_batches(train_images.shape[0], batch_size)

    def batch_value_and_grad(weights, iter):
        iter = iter % len(batch_idxs)
        cur_data = train_images[batch_idxs[iter]]
        return lower_bound(weights, encoder, decoder_log_like, N_weights_enc,
                           cur_data, samples_per_image, latent_dimensions, rs)

    lb_grad = grad(batch_value_and_grad)

    def callback(params, i, grad):
        n_iter = 0.0
        sum_ml = 0
        for j in xrange(0, 1):
            ml = batch_value_and_grad(params, j)
            print "---- log marginal likelihood:", ml
            n_iter += 1
            sum_ml += ml
            print '-------- avg_ml', sum_ml / n_iter

        #Generate samples
        num_samples = 100
        images_per_row = 10
        zs = rs.randn(num_samples, latent_dimensions)
        # samples = np.random.binomial(1,decoder(parser.get(params, 'decoding weights'), zs))
        samples = decoder(parser.get(params, 'decoding weights'), zs)
        fig = plt.figure(1)
        fig.clf()
        ax = fig.add_subplot(111)
        plot_images(samples, ax, ims_per_row=images_per_row)
        plt.savefig('samples.png')

    final_params = adam(lb_grad, params, num_training_iters, callback=callback)

    def decoder_with_weights(zs):
        return decoder(parser.get(final_params, 'decoding weights'), zs)

    return decoder_with_weights

    finish_time = time.time()
    print "total runtime", finish_time - start_time
Esempio n. 6
0
def run_aevb(train_images):
    start_time = time.time()

    # Create aevb function
    # Training parameters

    D = train_images.shape[1]

    enc_layers = [D, hidden_units, 2 * latent_dimensions]
    dec_layers = [latent_dimensions, hidden_units, D]

    N_weights_enc, encoder, encoder_log_like = make_gaussian_nn(enc_layers)
    N_weights_dec, decoder, decoder_log_like = make_binary_nn(dec_layers)

    # Optimize aevb
    batch_size = 100
    num_training_iters = 1600
    rs = npr.RandomState(0)

    parser = WeightsParser()
    parser.add_shape('encoding weights', (N_weights_enc, ))
    parser.add_shape('decoding weights', (N_weights_dec, ))
    initial_combined_weights = rs.randn(len(parser)) * param_scale

    batch_idxs = make_batches(train_images.shape[0], batch_size)
    log_prior = build_logprob_standard_normal(latent_dimensions)

    def batch_value_and_grad(weights, iter):
        iter = iter % len(batch_idxs)
        cur_data = train_images[batch_idxs[iter]]
        return lower_bound(weights, encoder, decoder_log_like, log_prior,
                           N_weights_enc, cur_data, samples_per_image,
                           latent_dimensions, rs)

    lb_grad = grad(batch_value_and_grad)

    def callback(params, i, grad):
        ml = batch_value_and_grad(params, i)
        print "log marginal likelihood:", ml
        #Generate samples
        num_samples = 100
        images_per_row = 10
        zs = rs.randn(num_samples, latent_dimensions)
        samples = decoder(parser.get(params, 'decoding weights'), zs)
        # samples = np.random.binomial(1,decoder(parser.get(params, 'decoding weights'), zs))

        fig = plt.figure(1)
        fig.clf()
        ax = fig.add_subplot(111)
        plot_images(samples, ax, ims_per_row=images_per_row)
        plt.savefig('samples.png')

    final_params = adam(lb_grad,
                        initial_combined_weights,
                        num_training_iters,
                        callback=callback)

    def decoder_with_weights(zs):
        return decoder(parser.get(final_params, 'decoding weights'), zs)

    return decoder_with_weights

    finish_time = time.time()
    print "total runtime", finish_time - start_time
Esempio n. 7
0
def run_cond_aevb(base_data, cond_data):
    start_time = time.time()

    # Create aevb function
    # Training parameters

    D_c = cond_data.shape[1]
    D_b = base_data.shape[1]
    N_data = cond_data.shape[0]
    assert cond_data.shape[0] == base_data.shape[0]
    enc_layers = [
        D_c, hidden_units, hidden_units, hidden_units, 2 * latent_dimensions
    ]
    dec_layers = [
        latent_dimensions + D_b, hidden_units, hidden_units, hidden_units, D_c
    ]

    N_weights_enc, encoder, encoder_log_like = make_gaussian_nn(enc_layers)
    N_weights_dec, decoder, decoder_log_like = make_binary_nn(dec_layers)

    # Optimize aevb
    batch_size = 100
    num_training_iters = 1600
    rs = npr.RandomState(0)

    parser = WeightsParser()
    parser.add_shape('encoding weights', (N_weights_enc, ))
    parser.add_shape('decoding weights', (N_weights_dec, ))
    initial_combined_weights = rs.randn(len(parser)) * param_scale

    batch_idxs = make_batches(N_data, batch_size)

    def batch_value_and_grad(weights, iter):
        iter = iter % len(batch_idxs)
        # cur_base = base_data[batch_idxs[iter]]
        cur_cond = cond_data[batch_idxs[iter]]
        cur_im = base_data[batch_idxs[iter]]
        cur_b = apply_mask(cur_im)
        return lower_bound(weights, encoder, decoder_log_like, N_weights_enc,
                           cur_b, cur_cond, samples_per_image,
                           latent_dimensions, rs)

    lb_grad = grad(batch_value_and_grad)

    base_test = np.repeat(apply_mask(base_data[0:10, :]), 10, axis=0)

    def callback(params, i, grad):
        ml = batch_value_and_grad(params, i)
        print "log marginal likelihood:", ml

        # #Generate samples
        num_samples = 100
        images_per_row = 10
        zs = rs.randn(100, latent_dimensions)
        # zs = rs.randn(10,latent_dimensions)
        # zs = np.repeat(zs,10,axis = 0)
        # base_test = base_data[0:num_samples,:]
        # base_test = np.repeat(base_data[0:10,:],10,axis = 0)
        dec_in = np.concatenate((zs, base_test), axis=1)
        samples = decoder(parser.get(params, 'decoding weights'), dec_in)
        fig = plt.figure(1)
        fig.clf()
        ax = fig.add_subplot(111)
        # plot_images(samples, ax, ims_per_row=images_per_row)
        plot_shape = (100, 784)
        im_samples = np.zeros(plot_shape)
        im_mean = np.zeros(plot_shape)
        im_map = np.zeros(plot_shape)
        for k in xrange(plot_shape[0]):
            if k % 10 == 0:
                im_samples[k, :] = base_test[k, :]
                im_mean[k, :] = base_test[k, :]
                im_map[k, :] = base_test[k, :]
            else:
                im_mean[k, :] = samples[k - 1, :]
                im_samples[k, :] = np.random.binomial(1, samples[k - 1, :])
                im_map[k, :] = np.round(samples[k - 1, :])

        plot_images(im_samples, ax, ims_per_row=images_per_row)
        plt.savefig('samples.png')

        fig = plt.figure(1)
        fig.clf()
        ax = fig.add_subplot(111)
        plot_images(im_mean, ax, ims_per_row=images_per_row)
        plt.savefig('mean_samples.png')

        fig = plt.figure(1)
        fig.clf()
        ax = fig.add_subplot(111)
        plot_images(im_map, ax, ims_per_row=images_per_row)
        plt.savefig('map_samples.png')

        fig = plt.figure(1)
        fig.clf()
        ax = fig.add_subplot(111)
        plot_images(base_test, ax, ims_per_row=images_per_row)
        plt.savefig('blurred_samples.png')

    final_params = adam(lb_grad,
                        initial_combined_weights,
                        num_training_iters,
                        callback=callback)

    def decoder_with_weights(zs):
        return decoder(parser.get(final_params, 'decoding weights'), zs)

    return decoder_with_weights

    finish_time = time.time()
    print "total runtime", finish_time - start_time
Esempio n. 8
0
def run_cond_aevb(base_data,cond_data):
    start_time = time.time()

    # Create aevb function
    # Training parameters


    D_c = cond_data.shape[1]
    D_b = base_data.shape[1]
    N_data = cond_data.shape[0]
    assert cond_data.shape[0] == base_data.shape[0]
    enc_layers = [D_c, hidden_units, 2*latent_dimensions]
    dec_layers = [latent_dimensions+D_b, hidden_units, D_c]

    N_weights_enc, encoder, encoder_log_like = make_gaussian_nn(enc_layers)
    N_weights_dec, decoder, decoder_log_like = make_binary_nn(dec_layers)

    # Optimize aevb
    batch_size = 1000
    num_training_iters = 1600
    rs = npr.RandomState(0)

    parser = WeightsParser()
    parser.add_shape('encoding weights', (N_weights_enc,))
    parser.add_shape('decoding weights', (N_weights_dec,))
    initial_combined_weights = rs.randn(len(parser)) * param_scale

    batch_idxs = make_batches(N_data, batch_size)

    def batch_value_and_grad(weights, iter):
        iter = iter % len(batch_idxs)
        cur_cond = cond_data[batch_idxs[iter]]
        cur_base = base_data[batch_idxs[iter]]
        return lower_bound(weights,encoder,decoder_log_like,N_weights_enc,cur_base,cur_cond,samples_per_image,latent_dimensions,rs)

    lb_grad = grad(batch_value_and_grad)



    def callback(params, i, grad):
        ml = batch_value_and_grad(params,i)
        print "log marginal likelihood:", ml

        # #Generate samples
        num_samples = 100
        images_per_row = 10
        zs = rs.randn(num_samples,latent_dimensions)
        base_test = np.zeros((num_samples,D_b))
        for i in xrange(num_samples):
            base_test[i,i%10] = 1

        dec_in = np.concatenate((zs,base_test),axis = 1)
        samples = decoder(parser.get(params, 'decoding weights'), dec_in)
        fig = plt.figure(1)
        fig.clf()
        ax = fig.add_subplot(111)
        plot_images(samples, ax, ims_per_row=images_per_row)
        plt.savefig('samples.png')

    final_params = adam(lb_grad, initial_combined_weights, num_training_iters, callback=callback)

    def decoder_with_weights(zs):
        return decoder(parser.get(final_params, 'decoding weights'), zs)
    return decoder_with_weights

    finish_time = time.time()
    print "total runtime", finish_time - start_time
Esempio n. 9
0
def run_aevb(train_images):
    start_time = time.time()

    # Create aevb function
    # Training parameters

    D = train_images.shape[1]

    enc_layers = [D, hidden_units, hidden_units, 2 * latent_dimensions]
    dec_layers = [latent_dimensions, hidden_units, hidden_units, D]

    N_weights_enc, encoder, encoder_log_like = make_gaussian_nn(enc_layers)
    N_weights_dec, decoder, decoder_log_like = make_binary_nn(dec_layers)

    # Optimize aevb
    batch_size = 500
    num_training_iters = 1600
    rs = npr.RandomState(0)

    parser = WeightsParser()
    parser.add_shape('encoding weights', (N_weights_enc, ))
    parser.add_shape('decoding weights', (N_weights_dec, ))
    initial_combined_weights = rs.randn(len(parser)) * param_scale

    batch_idxs = make_batches(train_images.shape[0], batch_size)

    def batch_value_and_grad(weights, iter):
        iter = iter % len(batch_idxs)
        cur_data = train_images[batch_idxs[iter]]
        return lower_bound(weights, encoder, decoder_log_like, N_weights_enc,
                           cur_data, samples_per_image, latent_dimensions, rs)

    lb_grad = grad(batch_value_and_grad)

    big_batch_idxs = make_batches(train_images.shape[0], 1000)

    def big_batch_value_and_grad(weights, iter):
        iter = iter % len(big_batch_idxs)
        cur_data = train_images[big_batch_idxs[iter]]
        return lower_bound(weights, encoder, decoder_log_like, N_weights_enc,
                           cur_data, samples_per_image, latent_dimensions, rs)

    def callback(params, i, grad):
        ml = big_batch_value_and_grad(params, i)
        print "log marginal likelihood:", ml

        print "----- iter ", i
        if i % 1000 == 0 and not np.isnan(
                lower_bound(params, encoder, decoder_log_like, N_weights_enc,
                            test_images[0:100, :], samples_per_image,
                            latent_dimensions, rs)):
            print 'SAVING ==== '
            save_string = 'parameters10l300hfor' + str(i) + '.pkl'

            parameters = params, N_weights_enc, samples_per_image, latent_dimensions, rs
            print 'SAVING AS: ', save_string
            print 'LATENTS DIMS', latent_dimensions
            with open(save_string, 'w') as f:
                pickle.dump(parameters, f, 1)
            #Validation loss:
            print '--- test loss:', lower_bound(params, encoder,
                                                decoder_log_like,
                                                N_weights_enc,
                                                test_images[0:100, :],
                                                samples_per_image,
                                                latent_dimensions, rs)

        #Generate samples
        num_samples = 100
        images_per_row = 10
        zs = rs.randn(num_samples, latent_dimensions)
        samples = decoder(parser.get(params, 'decoding weights'), zs)
        fig = plt.figure(1)
        fig.clf()
        ax = fig.add_subplot(111)
        plot_images(samples, ax, ims_per_row=images_per_row)
        plt.savefig('samples.png')
        if i % 100 == 0:
            enc_w = params[0:N_weights_enc]
            dec_w = params[N_weights_enc:len(params)]
            plot_latent_centers(encoder, decoder, enc_w, dec_w)

    final_params = adam(lb_grad,
                        initial_combined_weights,
                        num_training_iters,
                        callback=callback)

    finish_time = time.time()
    print "total runtime", finish_time - start_time

    enc_w = final_params[0:N_weights_enc]
    dec_w = final_params[N_weights_enc:len(final_params)]
    return encoder, decoder, enc_w, dec_w
Esempio n. 10
0
    N_weights_enc, encoder, _ = make_gaussian_nn(enc_layer_sizes)
    N_weights_dec, decoder, decoder_log_like = make_binary_nn(dec_layer_sizes)

    # Optimization parameters.
    batch_size = 100
    num_training_iters = 100
    sampler_learn_rate = 0.01
    batch_idxs = make_batches(train_images.shape[0], batch_size)

    init_enc_w = rs.randn(N_weights_enc) * param_scale
    init_dec_w = rs.randn(N_weights_dec) * param_scale

    flow_sampler, flow_parser = build_flow_sampler(latent_dimension, num_flow_steps)

    combined_parser = WeightsParser()
    combined_parser.add_shape('encoder weights', N_weights_enc)
    combined_parser.add_shape('decoder weights', N_weights_dec)
    combined_parser.add_shape('flow params', len(flow_parser))

    combined_params = np.zeros(len(combined_parser))
    combined_parser.put(combined_params, 'encoder weights', init_enc_w)
    combined_parser.put(combined_params, 'flow params', init_flow_params(flow_parser, rs))
    combined_parser.put(combined_params, 'decoder weights', init_dec_w)

    def get_batch_lower_bound(cur_params, iter):
        encoder_weights = combined_parser.get(cur_params, 'encoder weights')
        flow_params     = combined_parser.get(cur_params, 'flow params')
        decoder_weights = combined_parser.get(cur_params, 'decoder weights')

        cur_data = train_images[batch_idxs[iter]]
Esempio n. 11
0
def time_and_acc(latent_dimension):

    start_time = time.time()
    rs = np.random.npr.RandomState(0)
    #load_and_pickle_binary_mnist()
    with open('../../../autopaint/mnist_binary_data.pkl') as f:
        N_data, train_images, train_labels, test_images, test_labels = pickle.load(
            f)

    D = train_images.shape[1]
    enc_layer_sizes = [D, hidden_units, 2 * latent_dimension]
    dec_layer_sizes = [latent_dimension, hidden_units, D]

    N_weights_enc, encoder, encoder_log_like = make_gaussian_nn(
        enc_layer_sizes)
    N_weights_dec, decoder, decoder_log_like = make_binary_nn(dec_layer_sizes)

    # Optimization parameters.
    batch_size = 100
    num_training_iters = 100
    sampler_learn_rate = 0.01
    batch_idxs = make_batches(train_images.shape[0], batch_size)

    init_enc_w = rs.randn(N_weights_enc) * param_scale
    init_dec_w = rs.randn(N_weights_dec) * param_scale

    flow_sampler, flow_parser = build_flow_sampler(latent_dimension,
                                                   num_flow_steps)

    combined_parser = WeightsParser()
    combined_parser.add_shape('encoder weights', N_weights_enc)
    combined_parser.add_shape('decoder weights', N_weights_dec)
    combined_parser.add_shape('flow params', len(flow_parser))

    combined_params = np.zeros(len(combined_parser))
    combined_parser.put(combined_params, 'encoder weights', init_enc_w)
    combined_parser.put(combined_params, 'flow params',
                        init_flow_params(flow_parser, rs, latent_dimension))
    combined_parser.put(combined_params, 'decoder weights', init_dec_w)

    def get_batch_lower_bound(cur_params, iter):
        encoder_weights = combined_parser.get(cur_params, 'encoder weights')
        flow_params = combined_parser.get(cur_params, 'flow params')
        decoder_weights = combined_parser.get(cur_params, 'decoder weights')

        cur_data = train_images[batch_idxs[iter]]
        mus, log_sigs = encoder(encoder_weights, cur_data)
        samples, entropy_estimates = flow_sampler(flow_params, mus,
                                                  np.exp(log_sigs), rs)
        loglikes = decoder_log_like(decoder_weights, samples, cur_data)

        print "Iter", iter, "loglik:", np.mean(loglikes).value, \
            "entropy:", np.mean(entropy_estimates).value, "marg. like:", np.mean(entropy_estimates + loglikes).value
        lastVal = np.mean(entropy_estimates + loglikes).value
        with open('lastVal.pkl', 'w') as f:
            pickle.dump(lastVal, f, 1)
        return np.mean(entropy_estimates + loglikes)

    lb_grad = grad(get_batch_lower_bound)

    def callback(weights, iter, grad):
        #Generate samples
        num_samples = 100
        zs = rs.randn(num_samples, latent_dimension)
        samples = decoder(combined_parser.get(weights, 'decoder weights'), zs)
        fig = plt.figure(1)
        fig.clf()
        ax = fig.add_subplot(111)
        plot_images(samples, ax, ims_per_row=10)
        plt.savefig('samples.png')

    final_params = adam(lb_grad,
                        combined_params,
                        num_training_iters,
                        callback=callback)

    finish_time = time.time()
    # #Broken and very mysterious:
    # lb_val_grad = value_and_grad(get_batch_lower_bound)
    # lb_est = lb_val_grad(final_params,num_training_iters+2)
    # print lb_est
    # lb_est = lb_est[0]
    with open('lastVal.pkl') as f:
        lb_est = pickle.load(f)
    print 'lb_est is', lb_est
    print "Total training time:", finish_time - start_time
    return finish_time, lb_est