Exemplo n.º 1
0
    def test_simple_upsampling_block_refine_convs_bn_post(self):
        block = encoder_decoder.SimpleUpsamplingBlock(
            upsampling_stride=2,
            transposed_conv=False,
            interp_method="bilinear",
            skip_connection=True,
            refine_convs=2,
            refine_convs_filters=16,
            refine_convs_use_bias=True,
            refine_convs_kernel_size=3,
            refine_convs_batch_norm=True,
            refine_convs_batch_norm_before_activation=False,
            refine_convs_activation="relu",
        )
        x_in = tf.keras.Input((8, 8, 1))
        x = block.make_block(x_in)
        model = tf.keras.Model(x_in, x)

        self.assertEqual(len(model.layers), 8)
        self.assertEqual(len(model.trainable_weights), 8)
        self.assertEqual(model.count_params(), 2608)
        self.assertAllEqual(model.output.shape, (None, 16, 16, 16))
        self.assertIsInstance(model.layers[1], tf.keras.layers.UpSampling2D)
        self.assertIsInstance(model.layers[2], tf.keras.layers.Conv2D)
        self.assertIsInstance(model.layers[3], tf.keras.layers.Activation)
        self.assertIsInstance(model.layers[4],
                              tf.keras.layers.BatchNormalization)
Exemplo n.º 2
0
 def decoder_stack(self) -> List[encoder_decoder.SimpleUpsamplingBlock]:
     """Define the decoder stack."""
     blocks = []
     for block in range(self.up_blocks):
         block_filters_in = int(
             self.filters *
             (self.filters_rate
              **(self.down_blocks + self.stem_blocks - 1 - block)))
         if self.block_contraction:
             block_filters_out = int(
                 self.filters *
                 (self.filters_rate
                  **(self.down_blocks + self.stem_blocks - 2 - block)))
         else:
             block_filters_out = block_filters_in
         blocks.append(
             encoder_decoder.SimpleUpsamplingBlock(
                 upsampling_stride=2,
                 transposed_conv=(not self.up_interpolate),
                 transposed_conv_filters=block_filters_in,
                 transposed_conv_kernel_size=self.kernel_size,
                 transposed_conv_batch_norm=False,
                 interp_method="bilinear",
                 skip_connection=True,
                 skip_add=False,
                 refine_convs=self.convs_per_block,
                 refine_convs_first_filters=block_filters_in,
                 refine_convs_filters=block_filters_out,
                 refine_convs_kernel_size=self.kernel_size,
                 refine_convs_batch_norm=False,
             ))
     return blocks
Exemplo n.º 3
0
    def test_simple_upsampling_block(self):
        block = encoder_decoder.SimpleUpsamplingBlock(
            upsampling_stride=2,
            transposed_conv=False,
            interp_method="bilinear",
            refine_convs=0,
        )
        x_in = tf.keras.Input((8, 8, 1))
        x = block.make_block(x_in)
        model = tf.keras.Model(x_in, x)

        self.assertEqual(len(model.layers), 2)
        self.assertEqual(len(model.trainable_weights), 0)
        self.assertEqual(model.count_params(), 0)
        self.assertAllEqual(model.output.shape, (None, 16, 16, 1))
        self.assertIsInstance(model.layers[1], tf.keras.layers.UpSampling2D)
Exemplo n.º 4
0
    def test_simple_upsampling_block_skip_concat(self):
        block = encoder_decoder.SimpleUpsamplingBlock(
            upsampling_stride=2,
            transposed_conv=False,
            interp_method="bilinear",
            skip_connection=True,
            skip_add=False,
            refine_convs=0,
        )
        x_in = tf.keras.Input((8, 8, 1))
        skip_src = tf.keras.Input((16, 16, 4))
        x = block.make_block(x_in, skip_source=skip_src)
        model = tf.keras.Model([x_in, skip_src], x)

        self.assertEqual(len(model.layers), 4)
        self.assertEqual(len(model.trainable_weights), 0)
        self.assertEqual(model.count_params(), 0)
        self.assertAllEqual(model.output.shape, (None, 16, 16, 5))
        self.assertIsInstance(model.layers[2], tf.keras.layers.UpSampling2D)
        self.assertIsInstance(model.layers[3], tf.keras.layers.Concatenate)
Exemplo n.º 5
0
    def test_simple_upsampling_block_skip_add(self):
        block = encoder_decoder.SimpleUpsamplingBlock(
            upsampling_stride=2,
            transposed_conv=False,
            interp_method="bilinear",
            skip_connection=True,
            skip_add=True,
            refine_convs=0,
        )
        x_in = tf.keras.Input((8, 8, 1))
        skip_src = tf.ones((1, 16, 16, 1))
        x = block.make_block(x_in, skip_source=skip_src)
        model = tf.keras.Model(x_in, x)

        self.assertEqual(len(model.layers), 3)
        self.assertEqual(len(model.trainable_weights), 0)
        self.assertEqual(model.count_params(), 0)
        self.assertAllEqual(model.output.shape, (None, 16, 16, 1))
        self.assertIsInstance(model.layers[1], tf.keras.layers.UpSampling2D)
        self.assertTrue("add"
                        in model.layers[2].name.lower())  # tf_op_layer_AddV2
        self.assertAllClose(model(tf.ones((1, 8, 8, 1))),
                            tf.ones((1, 16, 16, 1)) * 2)
Exemplo n.º 6
0
    def test_simple_upsampling_block_trans_conv_bn_post(self):
        block = encoder_decoder.SimpleUpsamplingBlock(
            upsampling_stride=2,
            transposed_conv = True,
            transposed_conv_filters = 8,
            transposed_conv_kernel_size = 3,
            transposed_conv_use_bias = True,
            transposed_conv_batch_norm = True,
            transposed_conv_batch_norm_before_activation = False,
            transposed_conv_activation = "relu",
            refine_convs = 0,
        )
        x_in = tf.keras.Input((8, 8, 1))
        x = block.make_block(x_in)
        model = tf.keras.Model(x_in, x)

        self.assertEqual(len(model.layers), 1 + 3)
        self.assertEqual(len(model.trainable_weights), 4)
        self.assertEqual(model.count_params(), 112)
        self.assertAllEqual(model.output.shape, (None, 16, 16, 8))
        self.assertIsInstance(model.layers[1], tf.keras.layers.Conv2DTranspose)
        self.assertIsInstance(model.layers[2], tf.keras.layers.Activation)
        self.assertIsInstance(model.layers[3], tf.keras.layers.BatchNormalization)