def test_hard_sampler_from_dist(distmats_and_labels) -> None: # noqa: WPS442 """ Args: distmats_and_labels: list of distance matrices and valid labels """ sampler = HardTripletsSampler(norm_required=True) for distmat, labels in distmats_and_labels: ids_a, ids_p, ids_n = sampler._sample_from_distmat( # noqa: WPS437 distmat=distmat, labels=labels) check_triplets_are_hardest( ids_anchor=ids_a, ids_pos=ids_p, ids_neg=ids_n, labels=labels, distmat=distmat, ) check_triplets_consistency(ids_anchor=ids_a, ids_pos=ids_p, ids_neg=ids_n, labels=labels) assert len(labels) == len(ids_a)
def test_hard_sampler_manual() -> None: """ Test on manual example. """ labels = [0, 0, 1, 1] dist_mat = torch.tensor([ [0.0, 0.3, 0.2, 0.4], [0.3, 0.0, 0.4, 0.8], [0.2, 0.4, 0.0, 0.5], [0.4, 0.8, 0.5, 0.0], ]) gt = {(0, 1, 2), (1, 0, 2), (2, 3, 0), (3, 2, 0)} sampler = HardTripletsSampler(norm_required=True) ids_a, ids_p, ids_n = sampler._sample_from_distmat( # noqa: WPS437 distmat=dist_mat, labels=labels) predict = set(zip(ids_a, ids_p, ids_n)) check_triplets_consistency(ids_anchor=ids_a, ids_pos=ids_p, ids_neg=ids_n, labels=labels) assert len(labels) == len(ids_a) assert predict == gt
def test_triplet_cluster_edge_case() -> None: """ Check an edge case of trivial samples for classes: expected HardTripletsSampler and HardClusterSampler to generate the same triplets. """ features_dim = 128 p, k = randint(2, 32), randint(2, 32) # Create a list of random labels unique_labels = torch.tensor(list(range(p))) # Create a list of random features for all the classes unique_features = torch.rand(size=(p, features_dim), dtype=torch.float) labels = unique_labels.repeat((k, )) features = unique_features.repeat((k, 1)) hard_triplet_sampler = HardTripletsSampler() hard_cluster_sampler = HardClusterSampler() triplets = hard_triplet_sampler.sample(features, labels) cluster_triplets = hard_cluster_sampler.sample(features, labels) # Concatenates tensors from triplets to use torch.unique for comparison triplets = torch.cat(triplets, dim=1) cluster_triplets = torch.cat(cluster_triplets, dim=1) triplets = torch.unique(triplets, dim=0) cluster_triplets = torch.unique(cluster_triplets, dim=0) assert torch.allclose(triplets, cluster_triplets, atol=1e-10)
def test_hard_sampler_from_features(features_and_labels) -> None: # noqa: WPS442 """ Args: features_and_labels: features and valid labels """ sampler = HardTripletsSampler(norm_required=True) for features, labels in features_and_labels: ids_a, ids_p, ids_n = sampler._sample(features=features, labels=labels) # noqa: WPS437 check_triplets_consistency(ids_anchor=ids_a, ids_pos=ids_p, ids_neg=ids_n, labels=labels) assert len(ids_a) == len(labels)