예제 #1
0
파일: ops_test.py 프로젝트: sts-sadr/gan-2
    def test_batch_norm_shapes(self):
        """Tests the batch norm layer.

    This is a minimal test to make sure that shapes are OK.
    """
        bn = ops.BatchNorm()
        image = tf.random.normal([10, 32, 32, 3])
        bn_image = bn(image)
        self.assertEqual([10, 32, 32, 3], bn_image.shape.as_list())
예제 #2
0
def generator(zs, target_class, gf_dim, num_classes, training=True):
    """Builds the generator segment of the graph, going from z -> G(z).

  Args:
    zs: Tensor representing the latent variables.
    target_class: The class from which we seek to sample.
    gf_dim: The gf dimension.
    num_classes: Number of classes in the labels.
    training: Whether in train mode or not. This affects things like batch
      normalization and spectral normalization.

  Returns:
    - The output layer of the generator.
    - A list containing all trainable varaibles defined by the model.
  """
    with tf.compat.v1.variable_scope(
            'generator', reuse=tf.compat.v1.AUTO_REUSE) as gen_scope:
        act0 = ops.snlinear(zs,
                            gf_dim * 16 * 4 * 4,
                            training=training,
                            name='g_snh0')
        act0 = tf.reshape(act0, [-1, 4, 4, gf_dim * 16])

        # pylint: disable=line-too-long
        act1 = block(act0, target_class, gf_dim * 16, num_classes, 'g_block1',
                     training)  # 8
        act2 = block(act1, target_class, gf_dim * 8, num_classes, 'g_block2',
                     training)  # 16
        act3 = block(act2, target_class, gf_dim * 4, num_classes, 'g_block3',
                     training)  # 32
        act3 = ops.sn_non_local_block_sim(act3, training, name='g_ops')  # 32
        act4 = block(act3, target_class, gf_dim * 2, num_classes, 'g_block4',
                     training)  # 64
        act5 = block(act4, target_class, gf_dim, num_classes, 'g_block5',
                     training)  # 128
        bn = ops.BatchNorm(name='g_bn')

        act5 = tf.nn.relu(bn(act5))
        act6 = ops.snconv2d(act5, 3, 3, 3, 1, 1, training, 'g_snconv_last')
        out = tf.nn.tanh(act6)
    var_list = tf.compat.v1.get_collection(
        tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, gen_scope.name)
    return out, var_list