def squeeze_excitation_layer(
        inputs,
        ratio,
        training=True,
        data_format='NCHW',
        kernel_initializer=tf.compat.v1.variance_scaling_initializer(),
        bias_initializer=tf.zeros_initializer(),
        name="squeeze_excitation_layer"):

    if data_format not in ['NHWC', 'NCHW']:
        raise ValueError(
            "Unknown data format: `%s` (accepted: ['NHWC', 'NCHW'])" %
            data_format)

    in_shape = inputs.get_shape()

    num_channels = in_shape[1] if data_format == "NCHW" else in_shape[-1]

    with tf.variable_scope(name):

        net = inputs

        # squeeze
        squeeze = layers.reduce_mean(net,
                                     keepdims=False,
                                     data_format=data_format,
                                     name='squeeze_spatial_mean')

        # fc + relu
        excitation = layers.dense(inputs=squeeze,
                                  units=num_channels // ratio,
                                  use_bias=True,
                                  trainable=training,
                                  kernel_initializer=kernel_initializer,
                                  bias_initializer=bias_initializer)
        excitation = layers.relu(excitation)

        # fc + sigmoid
        excitation = layers.dense(inputs=excitation,
                                  units=num_channels,
                                  use_bias=True,
                                  trainable=training,
                                  kernel_initializer=kernel_initializer,
                                  bias_initializer=bias_initializer)
        excitation = layers.sigmoid(excitation)

        out_shape = [-1, num_channels, 1, 1
                     ] if data_format == "NCHW" else [-1, 1, 1, num_channels]

        excitation = tf.reshape(excitation, out_shape)

        net = net * excitation

        return net
Beispiel #2
0
    def build_model(self, inputs, training=True, reuse=False):
        
        with var_storage.model_variable_scope(
            self.model_hparams.model_name,
            reuse=reuse,
            dtype=self.model_hparams.dtype):

            with tf.variable_scope("input_reshape"):
                if self.model_hparams.input_format == 'NHWC' and self.model_hparams.compute_format == 'NCHW':
                    # Reshape inputs: NHWC => NCHW
                    inputs = tf.transpose(inputs, [0, 3, 1, 2])

                elif self.model_hparams.input_format == 'NCHW' and self.model_hparams.compute_format == 'NHWC':
                    # Reshape inputs: NCHW => NHWC
                    inputs = tf.transpose(inputs, [0, 2, 3, 1])

            if self.model_hparams.dtype != inputs.dtype:
                inputs = tf.cast(inputs, self.model_hparams.dtype)

            net = blocks.conv2d_block(
                inputs,
                n_channels=64,
                kernel_size=(7, 7),
                strides=(2, 2),
                mode='SAME',
                use_batch_norm=True,
                activation='relu',
                is_training=training,
                data_format=self.model_hparams.compute_format,
                conv2d_hparams=self.conv2d_hparams,
                batch_norm_hparams=self.batch_norm_hparams,
                name='conv2d'
            )

            net = layers.max_pooling2d(
                net,
                pool_size=(3, 3),
                strides=(2, 2),
                padding='SAME',
                data_format=self.model_hparams.compute_format,
                name="max_pooling2d",
            )

            model_bottlenecks = self.model_hparams.layers_depth
            for block_id, block_bottleneck in enumerate(model_bottlenecks):
                for layer_id in range(self.model_hparams.layers_count[block_id]):
                    stride = 2 if (layer_id == 0 and block_id != 0) else 1

                    net = blocks.bottleneck_block(
                        inputs=net,
                        depth=block_bottleneck * self.model_hparams.expansions,
                        depth_bottleneck=block_bottleneck,
                        cardinality=self.model_hparams.cardinality,
                        stride=stride,
                        training=training,
                        data_format=self.model_hparams.compute_format,
                        conv2d_hparams=self.conv2d_hparams,
                        batch_norm_hparams=self.batch_norm_hparams,
                        block_name="btlnck_block_%d_%d" % (block_id, layer_id),
                        use_se=self.model_hparams.use_se,
                        ratio=self.model_hparams.se_ratio)

            with tf.variable_scope("output"):
                net = layers.reduce_mean(
                    net, keepdims=False, data_format=self.model_hparams.compute_format, name='spatial_mean')

                logits = layers.dense(
                    inputs=net,
                    units=self.model_hparams.n_classes,
                    use_bias=True,
                    trainable=training,
                    kernel_initializer=self.dense_hparams.kernel_initializer,
                    bias_initializer=self.dense_hparams.bias_initializer)

                if logits.dtype != tf.float32:
                    logits = tf.cast(logits, tf.float32)

                probs = layers.softmax(logits, name="softmax", axis=1)

            return probs, logits
Beispiel #3
0
    def build_model(self, inputs, training=True, reuse=False):

        with var_storage.model_variable_scope(self.model_hparams.model_name,
                                              reuse=reuse,
                                              dtype=self.model_hparams.dtype):

            with tf.variable_scope("input_reshape"):

                if self.model_hparams.input_format == 'NHWC' and self.model_hparams.compute_format == 'NCHW':
                    # Reshape inputs: NHWC => NCHW
                    inputs = tf.transpose(inputs, [0, 3, 1, 2])

                elif self.model_hparams.input_format == 'NCHW' and self.model_hparams.compute_format == 'NHWC':

                    # Reshape inputs: NCHW => NHWC
                    inputs = tf.transpose(inputs, [0, 2, 3, 1])

            if self.model_hparams.dtype != inputs.dtype:
                inputs = tf.cast(inputs, self.model_hparams.dtype)

            net = blocks.conv2d_block(
                inputs,
                n_channels=64,
                # n_channels=16,
                kernel_size=(7, 7),
                strides=(2, 2),
                mode='SAME_RESNET',
                use_batch_norm=True,
                activation='relu',
                is_training=training,
                data_format=self.model_hparams.compute_format,
                conv2d_hparams=self.conv2d_hparams,
                batch_norm_hparams=self.batch_norm_hparams,
                name='conv2d')

            net = layers.max_pooling2d(
                net,
                pool_size=(3, 3),
                strides=(2, 2),
                padding='SAME',
                data_format=self.model_hparams.compute_format,
                name="max_pooling2d",
            )

            for block_id, _ in enumerate(
                    range(self.model_hparams.layer_counts[0])):
                net = blocks.bottleneck_block(
                    inputs=net,
                    depth=256,
                    depth_bottleneck=64,
                    stride=1,
                    training=training,
                    data_format=self.model_hparams.compute_format,
                    conv2d_hparams=self.conv2d_hparams,
                    batch_norm_hparams=self.batch_norm_hparams,
                    block_name="btlnck_block_1_%d" % (block_id + 1))

            for block_id, i in enumerate(
                    range(self.model_hparams.layer_counts[1])):
                stride = 2 if i == 0 else 1

                net = blocks.bottleneck_block(
                    inputs=net,
                    depth=512,
                    depth_bottleneck=128,
                    stride=stride,
                    training=training,
                    data_format=self.model_hparams.compute_format,
                    conv2d_hparams=self.conv2d_hparams,
                    batch_norm_hparams=self.batch_norm_hparams,
                    block_name="btlnck_block_2_%d" % (block_id + 1))

            for block_id, i in enumerate(
                    range(self.model_hparams.layer_counts[2])):
                block_id += 1
                stride = 2 if i == 0 else 1

                net = blocks.bottleneck_block(
                    inputs=net,
                    depth=1024,
                    depth_bottleneck=256,
                    stride=stride,
                    training=training,
                    data_format=self.model_hparams.compute_format,
                    conv2d_hparams=self.conv2d_hparams,
                    batch_norm_hparams=self.batch_norm_hparams,
                    block_name="btlnck_block_3_%d" % (block_id + 1))

            for block_id, i in enumerate(
                    range(self.model_hparams.layer_counts[3])):
                stride = 2 if i == 0 else 1

                net = blocks.bottleneck_block(
                    inputs=net,
                    depth=2048,
                    depth_bottleneck=512,
                    stride=stride,
                    training=training,
                    data_format=self.model_hparams.compute_format,
                    conv2d_hparams=self.conv2d_hparams,
                    batch_norm_hparams=self.batch_norm_hparams,
                    block_name="btlnck_block_4_%d" % (block_id + 1))

            with tf.variable_scope("output"):

                net = layers.reduce_mean(
                    net,
                    keepdims=False,
                    data_format=self.model_hparams.compute_format,
                    name='spatial_mean')

                logits = layers.dense(
                    inputs=net,
                    units=self.model_hparams.n_classes,
                    use_bias=True,
                    trainable=training,
                    kernel_initializer=self.dense_hparams.kernel_initializer,
                    bias_initializer=self.dense_hparams.bias_initializer)

                if logits.dtype != tf.float32:
                    logits = tf.cast(logits, tf.float32, name="logits")

                probs = layers.softmax(logits, name="softmax", axis=1)

            return probs, logits