示例#1
0
    def lower_bound_and_log_likelihood(relaxed=False):
        def log_joint(observed):
            model = vae(observed, n, n_x, n_z, n_k, tau_p, n_particles,
                        relaxed)
            log_pz, log_px_z = model.local_log_prob(['z', 'x'])
            return log_pz + log_px_z

        variational = q_net({}, x, n_z, n_k, tau_q, n_particles, relaxed)
        qz_samples, log_qz = variational.query('z',
                                               outputs=True,
                                               local_log_prob=True)

        lower_bound = zs.variational.elbo(log_joint,
                                          observed={'x': x_obs},
                                          latent={'z': [qz_samples, log_qz]},
                                          axis=0)
        cost = tf.reduce_mean(lower_bound.sgvb())
        lower_bound = tf.reduce_mean(lower_bound)

        # Importance sampling estimates of marginal log likelihood
        is_log_likelihood = tf.reduce_mean(
            zs.is_loglikelihood(log_joint, {'x': x_obs},
                                {'z': [qz_samples, log_qz]},
                                axis=0))

        return cost, lower_bound, is_log_likelihood
示例#2
0
    variational = q_net({}, x, n_z, n_particles, is_training)
    qz_samples, log_qz = variational.query('z', outputs=True,
                                           local_log_prob=True)
    # TODO: add tests for repeated calls of flows
    qz_samples, log_qz = zs.planar_normalizing_flow(qz_samples, log_qz,
                                                    n_iters=n_planar_flows)
    qz_samples, log_qz = zs.planar_normalizing_flow(qz_samples, log_qz,
                                                    n_iters=n_planar_flows)

    lower_bound = tf.reduce_mean(
        zs.sgvb(log_joint, {'x': x_obs}, {'z': [qz_samples, log_qz]}, axis=0))

    # Importance sampling estimates of log likelihood:
    # Fast, used for evaluation during training
    is_log_likelihood = tf.reduce_mean(
        zs.is_loglikelihood(log_joint, {'x': x_obs},
                            {'z': [qz_samples, log_qz]}, axis=0))

    learning_rate_ph = tf.placeholder(tf.float32, shape=[], name='lr')
    optimizer = tf.train.AdamOptimizer(learning_rate_ph, epsilon=1e-4)
    grads = optimizer.compute_gradients(-lower_bound)
    infer = optimizer.apply_gradients(grads)

    params = tf.trainable_variables()
    for i in params:
        print(i.name, i.get_shape())

    saver = tf.train.Saver(max_to_keep=10)

    # Run the inference
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
示例#3
0
def main():
    tf.set_random_seed(1234)
    np.random.seed(1234)

    # Load MNIST
    data_path = os.path.join(conf.data_dir, 'mnist.pkl.gz')
    x_train, t_train, x_valid, t_valid, x_test, t_test = \
        dataset.load_mnist_realval(data_path)
    x_train = np.vstack([x_train, x_valid])
    x_test = np.random.binomial(1, x_test, size=x_test.shape)
    x_dim = x_train.shape[1]

    # Define model/inference parameters
    z_dim = 40
    n_planar_flows = 10

    # Build the computation graph
    n_particles = tf.placeholder(tf.int32, shape=[], name="n_particles")
    x_input = tf.placeholder(tf.float32, shape=[None, x_dim], name="x")
    x = tf.cast(tf.less(tf.random_uniform(tf.shape(x_input)), x_input),
                tf.int32)
    n = tf.placeholder(tf.int32, shape=[], name="n")

    model = build_gen(n, x_dim, z_dim, n_particles)
    q_net = build_q_net(x, z_dim, n_particles)
    qz_samples, log_qz = q_net.query('z', outputs=True, local_log_prob=True)
    # TODO: add tests for repeated calls of flows
    qz_samples, log_qz = zs.planar_normalizing_flow(qz_samples,
                                                    log_qz,
                                                    n_iters=n_planar_flows)
    qz_samples, log_qz = zs.planar_normalizing_flow(qz_samples,
                                                    log_qz,
                                                    n_iters=n_planar_flows)

    lower_bound = zs.variational.elbo(model,
                                      observed={"x": x},
                                      latent={"z": [qz_samples, log_qz]},
                                      axis=0)
    cost = tf.reduce_mean(lower_bound.sgvb())
    lower_bound = tf.reduce_mean(lower_bound)

    # Importance sampling estimates of marginal log likelihood
    is_log_likelihood = tf.reduce_mean(
        zs.is_loglikelihood(model, {'x': x}, {'z': [qz_samples, log_qz]},
                            axis=0))

    optimizer = tf.train.AdamOptimizer(learning_rate=0.001)
    infer_op = optimizer.minimize(cost)

    # Define training/evaluation parameters
    epochs = 3000
    batch_size = 128
    iters = x_train.shape[0] // batch_size
    test_freq = 10
    test_batch_size = 400
    test_iters = x_test.shape[0] // test_batch_size

    # Run the inference
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        for epoch in range(1, epochs + 1):
            time_epoch = -time.time()
            np.random.shuffle(x_train)
            lbs = []
            for t in range(iters):
                x_batch = x_train[t * batch_size:(t + 1) * batch_size]
                _, lb = sess.run([infer_op, lower_bound],
                                 feed_dict={
                                     x_input: x_batch,
                                     n_particles: 1,
                                     n: batch_size
                                 })
                lbs.append(lb)
            time_epoch += time.time()
            print('Epoch {} ({:.1f}s): Lower bound = {}'.format(
                epoch, time_epoch, np.mean(lbs)))

            if epoch % test_freq == 0:
                time_test = -time.time()
                test_lbs = []
                test_lls = []
                for t in range(test_iters):
                    test_x_batch = x_test[t * test_batch_size:(t + 1) *
                                          test_batch_size]
                    test_lb = sess.run(lower_bound,
                                       feed_dict={
                                           x: test_x_batch,
                                           n_particles: 1,
                                           n: test_batch_size
                                       })
                    test_ll = sess.run(is_log_likelihood,
                                       feed_dict={
                                           x: test_x_batch,
                                           n_particles: 1000,
                                           n: test_batch_size
                                       })
                    test_lbs.append(test_lb)
                    test_lls.append(test_ll)
                time_test += time.time()
                print('>>> TEST ({:.1f}s)'.format(time_test))
                print('>> Test lower bound = {}'.format(np.mean(test_lbs)))
                print('>> Test log likelihood (IS) = {}'.format(
                    np.mean(test_lls)))
def main():
    data_path = "/home/zhikun/PycharmProjects/data/MNIST/mnist.pkl.gz"
    # the file comes from http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz
    x_train, t_train, x_valid, t_valid, x_test, t_test = load_mnist(data_path)
    x_train = np.vstack([x_train, x_valid])
    y_train = np.vstack([t_train, t_valid])
    x_test = np.random.binomial(1, x_test, size=x_test.shape)
    x_dim = x_train.shape[1]
    y_dim = y_train.shape[1]
    # Define model parameters
    z_dim = 40
    n_particles = tf.placeholder(tf.int32, shape=[], name="n_particles")
    x_input = tf.placeholder(tf.float32, shape=[None, x_dim], name="x")
    x = tf.cast(tf.less(tf.random_uniform(tf.shape(x_input)), x_input),
                tf.int32)
    y = tf.placeholder(tf.float32, shape=[None, 10], name="y")
    n = tf.placeholder(tf.int32, shape=[], name="n")

    model = build_gen(x_dim, z_dim, y, n, n_particles)
    variational = build_q_net(x, z_dim, y, n_particles)

    lower_bound = zs.variational.elbo(model, {"x": x},
                                      variational=variational,
                                      axis=0)
    cost = tf.reduce_mean(lower_bound.sgvb())
    lower_bound = tf.reduce_mean(lower_bound)

    # # Importance sampling estimates of marginal log likelihood
    is_log_likelihood = tf.reduce_mean(
        zs.is_loglikelihood(model, {"x": x}, proposal=variational, axis=0))

    optimizer = tf.train.AdamOptimizer(learning_rate=0.001)
    infer_op = optimizer.minimize(cost)

    # Random generation
    y_observe = tf.placeholder(tf.float32, shape=[None, 10], name="y_observe")
    x_gen = tf.reshape(model.observe(y=y_observe)["x_mean"], [-1, 28, 28, 1])

    # Define training/evaluation parameters
    epochs = 3000
    batch_size = 128
    iters = x_train.shape[0] // batch_size
    save_freq = 10
    test_freq = 10
    test_batch_size = 50
    test_iters = x_test.shape[0] // test_batch_size
    result_path = "results/c-vae"

    # Run the inference
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for epoch in range(1, epochs + 1):
            time_epoch = -time.time()
            lbs = []
            for t in range(iters):
                x_batch = x_train[t * batch_size:(t + 1) * batch_size]
                y_batch = y_train[t * batch_size:(t + 1) * batch_size]
                _, lb = sess.run([infer_op, lower_bound],
                                 feed_dict={
                                     x_input: x_batch,
                                     y: y_batch,
                                     n_particles: 1,
                                     n: batch_size
                                 })
                lbs.append(lb)
            time_epoch += time.time()
            print("Epoch {} ({:.1f}s): Lower bound = {}".format(
                epoch, time_epoch, np.mean(lbs)))

            if epoch % test_freq == 0:
                time_test = -time.time()
                test_lbs, test_lls = [], []
                for t in range(test_iters):
                    test_x_batch = x_test[t * test_batch_size:(t + 1) *
                                          test_batch_size]
                    test_y_batch = t_test[t * test_batch_size:(t + 1) *
                                          test_batch_size]
                    test_lb = sess.run(lower_bound,
                                       feed_dict={
                                           x: test_x_batch,
                                           y: test_y_batch,
                                           n_particles: 1,
                                           n: test_batch_size
                                       })
                    test_ll = sess.run(is_log_likelihood,
                                       feed_dict={
                                           x: test_x_batch,
                                           y: test_y_batch,
                                           n_particles: 1000,
                                           n: test_batch_size
                                       })
                    test_lbs.append(test_lb)
                    test_lls.append(test_ll)
                time_test += time.time()
                print(">>> TEST ({:.1f}s)".format(time_test))
                print(">> Test lower bound = {}".format(np.mean(test_lbs)))
                print('>> Test log likelihood (IS) = {}'.format(
                    np.mean(test_lls)))

            if epoch % save_freq == 0:
                y_index = np.repeat(np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
                                    10)
                y_index = np.eye(10)[y_index.reshape(-1)]
                images = sess.run(x_gen,
                                  feed_dict={
                                      y_observe: y_index,
                                      y: y_index,
                                      n: 100,
                                      n_particles: 1
                                  })
                name = os.path.join(result_path,
                                    "c-vae.epoch.{}.png".format(epoch))
                save_image_collections(images, name)
示例#5
0
def main():
    # Load MNIST
    data_path = os.path.join(conf.data_dir, 'mnist.pkl.gz')
    x_train, t_train, x_valid, t_valid, x_test, t_test = \
        dataset.load_mnist_realval(data_path)
    x_train = np.vstack([x_train, x_valid])
    x_test = np.random.binomial(1, x_test, size=x_test.shape)
    x_dim = x_train.shape[1]

    # Define model parameters
    z_dim = 40

    # Build the computation graph
    n_particles = tf.placeholder(tf.int32, shape=[], name='n_particles')
    x_input = tf.placeholder(tf.float32, shape=[None, x_dim], name='x')
    x = tf.to_int32(tf.less(tf.random_uniform(tf.shape(x_input)), x_input))
    n = tf.shape(x)[0]

    def log_joint(observed):
        model = vae(observed, x_dim, z_dim, n, n_particles)
        log_pz, log_px_z = model.local_log_prob(['z', 'x'])
        return log_pz + log_px_z

    variational = q_net({'x': x}, x_dim, z_dim, n_particles)
    qz_samples, log_qz = variational.query('z', outputs=True,
                                           local_log_prob=True)
    lower_bound = zs.variational.elbo(log_joint,
                                      observed={'x': x},
                                      latent={'z': [qz_samples, log_qz]},
                                      axis=0)
    cost = tf.reduce_mean(lower_bound.sgvb())
    lower_bound = tf.reduce_mean(lower_bound)

    # Importance sampling estimates of marginal log likelihood
    is_log_likelihood = tf.reduce_mean(
        zs.is_loglikelihood(log_joint, {'x': x},
                            {'z': [qz_samples, log_qz]}, axis=0))

    optimizer = tf.train.AdamOptimizer(learning_rate=0.001)
    infer_op = optimizer.minimize(cost)

    # Generate images
    n_gen = 100
    x_mean = vae({}, x_dim, z_dim, n_gen).outputs('x_mean')
    x_gen = tf.reshape(x_mean, [-1, 28, 28, 1])

    # Define training/evaluation parameters
    epochs = 3000
    batch_size = 128
    iters = x_train.shape[0] // batch_size
    save_freq = 10
    test_freq = 10
    test_batch_size = 400
    test_iters = x_test.shape[0] // test_batch_size
    result_path = "results/vae"

    # Run the inference
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        for epoch in range(1, epochs + 1):
            time_epoch = -time.time()
            np.random.shuffle(x_train)
            lbs = []
            for t in range(iters):
                x_batch = x_train[t * batch_size:(t + 1) * batch_size]
                _, lb = sess.run([infer_op, lower_bound],
                                 feed_dict={x_input: x_batch,
                                            n_particles: 1})
                lbs.append(lb)
            time_epoch += time.time()
            print('Epoch {} ({:.1f}s): Lower bound = {}'.format(
                epoch, time_epoch, np.mean(lbs)))

            if epoch % test_freq == 0:
                time_test = -time.time()
                test_lbs = []
                test_lls = []
                for t in range(test_iters):
                    test_x_batch = x_test[t * test_batch_size:
                                          (t + 1) * test_batch_size]
                    test_lb = sess.run(lower_bound,
                                       feed_dict={x: test_x_batch,
                                                  n_particles: 1})
                    test_ll = sess.run(is_log_likelihood,
                                       feed_dict={x: test_x_batch,
                                                  n_particles: 1000})
                    test_lbs.append(test_lb)
                    test_lls.append(test_ll)
                time_test += time.time()
                print('>>> TEST ({:.1f}s)'.format(time_test))
                print('>> Test lower bound = {}'.format(np.mean(test_lbs)))
                print('>> Test log likelihood (IS) = {}'.format(
                    np.mean(test_lls)))

            if epoch % save_freq == 0:
                images = sess.run(x_gen)
                name = os.path.join(result_path,
                                    "vae.epoch.{}.png".format(epoch))
                save_image_collections(images, name)
示例#6
0
def main():
    # Load MNIST
    data_path = os.path.join(conf.data_dir, "mnist.pkl.gz")
    x_train, t_train, x_valid, t_valid, x_test, t_test = \
        dataset.load_mnist_realval(data_path)
    x_train = np.vstack([x_train, x_valid])
    x_test = np.random.binomial(1, x_test, size=x_test.shape)
    x_dim = x_train.shape[1]

    # Define model parameters
    z_dim = 40

    # Build the computation graph
    n_particles = tf.placeholder(tf.int32, shape=[], name="n_particles")
    x_input = tf.placeholder(tf.float32, shape=[None, x_dim], name="x")
    x = tf.cast(tf.less(tf.random_uniform(tf.shape(x_input)), x_input),
                tf.int32)
    n = tf.placeholder(tf.int32, shape=[], name="n")

    model = build_gen(x_dim, z_dim, n, n_particles)
    variational = build_q_net(x, z_dim, n_particles)

    lower_bound = zs.variational.elbo(model, {"x": x},
                                      variational=variational,
                                      axis=0)
    cost = tf.reduce_mean(lower_bound.sgvb())
    lower_bound = tf.reduce_mean(lower_bound)

    # # Importance sampling estimates of marginal log likelihood
    is_log_likelihood = tf.reduce_mean(
        zs.is_loglikelihood(model, {"x": x}, proposal=variational, axis=0))

    optimizer = tf.train.AdamOptimizer(learning_rate=0.001)
    infer_op = optimizer.minimize(cost)

    # Random generation
    x_gen = tf.reshape(model.observe()["x_mean"], [-1, 28, 28, 1])

    # Define training/evaluation parameters
    epochs = 3000
    batch_size = 128
    iters = x_train.shape[0] // batch_size
    save_freq = 10
    test_freq = 10
    test_batch_size = 400
    test_iters = x_test.shape[0] // test_batch_size
    result_path = "results/vae"

    # Run the inference
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        for epoch in range(1, epochs + 1):
            time_epoch = -time.time()
            np.random.shuffle(x_train)
            lbs = []
            for t in range(iters):
                x_batch = x_train[t * batch_size:(t + 1) * batch_size]
                _, lb = sess.run([infer_op, lower_bound],
                                 feed_dict={
                                     x_input: x_batch,
                                     n_particles: 1,
                                     n: batch_size
                                 })
                lbs.append(lb)
            time_epoch += time.time()
            print("Epoch {} ({:.1f}s): Lower bound = {}".format(
                epoch, time_epoch, np.mean(lbs)))

            if epoch % test_freq == 0:
                time_test = -time.time()
                test_lbs, test_lls = [], []
                for t in range(test_iters):
                    test_x_batch = x_test[t * test_batch_size:(t + 1) *
                                          test_batch_size]
                    test_lb = sess.run(lower_bound,
                                       feed_dict={
                                           x: test_x_batch,
                                           n_particles: 1,
                                           n: test_batch_size
                                       })
                    test_ll = sess.run(is_log_likelihood,
                                       feed_dict={
                                           x: test_x_batch,
                                           n_particles: 1000,
                                           n: test_batch_size
                                       })
                    test_lbs.append(test_lb)
                    test_lls.append(test_ll)
                time_test += time.time()
                print(">>> TEST ({:.1f}s)".format(time_test))
                print(">> Test lower bound = {}".format(np.mean(test_lbs)))
                print('>> Test log likelihood (IS) = {}'.format(
                    np.mean(test_lls)))

            if epoch % save_freq == 0:
                images = sess.run(x_gen, feed_dict={n: 100, n_particles: 1})
                name = os.path.join(result_path,
                                    "vae.epoch.{}.png".format(epoch))
                save_image_collections(images, name)
示例#7
0
def main():
    # Load MNIST
    data_path = os.path.join(conf.data_dir, 'mnist.pkl.gz')
    x_train, t_train, x_valid, t_valid, x_test, t_test = \
        dataset.load_mnist_realval(data_path)
    x_train = np.vstack([x_train, x_valid])
    x_test = np.random.binomial(1, x_test, size=x_test.shape)
    x_dim = x_train.shape[1]

    # Define model parameters
    z_dim = 40

    # Build the computation graph
    is_training = tf.placeholder(tf.bool, shape=[], name='is_training')
    n_particles = tf.placeholder(tf.int32, shape=[], name='n_particles')
    x_input = tf.placeholder(tf.float32, shape=[None, x_dim], name='x')
    x = tf.to_int32(tf.less(tf.random_uniform(tf.shape(x_input)), x_input))
    n = tf.shape(x)[0]

    def log_joint(observed):
        model = vae(observed, n, x_dim, z_dim, n_particles, is_training)
        log_pz, log_px_z = model.local_log_prob(['z', 'x'])
        return log_pz + log_px_z

    variational = q_net(x, z_dim, n_particles, is_training)
    qz_samples, log_qz = variational.query('z',
                                           outputs=True,
                                           local_log_prob=True)
    cx = tf.expand_dims(baseline_net(x), 0)
    lower_bound = zs.variational.elbo(log_joint,
                                      observed={'x': x},
                                      latent={'z': [qz_samples, log_qz]},
                                      axis=0)
    cost, baseline_cost = lower_bound.reinforce(baseline=cx)
    cost = tf.reduce_mean(cost + baseline_cost)
    lower_bound = tf.reduce_mean(lower_bound)

    # Importance sampling estimates of marginal log likelihood
    is_log_likelihood = tf.reduce_mean(
        zs.is_loglikelihood(log_joint, {'x': x}, {'z': [qz_samples, log_qz]},
                            axis=0))

    optimizer = tf.train.AdamOptimizer(0.001)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        infer_op = optimizer.minimize(cost)

    # Define training/evaluation parameters
    epochs = 3000
    batch_size = 128
    iters = x_train.shape[0] // batch_size
    test_freq = 10
    test_batch_size = 400
    test_iters = x_test.shape[0] // test_batch_size

    # Run the inference
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        for epoch in range(1, epochs + 1):
            time_epoch = -time.time()
            np.random.shuffle(x_train)
            lbs = []
            for t in range(iters):
                x_batch = x_train[t * batch_size:(t + 1) * batch_size]
                _, lb = sess.run([infer_op, lower_bound],
                                 feed_dict={
                                     x_input: x_batch,
                                     is_training: True,
                                     n_particles: 1
                                 })
                lbs.append(lb)
            time_epoch += time.time()
            print('Epoch {} ({:.1f}s): Lower bound = {}'.format(
                epoch, time_epoch, np.mean(lbs)))

            if epoch % test_freq == 0:
                time_test = -time.time()
                test_lbs = []
                test_lls = []
                for t in range(test_iters):
                    test_x_batch = x_test[t * test_batch_size:(t + 1) *
                                          test_batch_size]
                    test_lb = sess.run(lower_bound,
                                       feed_dict={
                                           x: test_x_batch,
                                           is_training: False,
                                           n_particles: 1
                                       })
                    test_ll = sess.run(is_log_likelihood,
                                       feed_dict={
                                           x: test_x_batch,
                                           is_training: False,
                                           n_particles: 1000
                                       })
                    test_lbs.append(test_lb)
                    test_lls.append(test_ll)
                time_test += time.time()
                print('>>> TEST ({:.1f}s)'.format(time_test))
                print('>> Test lower bound = {}'.format(np.mean(test_lbs)))
                print('>> Test log likelihood (IS) = {}'.format(
                    np.mean(test_lls)))
示例#8
0
                                             local_log_prob=True)
    qh1_samples, log_qh1 = variational.query('h1',
                                             outputs=True,
                                             local_log_prob=True)
    cost, lower_bound = zs.rws(log_joint, {'x': x_obs}, {
        'h3': [qh3_samples, log_qh3],
        'h2': [qh2_samples, log_qh2],
        'h1': [qh1_samples, log_qh1]
    },
                               axis=0)
    lower_bound = tf.reduce_mean(lower_bound)
    cost = tf.reduce_mean(cost)
    log_likelihood = tf.reduce_mean(
        zs.is_loglikelihood(log_joint, {'x': x_obs}, {
            'h3': [qh3_samples, log_qh3],
            'h2': [qh2_samples, log_qh2],
            'h1': [qh1_samples, log_qh1]
        },
                            axis=0))

    learning_rate_ph = tf.placeholder(tf.float32, shape=[], name='lr')
    optimizer = tf.train.AdamOptimizer(learning_rate_ph, epsilon=1e-4)
    grads = optimizer.compute_gradients(cost)
    infer = optimizer.apply_gradients(grads)

    params = tf.trainable_variables()
    for i in params:
        print(i.name, i.get_shape())

    # Run the inference
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
示例#9
0
    #             'z_1': [qz_samples1, log_qz1],'z_2': [qz_samples2, log_qz2]}, axis=0))

    _, x_recon = VLAE(
        {
            'z_0': qz_samples0,
            'z_1': qz_samples1,
            'z_2': qz_samples2
        }, n, n_x, n_z_0, n_z_1, n_z_2, n_particles, is_training)
    lower_bound = tf.reduce_mean(
        tf.square(tf.reshape(x_recon, [-1, n_x]) - tf.cast(x, tf.float32)))

    # Importance sampling estimates of marginal log likelihood
    is_log_likelihood = tf.reduce_mean(
        zs.is_loglikelihood(log_joint, {'x': x_obs}, {
            'z_0': [qz_samples0, log_qz0],
            'z_1': [qz_samples1, log_qz1],
            'z_2': [qz_samples2, log_qz2]
        },
                            axis=0))

    learning_rate_ph = tf.placeholder(tf.float32, shape=[], name='lr')
    optimizer = tf.train.AdamOptimizer(learning_rate_ph, epsilon=1e-4)
    grads = optimizer.compute_gradients(lower_bound)
    infer = optimizer.apply_gradients(grads)

    params = tf.trainable_variables()
    for i in params:
        print(i.name, i.get_shape())

    saver = tf.train.Saver(max_to_keep=10)

    # Run the inference
def main():
    # Load MNIST
    data_path = os.path.join(conf.data_dir, "mnist.pkl.gz")
    x_train, t_train, x_valid, t_valid, x_test, t_test = \
        dataset.load_mnist_realval(data_path)
    x_train = np.vstack([x_train, x_valid])
    x_test = np.random.binomial(1, x_test, size=x_test.shape)

    # Define model parameters
    x_dim = x_train.shape[1]
    z_dim = 40

    # Build the computation graph

    # how many samples to draw from the distribution, more samples, more accuracy
    n_particles = tf.placeholder(tf.int32, shape=[], name="n_particles")

    # input data to feed the variational
    x_input = tf.placeholder(tf.float32, shape=[None, x_dim], name="x")
    x = tf.cast(tf.less(tf.random_uniform(tf.shape(x_input)), x_input),
                tf.int32)

    # batch size
    n = tf.placeholder(tf.int32, shape=[], name="n")

    # add random noise to the variance of the q_model so to
    # get more various samples when generating new digits
    std_noise = tf.placeholder_with_default(0., shape=[], name="std_noise")

    # build the model (encoder) and the q_model (variational or decoder)
    model = build_gen(x_dim, z_dim, n, n_particles)
    q_model = build_q_net(x, z_dim, n_particles, std_noise)
    variational = q_model.observe()

    # calculate ELBO
    lower_bound = zs.variational.elbo(model, {"x": x},
                                      variational=variational,
                                      axis=0)
    cost = tf.reduce_mean(lower_bound.sgvb())
    lower_bound = tf.reduce_mean(lower_bound)

    # calculate marginal log likelihood
    is_log_likelihood = tf.reduce_mean(
        zs.is_loglikelihood(model, {"x": x}, proposal=variational, axis=0))

    # optimizer
    optimizer = tf.train.AdamOptimizer(learning_rate=0.001)
    infer_op = optimizer.minimize(cost)

    # define training/evaluation parameters
    epochs = 1000
    batch_size = 128
    iters = x_train.shape[0] // batch_size
    test_freq = 100
    test_batch_size = 400
    test_iters = x_test.shape[0] // test_batch_size
    result_path = "results/vae_digits"
    checkpoints_path = "checkpoints/vae_digits"

    # used to save checkpoints during training
    saver = tf.train.Saver(max_to_keep=10)
    save_model_freq = 100

    # run the inference
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        # restore the model parameters from the latest checkpoint
        ckpt_file = tf.train.latest_checkpoint(checkpoints_path)
        begin_epoch = 1
        if ckpt_file is not None:
            print('Restoring model from {}...'.format(ckpt_file))
            begin_epoch = int(ckpt_file.split('.')[-2]) + 1
            saver.restore(sess, ckpt_file)

        # begin training
        for epoch in range(begin_epoch, epochs + 1):
            time_epoch = -time.time()
            np.random.shuffle(x_train)
            lbs = []
            for t in range(iters):
                x_batch = x_train[t * batch_size:(t + 1) * batch_size]
                _, lb = sess.run([infer_op, lower_bound],
                                 feed_dict={
                                     x_input: x_batch,
                                     n_particles: 1,
                                     n: batch_size
                                 })
                lbs.append(lb)
            time_epoch += time.time()
            print("Epoch {} ({:.1f}s): Lower bound = {}".format(
                epoch, time_epoch, np.mean(lbs)))

            # test marginal log likelihood
            if epoch % test_freq == 0:
                time_test = -time.time()
                test_lbs, test_lls = [], []
                for t in range(test_iters):
                    test_x_batch = x_test[t * test_batch_size:(t + 1) *
                                          test_batch_size]
                    test_lb = sess.run(lower_bound,
                                       feed_dict={
                                           x: test_x_batch,
                                           n_particles: 1,
                                           n: test_batch_size
                                       })
                    test_ll = sess.run(is_log_likelihood,
                                       feed_dict={
                                           x: test_x_batch,
                                           n_particles: 1000,
                                           n: test_batch_size
                                       })
                    test_lbs.append(test_lb)
                    test_lls.append(test_ll)
                time_test += time.time()
                print(">>> TEST ({:.1f}s)".format(time_test))
                print(">> Test lower bound = {}".format(np.mean(test_lbs)))
                print('>> Test log likelihood (IS) = {}'.format(
                    np.mean(test_lls)))

            # save model parameters
            if epoch % save_model_freq == 0:
                print('Saving model...')
                save_path = os.path.join(checkpoints_path,
                                         "vae.epoch.{}.ckpt".format(epoch))
                if not os.path.exists(os.path.dirname(save_path)):
                    os.makedirs(os.path.dirname(save_path))
                saver.save(sess, save_path)
                print('Done')

        # random generation of images from latent distribution
        x_gen = tf.reshape(model.observe()["x_mean"], [-1, 28, 28, 1])
        images = sess.run(x_gen, feed_dict={n: 100, n_particles: 1})
        name = os.path.join(result_path, "random_samples.png")
        save_image_collections(images, name)

        # the following code generates 100 samples for each number
        test_n = [3, 2, 1, 90, 95, 23, 11, 0, 84, 7]
        # map each digit to a corresponding sample from the test set so we can generate similar digits
        for i in range(len(test_n)):
            # get latent distribution from the variational giving as input a fixed sample from the dataset
            z = q_model.observe(x=np.expand_dims(x_test[test_n[i]], 0))['z']
            # run the computation graph adding noise to computed variance to get different output samples
            latent = sess.run(z,
                              feed_dict={
                                  x_input:
                                  np.expand_dims(x_test[test_n[i]], 0),
                                  n: 1,
                                  n_particles: 100,
                                  std_noise: 0.7
                              })
            # get the image from the model giving as input the latent distribution z
            x_gen = tf.reshape(
                model.observe(z=latent)["x_mean"], [-1, 28, 28, 1])
            images = sess.run(x_gen, feed_dict={})
            name = os.path.join(result_path, "{}.png".format(i))
            save_image_collections(images, name)