def test_triplet_cluster_edge_case() -> None: """ Check an edge case of trivial samples for classes: expected HardTripletsSampler and HardClusterSampler to generate the same triplets. """ features_dim = 128 p, k = randint(2, 32), randint(2, 32) # Create a list of random labels unique_labels = torch.tensor(list(range(p))) # Create a list of random features for all the classes unique_features = torch.rand(size=(p, features_dim), dtype=torch.float) labels = unique_labels.repeat((k, )) features = unique_features.repeat((k, 1)) hard_triplet_sampler = HardTripletsSampler() hard_cluster_sampler = HardClusterSampler() triplets = hard_triplet_sampler.sample(features, labels) cluster_triplets = hard_cluster_sampler.sample(features, labels) # Concatenates tensors from triplets to use torch.unique for comparison triplets = torch.cat(triplets, dim=1) cluster_triplets = torch.cat(cluster_triplets, dim=1) triplets = torch.unique(triplets, dim=0) cluster_triplets = torch.unique(cluster_triplets, dim=0) assert torch.allclose(triplets, cluster_triplets, atol=1e-10)
def test_cluster_get_labels_mask(labels: List[int], expected: torch.Tensor) -> None: """ Test _get_labels_mask method of HardClusterSampler. Args: labels: list of labels -- input data for method _skip_diagonal expected: correct answer for labels input """ sampler = HardClusterSampler() labels_mask = sampler._get_labels_mask(labels) # noqa: WPS437 assert (labels_mask == expected).all()
def test_cluster_count_inter_class_distances(mean_vectors, expected) -> None: """ Test _count_inter_class_distances method of HardClusterSampler. Args: mean_vectors: tensor of shape (p, embed_dim) -- mean vectors of classes in the batch expected: tensor of shape (p, p) -- expected distances from mean vectors of classes """ sampler = HardClusterSampler() distances = sampler._count_inter_class_distances(mean_vectors) # noqa: WPS437 assert (distances == expected).all()
def test_cluster_count_intra_class_distances(features: torch.Tensor, expected: torch.Tensor) -> None: """ Test _count_intra_class_distances method of HardClusterSampler. Args: features: tensor of shape (p, k, embed_dim), where p is a number of classes in the batch, k is a number of samples for each class, embed_dim is an embedding size -- features grouped by labels expected: tensor of shape (p, k) -- expected distances from mean vectors of classes to corresponding features """ sampler = HardClusterSampler() mean_vectors = features.mean(1) distances = sampler._count_intra_class_distances(features, mean_vectors) assert (distances == expected).all()
def test_cluster_sample_shapes(embed_dim: int, labels: TLabels, expected_shape: List[Tuple[int]]) -> None: """ Test output shapes in sample method of HardClusterSampler. Args: embed_dim: size of embedding labels: list of labels for samples in batch expected_shape: expected shape of output triplet """ sampler = HardClusterSampler() batch_size = len(labels) features = torch.rand(size=(batch_size, embed_dim)) anchor, positive, negative = sampler.sample(features, labels) anchor_shape, pos_shape, neg_shape = expected_shape assert anchor.shape == anchor_shape assert positive.shape == pos_shape assert negative.shape == neg_shape