Ejemplo n.º 1
0
    def test_four_layers_wrong_paddig(self):
        batch_size = 2
        input_size = 256

        images = tf.ones((batch_size, input_size, input_size, 3))
        with self.assertRaises(TypeError):
            discriminator.pix2pix_discriminator(
                images, num_filters=[64, 128, 256, 512], padding=1.5)
Ejemplo n.º 2
0
    def test_four_layers_negative_padding(self):
        batch_size = 2
        input_size = 256

        images = tf.ones((batch_size, input_size, input_size, 3))
        if tf.executing_eagerly():
            exception_type = tf.errors.InvalidArgumentError
        else:
            exception_type = ValueError
        with self.assertRaises(exception_type):
            discriminator.pix2pix_discriminator(
                images, num_filters=[64, 128, 256, 512], padding=-1)
Ejemplo n.º 3
0
def discriminator(image_batch, unused_conditioning=None):
    """A thin wrapper around the Pix2Pix discriminator to conform to TF-GAN."""
    logits_4d, _ = d_module.pix2pix_discriminator(
        image_batch, num_filters=[64, 128, 256, 512])
    logits_4d.shape.assert_has_rank(4)
    # Output of logits is 4D. Reshape to 2D, for TF-GAN.
    logits_2d = tf.compat.v1.layers.flatten(logits_4d)

    return logits_2d
Ejemplo n.º 4
0
    def test_four_layers_no_padding(self):
        batch_size = 2
        input_size = 256

        output_size = self._layer_output_size(input_size, pad=0)
        output_size = self._layer_output_size(output_size, pad=0)
        output_size = self._layer_output_size(output_size, pad=0)
        output_size = self._layer_output_size(output_size, stride=1, pad=0)
        output_size = self._layer_output_size(output_size, stride=1, pad=0)

        images = tf.ones((batch_size, input_size, input_size, 3))
        logits, end_points = discriminator.pix2pix_discriminator(
            images, num_filters=[64, 128, 256, 512], padding=0)
        self.assertListEqual([batch_size, output_size, output_size, 1],
                             logits.shape.as_list())
        self.assertListEqual([batch_size, output_size, output_size, 1],
                             end_points['predictions'].shape.as_list())