Esempio n. 1
0
def test_triplet_cluster_edge_case() -> None:
    """
    Check an edge case of trivial samples for classes:
    expected HardTripletsSampler and HardClusterSampler to
    generate the same triplets.
    """
    features_dim = 128
    p, k = randint(2, 32), randint(2, 32)

    # Create a list of random labels
    unique_labels = torch.tensor(list(range(p)))
    # Create a list of random features for all the classes
    unique_features = torch.rand(size=(p, features_dim), dtype=torch.float)

    labels = unique_labels.repeat((k, ))
    features = unique_features.repeat((k, 1))

    hard_triplet_sampler = HardTripletsSampler()
    hard_cluster_sampler = HardClusterSampler()

    triplets = hard_triplet_sampler.sample(features, labels)
    cluster_triplets = hard_cluster_sampler.sample(features, labels)

    # Concatenates tensors from triplets to use torch.unique for comparison
    triplets = torch.cat(triplets, dim=1)
    cluster_triplets = torch.cat(cluster_triplets, dim=1)

    triplets = torch.unique(triplets, dim=0)
    cluster_triplets = torch.unique(cluster_triplets, dim=0)

    assert torch.allclose(triplets, cluster_triplets, atol=1e-10)
Esempio n. 2
0
def test_cluster_sample_shapes(embed_dim: int, labels: TLabels,
                               expected_shape: List[Tuple[int]]) -> None:
    """
    Test output shapes in sample method of HardClusterSampler.

    Args:
        embed_dim: size of embedding
        labels: list of labels for samples in batch
        expected_shape: expected shape of output triplet
    """
    sampler = HardClusterSampler()
    batch_size = len(labels)
    features = torch.rand(size=(batch_size, embed_dim))
    anchor, positive, negative = sampler.sample(features, labels)
    anchor_shape, pos_shape, neg_shape = expected_shape

    assert anchor.shape == anchor_shape
    assert positive.shape == pos_shape
    assert negative.shape == neg_shape