def refine_by_decoder(features, end_points, decoder_height, decoder_width, decoder_use_separable_conv=False, model_variant=None, weight_decay=0.0001, reuse=tf.AUTO_REUSE, is_training=False, fine_tune_batch_norm=False): """Adds the decoder to obtain sharper segmentation results. Args: features: A tensor of size [batch, features_height, features_width, features_channels]. end_points: A dictionary from components of the network to the corresponding activation. decoder_height: The height of decoder feature maps. decoder_width: The width of decoder feature maps. decoder_use_separable_conv: Employ separable convolution for decoder or not. model_variant: Model variant for feature extraction. weight_decay: The weight decay for model variables. reuse: Reuse the model variables or not. is_training: Is training or not. fine_tune_batch_norm: Fine-tune the batch norm parameters or not. Returns: Decoder output with size [batch, decoder_height, decoder_width, decoder_channels]. """ batch_norm_params = { 'is_training': is_training and fine_tune_batch_norm, 'decay': 0.9997, 'eps': 1e-5, 'affine': True, } regularize_func = regularizer('l2', weight_decay) with tf.variable_scope(tf.get_variable_scope(), reuse=reuse): with arg_scope([sep_conv2d], activate=tf.nn.relu, activate_middle=tf.nn.relu, batch_norm=True, depthwise_weight_reg=None, pointwise_weight_reg=regularize_func, padding='SAME', strides=[1, 1]): with arg_scope([conv2d], activate=tf.nn.relu, weight_reg=regularize_func, batch_norm=True, padding='SAME', strides=[1, 1]): with arg_scope([batch_norm2d], **batch_norm_params): with tf.variable_scope(_DECODER_SCOPE, _DECODER_SCOPE, [features]): feature_list = feature_extractor.networks_to_feature_maps[ model_variant][feature_extractor.DECODER_END_POINTS] if feature_list is None: tf.logging.info('Not found any decoder end points.') return features else: decoder_features = features for i, name in enumerate(feature_list): decoder_features_list = [decoder_features] suffix = list(end_points.keys())[0].split('/')[0] feature_name = '{}/{}'.format( suffix, name) # [1, 1] to reduce channel to 4 decoder_features_list.append( conv2d( inputs=end_points[feature_name], outc=48, ksize=[1, 1], name='feature_projection' + str(i))) # Resize to decoder_height/decoder_width. for j, feature in enumerate(decoder_features_list): decoder_features_list[j] = tf.image.resize_bilinear( feature, [decoder_height, decoder_width], align_corners=True) decoder_features_list[j].set_shape( [None, decoder_height, decoder_width, None]) decoder_depth = 256 if decoder_use_separable_conv: # [3,3] kernel decoder_features = sep_conv2d( inputs=tf.concat(decoder_features_list, 3), ksize=[3, 3], outc=decoder_depth, ratios=[1, 1], name='decoder_conv0') decoder_features = sep_conv2d( inputs=decoder_features, ksize=[3, 3], outc=decoder_depth, ratios=[1, 1], name='decoder_conv1') DEBUG_VARS.decoder_features = decoder_features else: decoder_features = conv2d( inputs=tf.concat(decoder_features_list, 3), outc=[decoder_depth], ksize=[3, 3], name='decoder_conv0') decoder_features = conv2d( inputs=decoder_features, outc=[decoder_depth], ksize=[3, 3], name='decoder_conv0') return decoder_features
def xception_module(inputs, depth_list, skip_connection_type, strides, unit_rate_list=None, rate=1, activation_fn_in_separable_conv=False, outputs_collections=None, scope=None): """An Xception module. The output of one Xception module is equal to the sum of `residual` and `shortcut`, where `residual` is the feature computed by three separable convolution. The `shortcut` is the feature computed by 1x1 convolution with or without striding. In some cases, the `shortcut` path could be a simple identity function or none (i.e, no shortcut). Note that we replace the max pooling operations in the Xception module with another separable convolution with striding, since atrous rate is not properly supported in current TensorFlow max pooling implementation. Args: inputs: A tensor of size [batch, height, width, channels]. depth_list: A list of three integers specifying the depth values of one Xception module. skip_connection_type: Skip connection type for the residual path. Only supports 'conv', 'sum', or 'none'. strides: The block unit's stride. Determines the amount of downsampling of the units output compared to its input. unit_rate_list: A list of three integers, determining the unit rate for each separable convolution in the xception module. rate: An integer, rate for atrous convolution. activation_fn_in_separable_conv: use func between depthwise and pointwise convolution outputs_collections: Collection to add the Xception unit output. scope: Optional variable_scope. Returns: The Xception module's output. Raises: ValueError: If depth_list and unit_rate_list do not contain three elements, or if stride != 1 for the third separable convolution operation in the residual path, or unsupported skip connection type. """ if len(depth_list) != 3: raise ValueError('Expect three elements in depth_list.') if unit_rate_list: if len(unit_rate_list) != 3: raise ValueError('Expect three elements in unit_rate_list.') with tf.variable_scope(scope, 'xception_module', [inputs]): residual = inputs for i in range(3): if activation_fn_in_separable_conv is None: residual = tf.nn.relu(residual) activate_fn = None else: activate_fn = tf.nn.relu residual = sep_conv2d( inputs=residual, outc=depth_list[i], ksize=[3, 3], depth_multiplier=1, ratios=[rate * unit_rate_list[i], rate * unit_rate_list[i]], activate_middle=activation_fn_in_separable_conv, activate=activate_fn, strides=strides if i == 2 else [1, 1], name='separable_conv' + str(i + 1)) if skip_connection_type == 'conv': shortcut = conv2d(inputs=inputs, outc=depth_list[-1], ksize=[1, 1], strides=strides, activate=None, name='shortcut') outputs = residual + shortcut elif skip_connection_type == 'sum': outputs = residual + inputs elif skip_connection_type == 'none': outputs = residual else: raise ValueError('Unsupported skip connection type.') add_to_collection(outputs_collections, outputs) return outputs
def _extract_features(images, model_options, weight_decay=0.0001, reuse=tf.AUTO_REUSE, is_training=False, fine_tune_batch_norm=False): """Extracts features by the particular model_variant. Args: images: A tensor of size [batch, height, width, channels]. model_options: A ModelOptions instance to configure models. weight_decay: The weight decay for model variables. reuse: Reuse the model variables or not. is_training: Is training or not. fine_tune_batch_norm: Fine-tune the batch norm parameters or not. Returns: concat_logits: A tensor of size [batch, feature_height, feature_width, feature_channels], where feature_height/feature_width are determined by the images height/width and output_stride. end_points: A dictionary from components of the network to the corresponding activation. """ # feature extractor is a backbone factory DEBUG_VARS.raw_image = images features, end_points = feature_extractor.extract_features( images, output_stride=model_options.output_stride, multi_grid=model_options.multi_grid, model_variant=model_options.model_variant, weight_decay=weight_decay, reuse=reuse, is_training=is_training, fine_tune_batch_norm=fine_tune_batch_norm) # TODO:check # DEBUG_VARS.xception_feature = end_points['xception_65/entry_flow/conv1_1/Relu:0'] DEBUG_VARS.xception_feature = features if not model_options.aspp_with_batch_norm: return features, end_points else: batch_norm_params = { 'is_training': is_training and fine_tune_batch_norm, 'decay': 0.9997, 'eps': 1e-5, 'affine': True, } regularize_func = regularizer('l2', weight_decay) with tf.variable_scope(tf.get_variable_scope(), reuse=reuse): with arg_scope([sep_conv2d], activate=tf.nn.relu, activate_middle=tf.nn.relu, batch_norm=True, depthwise_weight_reg=None, pointwise_weight_reg=regularize_func, padding='SAME', strides=[1, 1]): with arg_scope([conv2d], activate=tf.nn.relu, weight_reg=regularize_func, batch_norm=True, padding='SAME', strides=[1, 1]): # TODO: ASPP IS IMPLEMENTED HERE! Check Out! with arg_scope([batch_norm2d], **batch_norm_params): depth = 256 branch_logits = [] # TODO: ADD IMAGE POOLING HERE if model_options.add_image_level_feature: # this crop size has been updated to the new scaled one outside, which is the exact size # of this model's inputs pool_height = scale_dimension(model_options.crop_size[0], 1. / model_options.output_stride) pool_width = scale_dimension(model_options.crop_size[1], 1. / model_options.output_stride) # global average pooling, check whether the shape here is 1? image_feature = avg_pool2d( features, [pool_height, pool_width], [pool_height, pool_width], padding='VALID') # collapse channels to depth after GAP image_feature = conv2d( inputs=image_feature, outc=depth, ksize=[1, 1], name=_IMAGE_POOLING_SCOPE) # TODO:check DEBUG_VARS.image_feature = image_feature # reshape it to final feature map shape image_feature = tf.image.resize_bilinear( image_feature, [pool_height, pool_width], align_corners=True) image_feature.set_shape([None, pool_height, pool_width, depth]) # add image level feature to branch_logits branch_logits.append(image_feature) # Employ a 1x1 convolution. branch_logits.append(conv2d(features, outc=depth, ksize=[1, 1], name=_ASPP_SCOPE + str(0))) if model_options.atrous_rates: # Employ 3x3 convolutions with different atrous rates. DEBUG_VARS.aspp_features = [] for i, rate in enumerate(model_options.atrous_rates, 1): scope = _ASPP_SCOPE + str(i) if model_options.aspp_with_separable_conv: aspp_features = sep_conv2d( features, outc=depth, ksize=[3, 3], ratios=[rate, rate], name=scope) DEBUG_VARS.aspp_features.append(aspp_features) else: aspp_features = conv2d( features, outc=depth, ksize=[3, 3], ratios=[rate, rate], name=scope) branch_logits.append(aspp_features) # Merge branch logits. concat_logits = tf.concat(branch_logits, 3) DEBUG_VARS.aspp_concat_feature = concat_logits concat_logits = conv2d(inputs=concat_logits, outc=depth, ksize=[1, 1], name=_CONCAT_PROJECTION_SCOPE) concat_logits = drop_out(concat_logits, kp_prob=0.9, is_training=is_training, name=_CONCAT_PROJECTION_SCOPE + '_dropout') return concat_logits, end_points