def test_balanced_sample(self):
     features = self._get_features_example()
     instance_ids = self._get_instance_id_example()
     sampled_features, sampled_instance_ids, sampled_indices = (
         sampling_utils.balanced_sample(features=features,
                                        instance_ids=instance_ids,
                                        num_samples=20,
                                        max_instance_id=4))
     self.assertAllEqual(sampled_features.shape, np.array([2, 20, 3]))
     self.assertAllEqual(sampled_instance_ids.shape, np.array([2, 20]))
     self.assertAllEqual(sampled_indices.shape, np.array([2, 20]))
def classification_loss_using_mask_iou_func(embeddings,
                                            logits,
                                            instance_ids,
                                            class_labels,
                                            num_samples,
                                            valid_mask=None,
                                            max_instance_id=None,
                                            similarity_strategy='dotproduct',
                                            is_balanced=True):
    """Classification loss using mask iou.

  Args:
    embeddings: A tf.float32 tensor of size [batch_size, n, f].
    logits: A tf.float32 tensor of size [batch_size, n, num_classes]. It is
      assumed that background is class 0.
    instance_ids: A tf.int32 tensor of size [batch_size, n].
    class_labels: A tf.int32 tensor of size [batch_size, n]. It is assumed
      that the background voxels are assigned to class 0.
    num_samples: An int determining the number of samples.
    valid_mask: A tf.bool tensor of size [batch_size, n] that is True when an
      element is valid and False if it needs to be ignored. By default the value
      is None which means it is not applied.
    max_instance_id: If set, instance ids larger than that value will be
      ignored. If not set, it will be computed from instance_ids tensor.
    similarity_strategy: Defines the method for computing similarity between
                         embedding vectors. Possible values are 'dotproduct' and
                         'distance'.
    is_balanced: If True, the per-voxel losses are re-weighted to have equal
      total weight for foreground vs. background voxels.

  Returns:
    A tf.float32 scalar loss tensor.
  """
    batch_size = embeddings.get_shape().as_list()[0]
    if batch_size is None:
        raise ValueError('Unknown batch size at graph construction time.')
    if max_instance_id is None:
        max_instance_id = tf.reduce_max(instance_ids)
    class_labels = tf.reshape(class_labels, [batch_size, -1, 1])
    sampled_embeddings, sampled_instance_ids, sampled_indices = (
        sampling_utils.balanced_sample(features=embeddings,
                                       instance_ids=instance_ids,
                                       num_samples=num_samples,
                                       valid_mask=valid_mask,
                                       max_instance_id=max_instance_id))
    losses = []
    for i in range(batch_size):
        embeddings_i = embeddings[i, :, :]
        instance_ids_i = instance_ids[i, :]
        class_labels_i = class_labels[i, :, :]
        logits_i = logits[i, :]
        sampled_embeddings_i = sampled_embeddings[i, :, :]
        sampled_instance_ids_i = sampled_instance_ids[i, :]
        sampled_indices_i = sampled_indices[i, :]
        sampled_class_labels_i = tf.gather(class_labels_i, sampled_indices_i)
        sampled_logits_i = tf.gather(logits_i, sampled_indices_i)
        if valid_mask is not None:
            valid_mask_i = valid_mask[i]
            embeddings_i = tf.boolean_mask(embeddings_i, valid_mask_i)
            instance_ids_i = tf.boolean_mask(instance_ids_i, valid_mask_i)
        loss_i = classification_loss_using_mask_iou_func_unbatched(
            embeddings=embeddings_i,
            instance_ids=instance_ids_i,
            sampled_embeddings=sampled_embeddings_i,
            sampled_instance_ids=sampled_instance_ids_i,
            sampled_class_labels=sampled_class_labels_i,
            sampled_logits=sampled_logits_i,
            similarity_strategy=similarity_strategy,
            is_balanced=is_balanced)
        losses.append(loss_i)
    return tf.math.reduce_mean(tf.stack(losses))
Esempio n. 3
0
def npair_loss_func(embeddings,
                    instance_ids,
                    num_samples,
                    valid_mask=None,
                    max_instance_id=None,
                    similarity_strategy='dotproduct',
                    loss_strategy='softmax'):
  """N-pair metric learning loss for learning feature embeddings.

  Args:
    embeddings: A tf.float32 tensor of size [batch_size, n, f].
    instance_ids: A tf.int32 tensor of size [batch_size, n].
    num_samples: An int determinig the number of samples.
    valid_mask: A tf.bool tensor of size [batch_size, n] that is True when an
      element is valid and False if it needs to be ignored. By default the value
      is None which means it is not applied.
    max_instance_id: If set, instance ids larger than that value will be
      ignored. If not set, it will be computed from instance_ids tensor.
    similarity_strategy: Defines the method for computing similarity between
                         embedding vectors. Possible values are 'dotproduct' and
                         'distance'.
    loss_strategy: Defines the type of loss including 'softmax' or 'sigmoid'.

  Returns:
    A tf.float32 scalar loss tensor.
  """
  batch_size = embeddings.get_shape().as_list()[0]
  if batch_size is None:
    raise ValueError('Unknown batch size at graph construction time.')
  if max_instance_id is None:
    max_instance_id = tf.reduce_max(instance_ids)
  sampled_embeddings, sampled_instance_ids, _ = sampling_utils.balanced_sample(
      features=embeddings,
      instance_ids=instance_ids,
      num_samples=num_samples,
      valid_mask=valid_mask,
      max_instance_id=max_instance_id)
  losses = []
  for i in range(batch_size):
    sampled_instance_ids_i = sampled_instance_ids[i, :]
    sampled_embeddings_i = sampled_embeddings[i, :, :]
    min_ids_i = tf.math.reduce_min(sampled_instance_ids_i)
    max_ids_i = tf.math.reduce_max(sampled_instance_ids_i)
    target_i = tf.one_hot(
        sampled_instance_ids_i,
        depth=(max_instance_id + 1),
        dtype=tf.float32)

    # pylint: disable=cell-var-from-loop
    def npair_loss_i():
      return metric_learning_losses.npair_loss(
          embedding=sampled_embeddings_i,
          target=target_i,
          similarity_strategy=similarity_strategy,
          loss_strategy=loss_strategy)
# pylint: enable=cell-var-from-loop

    loss_i = tf.cond(
        max_ids_i > min_ids_i, npair_loss_i,
        lambda: tf.constant(0.0, dtype=tf.float32))
    losses.append(loss_i)
  return tf.math.reduce_mean(losses)