def _detection_loss(self, cls_outputs, box_outputs, labels, loss_vals): """Computes total detection loss. Computes total detection loss including box and class loss from all levels. Args: cls_outputs: an OrderDict with keys representing levels and values representing logits in [batch_size, height, width, num_anchors]. box_outputs: an OrderDict with keys representing levels and values representing box regression targets in [batch_size, height, width, num_anchors * 4]. labels: the dictionary that returned from dataloader that includes groundtruth targets. loss_vals: A dict of loss values. Returns: total_loss: an integer tensor representing total loss reducing from class and box losses from all levels. cls_loss: an integer tensor representing total class loss. box_loss: an integer tensor representing total box regression loss. box_iou_loss: an integer tensor representing total box iou loss. """ # Sum all positives in a batch for normalization and avoid zero # num_positives_sum, which would lead to inf loss during training precision = utils.get_precision(self.config.strategy, self.config.mixed_precision) dtype = precision.split('_')[-1] num_positives_sum = tf.reduce_sum(labels['mean_num_positives']) + 1.0 positives_momentum = self.config.positives_momentum or 0 if positives_momentum > 0: # normalize the num_positive_examples for training stability. moving_normalizer_var = tf.Variable( 0.0, name='moving_normalizer', dtype=dtype, synchronization=tf.VariableSynchronization.ON_READ, trainable=False, aggregation=tf.VariableAggregation.MEAN) num_positives_sum = tf.keras.backend.moving_average_update( moving_normalizer_var, num_positives_sum, momentum=self.config.positives_momentum) elif positives_momentum < 0: num_positives_sum = utils.cross_replica_mean(num_positives_sum) num_positives_sum = tf.cast(num_positives_sum, dtype) levels = range(len(cls_outputs)) cls_losses = [] box_losses = [] for level in levels: # Onehot encoding for classification labels. cls_targets_at_level = tf.one_hot( labels['cls_targets_%d' % (level + self.config.min_level)], self.config.num_classes, dtype=dtype) if self.config.data_format == 'channels_first': bs, _, width, height, _ = cls_targets_at_level.get_shape( ).as_list() cls_targets_at_level = tf.reshape(cls_targets_at_level, [bs, -1, width, height]) else: bs, width, height, _, _ = cls_targets_at_level.get_shape( ).as_list() cls_targets_at_level = tf.reshape(cls_targets_at_level, [bs, width, height, -1]) class_loss_layer = self.loss.get(FocalLoss.__name__, None) if class_loss_layer: cls_loss = class_loss_layer( [num_positives_sum, cls_targets_at_level], cls_outputs[level]) if self.config.data_format == 'channels_first': cls_loss = tf.reshape( cls_loss, [bs, -1, width, height, self.config.num_classes]) else: cls_loss = tf.reshape( cls_loss, [bs, width, height, -1, self.config.num_classes]) cls_loss *= tf.cast( tf.expand_dims( tf.not_equal( labels['cls_targets_%d' % (level + self.config.min_level)], -2), -1), dtype) cls_loss_sum = tf.clip_by_value(tf.reduce_sum(cls_loss), 0.0, 2.0) cls_losses.append(tf.cast(cls_loss_sum, dtype)) if self.config.box_loss_weight and self.loss.get( BoxLoss.__name__, None): box_targets_at_level = ( labels['box_targets_%d' % (level + self.config.min_level)]) box_loss_layer = self.loss[BoxLoss.__name__] box_losses.append( box_loss_layer([num_positives_sum, box_targets_at_level], box_outputs[level])) if self.config.iou_loss_type: box_outputs = tf.concat( [tf.reshape(v, [-1, 4]) for v in box_outputs], axis=0) box_targets = tf.concat([ tf.reshape( labels['box_targets_%d' % (level + self.config.min_level)], [-1, 4]) for level in levels ], axis=0) box_iou_loss_layer = self.loss[BoxIouLoss.__name__] box_iou_loss = box_iou_loss_layer([num_positives_sum, box_targets], box_outputs) loss_vals['box_iou_loss'] = box_iou_loss else: box_iou_loss = 0 cls_loss = tf.add_n(cls_losses) if cls_losses else 0 box_loss = tf.add_n(box_losses) if box_losses else 0 total_loss = (cls_loss + self.config.box_loss_weight * box_loss + self.config.iou_loss_weight * box_iou_loss) loss_vals['det_loss'] = total_loss loss_vals['cls_loss'] = cls_loss loss_vals['box_loss'] = box_loss return total_loss
def detection_loss(cls_outputs, box_outputs, labels, params): """Computes total detection loss. Computes total detection loss including box and class loss from all levels. Args: cls_outputs: an OrderDict with keys representing levels and values representing logits in [batch_size, height, width, num_anchors]. box_outputs: an OrderDict with keys representing levels and values representing box regression targets in [batch_size, height, width, num_anchors * 4]. labels: the dictionary that returned from dataloader that includes groundtruth targets. params: the dictionary including training parameters specified in default_haprams function in this file. Returns: total_loss: an integer tensor representing total loss reducing from class and box losses from all levels. cls_loss: an integer tensor representing total class loss. box_loss: an integer tensor representing total box regression loss. """ # Sum all positives in a batch for normalization and avoid zero # num_positives_sum, which would lead to inf loss during training num_positives_sum = tf.reduce_sum(labels['mean_num_positives']) + 1.0 positives_momentum = params.get('positives_momentum', None) or 0 if positives_momentum > 0: # normalize the num_positive_examples for training stability. moving_normalizer_var = tf.Variable( 0.0, name='moving_normalizer', dtype=tf.float32, synchronization=tf.VariableSynchronization.ON_READ, trainable=False, aggregation=tf.VariableAggregation.MEAN) num_positives_sum = tf.keras.backend.moving_average_update( moving_normalizer_var, num_positives_sum, momentum=params['positives_momentum']) elif positives_momentum < 0: num_positives_sum = utils.cross_replica_mean(num_positives_sum) levels = cls_outputs.keys() cls_losses = [] box_losses = [] for level in levels: # Onehot encoding for classification labels. cls_targets_at_level = tf.one_hot(labels['cls_targets_%d' % level], params['num_classes'], dtype=cls_outputs[level].dtype) if params['data_format'] == 'channels_first': bs, _, width, height, _ = cls_targets_at_level.get_shape().as_list( ) cls_targets_at_level = tf.reshape(cls_targets_at_level, [bs, -1, width, height]) else: bs, width, height, _, _ = cls_targets_at_level.get_shape().as_list( ) cls_targets_at_level = tf.reshape(cls_targets_at_level, [bs, width, height, -1]) box_targets_at_level = labels['box_targets_%d' % level] cls_loss = focal_loss(cls_outputs[level], cls_targets_at_level, params['alpha'], params['gamma'], normalizer=num_positives_sum, label_smoothing=params['label_smoothing']) if params['data_format'] == 'channels_first': cls_loss = tf.reshape( cls_loss, [bs, -1, width, height, params['num_classes']]) else: cls_loss = tf.reshape( cls_loss, [bs, width, height, -1, params['num_classes']]) cls_loss *= tf.cast( tf.expand_dims(tf.not_equal(labels['cls_targets_%d' % level], -2), -1), cls_loss.dtype) cls_loss_sum = tf.reduce_sum(cls_loss) cls_losses.append(tf.cast(cls_loss_sum, tf.float32)) if params['box_loss_weight']: box_losses.append( _box_loss(box_outputs[level], box_targets_at_level, num_positives_sum, delta=params['delta'])) # Sum per level losses to total loss. cls_loss = tf.add_n(cls_losses) box_loss = tf.add_n(box_losses) if box_losses else tf.constant(0.) total_loss = (cls_loss + params['box_loss_weight'] * box_loss) return total_loss, cls_loss, box_loss