def test_balanced_sample(self): features = self._get_features_example() instance_ids = self._get_instance_id_example() sampled_features, sampled_instance_ids, sampled_indices = ( sampling_utils.balanced_sample(features=features, instance_ids=instance_ids, num_samples=20, max_instance_id=4)) self.assertAllEqual(sampled_features.shape, np.array([2, 20, 3])) self.assertAllEqual(sampled_instance_ids.shape, np.array([2, 20])) self.assertAllEqual(sampled_indices.shape, np.array([2, 20]))
def classification_loss_using_mask_iou_func(embeddings, logits, instance_ids, class_labels, num_samples, valid_mask=None, max_instance_id=None, similarity_strategy='dotproduct', is_balanced=True): """Classification loss using mask iou. Args: embeddings: A tf.float32 tensor of size [batch_size, n, f]. logits: A tf.float32 tensor of size [batch_size, n, num_classes]. It is assumed that background is class 0. instance_ids: A tf.int32 tensor of size [batch_size, n]. class_labels: A tf.int32 tensor of size [batch_size, n]. It is assumed that the background voxels are assigned to class 0. num_samples: An int determining the number of samples. valid_mask: A tf.bool tensor of size [batch_size, n] that is True when an element is valid and False if it needs to be ignored. By default the value is None which means it is not applied. max_instance_id: If set, instance ids larger than that value will be ignored. If not set, it will be computed from instance_ids tensor. similarity_strategy: Defines the method for computing similarity between embedding vectors. Possible values are 'dotproduct' and 'distance'. is_balanced: If True, the per-voxel losses are re-weighted to have equal total weight for foreground vs. background voxels. Returns: A tf.float32 scalar loss tensor. """ batch_size = embeddings.get_shape().as_list()[0] if batch_size is None: raise ValueError('Unknown batch size at graph construction time.') if max_instance_id is None: max_instance_id = tf.reduce_max(instance_ids) class_labels = tf.reshape(class_labels, [batch_size, -1, 1]) sampled_embeddings, sampled_instance_ids, sampled_indices = ( sampling_utils.balanced_sample(features=embeddings, instance_ids=instance_ids, num_samples=num_samples, valid_mask=valid_mask, max_instance_id=max_instance_id)) losses = [] for i in range(batch_size): embeddings_i = embeddings[i, :, :] instance_ids_i = instance_ids[i, :] class_labels_i = class_labels[i, :, :] logits_i = logits[i, :] sampled_embeddings_i = sampled_embeddings[i, :, :] sampled_instance_ids_i = sampled_instance_ids[i, :] sampled_indices_i = sampled_indices[i, :] sampled_class_labels_i = tf.gather(class_labels_i, sampled_indices_i) sampled_logits_i = tf.gather(logits_i, sampled_indices_i) if valid_mask is not None: valid_mask_i = valid_mask[i] embeddings_i = tf.boolean_mask(embeddings_i, valid_mask_i) instance_ids_i = tf.boolean_mask(instance_ids_i, valid_mask_i) loss_i = classification_loss_using_mask_iou_func_unbatched( embeddings=embeddings_i, instance_ids=instance_ids_i, sampled_embeddings=sampled_embeddings_i, sampled_instance_ids=sampled_instance_ids_i, sampled_class_labels=sampled_class_labels_i, sampled_logits=sampled_logits_i, similarity_strategy=similarity_strategy, is_balanced=is_balanced) losses.append(loss_i) return tf.math.reduce_mean(tf.stack(losses))
def npair_loss_func(embeddings, instance_ids, num_samples, valid_mask=None, max_instance_id=None, similarity_strategy='dotproduct', loss_strategy='softmax'): """N-pair metric learning loss for learning feature embeddings. Args: embeddings: A tf.float32 tensor of size [batch_size, n, f]. instance_ids: A tf.int32 tensor of size [batch_size, n]. num_samples: An int determinig the number of samples. valid_mask: A tf.bool tensor of size [batch_size, n] that is True when an element is valid and False if it needs to be ignored. By default the value is None which means it is not applied. max_instance_id: If set, instance ids larger than that value will be ignored. If not set, it will be computed from instance_ids tensor. similarity_strategy: Defines the method for computing similarity between embedding vectors. Possible values are 'dotproduct' and 'distance'. loss_strategy: Defines the type of loss including 'softmax' or 'sigmoid'. Returns: A tf.float32 scalar loss tensor. """ batch_size = embeddings.get_shape().as_list()[0] if batch_size is None: raise ValueError('Unknown batch size at graph construction time.') if max_instance_id is None: max_instance_id = tf.reduce_max(instance_ids) sampled_embeddings, sampled_instance_ids, _ = sampling_utils.balanced_sample( features=embeddings, instance_ids=instance_ids, num_samples=num_samples, valid_mask=valid_mask, max_instance_id=max_instance_id) losses = [] for i in range(batch_size): sampled_instance_ids_i = sampled_instance_ids[i, :] sampled_embeddings_i = sampled_embeddings[i, :, :] min_ids_i = tf.math.reduce_min(sampled_instance_ids_i) max_ids_i = tf.math.reduce_max(sampled_instance_ids_i) target_i = tf.one_hot( sampled_instance_ids_i, depth=(max_instance_id + 1), dtype=tf.float32) # pylint: disable=cell-var-from-loop def npair_loss_i(): return metric_learning_losses.npair_loss( embedding=sampled_embeddings_i, target=target_i, similarity_strategy=similarity_strategy, loss_strategy=loss_strategy) # pylint: enable=cell-var-from-loop loss_i = tf.cond( max_ids_i > min_ids_i, npair_loss_i, lambda: tf.constant(0.0, dtype=tf.float32)) losses.append(loss_i) return tf.math.reduce_mean(losses)