Example #1
0
def infogan():
    real_images = prepare_data()
    # Dimensions of the structured and unstructured noise dimensions.
    cat_dim, cont_dim, noise_dims = 10, 2, 64

    # (HEBI: using infogan generator)
    generator_fn = functools.partial(infogan_generator,
                                     categorical_dim=cat_dim)
    discriminator_fn = functools.partial(infogan_discriminator,
                                         categorical_dim=cat_dim,
                                         continuous_dim=cont_dim)
    unstructured_inputs, structured_inputs = util.get_infogan_noise(
        batch_size, cat_dim, cont_dim, noise_dims)

    infogan_model = tfgan.infogan_model(
        generator_fn=generator_fn,
        discriminator_fn=discriminator_fn,
        real_data=real_images,
        unstructured_generator_inputs=unstructured_inputs,
        structured_generator_inputs=structured_inputs)
    infogan_loss = tfgan.gan_loss(
        infogan_model,
        gradient_penalty_weight=1.0,
        # (HEBI: the mutual information penalty!!)
        mutual_information_penalty_weight=1.0)

    # Sanity check that we can evaluate our losses.
    evaluate_tfgan_loss(infogan_loss)
    # train ops
    generator_optimizer = tf.train.AdamOptimizer(0.001, beta1=0.5)
    discriminator_optimizer = tf.train.AdamOptimizer(0.00009, beta1=0.5)
    gan_train_ops = tfgan.gan_train_ops(infogan_model, infogan_loss,
                                        generator_optimizer,
                                        discriminator_optimizer)
    # train
    tfgan.gan_train(
        gan_train_ops,
        hooks=[tf.train.StopAtStepHook(num_steps=FLAGS.max_number_of_steps)],
        logdir=FLAGS.train_log_dir)

# ### InfoGANModel Tuple
#
# The InfoGAN model requires some extra information, so we use a subclassed tuple.

# In[14]:

# Dimensions of the structured and unstructured noise dimensions.
cat_dim, cont_dim, noise_dims = 10, 2, 64

generator_fn = functools.partial(infogan_generator, categorical_dim=cat_dim)
discriminator_fn = functools.partial(infogan_discriminator,
                                     categorical_dim=cat_dim,
                                     continuous_dim=cont_dim)
unstructured_inputs, structured_inputs = util.get_infogan_noise(
    batch_size, cat_dim, cont_dim, noise_dims)

infogan_model = tfgan.infogan_model(
    generator_fn=generator_fn,
    discriminator_fn=discriminator_fn,
    real_data=real_images,
    unstructured_generator_inputs=unstructured_inputs,
    structured_generator_inputs=structured_inputs)

# <a id='infogan_loss'></a>
# ## Losses
#
# The loss will be the same as before, with the addition of the mutual information loss.

# In[17]: