Exemplo n.º 1
0
    def _detection_loss(self, cls_outputs, box_outputs, labels, loss_vals):
        """Computes total detection loss.

    Computes total detection loss including box and class loss from all levels.
    Args:
      cls_outputs: an OrderDict with keys representing levels and values
        representing logits in [batch_size, height, width, num_anchors].
      box_outputs: an OrderDict with keys representing levels and values
        representing box regression targets in [batch_size, height, width,
        num_anchors * 4].
      labels: the dictionary that returned from dataloader that includes
        groundtruth targets.
      loss_vals: A dict of loss values.

    Returns:
      total_loss: an integer tensor representing total loss reducing from
        class and box losses from all levels.
      cls_loss: an integer tensor representing total class loss.
      box_loss: an integer tensor representing total box regression loss.
      box_iou_loss: an integer tensor representing total box iou loss.
    """
        # Sum all positives in a batch for normalization and avoid zero
        # num_positives_sum, which would lead to inf loss during training
        precision = utils.get_precision(self.config.strategy,
                                        self.config.mixed_precision)
        dtype = precision.split('_')[-1]
        num_positives_sum = tf.reduce_sum(labels['mean_num_positives']) + 1.0
        positives_momentum = self.config.positives_momentum or 0
        if positives_momentum > 0:
            # normalize the num_positive_examples for training stability.
            moving_normalizer_var = tf.Variable(
                0.0,
                name='moving_normalizer',
                dtype=dtype,
                synchronization=tf.VariableSynchronization.ON_READ,
                trainable=False,
                aggregation=tf.VariableAggregation.MEAN)
            num_positives_sum = tf.keras.backend.moving_average_update(
                moving_normalizer_var,
                num_positives_sum,
                momentum=self.config.positives_momentum)
        elif positives_momentum < 0:
            num_positives_sum = utils.cross_replica_mean(num_positives_sum)
        num_positives_sum = tf.cast(num_positives_sum, dtype)
        levels = range(len(cls_outputs))
        cls_losses = []
        box_losses = []
        for level in levels:
            # Onehot encoding for classification labels.
            cls_targets_at_level = tf.one_hot(
                labels['cls_targets_%d' % (level + self.config.min_level)],
                self.config.num_classes,
                dtype=dtype)

            if self.config.data_format == 'channels_first':
                bs, _, width, height, _ = cls_targets_at_level.get_shape(
                ).as_list()
                cls_targets_at_level = tf.reshape(cls_targets_at_level,
                                                  [bs, -1, width, height])
            else:
                bs, width, height, _, _ = cls_targets_at_level.get_shape(
                ).as_list()
                cls_targets_at_level = tf.reshape(cls_targets_at_level,
                                                  [bs, width, height, -1])

            class_loss_layer = self.loss.get(FocalLoss.__name__, None)
            if class_loss_layer:
                cls_loss = class_loss_layer(
                    [num_positives_sum, cls_targets_at_level],
                    cls_outputs[level])
                if self.config.data_format == 'channels_first':
                    cls_loss = tf.reshape(
                        cls_loss,
                        [bs, -1, width, height, self.config.num_classes])
                else:
                    cls_loss = tf.reshape(
                        cls_loss,
                        [bs, width, height, -1, self.config.num_classes])
                cls_loss *= tf.cast(
                    tf.expand_dims(
                        tf.not_equal(
                            labels['cls_targets_%d' %
                                   (level + self.config.min_level)], -2), -1),
                    dtype)
                cls_loss_sum = tf.clip_by_value(tf.reduce_sum(cls_loss), 0.0,
                                                2.0)
                cls_losses.append(tf.cast(cls_loss_sum, dtype))

            if self.config.box_loss_weight and self.loss.get(
                    BoxLoss.__name__, None):
                box_targets_at_level = (
                    labels['box_targets_%d' % (level + self.config.min_level)])
                box_loss_layer = self.loss[BoxLoss.__name__]
                box_losses.append(
                    box_loss_layer([num_positives_sum, box_targets_at_level],
                                   box_outputs[level]))

        if self.config.iou_loss_type:
            box_outputs = tf.concat(
                [tf.reshape(v, [-1, 4]) for v in box_outputs], axis=0)
            box_targets = tf.concat([
                tf.reshape(
                    labels['box_targets_%d' %
                           (level + self.config.min_level)], [-1, 4])
                for level in levels
            ],
                                    axis=0)
            box_iou_loss_layer = self.loss[BoxIouLoss.__name__]
            box_iou_loss = box_iou_loss_layer([num_positives_sum, box_targets],
                                              box_outputs)
            loss_vals['box_iou_loss'] = box_iou_loss
        else:
            box_iou_loss = 0

        cls_loss = tf.add_n(cls_losses) if cls_losses else 0
        box_loss = tf.add_n(box_losses) if box_losses else 0
        total_loss = (cls_loss + self.config.box_loss_weight * box_loss +
                      self.config.iou_loss_weight * box_iou_loss)
        loss_vals['det_loss'] = total_loss
        loss_vals['cls_loss'] = cls_loss
        loss_vals['box_loss'] = box_loss
        return total_loss
Exemplo n.º 2
0
def detection_loss(cls_outputs, box_outputs, labels, params):
    """Computes total detection loss.

  Computes total detection loss including box and class loss from all levels.
  Args:
    cls_outputs: an OrderDict with keys representing levels and values
      representing logits in [batch_size, height, width, num_anchors].
    box_outputs: an OrderDict with keys representing levels and values
      representing box regression targets in [batch_size, height, width,
      num_anchors * 4].
    labels: the dictionary that returned from dataloader that includes
      groundtruth targets.
    params: the dictionary including training parameters specified in
      default_haprams function in this file.

  Returns:
    total_loss: an integer tensor representing total loss reducing from
      class and box losses from all levels.
    cls_loss: an integer tensor representing total class loss.
    box_loss: an integer tensor representing total box regression loss.
  """
    # Sum all positives in a batch for normalization and avoid zero
    # num_positives_sum, which would lead to inf loss during training
    num_positives_sum = tf.reduce_sum(labels['mean_num_positives']) + 1.0
    positives_momentum = params.get('positives_momentum', None) or 0
    if positives_momentum > 0:
        # normalize the num_positive_examples for training stability.
        moving_normalizer_var = tf.Variable(
            0.0,
            name='moving_normalizer',
            dtype=tf.float32,
            synchronization=tf.VariableSynchronization.ON_READ,
            trainable=False,
            aggregation=tf.VariableAggregation.MEAN)
        num_positives_sum = tf.keras.backend.moving_average_update(
            moving_normalizer_var,
            num_positives_sum,
            momentum=params['positives_momentum'])
    elif positives_momentum < 0:
        num_positives_sum = utils.cross_replica_mean(num_positives_sum)

    levels = cls_outputs.keys()
    cls_losses = []
    box_losses = []
    for level in levels:
        # Onehot encoding for classification labels.
        cls_targets_at_level = tf.one_hot(labels['cls_targets_%d' % level],
                                          params['num_classes'],
                                          dtype=cls_outputs[level].dtype)

        if params['data_format'] == 'channels_first':
            bs, _, width, height, _ = cls_targets_at_level.get_shape().as_list(
            )
            cls_targets_at_level = tf.reshape(cls_targets_at_level,
                                              [bs, -1, width, height])
        else:
            bs, width, height, _, _ = cls_targets_at_level.get_shape().as_list(
            )
            cls_targets_at_level = tf.reshape(cls_targets_at_level,
                                              [bs, width, height, -1])
        box_targets_at_level = labels['box_targets_%d' % level]

        cls_loss = focal_loss(cls_outputs[level],
                              cls_targets_at_level,
                              params['alpha'],
                              params['gamma'],
                              normalizer=num_positives_sum,
                              label_smoothing=params['label_smoothing'])

        if params['data_format'] == 'channels_first':
            cls_loss = tf.reshape(
                cls_loss, [bs, -1, width, height, params['num_classes']])
        else:
            cls_loss = tf.reshape(
                cls_loss, [bs, width, height, -1, params['num_classes']])

        cls_loss *= tf.cast(
            tf.expand_dims(tf.not_equal(labels['cls_targets_%d' % level], -2),
                           -1), cls_loss.dtype)
        cls_loss_sum = tf.reduce_sum(cls_loss)
        cls_losses.append(tf.cast(cls_loss_sum, tf.float32))

        if params['box_loss_weight']:
            box_losses.append(
                _box_loss(box_outputs[level],
                          box_targets_at_level,
                          num_positives_sum,
                          delta=params['delta']))

    # Sum per level losses to total loss.
    cls_loss = tf.add_n(cls_losses)
    box_loss = tf.add_n(box_losses) if box_losses else tf.constant(0.)

    total_loss = (cls_loss + params['box_loss_weight'] * box_loss)

    return total_loss, cls_loss, box_loss