def test_batch_norm_shapes(self): """Tests the batch norm layer. This is a minimal test to make sure that shapes are OK. """ bn = ops.BatchNorm() image = tf.random.normal([10, 32, 32, 3]) bn_image = bn(image) self.assertEqual([10, 32, 32, 3], bn_image.shape.as_list())
def generator(zs, target_class, gf_dim, num_classes, training=True): """Builds the generator segment of the graph, going from z -> G(z). Args: zs: Tensor representing the latent variables. target_class: The class from which we seek to sample. gf_dim: The gf dimension. num_classes: Number of classes in the labels. training: Whether in train mode or not. This affects things like batch normalization and spectral normalization. Returns: - The output layer of the generator. - A list containing all trainable varaibles defined by the model. """ with tf.compat.v1.variable_scope( 'generator', reuse=tf.compat.v1.AUTO_REUSE) as gen_scope: act0 = ops.snlinear(zs, gf_dim * 16 * 4 * 4, training=training, name='g_snh0') act0 = tf.reshape(act0, [-1, 4, 4, gf_dim * 16]) # pylint: disable=line-too-long act1 = block(act0, target_class, gf_dim * 16, num_classes, 'g_block1', training) # 8 act2 = block(act1, target_class, gf_dim * 8, num_classes, 'g_block2', training) # 16 act3 = block(act2, target_class, gf_dim * 4, num_classes, 'g_block3', training) # 32 act3 = ops.sn_non_local_block_sim(act3, training, name='g_ops') # 32 act4 = block(act3, target_class, gf_dim * 2, num_classes, 'g_block4', training) # 64 act5 = block(act4, target_class, gf_dim, num_classes, 'g_block5', training) # 128 bn = ops.BatchNorm(name='g_bn') act5 = tf.nn.relu(bn(act5)) act6 = ops.snconv2d(act5, 3, 3, 3, 1, 1, training, 'g_snconv_last') out = tf.nn.tanh(act6) var_list = tf.compat.v1.get_collection( tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, gen_scope.name) return out, var_list