def test_nonsquare_inputs_raise_exception(self): batch_size = 2 height, width = 240, 320 num_outputs = 4 images = tf.ones((batch_size, height, width, 3)) with self.assertRaises(ValueError): with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): pix2pix.pix2pix_generator( images, num_outputs, upsample_method='nn_upsample_conv')
def test_output_size_conv2d_transpose(self): batch_size = 2 height, width = 256, 256 num_outputs = 4 images = tf.ones((batch_size, height, width, 3)) with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): logits, _ = pix2pix.pix2pix_generator( images, num_outputs, blocks=self._reduced_default_blocks(), upsample_method='conv2d_transpose') with self.test_session() as session: session.run(tf.global_variables_initializer()) np_outputs = session.run(logits) self.assertListEqual([batch_size, height, width, num_outputs], list(np_outputs.shape))
def test_block_number_dictates_number_of_layers(self): batch_size = 2 height, width = 256, 256 num_outputs = 4 images = tf.ones((batch_size, height, width, 3)) blocks = [ pix2pix.Block(64, 0.5), pix2pix.Block(128, 0), ] with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): _, end_points = pix2pix.pix2pix_generator(images, num_outputs, blocks) num_encoder_layers = 0 num_decoder_layers = 0 for end_point in end_points: if end_point.startswith('encoder'): num_encoder_layers += 1 elif end_point.startswith('decoder'): num_decoder_layers += 1 self.assertEqual(num_encoder_layers, len(blocks)) self.assertEqual(num_decoder_layers, len(blocks))
def test_block_number_dictates_number_of_layers(self): batch_size = 2 height, width = 256, 256 num_outputs = 4 images = tf.ones((batch_size, height, width, 3)) blocks = [ pix2pix.Block(64, 0.5), pix2pix.Block(128, 0), ] with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): _, end_points = pix2pix.pix2pix_generator( images, num_outputs, blocks) num_encoder_layers = 0 num_decoder_layers = 0 for end_point in end_points: if end_point.startswith('encoder'): num_encoder_layers += 1 elif end_point.startswith('decoder'): num_decoder_layers += 1 self.assertEqual(num_encoder_layers, len(blocks)) self.assertEqual(num_decoder_layers, len(blocks))