Пример #1
0
def test_triplet_cluster_edge_case() -> None:
    """
    Check an edge case of trivial samples for classes:
    expected HardTripletsSampler and HardClusterSampler to
    generate the same triplets.
    """
    features_dim = 128
    p, k = randint(2, 32), randint(2, 32)

    # Create a list of random labels
    unique_labels = torch.tensor(list(range(p)))
    # Create a list of random features for all the classes
    unique_features = torch.rand(size=(p, features_dim), dtype=torch.float)

    labels = unique_labels.repeat((k, ))
    features = unique_features.repeat((k, 1))

    hard_triplet_sampler = HardTripletsSampler()
    hard_cluster_sampler = HardClusterSampler()

    triplets = hard_triplet_sampler.sample(features, labels)
    cluster_triplets = hard_cluster_sampler.sample(features, labels)

    # Concatenates tensors from triplets to use torch.unique for comparison
    triplets = torch.cat(triplets, dim=1)
    cluster_triplets = torch.cat(cluster_triplets, dim=1)

    triplets = torch.unique(triplets, dim=0)
    cluster_triplets = torch.unique(cluster_triplets, dim=0)

    assert torch.allclose(triplets, cluster_triplets, atol=1e-10)
Пример #2
0
def test_cluster_get_labels_mask(labels: List[int], expected: torch.Tensor) -> None:
    """
    Test _get_labels_mask method of HardClusterSampler.

    Args:
        labels: list of labels -- input data for method _skip_diagonal
        expected: correct answer for labels input
    """
    sampler = HardClusterSampler()
    labels_mask = sampler._get_labels_mask(labels)  # noqa: WPS437
    assert (labels_mask == expected).all()
Пример #3
0
def test_cluster_count_inter_class_distances(mean_vectors, expected) -> None:
    """
    Test _count_inter_class_distances method of HardClusterSampler.

    Args:
        mean_vectors: tensor of shape (p, embed_dim) -- mean vectors of
        classes in the batch
        expected: tensor of shape (p, p) -- expected distances from mean
        vectors of classes
    """
    sampler = HardClusterSampler()
    distances = sampler._count_inter_class_distances(mean_vectors)  # noqa: WPS437
    assert (distances == expected).all()
Пример #4
0
def test_cluster_count_intra_class_distances(features: torch.Tensor,
                                             expected: torch.Tensor) -> None:
    """
    Test _count_intra_class_distances method of HardClusterSampler.

    Args:
        features: tensor of shape (p, k, embed_dim), where p is a number of
        classes in the batch, k is a number of samples for each class,
        embed_dim is an embedding size -- features grouped by labels
        expected: tensor of shape (p, k) -- expected distances from mean
        vectors of classes to corresponding features
    """
    sampler = HardClusterSampler()
    mean_vectors = features.mean(1)
    distances = sampler._count_intra_class_distances(features, mean_vectors)
    assert (distances == expected).all()
Пример #5
0
def test_cluster_sample_shapes(embed_dim: int, labels: TLabels,
                               expected_shape: List[Tuple[int]]) -> None:
    """
    Test output shapes in sample method of HardClusterSampler.

    Args:
        embed_dim: size of embedding
        labels: list of labels for samples in batch
        expected_shape: expected shape of output triplet
    """
    sampler = HardClusterSampler()
    batch_size = len(labels)
    features = torch.rand(size=(batch_size, embed_dim))
    anchor, positive, negative = sampler.sample(features, labels)
    anchor_shape, pos_shape, neg_shape = expected_shape

    assert anchor.shape == anchor_shape
    assert positive.shape == pos_shape
    assert negative.shape == neg_shape