예제 #1
0
    def test_upsampling_stack_upsampling_add(self):
        upsampling_stack = upsampling.UpsamplingStack(
            output_stride=2,
            upsampling_stride=2,
            skip_add=True,
            transposed_conv=True,
            transposed_conv_filters=16,
            refine_convs=0,
        )
        skip_sources = [
            upsampling.IntermediateFeature(tensor=tf.keras.Input((16, 16, 1)),
                                           stride=8),
            upsampling.IntermediateFeature(tensor=tf.keras.Input((32, 32, 2)),
                                           stride=4),
        ]
        x, intermediate_feats = upsampling_stack.make_stack(
            tf.keras.Input((8, 8, 32)),
            current_stride=16,
            skip_sources=skip_sources)
        model = tf.keras.Model(tf.keras.utils.get_source_inputs(x), x)

        self.assertAllEqual(x.shape, (None, 64, 64, 16))
        self.assertEqual(len(intermediate_feats), 4)
        self.assertAllEqual(
            model.get_layer("upsample_s16_to_s8_skip_conv1x1").output.shape,
            (None, 16, 16, 16),
        )
        self.assertAllEqual(
            model.get_layer("upsample_s8_to_s4_skip_conv1x1").output.shape,
            (None, 32, 32, 16),
        )
        self.assertIsInstance(model.get_layer("upsample_s16_to_s8_skip_add"),
                              tf.keras.layers.Add)
예제 #2
0
    def test_upsampling_stack_refine_convs_filter_rate(self):
        upsampling_stack = upsampling.UpsamplingStack(
            output_stride=2,
            upsampling_stride=2,
            transposed_conv=False,
            refine_convs=2,
            refine_convs_filters=16,
            refine_convs_filters_rate=2,
        )
        x, intermediate_feats = upsampling_stack.make_stack(tf.keras.Input(
            (4, 4, 2)),
                                                            current_stride=16)
        model = tf.keras.Model(tf.keras.utils.get_source_inputs(x), x)

        self.assertEqual(
            model.get_layer("upsample_s16_to_s8_refine0_conv").filters, 16)
        self.assertEqual(
            model.get_layer("upsample_s16_to_s8_refine1_conv").filters, 16)
        self.assertEqual(
            model.get_layer("upsample_s8_to_s4_refine0_conv").filters, 32)
        self.assertEqual(
            model.get_layer("upsample_s8_to_s4_refine1_conv").filters, 32)
        self.assertEqual(
            model.get_layer("upsample_s4_to_s2_refine0_conv").filters, 64)
        self.assertEqual(
            model.get_layer("upsample_s4_to_s2_refine1_conv").filters, 64)
        self.assertAllEqual(x.shape, (None, 32, 32, 64))
예제 #3
0
    def test_upsampling_stack_upsampling_skip(self):
        upsampling_stack = upsampling.UpsamplingStack(
            output_stride=2,
            upsampling_stride=2,
            skip_add=False,
            transposed_conv=True,
            transposed_conv_filters=16,
            refine_convs=0,
        )
        skip_sources = [
            upsampling.IntermediateFeature(tensor=tf.keras.Input((16, 16, 1)),
                                           stride=8),
            upsampling.IntermediateFeature(tensor=tf.keras.Input((32, 32, 2)),
                                           stride=4),
        ]
        x, intermediate_feats = upsampling_stack.make_stack(
            tf.keras.Input((8, 8, 32)),
            current_stride=16,
            skip_sources=skip_sources)
        model = tf.keras.Model(tf.keras.utils.get_source_inputs(x), x)

        self.assertAllEqual(x.shape, (None, 64, 64, 16))
        self.assertEqual(len(intermediate_feats), 4)
        self.assertIsInstance(model.layers[1], tf.keras.layers.Conv2DTranspose)
        self.assertIsInstance(model.layers[2],
                              tf.keras.layers.BatchNormalization)
        self.assertIsInstance(model.layers[4], tf.keras.layers.Activation)
        self.assertIsInstance(model.layers[5], tf.keras.layers.Concatenate)
        self.assertAllEqual(model.layers[5].output.shape, (None, 16, 16, 17))

        self.assertIsInstance(model.layers[10], tf.keras.layers.Concatenate)
        self.assertAllEqual(model.layers[10].output.shape, (None, 32, 32, 18))
예제 #4
0
    def test_upsampling_stack_upsampling_stride4(self):
        upsampling_stack = upsampling.UpsamplingStack(output_stride=4,
                                                      upsampling_stride=4)
        x, intermediate_feats = upsampling_stack.make_stack(tf.keras.Input(
            (8, 8, 32)),
                                                            current_stride=16)

        self.assertAllEqual(x.shape, (None, 32, 32, 64))
        self.assertEqual(len(intermediate_feats), 2)
예제 #5
0
    def test_upsampling_stack_upsampling_interp(self):
        upsampling_stack = upsampling.UpsamplingStack(output_stride=8,
                                                      upsampling_stride=2,
                                                      transposed_conv=False)
        x, intermediate_feats = upsampling_stack.make_stack(tf.keras.Input(
            (8, 8, 32)),
                                                            current_stride=16)

        self.assertAllEqual(x.shape, (None, 16, 16, 64))
        model = tf.keras.Model(tf.keras.utils.get_source_inputs(x), x)
        self.assertIsInstance(model.layers[1], tf.keras.layers.UpSampling2D)
예제 #6
0
    def test_resnet50_upsampling(self):
        resnet50 = resnet.ResNet50(
            pretrained=False,
            frozen=False,
            features_output_stride=32,
            upsampling_stack=upsampling.UpsamplingStack(
                output_stride=4, refine_convs_filters=64),
        )

        x_in = tf.keras.layers.Input((160, 160, 1))
        x, x_mid = resnet50.make_backbone(x_in)

        with self.subTest("output shape"):
            self.assertAllEqual(x.shape, (None, 160 // 4, 160 // 4, 64))
예제 #7
0
    def test_upsampling_stack(self):
        upsampling_stack = upsampling.UpsamplingStack(
            output_stride=4,
            upsampling_stride=2,
            transposed_conv=True,
            transposed_conv_batchnorm=True,
            refine_convs=1,
            refine_convs_batchnorm=True,
        )
        x, intermediate_feats = upsampling_stack.make_stack(tf.keras.Input(
            (8, 8, 32)),
                                                            current_stride=16)
        model = tf.keras.Model(tf.keras.utils.get_source_inputs(x), x)

        self.assertAllEqual(x.shape, (None, 32, 32, 64))
        self.assertEqual(len(intermediate_feats), 2)
        self.assertEqual(intermediate_feats[0].stride, 8)
        self.assertEqual(intermediate_feats[1].stride, 4)
        self.assertEqual(len(model.layers), 13)
        self.assertIsInstance(model.layers[1], tf.keras.layers.Conv2DTranspose)