x_bin = tf.cast(tf.less(tf.random_uniform(tf.shape(x_orig), 0, 1), x_orig), tf.int32) x = tf.placeholder(tf.int32, shape=[None, n_x], name='x') x_obs = tf.tile(tf.expand_dims(x, 0), [n_particles, 1, 1]) n = tf.shape(x)[0] def log_joint(observed): model = vae(observed, n, n_x, n_z, n_particles, is_training) log_pz, log_px_z = model.local_log_prob(['z', 'x']) return log_pz + log_px_z variational = q_net({}, x, n_z, n_particles, is_training) qz_samples, log_qz = variational.query('z', outputs=True, local_log_prob=True) # TODO: add tests for repeated calls of flows qz_samples, log_qz = zs.planar_normalizing_flow(qz_samples, log_qz, n_iters=n_planar_flows) qz_samples, log_qz = zs.planar_normalizing_flow(qz_samples, log_qz, n_iters=n_planar_flows) lower_bound = tf.reduce_mean( zs.sgvb(log_joint, {'x': x_obs}, {'z': [qz_samples, log_qz]}, axis=0)) # Importance sampling estimates of log likelihood: # Fast, used for evaluation during training is_log_likelihood = tf.reduce_mean( zs.is_loglikelihood(log_joint, {'x': x_obs}, {'z': [qz_samples, log_qz]}, axis=0)) learning_rate_ph = tf.placeholder(tf.float32, shape=[], name='lr') optimizer = tf.train.AdamOptimizer(learning_rate_ph, epsilon=1e-4) grads = optimizer.compute_gradients(-lower_bound)
def main(): tf.set_random_seed(1234) np.random.seed(1234) # Load MNIST data_path = os.path.join(conf.data_dir, 'mnist.pkl.gz') x_train, t_train, x_valid, t_valid, x_test, t_test = \ dataset.load_mnist_realval(data_path) x_train = np.vstack([x_train, x_valid]) x_test = np.random.binomial(1, x_test, size=x_test.shape) x_dim = x_train.shape[1] # Define model/inference parameters z_dim = 40 n_planar_flows = 10 # Build the computation graph n_particles = tf.placeholder(tf.int32, shape=[], name="n_particles") x_input = tf.placeholder(tf.float32, shape=[None, x_dim], name="x") x = tf.cast(tf.less(tf.random_uniform(tf.shape(x_input)), x_input), tf.int32) n = tf.placeholder(tf.int32, shape=[], name="n") model = build_gen(n, x_dim, z_dim, n_particles) q_net = build_q_net(x, z_dim, n_particles) qz_samples, log_qz = q_net.query('z', outputs=True, local_log_prob=True) # TODO: add tests for repeated calls of flows qz_samples, log_qz = zs.planar_normalizing_flow(qz_samples, log_qz, n_iters=n_planar_flows) qz_samples, log_qz = zs.planar_normalizing_flow(qz_samples, log_qz, n_iters=n_planar_flows) lower_bound = zs.variational.elbo(model, observed={"x": x}, latent={"z": [qz_samples, log_qz]}, axis=0) cost = tf.reduce_mean(lower_bound.sgvb()) lower_bound = tf.reduce_mean(lower_bound) # Importance sampling estimates of marginal log likelihood is_log_likelihood = tf.reduce_mean( zs.is_loglikelihood(model, {'x': x}, {'z': [qz_samples, log_qz]}, axis=0)) optimizer = tf.train.AdamOptimizer(learning_rate=0.001) infer_op = optimizer.minimize(cost) # Define training/evaluation parameters epochs = 3000 batch_size = 128 iters = x_train.shape[0] // batch_size test_freq = 10 test_batch_size = 400 test_iters = x_test.shape[0] // test_batch_size # Run the inference with tf.Session() as sess: sess.run(tf.global_variables_initializer()) for epoch in range(1, epochs + 1): time_epoch = -time.time() np.random.shuffle(x_train) lbs = [] for t in range(iters): x_batch = x_train[t * batch_size:(t + 1) * batch_size] _, lb = sess.run([infer_op, lower_bound], feed_dict={ x_input: x_batch, n_particles: 1, n: batch_size }) lbs.append(lb) time_epoch += time.time() print('Epoch {} ({:.1f}s): Lower bound = {}'.format( epoch, time_epoch, np.mean(lbs))) if epoch % test_freq == 0: time_test = -time.time() test_lbs = [] test_lls = [] for t in range(test_iters): test_x_batch = x_test[t * test_batch_size:(t + 1) * test_batch_size] test_lb = sess.run(lower_bound, feed_dict={ x: test_x_batch, n_particles: 1, n: test_batch_size }) test_ll = sess.run(is_log_likelihood, feed_dict={ x: test_x_batch, n_particles: 1000, n: test_batch_size }) test_lbs.append(test_lb) test_lls.append(test_ll) time_test += time.time() print('>>> TEST ({:.1f}s)'.format(time_test)) print('>> Test lower bound = {}'.format(np.mean(test_lbs))) print('>> Test log likelihood (IS) = {}'.format( np.mean(test_lls)))