def test_power_of_2_clusters(self): hashes = torch.cat([ torch.full((10, ), 1 << i, dtype=torch.int64) for i in range(8) ]).view(1, 1, 80)[:, :, torch.randperm(80)].cuda() lengths = torch.full((1, ), 80, dtype=torch.int32).cuda() centroids = torch.empty(1, 1, 8, dtype=torch.int64).cuda() distances = torch.empty(1, 1, 80, dtype=torch.int32).cuda() bitcounts = torch.empty(1, 1, 8, 8, dtype=torch.int32).cuda() clusters = torch.empty(1, 1, 80, dtype=torch.int32).cuda() counts = torch.empty(1, 1, 8, dtype=torch.int32).cuda() cluster_cuda.cluster(hashes, lengths, centroids, distances, bitcounts, clusters, counts, 2000, 8) self.assertEqual( tuple(sorted(centroids.cpu().numpy().ravel().tolist())), (1, 2, 4, 8, 16, 32, 64, 128)) self.assertTrue(torch.all(counts == 10))
def test_many_sequences(self): hashes = torch.cat([ torch.zeros(50).long(), torch.full((50, ), 255, dtype=torch.int64) ]).view(1, 1, 100)[:, :, torch.randperm(100)].repeat(5, 3, 1).cuda() lengths = torch.full((5, ), 100, dtype=torch.int32).cuda() centroids = torch.empty(5, 3, 2, dtype=torch.int64).cuda() distances = torch.empty(5, 3, 100, dtype=torch.int32).cuda() bitcounts = torch.empty(5, 3, 2, 8, dtype=torch.int32).cuda() clusters = torch.empty(5, 3, 100, dtype=torch.int32).cuda() counts = torch.empty(5, 3, 2, dtype=torch.int32).cuda() cluster_cuda.cluster(hashes, lengths, centroids, distances, bitcounts, clusters, counts, 10, 8) self.assertTrue(torch.all(centroids.min(-1)[0] == 0)) self.assertTrue(torch.all(centroids.max(-1)[0] == 255)) self.assertTrue(torch.all(counts == 50))
def test_two_clusters(self): hashes = torch.cat([ torch.zeros(50).long(), torch.full((50, ), 255, dtype=torch.int64) ]).view(1, 1, 100)[:, :, torch.randperm(100)].cuda() lengths = torch.full((1, ), 100, dtype=torch.int32).cuda() centroids = torch.empty(1, 1, 2, dtype=torch.int64).cuda() distances = torch.empty(1, 1, 100, dtype=torch.int32).cuda() bitcounts = torch.empty(1, 1, 2, 8, dtype=torch.int32).cuda() clusters = torch.empty(1, 1, 100, dtype=torch.int32).cuda() counts = torch.empty(1, 1, 2, dtype=torch.int32).cuda() cluster_cuda.cluster(hashes, lengths, centroids, distances, bitcounts, clusters, counts, 10, 8) self.assertEqual( tuple(sorted(centroids.cpu().numpy().ravel().tolist())), (0, 255)) self.assertTrue(torch.all(counts == 50))