Esempio n. 1
0
    def test_provide_data(self, mock_tfds):
        batch_size = 5
        mock_tfds.load.return_value = self.mock_ds

        images, labels = data_provider.provide_data('test', batch_size)

        with self.cached_session() as sess:
            sess.run(tf.compat.v1.tables_initializer())
            images, labels = sess.run([images, labels])
        self.assertTupleEqual(images.shape, (batch_size, 28, 28, 1))
        self.assertTrue(np.all(np.abs(images) <= 1))
        self.assertTupleEqual(labels.shape, (batch_size, 10))
Esempio n. 2
0
  def test_provide_data_can_be_reinitialized(self, mock_tfds):
    batch_size = 5
    mock_tfds.load.return_value = self.mock_ds

    images, labels = data_provider.provide_data('test', batch_size)

    with self.session() as sess:
      sess.run([images, labels])
      sess.run([images, labels])
    with self.session() as sess:
      sess.run([images, labels])
      sess.run([images, labels])
Esempio n. 3
0
def evaluate(hparams, run_eval_loop=True):
    """Runs an evaluation loop.

  Args:
    hparams: An HParams instance containing the eval hyperparameters.
    run_eval_loop: Whether to run the full eval loop. Set to False for testing.
  """
    # Fetch real images.
    with tf.compat.v1.name_scope('inputs'):
        real_images, _ = data_provider.provide_data(
            'train', hparams.num_images_generated, hparams.dataset_dir)

    image_write_ops = None
    if hparams.eval_real_images:
        tf.compat.v1.summary.scalar(
            'MNIST_Classifier_score',
            util.mnist_score(real_images, hparams.classifier_filename))
    else:
        # In order for variables to load, use the same variable scope as in the
        # train job.
        with tf.compat.v1.variable_scope('Generator'):
            images = networks.unconditional_generator(tf.random.normal(
                [hparams.num_images_generated, hparams.noise_dims]),
                                                      is_training=False)
        tf.compat.v1.summary.scalar(
            'MNIST_Frechet_distance',
            util.mnist_frechet_distance(real_images, images,
                                        hparams.classifier_filename))
        tf.compat.v1.summary.scalar(
            'MNIST_Classifier_score',
            util.mnist_score(images, hparams.classifier_filename))
        if hparams.num_images_generated >= 100 and hparams.write_to_disk:
            reshaped_images = tfgan.eval.image_reshaper(images[:100, ...],
                                                        num_cols=10)
            uint8_images = data_provider.float_image_to_uint8(reshaped_images)
            image_write_ops = tf.io.write_file(
                '%s/%s' % (hparams.eval_dir, 'unconditional_gan.png'),
                tf.image.encode_png(uint8_images[0]))

    # For unit testing, use `run_eval_loop=False`.
    if not run_eval_loop:
        return
    evaluation.evaluate_repeatedly(
        hparams.checkpoint_dir,
        hooks=[
            evaluation.SummaryAtEndHook(hparams.eval_dir),
            evaluation.StopAfterNEvalsHook(1)
        ],
        eval_ops=image_write_ops,
        max_number_of_evaluations=hparams.max_number_of_evaluations)
Esempio n. 4
0
    def test_provide_data_can_be_reinitialized(self, mock_tfds):
        if tf.executing_eagerly():
            # Trying to access properties or call methods on the result of
            # self.session().
            return
        batch_size = 5
        mock_tfds.load.return_value = self.mock_ds

        images, labels = data_provider.provide_data('test', batch_size)

        with self.session() as sess:
            sess.run([images, labels])
            sess.run([images, labels])
        with self.session() as sess:
            sess.run([images, labels])
            sess.run([images, labels])
Esempio n. 5
0
 def train_input_fn():
   images, _ = data_provider.provide_data(
       'train', batch_size, num_parallel_calls=num_parallel_calls)
   noise = tf.random.normal([batch_size, noise_dims])
   return noise, images
Esempio n. 6
0
def train(hparams):
    """Trains an MNIST GAN.

  Args:
    hparams: An HParams instance containing the hyperparameters for training.
  """
    if not tf.io.gfile.exists(hparams.train_log_dir):
        tf.io.gfile.makedirs(hparams.train_log_dir)

    # Force all input processing onto CPU in order to reserve the GPU for
    # the forward inference and back-propagation.
    with tf.name_scope('inputs'), tf.device('/cpu:0'):
        images, one_hot_labels = data_provider.provide_data(
            'train', hparams.batch_size, num_parallel_calls=4)

    # Define the GANModel tuple. Optionally, condition the GAN on the label or
    # use an InfoGAN to learn a latent representation.
    if hparams.gan_type == 'unconditional':
        gan_model = tfgan.gan_model(
            generator_fn=networks.unconditional_generator,
            discriminator_fn=networks.unconditional_discriminator,
            real_data=images,
            generator_inputs=tf.random.normal(
                [hparams.batch_size, hparams.noise_dims]))
    elif hparams.gan_type == 'conditional':
        noise = tf.random.normal([hparams.batch_size, hparams.noise_dims])
        gan_model = tfgan.gan_model(
            generator_fn=networks.conditional_generator,
            discriminator_fn=networks.conditional_discriminator,
            real_data=images,
            generator_inputs=(noise, one_hot_labels))
    elif hparams.gan_type == 'infogan':
        cat_dim, cont_dim = 10, 2
        generator_fn = functools.partial(networks.infogan_generator,
                                         categorical_dim=cat_dim)
        discriminator_fn = functools.partial(networks.infogan_discriminator,
                                             categorical_dim=cat_dim,
                                             continuous_dim=cont_dim)
        unstructured_inputs, structured_inputs = util.get_infogan_noise(
            hparams.batch_size, cat_dim, cont_dim, hparams.noise_dims)
        gan_model = tfgan.infogan_model(
            generator_fn=generator_fn,
            discriminator_fn=discriminator_fn,
            real_data=images,
            unstructured_generator_inputs=unstructured_inputs,
            structured_generator_inputs=structured_inputs)
    tfgan.eval.add_gan_model_image_summaries(gan_model, hparams.grid_size)

    # Get the GANLoss tuple. You can pass a custom function, use one of the
    # already-implemented losses from the losses library, or use the defaults.
    with tf.name_scope('loss'):
        if hparams.gan_type == 'infogan':
            gan_loss = tfgan.gan_loss(
                gan_model,
                generator_loss_fn=tfgan.losses.modified_generator_loss,
                discriminator_loss_fn=tfgan.losses.modified_discriminator_loss,
                mutual_information_penalty_weight=1.0,
                add_summaries=True)
        else:
            gan_loss = tfgan.gan_loss(gan_model, add_summaries=True)
        tfgan.eval.add_regularization_loss_summaries(gan_model)

    # Get the GANTrain ops using custom optimizers.
    with tf.name_scope('train'):
        gen_lr, dis_lr = _learning_rate(hparams.gan_type)
        train_ops = tfgan.gan_train_ops(
            gan_model,
            gan_loss,
            generator_optimizer=tf.train.AdamOptimizer(gen_lr, 0.5),
            discriminator_optimizer=tf.train.AdamOptimizer(dis_lr, 0.5),
            summarize_gradients=True,
            aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)

    # Run the alternating training loop. Skip it if no steps should be taken
    # (used for graph construction tests).
    status_message = tf.strings.join([
        'Starting train step: ',
        tf.as_string(tf.train.get_or_create_global_step())
    ],
                                     name='status_message')
    if hparams.max_number_of_steps == 0:
        return
    tfgan.gan_train(
        train_ops,
        hooks=[
            tf.estimator.StopAtStepHook(num_steps=hparams.max_number_of_steps),
            tf.estimator.LoggingTensorHook([status_message], every_n_iter=10)
        ],
        logdir=hparams.train_log_dir,
        get_hooks_fn=tfgan.get_joint_train_hooks(),
        save_checkpoint_secs=60)