Beispiel #1
0
 def test_blend_images_in_stable_stage(self):
   x_np = np.random.normal(size=[2, 8, 8, 3])
   x = tf.constant(x_np, tf.float32)
   x_blend = networks.blend_images(
       x,
       progress=tf.constant(0.0),
       resolution_schedule=networks.ResolutionSchedule(
           scale_base=2, num_resolutions=2),
       num_blocks=2)
   with self.cached_session(use_gpu=True) as sess:
     x_blend_np = sess.run(x_blend)
     x_blend_expected_np = sess.run(layers.upscale(layers.downscale(x, 2), 2))
   self.assertNDArrayNear(x_blend_np, x_blend_expected_np, 1.0e-6)
Beispiel #2
0
def build_model(stage_id, batch_size, real_images, **kwargs):
    """Builds progressive GAN model.

  Args:
    stage_id: An integer of training stage index.
    batch_size: Number of training images in each minibatch.
    real_images: A 4D `Tensor` of NHWC format.
    **kwargs: A dictionary of
        'start_height': An integer of start image height.
        'start_width': An integer of start image width.
        'scale_base': An integer of resolution multiplier.
        'num_resolutions': An integer of number of progressive resolutions.
        'stable_stage_num_images': An integer of number of training images in
          the stable stage.
        'transition_stage_num_images': An integer of number of training images
          in the transition stage.
        'total_num_images': An integer of total number of training images.
        'kernel_size': Convolution kernel size.
        'colors': Number of image channels.
        'to_rgb_use_tanh_activation': Whether to apply tanh activation when
          output rgb.
        'fmap_base': Base number of filters.
        'fmap_decay': Decay of number of filters.
        'fmap_max': Max number of filters.
        'latent_vector_size': An integer of latent vector size.
        'gradient_penalty_weight': A float of gradient norm target for
          wasserstein loss.
        'gradient_penalty_target': A float of gradient penalty weight for
          wasserstein loss.
        'real_score_penalty_weight': A float of Additional penalty to keep the
          scores from drifting too far from zero.
        'adam_beta1': A float of Adam optimizer beta1.
        'adam_beta2': A float of Adam optimizer beta2.
        'generator_learning_rate': A float of generator learning rate.
        'discriminator_learning_rate': A float of discriminator learning rate.

  Returns:
    An inernal object that wraps all information about the model.
  """
    kernel_size = kwargs['kernel_size']
    colors = kwargs['colors']
    resolution_schedule = make_resolution_schedule(**kwargs)

    num_blocks, num_images = get_stage_info(stage_id, **kwargs)

    current_image_id = tf.compat.v1.train.get_or_create_global_step()
    current_image_id_inc_op = current_image_id.assign_add(batch_size)
    tf.compat.v1.summary.scalar('current_image_id', current_image_id)

    progress = networks.compute_progress(current_image_id,
                                         kwargs['stable_stage_num_images'],
                                         kwargs['transition_stage_num_images'],
                                         num_blocks)
    tf.compat.v1.summary.scalar('progress', progress)

    real_images = networks.blend_images(real_images,
                                        progress,
                                        resolution_schedule,
                                        num_blocks=num_blocks)

    def _num_filters_fn(block_id):
        """Computes number of filters of block `block_id`."""
        return networks.num_filters(block_id, kwargs['fmap_base'],
                                    kwargs['fmap_decay'], kwargs['fmap_max'])

    def _generator_fn(z):
        """Builds generator network."""
        to_rgb_act = tf.tanh if kwargs['to_rgb_use_tanh_activation'] else None
        return networks.generator(z,
                                  progress,
                                  _num_filters_fn,
                                  resolution_schedule,
                                  num_blocks=num_blocks,
                                  kernel_size=kernel_size,
                                  colors=colors,
                                  to_rgb_activation=to_rgb_act)

    def _discriminator_fn(x):
        """Builds discriminator network."""
        return networks.discriminator(x,
                                      progress,
                                      _num_filters_fn,
                                      resolution_schedule,
                                      num_blocks=num_blocks,
                                      kernel_size=kernel_size)

    ########## Define model.
    z = make_latent_vectors(batch_size, **kwargs)

    gan_model = tfgan.gan_model(
        generator_fn=lambda z: _generator_fn(z)[0],
        discriminator_fn=lambda x, unused_z: _discriminator_fn(x)[0],
        real_data=real_images,
        generator_inputs=z)

    ########## Define loss.
    gan_loss = define_loss(gan_model, **kwargs)

    ########## Define train ops.
    gan_train_ops, optimizer_var_list = define_train_ops(
        gan_model, gan_loss, **kwargs)
    gan_train_ops = gan_train_ops._replace(
        global_step_inc_op=current_image_id_inc_op)

    ########## Generator smoothing.
    generator_ema = tf.train.ExponentialMovingAverage(decay=0.999)
    gan_train_ops, generator_vars_to_restore = add_generator_smoothing_ops(
        generator_ema, gan_model, gan_train_ops)

    class Model(object):
        pass

    model = Model()
    model.stage_id = stage_id
    model.batch_size = batch_size
    model.resolution_schedule = resolution_schedule
    model.num_images = num_images
    model.num_blocks = num_blocks
    model.current_image_id = current_image_id
    model.progress = progress
    model.num_filters_fn = _num_filters_fn
    model.generator_fn = _generator_fn
    model.discriminator_fn = _discriminator_fn
    model.gan_model = gan_model
    model.gan_loss = gan_loss
    model.gan_train_ops = gan_train_ops
    model.optimizer_var_list = optimizer_var_list
    model.generator_ema = generator_ema
    model.generator_vars_to_restore = generator_vars_to_restore
    return model