def test_anchor_positive_triplet_mask():
    """Test function _get_anchor_positive_triplet_mask."""
    num_data = 64
    num_classes = 10

    labels = np.random.randint(0, num_classes, size=(num_data))

    mask_np = np.zeros((num_data, num_data))
    for i in range(num_data):
        for j in range(num_data):
            distinct = (i != j)
            valid = labels[i] == labels[j]
            mask_np[i, j] = (distinct and valid)

    mask_pt_val = _get_anchor_positive_triplet_mask(torch.as_tensor(labels))
    assert np.allclose(mask_np, mask_pt_val)
Beispiel #2
0
def test_anchor_positive_triplet_mask():
    """Test function _get_anchor_positive_triplet_mask."""
    num_data = 64
    num_classes = 10

    labels = np.random.randint(0, num_classes, size=(num_data)).astype(np.float32)

    mask_np = np.zeros((num_data, num_data))
    for i in range(num_data):
        for j in range(num_data):
            distinct = (i != j)
            valid = labels[i] == labels[j]
            mask_np[i, j] = (distinct and valid)

    mask_tf = _get_anchor_positive_triplet_mask(labels)
    with tf.Session() as sess:
        mask_tf_val = sess.run(mask_tf)

    assert np.allclose(mask_np, mask_tf_val)
Beispiel #3
0
#
# """Test the pairwise distances function."""
# num_data = 64
# feat_dim = 6
#
# embeddings = np.random.randn(num_data, feat_dim).astype(np.float32)
# embeddings[1] = embeddings[0]  # to get distance 0
#
# with tf.Session() as sess:
#     # for squared in [True, False]:
#     #     res_np = pairwise_distance_np(embeddings, squared=squared)
#     #     res_tf = sess.run(_pairwise_distances(embeddings, squared=squared))
#     #     assert np.allclose(res_np, res_tf)
#     res_tf = sess.run(_pairwise_distances(embeddings, squared=False))
#     print(res_tf.shape)
"""Test function _get_anchor_positive_triplet_mask."""
num_data = 6
num_classes = 10

labels = np.random.randint(0, num_classes, size=(num_data)).astype(np.float32)

mask_np = np.zeros((num_data, num_data))
for i in range(num_data):
    for j in range(num_data):
        distinct = (i != j)
        valid = labels[i] == labels[j]
        mask_np[i, j] = (distinct and valid)

mask_tf = _get_anchor_positive_triplet_mask(labels)

assert mask_tf.shape == [64, 64]