示例#1
0
    def build_model(self, inputs, training=True, reuse=False):
        
        with var_storage.model_variable_scope(
            self.model_hparams.model_name,
            reuse=reuse,
            dtype=self.model_hparams.dtype):

            with tf.variable_scope("input_reshape"):
                if self.model_hparams.input_format == 'NHWC' and self.model_hparams.compute_format == 'NCHW':
                    # Reshape inputs: NHWC => NCHW
                    inputs = tf.transpose(inputs, [0, 3, 1, 2])

                elif self.model_hparams.input_format == 'NCHW' and self.model_hparams.compute_format == 'NHWC':
                    # Reshape inputs: NCHW => NHWC
                    inputs = tf.transpose(inputs, [0, 2, 3, 1])

            if self.model_hparams.dtype != inputs.dtype:
                inputs = tf.cast(inputs, self.model_hparams.dtype)

            net = blocks.conv2d_block(
                inputs,
                n_channels=64,
                kernel_size=(7, 7),
                strides=(2, 2),
                mode='SAME',
                use_batch_norm=True,
                activation='relu',
                is_training=training,
                data_format=self.model_hparams.compute_format,
                conv2d_hparams=self.conv2d_hparams,
                batch_norm_hparams=self.batch_norm_hparams,
                name='conv2d'
            )

            net = layers.max_pooling2d(
                net,
                pool_size=(3, 3),
                strides=(2, 2),
                padding='SAME',
                data_format=self.model_hparams.compute_format,
                name="max_pooling2d",
            )

            model_bottlenecks = self.model_hparams.layers_depth
            for block_id, block_bottleneck in enumerate(model_bottlenecks):
                for layer_id in range(self.model_hparams.layers_count[block_id]):
                    stride = 2 if (layer_id == 0 and block_id != 0) else 1

                    net = blocks.bottleneck_block(
                        inputs=net,
                        depth=block_bottleneck * self.model_hparams.expansions,
                        depth_bottleneck=block_bottleneck,
                        cardinality=self.model_hparams.cardinality,
                        stride=stride,
                        training=training,
                        data_format=self.model_hparams.compute_format,
                        conv2d_hparams=self.conv2d_hparams,
                        batch_norm_hparams=self.batch_norm_hparams,
                        block_name="btlnck_block_%d_%d" % (block_id, layer_id),
                        use_se=self.model_hparams.use_se,
                        ratio=self.model_hparams.se_ratio)

            with tf.variable_scope("output"):
                net = layers.reduce_mean(
                    net, keepdims=False, data_format=self.model_hparams.compute_format, name='spatial_mean')

                logits = layers.dense(
                    inputs=net,
                    units=self.model_hparams.n_classes,
                    use_bias=True,
                    trainable=training,
                    kernel_initializer=self.dense_hparams.kernel_initializer,
                    bias_initializer=self.dense_hparams.bias_initializer)

                if logits.dtype != tf.float32:
                    logits = tf.cast(logits, tf.float32)

                probs = layers.softmax(logits, name="softmax", axis=1)

            return probs, logits
示例#2
0
    def build_model(self, inputs, training=True, reuse=False):

        with var_storage.model_variable_scope(self.model_hparams.model_name,
                                              reuse=reuse,
                                              dtype=self.model_hparams.dtype):

            with tf.variable_scope("input_reshape"):

                if self.model_hparams.input_format == 'NHWC' and self.model_hparams.compute_format == 'NCHW':
                    # Reshape inputs: NHWC => NCHW
                    inputs = tf.transpose(inputs, [0, 3, 1, 2])

                elif self.model_hparams.input_format == 'NCHW' and self.model_hparams.compute_format == 'NHWC':

                    # Reshape inputs: NCHW => NHWC
                    inputs = tf.transpose(inputs, [0, 2, 3, 1])

            if self.model_hparams.dtype != inputs.dtype:
                inputs = tf.cast(inputs, self.model_hparams.dtype)

            net = blocks.conv2d_block(
                inputs,
                n_channels=64,
                # n_channels=16,
                kernel_size=(7, 7),
                strides=(2, 2),
                mode='SAME_RESNET',
                use_batch_norm=True,
                activation='relu',
                is_training=training,
                data_format=self.model_hparams.compute_format,
                conv2d_hparams=self.conv2d_hparams,
                batch_norm_hparams=self.batch_norm_hparams,
                name='conv2d')

            net = layers.max_pooling2d(
                net,
                pool_size=(3, 3),
                strides=(2, 2),
                padding='SAME',
                data_format=self.model_hparams.compute_format,
                name="max_pooling2d",
            )

            for block_id, _ in enumerate(
                    range(self.model_hparams.layer_counts[0])):
                net = blocks.bottleneck_block(
                    inputs=net,
                    depth=256,
                    depth_bottleneck=64,
                    stride=1,
                    training=training,
                    data_format=self.model_hparams.compute_format,
                    conv2d_hparams=self.conv2d_hparams,
                    batch_norm_hparams=self.batch_norm_hparams,
                    block_name="btlnck_block_1_%d" % (block_id + 1))

            for block_id, i in enumerate(
                    range(self.model_hparams.layer_counts[1])):
                stride = 2 if i == 0 else 1

                net = blocks.bottleneck_block(
                    inputs=net,
                    depth=512,
                    depth_bottleneck=128,
                    stride=stride,
                    training=training,
                    data_format=self.model_hparams.compute_format,
                    conv2d_hparams=self.conv2d_hparams,
                    batch_norm_hparams=self.batch_norm_hparams,
                    block_name="btlnck_block_2_%d" % (block_id + 1))

            for block_id, i in enumerate(
                    range(self.model_hparams.layer_counts[2])):
                block_id += 1
                stride = 2 if i == 0 else 1

                net = blocks.bottleneck_block(
                    inputs=net,
                    depth=1024,
                    depth_bottleneck=256,
                    stride=stride,
                    training=training,
                    data_format=self.model_hparams.compute_format,
                    conv2d_hparams=self.conv2d_hparams,
                    batch_norm_hparams=self.batch_norm_hparams,
                    block_name="btlnck_block_3_%d" % (block_id + 1))

            for block_id, i in enumerate(
                    range(self.model_hparams.layer_counts[3])):
                stride = 2 if i == 0 else 1

                net = blocks.bottleneck_block(
                    inputs=net,
                    depth=2048,
                    depth_bottleneck=512,
                    stride=stride,
                    training=training,
                    data_format=self.model_hparams.compute_format,
                    conv2d_hparams=self.conv2d_hparams,
                    batch_norm_hparams=self.batch_norm_hparams,
                    block_name="btlnck_block_4_%d" % (block_id + 1))

            with tf.variable_scope("output"):

                net = layers.reduce_mean(
                    net,
                    keepdims=False,
                    data_format=self.model_hparams.compute_format,
                    name='spatial_mean')

                logits = layers.dense(
                    inputs=net,
                    units=self.model_hparams.n_classes,
                    use_bias=True,
                    trainable=training,
                    kernel_initializer=self.dense_hparams.kernel_initializer,
                    bias_initializer=self.dense_hparams.bias_initializer)

                if logits.dtype != tf.float32:
                    logits = tf.cast(logits, tf.float32, name="logits")

                probs = layers.softmax(logits, name="softmax", axis=1)

            return probs, logits