Exemplo n.º 1
0
def generator_three_layer_block(input_layer,
                                out_channels,
                                do_pixel_norm=False,
                                conditional_layer=None,
                                unet_end_points=None):
    # Upsample
    ret = pggan_utils.resize_twice_as_big(input_layer)
    # Concat extra layers.
    ret = pggan_utils.maybe_concat_conditional_layer(ret, conditional_layer)
    ret = pggan_utils.maybe_concat_unet_layer(ret, unet_end_points)
    # Conv
    conv2d_out = ret
    conv2d_out = pggan_utils.maybe_pixel_norm(
        pggan_utils.maybe_equalized_conv2d(
            conv2d_out,
            out_channels,
        ),
        do_pixel_norm=do_pixel_norm)
    conv2d_out = pggan_utils.maybe_pixel_norm(
        pggan_utils.maybe_equalized_conv2d(
            conv2d_out,
            out_channels,
        ),
        do_pixel_norm=do_pixel_norm)
    ret = pggan_utils.maybe_resblock(ret, out_channels, conv2d_out)
    return ret
Exemplo n.º 2
0
def encoder_classification(source,
                           output_dim=4,
                           is_training=False,
                           arg_scope_fn=pggan_utils.pggan_generator_arg_scope,
                           prediction_scope_name='prediction',
                           **kwargs_unused):
    """Adds some last few layers of convolutions followed by a fully-connected layer. Outputs [batch, 1, 1, channels]."""
    end_points = {}
    net = source
    with tf.contrib.framework.arg_scope(arg_scope_fn(is_training=is_training)):
        before_fc_num_channels = FLAGS.pggan_max_num_channels
        with tf.variable_scope('before_fc_1x1x%d' % before_fc_num_channels):
            net = pggan_utils.maybe_equalized_conv2d(net,
                                                     before_fc_num_channels,
                                                     kernel_size=3,
                                                     padding='SAME')
            net = pggan_utils.maybe_equalized_conv2d(net,
                                                     before_fc_num_channels,
                                                     kernel_size=4,
                                                     padding='VALID')
            end_points['before_fc_1x1x%d' % before_fc_num_channels] = net

        with tf.variable_scope(prediction_scope_name, reuse=tf.AUTO_REUSE):
            weights_init_stddev = 1.0 if FLAGS.equalized_learning_rate else 0.02
            net = pggan_utils.maybe_equalized_fc(
                tf.squeeze(net, axis=(1, 2)),
                output_dim,
                activation_fn=None,
                weights_initializer=tf.random_normal_initializer(
                    0, weights_init_stddev),
                weights_regularizer=None,
            )
        end_points[prediction_scope_name] = net
        return net, end_points
Exemplo n.º 3
0
def encoder_from_rgb_block(input_layer, out_channels, do_pixel_norm=False):
    conv2d_out = pggan_utils.maybe_pixel_norm(
        pggan_utils.maybe_equalized_conv2d(input_layer,
                                           out_channels,
                                           kernel_size=1),
        do_pixel_norm=do_pixel_norm)
    ret = pggan_utils.maybe_resblock(input_layer, out_channels, conv2d_out)
    return ret
Exemplo n.º 4
0
def discriminator_two_layer_block(input_layer, out_channels, maybe_gdrop_fn):
    input_shape = input_layer.shape
    input_channels = input_shape[3]
    conv2d_out = input_layer
    # The first layer's depth is the same as input.
    conv2d_out = pggan_utils.maybe_equalized_conv2d(maybe_gdrop_fn(conv2d_out),
                                                    input_channels,
                                                    is_discriminator=True)
    # The second layer's depth is the output channel.
    conv2d_out = pggan_utils.maybe_equalized_conv2d(maybe_gdrop_fn(conv2d_out),
                                                    out_channels,
                                                    kernel_size=3,
                                                    is_discriminator=True)
    ret = pggan_utils.maybe_resblock(input_layer,
                                     out_channels,
                                     conv2d_out,
                                     is_discriminator=True)
    return ret
Exemplo n.º 5
0
def encoder_two_layer_block(input_layer, out_channels, do_pixel_norm=False):
    input_shape = input_layer.shape
    input_channels = input_shape[3]
    conv2d_out = input_layer
    # The first layer's depth is the same as input.
    conv2d_out = pggan_utils.maybe_pixel_norm(
        pggan_utils.maybe_equalized_conv2d(
            conv2d_out,
            input_channels,
        ),
        do_pixel_norm=do_pixel_norm)
    # The second layer's depth is the output channel.
    conv2d_out = pggan_utils.maybe_pixel_norm(
        pggan_utils.maybe_equalized_conv2d(conv2d_out,
                                           out_channels,
                                           kernel_size=3),
        do_pixel_norm=do_pixel_norm)
    ret = pggan_utils.maybe_resblock(input_layer, out_channels, conv2d_out)
    return ret
Exemplo n.º 6
0
def discriminator_from_rgb_block(input_layer, out_channels):
    conv2d_out = input_layer
    conv2d_out = pggan_utils.maybe_equalized_conv2d(conv2d_out,
                                                    out_channels,
                                                    kernel_size=1,
                                                    is_discriminator=True)
    ret = pggan_utils.maybe_resblock(input_layer,
                                     out_channels,
                                     conv2d_out,
                                     is_discriminator=True)
    return ret
Exemplo n.º 7
0
def discriminator_before_fc(
    source,
    maybe_gdrop_fn=tf.identity,
    is_training=False,
    is_growing=False,
    alpha_grow=0.0,
    conditional_embed=None,
    do_self_attention=False,
    self_attention_hw=64,
    arg_scope_fn=pggan_utils.pggan_discriminator_arg_scope,
):
    """The main body of the PGGAN discriminator. Contains everything before the fully connected prediction layer.

  :param source: Input to discriminate.
  :param maybe_gdrop_fn: optional gdrop function. Default is do not apply gdrop.
  :param is_training: Affects update ops.
  :param is_growing: See PGGAN paper for details.
  :param conditional_embed: Optional conditional embedding of the input source. E.g. 'female, red hair, ...' etc.
  :param alpha_grow: See PGGAN paper for details.
  :param do_self_attention: See SAGAN paper for details.
  :param self_attention_hw: The height and width to start self attention at.
  :param arg_scope_fn: A tf.contrib.framework.arg_scope.
  :return: (A tensor of shape [batch, 1], end_points dictionary)
  """
    # Note: discriminator do not use local response normalization (aka pixel norm in this code).
    max_num_channels = pggan_utils.get_discriminator_max_num_channels()
    max_stage = get_discriminator_max_stage(int(source.shape[1]))
    assert max_stage >= 0
    end_points = {}
    source_shrinked_from_rgb = None
    source_hw = source.shape[1]  # Assume same height and width.
    with tf.contrib.framework.arg_scope(arg_scope_fn(is_training=is_training)):
        # From RGB blocks
        if is_growing:
            source_shrinked_from_rgb = tf.nn.avg_pool(source, (1, 2, 2, 1),
                                                      (1, 2, 2, 1), 'VALID')
            scope_name = 'from_rgb_%dx%d' % (source_hw / 2, source_hw / 2)
            with tf.variable_scope(scope_name):
                rgb_out = pggan_utils.get_num_channels(
                    max_stage - 1, max_num_channels=max_num_channels)
                source_shrinked_from_rgb = discriminator_from_rgb_block(
                    source_shrinked_from_rgb, rgb_out)
            end_points[scope_name] = source_shrinked_from_rgb

        scope_name = 'from_rgb_%dx%d' % (source_hw, source_hw)
        with tf.variable_scope(scope_name):
            net = discriminator_from_rgb_block(
                source,
                pggan_utils.get_num_channels(
                    max_stage, max_num_channels=max_num_channels))
            end_points[scope_name] = net

        # Discriminator blocks.
        # The blocks are down to 8x8. 4x4 is handled outside of the for loop.
        for stage in range(max_stage, 0, -1):
            num_channels = pggan_utils.get_num_channels(
                stage - 1, max_num_channels=max_num_channels)
            hw_div_by = 2**(max_stage - stage)
            current_hw = source_hw / hw_div_by

            # Self atttention. Note that this may not get called if self_attention_hw == min_hw == 4.
            net = pggan_utils.maybe_add_self_attention(do_self_attention,
                                                       self_attention_hw,
                                                       current_hw,
                                                       num_channels, net,
                                                       end_points)
            # Conv.
            scope_name = 'encoder_block_%dx%dx%d' % (current_hw, current_hw,
                                                     num_channels)
            with tf.variable_scope(scope_name):
                net = discriminator_two_layer_block(
                    net, num_channels, maybe_gdrop_fn=maybe_gdrop_fn)
                end_points[scope_name] = net
            # Down-sample.
            current_hw /= 2
            scope_name = 'downsample_to_%dx%dx%d' % (current_hw, current_hw,
                                                     num_channels)
            with tf.variable_scope(scope_name):
                net = tf.nn.avg_pool(net, (1, 2, 2, 1), (1, 2, 2, 1), 'VALID')
                end_points[scope_name] = net
            # If it is growing, encode the down-sampled image.
            if stage == max_stage and is_growing:
                assert source_shrinked_from_rgb is not None
                assert current_hw == source_hw / 2
                scope_name = 'encoder_block_interpolated_%dx%dx%d' % (
                    current_hw, current_hw, num_channels)
                with tf.variable_scope(scope_name):
                    net = net * alpha_grow + (
                        1 - alpha_grow) * source_shrinked_from_rgb
                    end_points[scope_name] = net

        # The final 4x4 block, which is a little bit different from all others.
        if conditional_embed is not None:
            net_h = int(net.shape[1])
            net_w = int(net.shape[2])
            repeated = tf.expand_dims(tf.expand_dims(conditional_embed,
                                                     axis=1),
                                      axis=2)
            repeated = util_misc.tf_repeat(repeated, [1, net_h, net_w, 1])
            net = tf.concat((net, repeated),
                            axis=-1,
                            name='concat_conditional_embed')

        with tf.variable_scope('before_fc_1x1x%d' % max_num_channels):
            # TODO: is this compatible with all normalization and GAN training methods (e.g. Dragan)?
            net = pggan_utils.minibatch_state_concat(net)
            net = pggan_utils.maybe_equalized_conv2d(maybe_gdrop_fn(net),
                                                     max_num_channels,
                                                     kernel_size=3,
                                                     is_discriminator=True,
                                                     padding='SAME')
            net = pggan_utils.maybe_equalized_conv2d(maybe_gdrop_fn(net),
                                                     max_num_channels,
                                                     kernel_size=4,
                                                     is_discriminator=True,
                                                     padding='VALID')
            end_points['before_fc_1x1x%d' % max_num_channels] = net

    end_points['before_fc'] = net
    return net, end_points
Exemplo n.º 8
0
def generator(
    source=None,
    dtype=tf.float32,
    is_training=False,
    is_growing=False,
    alpha_grow=0.0,
    target_shape=None,
    max_num_channels=None,
    arg_scope_fn=pggan_utils.pggan_generator_arg_scope,
    do_pixel_norm=False,
    do_self_attention=False,
    self_attention_hw=64,
    conditional_layer=None,
    unet_end_points=None,
):
    """PGGAN generator.

  :param source: An optional tensor specifying the embedding the generator is conditioned on.
  :param dtype: tf dtype.
  :param is_training: Affects batch norm and update ops.
  :param is_growing: See PGGAN paper for details.
  :param alpha_grow: See PGGAN paper for details.
  :param target_shape: Output shape of format [batch, height, width, channels].
  :param max_num_channels: Maximum number of channels for latent embeddings.
  :param arg_scope_fn: A tf.contrib.framework.arg_scope.
  :param do_pixel_norm: See PGGAN paper for details.
  :param do_self_attention: See SAGAN paper for details.
  :param self_attention_hw: The height and width to start self attention at.
  :param conditional_layer: Generates the image conditioned on this tensor.
  :param unet_end_points: Concatenate UNet to their corresponding layers.
  :return: (Tensor with shape target_shape containing generated images, end_points dictionary)
  """
    if max_num_channels is None:
        max_num_channels = FLAGS.pggan_max_num_channels
    max_stage = int(np.log2(int(
        target_shape[1]))) - 2  # hw=4->max_stage=0, hw=8->max_stage=1 ...
    assert max_stage >= 0
    end_points = {}
    # Get latent vector the generator is conditioned on.
    with tf.variable_scope('latent_vector'):
        if source is None:
            source = tf.random_normal(get_noise_shape(target_shape[0],
                                                      max_num_channels),
                                      dtype=dtype)
        if len(source.shape) == 2:
            source = tf.expand_dims(tf.expand_dims(source, 1), 2)
        assert len(source.shape) == 4, 'incorrect source shape for generator.'
        if source.shape[1] == 1 and source.shape[2] == 1:
            # Pads to 7x7 so that after the first conv layer with kernel size = 4, the size will be 4x4.
            source = tf.pad(source, paddings=((0, 0), (3, 3), (3, 3), (0, 0)))
    end_points['source'] = source
    net = source
    net_before_growth = None  # To be filled inside the for loop.

    with tf.contrib.framework.arg_scope(arg_scope_fn(is_training=is_training)):
        for stage in range(0, max_stage + 1):
            hw = 2**(stage + 2)
            output_channels = pggan_utils.get_num_channels(
                stage, max_num_channels)
            # 4x4 is a little different.
            if hw == 4:
                scope_name = 'block_%dx%dx%d' % (hw, hw, output_channels)
                with tf.variable_scope(scope_name):
                    if source.shape[1] == 7 and source.shape[2] == 7:
                        net = pggan_utils.maybe_pixel_norm(
                            pggan_utils.maybe_equalized_conv2d(
                                net,
                                output_channels,
                                kernel_size=4,
                                padding='VALID'),
                            do_pixel_norm=do_pixel_norm)
                    else:
                        # When the source is not random noise but is a tensor provided to the generator.
                        assert source.shape[1] == 4 and source.shape[2] == 4
                        net = pggan_utils.maybe_pixel_norm(
                            pggan_utils.maybe_equalized_conv2d(net,
                                                               output_channels,
                                                               kernel_size=3,
                                                               padding='SAME'),
                            do_pixel_norm=do_pixel_norm)
                    # Concatenate conditional layer before each block.
                    net = pggan_utils.maybe_concat_conditional_layer(
                        net, conditional_layer)
                    net = pggan_utils.maybe_pixel_norm(
                        pggan_utils.maybe_equalized_conv2d(
                            net, output_channels),
                        do_pixel_norm=do_pixel_norm)
                    assert net.shape[1] == net.shape[2] == hw
                    end_points[scope_name] = net
            else:
                # Outputs the image from the previous shape [hw/2, hw/2] to be used later.
                if stage == max_stage and is_growing:
                    scope_name = 'generator_to_rgb_%dx%d' % (hw / 2, hw / 2)
                    with tf.variable_scope(scope_name):
                        if FLAGS.use_larger_filter_at_rgb_layer:
                            kernel_size = min(7, hw / 2)
                        else:
                            kernel_size = 1
                        # No pixel norm in to_rgb layers.
                        net_before_growth = pggan_utils.maybe_equalized_conv2d(
                            net,
                            target_shape[-1],
                            kernel_size=kernel_size,
                            activation_fn=None)
                        net_before_growth = pggan_utils.resize_twice_as_big(
                            net_before_growth)
                        end_points[scope_name] = net_before_growth
                # Generator block.
                scope_name = 'block_%dx%dx%d' % (hw, hw, output_channels)
                with tf.variable_scope(scope_name):
                    net = generator_three_layer_block(
                        net,
                        output_channels,
                        do_pixel_norm=do_pixel_norm,
                        conditional_layer=conditional_layer,
                        unet_end_points=unet_end_points)
                    end_points[scope_name] = net

            # Adds self attention to the current layer if the `hw` matches `self_attention_hw`.
            net = pggan_utils.maybe_add_self_attention(do_self_attention,
                                                       self_attention_hw, hw,
                                                       output_channels, net,
                                                       end_points)

        scope_name = 'generator_to_rgb_%dx%d' % (hw, hw)
        with tf.variable_scope(scope_name):
            if FLAGS.use_larger_filter_at_rgb_layer:
                kernel_size = min(7, hw / 2)
            else:
                kernel_size = 1
            # No pixel norm in to_rgb layers.
            to_rgb_layer = pggan_utils.maybe_equalized_conv2d(
                net,
                target_shape[-1],
                kernel_size=kernel_size,
                activation_fn=None)
            if not is_growing:
                output_layer = to_rgb_layer
            else:
                assert net_before_growth is not None
                output_layer = to_rgb_layer * alpha_grow + (
                    1 - alpha_grow) * net_before_growth
                end_points['alpha_grow'] = alpha_grow

        end_points['output'] = output_layer
        if conditional_layer is not None:
            end_points['conditional_layer'] = conditional_layer
    return end_points['output'], end_points