Exemplo n.º 1
0
  def test_compute_progress(self):
    if tf.executing_eagerly():
      progress_output = []
      for current_image_id in [0, 3, 6, 7, 8, 10, 15, 29, 100]:
        progress = networks.compute_progress(
            current_image_id,
            stable_stage_num_images=7,
            transition_stage_num_images=8,
            num_blocks=2)
        with self.cached_session(use_gpu=True) as sess:
          progress_output.append(sess.run(progress))
    else:
      current_image_id_ph = tf.compat.v1.placeholder(tf.int32, [])
      progress = networks.compute_progress(
          current_image_id_ph,
          stable_stage_num_images=7,
          transition_stage_num_images=8,
          num_blocks=2)
      with self.cached_session(use_gpu=True) as sess:
        progress_output = [
            sess.run(progress, feed_dict={current_image_id_ph: cur_image_id})
            for cur_image_id in [0, 3, 6, 7, 8, 10, 15, 29, 100]
        ]

    self.assertArrayNear(progress_output,
                         [0.0, 0.0, 0.0, 0.0, 0.125, 0.375, 1.0, 1.0, 1.0],
                         1.0e-6)
Exemplo n.º 2
0
  def test_generator_grad_norm_progress(self):
    if tf.executing_eagerly():
      # tf.placeholder() is not compatible with eager execution.
      return
    stable_stage_num_images = 2
    transition_stage_num_images = 3

    current_image_id_ph = tf.compat.v1.placeholder(tf.int32, [])
    progress = networks.compute_progress(
        current_image_id_ph,
        stable_stage_num_images,
        transition_stage_num_images,
        num_blocks=3)
    z = tf.random.normal([2, 10], dtype=tf.float32)
    x, _ = networks.generator(
        z, progress, _num_filters_stub,
        networks.ResolutionSchedule(
            start_resolutions=(4, 4), scale_base=2, num_resolutions=3))
    fake_loss = tf.reduce_sum(input_tensor=tf.square(x))
    grad_norms = [
        _get_grad_norm(
            fake_loss,
            tf.compat.v1.trainable_variables('.*/progressive_gan_block_1/.*')),
        _get_grad_norm(
            fake_loss,
            tf.compat.v1.trainable_variables('.*/progressive_gan_block_2/.*')),
        _get_grad_norm(
            fake_loss,
            tf.compat.v1.trainable_variables('.*/progressive_gan_block_3/.*'))
    ]

    grad_norms_output = None
    with self.cached_session(use_gpu=True) as sess:
      sess.run(tf.compat.v1.global_variables_initializer())
      x1_np = sess.run(x, feed_dict={current_image_id_ph: 0.12})
      x2_np = sess.run(x, feed_dict={current_image_id_ph: 1.8})
      grad_norms_output = np.array([
          sess.run(grad_norms, feed_dict={current_image_id_ph: i})
          for i in range(15)  # total num of images
      ])

    self.assertEqual((2, 16, 16, 3), x1_np.shape)
    self.assertEqual((2, 16, 16, 3), x2_np.shape)
    # The gradient of block_1 is always on.
    self.assertEqual(
        np.argmax(grad_norms_output[:, 0] > 0), 0,
        'gradient norms {} for block 1 is not always on'.format(
            grad_norms_output[:, 0]))
    # The gradient of block_2 is on after 1 stable stage.
    self.assertEqual(
        np.argmax(grad_norms_output[:, 1] > 0), 3,
        'gradient norms {} for block 2 is not on at step 3'.format(
            grad_norms_output[:, 1]))
    # The gradient of block_3 is on after 2 stable stage + 1 transition stage.
    self.assertEqual(
        np.argmax(grad_norms_output[:, 2] > 0), 8,
        'gradient norms {} for block 3 is not on at step 8'.format(
            grad_norms_output[:, 2]))
Exemplo n.º 3
0
    def test_discriminator_grad_norm_progress(self):
        stable_stage_num_images = 2
        transition_stage_num_images = 3

        current_image_id_ph = tf.compat.v1.placeholder(tf.int32, [])
        progress = networks.compute_progress(current_image_id_ph,
                                             stable_stage_num_images,
                                             transition_stage_num_images,
                                             num_blocks=3)
        x = tf.random.normal([2, 16, 16, 3])
        logits, _ = networks.discriminator(
            x, progress, _num_filters_stub,
            networks.ResolutionSchedule(start_resolutions=(4, 4),
                                        scale_base=2,
                                        num_resolutions=3))
        fake_loss = tf.reduce_sum(input_tensor=tf.square(logits))
        grad_norms = [
            _get_grad_norm(
                fake_loss,
                tf.compat.v1.trainable_variables(
                    '.*/progressive_gan_block_1/.*')),
            _get_grad_norm(
                fake_loss,
                tf.compat.v1.trainable_variables(
                    '.*/progressive_gan_block_2/.*')),
            _get_grad_norm(
                fake_loss,
                tf.compat.v1.trainable_variables(
                    '.*/progressive_gan_block_3/.*'))
        ]

        grad_norms_output = None
        with self.cached_session(use_gpu=True) as sess:
            sess.run(tf.compat.v1.global_variables_initializer())
            grad_norms_output = np.array([
                sess.run(grad_norms, feed_dict={current_image_id_ph: i})
                for i in range(15)  # total num of images
            ])

        # The gradient of block_1 is always on.
        self.assertEqual(
            np.argmax(grad_norms_output[:, 0] > 0), 0,
            'gradient norms {} for block 1 is not always on'.format(
                grad_norms_output[:, 0]))
        # The gradient of block_2 is on after 1 stable stage.
        self.assertEqual(
            np.argmax(grad_norms_output[:, 1] > 0), 3,
            'gradient norms {} for block 2 is not on at step 3'.format(
                grad_norms_output[:, 1]))
        # The gradient of block_3 is on after 2 stable stage + 1 transition stage.
        self.assertEqual(
            np.argmax(grad_norms_output[:, 2] > 0), 8,
            'gradient norms {} for block 3 is not on at step 8'.format(
                grad_norms_output[:, 2]))
Exemplo n.º 4
0
def build_model(stage_id, batch_size, real_images, **kwargs):
    """Builds progressive GAN model.

  Args:
    stage_id: An integer of training stage index.
    batch_size: Number of training images in each minibatch.
    real_images: A 4D `Tensor` of NHWC format.
    **kwargs: A dictionary of
        'start_height': An integer of start image height.
        'start_width': An integer of start image width.
        'scale_base': An integer of resolution multiplier.
        'num_resolutions': An integer of number of progressive resolutions.
        'stable_stage_num_images': An integer of number of training images in
          the stable stage.
        'transition_stage_num_images': An integer of number of training images
          in the transition stage.
        'total_num_images': An integer of total number of training images.
        'kernel_size': Convolution kernel size.
        'colors': Number of image channels.
        'to_rgb_use_tanh_activation': Whether to apply tanh activation when
          output rgb.
        'fmap_base': Base number of filters.
        'fmap_decay': Decay of number of filters.
        'fmap_max': Max number of filters.
        'latent_vector_size': An integer of latent vector size.
        'gradient_penalty_weight': A float of gradient norm target for
          wasserstein loss.
        'gradient_penalty_target': A float of gradient penalty weight for
          wasserstein loss.
        'real_score_penalty_weight': A float of Additional penalty to keep the
          scores from drifting too far from zero.
        'adam_beta1': A float of Adam optimizer beta1.
        'adam_beta2': A float of Adam optimizer beta2.
        'generator_learning_rate': A float of generator learning rate.
        'discriminator_learning_rate': A float of discriminator learning rate.

  Returns:
    An inernal object that wraps all information about the model.
  """
    kernel_size = kwargs['kernel_size']
    colors = kwargs['colors']
    resolution_schedule = make_resolution_schedule(**kwargs)

    num_blocks, num_images = get_stage_info(stage_id, **kwargs)

    current_image_id = tf.compat.v1.train.get_or_create_global_step()
    current_image_id_inc_op = current_image_id.assign_add(batch_size)
    tf.compat.v1.summary.scalar('current_image_id', current_image_id)

    progress = networks.compute_progress(current_image_id,
                                         kwargs['stable_stage_num_images'],
                                         kwargs['transition_stage_num_images'],
                                         num_blocks)
    tf.compat.v1.summary.scalar('progress', progress)

    real_images = networks.blend_images(real_images,
                                        progress,
                                        resolution_schedule,
                                        num_blocks=num_blocks)

    def _num_filters_fn(block_id):
        """Computes number of filters of block `block_id`."""
        return networks.num_filters(block_id, kwargs['fmap_base'],
                                    kwargs['fmap_decay'], kwargs['fmap_max'])

    def _generator_fn(z):
        """Builds generator network."""
        to_rgb_act = tf.tanh if kwargs['to_rgb_use_tanh_activation'] else None
        return networks.generator(z,
                                  progress,
                                  _num_filters_fn,
                                  resolution_schedule,
                                  num_blocks=num_blocks,
                                  kernel_size=kernel_size,
                                  colors=colors,
                                  to_rgb_activation=to_rgb_act)

    def _discriminator_fn(x):
        """Builds discriminator network."""
        return networks.discriminator(x,
                                      progress,
                                      _num_filters_fn,
                                      resolution_schedule,
                                      num_blocks=num_blocks,
                                      kernel_size=kernel_size)

    ########## Define model.
    z = make_latent_vectors(batch_size, **kwargs)

    gan_model = tfgan.gan_model(
        generator_fn=lambda z: _generator_fn(z)[0],
        discriminator_fn=lambda x, unused_z: _discriminator_fn(x)[0],
        real_data=real_images,
        generator_inputs=z)

    ########## Define loss.
    gan_loss = define_loss(gan_model, **kwargs)

    ########## Define train ops.
    gan_train_ops, optimizer_var_list = define_train_ops(
        gan_model, gan_loss, **kwargs)
    gan_train_ops = gan_train_ops._replace(
        global_step_inc_op=current_image_id_inc_op)

    ########## Generator smoothing.
    generator_ema = tf.train.ExponentialMovingAverage(decay=0.999)
    gan_train_ops, generator_vars_to_restore = add_generator_smoothing_ops(
        generator_ema, gan_model, gan_train_ops)

    class Model(object):
        pass

    model = Model()
    model.stage_id = stage_id
    model.batch_size = batch_size
    model.resolution_schedule = resolution_schedule
    model.num_images = num_images
    model.num_blocks = num_blocks
    model.current_image_id = current_image_id
    model.progress = progress
    model.num_filters_fn = _num_filters_fn
    model.generator_fn = _generator_fn
    model.discriminator_fn = _discriminator_fn
    model.gan_model = gan_model
    model.gan_loss = gan_loss
    model.gan_train_ops = gan_train_ops
    model.optimizer_var_list = optimizer_var_list
    model.generator_ema = generator_ema
    model.generator_vars_to_restore = generator_vars_to_restore
    return model