def call(self, y_true, y_pred): from tensorflow_addons.losses import metric_learning self.sd.update_state(y_true, y_pred) labels = tf.cast( tf.convert_to_tensor(y_true, name="labels"), dtype=tf.dtypes.float32 ) if len(labels.shape) == 1: labels = tf.reshape(labels, (1, -1)) embeddings = tf.convert_to_tensor(y_pred, name="embeddings") convert_to_float32 = ( (embeddings.dtype == tf.dtypes.float16) or (embeddings.dtype == tf.dtypes.bfloat16) ) precise_embeddings = ( tf.cast(embeddings, tf.dtypes.float32) if convert_to_float32 else embeddings ) # Reshape label tensor to [batch_size, 1]. # lshape = tf.shape(labels) # labels = tf.reshape(labels, [lshape[0], 1]) # Build pairwise squared distance matrix distance_metric = self.distance_metric if distance_metric == "L2": pdist_matrix = metric_learning.pairwise_distance( precise_embeddings, squared=False ) elif distance_metric == "squared-L2": pdist_matrix = metric_learning.pairwise_distance( precise_embeddings, squared=True ) elif distance_metric == "angular": pdist_matrix = metric_learning.angular_distance(precise_embeddings) else: pdist_matrix = distance_metric(precise_embeddings) # Fetch pairwise labels as adjacency matrix. adjacency = self.response_diffs(labels) # Invert so we can select negatives only. adjacency_not = tf.math.logical_not(adjacency) batch_size = tf.size(labels) # Compute the mask. pdist_matrix_tile = tf.tile(pdist_matrix, [batch_size, 1]) mask = tf.math.logical_and( tf.tile(adjacency_not, [batch_size, 1]), tf.math.greater( pdist_matrix_tile, tf.reshape(tf.transpose(pdist_matrix), [-1, 1]) ), ) mask_final = tf.reshape( tf.math.greater( tf.math.reduce_sum( tf.cast(mask, dtype=tf.dtypes.float32), 1, keepdims=True ), 0.0, ), [batch_size, batch_size], ) mask_final = tf.transpose(mask_final) adjacency_not = tf.cast(adjacency_not, dtype=tf.dtypes.float32) mask = tf.cast(mask, dtype=tf.dtypes.float32) # negatives_outside: smallest D_an where D_an > D_ap. negatives_outside = tf.reshape( _masked_minimum(pdist_matrix_tile, mask), [batch_size, batch_size] ) negatives_outside = tf.transpose(negatives_outside) # negatives_inside: largest D_an. negatives_inside = tf.tile( _masked_maximum(pdist_matrix, adjacency_not), [1, batch_size] ) semi_hard_negatives = tf.where( mask_final, negatives_outside, negatives_inside ) loss_mat = tf.math.add(self.margin, pdist_matrix - semi_hard_negatives) mask_positives = ( tf.cast(adjacency, dtype=tf.dtypes.float32) - tf.linalg.diag(tf.ones([batch_size])) ) # In lifted-struct, the authors multiply 0.5 for upper triangular # in semihard, they take all positive pairs except the diagonal. # Max(n, 1) necessary to stop nan loss, which just stops the whole # model from running. # Setting to 1 will just mean zero loss, since everything # else will be 0. num_positives = tf.math.maximum( tf.math.reduce_sum(mask_positives), 1.0 ) triplet_loss = tf.math.truediv( tf.math.reduce_sum( tf.math.maximum( tf.math.multiply(loss_mat, mask_positives), 0.0 ) ), num_positives, ) if convert_to_float32: return tf.cast(triplet_loss, embeddings.dtype) else: return triplet_loss
def call(self, y_true, y_pred): from tensorflow_addons.losses import metric_learning self.sd.update_state(y_true, y_pred) labels = tf.cast( tf.convert_to_tensor(y_true, name="labels"), dtype=tf.dtypes.float32 ) if len(labels.shape) == 1: labels = tf.reshape(labels, (1, -1)) embeddings = tf.convert_to_tensor(y_pred, name="embeddings") convert_to_float32 = ( (embeddings.dtype == tf.dtypes.float16) or (embeddings.dtype == tf.dtypes.bfloat16) ) precise_embeddings = ( tf.cast(embeddings, tf.dtypes.float32) if convert_to_float32 else embeddings ) # Reshape label tensor to [batch_size, 1]. # lshape = tf.shape(labels) # labels = tf.reshape(labels, [lshape[0], 1]) # Build pairwise squared distance matrix distance_metric = self.distance_metric if distance_metric == "L2": pdist_matrix = metric_learning.pairwise_distance( precise_embeddings, squared=False ) elif distance_metric == "squared-L2": pdist_matrix = metric_learning.pairwise_distance( precise_embeddings, squared=True ) elif distance_metric == "angular": pdist_matrix = metric_learning.angular_distance(precise_embeddings) else: pdist_matrix = distance_metric(precise_embeddings) # Fetch pairwise labels as adjacency matrix. adjacency = self.response_diffs(labels) # Invert so we can select negatives only. adjacency_not = tf.math.logical_not(adjacency) adjacency = tf.cast(adjacency, dtype=tf.dtypes.float32) adjacency_not = tf.cast(adjacency_not, dtype=tf.dtypes.float32) hard_negatives = _masked_minimum(pdist_matrix, adjacency_not) batch_size = tf.size(labels) mask_positives = ( tf.cast(adjacency, dtype=tf.dtypes.float32) - tf.linalg.diag(tf.ones([batch_size])) ) # hard positives: largest D_ap. hard_positives = _masked_maximum(pdist_matrix, mask_positives) if self.soft: triplet_loss = tf.math.log1p( tf.math.exp(hard_positives - hard_negatives)) else: triplet_loss = tf.maximum( hard_positives - hard_negatives + self.margin, 0.0 ) # Get final mean triplet loss triplet_loss = tf.reduce_mean(triplet_loss) if convert_to_float32: return tf.cast(triplet_loss, embeddings.dtype) else: return triplet_loss
def call(self, y_true, y_pred): from tensorflow_addons.losses import metric_learning self.sd.update_state(y_true, y_pred) labels = tf.cast( tf.convert_to_tensor(y_true, name="labels"), dtype=tf.dtypes.float32 ) if len(labels.shape) == 1: labels = tf.reshape(labels, (1, -1)) batch_size = tf.shape(labels)[0] embeddings = tf.convert_to_tensor(y_pred, name="embeddings") convert_to_float32 = ( (embeddings.dtype == tf.dtypes.float16) or (embeddings.dtype == tf.dtypes.bfloat16) ) precise_embeddings = ( tf.cast(embeddings, tf.dtypes.float32) if convert_to_float32 else embeddings ) # Reshape label tensor to [batch_size, 1]. # lshape = tf.shape(labels) # labels = tf.reshape(labels, [lshape[0], 1]) # Build pairwise squared distance matrix distance_metric = self.distance_metric if distance_metric == "L2": pdist_matrix = metric_learning.pairwise_distance( precise_embeddings, squared=False ) elif distance_metric == "squared-L2": pdist_matrix = metric_learning.pairwise_distance( precise_embeddings, squared=True ) elif distance_metric == "angular": pdist_matrix = metric_learning.angular_distance(precise_embeddings) else: pdist_matrix = distance_metric(precise_embeddings) # Fetch pairwise labels as adjacency matrix. adjacency = self.response_diffs(labels) # Invert so we can select negatives only. adjacency_not = tf.math.logical_not(adjacency) radii = ( tf.reduce_mean(pdist_matrix, axis=1) - (tf.math.reduce_std(pdist_matrix, axis=1) / 2.) ) neighbors = tf.math.less(pdist_matrix, tf.reshape(radii, (-1, 1))) hits = ( tf.cast( tf.math.logical_and(neighbors, adjacency), tf.dtypes.float32 ) - tf.linalg.diag(tf.ones([batch_size])) ) misses = tf.cast( tf.math.logical_and(neighbors, adjacency_not), tf.dtypes.float32 ) nhits = tf.reduce_sum(hits) nmisses = tf.reduce_sum(misses) n = tf.cast(batch_size, tf.dtypes.float32) hits_dists = tf.multiply(pdist_matrix, hits) hits_dists = tf.math.divide_no_nan( hits_dists, tf.math.multiply(n, nhits) ) misses_dists = tf.multiply(pdist_matrix, misses) misses_dists = tf.math.divide_no_nan( misses_dists, tf.math.multiply(n, nmisses) ) loss = tf.subtract(misses_dists, hits_dists) loss = tf.reduce_sum(loss, axis=1) if convert_to_float32: return tf.cast(loss, embeddings.dtype) else: return loss
def triplet_semihard_loss( y_true: TensorLike, y_pred: TensorLike, margin: FloatTensorLike = 1.0, distance_metric: Union[str, Callable] = "L2", ) -> tf.Tensor: """Computes the triplet loss with semi-hard negative mining. Args: y_true: 1-D integer `Tensor` with shape [batch_size] of multiclass integer labels. y_pred: 2-D float `Tensor` of embedding vectors. Embeddings should be l2 normalized. margin: Float, margin term in the loss definition. distance_metric: str or function, determines distance metric: "L2" for l2-norm distance "squared-L2" for squared l2-norm distance "angular" for cosine similarity A custom function returning a 2d adjacency matrix of a chosen distance metric can also be passed here. e.g. def custom_distance(batch): batch = 1 - batch @ batch.T return batch triplet_semihard_loss(batch, labels, distance_metric=custom_distance ) Returns: triplet_loss: float scalar with dtype of y_pred. """ labels, embeddings = y_true, y_pred convert_to_float32 = (embeddings.dtype == tf.dtypes.float16 or embeddings.dtype == tf.dtypes.bfloat16) precise_embeddings = (tf.cast(embeddings, tf.dtypes.float32) if convert_to_float32 else embeddings) # Reshape label tensor to [batch_size, 1]. lshape = tf.shape(labels) labels = tf.reshape(labels, [lshape[0], 1]) # Build pairwise squared distance matrix if distance_metric == "L2": pdist_matrix = metric_learning.pairwise_distance(precise_embeddings, squared=False) elif distance_metric == "squared-L2": pdist_matrix = metric_learning.pairwise_distance(precise_embeddings, squared=True) elif distance_metric == "angular": pdist_matrix = metric_learning.angular_distance(precise_embeddings) else: pdist_matrix = distance_metric(precise_embeddings) # Build pairwise binary adjacency matrix. adjacency = tf.math.equal(labels, tf.transpose(labels)) # Invert so we can select negatives only. adjacency_not = tf.math.logical_not(adjacency) batch_size = tf.size(labels) # Compute the mask. pdist_matrix_tile = tf.tile(pdist_matrix, [batch_size, 1]) mask = tf.math.logical_and( tf.tile(adjacency_not, [batch_size, 1]), tf.math.greater(pdist_matrix_tile, tf.reshape(tf.transpose(pdist_matrix), [-1, 1])), ) mask_final = tf.reshape( tf.math.greater( tf.math.reduce_sum(tf.cast(mask, dtype=tf.dtypes.float32), 1, keepdims=True), 0.0, ), [batch_size, batch_size], ) mask_final = tf.transpose(mask_final) adjacency_not = tf.cast(adjacency_not, dtype=tf.dtypes.float32) mask = tf.cast(mask, dtype=tf.dtypes.float32) # negatives_outside: smallest D_an where D_an > D_ap. negatives_outside = tf.reshape(_masked_minimum(pdist_matrix_tile, mask), [batch_size, batch_size]) negatives_outside = tf.transpose(negatives_outside) # negatives_inside: largest D_an. negatives_inside = tf.tile(_masked_maximum(pdist_matrix, adjacency_not), [1, batch_size]) semi_hard_negatives = tf.where(mask_final, negatives_outside, negatives_inside) loss_mat = tf.math.add(margin, pdist_matrix - semi_hard_negatives) mask_positives = tf.cast(adjacency, dtype=tf.dtypes.float32) - tf.linalg.diag( tf.ones([batch_size])) # In lifted-struct, the authors multiply 0.5 for upper triangular # in semihard, they take all positive pairs except the diagonal. num_positives = tf.math.reduce_sum(mask_positives) triplet_loss = tf.math.truediv( tf.math.reduce_sum( tf.math.maximum(tf.math.multiply(loss_mat, mask_positives), 0.0)), num_positives, ) if convert_to_float32: return tf.cast(triplet_loss, embeddings.dtype) else: return triplet_loss
def triplet_hard_loss( y_true: TensorLike, y_pred: TensorLike, margin: FloatTensorLike = 1.0, soft: bool = False, distance_metric: Union[str, Callable] = "L2", ) -> tf.Tensor: """Computes the triplet loss with hard negative and hard positive mining. Args: y_true: 1-D integer `Tensor` with shape [batch_size] of multiclass integer labels. y_pred: 2-D float `Tensor` of embedding vectors. Embeddings should be l2 normalized. margin: Float, margin term in the loss definition. soft: Boolean, if set, use the soft margin version. distance_metric: str or function, determines distance metric: "L2" for l2-norm distance "squared-L2" for squared l2-norm distance "angular" for cosine similarity A custom function returning a 2d adjacency matrix of a chosen distance metric can also be passed here. e.g. def custom_distance(batch): batch = 1 - batch @ batch.T return batch triplet_semihard_loss(batch, labels, distance_metric=custom_distance ) Returns: triplet_loss: float scalar with dtype of y_pred. """ labels, embeddings = y_true, y_pred convert_to_float32 = (embeddings.dtype == tf.dtypes.float16 or embeddings.dtype == tf.dtypes.bfloat16) precise_embeddings = (tf.cast(embeddings, tf.dtypes.float32) if convert_to_float32 else embeddings) # Reshape label tensor to [batch_size, 1]. lshape = tf.shape(labels) labels = tf.reshape(labels, [lshape[0], 1]) # Build pairwise squared distance matrix. if distance_metric == "L2": pdist_matrix = metric_learning.pairwise_distance(precise_embeddings, squared=False) elif distance_metric == "squared-L2": pdist_matrix = metric_learning.pairwise_distance(precise_embeddings, squared=True) elif distance_metric == "angular": pdist_matrix = metric_learning.angular_distance(precise_embeddings) else: pdist_matrix = distance_metric(precise_embeddings) # Build pairwise binary adjacency matrix. adjacency = tf.math.equal(labels, tf.transpose(labels)) # Invert so we can select negatives only. adjacency_not = tf.math.logical_not(adjacency) adjacency_not = tf.cast(adjacency_not, dtype=tf.dtypes.float32) # hard negatives: smallest D_an. hard_negatives = _masked_minimum(pdist_matrix, adjacency_not) batch_size = tf.size(labels) adjacency = tf.cast(adjacency, dtype=tf.dtypes.float32) mask_positives = tf.cast(adjacency, dtype=tf.dtypes.float32) - tf.linalg.diag( tf.ones([batch_size])) # hard positives: largest D_ap. hard_positives = _masked_maximum(pdist_matrix, mask_positives) if soft: triplet_loss = tf.math.log1p( tf.math.exp(hard_positives - hard_negatives)) else: triplet_loss = tf.maximum(hard_positives - hard_negatives + margin, 0.0) # Get final mean triplet loss triplet_loss = tf.reduce_mean(triplet_loss) if convert_to_float32: return tf.cast(triplet_loss, embeddings.dtype) else: return triplet_loss
def triplet_semihard_loss( y_true: TensorLike, y_pred: TensorLike, margin: FloatTensorLike = 1.0, distance_metric: Union[str, Callable] = "L2", ) -> tf.Tensor: r"""Computes the triplet loss with semi-hard negative mining. Usage: >>> y_true = tf.convert_to_tensor([0, 0]) >>> y_pred = tf.convert_to_tensor([[0.0, 1.0], [1.0, 0.0]]) >>> tfa.losses.triplet_semihard_loss(y_true, y_pred, distance_metric="L2") <tf.Tensor: shape=(), dtype=float32, numpy=2.4142137> >>> # Calling with callable `distance_metric` >>> distance_metric = lambda x: tf.linalg.matmul(x, x, transpose_b=True) >>> tfa.losses.triplet_semihard_loss(y_true, y_pred, distance_metric=distance_metric) <tf.Tensor: shape=(), dtype=float32, numpy=1.0> Args: y_true: 1-D integer `Tensor` with shape `[batch_size]` of multiclass integer labels. y_pred: 2-D float `Tensor` of embedding vectors. Embeddings should be l2 normalized. margin: Float, margin term in the loss definition. distance_metric: `str` or a `Callable` that determines distance metric. Valid strings are "L2" for l2-norm distance, "squared-L2" for squared l2-norm distance, and "angular" for cosine similarity. A `Callable` should take a batch of embeddings as input and return the pairwise distance matrix. Returns: triplet_loss: float scalar with dtype of `y_pred`. """ labels, embeddings = y_true, y_pred convert_to_float32 = (embeddings.dtype == tf.dtypes.float16 or embeddings.dtype == tf.dtypes.bfloat16) precise_embeddings = (tf.cast(embeddings, tf.dtypes.float32) if convert_to_float32 else embeddings) # Reshape label tensor to [batch_size, 1]. lshape = tf.shape(labels) labels = tf.reshape(labels, [lshape[0], 1]) # Build pairwise squared distance matrix if distance_metric == "L2": pdist_matrix = metric_learning.pairwise_distance(precise_embeddings, squared=False) elif distance_metric == "squared-L2": pdist_matrix = metric_learning.pairwise_distance(precise_embeddings, squared=True) elif distance_metric == "angular": pdist_matrix = metric_learning.angular_distance(precise_embeddings) else: pdist_matrix = distance_metric(precise_embeddings) # Build pairwise binary adjacency matrix. adjacency = tf.math.equal(labels, tf.transpose(labels)) # Invert so we can select negatives only. adjacency_not = tf.math.logical_not(adjacency) batch_size = tf.size(labels) # Compute the mask. pdist_matrix_tile = tf.tile(pdist_matrix, [batch_size, 1]) mask = tf.math.logical_and( tf.tile(adjacency_not, [batch_size, 1]), tf.math.greater(pdist_matrix_tile, tf.reshape(tf.transpose(pdist_matrix), [-1, 1])), ) mask_final = tf.reshape( tf.math.greater( tf.math.reduce_sum(tf.cast(mask, dtype=tf.dtypes.float32), 1, keepdims=True), 0.0, ), [batch_size, batch_size], ) mask_final = tf.transpose(mask_final) adjacency_not = tf.cast(adjacency_not, dtype=tf.dtypes.float32) mask = tf.cast(mask, dtype=tf.dtypes.float32) # negatives_outside: smallest D_an where D_an > D_ap. negatives_outside = tf.reshape(_masked_minimum(pdist_matrix_tile, mask), [batch_size, batch_size]) negatives_outside = tf.transpose(negatives_outside) # negatives_inside: largest D_an. negatives_inside = tf.tile(_masked_maximum(pdist_matrix, adjacency_not), [1, batch_size]) semi_hard_negatives = tf.where(mask_final, negatives_outside, negatives_inside) loss_mat = tf.math.add(margin, pdist_matrix - semi_hard_negatives) mask_positives = tf.cast(adjacency, dtype=tf.dtypes.float32) - tf.linalg.diag( tf.ones([batch_size])) # In lifted-struct, the authors multiply 0.5 for upper triangular # in semihard, they take all positive pairs except the diagonal. num_positives = tf.math.reduce_sum(mask_positives) triplet_loss = tf.math.truediv( tf.math.reduce_sum( tf.math.maximum(tf.math.multiply(loss_mat, mask_positives), 0.0)), num_positives, ) if convert_to_float32: return tf.cast(triplet_loss, embeddings.dtype) else: return triplet_loss
def triplet_hard_loss( y_true: TensorLike, y_pred: TensorLike, margin: FloatTensorLike = 1.0, soft: bool = False, distance_metric: Union[str, Callable] = "L2", ) -> tf.Tensor: r"""Computes the triplet loss with hard negative and hard positive mining. Usage: >>> y_true = tf.convert_to_tensor([0, 0]) >>> y_pred = tf.convert_to_tensor([[0.0, 1.0], [1.0, 0.0]]) >>> tfa.losses.triplet_hard_loss(y_true, y_pred, distance_metric="L2") <tf.Tensor: shape=(), dtype=float32, numpy=1.0> >>> # Calling with callable `distance_metric` >>> distance_metric = lambda x: tf.linalg.matmul(x, x, transpose_b=True) >>> tfa.losses.triplet_hard_loss(y_true, y_pred, distance_metric=distance_metric) <tf.Tensor: shape=(), dtype=float32, numpy=0.0> Args: y_true: 1-D integer `Tensor` with shape `[batch_size]` of multiclass integer labels. y_pred: 2-D float `Tensor` of embedding vectors. Embeddings should be l2 normalized. margin: Float, margin term in the loss definition. soft: Boolean, if set, use the soft margin version. distance_metric: `str` or a `Callable` that determines distance metric. Valid strings are "L2" for l2-norm distance, "squared-L2" for squared l2-norm distance, and "angular" for cosine similarity. A `Callable` should take a batch of embeddings as input and return the pairwise distance matrix. Returns: triplet_loss: float scalar with dtype of `y_pred`. """ labels, embeddings = y_true, y_pred convert_to_float32 = (embeddings.dtype == tf.dtypes.float16 or embeddings.dtype == tf.dtypes.bfloat16) precise_embeddings = (tf.cast(embeddings, tf.dtypes.float32) if convert_to_float32 else embeddings) # Reshape label tensor to [batch_size, 1]. lshape = tf.shape(labels) labels = tf.reshape(labels, [lshape[0], 1]) # Build pairwise squared distance matrix. if distance_metric == "L2": pdist_matrix = metric_learning.pairwise_distance(precise_embeddings, squared=False) elif distance_metric == "squared-L2": pdist_matrix = metric_learning.pairwise_distance(precise_embeddings, squared=True) elif distance_metric == "angular": pdist_matrix = metric_learning.angular_distance(precise_embeddings) else: pdist_matrix = distance_metric(precise_embeddings) # Build pairwise binary adjacency matrix. adjacency = tf.math.equal(labels, tf.transpose(labels)) # Invert so we can select negatives only. adjacency_not = tf.math.logical_not(adjacency) adjacency_not = tf.cast(adjacency_not, dtype=tf.dtypes.float32) # hard negatives: smallest D_an. hard_negatives = _masked_minimum(pdist_matrix, adjacency_not) batch_size = tf.size(labels) adjacency = tf.cast(adjacency, dtype=tf.dtypes.float32) mask_positives = tf.cast(adjacency, dtype=tf.dtypes.float32) - tf.linalg.diag( tf.ones([batch_size])) # hard positives: largest D_ap. hard_positives = _masked_maximum(pdist_matrix, mask_positives) if soft: triplet_loss = tf.math.log1p( tf.math.exp(hard_positives - hard_negatives)) else: triplet_loss = tf.maximum(hard_positives - hard_negatives + margin, 0.0) # Get final mean triplet loss triplet_loss = tf.reduce_mean(triplet_loss) if convert_to_float32: return tf.cast(triplet_loss, embeddings.dtype) else: return triplet_loss