Пример #1
0
  def test_generator_grad_norm_progress(self):
    if tf.executing_eagerly():
      # tf.placeholder() is not compatible with eager execution.
      return
    stable_stage_num_images = 2
    transition_stage_num_images = 3

    current_image_id_ph = tf.compat.v1.placeholder(tf.int32, [])
    progress = networks.compute_progress(
        current_image_id_ph,
        stable_stage_num_images,
        transition_stage_num_images,
        num_blocks=3)
    z = tf.random.normal([2, 10], dtype=tf.float32)
    x, _ = networks.generator(
        z, progress, _num_filters_stub,
        networks.ResolutionSchedule(
            start_resolutions=(4, 4), scale_base=2, num_resolutions=3))
    fake_loss = tf.reduce_sum(input_tensor=tf.square(x))
    grad_norms = [
        _get_grad_norm(
            fake_loss,
            tf.compat.v1.trainable_variables('.*/progressive_gan_block_1/.*')),
        _get_grad_norm(
            fake_loss,
            tf.compat.v1.trainable_variables('.*/progressive_gan_block_2/.*')),
        _get_grad_norm(
            fake_loss,
            tf.compat.v1.trainable_variables('.*/progressive_gan_block_3/.*'))
    ]

    grad_norms_output = None
    with self.cached_session(use_gpu=True) as sess:
      sess.run(tf.compat.v1.global_variables_initializer())
      x1_np = sess.run(x, feed_dict={current_image_id_ph: 0.12})
      x2_np = sess.run(x, feed_dict={current_image_id_ph: 1.8})
      grad_norms_output = np.array([
          sess.run(grad_norms, feed_dict={current_image_id_ph: i})
          for i in range(15)  # total num of images
      ])

    self.assertEqual((2, 16, 16, 3), x1_np.shape)
    self.assertEqual((2, 16, 16, 3), x2_np.shape)
    # The gradient of block_1 is always on.
    self.assertEqual(
        np.argmax(grad_norms_output[:, 0] > 0), 0,
        'gradient norms {} for block 1 is not always on'.format(
            grad_norms_output[:, 0]))
    # The gradient of block_2 is on after 1 stable stage.
    self.assertEqual(
        np.argmax(grad_norms_output[:, 1] > 0), 3,
        'gradient norms {} for block 2 is not on at step 3'.format(
            grad_norms_output[:, 1]))
    # The gradient of block_3 is on after 2 stable stage + 1 transition stage.
    self.assertEqual(
        np.argmax(grad_norms_output[:, 2] > 0), 8,
        'gradient norms {} for block 3 is not on at step 8'.format(
            grad_norms_output[:, 2]))
Пример #2
0
    def test_discriminator_grad_norm_progress(self):
        stable_stage_num_images = 2
        transition_stage_num_images = 3

        current_image_id_ph = tf.compat.v1.placeholder(tf.int32, [])
        progress = networks.compute_progress(current_image_id_ph,
                                             stable_stage_num_images,
                                             transition_stage_num_images,
                                             num_blocks=3)
        x = tf.random.normal([2, 16, 16, 3])
        logits, _ = networks.discriminator(
            x, progress, _num_filters_stub,
            networks.ResolutionSchedule(start_resolutions=(4, 4),
                                        scale_base=2,
                                        num_resolutions=3))
        fake_loss = tf.reduce_sum(input_tensor=tf.square(logits))
        grad_norms = [
            _get_grad_norm(
                fake_loss,
                tf.compat.v1.trainable_variables(
                    '.*/progressive_gan_block_1/.*')),
            _get_grad_norm(
                fake_loss,
                tf.compat.v1.trainable_variables(
                    '.*/progressive_gan_block_2/.*')),
            _get_grad_norm(
                fake_loss,
                tf.compat.v1.trainable_variables(
                    '.*/progressive_gan_block_3/.*'))
        ]

        grad_norms_output = None
        with self.cached_session(use_gpu=True) as sess:
            sess.run(tf.compat.v1.global_variables_initializer())
            grad_norms_output = np.array([
                sess.run(grad_norms, feed_dict={current_image_id_ph: i})
                for i in range(15)  # total num of images
            ])

        # The gradient of block_1 is always on.
        self.assertEqual(
            np.argmax(grad_norms_output[:, 0] > 0), 0,
            'gradient norms {} for block 1 is not always on'.format(
                grad_norms_output[:, 0]))
        # The gradient of block_2 is on after 1 stable stage.
        self.assertEqual(
            np.argmax(grad_norms_output[:, 1] > 0), 3,
            'gradient norms {} for block 2 is not on at step 3'.format(
                grad_norms_output[:, 1]))
        # The gradient of block_3 is on after 2 stable stage + 1 transition stage.
        self.assertEqual(
            np.argmax(grad_norms_output[:, 2] > 0), 8,
            'gradient norms {} for block 3 is not on at step 8'.format(
                grad_norms_output[:, 2]))
Пример #3
0
 def test_blend_images_in_stable_stage(self):
   x_np = np.random.normal(size=[2, 8, 8, 3])
   x = tf.constant(x_np, tf.float32)
   x_blend = networks.blend_images(
       x,
       progress=tf.constant(0.0),
       resolution_schedule=networks.ResolutionSchedule(
           scale_base=2, num_resolutions=2),
       num_blocks=2)
   with self.cached_session(use_gpu=True) as sess:
     x_blend_np = sess.run(x_blend)
     x_blend_expected_np = sess.run(layers.upscale(layers.downscale(x, 2), 2))
   self.assertNDArrayNear(x_blend_np, x_blend_expected_np, 1.0e-6)
Пример #4
0
 def test_resolution_schedule_correct(self):
   rs = networks.ResolutionSchedule(
       start_resolutions=[5, 3], scale_base=2, num_resolutions=3)
   self.assertEqual(rs.start_resolutions, (5, 3))
   self.assertEqual(rs.scale_base, 2)
   self.assertEqual(rs.num_resolutions, 3)
   self.assertEqual(rs.final_resolutions, (20, 12))
   self.assertEqual(rs.scale_factor(1), 4)
   self.assertEqual(rs.scale_factor(2), 2)
   self.assertEqual(rs.scale_factor(3), 1)
   with self.assertRaises(ValueError):
     rs.scale_factor(0)
   with self.assertRaises(ValueError):
     rs.scale_factor(4)
Пример #5
0
def make_resolution_schedule(**kwargs):
    """Returns an object of `ResolutionSchedule`."""
    return networks.ResolutionSchedule(
        start_resolutions=(kwargs['start_height'], kwargs['start_width']),
        scale_base=kwargs['scale_base'],
        num_resolutions=kwargs['num_resolutions'])