def conditional_discriminator(img, conditioning, is_training=True): """Discriminator for CIFAR images. Args: img: A Tensor of shape [batch size, width, height, channels], that can be either real or generated. It is the discriminator's goal to distinguish between the two. conditioning: A 2-tuple of Tensors representing (noise, one_hot_labels). is_training: If `True`, batch norm uses batch statistics. If `False`, batch norm uses the exponential moving average collected from population statistics. Returns: A 1D Tensor of shape [batch size] representing the confidence that the images are real. The output can lie in [-inf, inf], with positive values indicating high confidence that the images are real. """ logits, end_points = dcgan.discriminator(img, is_training=is_training, fused_batch_norm=True) # Condition the last convolution layer. _, one_hot_labels = conditioning net = _last_conv_layer(end_points) net = tfgan.features.condition_tensor_from_onehot( tf.contrib.layers.flatten(net), one_hot_labels) logits = tf.contrib.layers.linear(net, 1) return logits
def test_discriminator_invalid_input(self): wrong_dim_img = tf.zeros([5, 32, 32]) with self.assertRaises(ValueError): dcgan.discriminator(wrong_dim_img) spatially_undefined_shape = tf.placeholder(tf.float32, [5, 32, None, 3]) with self.assertRaises(ValueError): dcgan.discriminator(spatially_undefined_shape) not_square = tf.zeros([5, 32, 16, 3]) with self.assertRaisesRegexp(ValueError, 'not have equal width and height'): dcgan.discriminator(not_square) not_power_2 = tf.zeros([5, 30, 30, 3]) with self.assertRaisesRegexp(ValueError, 'not a power of 2'): dcgan.discriminator(not_power_2)
def test_discriminator_graph(self): # Check graph construction for a number of image size/depths and batch # sizes. for i, batch_size in zip(xrange(1, 6), xrange(3, 8)): tf.reset_default_graph() img_w = 2**i image = tf.random_uniform([batch_size, img_w, img_w, 3], -1, 1) output, end_points = dcgan.discriminator(image, depth=32) self.assertAllEqual([batch_size, 1], output.get_shape().as_list()) expected_names = ['conv%i' % j for j in xrange(1, i + 1)] + ['logits'] self.assertSetEqual(set(expected_names), set(end_points.keys())) # Check layer depths. for j in range(1, i + 1): layer = end_points['conv%i' % j] self.assertEqual(32 * 2**(j - 1), layer.get_shape().as_list()[-1])
def _encoder(img_batch, is_training=True, bits=64, depth=64): """Maps images to internal representation. Args: img_batch: Stuff is_training: Stuff bits: Number of bits per patch. depth: Stuff Returns: Real-valued 2D Tensor of size [batch_size, bits]. """ _, end_points = dcgan.discriminator(img_batch, depth=depth, is_training=is_training, scope='Encoder') # (joelshor): Make the DCGAN convolutional layer that converts to logits # not trainable, since it doesn't affect the encoder output. # Get the pre-logit layer, which is the last conv. net = _last_conv_layer(end_points) # Transform the features to the proper number of bits. with tf.variable_scope('EncoderTransformer'): encoded = tf.contrib.layers.conv2d(net, bits, kernel_size=1, stride=1, padding='VALID', normalizer_fn=None, activation_fn=None) encoded = tf.squeeze(encoded, [1, 2]) encoded.shape.assert_has_rank(2) # Map encoded to the range [-1, 1]. return tf.nn.softsign(encoded)
def discriminator(img, unused_conditioning, is_training=True): """Discriminator for CIFAR images. Args: img: A Tensor of shape [batch size, width, height, channels], that can be either real or generated. It is the discriminator's goal to distinguish between the two. unused_conditioning: The TFGAN API can help with conditional GANs, which would require extra `condition` information to both the generator and the discriminator. Since this example is not conditional, we do not use this argument. is_training: If `True`, batch norm uses batch statistics. If `False`, batch norm uses the exponential moving average collected from population statistics. Returns: A 1D Tensor of shape [batch size] representing the confidence that the images are real. The output can lie in [-inf, inf], with positive values indicating high confidence that the images are real. """ logits, _ = dcgan.discriminator(img, is_training=is_training, fused_batch_norm=True) return logits
def test_discriminator_run(self): image = tf.random_uniform([5, 32, 32, 3], -1, 1) output, _ = dcgan.discriminator(image) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) output.eval()