Пример #1
0
def block(x, labels, out_channels, num_classes, name, training=True):
    """Builds the residual blocks used in the generator.

  Args:
    x: The 4D input tensor.
    labels: The labels of the class we seek to sample from.
    out_channels: Integer number of features in the output layer.
    num_classes: Integer number of classes in the labels.
    name: The variable scope name for the block.
    training: Whether this block is for training or not.
  Returns:
    A `Tensor` representing the output of the operation.
  """
    with tf.compat.v1.variable_scope(name):
        labels_onehot = tf.one_hot(labels, num_classes)
        x_0 = x
        x = tf.nn.relu(
            tfgan.tpu.batch_norm(x, training, labels_onehot, name='cbn_0'))
        x = usample(x)
        x = ops.snconv2d(x, out_channels, 3, 3, 1, 1, training, 'snconv1')
        x = tf.nn.relu(
            tfgan.tpu.batch_norm(x, training, labels_onehot, name='cbn_1'))
        x = ops.snconv2d(x, out_channels, 3, 3, 1, 1, training, 'snconv2')

        x_0 = usample(x_0)
        x_0 = ops.snconv2d(x_0, out_channels, 1, 1, 1, 1, training, 'snconv3')

        return x_0 + x
Пример #2
0
def block(x, out_channels, name, downsample=True, act=tf.nn.relu):
    """Builds the residual blocks used in the discriminator.

  Args:
    x: The 4D input vector.
    out_channels: Number of features in the output layer.
    name: The variable scope name for the block.
    downsample: If True, downsample the spatial size the input tensor by
                a factor of 2 on each side. If False, the spatial size of the
                input tensor is unchanged.
    act: The activation function used in the block.
  Returns:
    A `Tensor` representing the output of the operation.
  """
    with tf.variable_scope(name):
        input_channels = x.shape.as_list()[-1]
        x_0 = x
        x = act(x)
        x = ops.snconv2d(x, out_channels, 3, 3, 1, 1, name='sn_conv1')
        x = act(x)
        x = ops.snconv2d(x, out_channels, 3, 3, 1, 1, name='sn_conv2')
        if downsample:
            x = dsample(x)
        if downsample or input_channels != out_channels:
            x_0 = ops.snconv2d(x_0, out_channels, 1, 1, 1, 1, name='sn_conv3')
            if downsample:
                x_0 = dsample(x_0)
        return x_0 + x
Пример #3
0
def biggan_block(x, y, out_channels, num_classes, name, training=True):
    """Builds the residual blocks used in the generator.
  ...
  """
    with tf.compat.v1.variable_scope(name):
        x_0 = x
        x = tf.nn.relu(conditional_batch_norm(x, y, training, name='cbn_0'))
        x = usample(x)
        x = ops.snconv2d(x, out_channels, 3, 3, 1, 1, training, 'snconv1')
        x = tf.nn.relu(conditional_batch_norm(x, y, training, name='cbn_1'))
        x = ops.snconv2d(x, out_channels, 3, 3, 1, 1, training, 'snconv2')

        x_0 = usample(x_0)
        x_0 = ops.snconv2d(x_0, out_channels, 1, 1, 1, 1, training, 'snconv3')

        return x_0 + x
Пример #4
0
def biggan_generator_128(z, target_class, gf_dim, num_classes, training=True):
    """...
  
  TODO(ilyak): Fix, this does not work as is.
  
  y is embedded, and skip concatenate with z
  
  4th block has attention (64x64 resolution)
  
  batch norm is conditional batch norm
  no layer_norm
  """
    # setables
    embed_y_dim = 128
    embed_bias = False
    with tf.compat.v1.variable_scope(
            'generator', reuse=tf.compat.v1.AUTO_REUSE) as gen_scope:
        num_blocks = 5
        # embedding of y that is shared
        target_class_onehot = tf.one_hot(target_class, num_classes)
        y = ops.linear(target_class_onehot,
                       embed_y_dim,
                       use_bias=embed_bias,
                       name="embed_y")
        y_per_block = num_blocks * [y]
        # skip z connections / hierarchical z
        z_per_block = tf.split(z, num_blocks + 1, axis=1)
        z0, z_per_block = z_per_block[0], z_per_block[1:]
        y_per_block = [tf.concat([zi, y], 1) for zi in z_per_block]

        act0 = ops.snlinear(z0,
                            gf_dim * 16 * 4 * 4,
                            training=training,
                            name='g_snh0')
        act0 = tf.reshape(act0, [-1, 4, 4, gf_dim * 16])

        # pylint: disable=line-too-long
        act1 = biggan_block(act0, y_per_block[0], gf_dim * 16, num_classes,
                            'g_block1', training)  # 8
        act2 = biggan_block(act1, y_per_block[1], gf_dim * 8, num_classes,
                            'g_block2', training)  # 16
        act3 = biggan_block(act2, y_per_block[2], gf_dim * 4, num_classes,
                            'g_block3', training)  # 32
        act4 = biggan_block(act3, y_per_block[3], gf_dim * 2, num_classes,
                            'g_block4', training)  # 64
        act4 = ops.sn_non_local_block_sim(act4, training, name='g_ops')  # 64
        act5 = biggan_block(act4, y_per_block[4], gf_dim, num_classes,
                            'g_block5', training)  # 128
        act5 = tf.nn.relu(
            tfgan.tpu.batch_norm(act5,
                                 training,
                                 conditional_class_labels=None,
                                 name='g_bn'))
        act6 = ops.snconv2d(act5, 3, 3, 3, 1, 1, training, 'g_snconv_last')
        out = (tf.nn.tanh(act6) + 1.0) / 2.0
    var_list = tf.compat.v1.get_collection(
        tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, gen_scope.name)
    return out, var_list
Пример #5
0
    def test_snconv2d_shapes(self):
        """Tests the spectrally normalized 2d conv function.

    This is a minimal test to make sure that shapes are OK.
    The image shape should match after snconv is applied.
    """
        if tf.executing_eagerly():
            # `compute_spectral_norm` doesn't work when executing eagerly.
            return
        image = tf.random.normal([10, 32, 32, 3])
        snconv_image = ops.snconv2d(image, 3, k_h=3, k_w=3, d_h=1, d_w=1)
        self.assertEqual([10, 32, 32, 3], snconv_image.shape.as_list())
Пример #6
0
def optimized_block(x, out_channels, name, act=tf.nn.relu):
    """Builds optimized residual blocks for downsampling.

  Compared with block, optimized_block always downsamples the spatial resolution
  by a factor of 2 on each side.

  Args:
    x: The 4D input vector.
    out_channels: Number of features in the output layer.
    name: The variable scope name for the block.
    act: The activation function used in the block.
  Returns:
    A `Tensor` representing the output of the operation.
  """
    with tf.variable_scope(name):
        x_0 = x
        x = ops.snconv2d(x, out_channels, 3, 3, 1, 1, name='sn_conv1')
        x = act(x)
        x = ops.snconv2d(x, out_channels, 3, 3, 1, 1, name='sn_conv2')
        x = dsample(x)
        x_0 = dsample(x_0)
        x_0 = ops.snconv2d(x_0, out_channels, 1, 1, 1, 1, name='sn_conv3')
        return x + x_0
Пример #7
0
def generator(zs, target_class, gf_dim, num_classes, training=True):
    """Builds the generator segment of the graph, going from z -> G(z).

  Args:
    zs: Tensor representing the latent variables.
    target_class: The class from which we seek to sample.
    gf_dim: The gf dimension.
    num_classes: Number of classes in the labels.
    training: Whether in train mode or not. This affects things like batch
      normalization and spectral normalization.

  Returns:
    - The output layer of the generator.
    - A list containing all trainable varaibles defined by the model.
  """
    with tf.compat.v1.variable_scope(
            'generator', reuse=tf.compat.v1.AUTO_REUSE) as gen_scope:

        act0 = ops.snlinear(zs,
                            gf_dim * 16 * 4 * 4,
                            training=training,
                            name='g_snh0')
        act0 = tf.reshape(act0, [-1, 4, 4, gf_dim * 16])

        # pylint: disable=line-too-long
        act1 = block(act0, target_class, gf_dim * 16, num_classes, 'g_block1',
                     training)  # 8
        act2 = block(act1, target_class, gf_dim * 8, num_classes, 'g_block2',
                     training)  # 16
        act3 = block(act2, target_class, gf_dim * 4, num_classes, 'g_block3',
                     training)  # 32
        act3 = ops.sn_non_local_block_sim(act3, training, name='g_ops')  # 32
        act4 = block(act3, target_class, gf_dim * 2, num_classes, 'g_block4',
                     training)  # 64
        act5 = block(act4, target_class, gf_dim, num_classes, 'g_block5',
                     training)  # 128
        act5 = tf.nn.relu(
            tfgan.tpu.batch_norm(act5,
                                 training,
                                 conditional_class_labels=None,
                                 name='g_bn'))
        act6 = ops.snconv2d(act5, 3, 3, 3, 1, 1, training, 'g_snconv_last')
        out = tf.nn.tanh(act6)
    var_list = tf.compat.v1.get_collection(
        tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, gen_scope.name)
    return out, var_list