Beispiel #1
0
def main():
    tf.set_random_seed(1234)
    np.random.seed(1234)

    # Load MNIST
    data_path = os.path.join('data', 'mnist.pkl.gz')
    x_train, t_train, x_valid, t_valid, x_test, t_test = \
        dataset.load_mnist_realval(data_path)
    x_train = np.vstack([x_train, x_valid])
    x_test = np.random.binomial(1, x_test, size=x_test.shape)
    n_x = x_train.shape[1]
    n_z = FLAGS.n_z

    n_particles = tf.placeholder(tf.int32, shape=[], name='n_particles')
    x_input = tf.placeholder(tf.float32, shape=[None, n_x], name='x')
    x = tf.to_int32(tf.random_uniform(tf.shape(x_input)) <= x_input)
    learning_rate_ph = tf.placeholder(tf.float32, shape=[], name='lr')
    optimizer = tf.train.AdamOptimizer(learning_rate_ph, beta1=0.5)

    def build_tower_graph(x, id_):
        tower_x = x[id_ * tf.shape(x)[0] // FLAGS.num_gpus:(id_ + 1) *
                    tf.shape(x)[0] // FLAGS.num_gpus]
        n = tf.shape(tower_x)[0]

        # qz_samples: [n_particles, n, n_z]
        qz_samples = q_net(tower_x, n_z, n_particles)
        # Use a single particle for the reconstruction term
        observed = {'x': tower_x, 'z': qz_samples[:1]}
        model, z, _ = vae(observed, n, n_x, n_z, 1)
        # log_px_qz: [1, n]
        log_px_qz = model.local_log_prob('x')
        eq_ll = tf.reduce_mean(log_px_qz)
        # log_p_qz: [n_particles, n]
        log_p_qz = z.log_prob(qz_samples)
        eq_joint = eq_ll + tf.reduce_mean(log_p_qz)

        if FLAGS.estimator == "stein":
            estimator = SteinScoreEstimator(eta=eta)
        elif FLAGS.estimator == "spectral":
            estimator = SpectralScoreEstimator(n_eigen=None,
                                               eta=None,
                                               n_eigen_threshold=0.99)
        else:
            raise ValueError("The chosen estimator is not recognized.")

        qzs = tf.transpose(qz_samples, [1, 0, 2])
        dlog_q = estimator.compute_gradients(qzs)
        entropy_surrogate = tf.reduce_mean(
            tf.reduce_sum(tf.stop_gradient(-dlog_q) * qzs, -1))
        cost = -eq_joint - entropy_surrogate
        grads_and_vars = optimizer.compute_gradients(cost)

        return grads_and_vars, eq_joint

    tower_losses = []
    tower_grads = []
    for i in range(FLAGS.num_gpus):
        with tf.device('/gpu:%d' % i):
            with tf.name_scope('tower_%d' % i):
                grads, tower_eq_joint = build_tower_graph(x, i)
                tower_losses.append([tower_eq_joint])
                tower_grads.append(grads)

    eq_joint = average_losses(tower_losses)[0]
    grads = average_gradients(tower_grads)
    infer_op = optimizer.apply_gradients(grads)

    # Generate images
    n_gen = 100
    _, _, x_logits = vae({}, n_gen, n_x, n_z, 1)
    x_gen = tf.reshape(tf.sigmoid(x_logits), [-1, 28, 28, 1])

    # Define training parameters
    learning_rate = 1e-4
    epochs = 3000
    batch_size = 128
    iters = x_train.shape[0] // batch_size
    save_image_freq = 10
    save_model_freq = 100
    test_freq = 10
    test_batch_size = 400
    test_iters = x_test.shape[0] // test_batch_size
    result_path = "results/vae_conv_{}_{}".format(
        n_z, FLAGS.estimator) + time.strftime("_%Y%m%d_%H%M%S")

    saver = tf.train.Saver(max_to_keep=10)
    logger = setup_logger('vae_conv_' + FLAGS.estimator, __file__, result_path)

    with create_session(FLAGS.log_device_placement) as sess:
        sess.run(tf.global_variables_initializer())

        # Restore from the latest checkpoint
        ckpt_file = tf.train.latest_checkpoint(result_path)
        begin_epoch = 1
        if ckpt_file is not None:
            logger.info('Restoring model from {}...'.format(ckpt_file))
            begin_epoch = int(ckpt_file.split('.')[-2]) + 1
            saver.restore(sess, ckpt_file)

        for epoch in range(begin_epoch, epochs + 1):
            time_epoch = -time.time()
            np.random.shuffle(x_train)
            eq_joints = []
            for t in range(iters):
                x_batch = x_train[t * batch_size:(t + 1) * batch_size]
                _, eq_joint_ = sess.run(
                    [infer_op, eq_joint],
                    feed_dict={
                        x_input: x_batch,
                        learning_rate_ph: learning_rate,
                        n_particles: n_est
                    },
                )

                eq_joints.append(eq_joint_)

            time_epoch += time.time()
            logger.info('Epoch {} ({:.1f}s): log joint = {}'.format(
                epoch, time_epoch, np.mean(eq_joints)))

            if epoch % test_freq == 0:
                time_test = -time.time()
                test_eq_joints = []
                for t in range(test_iters):
                    test_x_batch = x_test[t * test_batch_size:(t + 1) *
                                          test_batch_size]
                    test_eq_joint = sess.run(eq_joint,
                                             feed_dict={
                                                 x: test_x_batch,
                                                 n_particles: n_est
                                             })
                    test_eq_joints.append(test_eq_joint)
                time_test += time.time()
                logger.info('>>> TEST ({:.1f}s)'.format(time_test))
                logger.info('>> Test log joint = {}'.format(
                    np.mean(test_eq_joints)))

            if epoch % save_image_freq == 0:
                logger.info('Saving images...')
                images = sess.run(x_gen)
                name = os.path.join(result_path,
                                    "vae.epoch.{}.png".format(epoch))
                save_image_collections(images, name)

            if epoch % save_model_freq == 0:
                logger.info('Saving model...')
                save_path = os.path.join(result_path,
                                         "vae.epoch.{}.ckpt".format(epoch))
                if not os.path.exists(os.path.dirname(save_path)):
                    os.makedirs(os.path.dirname(save_path))
                saver.save(sess, save_path)
                logger.info('Done')
def main():
    seed = FLAGS.seed
    result_path = "results/mnist_crowd_{}_{}".format(time.strftime("%Y%m%d_%H%M%S"), seed)
    logger = setup_logger('mnist', __file__, result_path)
    np.random.seed(seed)
    tf.set_random_seed(seed)

    # Load MNIST
    data_path = os.path.join('data', 'mnist.pkl.gz')
    o_train, t_train, o_valid, t_valid, o_test, t_test = \
        dataset.load_mnist_realval(data_path, one_hot=False)
    o_train = np.vstack([o_train, o_valid])
    t_train = np.hstack([t_train, t_valid])
    n_train, o_dim = o_train.shape
    # indices = np.random.permutation(n_train)
    # o_train = o_train[indices]
    # t_train = t_train[indices]
    o_test = np.random.binomial(1, o_test, size=o_test.shape)
    n_test, _ = o_test.shape
    # n_class = np.max(t_test) + 1

    # Prior parameters
    d = 8
    K = 50
    W = 20
    prior_alpha = 1.05
    prior_niw_conc = 0.5
    prior_tau = 1.

    # Variational initialization
    alpha = 2.
    niw_conc = 1.
    random_scale = 3.
    tau = 10.

    # learning rate
    learning_rate = 1e-3
    nat_grad_scale = 1e4

    # Load annotations
    # [i, j, w, L]
    annotations = load_annotations(t_train, W, method="real")
    n_annotations = annotations.shape[0]
    W = len(set(annotations[:, 2]))
    # batch_size = 128
    # iters = o_train.shape[0] // batch_size
    # ann_batch_size = annotations.shape[0] // iters
    # print(ann_batch_size)
    # exit(0)

    # Define training parameters
    epochs = 200
    batch_size = 128
    iters = o_train.shape[0] // batch_size
    ann_batch_size = annotations.shape[0] // iters
    save_freq = 1
    test_freq = 10
    test_batch_size = 400
    test_iters = o_test.shape[0] // test_batch_size

    prior_global_params = get_global_params(
        "prior", d, K, W, prior_alpha, prior_niw_conc, prior_tau,
        trainable=False)
    global_params = get_global_params(
        "variational", d, K, W, alpha, niw_conc, tau,
        random_scale=random_scale, trainable=True)

    # n_particles = tf.placeholder(tf.int32, shape=[], name='n_particles')
    o_input = tf.placeholder(tf.float32, shape=[None, o_dim], name='o')
    o = tf.to_int32(tf.random_uniform(tf.shape(o_input)) <= o_input)

    ann_o_input = tf.placeholder(tf.float32, shape=[None, o_dim], name='ann_o')
    ann_o = tf.to_int32(tf.random_uniform(tf.shape(ann_o_input)) <= ann_o_input)
    L_ph = tf.sparse_placeholder(tf.float32, shape=[None, None, W])
    I_ph = tf.sparse_placeholder(tf.float32, shape=[None, None, W])

    lower_bound, global_nat_grads, z_stats, niw_stats, dir_stats = \
        variational_message_passing(
            prior_global_params, global_params, o, o_dim, d, K, n_train,
            n_iters=4)
    z_pred = tf.argmax(z_stats, axis=-1)

    ann_lower_bound, ann_nat_grads, _, _, _ = variational_message_passing(
        prior_global_params, global_params, ann_o, o_dim, d, K, n_train,
        L_ph, I_ph, n_annotations, ann_batch_size, n_iters=4)
    # ann_lower_bound = tf.constant(0.)
    # ann_nat_grads = [tf.zeros_like(param) for param in global_params]

    optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                           momentum=0.9)
    net_vars = (tf.trainable_variables(scope="encoder") +
                tf.trainable_variables(scope="decoder"))
    net_grads_and_vars = optimizer.compute_gradients(
        -0.5 * (lower_bound + ann_lower_bound), var_list=net_vars)
    global_nat_grads.extend([0, 0])
    nat_grads = [-nat_grad_scale * 0.5 * (g + ann_g)
                 for g, ann_g in zip(global_nat_grads, ann_nat_grads)]
    global_grads_and_vars = list(zip(nat_grads, global_params))
    infer_op = optimizer.apply_gradients(net_grads_and_vars +
                                         global_grads_and_vars)

    # Generation
    # niw_stats: [K, d + d^2 + 2]
    gen_mvn_params = niw_stats[:, :-2]
    # transparency: [K]
    transp = tf.exp(dir_stats) / tf.reduce_max(tf.exp(dir_stats))
    # x_samples: [K, d, 10]
    x_samples = mvn.sample(gen_mvn_params, d, n_samples=10)
    # o_mean: [10, K, o_dim]
    _, o_mean = decoder(tf.transpose(x_samples, [2, 0, 1]), o_dim)
    # o_gen: [10 * K, 28, 28, 1]
    o_gen = tf.reshape(o_mean * transp[:, None], [-1, 28, 28, 1])

    def _evaluate(pred_batches, labels):
        preds = np.hstack(pred_batches)
        truths = labels[:preds.size]
        acc, _ = cluster_acc(preds, truths)
        nmi = adjusted_mutual_info_score(truths, labels_pred=preds)
        return acc, nmi

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        for epoch in range(1, epochs + 1):
            time_epoch = -time.time()
            indices = np.random.permutation(n_train)
            # print(indices[:5])
            # exit(0)
            o_train_raw = o_train[indices]
            t_train_raw = t_train[indices]
            lbs, ann_lbs = [], []
            t_preds, ann_t_preds = [], []
            for t in range(iters):
                # Without annotation
                o_batch = o_train_raw[t * batch_size:(t + 1) * batch_size]

                # With annotation
                ann_indices = np.random.randint(0, n_annotations,
                                                size=ann_batch_size)
                ann_batch = annotations[ann_indices]
                o_indices, orig_to_batch_ind, batch_to_orig_ind, \
                    sparse_ann_batch, sparse_ann_ind = make_sparse_ann_batch(
                        ann_batch, W)
                ann_o_batch = o_train[o_indices]

                _, lb, t_pred, ann_lb = sess.run(
                    [infer_op, lower_bound, z_pred, ann_lower_bound],
                    feed_dict={o_input: o_batch,
                               ann_o_input: ann_o_batch,
                               L_ph: sparse_ann_batch,
                               I_ph: sparse_ann_ind})
                lbs.append(lb)
                t_preds.append(t_pred)
                # print("lb: {}".format(lb))
                ann_lbs.append(ann_lb)

            time_epoch += time.time()
            train_acc, train_nmi = _evaluate(t_preds, t_train_raw)
            logger.info(
                'Epoch {} ({:.1f}s): Lower bound = {}, ann LB = {}, '
                'acc = {}, nmi = {}'
                .format(epoch, time_epoch, np.mean(lbs), np.mean(ann_lbs),
                        train_acc, train_nmi))

            if epoch % test_freq == 0:
                time_test = -time.time()
                test_lbs = []
                test_t_preds = []
                for t in range(test_iters):
                    test_o_batch = o_test[t * test_batch_size:
                                          (t + 1) * test_batch_size]
                    test_lb, test_t_pred = sess.run([lower_bound, z_pred],
                                                    feed_dict={o: test_o_batch})
                    test_lbs.append(test_lb)
                    test_t_preds.append(test_t_pred)

                time_test += time.time()
                test_acc, test_nmi = _evaluate(test_t_preds, t_test)
                logger.info('>>> TEST ({:.1f}s)'.format(time_test))
                logger.info('>> Test lower bound = {}, acc = {}, nmi = {}'
                            .format(np.mean(test_lbs), test_acc, test_nmi))

                if epoch == epochs:
                    with open('results/mnist_bayesSCDC.txt', "a") as myfile:
                        myfile.write("seed: %d train_acc: %f train_nmi: %f "
                                     "test_acc: %f test_nmi: %f" % (
                            seed, train_acc, train_nmi, test_acc, test_nmi))
                        myfile.write('\n')
                        myfile.close()

            if epoch % save_freq == 0:
                logger.info('Saving images...')
                images = sess.run(o_gen)
                name = os.path.join(result_path,
                                    "vae.epoch.{}.png".format(epoch))
                save_image_collections(images, name, shape=(10, K))
Beispiel #3
0
def main():
    # Load MNIST
    data_path = os.path.join(conf.data_dir, "mnist.pkl.gz")
    x_train, t_train, x_valid, t_valid, x_test, t_test = \
        dataset.load_mnist_realval(data_path)
    x_train = np.vstack([x_train, x_valid])
    y_train = np.vstack([t_train, t_valid])
    x_test = np.random.binomial(1, x_test, size=x_test.shape)
    x_dim = x_train.shape[1]
    y_dim = y_train.shape[1]
    # Define model parameters
    z_dim = 40

    # Build the computation graph
    n_particles = tf.placeholder(tf.int32, shape=[], name="n_particles")
    x_input = tf.placeholder(tf.float32, shape=[None, x_dim], name="x")
    x = tf.cast(tf.less(tf.random_uniform(tf.shape(x_input)), x_input),tf.int32)
    y = tf.placeholder(tf.float32, shape=[None, 10], name="y")
    n = tf.placeholder(tf.int32, shape=[], name="n")

    model = build_gen(x_dim, z_dim, y, n, n_particles)
    variational = build_q_net(x, z_dim, y, n_particles)

    lower_bound = zs.variational.elbo(
        model, {"x": x}, variational=variational, axis=0)
    cost = tf.reduce_mean(lower_bound.sgvb())
    lower_bound = tf.reduce_mean(lower_bound)

    # # Importance sampling estimates of marginal log likelihood
    is_log_likelihood = tf.reduce_mean(
        zs.is_loglikelihood(model, {"x": x}, proposal=variational, axis=0))

    optimizer = tf.train.AdamOptimizer(learning_rate=0.001)
    infer_op = optimizer.minimize(cost)

    # Random generation
    y_observe = tf.placeholder(tf.float32, shape=[None, 10], name="y_observe")
    x_gen = tf.reshape(model.observe(y=y_observe)["x_mean"], [-1, 28, 28, 1])

    # Define training/evaluation parameters
    epochs = 3000
    batch_size = 128
    iters = x_train.shape[0] // batch_size
    save_freq = 10
    test_freq = 10
    test_batch_size = 400
    test_iters = x_test.shape[0] // test_batch_size
    result_path = "results/vae"

    # Run the inference
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        for epoch in range(1, epochs + 1):
            time_epoch = -time.time()
            np.random.shuffle(x_train)
            lbs = []
            for t in range(iters):
                x_batch = x_train[t * batch_size:(t + 1) * batch_size]
                y_batch = y_train[t * batch_size:(t + 1) * batch_size]
                _, lb = sess.run([infer_op, lower_bound],
                                 feed_dict={x_input: x_batch,
                                            y: y_batch,
                                            n_particles: 1,
                                            n: batch_size})
                lbs.append(lb)
            time_epoch += time.time()
            print("Epoch {} ({:.1f}s): Lower bound = {}".format(
                epoch, time_epoch, np.mean(lbs)))

            if epoch % test_freq == 0:
                time_test = -time.time()
                test_lbs, test_lls = [], []
                for t in range(test_iters):
                    test_x_batch = x_test[t * test_batch_size:(t + 1) * test_batch_size]
                    test_y_batch = t_test[t * test_batch_size:(t + 1) * test_batch_size]
                    test_lb = sess.run(lower_bound,
                                       feed_dict={x: test_x_batch,
                                                  y: test_y_batch,
                                                  n_particles: 1,
                                                  n: test_batch_size})
                    test_ll = sess.run(is_log_likelihood,
                                       feed_dict={x: test_x_batch,
                                                  y: test_y_batch,
                                                  n_particles: 1000,
                                                  n: test_batch_size})
                    test_lbs.append(test_lb)
                    test_lls.append(test_ll)
                time_test += time.time()
                print(">>> TEST ({:.1f}s)".format(time_test))
                print(">> Test lower bound = {}".format(np.mean(test_lbs)))
                print('>> Test log likelihood (IS) = {}'.format(
                    np.mean(test_lls)))

            if epoch % save_freq == 0:
                y_index = np.repeat(np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), 10)
                y_index = np.eye(10)[y_index.reshape(-1)]
                images = sess.run(x_gen, feed_dict={y_observe: y_index, y: y_index, n: 100, n_particles: 1})
                name = os.path.join(result_path, "c-vae.epoch.{}.png".format(epoch))
                save_image_collections(images, name)
Beispiel #4
0
from zhusuan.evaluation import AIS

from utils.utils import setup_logger

FLAGS = tf.flags.FLAGS
tf.flags.DEFINE_string("dir", "", """The result directory.""")
tf.flags.DEFINE_integer("seed", 1234, """Random seed.""")

if __name__ == "__main__":
    seed = FLAGS.seed
    tf.set_random_seed(seed)

    # Load MNIST
    data_path = os.path.join('data', 'mnist.pkl.gz')
    x_train, t_train, x_valid, t_valid, x_test, t_test = \
        dataset.load_mnist_realval(data_path)
    np.random.seed(seed)
    x_test = np.random.binomial(1, x_test, size=x_test.shape)
    test_idx = np.arange(x_test.shape[0])
    np.random.shuffle(test_idx)
    x_test = x_test[test_idx[:2048]]
    n_x = x_test.shape[1]

    # Define model parameters
    from .vae_conv import vae
    n_z = int(FLAGS.dir.split('_')[2])

    # Define training/evaluation parameters
    if "vae_conv_" in FLAGS.dir:
        test_batch_size = 256
    elif "vae_" in FLAGS.dir:
            val_ML = prng.binomial(1, alpha[m], num_ML)
            val_CL = prng.binomial(1, 1 - beta[m], num_CL)
            Sm_ML = np.hstack((ML, np.ones((num_ML, 1)) * (m + start_expert),
                               val_ML.reshape(val_ML.size, 1)))
            Sm_CL = np.hstack((CL, np.ones((num_CL, 1)) * (m + start_expert),
                               val_CL.reshape(val_CL.size, 1)))
            S = np.vstack((S, Sm_ML, Sm_CL)).astype(int)

    return S


if __name__ == "__main__":
    # Load MNIST
    data_path = os.path.join('data', 'mnist.pkl.gz')
    x_train, t_train, x_valid, t_valid, x_test, t_test = \
        dataset.load_mnist_realval(data_path, one_hot=False)
    n_x = x_train.shape[1]
    n_y = 10
    n_j = 400  # number of workers
    n_l = 20  # number of items each worker annotates each class
    n_l_all = 100  # number of items each worker annotates
    confusion_rate = 0
    # x_by_class, t_by_class, ind_by_class = group_by_class(x_test, t_test)
    anns = []
    ids = []
    mask_id = np.zeros(len(t_test))  # whether the id is labeled
    end_id = n_l_all
    indices = np.arange(len(t_test))
    for j in range(n_j):
        # _, t_labeled, indices_labeled = select_by_class(x_by_class, t_by_class, ind_by_class, n_l)
        indices_labeled = indices[end_id - n_l_all:end_id]
Beispiel #6
0
def main():
    tf.set_random_seed(1234)
    np.random.seed(1234)

    # Load MNIST
    data_path = os.path.join('data', 'mnist.pkl.gz')
    x_train, t_train, x_valid, t_valid, x_test, t_test = \
        dataset.load_mnist_realval(data_path)
    x_train = np.vstack([x_train, x_valid])
    x_test = np.random.binomial(1, x_test, size=x_test.shape)
    n_x = x_train.shape[1]
    n_z = FLAGS.n_z

    n_particles = tf.placeholder(tf.int32, shape=[], name='n_particles')
    x_input = tf.placeholder(tf.float32, shape=[None, n_x], name='x')
    x = tf.to_int32(tf.random_uniform(tf.shape(x_input)) <= x_input)
    n = tf.shape(x)[0]

    qz = q_net(x, n_z, n_particles)
    # log_qz = qz.log_prob(qz)
    model, _ = vae({'x': x, 'z': qz}, n, n_x, n_z, n_particles)
    log_px_qz = model.local_log_prob('x')
    eq_ll = tf.reduce_mean(log_px_qz)

    kl = kl_normal_normal(
        qz.distribution.mean, qz.distribution.logstd, 0., 0.)
    kl_term = tf.reduce_mean(tf.reduce_sum(kl, -1))
    lower_bound = eq_ll - kl_term
    cost = -lower_bound

    # log_pz = model.local_log_prob('z')
    # kl_term_est = tf.reduce_mean(log_qz - log_pz)
    # cost = kl_term

    learning_rate_ph = tf.placeholder(tf.float32, shape=[], name='lr')
    optimizer = tf.train.AdamOptimizer(learning_rate_ph, beta1=0.5)
    infer_op = optimizer.minimize(cost)

    # Generate images
    n_gen = 100
    _, x_logits = vae({}, n_gen, n_x, n_z, 1)
    x_gen = tf.reshape(tf.sigmoid(x_logits), [-1, 28, 28, 1])

    # Define training parameters
    lb_samples = 1
    learning_rate = 1e-4
    epochs = 3000
    batch_size = 128
    iters = x_train.shape[0] // batch_size
    save_image_freq = 10
    save_model_freq = 100
    test_freq = 10
    test_batch_size = 400
    test_iters = x_test.shape[0] // test_batch_size
    result_path = "results/vae_conv_{}_".format(n_z) + \
        time.strftime("%Y%m%d_%H%M%S")

    saver = tf.train.Saver(max_to_keep=10)
    logger = setup_logger('vae_conv', __file__, result_path)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        # Restore from the latest checkpoint
        ckpt_file = tf.train.latest_checkpoint(result_path)
        begin_epoch = 1
        if ckpt_file is not None:
            logger.info('Restoring model from {}...'.format(ckpt_file))
            begin_epoch = int(ckpt_file.split('.')[-2]) + 1
            saver.restore(sess, ckpt_file)

        for epoch in range(begin_epoch, epochs + 1):
            time_epoch = -time.time()
            np.random.shuffle(x_train)
            lbs = []
            for t in range(iters):
                x_batch = x_train[t * batch_size:(t + 1) * batch_size]
                _, lb = sess.run(
                    [infer_op, lower_bound],
                    feed_dict={x_input: x_batch,
                               learning_rate_ph: learning_rate,
                               n_particles: lb_samples})
                lbs.append(lb)

            time_epoch += time.time()
            logger.info(
                'Epoch {} ({:.1f}s): Lower bound = {}'
                .format(epoch, time_epoch, np.mean(lbs)))

            if epoch % test_freq == 0:
                time_test = -time.time()
                test_lbs = []
                for t in range(test_iters):
                    test_x_batch = x_test[t * test_batch_size:
                                          (t + 1) * test_batch_size]
                    test_lb = sess.run(lower_bound,
                                       feed_dict={x: test_x_batch,
                                                  n_particles: lb_samples})
                    test_lbs.append(test_lb)
                time_test += time.time()
                logger.info('>>> TEST ({:.1f}s)'.format(time_test))
                logger.info('>> Test lower bound = {}'
                            .format(np.mean(test_lbs)))

            if epoch % save_image_freq == 0:
                logger.info('Saving images...')
                images = sess.run(x_gen)
                name = os.path.join(result_path,
                                    "vae.epoch.{}.png".format(epoch))
                save_image_collections(images, name)

            if epoch % save_model_freq == 0:
                logger.info('Saving model...')
                save_path = os.path.join(result_path,
                                         "vae.epoch.{}.ckpt".format(epoch))
                if not os.path.exists(os.path.dirname(save_path)):
                    os.makedirs(os.path.dirname(save_path))
                saver.save(sess, save_path)
                logger.info('Done')