def fn():
   """Loss function for when number of input and output boxes is positive."""
   if is_balanced:
     weights = loss_utils.get_balanced_loss_weights_multiclass(
         labels=input_boxes_instance_id)
   else:
     weights = tf.ones([tf.shape(input_boxes_instance_id)[0], 1],
                       dtype=tf.float32)
   gt_length = tf.reshape(input_boxes_length, [-1, 1])
   gt_height = tf.reshape(input_boxes_height, [-1, 1])
   gt_width = tf.reshape(input_boxes_width, [-1, 1])
   predicted_length = tf.reshape(output_boxes_length, [-1, 1])
   predicted_height = tf.reshape(output_boxes_height, [-1, 1])
   predicted_width = tf.reshape(output_boxes_width, [-1, 1])
   predicted_length /= gt_length
   predicted_height /= gt_height
   predicted_width /= gt_width
   predicted_size = tf.concat(
       [predicted_length, predicted_height, predicted_width], axis=1)
   gt_size = tf.ones_like(predicted_size)
   if loss_type == 'huber':
     loss_fn = tf.keras.losses.Huber(
         delta=delta, reduction=tf.keras.losses.Reduction.NONE)
   elif loss_type == 'absolute_difference':
     loss_fn = tf.keras.losses.MeanAbsoluteError(
         reduction=tf.keras.losses.Reduction.NONE)
   else:
     raise ValueError(('Unknown loss type %s.' % loss_type))
   size_losses = loss_fn(y_true=gt_size, y_pred=predicted_size)
   return tf.reduce_mean(size_losses * tf.reshape(weights, [-1]))
 def fn():
     """Loss function for when number of input and output boxes is positive."""
     if is_balanced:
         weights = loss_utils.get_balanced_loss_weights_multiclass(
             labels=input_boxes_instance_id)
     else:
         weights = tf.ones([tf.shape(input_boxes_instance_id)[0], 1],
                           dtype=tf.float32)
     normalized_box_size = 5.0
     predicted_boxes_length = output_boxes_length
     predicted_boxes_height = output_boxes_height
     predicted_boxes_width = output_boxes_width
     predicted_boxes_center = output_boxes_center
     predicted_boxes_rotation_matrix = output_boxes_rotation_matrix
     gt_boxes_length = input_boxes_length
     gt_boxes_height = input_boxes_height
     gt_boxes_width = input_boxes_width
     gt_boxes_center = input_boxes_center
     gt_boxes_rotation_matrix = input_boxes_rotation_matrix
     if loss_type in ['normalized_huber', 'normalized_euclidean']:
         predicted_boxes_length /= (gt_boxes_length / normalized_box_size)
         predicted_boxes_height /= (gt_boxes_height / normalized_box_size)
         predicted_boxes_width /= (gt_boxes_width / normalized_box_size)
         gt_boxes_length = tf.ones_like(
             gt_boxes_length, dtype=tf.float32) * normalized_box_size
         gt_boxes_height = tf.ones_like(
             gt_boxes_height, dtype=tf.float32) * normalized_box_size
         gt_boxes_width = tf.ones_like(
             gt_boxes_width, dtype=tf.float32) * normalized_box_size
     gt_box_corners = box_utils.get_box_corners_3d(
         boxes_length=gt_boxes_length,
         boxes_height=gt_boxes_height,
         boxes_width=gt_boxes_width,
         boxes_rotation_matrix=gt_boxes_rotation_matrix,
         boxes_center=gt_boxes_center)
     predicted_box_corners = box_utils.get_box_corners_3d(
         boxes_length=predicted_boxes_length,
         boxes_height=predicted_boxes_height,
         boxes_width=predicted_boxes_width,
         boxes_rotation_matrix=predicted_boxes_rotation_matrix,
         boxes_center=predicted_boxes_center)
     corner_weights = tf.tile(weights, [1, 8])
     if loss_type in ['huber', 'normalized_huber']:
         loss_fn = tf.keras.losses.Huber(
             delta=delta, reduction=tf.keras.losses.Reduction.NONE)
     elif loss_type in [
             'normalized_absolute_difference', 'absolute_difference'
     ]:
         loss_fn = tf.keras.losses.MeanAbsoluteError(
             reduction=tf.keras.losses.Reduction.NONE)
     else:
         raise ValueError(('Unknown loss type %s.' % loss_type))
     box_corner_losses = loss_fn(y_true=tf.reshape(gt_box_corners, [-1, 3]),
                                 y_pred=tf.reshape(predicted_box_corners,
                                                   [-1, 3]))
     return tf.reduce_mean(box_corner_losses *
                           tf.reshape(corner_weights, [-1]))
def classification_loss_using_mask_iou_func_unbatched(
        embeddings, instance_ids, sampled_embeddings, sampled_instance_ids,
        sampled_class_labels, sampled_logits, similarity_strategy,
        is_balanced):
    """Classification loss using mask iou.

  Args:
    embeddings: A tf.float32 tensor of size [n, f].
    instance_ids: A tf.int32 tensor of size [n].
    sampled_embeddings: A tf.float32 tensor of size [num_samples, f].
    sampled_instance_ids: A tf.int32 tensor of size [num_samples].
    sampled_class_labels: A tf.int32 tensor of size [num_samples, 1].
    sampled_logits: A tf.float32 tensor of size [num_samples, num_classes].
    similarity_strategy: Defines the method for computing similarity between
                         embedding vectors. Possible values are 'dotproduct' and
                         'distance'.
    is_balanced: If True, the per-voxel losses are re-weighted to have equal
      total weight for foreground vs. background voxels.

  Returns:
    A tf.float32 loss scalar tensor.
  """
    predicted_soft_masks = metric_learning_utils.embedding_centers_to_soft_masks(
        embedding=embeddings,
        centers=sampled_embeddings,
        similarity_strategy=similarity_strategy)
    predicted_masks = tf.cast(tf.greater(predicted_soft_masks, 0.5),
                              dtype=tf.float32)
    gt_masks = tf.cast(tf.equal(tf.expand_dims(sampled_instance_ids, axis=1),
                                tf.expand_dims(instance_ids, axis=0)),
                       dtype=tf.float32)
    pairwise_iou = instance_segmentation_utils.points_mask_pairwise_iou(
        masks1=predicted_masks, masks2=gt_masks)
    num_classes = sampled_logits.get_shape().as_list()[1]
    sampled_class_labels_one_hot = tf.one_hot(indices=tf.reshape(
        sampled_class_labels, [-1]),
                                              depth=num_classes)
    sampled_class_labels_one_hot_fg = sampled_class_labels_one_hot[:, 1:]
    iou_coefs = tf.tile(tf.reshape(pairwise_iou, [-1, 1]),
                        [1, num_classes - 1])
    sampled_class_labels_one_hot_fg *= iou_coefs
    sampled_class_labels_one_hot_bg = tf.maximum(
        1.0 - tf.math.reduce_sum(
            sampled_class_labels_one_hot_fg, axis=1, keepdims=True), 0.0)
    sampled_class_labels_one_hot = tf.concat(
        [sampled_class_labels_one_hot_bg, sampled_class_labels_one_hot_fg],
        axis=1)
    params = {}
    if is_balanced:
        weights = loss_utils.get_balanced_loss_weights_multiclass(
            labels=tf.expand_dims(sampled_instance_ids, axis=1))
        params['weights'] = weights
    return classification_loss_fn(logits=sampled_logits,
                                  labels=sampled_class_labels_one_hot,
                                  **params)
def _box_classification_loss_unbatched(inputs_1, outputs_1, is_intermediate,
                                       is_balanced, mine_hard_negatives,
                                       hard_negative_score_threshold):
    """Loss function for input and outputs of batch size 1."""
    valid_mask = _get_voxels_valid_mask(inputs_1=inputs_1)
    if is_intermediate:
        logits = outputs_1[standard_fields.DetectionResultFields.
                           intermediate_object_semantic_voxels]
    else:
        logits = outputs_1[
            standard_fields.DetectionResultFields.object_semantic_voxels]
    num_classes = logits.get_shape().as_list()[-1]
    if num_classes is None:
        raise ValueError('Number of classes is unknown.')
    logits = tf.boolean_mask(tf.reshape(logits, [-1, num_classes]), valid_mask)
    labels = tf.boolean_mask(
        tf.reshape(
            inputs_1[standard_fields.InputDataFields.object_class_voxels],
            [-1, 1]), valid_mask)
    if mine_hard_negatives or is_balanced:
        instances = tf.boolean_mask(
            tf.reshape(
                inputs_1[
                    standard_fields.InputDataFields.object_instance_id_voxels],
                [-1]), valid_mask)
    params = {}
    if mine_hard_negatives:
        negative_scores = tf.reshape(tf.nn.softmax(logits)[:, 0], [-1])
        hard_negative_mask = tf.logical_and(
            tf.less(negative_scores, hard_negative_score_threshold),
            tf.equal(tf.reshape(labels, [-1]), 0))
        hard_negative_labels = tf.boolean_mask(labels, hard_negative_mask)
        hard_negative_logits = tf.boolean_mask(logits, hard_negative_mask)
        hard_negative_instances = tf.boolean_mask(
            tf.ones_like(instances) * (tf.reduce_max(instances) + 1),
            hard_negative_mask)
        logits = tf.concat([logits, hard_negative_logits], axis=0)
        instances = tf.concat([instances, hard_negative_instances], axis=0)
        labels = tf.concat([labels, hard_negative_labels], axis=0)
    if is_balanced:
        weights = loss_utils.get_balanced_loss_weights_multiclass(
            labels=tf.expand_dims(instances, axis=1))
        params['weights'] = weights
    return classification_loss_fn(logits=logits, labels=labels, **params)
 def fn():
   """Loss function for when number of input and output boxes is positive."""
   if is_balanced:
     weights = loss_utils.get_balanced_loss_weights_multiclass(
         labels=input_boxes_instance_id)
   else:
     weights = tf.ones([tf.shape(input_boxes_instance_id)[0], 1],
                       dtype=tf.float32)
   gt_center = tf.reshape(input_boxes_center, [-1, 3])
   predicted_center = tf.reshape(output_boxes_center, [-1, 3])
   if loss_type == 'huber':
     loss_fn = tf.keras.losses.Huber(
         delta=delta, reduction=tf.keras.losses.Reduction.NONE)
   elif loss_type == 'absolute_difference':
     loss_fn = tf.keras.losses.MeanAbsoluteError(
         reduction=tf.keras.losses.Reduction.NONE)
   else:
     raise ValueError(('Unknown loss type %s.' % loss_type))
   center_losses = loss_fn(y_true=gt_center, y_pred=predicted_center)
   return tf.reduce_mean(center_losses * tf.reshape(weights, [-1]))
def _box_classification_using_center_distance_loss_unbatched(
        inputs_1, outputs_1, is_intermediate, is_balanced,
        max_positive_normalized_distance):
    """Loss function for input and outputs of batch size 1."""
    inputs_1, outputs_1 = _get_voxels_valid_inputs_outputs(inputs_1=inputs_1,
                                                           outputs_1=outputs_1)
    if is_intermediate:
        output_object_centers = outputs_1[standard_fields.DetectionResultFields
                                          .intermediate_object_center_voxels]
        output_object_length = outputs_1[standard_fields.DetectionResultFields.
                                         intermediate_object_length_voxels]
        output_object_height = outputs_1[standard_fields.DetectionResultFields.
                                         intermediate_object_height_voxels]
        output_object_width = outputs_1[standard_fields.DetectionResultFields.
                                        intermediate_object_width_voxels]
        output_object_rotation_matrix = outputs_1[
            standard_fields.DetectionResultFields.
            intermediate_object_rotation_matrix_voxels]
        logits = outputs_1[standard_fields.DetectionResultFields.
                           intermediate_object_semantic_voxels]
    else:
        output_object_centers = outputs_1[
            standard_fields.DetectionResultFields.object_center_voxels]
        output_object_length = outputs_1[
            standard_fields.DetectionResultFields.object_length_voxels]
        output_object_height = outputs_1[
            standard_fields.DetectionResultFields.object_height_voxels]
        output_object_width = outputs_1[
            standard_fields.DetectionResultFields.object_width_voxels]
        output_object_rotation_matrix = outputs_1[
            standard_fields.DetectionResultFields.
            object_rotation_matrix_voxels]
        logits = outputs_1[
            standard_fields.DetectionResultFields.object_semantic_voxels]
    normalized_center_distance = loss_utils.get_normalized_corner_distances(
        predicted_boxes_center=output_object_centers,
        predicted_boxes_length=output_object_length,
        predicted_boxes_height=output_object_height,
        predicted_boxes_width=output_object_width,
        predicted_boxes_rotation_matrix=output_object_rotation_matrix,
        gt_boxes_center=inputs_1[
            standard_fields.InputDataFields.object_center_voxels],
        gt_boxes_length=inputs_1[
            standard_fields.InputDataFields.object_length_voxels],
        gt_boxes_height=inputs_1[
            standard_fields.InputDataFields.object_height_voxels],
        gt_boxes_width=inputs_1[
            standard_fields.InputDataFields.object_width_voxels],
        gt_boxes_rotation_matrix=inputs_1[
            standard_fields.InputDataFields.object_rotation_matrix_voxels])
    labels = tf.reshape(
        inputs_1[standard_fields.InputDataFields.object_class_voxels], [-1])
    instances = tf.reshape(
        inputs_1[standard_fields.InputDataFields.object_instance_id_voxels],
        [-1])
    params = {}
    if is_balanced:
        weights = loss_utils.get_balanced_loss_weights_multiclass(
            labels=tf.expand_dims(instances, axis=1))
        params['weights'] = weights

    def loss_fn():
        """Loss function."""
        num_classes = logits.get_shape().as_list()[-1]
        if num_classes is None:
            raise ValueError('Number of classes is unknown.')
        labels_one_hot = tf.one_hot(indices=(labels - 1),
                                    depth=(num_classes - 1))
        inverse_distance_coef = tf.maximum(
            tf.minimum(
                1.0 -
                normalized_center_distance / max_positive_normalized_distance,
                1.0), 0.0)
        labels_one_hot = tf.reshape(inverse_distance_coef,
                                    [-1, 1]) * labels_one_hot
        background_label = 1.0 - tf.math.reduce_sum(
            labels_one_hot, axis=1, keepdims=True)
        labels_one_hot = tf.concat([background_label, labels_one_hot], axis=1)
        loss = classification_loss_fn(logits=logits,
                                      labels=labels_one_hot,
                                      **params)
        return loss

    return tf.cond(tf.greater(tf.shape(labels)[0], 0), loss_fn,
                   lambda: tf.constant(0.0, dtype=tf.float32))
 def test_get_balanced_loss_weights_multiclass(self):
     labels = tf.constant([1, 1, 1, 1, 2, 2, 3], dtype=tf.int32)
     weights = utils.get_balanced_loss_weights_multiclass(labels=labels)
     self.assertAllClose(np.sum(weights.numpy()), np.array(7.0))
     self.assertAllClose(weights.numpy()[0] * 2.0, weights.numpy()[4])
     self.assertAllClose(weights.numpy()[4] * 2.0, weights.numpy()[6])