Ejemplo n.º 1
0
def test_unweighted_soft():
    num_data = 20
    feat_dim = 6
    margin = 1.0
    num_classes = 4

    embedding = np.random.rand(num_data, feat_dim).astype(np.float32)
    labels = np.random.randint(0, num_classes, size=(num_data))

    loss_np = triplet_hard_loss_np(labels, embedding, margin, soft=True)

    # Compute the loss in TF.
    y_true = tf.constant(labels)
    y_pred = tf.constant(embedding)
    cce_obj = triplet.TripletHardLoss(soft=True)
    loss = cce_obj(y_true, y_pred)
    np.testing.assert_allclose(loss, loss_np, rtol=1e-6, atol=1e-6)
Ejemplo n.º 2
0
    def test_unweighted_soft(self):
        num_data = 20
        feat_dim = 6
        margin = 1.0
        num_classes = 4

        embedding = np.random.rand(num_data, feat_dim).astype(np.float32)
        labels = np.random.randint(0, num_classes, size=(num_data))

        loss_np = triplet_hard_loss_np(labels, embedding, margin, soft=True)

        # Compute the loss in TF.
        y_true = tf.constant(labels)
        y_pred = tf.constant(embedding)
        cce_obj = triplet.TripletHardLoss(soft=True)
        loss = cce_obj(y_true, y_pred)
        self.assertAlmostEqual(self.evaluate(loss), loss_np, 3)
Ejemplo n.º 3
0
def test_hard_tripled_loss_angular(dtype, soft, dist_func, dist_metric):
    num_data = 20
    feat_dim = 6
    margin = 1.0
    num_classes = 4

    embedding = np.random.rand(num_data, feat_dim).astype(np.float32)
    labels = np.random.randint(0, num_classes, size=(num_data))

    # Compute the loss in NP.
    loss_np = triplet_hard_loss_np(labels, embedding, margin, dist_func, soft)

    # Compute the loss in TF.
    y_true = tf.constant(labels)
    y_pred = tf.constant(embedding, dtype=dtype)
    cce_obj = triplet.TripletHardLoss(soft=soft, distance_metric=dist_metric)
    loss = cce_obj(y_true, y_pred)
    test_utils.assert_allclose_according_to_type(loss.numpy(), loss_np)
Ejemplo n.º 4
0
def test_serialization_hard():
    loss = triplet.TripletHardLoss()
    tf.keras.losses.deserialize(tf.keras.losses.serialize(loss))