Exemple #1
0
# DATA. MNIST batches are fed at training time.
mnist = input_data.read_data_sets(DATA_DIR, one_hot=True)
x_ph = tf.placeholder(tf.float32, [M, 784])
z_ph = tf.placeholder(tf.float32, [M, d])

# MODEL
with tf.variable_scope("Gen"):
  xf = gen_data(z_ph, hidden_units)
  zf = gen_latent(x_ph, hidden_units)

# INFERENCE:
optimizer = tf.train.AdamOptimizer()
optimizer_d = tf.train.AdamOptimizer()
inference = ed.BiGANInference(
    latent_vars={zf: z_ph}, data={xf: x_ph},
    discriminator=discriminative_network)

inference.initialize(
    optimizer=optimizer, optimizer_d=optimizer_d, n_iter=100000, n_print=3000)

sess = ed.get_session()
init_op = tf.global_variables_initializer()
sess.run(init_op)

idx = np.random.randint(M, size=16)
i = 0
for t in range(inference.n_iter):
  if t % inference.n_print == 1:

    samples = sess.run(xf, feed_dict={z_ph: z_batch})
Exemple #2
0
def main(_):
    ed.set_seed(42)

    # DATA. MNIST batches are fed at training time.
    (x_train, _), (x_test, _) = mnist(FLAGS.data_dir)
    x_train_generator = generator(x_train, FLAGS.M)
    x_ph = tf.placeholder(tf.float32, [FLAGS.M, 784])
    z_ph = tf.placeholder(tf.float32, [FLAGS.M, FLAGS.d])

    # MODEL
    with tf.variable_scope("Gen"):
        xf = gen_data(z_ph, FLAGS.hidden_units)
        zf = gen_latent(x_ph, FLAGS.hidden_units)

    # INFERENCE:
    optimizer = tf.train.AdamOptimizer()
    optimizer_d = tf.train.AdamOptimizer()
    inference = ed.BiGANInference(latent_vars={zf: z_ph},
                                  data={xf: x_ph},
                                  discriminator=discriminative_network)

    inference.initialize(optimizer=optimizer,
                         optimizer_d=optimizer_d,
                         n_iter=100000,
                         n_print=3000)

    sess = ed.get_session()
    init_op = tf.global_variables_initializer()
    sess.run(init_op)

    idx = np.random.randint(FLAGS.M, size=16)
    i = 0
    for t in range(inference.n_iter):
        if t % inference.n_print == 1:

            samples = sess.run(xf, feed_dict={z_ph: z_batch})
            samples = samples[idx, ]
            fig = plot(samples)
            plt.savefig(os.path.join(FLAGS.out_dir, '{}{}.png').format(
                'Generated',
                str(i).zfill(3)),
                        bbox_inches='tight')
            plt.close(fig)

            fig = plot(x_batch[idx, ])
            plt.savefig(os.path.join(FLAGS.out_dir, '{}{}.png').format(
                'Base',
                str(i).zfill(3)),
                        bbox_inches='tight')
            plt.close(fig)

            zsam = sess.run(zf, feed_dict={x_ph: x_batch})
            reconstructions = sess.run(xf, feed_dict={z_ph: zsam})
            reconstructions = reconstructions[idx, ]
            fig = plot(reconstructions)
            plt.savefig(os.path.join(FLAGS.out_dir, '{}{}.png').format(
                'Reconstruct',
                str(i).zfill(3)),
                        bbox_inches='tight')
            plt.close(fig)

            i += 1

        x_batch = next(x_train_generator)
        z_batch = np.random.normal(0, 1, [FLAGS.M, FLAGS.d])

        info_dict = inference.update(feed_dict={x_ph: x_batch, z_ph: z_batch})
        inference.print_progress(info_dict)