def load_image_data(data, n_xl, n_channels, output_batch_size): if data == 'mnist': # Load MNIST data_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'data', 'mnist.pkl.gz') x_train, t_train, x_valid, t_valid, _, _ = \ dataset.load_mnist_realval(data_path) x_train = np.vstack([x_train, x_valid]).astype('float32') x_train = np.reshape(x_train, [-1, n_xl, n_xl, n_channels]) x_train2 = x_train[:output_batch_size] t_train2 = t_train[:output_batch_size] t_train2 = np.nonzero(t_train2)[1] order = np.argsort(t_train2) sorted_x_train = x_train2[order] elif data == 'svhn': # Load SVHN data print('Reading svhn...') time_read = -time.time() print('Train') x_train = np.load('data/svhn_train1_x.npy') y_train = np.load('data/svhn_train1_y.npy') print('Test') x_test = np.load('data/svhn_test_x.npy') y_test = np.load('data/svhn_test_y.npy') time_read += time.time() print('Finished in {:.4f} seconds'.format(time_read)) x_train2 = x_train[:output_batch_size] y_train2 = y_train[:output_batch_size] order = np.argsort(y_train2) sorted_x_train = x_train2[order] elif data == 'lfw': # Load LFW data print('Reading lfw...') time_read = -time.time() x_train = np.load('data/lfw.npy').astype(np.float32) print(x_train.shape) x_train = np.reshape(x_train, [-1, n_xl, n_xl, n_channels]) time_read += time.time() print('Finished in {:.4f} seconds'.format(time_read)) sorted_x_train = x_train[:output_batch_size] else: x_train, t_train, x_test, t_test = \ dataset.load_cifar10('data/cifar10/cifar-10-python.tar.gz', normalize=True, one_hot=True) x = np.vstack((x_train, x_test)) t = np.vstack((t_train, t_test)) x2 = x[:output_batch_size] t2 = np.argmax(t[:output_batch_size], 1) order = np.argsort(t2) x_train = x sorted_x_train = x2[order] return x_train, sorted_x_train
import tensorflow as tf from tensorflow.contrib import layers from six.moves import range import numpy as np import zhusuan as zs import conf import dataset if __name__ == "__main__": tf.set_random_seed(1237) # Load MNIST data_path = os.path.join(conf.data_dir, 'mnist.pkl.gz') x_train, t_train, x_valid, t_valid, x_test, t_test = \ dataset.load_mnist_realval(data_path) x_train = np.vstack([x_train, x_valid]).astype('float32') np.random.seed(1234) x_test = np.random.binomial(1, x_test, size=x_test.shape).astype('float32') n_x = x_train.shape[1] # Define model parameters n_z = 40 # Define training/evaluation parameters lb_samples = 10 ll_samples = 1000 epochs = 3000 batch_size = 100 iters = x_train.shape[0] // batch_size learning_rate = 0.001
def run_experiment(args): import os # set environment variables for tensorflow os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' import inspect import shutil import numpy as np import tensorflow as tf from collections import OrderedDict import matplotlib.pyplot as plt plt.switch_backend('Agg') import utils import paramgraphics import nn from tensorflow.contrib.framework.python.ops import arg_scope # import tensorflow.contrib.layers as layers # ---------------------------------------------------------------- # Arguments and Settings args.message = 'LBT-GAN-smnist_' + args.message np.random.seed(12345) tf.set_random_seed(args.seed) # copy file for reproducibility logger, dirname = utils.setup_logging(args) script_fn = inspect.getfile(inspect.currentframe()) script_src = os.path.abspath(script_fn) script_dst = os.path.abspath(os.path.join(dirname, script_fn)) shutil.copyfile(script_src, script_dst) logger.info("script copied from %s to %s" % (script_src, script_dst)) # print arguments for k, v in sorted(vars(args).items()): logger.info(" %20s: %s" % (k, v)) # get arguments batch_size = args.batch_size batch_size_est = args.batch_size_est gen_lr = args.gen_lr dis_lr = args.dis_lr est_lr = args.est_lr lambda_gan = args.lambda_gan beta1 = 0.5 epsilon = 1e-8 max_iter = args.max_iter viz_every = args.viz_every z_dim, vae_z_dim = utils.get_ints(args.z_dims) unrolling_steps = args.unrolling_steps assert unrolling_steps > 0 n_viz = args.n_viz # ---------------------------------------------------------------- # Dataset from dataset import load_mnist_realval, DataSet train_x, _, test_x, _ = load_mnist_realval(validation=False, asimage=True) train_x = np.concatenate([train_x, test_x], 0) ids = np.random.randint(0, train_x.shape[0], size=(128000, 3)) X_training = np.zeros(shape=(ids.shape[0], 28, 28, ids.shape[1])) for i in range(ids.shape[0]): for j in range(ids.shape[1]): X_training[i, :, :, j] = train_x[ids[i, j], :, :, 0] smnist = DataSet(X_training, None) # data_channel = 3 x_dim = 784 * 3 dim_input = (28, 28) feature_dim = 16 # ---------------------------------------------------------------- # Model setup logger.info("Setting up model ...") def discriminator(x, Reuse=tf.AUTO_REUSE, is_training=True): def leaky_relu(x, alpha=0.2): return tf.maximum(alpha * x, x) D_feature_dim = int(feature_dim * args.d_ratio) with tf.variable_scope("discriminator", reuse=Reuse): def bn_layer(x): if args.d_bn is True: # print("Use bn in D") return tf.layers.batch_normalization(x, training=is_training) else: # print("No BN in D") return x x = tf.reshape(x, [batch_size, 28, 28, 3]) conv1 = tf.layers.conv2d(x, D_feature_dim, 4, 2, use_bias=True, padding='same') conv1 = leaky_relu(conv1) conv2 = tf.layers.conv2d(conv1, 2 * D_feature_dim, 4, 2, use_bias=False, padding='same') conv2 = bn_layer(conv2) conv2 = leaky_relu(conv2) conv2 = tf.layers.flatten(conv2) fc1 = tf.layers.dense(conv2, 1024, use_bias=False) fc1 = bn_layer(fc1) fc1 = leaky_relu(fc1) fc2 = tf.layers.dense(fc1, 1) return fc2 def generator(z, Reuse=tf.AUTO_REUSE, flatten=True, is_training=True): if args.g_nonlin == 'relu': # print("Use Relu in G") nonlin = tf.nn.relu else: # print("Use tanh in G") nonlin = tf.nn.tanh # nonlin = tf.nn.relu if args.g_nonlin == 'relu' else tf.nn.tanh # norm_prms = {'is_training': is_training, 'decay': 0.9, 'scale': False} with tf.variable_scope("generator", reuse=Reuse): # lx = layers.fully_connected(z, 1024) lx = tf.layers.dense(z, 1024, use_bias=False) lx = tf.layers.batch_normalization(lx, training=is_training) lx = nonlin(lx) lx = tf.layers.dense(lx, feature_dim * 2 * 7 * 7, use_bias=False) lx = tf.layers.batch_normalization(lx, training=is_training) lx = nonlin(lx) lx = tf.reshape(lx, [-1, 7, 7, feature_dim * 2]) lx = tf.layers.conv2d_transpose(lx, feature_dim, 5, 2, use_bias=False, padding='same') lx = tf.layers.batch_normalization(lx, training=is_training) lx = nonlin(lx) lx = tf.layers.conv2d_transpose(lx, 3, 5, 2, padding='same') lx = tf.nn.sigmoid(lx) if flatten is True: lx = tf.layers.flatten(lx) return lx nonlin = tf.nn.relu def compute_est_samples(z, params=None, reuse=tf.AUTO_REUSE): with tf.variable_scope("estimator"): with arg_scope([nn.dense], params=params): with tf.variable_scope("decoder", reuse=reuse): h_dec_1 = nn.dense(z, vae_z_dim, 200 * 2, "dense1", nonlinearity=nonlin) h_dec_2 = nn.dense(h_dec_1, 200 * 2, 500 * 2, "dense2", nonlinearity=nonlin) x_mean = nn.dense(h_dec_2, 500 * 2, x_dim, "dense3", nonlinearity=None) return x_mean def compute_est_ll(x, params=None, reuse=tf.AUTO_REUSE): with tf.variable_scope("estimator"): with arg_scope([nn.dense], params=params): with tf.variable_scope("encoder", reuse=reuse): h_enc_1 = nn.dense(x, x_dim, 500 * 2, "dense1", nonlinearity=nonlin) # h_enc_1 = nn.batch_norm(h_enc_1, "bn1", 129, 2) h_enc_2 = nn.dense(h_enc_1, 500 * 2, 200 * 2, "dense2", nonlinearity=nonlin) # h_enc_2 = nn.batch_norm(h_enc_2, "bn2", 128, 2) z_mean = nn.dense(h_enc_2, 200 * 2, vae_z_dim, "dense3", nonlinearity=None) z_logvar = nn.dense(h_enc_2, 200 * 2, vae_z_dim, "dense4", nonlinearity=None) epsilon = tf.random_normal(tf.shape(z_mean), dtype=tf.float32) z = z_mean + tf.exp(0.5 * z_logvar) * epsilon with tf.variable_scope("decoder", reuse=reuse): h_dec_1 = nn.dense(z, vae_z_dim, 200 * 2, "dense1", nonlinearity=nonlin) # h_dec_1 = nn.batch_norm(h_dec_1, "bn1", 127, 2) h_dec_2 = nn.dense(h_dec_1, 200 * 2, 500 * 2, "dense2", nonlinearity=nonlin) # h_dec_2 = nn.batch_norm(h_dec_2, "bn2", 128, 2) x_mean = nn.dense(h_dec_2, 500 * 2, x_dim, "dense3", nonlinearity=None) elbo = tf.reduce_mean( tf.reduce_sum(-tf.nn.sigmoid_cross_entropy_with_logits( logits=x_mean, labels=x), axis=1) - tf.reduce_sum(-0.5 * (1 + z_logvar - tf.square(z_mean) - tf.exp(z_logvar)), axis=1)) return elbo, tf.nn.sigmoid(x_mean) def compute_est_updated_with_SGD(x, lr=0.001, params=None): elbo, _ = compute_est_ll(x, params=params) grads = tf.gradients(elbo, params.values()) new_params = params.copy() for key, g in zip(params, grads): new_params[key] += lr * g return elbo, new_params def compute_est_updated_with_Adam(x, lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-7, decay=0., params=None, adam_params=None): elbo, _ = compute_est_ll(x, params=params) grads = tf.gradients(elbo, params.values()) new_params = params.copy() new_adam_params = adam_params.copy() new_adam_params['iterations'] += 1 lr = lr * \ (1. / (1. + decay * tf.cast(adam_params['iterations'], tf.float32))) t = tf.cast(new_adam_params['iterations'], tf.float32) lr_t = lr * (tf.sqrt(1. - tf.pow(beta_2, t)) / (1. - tf.pow(beta_1, t))) for key, g in zip(params, grads): new_adam_params['m_' + key] = ( beta_1 * adam_params['m_' + key]) + (1. - beta_1) * g new_adam_params['v_' + key] = tf.stop_gradient( (beta_2 * adam_params['v_' + key]) + (1. - beta_2) * tf.square(g)) new_params[key] = params[key] + lr_t * new_adam_params[ 'm_' + key] / tf.sqrt(new_adam_params['v_' + key] + epsilon) return elbo, new_params, new_adam_params lr = tf.placeholder(tf.float32) data = tf.placeholder(tf.float32, shape=(batch_size, x_dim)) # Construct generator and estimator nets est_params_dict = OrderedDict() _, _ = compute_est_ll(data, params=est_params_dict) gen_noise = tf.random_normal((batch_size_est, z_dim), dtype=tf.float32) samples_gen = generator(gen_noise) vae_noise = tf.random_normal((batch_size_est, vae_z_dim), dtype=tf.float32) samples_est = tf.nn.sigmoid( compute_est_samples(z=vae_noise, params=est_params_dict)) # for key in est_params_dict: # print(key, est_params_dict[key]) adam_params_dict = OrderedDict() with tf.variable_scope("adam"): adam_params_dict['iterations'] = tf.Variable(0, dtype=tf.int64, name='iterations') for key in est_params_dict: adam_params_dict['m_' + key] = tf.Variable(tf.zeros_like( est_params_dict[key]), name='m_' + key) adam_params_dict['v_' + key] = tf.Variable(tf.zeros_like( est_params_dict[key]), name='v_' + key) gen_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "generator") est_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "estimator") adam_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "adam") # unrolling estimator updates cur_params = est_params_dict cur_adam_params = adam_params_dict elbo_genx_at_steps = [] for _ in range(unrolling_steps): samples_gen = generator( tf.random_normal((batch_size_est, z_dim), dtype=tf.float32)) elbo_genx_step, cur_params, cur_adam_params = compute_est_updated_with_Adam( samples_gen, lr=lr, beta_1=beta1, epsilon=epsilon, params=cur_params, adam_params=cur_adam_params) elbo_genx_at_steps.append(elbo_genx_step) # estimator update updates = [] for key in est_params_dict: updates.append(tf.assign(est_params_dict[key], cur_params[key])) for key in adam_params_dict: updates.append(tf.assign(adam_params_dict[key], cur_adam_params[key])) e_train_op = tf.group(*updates, name="e_train_op") # Optimize the generator on the unrolled ELBO loss unrolled_elbo_data, _ = compute_est_ll(data, params=cur_params) # unrolled_elbo_samp, _ = compute_est_ll( # tf.stop_gradient(samples_gen), params=cur_params) # GAN-loss for discriminator and generator samples_gen_gan = generator( tf.random_normal((batch_size_est, z_dim), dtype=tf.float32)) fake_D_output = discriminator(samples_gen_gan) real_D_output = discriminator(data) # print(fake_D_output, real_D_output) ganloss_g = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits( labels=tf.ones_like(fake_D_output), logits=fake_D_output)) ganloss_D_fake = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits( labels=tf.zeros_like(fake_D_output), logits=fake_D_output)) ganloss_D_real = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits( labels=tf.ones_like(real_D_output), logits=real_D_output)) use_e_sym = tf.placeholder(tf.float32, shape=(), name="use_E") if args.lbt: logger.info("Using lbt") object_g = lambda_gan * ganloss_g - use_e_sym * unrolled_elbo_data else: logger.info("Using GAN") object_g = lambda_gan * ganloss_g # - use_e_sym * unrolled_elbo_data # object_g = -1 * unrolled_elbo_data object_d = ganloss_D_fake + ganloss_D_real dis_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "discriminator") g_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, "generator") g_train_opt = tf.train.AdamOptimizer(learning_rate=gen_lr, beta1=beta1, epsilon=epsilon) # g_train_opt = tf.train.RMSPropOptimizer(learning_rate=gen_lr, epsilon=epsilon) g_grads = g_train_opt.compute_gradients(object_g, var_list=gen_vars) # g_grads_clipped = [(tf.clip_by_value(grad, -1., 1.), var) for grad, var in g_grads] g_grads_, g_vars_ = zip(*g_grads) g_grads_clipped_, g_grads_norm_ = tf.clip_by_global_norm(g_grads_, 5.) g_grads_clipped = zip(g_grads_clipped_, g_vars_) if args.clip_grad: logger.info("Clipping gradients of generator parameters.") with tf.control_dependencies(g_update_ops): g_train_op = g_train_opt.apply_gradients(g_grads_clipped) else: with tf.control_dependencies(g_update_ops): g_train_op = g_train_opt.apply_gradients(g_grads) # g_train_op = g_train_opt.apply_gradients(g_grads) d_train_opt = tf.train.AdamOptimizer(learning_rate=dis_lr, beta1=beta1, epsilon=epsilon) d_update_op = tf.get_collection(tf.GraphKeys.UPDATE_OPS, "discriminator") with tf.control_dependencies(d_update_op): d_train_op = d_train_opt.minimize(object_d, var_list=dis_vars) # ---------------------------------------------------------------- # Training sess = tf.InteractiveSession() sess.run(tf.global_variables_initializer()) saver = tf.train.Saver(max_to_keep=None) if args.model_path: saver.restore(sess, args.model_path) # # print variables # logger.info("Generator parameters:") # for p in gen_vars: # logger.debug("%s: %s" % (p.name, sess.run(tf.shape(p)))) # logger.info("Estimator parameters:") # for p in est_vars: # logger.debug("%s: %s" % (p.name, sess.run(tf.shape(p)))) # logger.info("Adam parameters:") # for p in adam_vars: # logger.debug("%s: %s" % (p.name, sess.run(tf.shape(p)))) elbo_vals = [] ganloss_vals = [] tgan_g, tgan_d_fake, tgan_d_real = 0., 0., 0. elbo_genx_val, elbo_data_val, gradients_nrom = -np.inf, -np.inf, 0 use_e_flag = 0. for i in range(max_iter + 1): # train estimator and generator # x_mini_batch0 = mnist.train.next_batch(batch_size)[0].reshape( # [batch_size, 28, 28, 1]) # x_mini_batch1 = mnist.train.next_batch(batch_size)[0].reshape( # [batch_size, 28, 28, 1]) # x_mini_batch2 = mnist.train.next_batch(batch_size)[0].reshape( # [batch_size, 28, 28, 1]) # x_mini_batch = np.concatenate( # [x_mini_batch0, x_mini_batch1, x_mini_batch2], # axis=-1).reshape([batch_size, 28 * 28 * 3]) x_mini_batch = smnist.next_batch(batch_size)[0].reshape( [batch_size, 28 * 28 * 3]) if i > 3000: use_e_flag = 1. for _ in range(args.n_est): elbo_genx_val, _ = sess.run( [elbo_genx_at_steps[-1], e_train_op], feed_dict={lr: 3. * est_lr}) for _ in range(args.n_dis): _, tgan_g, tgan_d_real, tgan_d_fake = sess.run( [d_train_op, ganloss_g, ganloss_D_real, ganloss_D_fake], feed_dict={data: x_mini_batch}) elbo_data_val, gradients_nrom, _ = sess.run( [unrolled_elbo_data, g_grads_norm_, g_train_op], feed_dict={ data: x_mini_batch, lr: est_lr, use_e_sym: use_e_flag }) elbo_vals.append([elbo_genx_val, elbo_data_val]) ganloss_vals.append([tgan_g, tgan_d_real, tgan_d_fake]) # visualization if i % viz_every == 0: np_samples_gen, np_samples_est, np_data = sess.run( [samples_gen, samples_est, data], feed_dict={data: x_mini_batch}) np_samples_est = np_samples_est.reshape([-1, 28, 28, 3]).transpose( [0, 3, 1, 2]).reshape([-1, 28 * 28 * 3]) np_samples_gen = np_samples_gen.reshape([-1, 28, 28, 3]).transpose( [0, 3, 1, 2]).reshape([-1, 28 * 28 * 3]) np_data = np_data.reshape([-1, 28, 28, 3]).transpose( [0, 3, 1, 2]).reshape([-1, 28 * 28 * 3]) paramgraphics.mat_to_img(np_samples_gen[:n_viz], dim_input, colorImg=True, save_path=os.path.join( dirname, 'sample_' + str(i) + '_gen.png')) paramgraphics.mat_to_img(np_data[:n_viz], dim_input, colorImg=True, save_path=os.path.join( dirname, 'sample_' + str(i) + '_dat.png')) paramgraphics.mat_to_img(np_samples_est[:n_viz], dim_input, colorImg=True, save_path=os.path.join( dirname, 'sample_' + str(i) + '_est.png')) fig = plt.figure(figsize=(6, 4)) plt.plot(elbo_vals, '.', markersize=2, markeredgecolor='none', linestyle='none', alpha=min(1.0, 0.01 * max_iter / (i + 1))) plt.ylim((-200.0, 0.0)) legend = plt.legend(('elbo_genx', 'elbo_data'), markerscale=6) for lh in legend.legendHandles: lh._legmarker.set_alpha(1.) plt.grid(True) plt.tight_layout() plt.savefig(os.path.join(dirname, 'curve.png'), bbox_inches='tight') plt.close(fig) # training log if i % viz_every == 0: elbo_genx_ma_val, elbo_data_ma_val = np.mean(elbo_vals[-200:], axis=0) logger.info( "Iter %d: gradients norm = %.4f. samples LL = %.4f, data LL = %.4f." % (i, gradients_nrom, elbo_genx_ma_val, elbo_data_ma_val)) logger.info( "Iter %d: gan_g = %.4f. gan_d_real = %.4f, gan_d_fake = %.4f." % (i, tgan_g, tgan_d_real, tgan_d_fake)) if i % args.model_every == 0: saver.save(sess, os.path.join(dirname, 'model_' + str(i)))
def print_param(param_name): val = sess.run(param_name) if 'log' in param_name: val = np.exp(val) val_normalized = val / np.sum(val) print('{}: {}'.format(param_name, val_normalized)) if __name__ == "__main__": tf.set_random_seed(666) np.random.seed(666) # Load data from MNIST data_dir = './data' data_path = os.path.join(data_dir, 'mnist.pkl.gz') x_train, t_train, x_val, t_val, x_test, t_test = dataset.load_mnist_realval( data_path) x_train = np.vstack([x_train, x_val]).astype('float32') n_x = x_train.shape[1] # 784=28*28 # Define model parameters n_h = 40 # D n_z = 10 # K # Define training/evaluation parameters lb_samples = 10 epoches = 100 batch_size = 100 iters = x_train.shape[0] // batch_size learning_rate = 0.001 save_freq = 20 ckpt_path = "./ckpt/10x10_2"
def main(): # Load MNIST data_path = os.path.join(conf.data_dir, "mnist.pkl.gz") x_train, t_train, x_valid, t_valid, x_test, t_test = \ dataset.load_mnist_realval(data_path) x_train = np.vstack([x_train, x_valid]) x_test = np.random.binomial(1, x_test, size=x_test.shape) x_dim = x_train.shape[1] # Define model parameters z_dim = 40 # Build the computation graph n_particles = tf.placeholder(tf.int32, shape=[], name="n_particles") x_input = tf.placeholder(tf.float32, shape=[None, x_dim], name="x") x = tf.cast(tf.less(tf.random_uniform(tf.shape(x_input)), x_input), tf.int32) n = tf.placeholder(tf.int32, shape=[], name="n") model = build_gen(x_dim, z_dim, n, n_particles) variational = build_q_net(x, z_dim, n_particles) lower_bound = zs.variational.elbo(model, {"x": x}, variational=variational, axis=0) cost = tf.reduce_mean(lower_bound.sgvb()) lower_bound = tf.reduce_mean(lower_bound) # # Importance sampling estimates of marginal log likelihood is_log_likelihood = tf.reduce_mean( zs.is_loglikelihood(model, {"x": x}, proposal=variational, axis=0)) optimizer = tf.train.AdamOptimizer(learning_rate=0.001) infer_op = optimizer.minimize(cost) # Random generation x_gen = tf.reshape(model.observe()["x_mean"], [-1, 28, 28, 1]) # Define training/evaluation parameters epochs = 3000 batch_size = 128 iters = x_train.shape[0] // batch_size save_freq = 10 test_freq = 10 test_batch_size = 400 test_iters = x_test.shape[0] // test_batch_size result_path = "results/vae" if not os.path.exists(result_path): os.makedirs(result_path) # Run the inference with tf.Session() as sess: sess.run(tf.global_variables_initializer()) for epoch in range(1, epochs + 1): time_epoch = -time.time() np.random.shuffle(x_train) lbs = [] for t in range(iters): x_batch = x_train[t * batch_size:(t + 1) * batch_size] _, lb = sess.run([infer_op, lower_bound], feed_dict={ x_input: x_batch, n_particles: 1, n: batch_size }) lbs.append(lb) time_epoch += time.time() print("Epoch {} ({:.1f}s): Lower bound = {}".format( epoch, time_epoch, np.mean(lbs))) if epoch % test_freq == 0: time_test = -time.time() test_lbs, test_lls = [], [] for t in range(test_iters): test_x_batch = x_test[t * test_batch_size:(t + 1) * test_batch_size] test_lb = sess.run(lower_bound, feed_dict={ x: test_x_batch, n_particles: 1, n: test_batch_size }) test_ll = sess.run(is_log_likelihood, feed_dict={ x: test_x_batch, n_particles: 1000, n: test_batch_size }) test_lbs.append(test_lb) test_lls.append(test_ll) time_test += time.time() print(">>> TEST ({:.1f}s)".format(time_test)) print(">> Test lower bound = {}".format(np.mean(test_lbs))) print('>> Test log likelihood (IS) = {}'.format( np.mean(test_lls))) if epoch % save_freq == 0: images = sess.run(x_gen, feed_dict={n: 100, n_particles: 1}) name = os.path.join(result_path, "vae.epoch.{}.png".format(epoch)) save_image_collections(images, name)
def train_vae(args): # Load MNIST data_path = os.path.join(args.data_dir, "mnist.pkl.gz") x_train, y_train, x_valid, y_valid, x_test, y_test = dataset.load_mnist_realval(data_path) x_train = np.random.binomial(1, x_train, size=x_train.shape) x_dim = x_train.shape[1] y_dim = y_train.shape[1] # Define model parameters z_dim = args.z_dim # Build the computation graph x = tf.placeholder(tf.float32, shape=[None, x_dim], name="x") y = tf.placeholder(tf.float32, shape=[None, y_dim], name="y") n = tf.placeholder(tf.int32, shape=[], name="n") # Get the models model = build_gen(y, x_dim, z_dim, n) variational = build_q_net(x, y, z_dim) # Calculate ELBO lower_bound = zs.variational.elbo(model, {"x": x }, variational=variational) cost = tf.reduce_mean(lower_bound.sgvb()) lower_bound = tf.reduce_mean(lower_bound) optimizer = tf.train.AdamOptimizer(learning_rate=args.lr) infer_op = optimizer.minimize(cost) # Random generation x_gen = tf.reshape(model.observe()["x_mean"], [-1, 28, 28, 1]) # Compute class labels labels = [] for c in range(10): l = np.zeros((100, 10)) l[:,c] = 1 labels.append(l) epochs = args.epochs batch_size = args.batch_size iters = x_train.shape[0] // batch_size saver = tf.train.Saver(max_to_keep=10) save_model_freq = min(100, args.epochs) # Run the Inference with tf.Session() as sess: sess.run(tf.global_variables_initializer()) ckpt_file = tf.train.latest_checkpoint(args.checkpoints_path) begin_epoch = 1 if(ckpt_file is not None): print('Restoring model from {}...'.format(ckpt_file)) begin_epoch = int(ckpt_file.split('.')[-2]) + 1 saver.restore(sess, ckpt_file) for epoch in range(1, epochs+1): time_epoch = -time.time() lbs = [] for t in range(iters): x_batch = x_train[t*batch_size:(t+1)*batch_size] y_batch = y_train[t*batch_size:(t+1)*batch_size] _, lb = sess.run( [infer_op, lower_bound], feed_dict={ x: x_batch, y: y_batch, n: batch_size } ) lbs.append(lb) time_epoch += time.time() print("Epoch {} ({:.1f}s): Lower bound = {}".format(epoch, time_epoch, np.mean(lbs))) if(epoch % args.save_model_freq == 0): save_path = os.path.join(args.checkpoints_path, "vae.epoch.{}.ckpt".format(epoch)) if not os.path.exists(os.path.dirname(save_path)): os.makedirs(os.path.dirname(save_path)) saver.save(sess, save_path) if epoch % args.save_img_freq == 0: for c in range(10): images = sess.run(x_gen, feed_dict={y: labels[c], n: 100 }) name = os.path.join(args.results_path, str(epoch).zfill(3), "{}.png".format(c)) utils.save_image_collections(images, name)