def time_clustering(L, N, H, E, n_batches, n_attentions, k, n_buckets, n_iterations, verbose): n_points = L * N * H hashes = torch.zeros(n_points, dtype=torch.int64).cuda() hashes = generate_hash(n_points, E, n_buckets, hashes).view(N, H, L) groups = torch.zeros((N, H, L), dtype=torch.int32).cuda() counts = torch.zeros((N, H, k), dtype=torch.int32).cuda() centroids = torch.zeros((N, H, k), dtype=torch.int64).cuda() distances = torch.zeros((N, H, L), dtype=torch.int32).cuda() cluster_bit_counts = torch.zeros((N, H, k, n_buckets), dtype=torch.int32).cuda() sequence_lengths = torch.ones((N,), dtype=torch.int32).cuda() * L start = time.time() for batch_idx in range(int(n_batches)): for attention_idx in range(int(n_attentions)): # hashes = generate_hash(n_points, E, n_buckets, hashes).view(L, N, H) cluster( hashes, sequence_lengths, groups=groups, counts=counts, centroids=centroids, distances=distances, bitcounts=cluster_bit_counts, iterations=n_iterations, bits=n_buckets ) end = time.time() duration = end - start print("Time Elapsed: {}".format(duration))
def test_clustering_convergence(self): N = 50 H = 4 E = 32 n_iterations = 10 for n_buckets in range(1, 10): print('Testing convergence for {} bits'.format(n_buckets)) k = 2**n_buckets L = k n_points = L * N * H hashes = generate_hash(n_points, E, n_buckets).view(N, H, L).cuda() lengths = torch.ones((N, ), dtype=torch.int32).cuda() * L distances = torch.zeros((N, H, L), dtype=torch.int32).cuda() cluster(hashes, lengths, distances=distances, clusters=k, iterations=n_iterations, bits=n_buckets) distances_np = distances.data.cpu().numpy() self.assertEqual(distances_np.sum(), 0)
def test_benchmark_clustering(self): N = 12 H = 4 L = 1000 E = 32 k = 100 n_buckets = 63 n_iterations = 10 n_points = L * N * H for n_buckets in range(10, 64): hashes = generate_hash(n_points, E, n_buckets).view(N, H, L).cuda() groups = torch.zeros((N, H, L), dtype=torch.int32).cuda() counts = torch.zeros((N, H, k), dtype=torch.int32).cuda() centroids = torch.zeros((N, H, k), dtype=torch.int64).cuda() distances = torch.zeros((N, H, L), dtype=torch.int32).cuda() cluster_bit_counts = torch.zeros((N, H, k, n_buckets), dtype=torch.int32).cuda() sequence_lengths = torch.ones((N, ), dtype=torch.int32).cuda() * L sequence_lengths.random_(1, L + 1) for i in range(500): cluster(hashes, sequence_lengths, groups=groups, counts=counts, centroids=centroids, distances=distances, bitcounts=cluster_bit_counts, iterations=n_iterations, bits=n_buckets) s = torch.cuda.Event(enable_timing=True) e = torch.cuda.Event(enable_timing=True) s.record() cluster(hashes, sequence_lengths, groups=groups, counts=counts, centroids=centroids, distances=distances, bitcounts=cluster_bit_counts, iterations=n_iterations, bits=n_buckets) e.record() torch.cuda.synchronize() t_clustering = s.elapsed_time(e) print("Clustering with {} bits took {} time".format( n_buckets, t_clustering))
def test_clustering(self): N = 50 H = 4 L = 100 E = 32 k = 20 n_buckets = 31 n_iterations = 10 n_points = L * N * H groups = torch.zeros((N, H, L), dtype=torch.int32).cuda() counts = torch.zeros((N, H, k), dtype=torch.int32).cuda() centroids = torch.zeros((N, H, k), dtype=torch.int64).cuda() distances = torch.zeros((N, H, L), dtype=torch.int32).cuda() cluster_bit_counts = torch.zeros((N, H, k, n_buckets), dtype=torch.int32).cuda() sequence_lengths = torch.ones((N, ), dtype=torch.int32).cuda() * L for i in range(50): hashes = generate_hash(n_points, E, n_buckets).view(N, H, L).cuda() cluster(hashes, sequence_lengths, groups=groups, counts=counts, centroids=centroids, distances=distances, bitcounts=cluster_bit_counts, iterations=n_iterations, bits=n_buckets) lengths_np = sequence_lengths.repeat_interleave(H).cpu().numpy() hashes_np = hashes.view(N * H, L).cpu().numpy() groups_np = groups.view(N * H, L).cpu().numpy() distances_np = distances.view(N * H, L).cpu().numpy() centroids_np = centroids.view(N * H, k).cpu().numpy() verify_distances(hashes_np, groups_np, centroids_np, distances_np, n_buckets, lengths_np)
def test_masked_clustering_convergence(self): N = 50 H = 4 E = 32 n_iterations = 10 for n_buckets in range(1, 10): print('Testing convergence for {} bits'.format(n_buckets)) k = 2**n_buckets L = k + 1 n_points = L * N * H hashes = generate_hash(n_points, E, n_buckets).view(N, H, L).cuda() groups = torch.zeros((N, H, L), dtype=torch.int32).cuda() counts = torch.zeros((N, H, k), dtype=torch.int32).cuda() centroids = torch.zeros((N, H, k), dtype=torch.int64).cuda() distances = torch.zeros((N, H, L), dtype=torch.int32).cuda() cluster_bit_counts = torch.zeros((N, H, k, n_buckets), dtype=torch.int32).cuda() sequence_lengths = torch.ones((N, ), dtype=torch.int32).cuda() * L sequence_lengths.random_(L) sequence_lengths += 1 cluster(hashes, sequence_lengths, groups=groups, counts=counts, centroids=centroids, distances=distances, bitcounts=cluster_bit_counts, iterations=n_iterations, bits=n_buckets) lengths_np = sequence_lengths.cpu().numpy() distances_np = distances.cpu().numpy() for n in range(N): distances_np[n, :, lengths_np[n] - 1:] = 0 assert (distances_np.sum() == 0)
def test_benchmark_clustering(self): N = 12 H = 4 L = 1000 E = 32 k = 100 n_buckets = 63 n_iterations = 10 n_points = L * N * H for n_buckets in range(10, 64): hashes = generate_hash(n_points, E, n_buckets).view(N, H, L) groups = torch.zeros((N, H, L), dtype=torch.int32) counts = torch.zeros((N, H, k), dtype=torch.int32) centroids = torch.zeros((N, H, k), dtype=torch.int64) distances = torch.zeros((N, H, L), dtype=torch.int32) cluster_bit_counts = torch.zeros((N, H, k, n_buckets), dtype=torch.int32) sequence_lengths = torch.ones((N, ), dtype=torch.int32) * L sequence_lengths.random_(1, L + 1) s = time.time() for i in range(50): cluster(hashes, sequence_lengths, groups=groups, counts=counts, centroids=centroids, distances=distances, bitcounts=cluster_bit_counts, iterations=n_iterations, bits=n_buckets) e = time.time() t_clustering = e - s print("Clustering with {} bits took {} time".format( n_buckets, t_clustering))
def cluster_queries(Q, query_lengths, C, I, B): N, H, L, E = Q.shape planes = Q.new_empty((B, E + 1)) normal_(planes) planes[:, -1] = 0 hashes = compute_hashes(Q.view(N * H * L, E), planes).view(N, H, L) # Cluster the hashes and return the cluster index per query groups, counts = cluster(hashes, query_lengths, clusters=C, iterations=I, bits=B) return groups, counts