def graph_fn(prediction, target):
     weights = tf.constant([
         [[1.0], [1.0]],
         [[1.0], [1.0]],
     ])
     loss = centernet_losses.PenaltyReducedLogisticFocalLoss(alpha=2.0,
                                                             beta=0.5)
     computed_value = loss(prediction, target, weights=weights)
     return computed_value
Ejemplo n.º 2
0
    def build_losses(self, outputs, labels, aux_losses=None):
        """Build losses."""
        input_size = self.task_config.model.input_size[0:2]
        output_size = outputs['ct_heatmaps'][0].get_shape().as_list()[1:3]

        gt_label = tf.map_fn(
            # pylint: disable=g-long-lambda
            fn=lambda x: target_assigner.assign_centernet_targets(
                labels=x,
                input_size=input_size,
                output_size=output_size,
                num_classes=self.task_config.model.num_classes,
                max_num_instances=self.task_config.model.max_num_instances,
                gaussian_iou=self.task_config.losses.gaussian_iou,
                class_offset=self.task_config.losses.class_offset),
            elems=labels,
            fn_output_signature={
                'ct_heatmaps':
                tf.TensorSpec(shape=[
                    output_size[0], output_size[1],
                    self.task_config.model.num_classes
                ],
                              dtype=tf.float32),
                'ct_offset':
                tf.TensorSpec(
                    shape=[self.task_config.model.max_num_instances, 2],
                    dtype=tf.float32),
                'size':
                tf.TensorSpec(
                    shape=[self.task_config.model.max_num_instances, 2],
                    dtype=tf.float32),
                'box_mask':
                tf.TensorSpec(shape=[self.task_config.model.max_num_instances],
                              dtype=tf.int32),
                'box_indices':
                tf.TensorSpec(
                    shape=[self.task_config.model.max_num_instances, 2],
                    dtype=tf.int32),
            })

        losses = {}

        # Create loss functions
        object_center_loss_fn = centernet_losses.PenaltyReducedLogisticFocalLoss(
        )
        localization_loss_fn = centernet_losses.L1LocalizationLoss()

        # Set up box indices so that they have a batch element as well
        box_indices = loss_ops.add_batch_to_indices(gt_label['box_indices'])

        box_mask = tf.cast(gt_label['box_mask'], dtype=tf.float32)
        num_boxes = tf.cast(loss_ops.get_num_instances_from_weights(
            gt_label['box_mask']),
                            dtype=tf.float32)

        # Calculate center heatmap loss
        output_unpad_image_shapes = tf.math.ceil(
            tf.cast(labels['unpad_image_shapes'], tf.float32) /
            self._net_down_scale)
        valid_anchor_weights = loss_ops.get_valid_anchor_weights_in_flattened_image(
            output_unpad_image_shapes, output_size[0], output_size[1])
        valid_anchor_weights = tf.expand_dims(valid_anchor_weights, 2)

        pred_ct_heatmap_list = outputs['ct_heatmaps']
        true_flattened_ct_heatmap = loss_ops.flatten_spatial_dimensions(
            gt_label['ct_heatmaps'])
        true_flattened_ct_heatmap = tf.cast(true_flattened_ct_heatmap,
                                            tf.float32)

        total_center_loss = 0.0
        for ct_heatmap in pred_ct_heatmap_list:
            pred_flattened_ct_heatmap = loss_ops.flatten_spatial_dimensions(
                ct_heatmap)
            pred_flattened_ct_heatmap = tf.cast(pred_flattened_ct_heatmap,
                                                tf.float32)
            total_center_loss += object_center_loss_fn(
                target_tensor=true_flattened_ct_heatmap,
                prediction_tensor=pred_flattened_ct_heatmap,
                weights=valid_anchor_weights)

        center_loss = tf.reduce_sum(total_center_loss) / float(
            len(pred_ct_heatmap_list) * num_boxes)
        losses['ct_loss'] = center_loss

        # Calculate scale loss
        pred_scale_list = outputs['ct_size']
        true_scale = tf.cast(gt_label['size'], tf.float32)

        total_scale_loss = 0.0
        for scale_map in pred_scale_list:
            pred_scale = loss_ops.get_batch_predictions_from_indices(
                scale_map, box_indices)
            pred_scale = tf.cast(pred_scale, tf.float32)
            # Only apply loss for boxes that appear in the ground truth
            total_scale_loss += tf.reduce_sum(localization_loss_fn(
                target_tensor=true_scale, prediction_tensor=pred_scale),
                                              axis=-1) * box_mask

        scale_loss = tf.reduce_sum(total_scale_loss) / float(
            len(pred_scale_list) * num_boxes)
        losses['scale_loss'] = scale_loss

        # Calculate offset loss
        pred_offset_list = outputs['ct_offset']
        true_offset = tf.cast(gt_label['ct_offset'], tf.float32)

        total_offset_loss = 0.0
        for offset_map in pred_offset_list:
            pred_offset = loss_ops.get_batch_predictions_from_indices(
                offset_map, box_indices)
            pred_offset = tf.cast(pred_offset, tf.float32)
            # Only apply loss for boxes that appear in the ground truth
            total_offset_loss += tf.reduce_sum(localization_loss_fn(
                target_tensor=true_offset, prediction_tensor=pred_offset),
                                               axis=-1) * box_mask

        offset_loss = tf.reduce_sum(total_offset_loss) / float(
            len(pred_offset_list) * num_boxes)
        losses['ct_offset_loss'] = offset_loss

        # Aggregate and finalize loss
        loss_weights = self.task_config.losses.detection
        total_loss = (loss_weights.object_center_weight * center_loss +
                      loss_weights.scale_weight * scale_loss +
                      loss_weights.offset_weight * offset_loss)

        if aux_losses:
            total_loss += tf.add_n(aux_losses)

        losses['total_loss'] = total_loss
        return losses