Exemplo n.º 1
0
def time_clustering(L, N, H, E,
                    n_batches, n_attentions,
                    k, n_buckets, n_iterations, verbose):
    n_points = L * N * H
    hashes = torch.zeros(n_points, dtype=torch.int64).cuda()
    hashes = generate_hash(n_points, E, n_buckets, hashes).view(N, H, L)

    groups = torch.zeros((N, H, L), dtype=torch.int32).cuda()
    counts = torch.zeros((N, H, k), dtype=torch.int32).cuda()
    centroids = torch.zeros((N, H, k), dtype=torch.int64).cuda()
    distances = torch.zeros((N, H, L), dtype=torch.int32).cuda()
    cluster_bit_counts = torch.zeros((N, H, k, n_buckets),
                                     dtype=torch.int32).cuda()
    sequence_lengths = torch.ones((N,), dtype=torch.int32).cuda() * L

    start = time.time()
    for batch_idx in range(int(n_batches)):
        for attention_idx in range(int(n_attentions)):
            # hashes = generate_hash(n_points, E, n_buckets, hashes).view(L, N, H)
            cluster(
                hashes, sequence_lengths,
                groups=groups, counts=counts, centroids=centroids,
                distances=distances, bitcounts=cluster_bit_counts,
                iterations=n_iterations,
                bits=n_buckets
            )
    end = time.time()
    duration = end - start
    print("Time Elapsed: {}".format(duration))
Exemplo n.º 2
0
    def test_clustering_convergence(self):
        N = 50
        H = 4
        E = 32
        n_iterations = 10

        for n_buckets in range(1, 10):
            print('Testing convergence for {} bits'.format(n_buckets))
            k = 2**n_buckets

            L = k
            n_points = L * N * H
            hashes = generate_hash(n_points, E, n_buckets).view(N, H, L).cuda()
            lengths = torch.ones((N, ), dtype=torch.int32).cuda() * L
            distances = torch.zeros((N, H, L), dtype=torch.int32).cuda()

            cluster(hashes,
                    lengths,
                    distances=distances,
                    clusters=k,
                    iterations=n_iterations,
                    bits=n_buckets)

            distances_np = distances.data.cpu().numpy()
            self.assertEqual(distances_np.sum(), 0)
Exemplo n.º 3
0
    def test_benchmark_clustering(self):
        N = 12
        H = 4
        L = 1000
        E = 32

        k = 100
        n_buckets = 63
        n_iterations = 10

        n_points = L * N * H
        for n_buckets in range(10, 64):
            hashes = generate_hash(n_points, E, n_buckets).view(N, H, L).cuda()
            groups = torch.zeros((N, H, L), dtype=torch.int32).cuda()
            counts = torch.zeros((N, H, k), dtype=torch.int32).cuda()
            centroids = torch.zeros((N, H, k), dtype=torch.int64).cuda()
            distances = torch.zeros((N, H, L), dtype=torch.int32).cuda()
            cluster_bit_counts = torch.zeros((N, H, k, n_buckets),
                                             dtype=torch.int32).cuda()
            sequence_lengths = torch.ones((N, ), dtype=torch.int32).cuda() * L
            sequence_lengths.random_(1, L + 1)

            for i in range(500):
                cluster(hashes,
                        sequence_lengths,
                        groups=groups,
                        counts=counts,
                        centroids=centroids,
                        distances=distances,
                        bitcounts=cluster_bit_counts,
                        iterations=n_iterations,
                        bits=n_buckets)

            s = torch.cuda.Event(enable_timing=True)
            e = torch.cuda.Event(enable_timing=True)
            s.record()
            cluster(hashes,
                    sequence_lengths,
                    groups=groups,
                    counts=counts,
                    centroids=centroids,
                    distances=distances,
                    bitcounts=cluster_bit_counts,
                    iterations=n_iterations,
                    bits=n_buckets)
            e.record()
            torch.cuda.synchronize()
            t_clustering = s.elapsed_time(e)

            print("Clustering with {} bits took {} time".format(
                n_buckets, t_clustering))
Exemplo n.º 4
0
    def test_clustering(self):
        N = 50
        H = 4
        L = 100
        E = 32

        k = 20
        n_buckets = 31
        n_iterations = 10

        n_points = L * N * H

        groups = torch.zeros((N, H, L), dtype=torch.int32).cuda()
        counts = torch.zeros((N, H, k), dtype=torch.int32).cuda()
        centroids = torch.zeros((N, H, k), dtype=torch.int64).cuda()
        distances = torch.zeros((N, H, L), dtype=torch.int32).cuda()
        cluster_bit_counts = torch.zeros((N, H, k, n_buckets),
                                         dtype=torch.int32).cuda()
        sequence_lengths = torch.ones((N, ), dtype=torch.int32).cuda() * L

        for i in range(50):
            hashes = generate_hash(n_points, E, n_buckets).view(N, H, L).cuda()

            cluster(hashes,
                    sequence_lengths,
                    groups=groups,
                    counts=counts,
                    centroids=centroids,
                    distances=distances,
                    bitcounts=cluster_bit_counts,
                    iterations=n_iterations,
                    bits=n_buckets)

            lengths_np = sequence_lengths.repeat_interleave(H).cpu().numpy()
            hashes_np = hashes.view(N * H, L).cpu().numpy()
            groups_np = groups.view(N * H, L).cpu().numpy()
            distances_np = distances.view(N * H, L).cpu().numpy()
            centroids_np = centroids.view(N * H, k).cpu().numpy()

            verify_distances(hashes_np, groups_np, centroids_np, distances_np,
                             n_buckets, lengths_np)
Exemplo n.º 5
0
    def test_masked_clustering_convergence(self):
        N = 50
        H = 4
        E = 32

        n_iterations = 10

        for n_buckets in range(1, 10):
            print('Testing convergence for {} bits'.format(n_buckets))
            k = 2**n_buckets
            L = k + 1
            n_points = L * N * H

            hashes = generate_hash(n_points, E, n_buckets).view(N, H, L).cuda()

            groups = torch.zeros((N, H, L), dtype=torch.int32).cuda()
            counts = torch.zeros((N, H, k), dtype=torch.int32).cuda()
            centroids = torch.zeros((N, H, k), dtype=torch.int64).cuda()
            distances = torch.zeros((N, H, L), dtype=torch.int32).cuda()
            cluster_bit_counts = torch.zeros((N, H, k, n_buckets),
                                             dtype=torch.int32).cuda()
            sequence_lengths = torch.ones((N, ), dtype=torch.int32).cuda() * L
            sequence_lengths.random_(L)
            sequence_lengths += 1

            cluster(hashes,
                    sequence_lengths,
                    groups=groups,
                    counts=counts,
                    centroids=centroids,
                    distances=distances,
                    bitcounts=cluster_bit_counts,
                    iterations=n_iterations,
                    bits=n_buckets)

            lengths_np = sequence_lengths.cpu().numpy()
            distances_np = distances.cpu().numpy()
            for n in range(N):
                distances_np[n, :, lengths_np[n] - 1:] = 0
            assert (distances_np.sum() == 0)
    def test_benchmark_clustering(self):
        N = 12
        H = 4
        L = 1000
        E = 32

        k = 100
        n_buckets = 63
        n_iterations = 10

        n_points = L * N * H
        for n_buckets in range(10, 64):
            hashes = generate_hash(n_points, E, n_buckets).view(N, H, L)
            groups = torch.zeros((N, H, L), dtype=torch.int32)
            counts = torch.zeros((N, H, k), dtype=torch.int32)
            centroids = torch.zeros((N, H, k), dtype=torch.int64)
            distances = torch.zeros((N, H, L), dtype=torch.int32)
            cluster_bit_counts = torch.zeros((N, H, k, n_buckets),
                                             dtype=torch.int32)
            sequence_lengths = torch.ones((N, ), dtype=torch.int32) * L
            sequence_lengths.random_(1, L + 1)

            s = time.time()
            for i in range(50):
                cluster(hashes,
                        sequence_lengths,
                        groups=groups,
                        counts=counts,
                        centroids=centroids,
                        distances=distances,
                        bitcounts=cluster_bit_counts,
                        iterations=n_iterations,
                        bits=n_buckets)
            e = time.time()
            t_clustering = e - s

            print("Clustering with {} bits took {} time".format(
                n_buckets, t_clustering))
Exemplo n.º 7
0
def cluster_queries(Q, query_lengths, C, I, B):
    N, H, L, E = Q.shape
    planes = Q.new_empty((B, E + 1))
    normal_(planes)
    planes[:, -1] = 0
    hashes = compute_hashes(Q.view(N * H * L, E), planes).view(N, H, L)
    # Cluster the hashes and return the cluster index per query
    groups, counts = cluster(hashes,
                             query_lengths,
                             clusters=C,
                             iterations=I,
                             bits=B)

    return groups, counts