def test_generator_graph(self): for shape in ([4, 32, 32], [3, 128, 128], [2, 80, 400]): tf.reset_default_graph() img = tf.ones(shape + [3]) output_imgs = networks.generator(img) self.assertAllEqual(shape + [3], output_imgs.shape.as_list())
def test_generator_invalid_channels(self): with self.assertRaisesRegexp( ValueError, 'Last dimension shape must be known but is None'): img = tf.placeholder(tf.float32, shape=[4, 32, 32, None]) networks.generator(img)
def test_generator_run_multi_channel(self): img_batch = tf.zeros([3, 128, 128, 5]) model_output = networks.generator(img_batch) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) sess.run(model_output)
def test_generator_invalid_input(self): with self.assertRaisesRegexp(ValueError, 'must have rank 4'): networks.generator(tf.zeros([28, 28, 3]))
def test_generator_graph_unknown_batch_dim(self): img = tf.placeholder(tf.float32, shape=[None, 32, 32, 3]) output_imgs = networks.generator(img) self.assertAllEqual([None, 32, 32, 3], output_imgs.shape.as_list())