Exemple #1
0
    def forward(self, inputs, drop_connect_rate=None):
        """
        :param inputs: input tensor
        :param drop_connect_rate: drop connect rate (float, between 0 and 1)
        :return: output of block
        """

        # Expansion and Depthwise Convolution
        x = inputs
        if self._block_args.expand_ratio != 1:
            x = relu_fn(self._bn0(self._expand_conv(inputs)))
        x = relu_fn(self._bn1(self._depthwise_conv(x)))

        # Squeeze and Excitation
        if self.has_se:
            x_squeezed = F.adaptive_avg_pool2d(x, 1)
            x_squeezed = self._se_expand(relu_fn(self._se_reduce(x_squeezed)))
            x = torch.sigmoid(x_squeezed) * x

        x = self._bn2(self._project_conv(x))

        # Skip connection and drop connect
        input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters
        if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters:
            if drop_connect_rate:
                x = drop_connect(x,
                                 p=drop_connect_rate,
                                 training=self.training)
            x = x + inputs  # skip connection
        return x
Exemple #2
0
    def extract_features(self, inputs):
        """ Returns output of the final convolution layer """

        # Stem
        x = relu_fn(self._bn0(self._conv_stem(inputs)))

        # Blocks
        for idx, block in enumerate(self._blocks):
            drop_connect_rate = self._global_params.drop_connect_rate
            if drop_connect_rate:
                drop_connect_rate *= float(idx) / len(self._blocks)
            x = block(x, drop_connect_rate=drop_connect_rate)

        # Head
        x = relu_fn(self._bn1(self._conv_head(x)))
        return x
Exemple #3
0
    def forward(self, inputs):
        # Stem
        x = relu_fn(self._bn0(self._conv_stem(inputs)))

        # Blocks
        for idx, block in enumerate(self._blocks):
            drop_connect_rate = self._global_params.drop_connect_rate
            if drop_connect_rate:
                drop_connect_rate *= float(idx) / len(self._blocks)
            x = block(
                x
            )  # , drop_connect_rate) # see https://github.com/tensorflow/tpu/issues/381

        ret = {}
        for head in self._heads:
            ret[head] = self.__getattr__(head)(x)
        return [ret]
def build_bifpn_layer(feats,
                      feat_sizes,
                      fpn_name,
                      fpn_config,
                      is_training,
                      fpn_num_filters,
                      min_level,
                      max_level,
                      separable_conv,
                      apply_bn_for_resampling,
                      conv_after_downsample,
                      use_native_resize_op,
                      conv_bn_relu_pattern,
                      pooling_type,
                      use_tpu=False):
    """Builds a feature pyramid given previous feature pyramid and config."""
    config = fpn_config or get_fpn_config(fpn_name)

    num_output_connections = [0 for _ in feats]
    for i, fnode in enumerate(config.nodes):
        with tf.variable_scope('fnode{}'.format(i)):
            logging.info('fnode %d : %s', i, fnode)
            new_node_width = feat_sizes[fnode['width_index']]
            nodes = []
            for idx, input_offset in enumerate(fnode['inputs_offsets']):
                input_node = feats[input_offset]
                num_output_connections[input_offset] += 1
                input_node = resample_feature_map(
                    input_node, '{}_{}_{}'.format(idx, input_offset,
                                                  len(feats)), new_node_width,
                    fpn_num_filters, apply_bn_for_resampling, is_training,
                    conv_after_downsample, use_native_resize_op, pooling_type)
                nodes.append(input_node)

            # Combine all nodes.
            dtype = nodes[0].dtype
            if config.weight_method == 'attn':
                edge_weights = [
                    tf.cast(tf.Variable(1.0, name='WSM'), dtype=dtype)
                    for _ in range(len(fnode['inputs_offsets']))
                ]
                normalized_weights = tf.nn.softmax(tf.stack(edge_weights))
                nodes = tf.stack(nodes, axis=-1)
                new_node = tf.reduce_sum(
                    tf.multiply(nodes, normalized_weights), -1)
            elif config.weight_method == 'fastattn':
                edge_weights = [
                    tf.nn.relu(
                        tf.cast(tf.Variable(1.0, name='WSM'), dtype=dtype))
                    for _ in range(len(fnode['inputs_offsets']))
                ]
                weights_sum = tf.add_n(edge_weights)
                nodes = [
                    nodes[i] * edge_weights[i] / (weights_sum + 0.0001)
                    for i in range(len(nodes))
                ]
                new_node = tf.add_n(nodes)
            elif config.weight_method == 'sum':
                new_node = tf.add_n(nodes)
            else:
                raise ValueError('unknown weight_method {}'.format(
                    config.weight_method))

            with tf.variable_scope('op_after_combine{}'.format(len(feats))):
                if not conv_bn_relu_pattern:
                    new_node = utils.relu_fn(new_node)

                if separable_conv:
                    conv_op = functools.partial(tf.layers.separable_conv2d,
                                                depth_multiplier=1)
                else:
                    conv_op = tf.layers.conv2d

                new_node = conv_op(
                    new_node,
                    filters=fpn_num_filters,
                    kernel_size=(3, 3),
                    padding='same',
                    use_bias=True if not conv_bn_relu_pattern else False,
                    name='conv')

                new_node = utils.batch_norm_relu(
                    new_node,
                    is_training_bn=is_training,
                    relu=False if not conv_bn_relu_pattern else True,
                    data_format='channels_last',
                    use_tpu=use_tpu,
                    name='bn')

            feats.append(new_node)
            num_output_connections.append(0)

    output_feats = {}
    for l in range(min_level, max_level + 1):
        for i, fnode in enumerate(reversed(config.nodes)):
            if fnode['width_index'] == l:
                output_feats[l] = feats[-1 - i]
                break
    return output_feats