Beispiel #1
0
def discriminator_128(image, labels, df_dim, number_classes, act=tf.nn.relu):
    """Builds the discriminator graph.

  Args:
    image: The current batch of images to classify as fake or real.
    labels: The corresponding labels for the images.
    df_dim: The df dimension.
    number_classes: The number of classes in the labels.
    act: The activation function used in the discriminator.
  Returns:
    - A `Tensor` representing the logits of the discriminator.
    - A list containing all trainable varaibles defined by the model.
  """
    with tf.compat.v1.variable_scope(
            'discriminator', reuse=tf.compat.v1.AUTO_REUSE) as dis_scope:
        h0 = optimized_block(image, df_dim, 'd_optimized_block1',
                             act=act)  # 64 * 64
        h1 = block(h0, df_dim * 2, 'd_block2', act=act)  # 32 * 32
        h1 = ops.sn_non_local_block_sim(h1, name='d_ops')  # 32 * 32
        h2 = block(h1, df_dim * 4, 'd_block3', act=act)  # 16 * 16
        h3 = block(h2, df_dim * 8, 'd_block4', act=act)  # 8 * 8
        h4 = block(h3, df_dim * 16, 'd_block5', act=act)  # 4 * 4
        h5 = block(h4, df_dim * 16, 'd_block6', downsample=False, act=act)
        h5_act = act(h5)
        h6 = tf.reduce_sum(input_tensor=h5_act, axis=[1, 2])
        output = ops.snlinear(h6, 1, name='d_sn_linear')
        classification_output = ops.snlinear(h6,
                                             flags.FLAGS.num_classes,
                                             name='d_sn_linear_class')
        if labels is None:
            pseudo_labels = tf.argmax(classification_output, axis=1)
            h_labels = ops.sn_embedding(pseudo_labels,
                                        number_classes,
                                        df_dim * 16,
                                        name='d_embedding')
        else:
            h_labels = ops.sn_embedding(labels,
                                        number_classes,
                                        df_dim * 16,
                                        name='d_embedding')

        output += tf.reduce_sum(input_tensor=h6 * h_labels,
                                axis=1,
                                keepdims=True)

    var_list = tf.compat.v1.get_collection(
        tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, dis_scope.name)
    return output, classification_output, var_list
Beispiel #2
0
def biggan_discriminator_128(image,
                             labels,
                             df_dim,
                             number_classes,
                             act=tf.nn.relu):
    """Builds the discriminator graph.
  
  TODO(ilyak): debug, this implementation doesn't work as is.
  ...
  Only position of the non local block changes
  """
    with tf.compat.v1.variable_scope(
            'discriminator', reuse=tf.compat.v1.AUTO_REUSE) as dis_scope:
        h0 = optimized_block(image, df_dim, 'd_optimized_block1',
                             act=act)  # 64 * 64
        h0 = ops.sn_non_local_block_sim(h0, name='d_ops')  # 64 * 64
        h1 = block(h0, df_dim * 2, 'd_block2', act=act)  # 32 * 32
        h2 = block(h1, df_dim * 4, 'd_block3', act=act)  # 16 * 16
        h3 = block(h2, df_dim * 8, 'd_block4', act=act)  # 8 * 8
        h4 = block(h3, df_dim * 16, 'd_block5', act=act)  # 4 * 4
        h5 = block(h4, df_dim * 16, 'd_block6', downsample=False, act=act)
        h5_act = act(h5)
        h6 = tf.reduce_sum(input_tensor=h5_act, axis=[1, 2])
        output = ops.snlinear(h6, 1, name='d_sn_linear')
        classification_output = ops.snlinear(h6,
                                             flags.FLAGS.num_classes,
                                             name='d_sn_linear_class')
        if labels is None:
            pseudo_labels = tf.argmax(classification_output, axis=1)
            h_labels = ops.sn_embedding(pseudo_labels,
                                        number_classes,
                                        df_dim * 16,
                                        name='d_embedding')
        else:
            h_labels = ops.sn_embedding(labels,
                                        number_classes,
                                        df_dim * 16,
                                        name='d_embedding')

        output += tf.reduce_sum(input_tensor=h6 * h_labels,
                                axis=1,
                                keepdims=True)

    var_list = tf.compat.v1.get_collection(
        tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, dis_scope.name)
    return output, classification_output, var_list
Beispiel #3
0
def biggan_generator_128(z, target_class, gf_dim, num_classes, training=True):
    """...
  
  TODO(ilyak): Fix, this does not work as is.
  
  y is embedded, and skip concatenate with z
  
  4th block has attention (64x64 resolution)
  
  batch norm is conditional batch norm
  no layer_norm
  """
    # setables
    embed_y_dim = 128
    embed_bias = False
    with tf.compat.v1.variable_scope(
            'generator', reuse=tf.compat.v1.AUTO_REUSE) as gen_scope:
        num_blocks = 5
        # embedding of y that is shared
        target_class_onehot = tf.one_hot(target_class, num_classes)
        y = ops.linear(target_class_onehot,
                       embed_y_dim,
                       use_bias=embed_bias,
                       name="embed_y")
        y_per_block = num_blocks * [y]
        # skip z connections / hierarchical z
        z_per_block = tf.split(z, num_blocks + 1, axis=1)
        z0, z_per_block = z_per_block[0], z_per_block[1:]
        y_per_block = [tf.concat([zi, y], 1) for zi in z_per_block]

        act0 = ops.snlinear(z0,
                            gf_dim * 16 * 4 * 4,
                            training=training,
                            name='g_snh0')
        act0 = tf.reshape(act0, [-1, 4, 4, gf_dim * 16])

        # pylint: disable=line-too-long
        act1 = biggan_block(act0, y_per_block[0], gf_dim * 16, num_classes,
                            'g_block1', training)  # 8
        act2 = biggan_block(act1, y_per_block[1], gf_dim * 8, num_classes,
                            'g_block2', training)  # 16
        act3 = biggan_block(act2, y_per_block[2], gf_dim * 4, num_classes,
                            'g_block3', training)  # 32
        act4 = biggan_block(act3, y_per_block[3], gf_dim * 2, num_classes,
                            'g_block4', training)  # 64
        act4 = ops.sn_non_local_block_sim(act4, training, name='g_ops')  # 64
        act5 = biggan_block(act4, y_per_block[4], gf_dim, num_classes,
                            'g_block5', training)  # 128
        act5 = tf.nn.relu(
            tfgan.tpu.batch_norm(act5,
                                 training,
                                 conditional_class_labels=None,
                                 name='g_bn'))
        act6 = ops.snconv2d(act5, 3, 3, 3, 1, 1, training, 'g_snconv_last')
        out = (tf.nn.tanh(act6) + 1.0) / 2.0
    var_list = tf.compat.v1.get_collection(
        tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, gen_scope.name)
    return out, var_list
Beispiel #4
0
    def test_snlinear_shapes(self):
        """Tests the spectrally normalized linear layer.

    This is a minimal test to make sure that shapes are OK.
    The vector shape should match after snlinear.
    """
        if tf.executing_eagerly():
            # `compute_spectral_norm` doesn't work when executing eagerly.
            return
        vector = tf.random.normal([10, 32])
        snconv_vector = ops.snlinear(vector, 32)
        self.assertEqual([10, 32], snconv_vector.shape.as_list())
Beispiel #5
0
def generator(zs, target_class, gf_dim, num_classes, training=True):
    """Builds the generator segment of the graph, going from z -> G(z).

  Args:
    zs: Tensor representing the latent variables.
    target_class: The class from which we seek to sample.
    gf_dim: The gf dimension.
    num_classes: Number of classes in the labels.
    training: Whether in train mode or not. This affects things like batch
      normalization and spectral normalization.

  Returns:
    - The output layer of the generator.
    - A list containing all trainable varaibles defined by the model.
  """
    with tf.compat.v1.variable_scope(
            'generator', reuse=tf.compat.v1.AUTO_REUSE) as gen_scope:

        act0 = ops.snlinear(zs,
                            gf_dim * 16 * 4 * 4,
                            training=training,
                            name='g_snh0')
        act0 = tf.reshape(act0, [-1, 4, 4, gf_dim * 16])

        # pylint: disable=line-too-long
        act1 = block(act0, target_class, gf_dim * 16, num_classes, 'g_block1',
                     training)  # 8
        act2 = block(act1, target_class, gf_dim * 8, num_classes, 'g_block2',
                     training)  # 16
        act3 = block(act2, target_class, gf_dim * 4, num_classes, 'g_block3',
                     training)  # 32
        act3 = ops.sn_non_local_block_sim(act3, training, name='g_ops')  # 32
        act4 = block(act3, target_class, gf_dim * 2, num_classes, 'g_block4',
                     training)  # 64
        act5 = block(act4, target_class, gf_dim, num_classes, 'g_block5',
                     training)  # 128
        act5 = tf.nn.relu(
            tfgan.tpu.batch_norm(act5,
                                 training,
                                 conditional_class_labels=None,
                                 name='g_bn'))
        act6 = ops.snconv2d(act5, 3, 3, 3, 1, 1, training, 'g_snconv_last')
        out = tf.nn.tanh(act6)
    var_list = tf.compat.v1.get_collection(
        tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, gen_scope.name)
    return out, var_list
Beispiel #6
0
def conditional_batch_norm(inputs,
                           y,
                           is_training,
                           axis=-1,
                           variance_epsilon=1e-3,
                           center=True,
                           scale=True,
                           beta_initializer=tf.compat.v1.initializers.zeros(),
                           gamma_initializer=tf.compat.v1.initializers.ones(),
                           batch_axis=0,
                           name='batch_norm'):
    """Adds Conditional Batch Norm when label is not a class label.
  
  Taken from compare_gan arch_ops's conditional_batch_norm.

  Args:
    inputs: Tensor of inputs (e.g. images).
    y: Need not be class labels/one hot.
    is_training: Whether or not the layer is in training mode. In training
      mode it would accumulate the statistics of the moments into the
      `moving_mean` and `moving_variance` using an exponential moving average
      with the given `decay`. When is_training=False, these variables are not
      updated, and the precomputed values are used verbatim.
    axis: Integer, the axis that should be normalized (typically the features
        axis). For instance, after a `Convolution2D` layer with
        `data_format="channels_first"`, set `axis=1` in `BatchNormalization`.
    variance_epsilon: A small float number to avoid dividing by 0.
    center: If True, add offset of `beta` to normalized tensor. If False,
      `beta` is ignored.
    scale: If True, multiply by `gamma`. If False, `gamma` is
      not used. When the next layer is linear (also e.g. `nn.relu`), this can
      be disabled since the scaling can be done by the next layer.
    beta_initializer: Initializer for the beta weight.
    gamma_initializer: Initializer for the gamma weight.
    batch_axis: The axis of the batch dimension.
    name: name: String name to be used for scoping.
  Returns:
    Output tensor.
  """
    if y is None:
        raise ValueError(
            "You must provide y for conditional batch normalization.")
    if y.shape.ndims != 2:
        raise ValueError("Conditioning must have rank 2.")
    with tf.compat.v1.variable_scope(name,
                                     values=[inputs],
                                     reuse=tf.compat.v1.AUTO_REUSE):
        outputs = tfgan.tpu.standardize_batch(inputs,
                                              is_training=is_training,
                                              decay=0.9,
                                              epsilon=1e-5,
                                              use_moving_averages=False)
        num_channels = tf.compat.dimension_value(inputs.shape[-1])
        with tf.compat.v1.variable_scope("condition",
                                         values=[inputs, y],
                                         reuse=tf.compat.v1.AUTO_REUSE):
            if scale:
                gamma = ops.snlinear(y,
                                     num_channels,
                                     name="gamma",
                                     use_bias=False)
                gamma = tf.reshape(gamma, [-1, 1, 1, num_channels])
                outputs *= gamma
            if center:
                beta = ops.snlinear(y,
                                    num_channels,
                                    name="beta",
                                    use_bias=False)
                beta = tf.reshape(beta, [-1, 1, 1, num_channels])
                outputs += beta
            return outputs