Exemple #1
0
    def test_long_clusters(self):
        for bits in range(1, 63):
            hashes = torch.cat([
                torch.zeros(50).long(),
                torch.ones(50).long() * (2**bits - 1)
            ]).view(1, 1, 100)[:,:,torch.randperm(100)]
            lengths = torch.full((1,), 100).int()
            centroids = torch.empty(1, 1, 2, dtype=torch.int64)
            clusters = torch.empty(1, 1, 100, dtype=torch.int32)
            counts = torch.empty(1, 1, 2, dtype=torch.int32)

            cluster_cpu(
                hashes,
                lengths,
                centroids,
                clusters,
                counts,
                10,
                bits
            )
            self.assertEqual(
                tuple(sorted(centroids.numpy().ravel().tolist())),
                (0, 2**bits - 1)
            )
            self.assertTrue(torch.all(counts==50))
    def test_power_of_2_clusters(self):
        hashes = torch.cat([
            torch.full((10, ), 1 << i, dtype=torch.int64) for i in range(8)
        ]).view(1, 1, 80)[:, :, torch.randperm(80)]
        lengths = torch.full((1, ), 80, dtype=torch.int32)
        centroids = torch.empty(1, 1, 8, dtype=torch.int64)
        clusters = torch.empty(1, 1, 80, dtype=torch.int32)
        counts = torch.empty(1, 1, 8, dtype=torch.int32)

        cluster_cpu(hashes, lengths, centroids, clusters, counts, 2000, 8)
        self.assertEqual(tuple(sorted(centroids.numpy().ravel().tolist())),
                         (1, 2, 4, 8, 16, 32, 64, 128))
        self.assertTrue(torch.all(counts == 10))
    def test_many_sequences(self):
        hashes = torch.cat([
            torch.zeros(50).long(),
            torch.full((50, ), 255, dtype=torch.int64)
        ]).view(1, 1, 100)[:, :, torch.randperm(100)].repeat(5, 3, 1)
        lengths = torch.full((5, ), 100, dtype=torch.int32)
        centroids = torch.empty(5, 3, 2, dtype=torch.int64)
        clusters = torch.empty(5, 3, 100, dtype=torch.int32)
        counts = torch.empty(5, 3, 2, dtype=torch.int32)

        cluster_cpu(hashes, lengths, centroids, clusters, counts, 10, 8)
        self.assertTrue(torch.all(centroids.min(2)[0] == 0))
        self.assertTrue(torch.all(centroids.max(2)[0] == 255))
        self.assertTrue(torch.all(counts == 50))
    def test_two_clusters(self):
        hashes = torch.cat([
            torch.zeros(50).long(),
            torch.full((50, ), 255, dtype=torch.int64)
        ]).view(1, 1, 100)[:, :, torch.randperm(100)]
        lengths = torch.full((1, ), 100, dtype=torch.int32)
        centroids = torch.empty(1, 1, 2, dtype=torch.int64)
        clusters = torch.empty(1, 1, 100, dtype=torch.int32)
        counts = torch.empty(1, 1, 2, dtype=torch.int32)

        cluster_cpu(hashes, lengths, centroids, clusters, counts, 10, 8)
        self.assertEqual(tuple(sorted(centroids.numpy().ravel().tolist())),
                         (0, 255))
        self.assertTrue(torch.all(counts == 50))