Example #1
0
    def test_simple_conv_block_pool_before_convs(self):
        block = encoder_decoder.SimpleConvBlock(
            pool=True,
            pool_before_convs=True,
            pooling_stride=2,
            num_convs=3,
            filters=16,
            kernel_size=3,
            use_bias=True,
            batch_norm=True,
            batch_norm_before_activation=True,
            activation="relu",
        )
        x_in = tf.keras.Input((8, 8, 1))
        x = block.make_block(x_in)
        model = tf.keras.Model(x_in, x)

        self.assertEqual(len(model.layers), 1 + 3 * 3 + 1)
        self.assertEqual(len(model.trainable_weights), 12)
        self.assertEqual(model.count_params(), 4992)
        self.assertAllEqual(model.output.shape, (None, 4, 4, 16))
        self.assertIsInstance(model.layers[1], tf.keras.layers.MaxPooling2D)
        self.assertIsInstance(model.layers[2], tf.keras.layers.Conv2D)
        self.assertIsInstance(model.layers[3],
                              tf.keras.layers.BatchNormalization)
        self.assertIsInstance(model.layers[4], tf.keras.layers.Activation)
Example #2
0
    def stem_stack(self) -> Optional[List[encoder_decoder.SimpleConvBlock]]:
        """Define the downsampling stem."""
        if self.stem_blocks == 0:
            return None

        blocks = []
        for block in range(self.stem_blocks):
            block_filters = int(self.filters * (self.filters_rate**block))
            blocks.append(
                encoder_decoder.SimpleConvBlock(
                    pool=(block > 0),
                    pool_before_convs=True,
                    pooling_stride=2,
                    num_convs=self.convs_per_block,
                    filters=block_filters,
                    kernel_size=self.stem_kernel_size,
                    use_bias=True,
                    batch_norm=False,
                    activation="relu",
                ))

        # Always finish with a pooling block to account for pooling before convs.
        blocks.append(PoolingBlock(pool=True, pooling_stride=2))

        return blocks
Example #3
0
    def test_simple_conv_block(self):
        block = encoder_decoder.SimpleConvBlock(
            pooling_stride=2,
            num_convs=3,
            filters=16,
            kernel_size=3,
            use_bias=True,
            batch_norm=False,
            batch_norm_before_activation=True,
            activation="relu",
        )
        x_in = tf.keras.Input((8, 8, 1))
        x = block.make_block(x_in)
        model = tf.keras.Model(x_in, x)

        self.assertEqual(len(model.layers), 1 + 2 * 3 + 1)
        self.assertEqual(len(model.trainable_weights), 6)
        self.assertEqual(model.count_params(), 4800)
        self.assertAllEqual(model.output.shape, (None, 4, 4, 16))
Example #4
0
    def encoder_stack(self) -> List[encoder_decoder.SimpleConvBlock]:
        """Define the encoder stack."""
        blocks = []
        for block in range(self.down_blocks):
            block_filters = int(
                self.filters * (self.filters_rate**(block + self.stem_blocks)))
            blocks.append(
                encoder_decoder.SimpleConvBlock(
                    pool=(block > 0),
                    pool_before_convs=True,
                    pooling_stride=2,
                    num_convs=self.convs_per_block,
                    filters=block_filters,
                    kernel_size=self.kernel_size,
                    use_bias=True,
                    batch_norm=False,
                    activation="relu",
                ))

        # Always finish with a pooling block to account for pooling before convs.
        blocks.append(PoolingBlock(pool=True, pooling_stride=2))

        # Create a middle block (like the CARE implementation).
        if self.middle_block:
            if self.convs_per_block > 1:
                # First convs are one exponent higher than the last encoder block.
                block_filters = int(
                    self.filters *
                    (self.filters_rate**(self.down_blocks + self.stem_blocks)))
                blocks.append(
                    encoder_decoder.SimpleConvBlock(
                        pool=False,
                        pool_before_convs=False,
                        pooling_stride=2,
                        num_convs=self.convs_per_block - 1,
                        filters=block_filters,
                        kernel_size=self.kernel_size,
                        use_bias=True,
                        batch_norm=False,
                        activation="relu",
                        block_prefix="_middle_expand",
                    ))

            if self.block_contraction:
                # Contract the channels with an exponent lower than the last encoder block.
                block_filters = int(
                    self.filters * (self.filters_rate**(self.down_blocks +
                                                        self.stem_blocks - 1)))
            else:
                # Keep the block output filters the same.
                block_filters = int(
                    self.filters *
                    (self.filters_rate**(self.down_blocks + self.stem_blocks)))
            blocks.append(
                encoder_decoder.SimpleConvBlock(
                    pool=False,
                    pool_before_convs=False,
                    pooling_stride=2,
                    num_convs=1,
                    filters=block_filters,
                    kernel_size=self.kernel_size,
                    use_bias=True,
                    batch_norm=False,
                    activation="relu",
                    block_prefix="_middle_contract",
                ))

        return blocks