コード例 #1
0
def test_hard_sampler_from_dist(distmats_and_labels) -> None:  # noqa: WPS442
    """
    Args:
        distmats_and_labels:
            list of distance matrices and valid labels
    """
    sampler = HardTripletsSampler(norm_required=True)

    for distmat, labels in distmats_and_labels:
        ids_a, ids_p, ids_n = sampler._sample_from_distmat(  # noqa: WPS437
            distmat=distmat, labels=labels)

        check_triplets_are_hardest(
            ids_anchor=ids_a,
            ids_pos=ids_p,
            ids_neg=ids_n,
            labels=labels,
            distmat=distmat,
        )

        check_triplets_consistency(ids_anchor=ids_a,
                                   ids_pos=ids_p,
                                   ids_neg=ids_n,
                                   labels=labels)

        assert len(labels) == len(ids_a)
コード例 #2
0
def test_hard_sampler_manual() -> None:
    """
    Test on manual example.
    """
    labels = [0, 0, 1, 1]

    dist_mat = torch.tensor([
        [0.0, 0.3, 0.2, 0.4],
        [0.3, 0.0, 0.4, 0.8],
        [0.2, 0.4, 0.0, 0.5],
        [0.4, 0.8, 0.5, 0.0],
    ])

    gt = {(0, 1, 2), (1, 0, 2), (2, 3, 0), (3, 2, 0)}

    sampler = HardTripletsSampler(norm_required=True)

    ids_a, ids_p, ids_n = sampler._sample_from_distmat(  # noqa: WPS437
        distmat=dist_mat, labels=labels)
    predict = set(zip(ids_a, ids_p, ids_n))

    check_triplets_consistency(ids_anchor=ids_a,
                               ids_pos=ids_p,
                               ids_neg=ids_n,
                               labels=labels)

    assert len(labels) == len(ids_a)
    assert predict == gt
コード例 #3
0
def test_triplet_cluster_edge_case() -> None:
    """
    Check an edge case of trivial samples for classes:
    expected HardTripletsSampler and HardClusterSampler to
    generate the same triplets.
    """
    features_dim = 128
    p, k = randint(2, 32), randint(2, 32)

    # Create a list of random labels
    unique_labels = torch.tensor(list(range(p)))
    # Create a list of random features for all the classes
    unique_features = torch.rand(size=(p, features_dim), dtype=torch.float)

    labels = unique_labels.repeat((k, ))
    features = unique_features.repeat((k, 1))

    hard_triplet_sampler = HardTripletsSampler()
    hard_cluster_sampler = HardClusterSampler()

    triplets = hard_triplet_sampler.sample(features, labels)
    cluster_triplets = hard_cluster_sampler.sample(features, labels)

    # Concatenates tensors from triplets to use torch.unique for comparison
    triplets = torch.cat(triplets, dim=1)
    cluster_triplets = torch.cat(cluster_triplets, dim=1)

    triplets = torch.unique(triplets, dim=0)
    cluster_triplets = torch.unique(cluster_triplets, dim=0)

    assert torch.allclose(triplets, cluster_triplets, atol=1e-10)
コード例 #4
0
def test_hard_sampler_from_features(features_and_labels) -> None:  # noqa: WPS442
    """
    Args:
        features_and_labels: features and valid labels
    """
    sampler = HardTripletsSampler(norm_required=True)

    for features, labels in features_and_labels:
        ids_a, ids_p, ids_n = sampler._sample(features=features, labels=labels)  # noqa: WPS437

        check_triplets_consistency(ids_anchor=ids_a, ids_pos=ids_p, ids_neg=ids_n, labels=labels)

        assert len(ids_a) == len(labels)