Пример #1
0
    def call(self, y_true, y_pred):
        from tensorflow_addons.losses import metric_learning

        self.sd.update_state(y_true, y_pred)

        labels = tf.cast(
            tf.convert_to_tensor(y_true, name="labels"),
            dtype=tf.dtypes.float32
        )
        if len(labels.shape) == 1:
            labels = tf.reshape(labels, (1, -1))

        embeddings = tf.convert_to_tensor(y_pred, name="embeddings")

        convert_to_float32 = (
            (embeddings.dtype == tf.dtypes.float16) or
            (embeddings.dtype == tf.dtypes.bfloat16)
        )
        precise_embeddings = (
            tf.cast(embeddings, tf.dtypes.float32)
            if convert_to_float32
            else embeddings
        )

        # Reshape label tensor to [batch_size, 1].
        # lshape = tf.shape(labels)
        # labels = tf.reshape(labels, [lshape[0], 1])

        # Build pairwise squared distance matrix
        distance_metric = self.distance_metric

        if distance_metric == "L2":
            pdist_matrix = metric_learning.pairwise_distance(
                precise_embeddings, squared=False
            )

        elif distance_metric == "squared-L2":
            pdist_matrix = metric_learning.pairwise_distance(
                precise_embeddings, squared=True
            )

        elif distance_metric == "angular":
            pdist_matrix = metric_learning.angular_distance(precise_embeddings)

        else:
            pdist_matrix = distance_metric(precise_embeddings)

        # Fetch pairwise labels as adjacency matrix.
        adjacency = self.response_diffs(labels)
        # Invert so we can select negatives only.
        adjacency_not = tf.math.logical_not(adjacency)

        batch_size = tf.size(labels)

        # Compute the mask.
        pdist_matrix_tile = tf.tile(pdist_matrix, [batch_size, 1])
        mask = tf.math.logical_and(
            tf.tile(adjacency_not, [batch_size, 1]),
            tf.math.greater(
                pdist_matrix_tile,
                tf.reshape(tf.transpose(pdist_matrix), [-1, 1])
            ),
        )
        mask_final = tf.reshape(
            tf.math.greater(
                tf.math.reduce_sum(
                    tf.cast(mask, dtype=tf.dtypes.float32),
                    1,
                    keepdims=True
                ),
                0.0,
            ),
            [batch_size, batch_size],
        )
        mask_final = tf.transpose(mask_final)

        adjacency_not = tf.cast(adjacency_not, dtype=tf.dtypes.float32)
        mask = tf.cast(mask, dtype=tf.dtypes.float32)

        # negatives_outside: smallest D_an where D_an > D_ap.
        negatives_outside = tf.reshape(
            _masked_minimum(pdist_matrix_tile, mask), [batch_size, batch_size]
        )
        negatives_outside = tf.transpose(negatives_outside)

        # negatives_inside: largest D_an.
        negatives_inside = tf.tile(
            _masked_maximum(pdist_matrix, adjacency_not), [1, batch_size]
        )
        semi_hard_negatives = tf.where(
            mask_final,
            negatives_outside,
            negatives_inside
        )

        loss_mat = tf.math.add(self.margin, pdist_matrix - semi_hard_negatives)

        mask_positives = (
            tf.cast(adjacency, dtype=tf.dtypes.float32) -
            tf.linalg.diag(tf.ones([batch_size]))
        )

        # In lifted-struct, the authors multiply 0.5 for upper triangular
        #   in semihard, they take all positive pairs except the diagonal.
        # Max(n, 1) necessary to stop nan loss, which just stops the whole
        # model from running.
        # Setting to 1 will just mean zero loss, since everything
        # else will be 0.
        num_positives = tf.math.maximum(
            tf.math.reduce_sum(mask_positives),
            1.0
        )

        triplet_loss = tf.math.truediv(
            tf.math.reduce_sum(
                tf.math.maximum(
                    tf.math.multiply(loss_mat, mask_positives),
                    0.0
                )
            ),
            num_positives,
        )

        if convert_to_float32:
            return tf.cast(triplet_loss, embeddings.dtype)
        else:
            return triplet_loss
Пример #2
0
    def call(self, y_true, y_pred):
        from tensorflow_addons.losses import metric_learning

        self.sd.update_state(y_true, y_pred)

        labels = tf.cast(
            tf.convert_to_tensor(y_true, name="labels"),
            dtype=tf.dtypes.float32
        )
        if len(labels.shape) == 1:
            labels = tf.reshape(labels, (1, -1))

        embeddings = tf.convert_to_tensor(y_pred, name="embeddings")

        convert_to_float32 = (
            (embeddings.dtype == tf.dtypes.float16) or
            (embeddings.dtype == tf.dtypes.bfloat16)
        )
        precise_embeddings = (
            tf.cast(embeddings, tf.dtypes.float32)
            if convert_to_float32
            else embeddings
        )

        # Reshape label tensor to [batch_size, 1].
        # lshape = tf.shape(labels)
        # labels = tf.reshape(labels, [lshape[0], 1])

        # Build pairwise squared distance matrix
        distance_metric = self.distance_metric

        if distance_metric == "L2":
            pdist_matrix = metric_learning.pairwise_distance(
                precise_embeddings, squared=False
            )

        elif distance_metric == "squared-L2":
            pdist_matrix = metric_learning.pairwise_distance(
                precise_embeddings, squared=True
            )

        elif distance_metric == "angular":
            pdist_matrix = metric_learning.angular_distance(precise_embeddings)

        else:
            pdist_matrix = distance_metric(precise_embeddings)

        # Fetch pairwise labels as adjacency matrix.
        adjacency = self.response_diffs(labels)

        # Invert so we can select negatives only.
        adjacency_not = tf.math.logical_not(adjacency)

        adjacency = tf.cast(adjacency, dtype=tf.dtypes.float32)
        adjacency_not = tf.cast(adjacency_not, dtype=tf.dtypes.float32)
        hard_negatives = _masked_minimum(pdist_matrix, adjacency_not)

        batch_size = tf.size(labels)

        mask_positives = (
            tf.cast(adjacency, dtype=tf.dtypes.float32) -
            tf.linalg.diag(tf.ones([batch_size]))
        )

        # hard positives: largest D_ap.
        hard_positives = _masked_maximum(pdist_matrix, mask_positives)

        if self.soft:
            triplet_loss = tf.math.log1p(
                tf.math.exp(hard_positives - hard_negatives))
        else:
            triplet_loss = tf.maximum(
                hard_positives - hard_negatives + self.margin,
                0.0
            )

        # Get final mean triplet loss
        triplet_loss = tf.reduce_mean(triplet_loss)

        if convert_to_float32:
            return tf.cast(triplet_loss, embeddings.dtype)
        else:
            return triplet_loss
Пример #3
0
    def call(self, y_true, y_pred):
        from tensorflow_addons.losses import metric_learning

        self.sd.update_state(y_true, y_pred)

        labels = tf.cast(
            tf.convert_to_tensor(y_true, name="labels"),
            dtype=tf.dtypes.float32
        )
        if len(labels.shape) == 1:
            labels = tf.reshape(labels, (1, -1))

        batch_size = tf.shape(labels)[0]

        embeddings = tf.convert_to_tensor(y_pred, name="embeddings")

        convert_to_float32 = (
            (embeddings.dtype == tf.dtypes.float16) or
            (embeddings.dtype == tf.dtypes.bfloat16)
        )
        precise_embeddings = (
            tf.cast(embeddings, tf.dtypes.float32)
            if convert_to_float32
            else embeddings
        )

        # Reshape label tensor to [batch_size, 1].
        # lshape = tf.shape(labels)
        # labels = tf.reshape(labels, [lshape[0], 1])

        # Build pairwise squared distance matrix
        distance_metric = self.distance_metric

        if distance_metric == "L2":
            pdist_matrix = metric_learning.pairwise_distance(
                precise_embeddings, squared=False
            )

        elif distance_metric == "squared-L2":
            pdist_matrix = metric_learning.pairwise_distance(
                precise_embeddings, squared=True
            )

        elif distance_metric == "angular":
            pdist_matrix = metric_learning.angular_distance(precise_embeddings)

        else:
            pdist_matrix = distance_metric(precise_embeddings)

        # Fetch pairwise labels as adjacency matrix.
        adjacency = self.response_diffs(labels)
        # Invert so we can select negatives only.
        adjacency_not = tf.math.logical_not(adjacency)

        radii = (
            tf.reduce_mean(pdist_matrix, axis=1) -
            (tf.math.reduce_std(pdist_matrix, axis=1) / 2.)
        )
        neighbors = tf.math.less(pdist_matrix, tf.reshape(radii, (-1, 1)))

        hits = (
            tf.cast(
                tf.math.logical_and(neighbors, adjacency),
                tf.dtypes.float32
            ) - tf.linalg.diag(tf.ones([batch_size]))
        )

        misses = tf.cast(
            tf.math.logical_and(neighbors, adjacency_not),
            tf.dtypes.float32
        )

        nhits = tf.reduce_sum(hits)
        nmisses = tf.reduce_sum(misses)

        n = tf.cast(batch_size, tf.dtypes.float32)
        hits_dists = tf.multiply(pdist_matrix, hits)
        hits_dists = tf.math.divide_no_nan(
            hits_dists,
            tf.math.multiply(n, nhits)
        )
        misses_dists = tf.multiply(pdist_matrix, misses)
        misses_dists = tf.math.divide_no_nan(
            misses_dists,
            tf.math.multiply(n, nmisses)
        )

        loss = tf.subtract(misses_dists, hits_dists)
        loss = tf.reduce_sum(loss, axis=1)

        if convert_to_float32:
            return tf.cast(loss, embeddings.dtype)
        else:
            return loss
Пример #4
0
def triplet_semihard_loss(
    y_true: TensorLike,
    y_pred: TensorLike,
    margin: FloatTensorLike = 1.0,
    distance_metric: Union[str, Callable] = "L2",
) -> tf.Tensor:
    """Computes the triplet loss with semi-hard negative mining.

    Args:
      y_true: 1-D integer `Tensor` with shape [batch_size] of
        multiclass integer labels.
      y_pred: 2-D float `Tensor` of embedding vectors. Embeddings should
        be l2 normalized.
      margin: Float, margin term in the loss definition.
      distance_metric: str or function, determines distance metric:
                       "L2" for l2-norm distance
                       "squared-L2" for squared l2-norm distance
                       "angular" for cosine similarity
                        A custom function returning a 2d adjacency
                          matrix of a chosen distance metric can
                          also be passed here. e.g.

                          def custom_distance(batch):
                              batch = 1 - batch @ batch.T
                              return batch

                          triplet_semihard_loss(batch, labels,
                                        distance_metric=custom_distance
                                    )


    Returns:
      triplet_loss: float scalar with dtype of y_pred.
    """

    labels, embeddings = y_true, y_pred

    convert_to_float32 = (embeddings.dtype == tf.dtypes.float16
                          or embeddings.dtype == tf.dtypes.bfloat16)
    precise_embeddings = (tf.cast(embeddings, tf.dtypes.float32)
                          if convert_to_float32 else embeddings)

    # Reshape label tensor to [batch_size, 1].
    lshape = tf.shape(labels)
    labels = tf.reshape(labels, [lshape[0], 1])

    # Build pairwise squared distance matrix

    if distance_metric == "L2":
        pdist_matrix = metric_learning.pairwise_distance(precise_embeddings,
                                                         squared=False)

    elif distance_metric == "squared-L2":
        pdist_matrix = metric_learning.pairwise_distance(precise_embeddings,
                                                         squared=True)

    elif distance_metric == "angular":
        pdist_matrix = metric_learning.angular_distance(precise_embeddings)

    else:
        pdist_matrix = distance_metric(precise_embeddings)

    # Build pairwise binary adjacency matrix.
    adjacency = tf.math.equal(labels, tf.transpose(labels))
    # Invert so we can select negatives only.
    adjacency_not = tf.math.logical_not(adjacency)

    batch_size = tf.size(labels)

    # Compute the mask.
    pdist_matrix_tile = tf.tile(pdist_matrix, [batch_size, 1])
    mask = tf.math.logical_and(
        tf.tile(adjacency_not, [batch_size, 1]),
        tf.math.greater(pdist_matrix_tile,
                        tf.reshape(tf.transpose(pdist_matrix), [-1, 1])),
    )
    mask_final = tf.reshape(
        tf.math.greater(
            tf.math.reduce_sum(tf.cast(mask, dtype=tf.dtypes.float32),
                               1,
                               keepdims=True),
            0.0,
        ),
        [batch_size, batch_size],
    )
    mask_final = tf.transpose(mask_final)

    adjacency_not = tf.cast(adjacency_not, dtype=tf.dtypes.float32)
    mask = tf.cast(mask, dtype=tf.dtypes.float32)

    # negatives_outside: smallest D_an where D_an > D_ap.
    negatives_outside = tf.reshape(_masked_minimum(pdist_matrix_tile, mask),
                                   [batch_size, batch_size])
    negatives_outside = tf.transpose(negatives_outside)

    # negatives_inside: largest D_an.
    negatives_inside = tf.tile(_masked_maximum(pdist_matrix, adjacency_not),
                               [1, batch_size])
    semi_hard_negatives = tf.where(mask_final, negatives_outside,
                                   negatives_inside)

    loss_mat = tf.math.add(margin, pdist_matrix - semi_hard_negatives)

    mask_positives = tf.cast(adjacency,
                             dtype=tf.dtypes.float32) - tf.linalg.diag(
                                 tf.ones([batch_size]))

    # In lifted-struct, the authors multiply 0.5 for upper triangular
    #   in semihard, they take all positive pairs except the diagonal.
    num_positives = tf.math.reduce_sum(mask_positives)

    triplet_loss = tf.math.truediv(
        tf.math.reduce_sum(
            tf.math.maximum(tf.math.multiply(loss_mat, mask_positives), 0.0)),
        num_positives,
    )

    if convert_to_float32:
        return tf.cast(triplet_loss, embeddings.dtype)
    else:
        return triplet_loss
Пример #5
0
def triplet_hard_loss(
    y_true: TensorLike,
    y_pred: TensorLike,
    margin: FloatTensorLike = 1.0,
    soft: bool = False,
    distance_metric: Union[str, Callable] = "L2",
) -> tf.Tensor:
    """Computes the triplet loss with hard negative and hard positive mining.

    Args:
      y_true: 1-D integer `Tensor` with shape [batch_size] of
        multiclass integer labels.
      y_pred: 2-D float `Tensor` of embedding vectors. Embeddings should
        be l2 normalized.
      margin: Float, margin term in the loss definition.
      soft: Boolean, if set, use the soft margin version.
      distance_metric: str or function, determines distance metric:
                       "L2" for l2-norm distance
                       "squared-L2" for squared l2-norm distance
                       "angular" for cosine similarity
                        A custom function returning a 2d adjacency
                          matrix of a chosen distance metric can
                          also be passed here. e.g.

                          def custom_distance(batch):
                              batch = 1 - batch @ batch.T
                              return batch

                          triplet_semihard_loss(batch, labels,
                                        distance_metric=custom_distance
                                    )

    Returns:
      triplet_loss: float scalar with dtype of y_pred.
    """
    labels, embeddings = y_true, y_pred

    convert_to_float32 = (embeddings.dtype == tf.dtypes.float16
                          or embeddings.dtype == tf.dtypes.bfloat16)
    precise_embeddings = (tf.cast(embeddings, tf.dtypes.float32)
                          if convert_to_float32 else embeddings)

    # Reshape label tensor to [batch_size, 1].
    lshape = tf.shape(labels)
    labels = tf.reshape(labels, [lshape[0], 1])

    # Build pairwise squared distance matrix.
    if distance_metric == "L2":
        pdist_matrix = metric_learning.pairwise_distance(precise_embeddings,
                                                         squared=False)

    elif distance_metric == "squared-L2":
        pdist_matrix = metric_learning.pairwise_distance(precise_embeddings,
                                                         squared=True)

    elif distance_metric == "angular":
        pdist_matrix = metric_learning.angular_distance(precise_embeddings)

    else:
        pdist_matrix = distance_metric(precise_embeddings)

    # Build pairwise binary adjacency matrix.
    adjacency = tf.math.equal(labels, tf.transpose(labels))
    # Invert so we can select negatives only.
    adjacency_not = tf.math.logical_not(adjacency)

    adjacency_not = tf.cast(adjacency_not, dtype=tf.dtypes.float32)
    # hard negatives: smallest D_an.
    hard_negatives = _masked_minimum(pdist_matrix, adjacency_not)

    batch_size = tf.size(labels)

    adjacency = tf.cast(adjacency, dtype=tf.dtypes.float32)

    mask_positives = tf.cast(adjacency,
                             dtype=tf.dtypes.float32) - tf.linalg.diag(
                                 tf.ones([batch_size]))

    # hard positives: largest D_ap.
    hard_positives = _masked_maximum(pdist_matrix, mask_positives)

    if soft:
        triplet_loss = tf.math.log1p(
            tf.math.exp(hard_positives - hard_negatives))
    else:
        triplet_loss = tf.maximum(hard_positives - hard_negatives + margin,
                                  0.0)

    # Get final mean triplet loss
    triplet_loss = tf.reduce_mean(triplet_loss)

    if convert_to_float32:
        return tf.cast(triplet_loss, embeddings.dtype)
    else:
        return triplet_loss
Пример #6
0
def triplet_semihard_loss(
    y_true: TensorLike,
    y_pred: TensorLike,
    margin: FloatTensorLike = 1.0,
    distance_metric: Union[str, Callable] = "L2",
) -> tf.Tensor:
    r"""Computes the triplet loss with semi-hard negative mining.

    Usage:

    >>> y_true = tf.convert_to_tensor([0, 0])
    >>> y_pred = tf.convert_to_tensor([[0.0, 1.0], [1.0, 0.0]])
    >>> tfa.losses.triplet_semihard_loss(y_true, y_pred, distance_metric="L2")
    <tf.Tensor: shape=(), dtype=float32, numpy=2.4142137>

    >>> # Calling with callable `distance_metric`
    >>> distance_metric = lambda x: tf.linalg.matmul(x, x, transpose_b=True)
    >>> tfa.losses.triplet_semihard_loss(y_true, y_pred, distance_metric=distance_metric)
    <tf.Tensor: shape=(), dtype=float32, numpy=1.0>

    Args:
      y_true: 1-D integer `Tensor` with shape `[batch_size]` of
        multiclass integer labels.
      y_pred: 2-D float `Tensor` of embedding vectors. Embeddings should
        be l2 normalized.
      margin: Float, margin term in the loss definition.
      distance_metric: `str` or a `Callable` that determines distance metric.
        Valid strings are "L2" for l2-norm distance,
        "squared-L2" for squared l2-norm distance,
        and "angular" for cosine similarity.

        A `Callable` should take a batch of embeddings as input and
        return the pairwise distance matrix.

    Returns:
      triplet_loss: float scalar with dtype of `y_pred`.
    """

    labels, embeddings = y_true, y_pred

    convert_to_float32 = (embeddings.dtype == tf.dtypes.float16
                          or embeddings.dtype == tf.dtypes.bfloat16)
    precise_embeddings = (tf.cast(embeddings, tf.dtypes.float32)
                          if convert_to_float32 else embeddings)

    # Reshape label tensor to [batch_size, 1].
    lshape = tf.shape(labels)
    labels = tf.reshape(labels, [lshape[0], 1])

    # Build pairwise squared distance matrix

    if distance_metric == "L2":
        pdist_matrix = metric_learning.pairwise_distance(precise_embeddings,
                                                         squared=False)

    elif distance_metric == "squared-L2":
        pdist_matrix = metric_learning.pairwise_distance(precise_embeddings,
                                                         squared=True)

    elif distance_metric == "angular":
        pdist_matrix = metric_learning.angular_distance(precise_embeddings)

    else:
        pdist_matrix = distance_metric(precise_embeddings)

    # Build pairwise binary adjacency matrix.
    adjacency = tf.math.equal(labels, tf.transpose(labels))
    # Invert so we can select negatives only.
    adjacency_not = tf.math.logical_not(adjacency)

    batch_size = tf.size(labels)

    # Compute the mask.
    pdist_matrix_tile = tf.tile(pdist_matrix, [batch_size, 1])
    mask = tf.math.logical_and(
        tf.tile(adjacency_not, [batch_size, 1]),
        tf.math.greater(pdist_matrix_tile,
                        tf.reshape(tf.transpose(pdist_matrix), [-1, 1])),
    )
    mask_final = tf.reshape(
        tf.math.greater(
            tf.math.reduce_sum(tf.cast(mask, dtype=tf.dtypes.float32),
                               1,
                               keepdims=True),
            0.0,
        ),
        [batch_size, batch_size],
    )
    mask_final = tf.transpose(mask_final)

    adjacency_not = tf.cast(adjacency_not, dtype=tf.dtypes.float32)
    mask = tf.cast(mask, dtype=tf.dtypes.float32)

    # negatives_outside: smallest D_an where D_an > D_ap.
    negatives_outside = tf.reshape(_masked_minimum(pdist_matrix_tile, mask),
                                   [batch_size, batch_size])
    negatives_outside = tf.transpose(negatives_outside)

    # negatives_inside: largest D_an.
    negatives_inside = tf.tile(_masked_maximum(pdist_matrix, adjacency_not),
                               [1, batch_size])
    semi_hard_negatives = tf.where(mask_final, negatives_outside,
                                   negatives_inside)

    loss_mat = tf.math.add(margin, pdist_matrix - semi_hard_negatives)

    mask_positives = tf.cast(adjacency,
                             dtype=tf.dtypes.float32) - tf.linalg.diag(
                                 tf.ones([batch_size]))

    # In lifted-struct, the authors multiply 0.5 for upper triangular
    #   in semihard, they take all positive pairs except the diagonal.
    num_positives = tf.math.reduce_sum(mask_positives)

    triplet_loss = tf.math.truediv(
        tf.math.reduce_sum(
            tf.math.maximum(tf.math.multiply(loss_mat, mask_positives), 0.0)),
        num_positives,
    )

    if convert_to_float32:
        return tf.cast(triplet_loss, embeddings.dtype)
    else:
        return triplet_loss
Пример #7
0
def triplet_hard_loss(
    y_true: TensorLike,
    y_pred: TensorLike,
    margin: FloatTensorLike = 1.0,
    soft: bool = False,
    distance_metric: Union[str, Callable] = "L2",
) -> tf.Tensor:
    r"""Computes the triplet loss with hard negative and hard positive mining.

    Usage:

    >>> y_true = tf.convert_to_tensor([0, 0])
    >>> y_pred = tf.convert_to_tensor([[0.0, 1.0], [1.0, 0.0]])
    >>> tfa.losses.triplet_hard_loss(y_true, y_pred, distance_metric="L2")
    <tf.Tensor: shape=(), dtype=float32, numpy=1.0>

    >>> # Calling with callable `distance_metric`
    >>> distance_metric = lambda x: tf.linalg.matmul(x, x, transpose_b=True)
    >>> tfa.losses.triplet_hard_loss(y_true, y_pred, distance_metric=distance_metric)
    <tf.Tensor: shape=(), dtype=float32, numpy=0.0>

    Args:
      y_true: 1-D integer `Tensor` with shape `[batch_size]` of
        multiclass integer labels.
      y_pred: 2-D float `Tensor` of embedding vectors. Embeddings should
        be l2 normalized.
      margin: Float, margin term in the loss definition.
      soft: Boolean, if set, use the soft margin version.
      distance_metric: `str` or a `Callable` that determines distance metric.
        Valid strings are "L2" for l2-norm distance,
        "squared-L2" for squared l2-norm distance,
        and "angular" for cosine similarity.

        A `Callable` should take a batch of embeddings as input and
        return the pairwise distance matrix.

    Returns:
      triplet_loss: float scalar with dtype of `y_pred`.
    """
    labels, embeddings = y_true, y_pred

    convert_to_float32 = (embeddings.dtype == tf.dtypes.float16
                          or embeddings.dtype == tf.dtypes.bfloat16)
    precise_embeddings = (tf.cast(embeddings, tf.dtypes.float32)
                          if convert_to_float32 else embeddings)

    # Reshape label tensor to [batch_size, 1].
    lshape = tf.shape(labels)
    labels = tf.reshape(labels, [lshape[0], 1])

    # Build pairwise squared distance matrix.
    if distance_metric == "L2":
        pdist_matrix = metric_learning.pairwise_distance(precise_embeddings,
                                                         squared=False)

    elif distance_metric == "squared-L2":
        pdist_matrix = metric_learning.pairwise_distance(precise_embeddings,
                                                         squared=True)

    elif distance_metric == "angular":
        pdist_matrix = metric_learning.angular_distance(precise_embeddings)

    else:
        pdist_matrix = distance_metric(precise_embeddings)

    # Build pairwise binary adjacency matrix.
    adjacency = tf.math.equal(labels, tf.transpose(labels))
    # Invert so we can select negatives only.
    adjacency_not = tf.math.logical_not(adjacency)

    adjacency_not = tf.cast(adjacency_not, dtype=tf.dtypes.float32)
    # hard negatives: smallest D_an.
    hard_negatives = _masked_minimum(pdist_matrix, adjacency_not)

    batch_size = tf.size(labels)

    adjacency = tf.cast(adjacency, dtype=tf.dtypes.float32)

    mask_positives = tf.cast(adjacency,
                             dtype=tf.dtypes.float32) - tf.linalg.diag(
                                 tf.ones([batch_size]))

    # hard positives: largest D_ap.
    hard_positives = _masked_maximum(pdist_matrix, mask_positives)

    if soft:
        triplet_loss = tf.math.log1p(
            tf.math.exp(hard_positives - hard_negatives))
    else:
        triplet_loss = tf.maximum(hard_positives - hard_negatives + margin,
                                  0.0)

    # Get final mean triplet loss
    triplet_loss = tf.reduce_mean(triplet_loss)

    if convert_to_float32:
        return tf.cast(triplet_loss, embeddings.dtype)
    else:
        return triplet_loss