示例#1
0
文件: train.py 项目: sts-sadr/gan-2
 def _discriminator_fn(x):
     """Builds discriminator network."""
     return networks.discriminator(x,
                                   progress,
                                   _num_filters_fn,
                                   resolution_schedule,
                                   num_blocks=num_blocks,
                                   kernel_size=kernel_size)
示例#2
0
  def test_discriminator_grad_norm_progress(self):
    if tf.executing_eagerly():
      # tf.placeholder() is not compatible with eager execution.
      return
    stable_stage_num_images = 2
    transition_stage_num_images = 3

    current_image_id_ph = tf.compat.v1.placeholder(tf.int32, [])
    progress = networks.compute_progress(
        current_image_id_ph,
        stable_stage_num_images,
        transition_stage_num_images,
        num_blocks=3)
    x = tf.random.normal([2, 16, 16, 3])
    logits, _ = networks.discriminator(
        x, progress, _num_filters_stub,
        networks.ResolutionSchedule(
            start_resolutions=(4, 4), scale_base=2, num_resolutions=3))
    fake_loss = tf.reduce_sum(input_tensor=tf.square(logits))
    grad_norms = [
        _get_grad_norm(
            fake_loss,
            tf.compat.v1.trainable_variables('.*/progressive_gan_block_1/.*')),
        _get_grad_norm(
            fake_loss,
            tf.compat.v1.trainable_variables('.*/progressive_gan_block_2/.*')),
        _get_grad_norm(
            fake_loss,
            tf.compat.v1.trainable_variables('.*/progressive_gan_block_3/.*'))
    ]

    grad_norms_output = None
    with self.cached_session(use_gpu=True) as sess:
      sess.run(tf.compat.v1.global_variables_initializer())
      grad_norms_output = np.array([
          sess.run(grad_norms, feed_dict={current_image_id_ph: i})
          for i in range(15)  # total num of images
      ])

    # The gradient of block_1 is always on.
    self.assertEqual(
        np.argmax(grad_norms_output[:, 0] > 0), 0,
        'gradient norms {} for block 1 is not always on'.format(
            grad_norms_output[:, 0]))
    # The gradient of block_2 is on after 1 stable stage.
    self.assertEqual(
        np.argmax(grad_norms_output[:, 1] > 0), 3,
        'gradient norms {} for block 2 is not on at step 3'.format(
            grad_norms_output[:, 1]))
    # The gradient of block_3 is on after 2 stable stage + 1 transition stage.
    self.assertEqual(
        np.argmax(grad_norms_output[:, 2] > 0), 8,
        'gradient norms {} for block 3 is not on at step 8'.format(
            grad_norms_output[:, 2]))