Example #1
0
  def test_unused_update_ops(self, create_gan_model_fn, provide_update_ops):
    if tf.executing_eagerly():
      # None of the usual utilities work in eager.
      return
    model = create_gan_model_fn()
    loss = tfgan.gan_loss(model)

    # Add generator and discriminator update tf.
    with tf.compat.v1.variable_scope(model.generator_scope):
      gen_update_count = tf.compat.v1.get_variable('gen_count', initializer=0)
      gen_update_op = gen_update_count.assign_add(1)
      tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.UPDATE_OPS,
                                     gen_update_op)
    with tf.compat.v1.variable_scope(model.discriminator_scope):
      dis_update_count = tf.compat.v1.get_variable('dis_count', initializer=0)
      dis_update_op = dis_update_count.assign_add(1)
      tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.UPDATE_OPS,
                                     dis_update_op)

    # Add an update op outside the generator and discriminator scopes.
    if provide_update_ops:
      kwargs = {'update_ops': [tf.constant(1.0), gen_update_op, dis_update_op]}
    else:
      tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.UPDATE_OPS,
                                     tf.constant(1.0))
      kwargs = {}

    g_opt = tf.compat.v1.train.GradientDescentOptimizer(1.0)
    d_opt = tf.compat.v1.train.GradientDescentOptimizer(1.0)

    with self.assertRaisesRegexp(ValueError, 'There are unused update ops:'):
      tfgan.gan_train_ops(
          model, loss, g_opt, d_opt, check_for_unused_update_ops=True, **kwargs)
    train_ops = tfgan.gan_train_ops(
        model, loss, g_opt, d_opt, check_for_unused_update_ops=False, **kwargs)

    with self.cached_session() as sess:
      sess.run(tf.compat.v1.global_variables_initializer())
      self.assertEqual(0, sess.run(gen_update_count))
      self.assertEqual(0, sess.run(dis_update_count))

      sess.run(train_ops.generator_train_op)
      self.assertEqual(1, sess.run(gen_update_count))
      self.assertEqual(0, sess.run(dis_update_count))

      sess.run(train_ops.discriminator_train_op)
      self.assertEqual(1, sess.run(gen_update_count))
      self.assertEqual(1, sess.run(dis_update_count))
Example #2
0
def define_train_ops(gan_model, gan_loss, **kwargs):
    """Defines progressive GAN train ops.

  Args:
    gan_model: A `GANModel` namedtuple.
    gan_loss: A `GANLoss` namedtuple.
    **kwargs: A dictionary of
        'adam_beta1': A float of Adam optimizer beta1.
        'adam_beta2': A float of Adam optimizer beta2.
        'generator_learning_rate': A float of generator learning rate.
        'discriminator_learning_rate': A float of discriminator learning rate.

  Returns:
    A tuple of `GANTrainOps` namedtuple and a list variables tracking the state
    of optimizers.
  """
    with tf.compat.v1.variable_scope('progressive_gan_train_ops') as var_scope:
        beta1, beta2 = kwargs['adam_beta1'], kwargs['adam_beta2']
        gen_opt = tf.compat.v1.train.AdamOptimizer(
            kwargs['generator_learning_rate'], beta1, beta2)
        dis_opt = tf.compat.v1.train.AdamOptimizer(
            kwargs['discriminator_learning_rate'], beta1, beta2)
        gan_train_ops = tfgan.gan_train_ops(gan_model, gan_loss, gen_opt,
                                            dis_opt)
    return gan_train_ops, tf.compat.v1.get_collection(
        tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, scope=var_scope.name)
Example #3
0
def _define_train_ops(cyclegan_model, cyclegan_loss, hparams):
    """Defines train ops that trains `cyclegan_model` with `cyclegan_loss`.

  Args:
    cyclegan_model: A `CycleGANModel` namedtuple.
    cyclegan_loss: A `CycleGANLoss` namedtuple containing all losses for
      `cyclegan_model`.
    hparams: An HParams instance containing the hyperparameters for training.

  Returns:
    A `GANTrainOps` namedtuple.
  """
    gen_lr = _get_lr(hparams.generator_lr, hparams.max_number_of_steps)
    dis_lr = _get_lr(hparams.discriminator_lr, hparams.max_number_of_steps)
    gen_opt, dis_opt = _get_optimizer(gen_lr, dis_lr)
    train_ops = tfgan.gan_train_ops(
        cyclegan_model,
        cyclegan_loss,
        generator_optimizer=gen_opt,
        discriminator_optimizer=dis_opt,
        summarize_gradients=True,
        colocate_gradients_with_ops=True,
        aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)

    tf.summary.scalar('generator_lr', gen_lr)
    tf.summary.scalar('discriminator_lr', dis_lr)
    return train_ops
Example #4
0
def _define_train_ops(cyclegan_model, cyclegan_loss):
    """Defines train ops that trains `cyclegan_model` with `cyclegan_loss`.

    Args:
      cyclegan_model: A `CycleGANModel` namedtuple.
      cyclegan_loss: A `CycleGANLoss` namedtuple containing all losses for
          `cyclegan_model`.

    Returns:
      A `GANTrainOps` namedtuple.
    """
    gen_lr = _get_lr(FLAGS.generator_lr)
    dis_lr = _get_lr(FLAGS.discriminator_lr)
    gen_opt, dis_opt = _get_optimizer(gen_lr, dis_lr)

    train_ops = tfgan.gan_train_ops(
        cyclegan_model,
        cyclegan_loss,
        generator_optimizer=gen_opt,
        discriminator_optimizer=dis_opt,
        summarize_gradients=True,
        colocate_gradients_with_ops=True,
        check_for_unused_update_ops=False,
        aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)

    tf.summary.scalar('generator_lr', gen_lr)
    tf.summary.scalar('discriminator_lr', dis_lr)
    return train_ops
Example #5
0
  def test_is_chief_in_train_hooks(self, is_chief):
    """Make sure is_chief is propagated correctly to sync hooks."""
    if tf.executing_eagerly():
      # None of the usual utilities work in eager.
      return

    model = create_gan_model()
    loss = tfgan.gan_loss(model)
    g_opt = get_sync_optimizer()
    d_opt = get_sync_optimizer()
    train_ops = tfgan.gan_train_ops(
        model,
        loss,
        g_opt,
        d_opt,
        is_chief=is_chief,
        summarize_gradients=True,
        colocate_gradients_with_ops=True)

    self.assertLen(train_ops.train_hooks, 2)

    for hook in train_ops.train_hooks:
      self.assertIsInstance(hook, get_sync_optimizer_hook_type())
    is_chief_list = [hook._is_chief for hook in train_ops.train_hooks]
    self.assertListEqual(is_chief_list, [is_chief, is_chief])
Example #6
0
  def test_run_helper(self, create_gan_model_fn):
    if tf.executing_eagerly():
      # None of the usual utilities work in eager.
      return
    tf.compat.v1.random.set_random_seed(1234)
    model = create_gan_model_fn()
    loss = tfgan.gan_loss(model)

    g_opt = tf.compat.v1.train.GradientDescentOptimizer(1.0)
    d_opt = tf.compat.v1.train.GradientDescentOptimizer(1.0)
    train_ops = tfgan.gan_train_ops(model, loss, g_opt, d_opt)

    final_step = tfgan.gan_train(
        train_ops, logdir='', hooks=[tf.estimator.StopAtStepHook(num_steps=2)])
    self.assertTrue(np.isscalar(final_step))
    self.assertEqual(2, final_step)
Example #7
0
  def test_output_type(self, create_gan_model_fn):
    if tf.executing_eagerly():
      # None of the usual utilities work in eager.
      return

    model = create_gan_model_fn()
    loss = tfgan.gan_loss(model)

    g_opt = tf.compat.v1.train.GradientDescentOptimizer(1.0)
    d_opt = tf.compat.v1.train.GradientDescentOptimizer(1.0)
    train_ops = tfgan.gan_train_ops(
        model,
        loss,
        g_opt,
        d_opt,
        summarize_gradients=True,
        colocate_gradients_with_ops=True)

    self.assertIsInstance(train_ops, tfgan.GANTrainOps)

    # Make sure there are no training hooks populated accidentally.
    self.assertEmpty(train_ops.train_hooks)
Example #8
0
def _define_train_ops(model, loss, gen_lr, dis_lr, beta1, beta2,
                      max_number_of_steps):
    """Defines train ops that trains `stargan_model` with `stargan_loss`.

  Args:
    model: A `StarGANModel` namedtuple.
    loss: A `StarGANLoss` namedtuple containing all losses for `stargan_model`.
    gen_lr: A scalar float `Tensor` or a Python number.  The Generator base
      learning rate.
    dis_lr: A scalar float `Tensor` or a Python number.  The Discriminator base
      learning rate.
    beta1: A scalar float `Tensor` or a Python number. The beta1 parameter to
      the `AdamOptimizer`.
    beta2: A scalar float `Tensor` or a Python number. The beta2 parameter to
      the `AdamOptimizer`.
    max_number_of_steps: A Python number. The total number of steps to train.

  Returns:
    A `GANTrainOps` namedtuple.
  """

    gen_lr = _get_lr(gen_lr, max_number_of_steps)
    dis_lr = _get_lr(dis_lr, max_number_of_steps)
    gen_opt, dis_opt = _get_optimizer(gen_lr, dis_lr, beta1, beta2)
    train_ops = tfgan.gan_train_ops(
        model,
        loss,
        generator_optimizer=gen_opt,
        discriminator_optimizer=dis_opt,
        summarize_gradients=True,
        colocate_gradients_with_ops=True,
        aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)

    tf.summary.scalar('generator_lr', gen_lr)
    tf.summary.scalar('discriminator_lr', dis_lr)

    return train_ops
Example #9
0
  def test_train_hooks_exist_in_get_hooks_fn(self, create_gan_model_fn):
    if tf.executing_eagerly():
      # None of the usual utilities work in eager.
      return
    model = create_gan_model_fn()
    loss = tfgan.gan_loss(model)

    g_opt = get_sync_optimizer()
    d_opt = get_sync_optimizer()
    train_ops = tfgan.gan_train_ops(
        model,
        loss,
        g_opt,
        d_opt,
        summarize_gradients=True,
        colocate_gradients_with_ops=True)

    sequential_train_hooks = tfgan.get_sequential_train_hooks()(train_ops)
    self.assertLen(sequential_train_hooks, 4)
    sync_opts = [
        hook._sync_optimizer
        for hook in sequential_train_hooks
        if isinstance(hook, get_sync_optimizer_hook_type())
    ]
    self.assertLen(sync_opts, 2)
    self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt)))

    joint_train_hooks = tfgan.get_joint_train_hooks()(train_ops)
    self.assertLen(joint_train_hooks, 5)
    sync_opts = [
        hook._sync_optimizer
        for hook in joint_train_hooks
        if isinstance(hook, get_sync_optimizer_hook_type())
    ]
    self.assertLen(sync_opts, 2)
    self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt)))
Example #10
0
def train(hparams):
    """Trains an MNIST GAN.

  Args:
    hparams: An HParams instance containing the hyperparameters for training.
  """
    if not tf.io.gfile.exists(hparams.train_log_dir):
        tf.io.gfile.makedirs(hparams.train_log_dir)

    # Force all input processing onto CPU in order to reserve the GPU for
    # the forward inference and back-propagation.
    with tf.name_scope('inputs'), tf.device('/cpu:0'):
        images, one_hot_labels = data_provider.provide_data(
            'train', hparams.batch_size, num_parallel_calls=4)

    # Define the GANModel tuple. Optionally, condition the GAN on the label or
    # use an InfoGAN to learn a latent representation.
    if hparams.gan_type == 'unconditional':
        gan_model = tfgan.gan_model(
            generator_fn=networks.unconditional_generator,
            discriminator_fn=networks.unconditional_discriminator,
            real_data=images,
            generator_inputs=tf.random.normal(
                [hparams.batch_size, hparams.noise_dims]))
    elif hparams.gan_type == 'conditional':
        noise = tf.random.normal([hparams.batch_size, hparams.noise_dims])
        gan_model = tfgan.gan_model(
            generator_fn=networks.conditional_generator,
            discriminator_fn=networks.conditional_discriminator,
            real_data=images,
            generator_inputs=(noise, one_hot_labels))
    elif hparams.gan_type == 'infogan':
        cat_dim, cont_dim = 10, 2
        generator_fn = functools.partial(networks.infogan_generator,
                                         categorical_dim=cat_dim)
        discriminator_fn = functools.partial(networks.infogan_discriminator,
                                             categorical_dim=cat_dim,
                                             continuous_dim=cont_dim)
        unstructured_inputs, structured_inputs = util.get_infogan_noise(
            hparams.batch_size, cat_dim, cont_dim, hparams.noise_dims)
        gan_model = tfgan.infogan_model(
            generator_fn=generator_fn,
            discriminator_fn=discriminator_fn,
            real_data=images,
            unstructured_generator_inputs=unstructured_inputs,
            structured_generator_inputs=structured_inputs)
    tfgan.eval.add_gan_model_image_summaries(gan_model, hparams.grid_size)

    # Get the GANLoss tuple. You can pass a custom function, use one of the
    # already-implemented losses from the losses library, or use the defaults.
    with tf.name_scope('loss'):
        if hparams.gan_type == 'infogan':
            gan_loss = tfgan.gan_loss(
                gan_model,
                generator_loss_fn=tfgan.losses.modified_generator_loss,
                discriminator_loss_fn=tfgan.losses.modified_discriminator_loss,
                mutual_information_penalty_weight=1.0,
                add_summaries=True)
        else:
            gan_loss = tfgan.gan_loss(gan_model, add_summaries=True)
        tfgan.eval.add_regularization_loss_summaries(gan_model)

    # Get the GANTrain ops using custom optimizers.
    with tf.name_scope('train'):
        gen_lr, dis_lr = _learning_rate(hparams.gan_type)
        train_ops = tfgan.gan_train_ops(
            gan_model,
            gan_loss,
            generator_optimizer=tf.train.AdamOptimizer(gen_lr, 0.5),
            discriminator_optimizer=tf.train.AdamOptimizer(dis_lr, 0.5),
            summarize_gradients=True,
            aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)

    # Run the alternating training loop. Skip it if no steps should be taken
    # (used for graph construction tests).
    status_message = tf.strings.join([
        'Starting train step: ',
        tf.as_string(tf.train.get_or_create_global_step())
    ],
                                     name='status_message')
    if hparams.max_number_of_steps == 0:
        return
    tfgan.gan_train(
        train_ops,
        hooks=[
            tf.estimator.StopAtStepHook(num_steps=hparams.max_number_of_steps),
            tf.estimator.LoggingTensorHook([status_message], every_n_iter=10)
        ],
        logdir=hparams.train_log_dir,
        get_hooks_fn=tfgan.get_joint_train_hooks(),
        save_checkpoint_secs=60)
Example #11
0
  def test_sync_replicas(self, create_gan_model_fn, create_global_step):
    if tf.executing_eagerly():
      # None of the usual utilities work in eager.
      return
    model = create_gan_model_fn()
    loss = tfgan.gan_loss(model)
    num_trainable_vars = len(get_trainable_variables())

    if create_global_step:
      gstep = tf.compat.v1.get_variable(
          'custom_gstep',
          dtype=tf.int32,
          initializer=0,
          trainable=False)
      tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.GLOBAL_STEP, gstep)

    g_opt = get_sync_optimizer()
    d_opt = get_sync_optimizer()
    train_ops = tfgan.gan_train_ops(
        model, loss, generator_optimizer=g_opt, discriminator_optimizer=d_opt)
    self.assertIsInstance(train_ops, tfgan.GANTrainOps)
    # No new trainable variables should have been added.
    self.assertLen(get_trainable_variables(), num_trainable_vars)

    # Sync hooks should be populated in the GANTraintf.
    self.assertLen(train_ops.train_hooks, 2)
    for hook in train_ops.train_hooks:
      self.assertIsInstance(hook, get_sync_optimizer_hook_type())
    sync_opts = [hook._sync_optimizer for hook in train_ops.train_hooks]
    self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt)))

    g_sync_init_op = g_opt.get_init_tokens_op(num_tokens=1)
    d_sync_init_op = d_opt.get_init_tokens_op(num_tokens=1)

    # Check that update op is run properly.
    global_step = tf.compat.v1.train.get_or_create_global_step()
    with self.cached_session() as sess:
      sess.run(tf.compat.v1.global_variables_initializer())
      sess.run(tf.compat.v1.local_variables_initializer())

      sess.run(g_opt.chief_init_op)
      sess.run(d_opt.chief_init_op)

      gstep_before = sess.run(global_step)

      # Start required queue runner for SyncReplicasOptimizer.
      coord = tf.train.Coordinator()
      g_threads = g_opt.get_chief_queue_runner().create_threads(sess, coord)
      d_threads = d_opt.get_chief_queue_runner().create_threads(sess, coord)

      sess.run(g_sync_init_op)
      sess.run(d_sync_init_op)

      sess.run(train_ops.generator_train_op)
      # Check that global step wasn't incremented.
      self.assertEqual(gstep_before, sess.run(global_step))

      sess.run(train_ops.discriminator_train_op)
      # Check that global step wasn't incremented.
      self.assertEqual(gstep_before, sess.run(global_step))

      coord.request_stop()
      coord.join(g_threads + d_threads)
Example #12
0
def train(hparams):
    """Trains a CIFAR10 GAN.

  Args:
    hparams: An HParams instance containing the hyperparameters for training.
  """
    if not tf.io.gfile.exists(hparams.train_log_dir):
        tf.io.gfile.makedirs(hparams.train_log_dir)

    with tf.device(
            tf.compat.v1.train.replica_device_setter(hparams.ps_replicas)):
        # Force all input processing onto CPU in order to reserve the GPU for
        # the forward inference and back-propagation.
        with tf.compat.v1.name_scope('inputs'):
            with tf.device('/cpu:0'):
                images, _ = data_provider.provide_data('train',
                                                       hparams.batch_size,
                                                       num_parallel_calls=4)

        # Define the GANModel tuple.
        generator_fn = networks.generator
        discriminator_fn = networks.discriminator
        generator_inputs = tf.random.normal([hparams.batch_size, 64])
        gan_model = tfgan.gan_model(generator_fn,
                                    discriminator_fn,
                                    real_data=images,
                                    generator_inputs=generator_inputs)
        tfgan.eval.add_gan_model_image_summaries(gan_model)

        # Get the GANLoss tuple. Use the selected GAN loss functions.
        with tf.compat.v1.name_scope('loss'):
            gan_loss = tfgan.gan_loss(gan_model,
                                      gradient_penalty_weight=1.0,
                                      add_summaries=True)

        # Get the GANTrain ops using the custom optimizers and optional
        # discriminator weight clipping.
        with tf.compat.v1.name_scope('train'):
            gen_opt, dis_opt = _get_optimizers(hparams)
            train_ops = tfgan.gan_train_ops(gan_model,
                                            gan_loss,
                                            generator_optimizer=gen_opt,
                                            discriminator_optimizer=dis_opt,
                                            summarize_gradients=True)

        # Run the alternating training loop. Skip it if no steps should be taken
        # (used for graph construction tests).
        status_message = tf.strings.join([
            'Starting train step: ',
            tf.as_string(tf.compat.v1.train.get_or_create_global_step())
        ],
                                         name='status_message')
        if hparams.max_number_of_steps == 0:
            return
        tfgan.gan_train(train_ops,
                        hooks=([
                            tf.estimator.StopAtStepHook(
                                num_steps=hparams.max_number_of_steps),
                            tf.estimator.LoggingTensorHook([status_message],
                                                           every_n_iter=10)
                        ]),
                        logdir=hparams.train_log_dir,
                        master=hparams.master,
                        is_chief=hparams.task == 0)