Example #1
0
    def test_generator_graph(self):
        for shape in ([4, 32, 32], [3, 128, 128], [2, 80, 400]):
            tf.reset_default_graph()
            img = tf.ones(shape + [3])
            output_imgs = networks.generator(img)

            self.assertAllEqual(shape + [3], output_imgs.shape.as_list())
Example #2
0
 def test_generator_invalid_channels(self):
     with self.assertRaisesRegexp(
             ValueError, 'Last dimension shape must be known but is None'):
         img = tf.placeholder(tf.float32, shape=[4, 32, 32, None])
         networks.generator(img)
Example #3
0
 def test_generator_run_multi_channel(self):
     img_batch = tf.zeros([3, 128, 128, 5])
     model_output = networks.generator(img_batch)
     with self.test_session() as sess:
         sess.run(tf.global_variables_initializer())
         sess.run(model_output)
Example #4
0
 def test_generator_invalid_input(self):
     with self.assertRaisesRegexp(ValueError, 'must have rank 4'):
         networks.generator(tf.zeros([28, 28, 3]))
Example #5
0
    def test_generator_graph_unknown_batch_dim(self):
        img = tf.placeholder(tf.float32, shape=[None, 32, 32, 3])
        output_imgs = networks.generator(img)

        self.assertAllEqual([None, 32, 32, 3], output_imgs.shape.as_list())