Esempio n. 1
0
def refine_by_decoder(features,
                      end_points,
                      decoder_height,
                      decoder_width,
                      decoder_use_separable_conv=False,
                      model_variant=None,
                      weight_decay=0.0001,
                      reuse=tf.AUTO_REUSE,
                      is_training=False,
                      fine_tune_batch_norm=False):
    """Adds the decoder to obtain sharper segmentation results.

    Args:
      features: A tensor of size [batch, features_height, features_width,
        features_channels].
      end_points: A dictionary from components of the network to the corresponding
        activation.
      decoder_height: The height of decoder feature maps.
      decoder_width: The width of decoder feature maps.
      decoder_use_separable_conv: Employ separable convolution for decoder or not.
      model_variant: Model variant for feature extraction.
      weight_decay: The weight decay for model variables.
      reuse: Reuse the model variables or not.
      is_training: Is training or not.
      fine_tune_batch_norm: Fine-tune the batch norm parameters or not.

    Returns:
      Decoder output with size [batch, decoder_height, decoder_width,
        decoder_channels].
    """
    batch_norm_params = {
        'is_training': is_training and fine_tune_batch_norm,
        'decay': 0.9997,
        'eps': 1e-5,
        'affine': True,
    }
    regularize_func = regularizer('l2', weight_decay)
    with tf.variable_scope(tf.get_variable_scope(), reuse=reuse):
        with arg_scope([sep_conv2d], activate=tf.nn.relu, activate_middle=tf.nn.relu,
                       batch_norm=True, depthwise_weight_reg=None, pointwise_weight_reg=regularize_func,
                       padding='SAME', strides=[1, 1]):
            with arg_scope([conv2d], activate=tf.nn.relu, weight_reg=regularize_func,
                           batch_norm=True, padding='SAME', strides=[1, 1]):
                with arg_scope([batch_norm2d], **batch_norm_params):
                    with tf.variable_scope(_DECODER_SCOPE, _DECODER_SCOPE, [features]):
                        feature_list = feature_extractor.networks_to_feature_maps[
                            model_variant][feature_extractor.DECODER_END_POINTS]
                        if feature_list is None:
                            tf.logging.info('Not found any decoder end points.')
                            return features
                        else:
                            decoder_features = features
                            for i, name in enumerate(feature_list):
                                decoder_features_list = [decoder_features]

                                suffix = list(end_points.keys())[0].split('/')[0]
                                feature_name = '{}/{}'.format(
                                    suffix, name)
                                # [1, 1] to reduce channel to 4
                                decoder_features_list.append(
                                    conv2d(
                                        inputs=end_points[feature_name],
                                        outc=48,
                                        ksize=[1, 1],
                                        name='feature_projection' + str(i)))
                                # Resize to decoder_height/decoder_width.
                                for j, feature in enumerate(decoder_features_list):
                                    decoder_features_list[j] = tf.image.resize_bilinear(
                                        feature, [decoder_height, decoder_width], align_corners=True)
                                    decoder_features_list[j].set_shape(
                                        [None, decoder_height, decoder_width, None])
                                decoder_depth = 256
                                if decoder_use_separable_conv:
                                    # [3,3] kernel
                                    decoder_features = sep_conv2d(
                                        inputs=tf.concat(decoder_features_list, 3),
                                        ksize=[3, 3],
                                        outc=decoder_depth,
                                        ratios=[1, 1],
                                        name='decoder_conv0')
                                    decoder_features = sep_conv2d(
                                        inputs=decoder_features,
                                        ksize=[3, 3],
                                        outc=decoder_depth,
                                        ratios=[1, 1],
                                        name='decoder_conv1')
                                    DEBUG_VARS.decoder_features = decoder_features
                                else:
                                    decoder_features = conv2d(
                                        inputs=tf.concat(decoder_features_list, 3),
                                        outc=[decoder_depth],
                                        ksize=[3, 3],
                                        name='decoder_conv0')
                                    decoder_features = conv2d(
                                        inputs=decoder_features,
                                        outc=[decoder_depth],
                                        ksize=[3, 3],
                                        name='decoder_conv0')
                            return decoder_features
Esempio n. 2
0
def xception_module(inputs,
                    depth_list,
                    skip_connection_type,
                    strides,
                    unit_rate_list=None,
                    rate=1,
                    activation_fn_in_separable_conv=False,
                    outputs_collections=None,
                    scope=None):
    """An Xception module.

    The output of one Xception module is equal to the sum of `residual` and
    `shortcut`, where `residual` is the feature computed by three separable
    convolution. The `shortcut` is the feature computed by 1x1 convolution with
    or without striding. In some cases, the `shortcut` path could be a simple
    identity function or none (i.e, no shortcut).

    Note that we replace the max pooling operations in the Xception module with
    another separable convolution with striding, since atrous rate is not properly
    supported in current TensorFlow max pooling implementation.

    Args:
      inputs: A tensor of size [batch, height, width, channels].
      depth_list: A list of three integers specifying the depth values of one
        Xception module.
      skip_connection_type: Skip connection type for the residual path. Only
        supports 'conv', 'sum', or 'none'.
      strides: The block unit's stride. Determines the amount of downsampling of
        the units output compared to its input.
      unit_rate_list: A list of three integers, determining the unit rate for
        each separable convolution in the xception module.
      rate: An integer, rate for atrous convolution.
      activation_fn_in_separable_conv: use func between depthwise and pointwise convolution
      outputs_collections: Collection to add the Xception unit output.
      scope: Optional variable_scope.

    Returns:
      The Xception module's output.

    Raises:
      ValueError: If depth_list and unit_rate_list do not contain three elements,
        or if stride != 1 for the third separable convolution operation in the
        residual path, or unsupported skip connection type.

    """
    if len(depth_list) != 3:
        raise ValueError('Expect three elements in depth_list.')
    if unit_rate_list:
        if len(unit_rate_list) != 3:
            raise ValueError('Expect three elements in unit_rate_list.')

    with tf.variable_scope(scope, 'xception_module', [inputs]):
        residual = inputs

        for i in range(3):
            if activation_fn_in_separable_conv is None:
                residual = tf.nn.relu(residual)
                activate_fn = None
            else:
                activate_fn = tf.nn.relu
            residual = sep_conv2d(
                inputs=residual,
                outc=depth_list[i],
                ksize=[3, 3],
                depth_multiplier=1,
                ratios=[rate * unit_rate_list[i], rate * unit_rate_list[i]],
                activate_middle=activation_fn_in_separable_conv,
                activate=activate_fn,
                strides=strides if i == 2 else [1, 1],
                name='separable_conv' + str(i + 1))
        if skip_connection_type == 'conv':
            shortcut = conv2d(inputs=inputs,
                              outc=depth_list[-1],
                              ksize=[1, 1],
                              strides=strides,
                              activate=None,
                              name='shortcut')
            outputs = residual + shortcut
        elif skip_connection_type == 'sum':
            outputs = residual + inputs
        elif skip_connection_type == 'none':
            outputs = residual
        else:
            raise ValueError('Unsupported skip connection type.')

        add_to_collection(outputs_collections, outputs)
        return outputs
Esempio n. 3
0
def _extract_features(images,
                      model_options,
                      weight_decay=0.0001,
                      reuse=tf.AUTO_REUSE,
                      is_training=False,
                      fine_tune_batch_norm=False):
    """Extracts features by the particular model_variant.

    Args:
      images: A tensor of size [batch, height, width, channels].
      model_options: A ModelOptions instance to configure models.
      weight_decay: The weight decay for model variables.
      reuse: Reuse the model variables or not.
      is_training: Is training or not.
      fine_tune_batch_norm: Fine-tune the batch norm parameters or not.

    Returns:
      concat_logits: A tensor of size [batch, feature_height, feature_width,
        feature_channels], where feature_height/feature_width are determined by
        the images height/width and output_stride.
      end_points: A dictionary from components of the network to the corresponding
        activation.
    """
    # feature extractor is a backbone factory
    DEBUG_VARS.raw_image = images
    features, end_points = feature_extractor.extract_features(
        images,
        output_stride=model_options.output_stride,
        multi_grid=model_options.multi_grid,
        model_variant=model_options.model_variant,
        weight_decay=weight_decay,
        reuse=reuse,
        is_training=is_training,
        fine_tune_batch_norm=fine_tune_batch_norm)

    # TODO:check
    # DEBUG_VARS.xception_feature = end_points['xception_65/entry_flow/conv1_1/Relu:0']
    DEBUG_VARS.xception_feature = features
    if not model_options.aspp_with_batch_norm:
        return features, end_points
    else:
        batch_norm_params = {
            'is_training': is_training and fine_tune_batch_norm,
            'decay': 0.9997,
            'eps': 1e-5,
            'affine': True,
        }
        regularize_func = regularizer('l2', weight_decay)
        with tf.variable_scope(tf.get_variable_scope(), reuse=reuse):
            with arg_scope([sep_conv2d], activate=tf.nn.relu, activate_middle=tf.nn.relu, batch_norm=True,
                           depthwise_weight_reg=None, pointwise_weight_reg=regularize_func,
                           padding='SAME', strides=[1, 1]):
                with arg_scope([conv2d], activate=tf.nn.relu, weight_reg=regularize_func,
                               batch_norm=True, padding='SAME', strides=[1, 1]):
                    # TODO: ASPP IS IMPLEMENTED HERE! Check Out!
                    with arg_scope([batch_norm2d], **batch_norm_params):
                        depth = 256
                        branch_logits = []

                        # TODO: ADD IMAGE POOLING HERE
                        if model_options.add_image_level_feature:
                            # this crop size has been updated to the new scaled one outside, which is the exact size
                            # of this model's inputs
                            pool_height = scale_dimension(model_options.crop_size[0],
                                                          1. / model_options.output_stride)
                            pool_width = scale_dimension(model_options.crop_size[1],
                                                         1. / model_options.output_stride)
                            # global average pooling, check whether the shape here is 1?
                            image_feature = avg_pool2d(
                                features, [pool_height, pool_width], [pool_height, pool_width],
                                padding='VALID')
                            # collapse channels to depth after GAP
                            image_feature = conv2d(
                                inputs=image_feature, outc=depth, ksize=[1, 1], name=_IMAGE_POOLING_SCOPE)
                            # TODO:check
                            DEBUG_VARS.image_feature = image_feature
                            # reshape it to final feature map shape
                            image_feature = tf.image.resize_bilinear(
                                image_feature, [pool_height, pool_width], align_corners=True)
                            image_feature.set_shape([None, pool_height, pool_width, depth])
                            # add image level feature to branch_logits
                            branch_logits.append(image_feature)

                        # Employ a 1x1 convolution.
                        branch_logits.append(conv2d(features, outc=depth, ksize=[1, 1], name=_ASPP_SCOPE + str(0)))

                        if model_options.atrous_rates:
                            # Employ 3x3 convolutions with different atrous rates.
                            DEBUG_VARS.aspp_features = []
                            for i, rate in enumerate(model_options.atrous_rates, 1):
                                scope = _ASPP_SCOPE + str(i)
                                if model_options.aspp_with_separable_conv:
                                    aspp_features = sep_conv2d(
                                        features, outc=depth, ksize=[3, 3], ratios=[rate, rate], name=scope)
                                    DEBUG_VARS.aspp_features.append(aspp_features)
                                else:
                                    aspp_features = conv2d(
                                        features, outc=depth, ksize=[3, 3], ratios=[rate, rate], name=scope)
                                branch_logits.append(aspp_features)

                        # Merge branch logits.
                        concat_logits = tf.concat(branch_logits, 3)
                        DEBUG_VARS.aspp_concat_feature = concat_logits
                        concat_logits = conv2d(inputs=concat_logits, outc=depth, ksize=[1, 1],
                                               name=_CONCAT_PROJECTION_SCOPE)
                        concat_logits = drop_out(concat_logits, kp_prob=0.9, is_training=is_training,
                                                 name=_CONCAT_PROJECTION_SCOPE + '_dropout')

                        return concat_logits, end_points