コード例 #1
0
ファイル: train_test.py プロジェクト: sts-sadr/gan-2
  def test_patchgan(self, create_gan_model_fn):
    """Ensure that patch-based discriminators work end-to-end."""
    if tf.executing_eagerly():
      # None of the usual utilities work in eager.
      return

    tf.compat.v1.random.set_random_seed(1234)
    model = create_gan_model_fn()
    loss = tfgan.gan_loss(model)

    g_opt = tf.compat.v1.train.GradientDescentOptimizer(1.0)
    d_opt = tf.compat.v1.train.GradientDescentOptimizer(1.0)
    train_ops = tfgan.gan_train_ops(model, loss, g_opt, d_opt)

    final_step = tfgan.gan_train(
        train_ops, logdir='', hooks=[tf.estimator.StopAtStepHook(num_steps=2)])
    self.assertTrue(np.isscalar(final_step))
    self.assertEqual(2, final_step)
コード例 #2
0
    def test_get_sync_estimator_spec(self):
        """Make sure spec is loaded with sync hooks for sync opts."""
        with tf.Graph().as_default():
            gan_model = get_dummy_gan_model()
            gan_loss = tfgan.gan_loss(gan_model, dummy_loss_fn, dummy_loss_fn)
            g_opt = get_sync_optimizer()
            d_opt = get_sync_optimizer()

            spec = get_train_estimator_spec(gan_model,
                                            gan_loss,
                                            Optimizers(g_opt, d_opt),
                                            get_hooks_fn=None)  # use default.

            self.assertLen(spec.training_hooks, 4)
            sync_opts = [
                hook._sync_optimizer for hook in spec.training_hooks
                if isinstance(hook, get_sync_optimizer_hook_type())
            ]
            self.assertLen(sync_opts, 2)
            self.assertSetEqual(frozenset(sync_opts), frozenset(
                (g_opt, d_opt)))
コード例 #3
0
ファイル: train_test.py プロジェクト: Aerochip7/gan
  def test_output_type(self, create_gan_model_fn):
    if tf.executing_eagerly():
      # None of the usual utilities work in eager.
      return

    model = create_gan_model_fn()
    loss = tfgan.gan_loss(model)

    g_opt = tf.compat.v1.train.GradientDescentOptimizer(1.0)
    d_opt = tf.compat.v1.train.GradientDescentOptimizer(1.0)
    train_ops = tfgan.gan_train_ops(
        model,
        loss,
        g_opt,
        d_opt,
        summarize_gradients=True,
        colocate_gradients_with_ops=True)

    self.assertIsInstance(train_ops, tfgan.GANTrainOps)

    # Make sure there are no training hooks populated accidentally.
    self.assertEmpty(train_ops.train_hooks)
コード例 #4
0
def define_loss(gan_model, **kwargs):
  """Defines progressive GAN losses.

  The generator and discriminator both use wasserstein loss. In addition,
  a small penalty term is added to the discriminator loss to prevent it getting
  too large.

  Args:
    gan_model: A `GANModel` namedtuple.
    **kwargs: A dictionary of
        'gradient_penalty_weight': A float of gradient norm target for
          wasserstein loss.
        'gradient_penalty_target': A float of gradient penalty weight for
          wasserstein loss.
        'real_score_penalty_weight': A float of Additional penalty to keep the
          scores from drifting too far from zero.

  Returns:
    A `GANLoss` namedtuple.
  """
  gan_loss = tfgan.gan_loss(
      gan_model,
      generator_loss_fn=tfgan.losses.wasserstein_generator_loss,
      discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss,
      gradient_penalty_weight=kwargs['gradient_penalty_weight'],
      gradient_penalty_target=kwargs['gradient_penalty_target'],
      gradient_penalty_epsilon=0.0)

  real_score_penalty = tf.reduce_mean(
      input_tensor=tf.square(gan_model.discriminator_real_outputs))
  tf.compat.v1.summary.scalar('real_score_penalty', real_score_penalty)

  return gan_loss._replace(
      discriminator_loss=(
          gan_loss.discriminator_loss +
          kwargs['real_score_penalty_weight'] * real_score_penalty))
コード例 #5
0
ファイル: train_test.py プロジェクト: Aerochip7/gan
  def test_train_hooks_exist_in_get_hooks_fn(self, create_gan_model_fn):
    if tf.executing_eagerly():
      # None of the usual utilities work in eager.
      return
    model = create_gan_model_fn()
    loss = tfgan.gan_loss(model)

    g_opt = get_sync_optimizer()
    d_opt = get_sync_optimizer()
    train_ops = tfgan.gan_train_ops(
        model,
        loss,
        g_opt,
        d_opt,
        summarize_gradients=True,
        colocate_gradients_with_ops=True)

    sequential_train_hooks = tfgan.get_sequential_train_hooks()(train_ops)
    self.assertLen(sequential_train_hooks, 4)
    sync_opts = [
        hook._sync_optimizer
        for hook in sequential_train_hooks
        if isinstance(hook, get_sync_optimizer_hook_type())
    ]
    self.assertLen(sync_opts, 2)
    self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt)))

    joint_train_hooks = tfgan.get_joint_train_hooks()(train_ops)
    self.assertLen(joint_train_hooks, 5)
    sync_opts = [
        hook._sync_optimizer
        for hook in joint_train_hooks
        if isinstance(hook, get_sync_optimizer_hook_type())
    ]
    self.assertLen(sync_opts, 2)
    self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt)))
コード例 #6
0
ファイル: train_lib.py プロジェクト: zhouyonglong/gan
def train(hparams):
    """Trains an MNIST GAN.

  Args:
    hparams: An HParams instance containing the hyperparameters for training.
  """
    if not tf.io.gfile.exists(hparams.train_log_dir):
        tf.io.gfile.makedirs(hparams.train_log_dir)

    # Force all input processing onto CPU in order to reserve the GPU for
    # the forward inference and back-propagation.
    with tf.name_scope('inputs'), tf.device('/cpu:0'):
        images, one_hot_labels = data_provider.provide_data(
            'train', hparams.batch_size, num_parallel_calls=4)

    # Define the GANModel tuple. Optionally, condition the GAN on the label or
    # use an InfoGAN to learn a latent representation.
    if hparams.gan_type == 'unconditional':
        gan_model = tfgan.gan_model(
            generator_fn=networks.unconditional_generator,
            discriminator_fn=networks.unconditional_discriminator,
            real_data=images,
            generator_inputs=tf.random.normal(
                [hparams.batch_size, hparams.noise_dims]))
    elif hparams.gan_type == 'conditional':
        noise = tf.random.normal([hparams.batch_size, hparams.noise_dims])
        gan_model = tfgan.gan_model(
            generator_fn=networks.conditional_generator,
            discriminator_fn=networks.conditional_discriminator,
            real_data=images,
            generator_inputs=(noise, one_hot_labels))
    elif hparams.gan_type == 'infogan':
        cat_dim, cont_dim = 10, 2
        generator_fn = functools.partial(networks.infogan_generator,
                                         categorical_dim=cat_dim)
        discriminator_fn = functools.partial(networks.infogan_discriminator,
                                             categorical_dim=cat_dim,
                                             continuous_dim=cont_dim)
        unstructured_inputs, structured_inputs = util.get_infogan_noise(
            hparams.batch_size, cat_dim, cont_dim, hparams.noise_dims)
        gan_model = tfgan.infogan_model(
            generator_fn=generator_fn,
            discriminator_fn=discriminator_fn,
            real_data=images,
            unstructured_generator_inputs=unstructured_inputs,
            structured_generator_inputs=structured_inputs)
    tfgan.eval.add_gan_model_image_summaries(gan_model, hparams.grid_size)

    # Get the GANLoss tuple. You can pass a custom function, use one of the
    # already-implemented losses from the losses library, or use the defaults.
    with tf.name_scope('loss'):
        if hparams.gan_type == 'infogan':
            gan_loss = tfgan.gan_loss(
                gan_model,
                generator_loss_fn=tfgan.losses.modified_generator_loss,
                discriminator_loss_fn=tfgan.losses.modified_discriminator_loss,
                mutual_information_penalty_weight=1.0,
                add_summaries=True)
        else:
            gan_loss = tfgan.gan_loss(gan_model, add_summaries=True)
        tfgan.eval.add_regularization_loss_summaries(gan_model)

    # Get the GANTrain ops using custom optimizers.
    with tf.name_scope('train'):
        gen_lr, dis_lr = _learning_rate(hparams.gan_type)
        train_ops = tfgan.gan_train_ops(
            gan_model,
            gan_loss,
            generator_optimizer=tf.train.AdamOptimizer(gen_lr, 0.5),
            discriminator_optimizer=tf.train.AdamOptimizer(dis_lr, 0.5),
            summarize_gradients=True,
            aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)

    # Run the alternating training loop. Skip it if no steps should be taken
    # (used for graph construction tests).
    status_message = tf.strings.join([
        'Starting train step: ',
        tf.as_string(tf.train.get_or_create_global_step())
    ],
                                     name='status_message')
    if hparams.max_number_of_steps == 0:
        return
    tfgan.gan_train(
        train_ops,
        hooks=[
            tf.estimator.StopAtStepHook(num_steps=hparams.max_number_of_steps),
            tf.estimator.LoggingTensorHook([status_message], every_n_iter=10)
        ],
        logdir=hparams.train_log_dir,
        get_hooks_fn=tfgan.get_joint_train_hooks(),
        save_checkpoint_secs=60)
コード例 #7
0
ファイル: train_test.py プロジェクト: Aerochip7/gan
 def test_mutual_info_penalty(self, create_gan_model_fn):
   """Test mutual information penalty option."""
   tfgan.gan_loss(
       create_gan_model_fn(),
       mutual_information_penalty_weight=tf.constant(1.0))
コード例 #8
0
ファイル: train_test.py プロジェクト: Aerochip7/gan
 def test_args_passed_in_correctly(self):
   def loss_fn(gan_model, add_summaries):
     del gan_model
     self.assertFalse(add_summaries)
     return 0
   tfgan.gan_loss(get_gan_model(), loss_fn, loss_fn, add_summaries=False)
コード例 #9
0
ファイル: train_test.py プロジェクト: Aerochip7/gan
 def test_no_reduction_or_add_summaries_loss(self):
   def loss_fn(_):
     return 0
   tfgan.gan_loss(get_gan_model(), loss_fn, loss_fn)
コード例 #10
0
ファイル: train_test.py プロジェクト: Aerochip7/gan
  def test_sync_replicas(self, create_gan_model_fn, create_global_step):
    if tf.executing_eagerly():
      # None of the usual utilities work in eager.
      return
    model = create_gan_model_fn()
    loss = tfgan.gan_loss(model)
    num_trainable_vars = len(get_trainable_variables())

    if create_global_step:
      gstep = tf.compat.v1.get_variable(
          'custom_gstep',
          dtype=tf.int32,
          initializer=0,
          trainable=False)
      tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.GLOBAL_STEP, gstep)

    g_opt = get_sync_optimizer()
    d_opt = get_sync_optimizer()
    train_ops = tfgan.gan_train_ops(
        model, loss, generator_optimizer=g_opt, discriminator_optimizer=d_opt)
    self.assertIsInstance(train_ops, tfgan.GANTrainOps)
    # No new trainable variables should have been added.
    self.assertLen(get_trainable_variables(), num_trainable_vars)

    # Sync hooks should be populated in the GANTraintf.
    self.assertLen(train_ops.train_hooks, 2)
    for hook in train_ops.train_hooks:
      self.assertIsInstance(hook, get_sync_optimizer_hook_type())
    sync_opts = [hook._sync_optimizer for hook in train_ops.train_hooks]
    self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt)))

    g_sync_init_op = g_opt.get_init_tokens_op(num_tokens=1)
    d_sync_init_op = d_opt.get_init_tokens_op(num_tokens=1)

    # Check that update op is run properly.
    global_step = tf.compat.v1.train.get_or_create_global_step()
    with self.cached_session() as sess:
      sess.run(tf.compat.v1.global_variables_initializer())
      sess.run(tf.compat.v1.local_variables_initializer())

      sess.run(g_opt.chief_init_op)
      sess.run(d_opt.chief_init_op)

      gstep_before = sess.run(global_step)

      # Start required queue runner for SyncReplicasOptimizer.
      coord = tf.train.Coordinator()
      g_threads = g_opt.get_chief_queue_runner().create_threads(sess, coord)
      d_threads = d_opt.get_chief_queue_runner().create_threads(sess, coord)

      sess.run(g_sync_init_op)
      sess.run(d_sync_init_op)

      sess.run(train_ops.generator_train_op)
      # Check that global step wasn't incremented.
      self.assertEqual(gstep_before, sess.run(global_step))

      sess.run(train_ops.discriminator_train_op)
      # Check that global step wasn't incremented.
      self.assertEqual(gstep_before, sess.run(global_step))

      coord.request_stop()
      coord.join(g_threads + d_threads)
コード例 #11
0
ファイル: train_lib.py プロジェクト: yyht/gan
def train(hparams):
    """Trains a CIFAR10 GAN.

  Args:
    hparams: An HParams instance containing the hyperparameters for training.
  """
    if not tf.io.gfile.exists(hparams.train_log_dir):
        tf.io.gfile.makedirs(hparams.train_log_dir)

    with tf.device(
            tf.compat.v1.train.replica_device_setter(hparams.ps_replicas)):
        # Force all input processing onto CPU in order to reserve the GPU for
        # the forward inference and back-propagation.
        with tf.compat.v1.name_scope('inputs'):
            with tf.device('/cpu:0'):
                images, _ = data_provider.provide_data('train',
                                                       hparams.batch_size,
                                                       num_parallel_calls=4)

        # Define the GANModel tuple.
        generator_fn = networks.generator
        discriminator_fn = networks.discriminator
        generator_inputs = tf.random.normal([hparams.batch_size, 64])
        gan_model = tfgan.gan_model(generator_fn,
                                    discriminator_fn,
                                    real_data=images,
                                    generator_inputs=generator_inputs)
        tfgan.eval.add_gan_model_image_summaries(gan_model)

        # Get the GANLoss tuple. Use the selected GAN loss functions.
        with tf.compat.v1.name_scope('loss'):
            gan_loss = tfgan.gan_loss(gan_model,
                                      gradient_penalty_weight=1.0,
                                      add_summaries=True)

        # Get the GANTrain ops using the custom optimizers and optional
        # discriminator weight clipping.
        with tf.compat.v1.name_scope('train'):
            gen_opt, dis_opt = _get_optimizers(hparams)
            train_ops = tfgan.gan_train_ops(gan_model,
                                            gan_loss,
                                            generator_optimizer=gen_opt,
                                            discriminator_optimizer=dis_opt,
                                            summarize_gradients=True)

        # Run the alternating training loop. Skip it if no steps should be taken
        # (used for graph construction tests).
        status_message = tf.strings.join([
            'Starting train step: ',
            tf.as_string(tf.compat.v1.train.get_or_create_global_step())
        ],
                                         name='status_message')
        if hparams.max_number_of_steps == 0:
            return
        tfgan.gan_train(train_ops,
                        hooks=([
                            tf.estimator.StopAtStepHook(
                                num_steps=hparams.max_number_of_steps),
                            tf.estimator.LoggingTensorHook([status_message],
                                                           every_n_iter=10)
                        ]),
                        logdir=hparams.train_log_dir,
                        master=hparams.master,
                        is_chief=hparams.task == 0)