def add_lovasz_softmax_loss_for_each_scale(scales_to_logits,
                                           labels,
                                           num_classes,
                                           ignore_label,
                                           loss_weight=1.0,
                                           upsample_logits=True,
                                           scope=None):

    if labels is None:
        raise ValueError('No label for lovasz softmax loss.')

    for scale, logits in six.iteritems(scales_to_logits):
        loss_scope = None
        if scope:
            loss_scope = '%s_%s' % (scope, scale)

        if upsample_logits:
            # Label is not downsampled, and instead we upsample logits.
            logits = tf.image.resize_bilinear(logits,
                                              preprocess_utils.resolve_shape(
                                                  labels, 4)[1:3],
                                              align_corners=True)
            scaled_labels = labels
        else:
            # Label is downsampled to the same size as logits.
            scaled_labels = tf.image.resize_nearest_neighbor(
                labels,
                preprocess_utils.resolve_shape(logits, 4)[1:3],
                align_corners=True)
        logits = tf.nn.softmax(logits)
        tf.losses.add_loss(
            lovasz_softmax(logits,
                           scaled_labels,
                           ignore=ignore_label,
                           classes='present'))
Exemple #2
0
def add_softmax_cross_entropy_loss_for_each_scale(scales_to_logits,
                                                  labels,
                                                  num_classes,
                                                  ignore_label,
                                                  loss_weight_0=10.0,
                                                  loss_weight_1=50.0,
                                                  upsample_logits=True,
                                                  scope=None):
    """Adds softmax cross entropy loss for logits of each scale.

  Args:
    scales_to_logits: A map from logits names for different scales to logits.
      The logits have shape [batch, logits_height, logits_width, num_classes].
    labels: Groundtruth labels with shape [batch, image_height, image_width, 1].
    num_classes: Integer, number of target classes.
    ignore_label: Integer, label to ignore.
    loss_weight: Float, loss weight.
    upsample_logits: Boolean, upsample logits or not.
    scope: String, the scope for the loss.

  Raises:
    ValueError: Label or logits is None.
  """
    if labels is None:
        raise ValueError('No label for softmax cross entropy loss.')

    for scale, logits in six.iteritems(scales_to_logits):
        loss_scope = None
        if scope:
            loss_scope = '%s_%s' % (scope, scale)

        if upsample_logits:
            # Label is not downsampled, and instead we upsample logits.
            logits = tf.image.resize_bilinear(logits,
                                              preprocess_utils.resolve_shape(
                                                  labels, 4)[1:3],
                                              align_corners=True)
            scaled_labels = labels
        else:
            # Label is downsampled to the same size as logits.
            scaled_labels = tf.image.resize_nearest_neighbor(
                labels,
                preprocess_utils.resolve_shape(logits, 4)[1:3],
                align_corners=True)

        scaled_labels = tf.reshape(scaled_labels, shape=[-1])
        not_ignore_mask = tf.to_float(tf.equal(
            scaled_labels, 0)) * loss_weight_0 + tf.to_float(
                tf.equal(scaled_labels, 1)) * loss_weight_1 + tf.to_float(
                    tf.equal(scaled_labels, ignore_label)) * 0
        one_hot_labels = slim.one_hot_encoding(scaled_labels,
                                               num_classes,
                                               on_value=1.0,
                                               off_value=0.0)
        tf.losses.softmax_cross_entropy(one_hot_labels,
                                        tf.reshape(logits,
                                                   shape=[-1, num_classes]),
                                        weights=not_ignore_mask,
                                        scope=loss_scope)
Exemple #3
0
def _prep_logits(logits, labels, upsample_logits):
    if upsample_logits:
        # Label is not downsampled, and instead we upsample logits.
        logits = tf.image.resize_bilinear(
            logits,
            preprocess_utils.resolve_shape(labels, 4)[1:3],
            align_corners=True)
        scaled_labels = labels
    else:
        # Label is downsampled to the same size as logits.
        scaled_labels = tf.image.resize_nearest_neighbor(
            labels,
            preprocess_utils.resolve_shape(logits, 4)[1:3],
            align_corners=True)

    return logits, scaled_labels
def sram(in_node,
              guidance,
              num_conv=1,
              conv_type="conv",
              conv_node=64,
              scope=None):
    """Single Residual Attention Module"""
    with tf.variable_scope(scope, "sram", reuse=tf.AUTO_REUSE):
        net = in_node
        if conv_type == "conv":
          conv_op = slim.conv2d
        elif conv_type == "separable_conv":
          conv_op = slim.separable_conv2d
        else:
          raise ValueError("Unknown convolution type")

        for i in range(num_conv-1):
          net = conv_op(net, conv_node, kernel_size=[3,3], scope=conv_type+str(i+1))
        net = conv_op(net, conv_node, kernel_size=[3,3], scope=conv_type+"out", activation_fn=None)

        guidance_filters = preprocess_utils.resolve_shape(guidance, rank=4)[3]
        if guidance_filters == 1:
            guidance_tile = tf.tile(guidance, [1,1,1,conv_node])
        elif guidance_filters == conv_node:
            guidance_tile = guidance
        else:
            raise ValueError("Unknown guidance filters number")

        # tf.add_to_collection("/sram_embed", {"in_node": in_node,
        #                                      "conv2": conv2,
        #                                      "guidance_tile": guidance_tile,
        #                                      "output": output})
        output = in_node + tf.multiply(net, guidance_tile)
        tf.add_to_collection(scope+"_guided_feature", tf.multiply(net, guidance_tile))
        return output
Exemple #5
0
 def concat_convolution(self, cur, last, out_node, scope):
     with tf.variable_scope(scope, "concat_conv"):
         h, w = preprocess_utils.resolve_shape(cur, rank=4)[1:3]
         last = resize_bilinear(last, [h, w])
         net = slim.conv2d(tf.concat([cur, last], axis=3),
                           out_node,
                           scope="conv1")
         return net
def add_softmax_generalized_dice_loss_for_each_scale(scales_to_logits,
                                                     labels,
                                                     num_classes,
                                                     ignore_label,
                                                     alpha=0.5,
                                                     beta=0.5,
                                                     loss_weight=1.0,
                                                     scope=None):
    """Adds softmax genralized dice loss (GDL) for logits of each scale."""
    if labels is None:
        raise ValueError('No label for softmax dice loss.')

    for scale, logits in scales_to_logits.items():
        loss_scope = None
        if scope:
            loss_scope = '%s_%s' % (scope, scale)

        shape = preprocess_utils.resolve_shape(labels, 4)
        logits = tf.image.resize_bilinear(logits,
                                          shape[1:3],
                                          align_corners=True)
        scaled_labels = labels

        scaled_labels = tf.reshape(scaled_labels,
                                   shape=[-1, shape[1] * shape[2]])

        logits = tf.reshape(logits,
                            shape=[-1, shape[1] * shape[2], num_classes])
        train_labels = tf.one_hot(scaled_labels,
                                  num_classes,
                                  on_value=1.0,
                                  off_value=0.0)

        # The reciprocal of label square for loss weight
        area = tf.reduce_sum(train_labels, axis=1)
        weights = tf.ones_like(area) / (tf.square(area) + _EPSILON)
        weights = tf.where(tf.greater(weights, tf.ones_like(weights)),
                           tf.zeros_like(weights), weights)
        weights = weights * loss_weight
        with tf.name_scope(loss_scope, 'softmax_all_pixel_loss',
                           [logits, train_labels, weights]):
            # Compute the loss for all pixels.
            prediction = tf.nn.softmax(logits, 2)
            train_labels = tf.stop_gradient(train_labels,
                                            name='train_labels_stop_gradient')

            intersection = tf.reduce_sum(train_labels * prediction, axis=1)
            union = tf.reduce_sum(train_labels, axis=1) + tf.reduce_sum(
                prediction, axis=1)

            weighted_intersection = tf.reduce_sum(tf.multiply(
                intersection, weights),
                                                  axis=1)
            weighted_union = tf.reduce_sum(tf.multiply(union, weights), axis=1)
            loss = 1 - 2 * tf.reduce_mean((weighted_intersection + _EPSILON) /
                                          (weighted_union + _EPSILON))

            tf.losses.add_loss(loss)
def add_softmax_cross_entropy_loss_for_each_scale(scales_to_logits,
                                                  labels,
                                                  num_classes,
                                                  ignore_label,
                                                  loss_weight=1.0,
                                                  upsample_logits=True,
                                                  scope=None):
    """Adds softmax cross entropy loss for logits of each scale."""
    if labels is None:
        raise ValueError('No label for softmax cross entropy loss.')

    for scale, logits in six.iteritems(scales_to_logits):
        loss_scope = None
        if scope:
            loss_scope = '%s_%s' % (scope, scale)

        if upsample_logits:
            # Label is not downsampled, and instead we upsample logits.
            logits = tf.image.resize_bilinear(logits,
                                              preprocess_utils.resolve_shape(
                                                  labels, 4)[1:3],
                                              align_corners=True)
            scaled_labels = labels
        else:
            # Label is downsampled to the same size as logits.
            scaled_labels = tf.image.resize_nearest_neighbor(
                labels,
                preprocess_utils.resolve_shape(logits, 4)[1:3],
                align_corners=True)

        scaled_labels = tf.reshape(scaled_labels, shape=[-1])
        not_ignore_mask = tf.to_float(tf.not_equal(scaled_labels,
                                                   ignore_label)) * loss_weight
        #sparse the labels
        one_hot_labels = slim.one_hot_encoding(scaled_labels,
                                               num_classes,
                                               on_value=1.0,
                                               off_value=0.0)
        tf.losses.softmax_cross_entropy(one_hot_labels,
                                        tf.reshape(logits,
                                                   shape=[-1, num_classes]),
                                        weights=not_ignore_mask,
                                        scope=loss_scope)
def guid_attention(cur, last, guid, out_node, scope=None, guid_conv_nums=2,
                   guid_conv_type="conv2d", apply_sram2=True):
  """Guid attention module"""
  h, w = preprocess_utils.resolve_shape(cur, rank=4)[1:3]
	guid_node = preprocess_utils.resolve_shape(guid, rank=4)[3]
	if guid_node != out_node and guid_node != 1:
		raise ValueError("Unknown guidance node number %d, should be 1 or out_node" %guid_node)

	with tf.variable_scope(scope, 'guid_attention'):

		guid = resize_bilinear(guid, [h, w])
		last = resize_bilinear(last, [h, w])
		net = sram(cur, guid, guid_conv_nums, guid_conv_type, out_node, "sram1")
		tf.add_to_collection("sram1", net)
		if last is not None:
			net = net + last
			if apply_sram2:
				net = sram(net, guid, guid_conv_nums, guid_conv_type, out_node, "sram2")
		tf.add_to_collection("sram2", net)
		return net
def add_softmax_dice_loss_for_each_scale(scales_to_logits,
                                         labels,
                                         num_classes,
                                         ignore_label,
                                         alpha=0.5,
                                         beta=0.5,
                                         loss_weight=1.0,
                                         activation="softmax",
                                         scope=None):
    """Adds softmax dice loss for logits of each scale."""
    if labels is None:
        raise ValueError('No label for softmax dice loss.')

    for scale, logits in scales_to_logits.items():
        loss_scope = None
        if scope:
            loss_scope = '%s_%s' % (scope, scale)

        logits = tf.image.resize_bilinear(logits,
                                          preprocess_utils.resolve_shape(
                                              labels, 4)[1:3],
                                          align_corners=True)

        labels = tf.reshape(labels, shape=[-1])
        weights = tf.constant(loss_weight, tf.float32)

        logits = tf.reshape(logits, shape=[-1, num_classes])
        train_labels = tf.one_hot(labels,
                                  num_classes,
                                  on_value=1.0,
                                  off_value=0.0)

        with tf.name_scope(loss_scope, '%s_all_pixel_loss' % activation,
                           [logits, train_labels, weights]):
            # Compute the loss for all pixels.
            if activation == "softmax":
                prediction = tf.nn.softmax(logits, 1)
            elif activation == "sigmoid":
                prediction = tf.nn.sigmoid(logits)
            else:
                raise ValueError("Unknown activation for prediction")
            train_labels = tf.stop_gradient(train_labels,
                                            name='train_labels_stop_gradient')

            intersection = tf.reduce_sum(train_labels * prediction, 0)
            union = tf.reduce_sum(train_labels, 0) + tf.reduce_sum(
                prediction, 0)

            pixel_losses = (2 * intersection + _EPSILON) / (union + _EPSILON)
            weighted_pixel_losses = tf.multiply(pixel_losses, weights)
            loss = 1 - tf.reduce_mean(weighted_pixel_losses)

            tf.losses.add_loss(loss)
  def context_attention(cur, last, guid, out_node, scope=None, guid_conv_nums=2, guid_conv_type="conv2d"):
    guid_node = preprocess_utils.resolve_shape(guid, rank=4)[3]

    if guid_node != out_node and guid_node != 1:
      raise ValueError("Unknown guidance node number %d, should be 1 or out_node" %guid_node)

    with tf.variable_scope(scope, 'context_attention'):
      context = sram(cur, guid, guid_conv_nums, guid_conv_type, embed_node, "sram1")
      tf.add_to_collection("sram1", context)
      if last is not None:
        ca_layer = attentions.self_attention(out_node)
        net = ca_layer(last, context, last, "context_att1")
        tf.add_to_collection("context_att1", net)
      return net
Exemple #11
0
    def __call__(self, f, g, h, scope):
        with tf.variable_scope(scope, 'self_attention'):
            self.n, self.h, self.w, self.c = preprocess_utils.resolve_shape(
                f, rank=4)
            f = self.embedding(f, "f")  # [bs, h, w, emb_c]
            g = self.embedding(g, "g")
            h = self.embedding(h, "h")

            # N = h * w
            o = self.get_attention(self.flatten(f), self.flatten(g),
                                   self.flatten(h))

            o = tf.reshape(o, shape=tf.shape(f))  # [bs, h, w, emb_c]
            y = f + self.embedding(o, "y")
            return y
Exemple #12
0
    def self_attention(self, x1, x2, guid, out_node, scope, *args, **kwargs):
        guid_node = preprocess_utils.resolve_shape(guid, rank=4)[3]

        if guid_node != out_node and guid_node != 1:
            raise ValueError(
                "Unknown guidance node number %d, should be 1 or out_node" %
                guid_node)

        with tf.variable_scope(scope, 'guid'):
            net = slim_sram(x1, guid, self.guid_conv_nums, self.guid_conv_type,
                            self.embed_node, "sram1")
            tf.add_to_collection("sram1", net)
            if x2 is not None:
                net = net + x2
                sa_layer = attentions.self_attention(out_node)
                net = sa_layer(net, net, net, "self_att1")
                tf.add_to_collection("self_att1", net)
            return net
  def guid_class_attention(cur, last, guid, num_class, out_node, scope=None, guid_conv_nums=2,
                           guid_conv_type="conv2d", apply_sram2=True):
    """Guid class attention module"""
    h, w = preprocess_utils.resolve_shape(cur, rank=4)[1:3]
    guid_node = guid.get_shape().as_list()[3]
    if guid_node != num_class:
      raise ValueError("Unknown guidance node number %d, should equal class number" %guid_node)
    with tf.variable_scope(scope, "guid_class_attention"):
      guid = resize_bilinear(guid, [h, w])
			last = resize_bilinear(last, [h, w])
      total_att = []
      for i in range(1, num_class):
        net = sram(cur, guid[...,i:i+1], guid_conv_nums, guid_conv_type, out_node, "sram1")
        if last is not None:
          net = net + last
          if apply_sram2:
          	net = sram(net, guid[...,i:i+1], guid_conv_nums, guid_conv_type, out_node, "sram2")
        total_att.append(net)
      fuse = slim.conv2d(tf.concat(total_att, axis=3), out_node, kernel_size=[1,1], scope="fuse")
Exemple #14
0
def seq_model(inputs, ny, nx, n_class, weight_decay, is_training, cell_type='ConvGRU'):
  with slim.arg_scope([slim.batch_norm],
                        is_training=is_training):
    with slim.arg_scope([slim.conv2d],
                      weights_initializer=tf.initializers.he_normal(),
                      weights_regularizer=slim.l2_regularizer(weight_decay),
                      normalizer_fn=slim.batch_norm):
    # in_shape = inputs.get_shape().as_list()
      in_shape = preprocess_utils.resolve_shape(inputs, rank=5)
      batch_size = in_shape[0]
      # seq_length = in_shape[1]
      nx = in_shape[2]
      ny = in_shape[3]

      if cell_type =='ConvGRU':
        with tf.variable_scope("forward_cell") as scope:
            cell_forward = cell.ConvGRUCell(shape=[ny, nx], filters=n_class, kernel=[3, 3])
            outputs_forward, state_forward = tf.nn.dynamic_rnn(
              cell=cell_forward, dtype=tf.float32, inputs=inputs,
              initial_state=cell_forward.zero_state(batch_size, dtype=tf.float32))
        feats = state_forward
      elif cell_type =='BiConvGRU':
        with tf.variable_scope("forward_cell") as scope:
            cell_forward = cell.ConvGRUCell(shape=[ny, nx], filters=n_class, kernel=[3, 3])
            outputs_forward, state_forward = tf.nn.dynamic_rnn(
              cell=cell_forward, dtype=tf.float32, inputs=inputs,
              initial_state=cell_forward.zero_state(batch_size, dtype=tf.float32))

        inputs_b = inputs[:,::-1]
        with tf.variable_scope("backward_cell") as scope:
            cell_backward = cell.ConvGRUCell(shape=[ny, nx], filters=n_class, kernel=[3, 3])
            outputs_backward, state_backward = tf.nn.dynamic_rnn(
              cell=cell_backward, dtype=tf.float32, inputs=inputs_b,
              initial_state=cell_backward.zero_state(batch_size, dtype=tf.float32))

        feats = tf.concat([state_forward, state_backward], axis=3)

      y = slim.conv2d(feats, n_class, kernel_size=[1, 1], stride=1, activation_fn=None, scope='fuse')
    # print(60*"X", inputs, y, state_forward)
    return y
Exemple #15
0
def _build_deeplab(iterator_seg, iterator, outputs_to_num_classes,
                   ignore_label):
    """Builds a clone of Supervised DeepLab.

  Args:
    iterator_seg: An iterator of type tf.data.Iterator for images and labels.
    (seg)
    iterator: An iterator of type tf.data. Iterator for images and labels.
    outputs_to_num_classes: A map from output type to the number of classes. For
      example, for the task of semantic segmentation with 21 semantic classes,
      we would have outputs_to_num_classes['semantic'] = 21.
    ignore_label: Ignore label.
  """
    if FLAGS.weakly:
        samples = iterator.get_next()
        samples[common.IMAGE] = tf.identity(samples[common.IMAGE],
                                            name=common.IMAGE)
        samples[common.LABEL] = tf.identity(samples[common.LABEL],
                                            name=common.LABEL)

    samples_seg = iterator_seg.get_next()
    samples_seg[common.IMAGE] = tf.identity(samples_seg[common.IMAGE],
                                            name=common.IMAGE + '_seg')
    samples_seg[common.LABEL] = tf.identity(samples_seg[common.LABEL],
                                            name=common.LABEL + '_seg')

    model_options = common.ModelOptions(
        outputs_to_num_classes=outputs_to_num_classes,
        crop_size=[int(sz) for sz in FLAGS.train_crop_size],
        atrous_rates=FLAGS.atrous_rates,
        output_stride=FLAGS.output_stride)

    ### Cls data
    if FLAGS.weakly:
        _, end_points_cls = feature_extractor.extract_features(
            samples[common.IMAGE],
            output_stride=model_options.output_stride,
            multi_grid=model_options.multi_grid,
            model_variant=model_options.model_variant,
            depth_multiplier=model_options.depth_multiplier,
            divisible_by=model_options.divisible_by,
            weight_decay=FLAGS.weight_decay,
            reuse=tf.AUTO_REUSE,
            is_training=True,
            preprocessed_images_dtype=model_options.preprocessed_images_dtype,
            fine_tune_batch_norm=FLAGS.fine_tune_batch_norm,
            global_pool=True,
            num_classes=outputs_to_num_classes[common.OUTPUT_TYPE] - 1)

        # ResNet beta version has an additional suffix in FLAGS.model_variant, but
        # it shares the same variable names with original version. Add a special
        # handling here for beta version ResNet.
        logits_cls = end_points_cls['{}/logits'.format(
            FLAGS.model_variant).replace('_beta', '')]
        logits_cls = tf.reshape(
            logits_cls, [-1, outputs_to_num_classes[common.OUTPUT_TYPE] - 1])
        # Seems that people usually use multi-label soft margin loss
        loss_cls = tf.nn.sigmoid_cross_entropy_with_logits(
            labels=samples['cls_label'], logits=logits_cls)
        loss_cls = tf.reduce_mean(loss_cls)
        loss_cls = tf.identity(loss_cls, name='loss_cls')
        tf.compat.v1.losses.add_loss(loss_cls)

    ### Seg data
    outputs_to_scales_to_logits = model.multi_scale_logits(
        samples_seg[common.IMAGE],
        model_options=model_options,
        image_pyramid=FLAGS.image_pyramid,
        weight_decay=FLAGS.weight_decay,
        is_training=True,
        fine_tune_batch_norm=FLAGS.fine_tune_batch_norm,
        nas_training_hyper_parameters={
            'drop_path_keep_prob': FLAGS.drop_path_keep_prob,
            'total_training_steps': FLAGS.training_number_of_steps,
        })

    # Add name to graph node so we can add to summary.
    output_type_dict = outputs_to_scales_to_logits[common.OUTPUT_TYPE]
    output_type_dict[model.MERGED_LOGITS_SCOPE] = tf.identity(
        output_type_dict[model.MERGED_LOGITS_SCOPE],
        name=common.OUTPUT_TYPE + '_seg')

    for output, num_classes in six.iteritems(outputs_to_num_classes):
        train_utils.add_softmax_cross_entropy_loss_for_each_scale(
            outputs_to_scales_to_logits[output],
            samples_seg[common.LABEL],
            num_classes,
            ignore_label,
            loss_weight=model_options.label_weights,
            upsample_logits=FLAGS.upsample_logits,
            hard_example_mining_step=FLAGS.hard_example_mining_step,
            top_k_percent_pixels=FLAGS.top_k_percent_pixels,
            scope=output)

    ## Sanity check. Monitor pixel accuracy
    logits_seg = outputs_to_scales_to_logits[common.OUTPUT_TYPE][
        model.MERGED_LOGITS_SCOPE]
    temp_label = tf.compat.v1.image.resize_nearest_neighbor(
        samples_seg[common.LABEL],
        preprocess_utils.resolve_shape(logits_seg, 4)[1:3])
    temp_label = tf.reshape(temp_label, [-1])

    dump = tf.concat(
        [tf.range(outputs_to_num_classes[common.OUTPUT_TYPE]), temp_label],
        axis=-1)
    _, _, count = tf.unique_with_counts(dump)
    num_pixel_list = count - 1
    # Exclude the ignore region
    num_pixel_list = num_pixel_list[:outputs_to_num_classes[common.
                                                            OUTPUT_TYPE]]
    num_pixel_list = tf.cast(num_pixel_list, tf.float32)
    inverse_ratio = train_utils._div_maybe_zero(1, num_pixel_list)
    inverse_ratio = inverse_ratio / tf.reduce_sum(inverse_ratio)

    # Create weight mask to balance each class
    weight_mask = tf.einsum(
        '...y,y->...',
        tf.one_hot(temp_label,
                   outputs_to_num_classes[common.OUTPUT_TYPE],
                   dtype=tf.float32), inverse_ratio)
    temp_valid = tf.not_equal(temp_label, ignore_label)
    temp_label_valid = tf.boolean_mask(temp_label, temp_valid)
    weight_mask_valid = tf.boolean_mask(weight_mask, temp_valid)

    pred_seg = tf.argmax(logits_seg, axis=-1)
    pred_seg = tf.reshape(pred_seg, [-1])
    acc_seg, acc_seg_op = tf.metrics.mean_per_class_accuracy(
        temp_label_valid,
        tf.boolean_mask(pred_seg, temp_valid),
        outputs_to_num_classes[common.OUTPUT_TYPE],
        weights=weight_mask_valid)
    with tf.control_dependencies([acc_seg_op]):
        acc_seg = tf.identity(acc_seg, name='acc_seg')
def add_softmax_cross_entropy_loss_for_each_scale(scales_to_logits,
                                                  labels,
                                                  num_classes,
                                                  ignore_label,
                                                  loss_weight=1.0,
                                                  upsample_logits=True,
                                                  hard_example_mining_step=0,
                                                  top_k_percent_pixels=1.0,
                                                  gt_is_matting_map=False,
                                                  activation="softmax",
                                                  scope=None):
    """Adds softmax cross entropy loss for logits of each scale.
  Args:
    scales_to_logits: A map from logits names for different scales to logits.
      The logits have shape [batch, logits_height, logits_width, num_classes].
    labels: Groundtruth labels with shape [batch, image_height, image_width, 1].
    num_classes: Integer, number of target classes.
    ignore_label: Integer, label to ignore.
    loss_weight: A float or a list of loss weights. If it is a float, it means
      all the labels have the same weight. If it is a list of weights, then each
      element in the list represents the weight for the label of its index, for
      example, loss_weight = [0.1, 0.5] means the weight for label 0 is 0.1 and
      the weight for label 1 is 0.5.
    upsample_logits: Boolean, upsample logits or not.
    hard_example_mining_step: An integer, the training step in which the hard
      exampling mining kicks off. Note that we gradually reduce the mining
      percent to the top_k_percent_pixels. For example, if
      hard_example_mining_step = 100K and top_k_percent_pixels = 0.25, then
      mining percent will gradually reduce from 100% to 25% until 100K steps
      after which we only mine top 25% pixels.
    top_k_percent_pixels: A float, the value lies in [0.0, 1.0]. When its value
      < 1.0, only compute the loss for the top k percent pixels (e.g., the top
      20% pixels). This is useful for hard pixel mining.
    gt_is_matting_map: If true, the groundtruth is a matting map of confidence
      score. If false, the groundtruth is an integer valued class mask.
    scope: String, the scope for the loss.
  Raises:
    ValueError: Label or logits is None, or groundtruth is matting map while
      label is not floating value.
  """
    if labels is None:
        raise ValueError('No label for softmax cross entropy loss.')

    # If input groundtruth is a matting map of confidence, check if the input
    # labels are floating point values.
    if gt_is_matting_map and not labels.dtype.is_floating:
        raise ValueError(
            'Labels must be floats if groundtruth is a matting map.')

    for scale, logits in six.iteritems(scales_to_logits):
        loss_scope = None
        if scope:
            loss_scope = '%s_%s' % (scope, scale)

        if upsample_logits:
            # Label is not downsampled, and instead we upsample logits.
            logits = tf.image.resize_bilinear(logits,
                                              preprocess_utils.resolve_shape(
                                                  labels, 4)[1:3],
                                              align_corners=True)
            scaled_labels = labels
        else:
            # Label is downsampled to the same size as logits.
            # When gt_is_matting_map = true, label downsampling with nearest neighbor
            # method may introduce artifacts. However, to avoid ignore_label from
            # being interpolated with other labels, we still perform nearest neighbor
            # interpolation.
            # TODO(huizhongc): Change to bilinear interpolation by processing padded
            # and non-padded label separately.
            if gt_is_matting_map:
                tf.logging.warning(
                    'Label downsampling with nearest neighbor may introduce artifacts.'
                )

            scaled_labels = tf.image.resize_nearest_neighbor(
                labels,
                preprocess_utils.resolve_shape(logits, 4)[1:3],
                align_corners=True)

        scaled_labels = tf.reshape(scaled_labels, shape=[-1])
        if activation == "sigmoid":
            keep_class_dims = True
            loss_func = tf.nn.sigmoid_cross_entropy_with_logits
        elif activation == "softmax":
            keep_class_dims = False
            loss_func = tf.nn.softmax_cross_entropy_with_logits_v2
        else:
            raise ValueError("Unknown activation for prediction")
        weights = get_label_weight_mask(scaled_labels,
                                        ignore_label,
                                        num_classes,
                                        label_weights=loss_weight,
                                        keep_class_dims=keep_class_dims)
        # Dimension of keep_mask is equal to the total number of pixels.
        keep_mask = tf.cast(tf.not_equal(scaled_labels, ignore_label),
                            dtype=tf.float32)

        train_labels = None
        logits = tf.reshape(logits, shape=[-1, num_classes])

        if gt_is_matting_map:
            # When the groundtruth is integer label mask, we can assign class
            # dependent label weights to the loss. When the groundtruth is image
            # matting confidence, we do not apply class-dependent label weight (i.e.,
            # label_weight = 1.0).
            if loss_weight != 1.0:
                raise ValueError(
                    'loss_weight must equal to 1 if groundtruth is matting map.'
                )

            # Assign label value 0 to ignore pixels. The exact label value of ignore
            # pixel does not matter, because those ignore_value pixel losses will be
            # multiplied to 0 weight.
            train_labels = scaled_labels * keep_mask

            train_labels = tf.expand_dims(train_labels, 1)
            train_labels = tf.concat([1 - train_labels, train_labels], axis=1)
        else:
            train_labels = tf.one_hot(scaled_labels,
                                      num_classes,
                                      on_value=1.0,
                                      off_value=0.0)

        default_loss_scope = ('softmax_all_pixel_loss' if top_k_percent_pixels
                              == 1.0 else 'softmax_hard_example_mining')
        with tf.name_scope(loss_scope, default_loss_scope,
                           [logits, train_labels, weights]):
            # Compute the loss for all pixels.
            pixel_losses = loss_func(labels=tf.stop_gradient(
                train_labels, name='train_labels_stop_gradient'),
                                     logits=logits,
                                     name='pixel_losses')
            weighted_pixel_losses = tf.multiply(pixel_losses, weights)

            if top_k_percent_pixels == 1.0:
                total_loss = tf.reduce_sum(weighted_pixel_losses)
                num_present = tf.reduce_sum(keep_mask)
                loss = _div_maybe_zero(total_loss, num_present)
                tf.losses.add_loss(loss)
            else:
                num_pixels = tf.to_float(tf.shape(logits)[0])
                # Compute the top_k_percent pixels based on current training step.
                if hard_example_mining_step == 0:
                    # Directly focus on the top_k pixels.
                    top_k_pixels = tf.to_int32(top_k_percent_pixels *
                                               num_pixels)
                else:
                    # Gradually reduce the mining percent to top_k_percent_pixels.
                    global_step = tf.to_float(
                        tf.train.get_or_create_global_step())
                    ratio = tf.minimum(1.0,
                                       global_step / hard_example_mining_step)
                    top_k_pixels = tf.to_int32((ratio * top_k_percent_pixels +
                                                (1.0 - ratio)) * num_pixels)
                top_k_losses, _ = tf.nn.top_k(weighted_pixel_losses,
                                              k=top_k_pixels,
                                              sorted=True,
                                              name='top_k_percent_pixels')
                total_loss = tf.reduce_sum(top_k_losses)
                num_present = tf.reduce_sum(
                    tf.to_float(tf.not_equal(top_k_losses, 0.0)))
                loss = _div_maybe_zero(total_loss, num_present)
                tf.losses.add_loss(loss)
Exemple #17
0
def _build_pseudo_seg(iterator_seg,
                      iterator,
                      outputs_to_num_classes,
                      ignore_label,
                      batch_size=8):
    """Builds a clone of PseudoSeg.

  Args:
    iterator_seg: An iterator of type tf.data.Iterator for images and labels.
    (seg)
    iterator: An iterator of type tf.data. Iterator for images and labels.
    outputs_to_num_classes: A map from output type to the number of classes. For
      example, for the task of semantic segmentation with 21 semantic classes,
      we would have outputs_to_num_classes['semantic'] = 21.
    ignore_label: Ignore label.
    batch_size: Training batch size for each clone.
  """
    samples_cls = iterator.get_next()
    samples_cls[common.IMAGE] = tf.identity(samples_cls[common.IMAGE],
                                            name='weak')
    samples_cls['strong'] = tf.identity(samples_cls['strong'], name='strong')
    samples_cls[common.LABEL] = tf.identity(samples_cls[common.LABEL],
                                            name='unlabeled')

    samples_seg = iterator_seg.get_next()
    samples_seg[common.IMAGE] = tf.identity(samples_seg[common.IMAGE],
                                            name=common.IMAGE + '_seg')
    samples_seg[common.LABEL] = tf.identity(samples_seg[common.LABEL],
                                            name=common.LABEL + '_seg')

    model_options = common.ModelOptions(
        outputs_to_num_classes=outputs_to_num_classes,
        crop_size=[int(sz) for sz in FLAGS.train_crop_size],
        atrous_rates=FLAGS.atrous_rates,
        output_stride=FLAGS.output_stride)

    if FLAGS.att_v2:
        cam_func = train_utils_core.compute_cam_v2
    else:
        cam_func = train_utils_core.compute_cam

    ### Cls/unlabeled data
    ## 1) If we have image-level label, we train the classifier here
    if FLAGS.weakly:
        if FLAGS.cls_with_cls:
            curr_samples = samples_cls
        else:
            curr_samples = samples_seg

        _, end_points_cls = feature_extractor.extract_features(
            curr_samples[common.IMAGE],
            output_stride=model_options.output_stride,
            multi_grid=model_options.multi_grid,
            model_variant=model_options.model_variant,
            depth_multiplier=model_options.depth_multiplier,
            divisible_by=model_options.divisible_by,
            weight_decay=FLAGS.weight_decay,
            reuse=tf.AUTO_REUSE,
            is_training=True,
            preprocessed_images_dtype=model_options.preprocessed_images_dtype,
            fine_tune_batch_norm=FLAGS.fine_tune_batch_norm,
            global_pool=True,
            num_classes=outputs_to_num_classes[common.OUTPUT_TYPE] - 1)

        # ResNet beta version has an additional suffix in FLAGS.model_variant, but
        # it shares the same variable names with original version. Add a special
        # handling here for beta version ResNet.
        logits_cls = end_points_cls['{}/logits'.format(
            FLAGS.model_variant).replace('_beta', '')]
        logits_cls = tf.reshape(
            logits_cls, [-1, outputs_to_num_classes[common.OUTPUT_TYPE] - 1])
        # Seems that people usually use multi-label soft margin loss in PyTorch
        loss_cls = tf.nn.sigmoid_cross_entropy_with_logits(
            labels=curr_samples['cls_label'], logits=logits_cls)
        loss_cls = tf.reduce_mean(loss_cls)
        loss_cls = tf.identity(loss_cls, name='loss_cls')
        tf.compat.v1.losses.add_loss(loss_cls)

    ## 2) Consistency
    with tf.name_scope('cls_weak'):
        outputs_to_scales_to_logits, _ = model.multi_scale_logits(
            samples_cls[common.IMAGE],
            model_options=model_options,
            image_pyramid=FLAGS.image_pyramid,
            weight_decay=FLAGS.weight_decay,
            is_training=True,
            fine_tune_batch_norm=FLAGS.fine_tune_batch_norm,
            nas_training_hyper_parameters={
                'drop_path_keep_prob': FLAGS.drop_path_keep_prob,
                'total_training_steps': FLAGS.training_number_of_steps,
            },
            output_end_points=True)
        logits_weak = outputs_to_scales_to_logits[common.OUTPUT_TYPE][
            model.MERGED_LOGITS_SCOPE]

    prob_weak = tf.nn.softmax(logits_weak, axis=-1)
    logits_weak = tf.identity(logits_weak, name='logits_weak')
    # Monitor max score
    max_prob_weak = tf.reduce_max(prob_weak, axis=-1)
    max_prob_weak = tf.identity(max_prob_weak, name='max_prob_weak')

    valid_mask_pad = samples_cls['valid']
    valid_mask_pad = tf.compat.v1.image.resize_nearest_neighbor(
        valid_mask_pad,
        preprocess_utils.resolve_shape(logits_weak, 4)[1:3])
    valid_mask_pad = tf.cast(valid_mask_pad, tf.float32)

    if FLAGS.use_attention:
        # Using inference mode to generate Grad-CAM
        with tf.name_scope('cls_data_cls_inference'):
            _, end_points_cls = feature_extractor.extract_features(
                samples_cls[common.IMAGE],
                output_stride=model_options.output_stride,
                multi_grid=model_options.multi_grid,
                model_variant=model_options.model_variant,
                depth_multiplier=model_options.depth_multiplier,
                divisible_by=model_options.divisible_by,
                weight_decay=FLAGS.weight_decay,
                reuse=tf.AUTO_REUSE,
                is_training=False,
                preprocessed_images_dtype=model_options.
                preprocessed_images_dtype,
                fine_tune_batch_norm=FLAGS.fine_tune_batch_norm,
                global_pool=True,
                num_classes=outputs_to_num_classes[common.OUTPUT_TYPE] - 1)

        logits_cls = end_points_cls['{}/logits'.format(
            FLAGS.model_variant).replace('_beta', '')]
        logits_cls = tf.reshape(
            logits_cls, [-1, outputs_to_num_classes[common.OUTPUT_TYPE] - 1])

        # We can only get ground truth image-level label in weakly+semi setting
        if FLAGS.weakly:
            image_level_label = samples_cls['cls_label']
        else:
            prob_cls = tf.sigmoid(logits_cls)
            # TODO(ylzou): Might use a variable threshold for different classes
            pred_cls = tf.greater_equal(prob_cls, 0.5)
            image_level_label = tf.stop_gradient(tf.cast(pred_cls, tf.float32))

        cam_weak, att_cam_weak = cam_func(
            end_points_cls,
            logits_cls,
            image_level_label,
            num_class=outputs_to_num_classes[common.OUTPUT_TYPE],
            use_attention=True,
            attention_dim=FLAGS.attention_dim,
            strides=[int(st) for st in FLAGS.att_strides],
            is_training=True,
            valid_mask=valid_mask_pad,
            net=FLAGS.model_variant.replace('_beta', ''))
        att_logits_weak = att_cam_weak
        # Upsample att-cam
        att_logits_weak = tf.compat.v1.image.resize_bilinear(
            att_logits_weak,
            preprocess_utils.resolve_shape(logits_weak, 4)[1:3],
            align_corners=True)
        # Monitor vanilla cam
        cam_weak = tf.compat.v1.image.resize_bilinear(
            cam_weak,
            preprocess_utils.resolve_shape(logits_weak, 4)[1:3],
            align_corners=True)
        cam_weak = tf.identity(cam_weak, name='cam_weak')

        att_prob_weak = tf.nn.softmax(att_logits_weak, axis=-1)
        att_logits_weak = tf.identity(att_logits_weak, name='att_logits_weak')
        # Monitor max score
        max_att_prob_weak = tf.reduce_max(att_prob_weak, axis=-1)
        max_att_prob_weak = tf.identity(max_att_prob_weak,
                                        name='max_att_prob_weak')

        # Ensemble
        if FLAGS.pseudo_src == 'att':
            prob_weak = att_prob_weak
        else:
            if FLAGS.logit_norm:
                v = tf.concat([logits_weak, att_logits_weak], axis=0)
                all_logits_weak = v * tf.rsqrt(
                    tf.reduce_mean(tf.square(v)) + 1e-8)
                scaled_logits_weak = all_logits_weak[:batch_size]
                prob_weak = tf.nn.softmax(scaled_logits_weak, axis=-1)
                scaled_att_logits_weak = all_logits_weak[batch_size:]
                att_prob_weak = tf.nn.softmax(scaled_att_logits_weak, axis=-1)
            prob_weak = (prob_weak + att_prob_weak) / 2.

        # Monitor max score
        max_prob_avg = tf.reduce_max(prob_weak, axis=-1)
        max_prob_avg = tf.identity(max_prob_avg, name='max_prob_avg')

    # Temperature
    if FLAGS.soft_pseudo_label and FLAGS.temperature != 1.0:
        prob_weak = tf.pow(prob_weak, 1. / FLAGS.temperature)
        prob_weak /= tf.reduce_sum(prob_weak, axis=-1, keepdims=True)
        # Monitor max score
        max_prob_avg_t = tf.reduce_max(prob_weak, axis=-1)
        max_prob_avg_t = tf.identity(max_prob_avg_t, name='max_prob_avg_t')
        # Monitor merged logits
        prob_weak = tf.identity(prob_weak, name='merged_logits')

    with tf.name_scope('cls_strong'):
        outputs_to_scales_to_logits, _ = model.multi_scale_logits(
            samples_cls['strong'],
            model_options=model_options,
            image_pyramid=FLAGS.image_pyramid,
            weight_decay=FLAGS.weight_decay,
            is_training=True,
            fine_tune_batch_norm=FLAGS.fine_tune_batch_norm,
            nas_training_hyper_parameters={
                'drop_path_keep_prob': FLAGS.drop_path_keep_prob,
                'total_training_steps': FLAGS.training_number_of_steps,
            },
            output_end_points=True)
        logits_strong = outputs_to_scales_to_logits[common.OUTPUT_TYPE][
            model.MERGED_LOGITS_SCOPE]
    logits_strong = tf.identity(logits_strong, name='logits_strong')

    if FLAGS.pseudo_label_threshold > 0:
        confidence_weak = tf.expand_dims(tf.reduce_max(prob_weak, axis=-1),
                                         axis=-1)
        valid_mask_score = tf.greater_equal(confidence_weak,
                                            FLAGS.pseudo_label_threshold)
        valid_mask_score = tf.cast(valid_mask_score, tf.float32)
        valid_mask = valid_mask_score * valid_mask_pad
    else:
        valid_mask_score = None
        valid_mask = valid_mask_pad
    # Save for visualization
    valid_mask = tf.identity(valid_mask, name='valid_mask')

    logits_strong = tf.reshape(
        logits_strong, [-1, outputs_to_num_classes[common.OUTPUT_TYPE]])

    if not FLAGS.soft_pseudo_label:
        pseudo_label = tf.argmax(prob_weak, axis=-1)
        pseudo_label = tf.reshape(pseudo_label, [-1])
        pseudo_label = tf.stop_gradient(pseudo_label)
        loss_consistency = tf.compat.v1.nn.sparse_softmax_cross_entropy_with_logits(
            labels=pseudo_label,
            logits=logits_strong,
            name='consistency_losses')
        loss_consistency = loss_consistency * tf.reshape(valid_mask, [-1])
        pred_pseudo = pseudo_label
    else:
        pseudo_label = prob_weak
        pseudo_label = tf.reshape(
            pseudo_label, [-1, outputs_to_num_classes[common.OUTPUT_TYPE]])
        pseudo_label = tf.stop_gradient(pseudo_label)
        loss_consistency = tf.compat.v1.nn.softmax_cross_entropy_with_logits_v2(
            labels=pseudo_label,
            logits=logits_strong,
            name='consistency_losses')
        loss_consistency = loss_consistency * tf.reshape(valid_mask, [-1])
        pred_pseudo = tf.argmax(pseudo_label, axis=-1)

    # NOTE: When average, we divide by the number of pixels excluding padding
    loss_consistency = tf.reduce_sum(loss_consistency)
    loss_consistency = train_utils._div_maybe_zero(
        loss_consistency, tf.reduce_sum(valid_mask_pad))
    loss_consistency *= FLAGS.unlabeled_weight
    loss_consistency = tf.identity(loss_consistency, 'loss_consistency')
    tf.compat.v1.losses.add_loss(loss_consistency)

    ## 3) Monitor prediction quality
    temp_label = tf.compat.v1.image.resize_nearest_neighbor(
        samples_cls[common.LABEL],
        preprocess_utils.resolve_shape(logits_weak, 4)[1:3])
    temp_label = tf.reshape(temp_label, [-1])

    # Get #pixel of each class, so that we can re-weight them for pixel acc.
    dump = tf.concat(
        [tf.range(outputs_to_num_classes[common.OUTPUT_TYPE]), temp_label],
        axis=-1)
    _, _, count = tf.unique_with_counts(dump)
    num_pixel_list = count - 1
    # Exclude the ignore region
    num_pixel_list = num_pixel_list[:outputs_to_num_classes[common.
                                                            OUTPUT_TYPE]]
    num_pixel_list = tf.cast(num_pixel_list, tf.float32)
    inverse_ratio = train_utils._div_maybe_zero(1, num_pixel_list)
    inverse_ratio = inverse_ratio / tf.reduce_sum(inverse_ratio)

    # Since tf.metrics.mean_per_class_accuracy does not support weighted average
    # for each class directly, we here convert it to pixel-wise weighted mask to
    # compute weighted average pixel accuracy.
    weight_mask = tf.einsum(
        '...y,y->...',
        tf.one_hot(temp_label,
                   outputs_to_num_classes[common.OUTPUT_TYPE],
                   dtype=tf.float32), inverse_ratio)
    temp_valid = tf.not_equal(temp_label, ignore_label)
    if valid_mask_score is not None:
        temp_valid_confident = tf.cast(temp_valid, tf.float32) * tf.reshape(
            valid_mask_score, [-1])
        temp_valid_confident = tf.cast(temp_valid_confident, tf.bool)
    else:
        temp_valid_confident = temp_valid

    temp_label_confident = tf.boolean_mask(temp_label, temp_valid_confident)
    temp_label_valid = tf.boolean_mask(temp_label, temp_valid)
    weight_mask_confident = tf.boolean_mask(weight_mask, temp_valid_confident)
    weight_mask_valid = tf.boolean_mask(weight_mask, temp_valid)

    if FLAGS.pseudo_label_threshold > 0:
        acc_pseudo, acc_pseudo_op = tf.metrics.mean_per_class_accuracy(
            temp_label_confident,
            tf.boolean_mask(pred_pseudo, temp_valid_confident),
            outputs_to_num_classes[common.OUTPUT_TYPE],
            weights=weight_mask_confident)
        with tf.control_dependencies([acc_pseudo_op]):
            acc_pseudo = tf.identity(acc_pseudo, name='acc_pseudo')

    pred_weak = tf.cast(tf.argmax(prob_weak, axis=-1), tf.int32)
    pred_weak = tf.reshape(pred_weak, [-1])
    acc_weak, acc_weak_op = tf.metrics.mean_per_class_accuracy(
        temp_label_valid,
        tf.boolean_mask(pred_weak, temp_valid),
        outputs_to_num_classes[common.OUTPUT_TYPE],
        weights=weight_mask_valid)
    with tf.control_dependencies([acc_weak_op]):
        acc_weak = tf.identity(acc_weak, name='acc_weak')

    pred_strong = tf.cast(tf.argmax(logits_strong, axis=-1), tf.int32)
    pred_strong = tf.reshape(pred_strong, [-1])
    # For all pixels
    acc_strong, acc_strong_op = tf.metrics.mean_per_class_accuracy(
        temp_label_valid,
        tf.boolean_mask(pred_strong, temp_valid),
        outputs_to_num_classes[common.OUTPUT_TYPE],
        weights=weight_mask_valid)
    with tf.control_dependencies([acc_strong_op]):
        acc_strong = tf.identity(acc_strong, name='acc_strong')

    # For confident pixels
    if FLAGS.pseudo_label_threshold > 0:
        acc_strong_confident, acc_strong_confident_op = tf.metrics.mean_per_class_accuracy(
            temp_label_confident,
            tf.boolean_mask(pred_strong, temp_valid_confident),
            outputs_to_num_classes[common.OUTPUT_TYPE],
            weights=weight_mask_confident)
        with tf.control_dependencies([acc_strong_confident_op]):
            acc_strong_confident = tf.identity(acc_strong_confident,
                                               name='acc_strong_confident')

        valid_ratio = tf.reduce_sum(valid_mask) / tf.reduce_sum(valid_mask_pad)
        valid_ratio = tf.identity(valid_ratio, name='valid_ratio')

    ### Pixel-level data
    ## 1) Segmentation
    outputs_to_scales_to_logits = model.multi_scale_logits(
        samples_seg[common.IMAGE],
        model_options=model_options,
        image_pyramid=FLAGS.image_pyramid,
        weight_decay=FLAGS.weight_decay,
        is_training=True,
        fine_tune_batch_norm=FLAGS.fine_tune_batch_norm,
        nas_training_hyper_parameters={
            'drop_path_keep_prob': FLAGS.drop_path_keep_prob,
            'total_training_steps': FLAGS.training_number_of_steps,
        })

    # Add name to graph node so we can add to summary.
    output_type_dict = outputs_to_scales_to_logits[common.OUTPUT_TYPE]
    output_type_dict[model.MERGED_LOGITS_SCOPE] = tf.identity(
        output_type_dict[model.MERGED_LOGITS_SCOPE],
        name=common.OUTPUT_TYPE + '_seg')

    for output, num_classes in six.iteritems(outputs_to_num_classes):
        train_utils.add_softmax_cross_entropy_loss_for_each_scale(
            outputs_to_scales_to_logits[output],
            samples_seg[common.LABEL],
            num_classes,
            ignore_label,
            loss_weight=model_options.label_weights,
            upsample_logits=FLAGS.upsample_logits,
            hard_example_mining_step=FLAGS.hard_example_mining_step,
            top_k_percent_pixels=FLAGS.top_k_percent_pixels,
            scope=output)

    ## 2) Train self-attention module
    if FLAGS.use_attention:
        valid_mask_pad = samples_seg['valid']
        valid_mask_pad = tf.compat.v1.image.resize_nearest_neighbor(
            valid_mask_pad,
            preprocess_utils.resolve_shape(logits_weak, 4)[1:3])
        valid_mask_pad = tf.cast(valid_mask_pad, tf.float32)

        with tf.name_scope('seg_data_cls'):
            _, end_points_cls = feature_extractor.extract_features(
                samples_seg[common.IMAGE],
                output_stride=model_options.output_stride,
                multi_grid=model_options.multi_grid,
                model_variant=model_options.model_variant,
                depth_multiplier=model_options.depth_multiplier,
                divisible_by=model_options.divisible_by,
                weight_decay=FLAGS.weight_decay,
                reuse=tf.AUTO_REUSE,
                is_training=True,
                preprocessed_images_dtype=model_options.
                preprocessed_images_dtype,
                fine_tune_batch_norm=FLAGS.fine_tune_batch_norm,
                global_pool=True,
                num_classes=outputs_to_num_classes[common.OUTPUT_TYPE] - 1)

        logits_cls = end_points_cls['{}/logits'.format(
            FLAGS.model_variant).replace('_beta', '')]
        logits_cls = tf.reshape(
            logits_cls, [-1, outputs_to_num_classes[common.OUTPUT_TYPE] - 1])

        _, att_cam_labeled = cam_func(
            end_points_cls,
            logits_cls,
            samples_seg['cls_label'],
            num_class=outputs_to_num_classes[common.OUTPUT_TYPE],
            use_attention=True,
            attention_dim=FLAGS.attention_dim,
            strides=[int(st) for st in FLAGS.att_strides],
            is_training=True,
            valid_mask=valid_mask_pad,
            net=FLAGS.model_variant.replace('_beta', ''))
        att_logits_labeled = att_cam_labeled

        # Loss
        train_utils.add_softmax_cross_entropy_loss_for_each_scale(
            {'self-attention_logits': att_logits_labeled},
            samples_seg[common.LABEL],
            outputs_to_num_classes[common.OUTPUT_TYPE],
            ignore_label,
            loss_weight=model_options.label_weights,
            upsample_logits=FLAGS.upsample_logits,
            hard_example_mining_step=FLAGS.hard_example_mining_step,
            top_k_percent_pixels=FLAGS.top_k_percent_pixels,
            scope='self-attention')

        att_logits_labeled = tf.identity(att_logits_labeled,
                                         name='att_logits_labeled')

        ## 3) If no image-level label, convert pixel-level label to train classifier
        if not FLAGS.weakly:
            # Seems that people usually use multi-label soft margin loss in PyTorch
            loss_cls = tf.nn.sigmoid_cross_entropy_with_logits(
                labels=samples_seg['cls_label'], logits=logits_cls)
            loss_cls = tf.reduce_mean(loss_cls)
            loss_cls = tf.identity(loss_cls, name='loss_cls')
            tf.compat.v1.losses.add_loss(loss_cls)

    ## 4) Sanity check. Monitor pixel accuracy
    logits_seg = outputs_to_scales_to_logits[common.OUTPUT_TYPE][
        model.MERGED_LOGITS_SCOPE]
    temp_label = tf.compat.v1.image.resize_nearest_neighbor(
        samples_seg[common.LABEL],
        preprocess_utils.resolve_shape(logits_seg, 4)[1:3])
    temp_label = tf.reshape(temp_label, [-1])

    dump = tf.concat(
        [tf.range(outputs_to_num_classes[common.OUTPUT_TYPE]), temp_label],
        axis=-1)
    _, _, count = tf.unique_with_counts(dump)
    num_pixel_list = count - 1
    # Exclude the ignore region
    num_pixel_list = num_pixel_list[:outputs_to_num_classes[common.
                                                            OUTPUT_TYPE]]
    num_pixel_list = tf.cast(num_pixel_list, tf.float32)
    inverse_ratio = train_utils._div_maybe_zero(1, num_pixel_list)
    inverse_ratio = inverse_ratio / tf.reduce_sum(inverse_ratio)

    # Create weight mask to balance each class
    weight_mask = tf.einsum(
        '...y,y->...',
        tf.one_hot(temp_label,
                   outputs_to_num_classes[common.OUTPUT_TYPE],
                   dtype=tf.float32), inverse_ratio)
    temp_valid = tf.not_equal(temp_label, ignore_label)
    temp_label_valid = tf.boolean_mask(temp_label, temp_valid)
    weight_mask_valid = tf.boolean_mask(weight_mask, temp_valid)

    pred_seg = tf.argmax(logits_seg, axis=-1)
    pred_seg = tf.reshape(pred_seg, [-1])
    acc_seg, acc_seg_op = tf.metrics.mean_per_class_accuracy(
        temp_label_valid,
        tf.boolean_mask(pred_seg, temp_valid),
        outputs_to_num_classes[common.OUTPUT_TYPE],
        weights=weight_mask_valid)
    with tf.control_dependencies([acc_seg_op]):
        acc_seg = tf.identity(acc_seg, name='acc_seg')
def sum_convolution(cur, last, out_node, scope=None):
	with tf.variable_scope(scope, "sum_cocnv"):
		h, w = preprocess_utils.resolve_shape(cur, rank=4)[1:3]
		last = resize_bilinear(last, [h, w])
		net = slim.conv2d(cur+last, out_node, scope="conv1")
		return net
Exemple #19
0
    def model(self):
        # TODO: reolve_shape
        # TODO: image size
        # TODO: Remove after finish code
        batch_norm = slim.batch_norm
        batch_norm_params = get_batch_norm_params(
            decay=0.9997,
            epsilon=1e-5,
            scale=True,
            is_training=(self.is_training and self.fine_tune_batch_norm),
            # sync_batch_norm_method=model_options.sync_batch_norm_method
        )

        with tf.variable_scope(self.scope, 'Refine_Network'):
            with slim.arg_scope(
                [slim.conv2d],
                    trainable=True,
                    activation_fn=tf.nn.relu,
                    weights_initializer=tf.initializers.he_normal(),
                    weights_regularizer=slim.l2_regularizer(self.weight_decay),
                    kernel_size=[3, 3],
                    padding='SAME',
                    normalizer_fn=slim.batch_norm):
                with slim.arg_scope([batch_norm], **batch_norm_params):
                    y_tm1 = self.prior_seg

                    # TODO: Would default vars value causes error?
                    if self.prior_seg is not None or self.prior_pred is not None:
                        if "guid" in self.fusions:
                            guid = self.prior_seg
                        elif "guid_class" in self.fusions:
                            guid = self.prior_pred
                        elif "guid_uni" in self.fusions or "context_att" in self.fusions or "self_att" in self.fusions:
                            guid = tf.reduce_mean(self.prior_pred,
                                                  axis=3,
                                                  keepdims=True)
                    out_node = self.embed_node
                    tf.add_to_collection("guidance", guid)

                    for i, v in enumerate(self.low_level):
                        module_order = self.num_stage - i
                        fuse_method = self.fusions[i]
                        embed = self.embed(v,
                                           fuse_method,
                                           out_node,
                                           scope="embed%d" % module_order)
                        tf.add_to_collection("embed", embed)

                        fuse_func = self.get_fusion_method(fuse_method)
                        h, w = preprocess_utils.resolve_shape(embed,
                                                              rank=4)[1:3]

                        if y_tm1 is not None:
                            y_tm1 = resize_bilinear(y_tm1, [h, w])
                            tf.add_to_collection("feature", y_tm1)
                        else:
                            # TODO: remove
                            tf.add_to_collection("feature",
                                                 tf.zeros_like(embed))

                        if fuse_method in ("concat", "sum"):
                            if y_tm1 is not None:
                                y = fuse_func(embed, y_tm1, out_node,
                                              fuse_method + str(module_order))
                            else:
                                y = tf.identity(embed,
                                                name="identity%d" %
                                                module_order)
                        elif fuse_method in ("guid", "guid_class", "guid_uni",
                                             "context_att", "self_att"):
                            # guid = resize_bilinear(guid, [h, w])
                            if guid is not None:
                                guid = resize_bilinear(guid, [h, w])
                            # tf.add_to_collection("guid", guid)

                            fuse = fuse_func(embed,
                                             y_tm1,
                                             guid,
                                             out_node,
                                             fuse_method + str(module_order),
                                             num_classes=self.num_class,
                                             apply_sram2=self.apply_sram2)
                            """
              fuse = tf.reshape(fuse, [4, 3, h, w, out_node])
              _, fuse = seq_model(fuse, h, w, out_node, self.weight_decay, self.is_training,
                                  scope="gru"+str(i), cell_type='ConvGRU', output_wo_fuse=True)
              fuse = tf.reshape(fuse, [-1, h, w, out_node])
              """
                            y = slim.conv2d(fuse,
                                            self.embed_node,
                                            scope='fuse' + str(i))
                            tf.add_to_collection("refining", y)

                        if self.stage_pred_loss_name is not None:

                            num_class = self.num_class
                            if self.predict_without_background:
                                num_class -= 1

                            stage_pred = slim.conv2d(
                                fuse,
                                num_class,
                                kernel_size=[1, 1],
                                activation_fn=None,
                                scope="stage_pred%d_pred_class%d" %
                                (module_order, num_class))

                            # preds["guidance%d" %module_order] = stage_pred
                            tf.add_to_collection("stage_pred", stage_pred)

                        if fuse_method in ("guid"):
                            guid = y
                            y_tm1 = None

                        elif fuse_method in ("guid_class", "guid_uni",
                                             "context_att", "self_att"):
                            if i < len(self.low_level) - 1:
                                if "softmax" in self.stage_pred_loss_name:
                                    # guid = tf.nn.softmax(stage_pred, axis=3)
                                    guid = tf.nn.softmax(stage_pred, axis=3)
                                elif "sigmoid" in self.stage_pred_loss_name:
                                    guid = tf.nn.sigmoid(stage_pred)

                                if self.guid_fuse == "sum":
                                    guid = tf.reduce_sum(guid,
                                                         axis=3,
                                                         keepdims=True)
                                elif self.guid_fuse == "mean":
                                    guid = tf.reduce_mean(guid,
                                                          axis=3,
                                                          keepdims=True)
                                elif self.guid_fuse == "entropy":
                                    guid = tf.clip_by_value(guid, 1e-10, 1.0)
                                    guid = -tf.reduce_sum(guid * tf.log(guid),
                                                          axis=3,
                                                          keepdims=True)
                                elif self.guid_fuse == "conv":
                                    if i < len(self.low_level) - 1:
                                        guid = slim.conv2d(guid,
                                                           out_node,
                                                           kernel_size=[3, 3],
                                                           activation_fn=None)
                                    else:
                                        guid = tf.reduce_sum(guid,
                                                             axis=3,
                                                             keepdims=True)
                                elif self.guid_fuse == "sum_dilated":
                                    size = [8, 6, 4, 2, 1]
                                    kernel = tf.ones(
                                        (size[i], size[i], num_class))
                                    guid = tf.nn.dilation2d(guid,
                                                            filter=kernel,
                                                            strides=(1, 1, 1,
                                                                     1),
                                                            rates=(1, 1, 1, 1),
                                                            padding="SAME")
                                    guid = guid - tf.ones_like(guid)
                                    guid = tf.reduce_sum(guid,
                                                         axis=3,
                                                         keepdims=True)
                                elif self.guid_fuse == "w_sum":
                                    w = tf.nn.softmax(tf.reduce_sum(
                                        guid, axis=[1, 2], keepdims=True),
                                                      axis=3)
                                    rev_w = tf.ones_like(w) - w
                                    guid = tf.reduce_sum(tf.multiply(
                                        guid, rev_w),
                                                         axis=3,
                                                         keepdims=True)
                                elif self.guid_fuse == "conv_sum":
                                    k_size_list = [1, 1, 1, 3, 5]
                                    k_size = 2 * k_size_list[i] + 1
                                    guid = slim.conv2d(
                                        guid,
                                        1,
                                        kernel_size=[k_size, k_size],
                                        activation_fn=None,
                                        weights_initializer=tf.
                                        ones_initializer(),
                                        trainable=False,
                                        normalizer_fn=None)
                                    guid = guid / (k_size * k_size *
                                                   num_class * 1)
                                elif self.guid_fuse == "w_sum_conv":
                                    # TODO: make it right
                                    k_size_list = [1, 1, 1, 2, 4]
                                    k_size = 3 * k_size_list[i] + 1
                                    w = tf.reduce_sum(guid,
                                                      axis=[1, 2],
                                                      keepdims=True)
                                    rev_w = (tf.ones_like(w) +
                                             1e-5) / (tf.sqrt(w) + 1e-5)
                                    rev_w = tf.tile(rev_w,
                                                    [1, k_size, k_size, 1])
                                    rev_w = tf.expand_dims(rev_w, axis=4)

                                    n, h, w, channels_img = preprocess_utils.resolve_shape(
                                        guid, rank=4)
                                    n, fh, fw, channels, out_channels = preprocess_utils.resolve_shape(
                                        rev_w, rank=5)
                                    # F has shape (n, k_size, k_size, channels, out_channels)

                                    rev_w = tf.transpose(
                                        rev_w, [1, 2, 0, 3, 4])
                                    rev_w = tf.reshape(
                                        rev_w,
                                        [fh, fw, channels * n, out_channels])

                                    guid = tf.transpose(
                                        guid,
                                        [1, 2, 0, 3
                                         ])  # shape (H, W, MB, channels_img)
                                    guid = tf.reshape(
                                        guid, [1, h, w, n * channels_img])

                                    out = tf.nn.depthwise_conv2d(
                                        guid,
                                        filter=rev_w,
                                        strides=[1, 1, 1, 1],
                                        padding="SAME"
                                    )  # here no requirement about padding being 'VALID', use whatever you want.
                                    # Now out shape is (1, H-fh+1, W-fw+1, MB*channels*out_channels), because we used "VALID"

                                    out = tf.reshape(
                                        out, [h, w, n, channels, out_channels])
                                    out = tf.transpose(out, [2, 0, 1, 3, 4])
                                    out = tf.reduce_sum(out, axis=3)

                                    guid = out
                                elif self.guid_fuse == "sum_wo_back":
                                    flag = tf.concat([
                                        tf.zeros([1, 1, 1, 1]),
                                        tf.ones([1, 1, 1, num_class - 1])
                                    ],
                                                     axis=3)
                                    guid = tf.multiply(guid, flag)
                                    guid = tf.reduce_sum(guid,
                                                         axis=3,
                                                         keepdims=True)
                                elif self.guid_fuse == "mean_wo_back":
                                    flag = tf.concat([
                                        tf.zeros([1, 1, 1, 1]),
                                        tf.ones([1, 1, 1, num_class - 1])
                                    ],
                                                     axis=3)
                                    guid = tf.multiply(guid, flag)
                                    guid = tf.reduce_mean(guid,
                                                          axis=3,
                                                          keepdims=True)
                                elif self.guid_fuse == "same":
                                    pass
                                else:
                                    raise ValueError("Unknown guid fuse")

                                tf.add_to_collection("guidance", guid)

                            y_tm1 = y
                        elif fuse_method in ("concat", "sum"):
                            y_tm1 = y

                    # h, w = y.get_shape().as_list()[1:3]
                    y = resize_bilinear(y, [2 * h, 2 * w])
                    y = slim.conv2d(y, self.embed_node, scope="decoder_output")
                    y = slim.conv2d(y,
                                    self.num_class,
                                    kernel_size=[1, 1],
                                    stride=1,
                                    activation_fn=None,
                                    scope='logits_pred_class%d' %
                                    self.num_class)

        return y
Exemple #20
0
def guidance_fusion_method(logits, guid_fuse, num_class, out_node, level):
    if self.guid_fuse == "sum":
        guid = tf.reduce_sum(logits, axis=3, keepdims=True)
    elif self.guid_fuse == "mean":
        guid = tf.reduce_mean(logits, axis=3, keepdims=True)
    elif self.guid_fuse == "entropy":
        guid = tf.clip_by_value(logits, 1e-10, 1.0)
        guid = -tf.reduce_sum(guid * tf.log(guid), axis=3, keepdims=True)
    elif self.guid_fuse == "conv":
        if level < len(self.low_level) - 1:
            guid = slim.conv2d(logits,
                               out_node,
                               kernel_size=[3, 3],
                               activation_fn=None)
        else:
            guid = tf.reduce_sum(logits, axis=3, keepdims=True)
    elif self.guid_fuse == "sum_dilated":
        size = [8, 6, 4, 2, 1]
        kernel = tf.ones((size[level], size[level], num_class))
        guid = tf.nn.dilation2d(logits,
                                filter=kernel,
                                strides=(1, 1, 1, 1),
                                rates=(1, 1, 1, 1),
                                padding="SAME")
        guid = guid - tf.ones_like(guid)
        guid = tf.reduce_sum(guid, axis=3, keepdims=True)
    elif self.guid_fuse == "w_sum":
        w = tf.nn.softmax(tf.reduce_sum(logits, axis=[1, 2], keepdims=True),
                          axis=3)
        rev_w = tf.ones_like(w) - w
        guid = tf.reduce_sum(tf.multiply(logits, rev_w), axis=3, keepdims=True)
    elif self.guid_fuse == "conv_sum":
        k_size_list = [1, 1, 1, 3, 5]
        k_size = 2 * k_size_list[level] + 1
        guid = slim.conv2d(logits,
                           1,
                           kernel_size=[k_size, k_size],
                           activation_fn=None,
                           weights_initializer=tf.ones_initializer(),
                           trainable=False,
                           normalizer_fn=None)
        guid = guid / (k_size * k_size * num_class * 1)
    elif self.guid_fuse == "w_sum_conv":
        # TODO: make it right
        k_size_list = [1, 1, 1, 2, 4]
        k_size = 3 * k_size_list[level] + 1
        w = tf.reduce_sum(logits, axis=[1, 2], keepdims=True)
        rev_w = (tf.ones_like(w) + 1e-5) / (tf.sqrt(w) + 1e-5)
        rev_w = tf.tile(rev_w, [1, k_size, k_size, 1])
        rev_w = tf.expand_dims(rev_w, axis=4)

        n, h, w, channels_img = preprocess_utils.resolve_shape(logits, rank=4)
        n, fh, fw, channels, out_channels = preprocess_utils.resolve_shape(
            rev_w, rank=5)
        # F has shape (n, k_size, k_size, channels, out_channels)

        rev_w = tf.transpose(rev_w, [1, 2, 0, 3, 4])
        rev_w = tf.reshape(rev_w, [fh, fw, channels * n, out_channels])

        guid = tf.transpose(logits,
                            [1, 2, 0, 3])  # shape (H, W, MB, channels_img)
        guid = tf.reshape(guid, [1, h, w, n * channels_img])

        out = tf.nn.depthwise_conv2d(
            guid, filter=rev_w, strides=[1, 1, 1, 1], padding="SAME"
        )  # here no requirement about padding being 'VALID', use whatever you want.
        # Now out shape is (1, H-fh+1, W-fw+1, MB*channels*out_channels), because we used "VALID"

        out = tf.reshape(out, [h, w, n, channels, out_channels])
        out = tf.transpose(out, [2, 0, 1, 3, 4])
        out = tf.reduce_sum(out, axis=3)

        guid = out
    elif self.guid_fuse == "sum_wo_back":
        flag = tf.concat(
            [tf.zeros([1, 1, 1, 1]),
             tf.ones([1, 1, 1, num_class - 1])],
            axis=3)
        guid = tf.multiply(logits, flag)
        guid = tf.reduce_sum(guid, axis=3, keepdims=True)
    elif self.guid_fuse == "mean_wo_back":
        flag = tf.concat(
            [tf.zeros([1, 1, 1, 1]),
             tf.ones([1, 1, 1, num_class - 1])],
            axis=3)
        guid = tf.multiply(logits, flag)
        guid = tf.reduce_mean(guid, axis=3, keepdims=True)
    elif self.guid_fuse == "same":
        pass
    else:
        raise ValueError("Unknown guid fuse")

    tf.add_to_collection("guidance", guid)
    return guid
Exemple #21
0
def add_softmax_cross_entropy_loss_for_each_scale(scales_to_logits,
                                                  labels,
                                                  num_classes,
                                                  ignore_label,
                                                  loss_weight=1.0,
                                                  upsample_logits=True,
                                                  hard_example_mining_step=0,
                                                  top_k_percent_pixels=1.0,
                                                  scope=None):
    """Adds softmax cross entropy loss for logits of each scale.

  Args:
    scales_to_logits: A map from logits names for different scales to logits.
      The logits have shape [batch, logits_height, logits_width, num_classes].
    labels: Groundtruth labels with shape [batch, image_height, image_width, 1].
    num_classes: Integer, number of target classes.
    ignore_label: Integer, label to ignore.
    loss_weight: Float, loss weight.
    upsample_logits: Boolean, upsample logits or not.
    hard_example_mining_step: An integer, the training step in which the hard
      exampling mining kicks off. Note that we gradually reduce the mining
      percent to the top_k_percent_pixels. For example, if
      hard_example_mining_step = 100K and top_k_percent_pixels = 0.25, then
      mining percent will gradually reduce from 100% to 25% until 100K steps
      after which we only mine top 25% pixels.
    top_k_percent_pixels: A float, the value lies in [0.0, 1.0]. When its value
      < 1.0, only compute the loss for the top k percent pixels (e.g., the top
      20% pixels). This is useful for hard pixel mining.
    scope: String, the scope for the loss.

  Raises:
    ValueError: Label or logits is None.
  """
    if labels is None:
        raise ValueError('No label for softmax cross entropy loss.')

    for scale, logits in six.iteritems(scales_to_logits):
        loss_scope = None
        if scope:
            loss_scope = '%s_%s' % (scope, scale)

        if upsample_logits:
            # Label is not downsampled, and instead we upsample logits.
            logits = tf.image.resize_bilinear(logits,
                                              preprocess_utils.resolve_shape(
                                                  labels, 4)[1:3],
                                              align_corners=True)
            scaled_labels = labels
        else:
            # Label is downsampled to the same size as logits.
            scaled_labels = tf.image.resize_nearest_neighbor(
                labels,
                preprocess_utils.resolve_shape(logits, 4)[1:3],
                align_corners=True)

        scaled_labels = tf.reshape(scaled_labels, shape=[-1])
        not_ignore_mask = tf.to_float(tf.not_equal(scaled_labels,
                                                   ignore_label)) * loss_weight
        one_hot_labels = tf.one_hot(scaled_labels,
                                    num_classes,
                                    on_value=1.0,
                                    off_value=0.0)

        if top_k_percent_pixels == 1.0:
            # Compute the loss for all pixels.
            tf.losses.softmax_cross_entropy(one_hot_labels,
                                            tf.reshape(logits,
                                                       shape=[-1,
                                                              num_classes]),
                                            weights=not_ignore_mask,
                                            scope=loss_scope)
        else:
            logits = tf.reshape(logits, shape=[-1, num_classes])
            weights = not_ignore_mask
            with tf.name_scope(loss_scope, 'softmax_hard_example_mining',
                               [logits, one_hot_labels, weights]):
                one_hot_labels = tf.stop_gradient(one_hot_labels,
                                                  name='labels_stop_gradient')
                pixel_losses = tf.nn.softmax_cross_entropy_with_logits_v2(
                    labels=one_hot_labels, logits=logits, name='pixel_losses')
                weighted_pixel_losses = tf.multiply(pixel_losses, weights)
                num_pixels = tf.to_float(tf.shape(logits)[0])
                # Compute the top_k_percent pixels based on current training step.
                if hard_example_mining_step == 0:
                    # Directly focus on the top_k pixels.
                    top_k_pixels = tf.to_int32(top_k_percent_pixels *
                                               num_pixels)
                else:
                    # Gradually reduce the mining percent to top_k_percent_pixels.
                    global_step = tf.to_float(
                        tf.train.get_or_create_global_step())
                    ratio = tf.minimum(1.0,
                                       global_step / hard_example_mining_step)
                    top_k_pixels = tf.to_int32((ratio * top_k_percent_pixels +
                                                (1.0 - ratio)) * num_pixels)
                top_k_losses, _ = tf.nn.top_k(weighted_pixel_losses,
                                              k=top_k_pixels,
                                              sorted=True,
                                              name='top_k_percent_pixels')
                total_loss = tf.reduce_sum(top_k_losses)
                num_present = tf.reduce_sum(
                    tf.to_float(tf.not_equal(top_k_losses, 0.0)))
                loss = _div_maybe_zero(total_loss, num_present)
                tf.losses.add_loss(loss)
    def _preprocessing(self, sample):
        """
        image: [num_frame, height, width, channel]
        label: [num_frame, height, width, 1]
        prior_segs: [num_frame, height, width, class]
        """
        height = sample[common.HEIGHT]
        width = sample[common.WIDTH]
        image = tf.reshape(sample[common.IMAGE],
                           [self.seq_length, height, width])
        image = tf.transpose(image, [1, 2, 0])
        if common.LABEL in sample:
            label = tf.reshape(sample[common.LABEL],
                               [self.seq_length, height, width])
            label = tf.transpose(label, [1, 2, 0])
        else:
            label = None
        depth = sample[common.DEPTH]
        num_slices = sample[common.NUM_SLICES]

        # get prior
        # TODO: prior for pgn-v1
        if self.guidance_type == "training_data_fusion":
            # print("Input Prior Infomrmation: Slice=%d, Subject=%d" % (
            #     self.prior_num_slice, self.prior_num_subject))
            prior_segs = self.load_prior_from_dir(height, width)
            [_, _, prior_channel] = preprocess_utils.resolve_shape(prior_segs,
                                                                   rank=3)
        elif self.guidance_type == "ground_truth":
            prior_segs = label
        elif self.guidance_type == "zeros":
            prior_segs = tf.zeros_like(label)
        else:
            prior_segs = None

        # Preprocessing for images, label and z_label
        original_image, image, label, _, prior_segs = input_preprocess.preprocess_image_and_label_seq(
            image=image,
            label=label,
            prior_segs=prior_segs,
            crop_height=self.crop_size[0],
            crop_width=self.crop_size[1],
            channel=self.dataset_infos.channel,
            seq_length=self.seq_length,
            label_for_each_frame=self.label_for_each_frame,
            pre_crop_height=self.pre_crop_size[0],
            pre_crop_width=self.pre_crop_size[1],
            num_class=self.num_of_classes,
            HU_window=self.dataset_infos.HU_window,
            min_resize_value=self.min_resize_value,
            max_resize_value=self.max_resize_value,
            resize_factor=self.resize_factor,
            min_scale_factor=self.min_scale_factor,
            max_scale_factor=self.max_scale_factor,
            scale_factor_step_size=self.scale_factor_step_size,
            ignore_label=self.ignore_label,
            is_training=self.is_training,
            model_variant=self.model_variant)

        if self.seq_length > 1:
            image = tf.expand_dims(tf.transpose(image, [2, 0, 1]), axis=3)
            if label is not None:
                label = tf.expand_dims(tf.transpose(label, [2, 0, 1]), axis=3)

        sample[common.IMAGE] = image
        if not self.is_training:
            # Original image is only used during visualization.
            sample[common.ORIGINAL_IMAGE] = original_image

        if label is not None:
            sample[common.LABEL] = label

        if prior_segs is not None:
            sample[common.PRIOR_SEGS] = tf.reshape(
                prior_segs,
                [self.crop_size[0], self.crop_size[1], prior_channel, 1])

        # get multi-task label
        if self.z_loss_name is not None:
            mt_label = get_z_label(self.z_loss_name,
                                   num_slices,
                                   depth,
                                   z_class=self.mt_class)
            sample[common.Z_LABEL] = mt_label
        return sample