# DATA. MNIST batches are fed at training time. mnist = input_data.read_data_sets(DATA_DIR, one_hot=True) x_ph = tf.placeholder(tf.float32, [M, 784]) z_ph = tf.placeholder(tf.float32, [M, d]) # MODEL with tf.variable_scope("Gen"): xf = gen_data(z_ph, hidden_units) zf = gen_latent(x_ph, hidden_units) # INFERENCE: optimizer = tf.train.AdamOptimizer() optimizer_d = tf.train.AdamOptimizer() inference = ed.BiGANInference( latent_vars={zf: z_ph}, data={xf: x_ph}, discriminator=discriminative_network) inference.initialize( optimizer=optimizer, optimizer_d=optimizer_d, n_iter=100000, n_print=3000) sess = ed.get_session() init_op = tf.global_variables_initializer() sess.run(init_op) idx = np.random.randint(M, size=16) i = 0 for t in range(inference.n_iter): if t % inference.n_print == 1: samples = sess.run(xf, feed_dict={z_ph: z_batch})
def main(_): ed.set_seed(42) # DATA. MNIST batches are fed at training time. (x_train, _), (x_test, _) = mnist(FLAGS.data_dir) x_train_generator = generator(x_train, FLAGS.M) x_ph = tf.placeholder(tf.float32, [FLAGS.M, 784]) z_ph = tf.placeholder(tf.float32, [FLAGS.M, FLAGS.d]) # MODEL with tf.variable_scope("Gen"): xf = gen_data(z_ph, FLAGS.hidden_units) zf = gen_latent(x_ph, FLAGS.hidden_units) # INFERENCE: optimizer = tf.train.AdamOptimizer() optimizer_d = tf.train.AdamOptimizer() inference = ed.BiGANInference(latent_vars={zf: z_ph}, data={xf: x_ph}, discriminator=discriminative_network) inference.initialize(optimizer=optimizer, optimizer_d=optimizer_d, n_iter=100000, n_print=3000) sess = ed.get_session() init_op = tf.global_variables_initializer() sess.run(init_op) idx = np.random.randint(FLAGS.M, size=16) i = 0 for t in range(inference.n_iter): if t % inference.n_print == 1: samples = sess.run(xf, feed_dict={z_ph: z_batch}) samples = samples[idx, ] fig = plot(samples) plt.savefig(os.path.join(FLAGS.out_dir, '{}{}.png').format( 'Generated', str(i).zfill(3)), bbox_inches='tight') plt.close(fig) fig = plot(x_batch[idx, ]) plt.savefig(os.path.join(FLAGS.out_dir, '{}{}.png').format( 'Base', str(i).zfill(3)), bbox_inches='tight') plt.close(fig) zsam = sess.run(zf, feed_dict={x_ph: x_batch}) reconstructions = sess.run(xf, feed_dict={z_ph: zsam}) reconstructions = reconstructions[idx, ] fig = plot(reconstructions) plt.savefig(os.path.join(FLAGS.out_dir, '{}{}.png').format( 'Reconstruct', str(i).zfill(3)), bbox_inches='tight') plt.close(fig) i += 1 x_batch = next(x_train_generator) z_batch = np.random.normal(0, 1, [FLAGS.M, FLAGS.d]) info_dict = inference.update(feed_dict={x_ph: x_batch, z_ph: z_batch}) inference.print_progress(info_dict)