Beispiel #1
0
 def test_discriminator_invalid_input(self):
     try:
         networks.discriminator(tf.zeros([28, 28, 3]))
     except (ValueError, tf.errors.InvalidArgumentError):
         # TF raises ValueError, while TF2 raises tf.errors.InvalidArgumentError.
         return
     except Exception as e:  # pylint: disable=broad-except
         self.assertTrue(False, msg='Unexpected exception: {}'.format(e))
     self.assertTrue(False,
                     msg='Expected ValueError or InvalidArgumentError.')
Beispiel #2
0
    def test_discriminator_graph(self):
        # Check graph construction for a number of image size/depths and batch
        # sizes.
        for batch_size, patch_size in zip([3, 6], [70, 128]):
            tf.reset_default_graph()
            img = tf.ones([batch_size, patch_size, patch_size, 3])
            disc_output = networks.discriminator(img)

            self.assertEqual(2, disc_output.shape.ndims)
            self.assertEqual(batch_size, disc_output.shape.as_list()[0])
Beispiel #3
0
 def test_discriminator_run(self):
     img_batch = tf.zeros([3, 70, 70, 3])
     disc_output = networks.discriminator(img_batch)
     with self.cached_session() as sess:
         sess.run(tf.global_variables_initializer())
         sess.run(disc_output)