def infogan(): real_images = prepare_data() # Dimensions of the structured and unstructured noise dimensions. cat_dim, cont_dim, noise_dims = 10, 2, 64 # (HEBI: using infogan generator) generator_fn = functools.partial(infogan_generator, categorical_dim=cat_dim) discriminator_fn = functools.partial(infogan_discriminator, categorical_dim=cat_dim, continuous_dim=cont_dim) unstructured_inputs, structured_inputs = util.get_infogan_noise( batch_size, cat_dim, cont_dim, noise_dims) infogan_model = tfgan.infogan_model( generator_fn=generator_fn, discriminator_fn=discriminator_fn, real_data=real_images, unstructured_generator_inputs=unstructured_inputs, structured_generator_inputs=structured_inputs) infogan_loss = tfgan.gan_loss( infogan_model, gradient_penalty_weight=1.0, # (HEBI: the mutual information penalty!!) mutual_information_penalty_weight=1.0) # Sanity check that we can evaluate our losses. evaluate_tfgan_loss(infogan_loss) # train ops generator_optimizer = tf.train.AdamOptimizer(0.001, beta1=0.5) discriminator_optimizer = tf.train.AdamOptimizer(0.00009, beta1=0.5) gan_train_ops = tfgan.gan_train_ops(infogan_model, infogan_loss, generator_optimizer, discriminator_optimizer) # train tfgan.gan_train( gan_train_ops, hooks=[tf.train.StopAtStepHook(num_steps=FLAGS.max_number_of_steps)], logdir=FLAGS.train_log_dir)
# ### InfoGANModel Tuple # # The InfoGAN model requires some extra information, so we use a subclassed tuple. # In[14]: # Dimensions of the structured and unstructured noise dimensions. cat_dim, cont_dim, noise_dims = 10, 2, 64 generator_fn = functools.partial(infogan_generator, categorical_dim=cat_dim) discriminator_fn = functools.partial(infogan_discriminator, categorical_dim=cat_dim, continuous_dim=cont_dim) unstructured_inputs, structured_inputs = util.get_infogan_noise( batch_size, cat_dim, cont_dim, noise_dims) infogan_model = tfgan.infogan_model( generator_fn=generator_fn, discriminator_fn=discriminator_fn, real_data=real_images, unstructured_generator_inputs=unstructured_inputs, structured_generator_inputs=structured_inputs) # <a id='infogan_loss'></a> # ## Losses # # The loss will be the same as before, with the addition of the mutual information loss. # In[17]: