Beispiel #1
0
def discriminator(image, labels, df_dim, number_classes, act=tf.nn.relu):
    """Builds the discriminator graph.

  Args:
    image: The current batch of images to classify as fake or real.
    labels: The corresponding labels for the images.
    df_dim: The df dimension.
    number_classes: The number of classes in the labels.
    act: The activation function used in the discriminator.
  Returns:
    - A `Tensor` representing the logits of the discriminator.
    - A list containing all trainable varaibles defined by the model.
  """
    with tf.variable_scope('discriminator', reuse=tf.AUTO_REUSE) as dis_scope:
        h0 = optimized_block(image, df_dim, 'd_optimized_block1',
                             act=act)  # 64 * 64
        h1 = block(h0, df_dim * 2, 'd_block2', act=act)  # 32 * 32
        h1 = ops.sn_non_local_block_sim(h1, name='d_ops')  # 32 * 32
        h2 = block(h1, df_dim * 4, 'd_block3', act=act)  # 16 * 16
        h3 = block(h2, df_dim * 8, 'd_block4', act=act)  # 8 * 8
        h4 = block(h3, df_dim * 16, 'd_block5', act=act)  # 4 * 4
        h5 = block(h4, df_dim * 16, 'd_block6', downsample=False, act=act)
        h5_act = act(h5)
        h6 = tf.reduce_sum(input_tensor=h5_act, axis=[1, 2])
        output = ops.snlinear(h6, 1, name='d_sn_linear')
        h_labels = ops.sn_embedding(labels,
                                    number_classes,
                                    df_dim * 16,
                                    name='d_embedding')
        output += tf.reduce_sum(input_tensor=h6 * h_labels,
                                axis=1,
                                keepdims=True)
    var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                 dis_scope.name)
    return output, var_list
Beispiel #2
0
 def test_sn_non_local_block_sim_shapes(self):
     """Tests that downsampling has the desired effect on shape."""
     if tf.executing_eagerly():
         # `compute_spectral_norm` doesn't work when executing eagerly.
         return
     image = tf.random.normal([10, 8, 8, 64])
     big_image = ops.sn_non_local_block_sim(image, name='test_sa')
     self.assertEqual([10, 8, 8, 64], big_image.shape.as_list())
Beispiel #3
0
def biggan_generator_128(z, target_class, gf_dim, num_classes, training=True):
    """...
  
  TODO(ilyak): Fix, this does not work as is.
  
  y is embedded, and skip concatenate with z
  
  4th block has attention (64x64 resolution)
  
  batch norm is conditional batch norm
  no layer_norm
  """
    # setables
    embed_y_dim = 128
    embed_bias = False
    with tf.compat.v1.variable_scope(
            'generator', reuse=tf.compat.v1.AUTO_REUSE) as gen_scope:
        num_blocks = 5
        # embedding of y that is shared
        target_class_onehot = tf.one_hot(target_class, num_classes)
        y = ops.linear(target_class_onehot,
                       embed_y_dim,
                       use_bias=embed_bias,
                       name="embed_y")
        y_per_block = num_blocks * [y]
        # skip z connections / hierarchical z
        z_per_block = tf.split(z, num_blocks + 1, axis=1)
        z0, z_per_block = z_per_block[0], z_per_block[1:]
        y_per_block = [tf.concat([zi, y], 1) for zi in z_per_block]

        act0 = ops.snlinear(z0,
                            gf_dim * 16 * 4 * 4,
                            training=training,
                            name='g_snh0')
        act0 = tf.reshape(act0, [-1, 4, 4, gf_dim * 16])

        # pylint: disable=line-too-long
        act1 = biggan_block(act0, y_per_block[0], gf_dim * 16, num_classes,
                            'g_block1', training)  # 8
        act2 = biggan_block(act1, y_per_block[1], gf_dim * 8, num_classes,
                            'g_block2', training)  # 16
        act3 = biggan_block(act2, y_per_block[2], gf_dim * 4, num_classes,
                            'g_block3', training)  # 32
        act4 = biggan_block(act3, y_per_block[3], gf_dim * 2, num_classes,
                            'g_block4', training)  # 64
        act4 = ops.sn_non_local_block_sim(act4, training, name='g_ops')  # 64
        act5 = biggan_block(act4, y_per_block[4], gf_dim, num_classes,
                            'g_block5', training)  # 128
        act5 = tf.nn.relu(
            tfgan.tpu.batch_norm(act5,
                                 training,
                                 conditional_class_labels=None,
                                 name='g_bn'))
        act6 = ops.snconv2d(act5, 3, 3, 3, 1, 1, training, 'g_snconv_last')
        out = (tf.nn.tanh(act6) + 1.0) / 2.0
    var_list = tf.compat.v1.get_collection(
        tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, gen_scope.name)
    return out, var_list
Beispiel #4
0
def generator(zs, target_class, gf_dim, num_classes, training=True):
    """Builds the generator segment of the graph, going from z -> G(z).

  Args:
    zs: Tensor representing the latent variables.
    target_class: The class from which we seek to sample.
    gf_dim: The gf dimension.
    num_classes: Number of classes in the labels.
    training: Whether in train mode or not. This affects things like batch
      normalization and spectral normalization.

  Returns:
    - The output layer of the generator.
    - A list containing all trainable varaibles defined by the model.
  """
    with tf.compat.v1.variable_scope(
            'generator', reuse=tf.compat.v1.AUTO_REUSE) as gen_scope:

        act0 = ops.snlinear(zs,
                            gf_dim * 16 * 4 * 4,
                            training=training,
                            name='g_snh0')
        act0 = tf.reshape(act0, [-1, 4, 4, gf_dim * 16])

        # pylint: disable=line-too-long
        act1 = block(act0, target_class, gf_dim * 16, num_classes, 'g_block1',
                     training)  # 8
        act2 = block(act1, target_class, gf_dim * 8, num_classes, 'g_block2',
                     training)  # 16
        act3 = block(act2, target_class, gf_dim * 4, num_classes, 'g_block3',
                     training)  # 32
        act3 = ops.sn_non_local_block_sim(act3, training, name='g_ops')  # 32
        act4 = block(act3, target_class, gf_dim * 2, num_classes, 'g_block4',
                     training)  # 64
        act5 = block(act4, target_class, gf_dim, num_classes, 'g_block5',
                     training)  # 128
        act5 = tf.nn.relu(
            tfgan.tpu.batch_norm(act5,
                                 training,
                                 conditional_class_labels=None,
                                 name='g_bn'))
        act6 = ops.snconv2d(act5, 3, 3, 3, 1, 1, training, 'g_snconv_last')
        out = tf.nn.tanh(act6)
    var_list = tf.compat.v1.get_collection(
        tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, gen_scope.name)
    return out, var_list
Beispiel #5
0
def biggan_discriminator_128(image,
                             labels,
                             df_dim,
                             number_classes,
                             act=tf.nn.relu):
    """Builds the discriminator graph.
  
  TODO(ilyak): debug, this implementation doesn't work as is.
  ...
  Only position of the non local block changes
  """
    with tf.compat.v1.variable_scope(
            'discriminator', reuse=tf.compat.v1.AUTO_REUSE) as dis_scope:
        h0 = optimized_block(image, df_dim, 'd_optimized_block1',
                             act=act)  # 64 * 64
        h0 = ops.sn_non_local_block_sim(h0, name='d_ops')  # 64 * 64
        h1 = block(h0, df_dim * 2, 'd_block2', act=act)  # 32 * 32
        h2 = block(h1, df_dim * 4, 'd_block3', act=act)  # 16 * 16
        h3 = block(h2, df_dim * 8, 'd_block4', act=act)  # 8 * 8
        h4 = block(h3, df_dim * 16, 'd_block5', act=act)  # 4 * 4
        h5 = block(h4, df_dim * 16, 'd_block6', downsample=False, act=act)
        h5_act = act(h5)
        h6 = tf.reduce_sum(input_tensor=h5_act, axis=[1, 2])
        output = ops.snlinear(h6, 1, name='d_sn_linear')
        classification_output = ops.snlinear(h6,
                                             flags.FLAGS.num_classes,
                                             name='d_sn_linear_class')
        if labels is None:
            pseudo_labels = tf.argmax(classification_output, axis=1)
            h_labels = ops.sn_embedding(pseudo_labels,
                                        number_classes,
                                        df_dim * 16,
                                        name='d_embedding')
        else:
            h_labels = ops.sn_embedding(labels,
                                        number_classes,
                                        df_dim * 16,
                                        name='d_embedding')

        output += tf.reduce_sum(input_tensor=h6 * h_labels,
                                axis=1,
                                keepdims=True)

    var_list = tf.compat.v1.get_collection(
        tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, dis_scope.name)
    return output, classification_output, var_list