def test_upsampling_stack_upsampling_add(self): upsampling_stack = upsampling.UpsamplingStack( output_stride=2, upsampling_stride=2, skip_add=True, transposed_conv=True, transposed_conv_filters=16, refine_convs=0, ) skip_sources = [ upsampling.IntermediateFeature(tensor=tf.keras.Input((16, 16, 1)), stride=8), upsampling.IntermediateFeature(tensor=tf.keras.Input((32, 32, 2)), stride=4), ] x, intermediate_feats = upsampling_stack.make_stack( tf.keras.Input((8, 8, 32)), current_stride=16, skip_sources=skip_sources) model = tf.keras.Model(tf.keras.utils.get_source_inputs(x), x) self.assertAllEqual(x.shape, (None, 64, 64, 16)) self.assertEqual(len(intermediate_feats), 4) self.assertAllEqual( model.get_layer("upsample_s16_to_s8_skip_conv1x1").output.shape, (None, 16, 16, 16), ) self.assertAllEqual( model.get_layer("upsample_s8_to_s4_skip_conv1x1").output.shape, (None, 32, 32, 16), ) self.assertIsInstance(model.get_layer("upsample_s16_to_s8_skip_add"), tf.keras.layers.Add)
def test_upsampling_stack_refine_convs_filter_rate(self): upsampling_stack = upsampling.UpsamplingStack( output_stride=2, upsampling_stride=2, transposed_conv=False, refine_convs=2, refine_convs_filters=16, refine_convs_filters_rate=2, ) x, intermediate_feats = upsampling_stack.make_stack(tf.keras.Input( (4, 4, 2)), current_stride=16) model = tf.keras.Model(tf.keras.utils.get_source_inputs(x), x) self.assertEqual( model.get_layer("upsample_s16_to_s8_refine0_conv").filters, 16) self.assertEqual( model.get_layer("upsample_s16_to_s8_refine1_conv").filters, 16) self.assertEqual( model.get_layer("upsample_s8_to_s4_refine0_conv").filters, 32) self.assertEqual( model.get_layer("upsample_s8_to_s4_refine1_conv").filters, 32) self.assertEqual( model.get_layer("upsample_s4_to_s2_refine0_conv").filters, 64) self.assertEqual( model.get_layer("upsample_s4_to_s2_refine1_conv").filters, 64) self.assertAllEqual(x.shape, (None, 32, 32, 64))
def test_upsampling_stack_upsampling_skip(self): upsampling_stack = upsampling.UpsamplingStack( output_stride=2, upsampling_stride=2, skip_add=False, transposed_conv=True, transposed_conv_filters=16, refine_convs=0, ) skip_sources = [ upsampling.IntermediateFeature(tensor=tf.keras.Input((16, 16, 1)), stride=8), upsampling.IntermediateFeature(tensor=tf.keras.Input((32, 32, 2)), stride=4), ] x, intermediate_feats = upsampling_stack.make_stack( tf.keras.Input((8, 8, 32)), current_stride=16, skip_sources=skip_sources) model = tf.keras.Model(tf.keras.utils.get_source_inputs(x), x) self.assertAllEqual(x.shape, (None, 64, 64, 16)) self.assertEqual(len(intermediate_feats), 4) self.assertIsInstance(model.layers[1], tf.keras.layers.Conv2DTranspose) self.assertIsInstance(model.layers[2], tf.keras.layers.BatchNormalization) self.assertIsInstance(model.layers[4], tf.keras.layers.Activation) self.assertIsInstance(model.layers[5], tf.keras.layers.Concatenate) self.assertAllEqual(model.layers[5].output.shape, (None, 16, 16, 17)) self.assertIsInstance(model.layers[10], tf.keras.layers.Concatenate) self.assertAllEqual(model.layers[10].output.shape, (None, 32, 32, 18))
def test_upsampling_stack_upsampling_stride4(self): upsampling_stack = upsampling.UpsamplingStack(output_stride=4, upsampling_stride=4) x, intermediate_feats = upsampling_stack.make_stack(tf.keras.Input( (8, 8, 32)), current_stride=16) self.assertAllEqual(x.shape, (None, 32, 32, 64)) self.assertEqual(len(intermediate_feats), 2)
def test_upsampling_stack_upsampling_interp(self): upsampling_stack = upsampling.UpsamplingStack(output_stride=8, upsampling_stride=2, transposed_conv=False) x, intermediate_feats = upsampling_stack.make_stack(tf.keras.Input( (8, 8, 32)), current_stride=16) self.assertAllEqual(x.shape, (None, 16, 16, 64)) model = tf.keras.Model(tf.keras.utils.get_source_inputs(x), x) self.assertIsInstance(model.layers[1], tf.keras.layers.UpSampling2D)
def test_resnet50_upsampling(self): resnet50 = resnet.ResNet50( pretrained=False, frozen=False, features_output_stride=32, upsampling_stack=upsampling.UpsamplingStack( output_stride=4, refine_convs_filters=64), ) x_in = tf.keras.layers.Input((160, 160, 1)) x, x_mid = resnet50.make_backbone(x_in) with self.subTest("output shape"): self.assertAllEqual(x.shape, (None, 160 // 4, 160 // 4, 64))
def test_upsampling_stack(self): upsampling_stack = upsampling.UpsamplingStack( output_stride=4, upsampling_stride=2, transposed_conv=True, transposed_conv_batchnorm=True, refine_convs=1, refine_convs_batchnorm=True, ) x, intermediate_feats = upsampling_stack.make_stack(tf.keras.Input( (8, 8, 32)), current_stride=16) model = tf.keras.Model(tf.keras.utils.get_source_inputs(x), x) self.assertAllEqual(x.shape, (None, 32, 32, 64)) self.assertEqual(len(intermediate_feats), 2) self.assertEqual(intermediate_feats[0].stride, 8) self.assertEqual(intermediate_feats[1].stride, 4) self.assertEqual(len(model.layers), 13) self.assertIsInstance(model.layers[1], tf.keras.layers.Conv2DTranspose)