def test_provide_data(self, mock_tfds): batch_size = 5 mock_tfds.load.return_value = self.mock_ds images, labels = data_provider.provide_data('test', batch_size) with self.cached_session() as sess: sess.run(tf.compat.v1.tables_initializer()) images, labels = sess.run([images, labels]) self.assertTupleEqual(images.shape, (batch_size, 28, 28, 1)) self.assertTrue(np.all(np.abs(images) <= 1)) self.assertTupleEqual(labels.shape, (batch_size, 10))
def test_provide_data_can_be_reinitialized(self, mock_tfds): batch_size = 5 mock_tfds.load.return_value = self.mock_ds images, labels = data_provider.provide_data('test', batch_size) with self.session() as sess: sess.run([images, labels]) sess.run([images, labels]) with self.session() as sess: sess.run([images, labels]) sess.run([images, labels])
def evaluate(hparams, run_eval_loop=True): """Runs an evaluation loop. Args: hparams: An HParams instance containing the eval hyperparameters. run_eval_loop: Whether to run the full eval loop. Set to False for testing. """ # Fetch real images. with tf.compat.v1.name_scope('inputs'): real_images, _ = data_provider.provide_data( 'train', hparams.num_images_generated, hparams.dataset_dir) image_write_ops = None if hparams.eval_real_images: tf.compat.v1.summary.scalar( 'MNIST_Classifier_score', util.mnist_score(real_images, hparams.classifier_filename)) else: # In order for variables to load, use the same variable scope as in the # train job. with tf.compat.v1.variable_scope('Generator'): images = networks.unconditional_generator(tf.random.normal( [hparams.num_images_generated, hparams.noise_dims]), is_training=False) tf.compat.v1.summary.scalar( 'MNIST_Frechet_distance', util.mnist_frechet_distance(real_images, images, hparams.classifier_filename)) tf.compat.v1.summary.scalar( 'MNIST_Classifier_score', util.mnist_score(images, hparams.classifier_filename)) if hparams.num_images_generated >= 100 and hparams.write_to_disk: reshaped_images = tfgan.eval.image_reshaper(images[:100, ...], num_cols=10) uint8_images = data_provider.float_image_to_uint8(reshaped_images) image_write_ops = tf.io.write_file( '%s/%s' % (hparams.eval_dir, 'unconditional_gan.png'), tf.image.encode_png(uint8_images[0])) # For unit testing, use `run_eval_loop=False`. if not run_eval_loop: return evaluation.evaluate_repeatedly( hparams.checkpoint_dir, hooks=[ evaluation.SummaryAtEndHook(hparams.eval_dir), evaluation.StopAfterNEvalsHook(1) ], eval_ops=image_write_ops, max_number_of_evaluations=hparams.max_number_of_evaluations)
def test_provide_data_can_be_reinitialized(self, mock_tfds): if tf.executing_eagerly(): # Trying to access properties or call methods on the result of # self.session(). return batch_size = 5 mock_tfds.load.return_value = self.mock_ds images, labels = data_provider.provide_data('test', batch_size) with self.session() as sess: sess.run([images, labels]) sess.run([images, labels]) with self.session() as sess: sess.run([images, labels]) sess.run([images, labels])
def train_input_fn(): images, _ = data_provider.provide_data( 'train', batch_size, num_parallel_calls=num_parallel_calls) noise = tf.random.normal([batch_size, noise_dims]) return noise, images
def train(hparams): """Trains an MNIST GAN. Args: hparams: An HParams instance containing the hyperparameters for training. """ if not tf.io.gfile.exists(hparams.train_log_dir): tf.io.gfile.makedirs(hparams.train_log_dir) # Force all input processing onto CPU in order to reserve the GPU for # the forward inference and back-propagation. with tf.name_scope('inputs'), tf.device('/cpu:0'): images, one_hot_labels = data_provider.provide_data( 'train', hparams.batch_size, num_parallel_calls=4) # Define the GANModel tuple. Optionally, condition the GAN on the label or # use an InfoGAN to learn a latent representation. if hparams.gan_type == 'unconditional': gan_model = tfgan.gan_model( generator_fn=networks.unconditional_generator, discriminator_fn=networks.unconditional_discriminator, real_data=images, generator_inputs=tf.random.normal( [hparams.batch_size, hparams.noise_dims])) elif hparams.gan_type == 'conditional': noise = tf.random.normal([hparams.batch_size, hparams.noise_dims]) gan_model = tfgan.gan_model( generator_fn=networks.conditional_generator, discriminator_fn=networks.conditional_discriminator, real_data=images, generator_inputs=(noise, one_hot_labels)) elif hparams.gan_type == 'infogan': cat_dim, cont_dim = 10, 2 generator_fn = functools.partial(networks.infogan_generator, categorical_dim=cat_dim) discriminator_fn = functools.partial(networks.infogan_discriminator, categorical_dim=cat_dim, continuous_dim=cont_dim) unstructured_inputs, structured_inputs = util.get_infogan_noise( hparams.batch_size, cat_dim, cont_dim, hparams.noise_dims) gan_model = tfgan.infogan_model( generator_fn=generator_fn, discriminator_fn=discriminator_fn, real_data=images, unstructured_generator_inputs=unstructured_inputs, structured_generator_inputs=structured_inputs) tfgan.eval.add_gan_model_image_summaries(gan_model, hparams.grid_size) # Get the GANLoss tuple. You can pass a custom function, use one of the # already-implemented losses from the losses library, or use the defaults. with tf.name_scope('loss'): if hparams.gan_type == 'infogan': gan_loss = tfgan.gan_loss( gan_model, generator_loss_fn=tfgan.losses.modified_generator_loss, discriminator_loss_fn=tfgan.losses.modified_discriminator_loss, mutual_information_penalty_weight=1.0, add_summaries=True) else: gan_loss = tfgan.gan_loss(gan_model, add_summaries=True) tfgan.eval.add_regularization_loss_summaries(gan_model) # Get the GANTrain ops using custom optimizers. with tf.name_scope('train'): gen_lr, dis_lr = _learning_rate(hparams.gan_type) train_ops = tfgan.gan_train_ops( gan_model, gan_loss, generator_optimizer=tf.train.AdamOptimizer(gen_lr, 0.5), discriminator_optimizer=tf.train.AdamOptimizer(dis_lr, 0.5), summarize_gradients=True, aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N) # Run the alternating training loop. Skip it if no steps should be taken # (used for graph construction tests). status_message = tf.strings.join([ 'Starting train step: ', tf.as_string(tf.train.get_or_create_global_step()) ], name='status_message') if hparams.max_number_of_steps == 0: return tfgan.gan_train( train_ops, hooks=[ tf.estimator.StopAtStepHook(num_steps=hparams.max_number_of_steps), tf.estimator.LoggingTensorHook([status_message], every_n_iter=10) ], logdir=hparams.train_log_dir, get_hooks_fn=tfgan.get_joint_train_hooks(), save_checkpoint_secs=60)