def test_long_clusters(self): for bits in range(1, 63): hashes = torch.cat([ torch.zeros(50).long(), torch.ones(50).long() * (2**bits - 1) ]).view(1, 1, 100)[:,:,torch.randperm(100)] lengths = torch.full((1,), 100).int() centroids = torch.empty(1, 1, 2, dtype=torch.int64) clusters = torch.empty(1, 1, 100, dtype=torch.int32) counts = torch.empty(1, 1, 2, dtype=torch.int32) cluster_cpu( hashes, lengths, centroids, clusters, counts, 10, bits ) self.assertEqual( tuple(sorted(centroids.numpy().ravel().tolist())), (0, 2**bits - 1) ) self.assertTrue(torch.all(counts==50))
def test_power_of_2_clusters(self): hashes = torch.cat([ torch.full((10, ), 1 << i, dtype=torch.int64) for i in range(8) ]).view(1, 1, 80)[:, :, torch.randperm(80)] lengths = torch.full((1, ), 80, dtype=torch.int32) centroids = torch.empty(1, 1, 8, dtype=torch.int64) clusters = torch.empty(1, 1, 80, dtype=torch.int32) counts = torch.empty(1, 1, 8, dtype=torch.int32) cluster_cpu(hashes, lengths, centroids, clusters, counts, 2000, 8) self.assertEqual(tuple(sorted(centroids.numpy().ravel().tolist())), (1, 2, 4, 8, 16, 32, 64, 128)) self.assertTrue(torch.all(counts == 10))
def test_many_sequences(self): hashes = torch.cat([ torch.zeros(50).long(), torch.full((50, ), 255, dtype=torch.int64) ]).view(1, 1, 100)[:, :, torch.randperm(100)].repeat(5, 3, 1) lengths = torch.full((5, ), 100, dtype=torch.int32) centroids = torch.empty(5, 3, 2, dtype=torch.int64) clusters = torch.empty(5, 3, 100, dtype=torch.int32) counts = torch.empty(5, 3, 2, dtype=torch.int32) cluster_cpu(hashes, lengths, centroids, clusters, counts, 10, 8) self.assertTrue(torch.all(centroids.min(2)[0] == 0)) self.assertTrue(torch.all(centroids.max(2)[0] == 255)) self.assertTrue(torch.all(counts == 50))
def test_two_clusters(self): hashes = torch.cat([ torch.zeros(50).long(), torch.full((50, ), 255, dtype=torch.int64) ]).view(1, 1, 100)[:, :, torch.randperm(100)] lengths = torch.full((1, ), 100, dtype=torch.int32) centroids = torch.empty(1, 1, 2, dtype=torch.int64) clusters = torch.empty(1, 1, 100, dtype=torch.int32) counts = torch.empty(1, 1, 2, dtype=torch.int32) cluster_cpu(hashes, lengths, centroids, clusters, counts, 10, 8) self.assertEqual(tuple(sorted(centroids.numpy().ravel().tolist())), (0, 255)) self.assertTrue(torch.all(counts == 50))