def classification_loss_using_mask_iou_func_unbatched(
        embeddings, instance_ids, sampled_embeddings, sampled_instance_ids,
        sampled_class_labels, sampled_logits, similarity_strategy,
        is_balanced):
    """Classification loss using mask iou.

  Args:
    embeddings: A tf.float32 tensor of size [n, f].
    instance_ids: A tf.int32 tensor of size [n].
    sampled_embeddings: A tf.float32 tensor of size [num_samples, f].
    sampled_instance_ids: A tf.int32 tensor of size [num_samples].
    sampled_class_labels: A tf.int32 tensor of size [num_samples, 1].
    sampled_logits: A tf.float32 tensor of size [num_samples, num_classes].
    similarity_strategy: Defines the method for computing similarity between
                         embedding vectors. Possible values are 'dotproduct' and
                         'distance'.
    is_balanced: If True, the per-voxel losses are re-weighted to have equal
      total weight for foreground vs. background voxels.

  Returns:
    A tf.float32 loss scalar tensor.
  """
    predicted_soft_masks = metric_learning_utils.embedding_centers_to_soft_masks(
        embedding=embeddings,
        centers=sampled_embeddings,
        similarity_strategy=similarity_strategy)
    predicted_masks = tf.cast(tf.greater(predicted_soft_masks, 0.5),
                              dtype=tf.float32)
    gt_masks = tf.cast(tf.equal(tf.expand_dims(sampled_instance_ids, axis=1),
                                tf.expand_dims(instance_ids, axis=0)),
                       dtype=tf.float32)
    pairwise_iou = instance_segmentation_utils.points_mask_pairwise_iou(
        masks1=predicted_masks, masks2=gt_masks)
    num_classes = sampled_logits.get_shape().as_list()[1]
    sampled_class_labels_one_hot = tf.one_hot(indices=tf.reshape(
        sampled_class_labels, [-1]),
                                              depth=num_classes)
    sampled_class_labels_one_hot_fg = sampled_class_labels_one_hot[:, 1:]
    iou_coefs = tf.tile(tf.reshape(pairwise_iou, [-1, 1]),
                        [1, num_classes - 1])
    sampled_class_labels_one_hot_fg *= iou_coefs
    sampled_class_labels_one_hot_bg = tf.maximum(
        1.0 - tf.math.reduce_sum(
            sampled_class_labels_one_hot_fg, axis=1, keepdims=True), 0.0)
    sampled_class_labels_one_hot = tf.concat(
        [sampled_class_labels_one_hot_bg, sampled_class_labels_one_hot_fg],
        axis=1)
    params = {}
    if is_balanced:
        weights = loss_utils.get_balanced_loss_weights_multiclass(
            labels=tf.expand_dims(sampled_instance_ids, axis=1))
        params['weights'] = weights
    return classification_loss_fn(logits=sampled_logits,
                                  labels=sampled_class_labels_one_hot,
                                  **params)
 def test_points_mask_pairwise_iou(self):
     masks1 = tf.constant([[0, 0, 0, 0, 0], [1, 1, 1, 1, 1],
                           [1, 0, 1, 0, 1], [0, 1, 0, 1, 0]],
                          dtype=tf.int32)
     masks2 = tf.constant([[0, 0, 0, 0, 0], [1, 1, 1, 1, 1],
                           [1, 1, 1, 1, 0], [1, 0, 1, 1, 1]],
                          dtype=tf.int32)
     pairwise_iou = isu.points_mask_pairwise_iou(masks1=masks1,
                                                 masks2=masks2)
     expected_iou = tf.constant([0, 1, 0.4, 0.2], dtype=tf.float32)
     self.assertAllClose(pairwise_iou.numpy(), expected_iou.numpy())