Exemplo n.º 1
0
    def test_masked_simple_grad(self):
        N = 4
        H = 2
        L = 100
        E = 64
        S = 100
        k = 32
        C = 5
        I = 5
        B = 16

        for i in range(30):
            C = np.random.randint(10, 500)
            L = np.random.randint(C, 2000)
            E = np.random.randint(10, 128)
            S = np.random.randint(100, 1000)
            k = np.random.randint(10, 64)

            if os.getenv("VEROSE_TESTS", ""):
                print(("Testing Masked: N H L S E C k: "
                       "{} {} {} {} {} {} {}").format(N, H, L, S, E, C, k))

            Q = torch.randn(N, H, L, E).to(self.device).requires_grad_(True)
            K = torch.randn(N, H, S, E).to(self.device).requires_grad_(True)

            lengths = np.random.randint(C, L + 1, N)
            lengths = torch.tensor(lengths, dtype=torch.int32).to(self.device)
            query_lengths = LengthMask(lengths, max_len=L)
            groups, counts = cluster_queries(Q, lengths, C, I, B)
            Q_grouped = aggregate(Q, groups, 1 / counts.float())
            QK = torch.einsum("nhle,nhse->nhls", Q_grouped, K)
            _, topk = torch.topk(QK, k, dim=-1)
            topk = topk.contiguous()
            topk_broadcast = broadcast(
                topk.float(), groups,
                torch.ones_like(counts, dtype=torch.float32),
                torch.zeros((N, H, L, k), device=Q.device))

            self._zero_grad(Q, K)
            QK_full = torch.einsum("nhle,nhse->nhls", Q, K)
            QK_selected = QK_full[
                torch.arange(N).view(N, 1, 1, 1).to(self.device),
                torch.arange(H).view(1, H, 1, 1).to(self.device),
                torch.arange(L).view(1, 1, L, 1).to(self.device),
                topk_broadcast.long()]

            QK_selected = QK_selected * query_lengths.float_matrix[:, None, :,
                                                                   None]
            QK_selected.sum().backward()
            grad = [torch.clone(Q.grad), torch.clone(K.grad)]

            self._zero_grad(Q, K)
            QK_selected_hat = sparse_product(Q, K, groups, topk, counts,
                                             lengths)
            QK_selected_hat.sum().backward()
            grad_hat = [torch.clone(Q.grad), torch.clone(K.grad)]
            self.assertLess(
                torch.abs(QK_selected - QK_selected_hat).max(), 1e-4)
            for g1, g2 in zip(grad, grad_hat):
                self.assertLess(torch.abs(g1 - g2).max(), 1e-3)
Exemplo n.º 2
0
    def test_broadcast(self):
        N = 10
        H = 3
        L = 500
        S = 500
        E = 32
        C = 4
        I = 5
        B = 16

        Q = torch.randn(N, H, L, E).cuda()
        lengths = torch.full((N, ), L, dtype=torch.int32).cuda()
        lengths[1] = 400
        lengths[3] = 200
        lengths[7] = 450
        lengths[8] = 150
        groups, counts = cluster_queries(Q, lengths, C, I, B)
        Q_grouped = aggregate(Q, groups, 1 / counts.float())
        K = torch.randn(N, H, S, E).cuda()
        QK = torch.einsum("nhle,nhse->nhls", Q_grouped, K)

        V = torch.randn(N, H, S, E).cuda()
        A = F.softmax(QK, dim=-1)
        V_new = torch.einsum("nhls,nhse->nhle", A, V)
        V_broadcast = torch.zeros((N, H, L, E), dtype=V_new.dtype).cuda()
        # V_broadcast = broadcast_clustered(V_new, groups, counts, lengths, V_broadcast)
        V_broadcast = clustered_broadcast(V_new, groups, counts, lengths,
                                          V_broadcast)

        V_broadcast_2 = broadcast(V_new, groups,
                                  torch.ones_like(counts, dtype=torch.float32),
                                  torch.zeros((N, H, L, E), device=Q.device))
        self.assertLess(torch.max(torch.abs(V_broadcast_2 - V_broadcast)),
                        1e-4)
Exemplo n.º 3
0
    def test_forward(self):
        N = 5
        H = 2
        L = 100
        S = 100
        E = 32
        C = 10
        I = 10
        B = 32
        k = 5

        Q = torch.randn(N, H, L, E).to(self.device)
        K = torch.randn(N, H, S, E).to(self.device)
        lengths = torch.full((N, ), L, dtype=torch.int32).to(self.device)
        groups, counts = cluster_queries(Q, lengths, C, I, B)
        Q_grouped = aggregate(Q, groups, 1 / counts.float())
        QK = torch.einsum("nhle,nhse->nhls", Q_grouped, K)
        _, topk = torch.topk(QK, k, dim=-1)
        topk = topk.contiguous()
        topk_broadcast = broadcast(
            topk.float(), groups, torch.ones_like(counts, dtype=torch.float32),
            torch.zeros((N, H, L, k), device=Q.device))

        weights = torch.rand(N, H, L, k).to(self.device).requires_grad_(True)
        values = torch.randn(N, H, S, E).to(self.device).requires_grad_(True)
        values_selected = values[
            torch.arange(N).view(N, 1, 1, 1).to(self.device),
            torch.arange(H).view(1, H, 1, 1).to(self.device),
            topk_broadcast.long()]

        output = (weights.unsqueeze(-1) * values_selected).sum(-2)
        output_hat = clustered_sparse_weighted_average(weights, values, topk,
                                                       groups)
        self.assertLess(torch.abs(output - output_hat).max(), 1e-4)
Exemplo n.º 4
0
    def test_simple_product(self):
        N = 2
        H = 2
        L = 1000
        E = 32
        S = 1000
        k = 32
        C = 50
        I = 5
        B = 16

        Q = torch.randn(N, H, L, E).to(self.device)
        K = torch.randn(N, H, S, E).to(self.device)
        lengths = torch.full((N, ), L).int().to(self.device)
        groups, counts = cluster_queries(Q, lengths, C, I, B)
        Q_grouped = aggregate(Q, groups, 1 / counts.float())
        QK = torch.einsum("nhle,nhse->nhls", Q_grouped, K)
        _, topk = torch.topk(QK, k, dim=-1)
        topk = topk.contiguous()
        products = clustered_sparse_dot_product(Q, K, topk, groups, counts,
                                                lengths)
        topk_broadcast = broadcast(
            topk.float(), groups, torch.ones_like(counts, dtype=torch.float32),
            torch.zeros((N, H, L, k), device=Q.device))

        all_products = torch.einsum("nhle,nhse->nhls", Q, K)
        products_2 = all_products[torch.arange(N).view(N, 1, 1, 1),
                                  torch.arange(H).view(1, H, 1, 1),
                                  torch.arange(L).view(1, 1, L, 1),
                                  topk_broadcast.long()]

        self.assertLess(torch.max(torch.abs(products_2 - products)), 1e-4)
Exemplo n.º 5
0
    def test_simple_grad(self):
        N = 2
        H = 2
        L = 1000
        E = 32
        S = 1000
        k = 32
        C = 50
        I = 5
        B = 16

        Q = torch.randn(N, H, L, E).to(self.device).requires_grad_(True)
        K = torch.randn(N, H, S, E).to(self.device).requires_grad_(True)

        lengths = torch.full((N,), L).int().to(self.device)
        groups, counts = cluster_queries(Q, lengths, C, I, B)
        Q_grouped = aggregate(Q, groups, 1/counts.float())
        QK = torch.einsum("nhle,nhse->nhls", Q_grouped, K)
        _, topk = torch.topk(QK, k, dim=-1)
        topk = topk.contiguous()
        topk_broadcast = broadcast(
            topk.float(),
            groups,
            torch.ones_like(counts, dtype=torch.float32),
            torch.zeros((N, H, L, k), device=Q.device)
        )


        self._zero_grad(Q, K)
        QK_full = torch.einsum("nhle,nhse->nhls", Q, K)
        QK_selected = QK_full[
            torch.arange(N).view(N, 1, 1, 1).to(self.device),
            torch.arange(H).view(1, H, 1, 1).to(self.device),
            torch.arange(L).view(1, 1, L, 1).to(self.device),
            topk_broadcast.long()
        ]

        QK_selected.sum().backward()
        grad = [torch.clone(Q.grad), torch.clone(K.grad)]


        self._zero_grad(Q, K)
        QK_selected_hat = clustered_sparse_dot_product(
            Q, K, topk,
            groups, counts,
            lengths
        )

        QK_selected_hat.sum().backward()
        grad_hat = [torch.clone(Q.grad), torch.clone(K.grad)]

        self.assertLess(
            torch.abs(QK_selected - QK_selected_hat).max(),
            1e-4
        )
        for g1, g2 in zip(grad, grad_hat):
            self.assertLess(
                torch.abs(g1 - g2).max(),
                1e-4
            )
Exemplo n.º 6
0
    def test_small_forward_backward(self):
        N = 12
        H = 8
        L = 2000
        S = 2000
        E = 32
        k = 32
        C = 100
        I = 10
        B = 32

        Q = torch.randn(N, H, L, E).to(self.device)
        K = torch.randn(N, H, S, E).to(self.device)
        lengths = torch.full((N, ), L, dtype=torch.int32).to(self.device)
        groups, counts = cluster_queries(Q, lengths, C, I, B)
        Q_grouped = aggregate(Q, groups, 1 / counts.float())
        QK = torch.einsum("nhle,nhse->nhls", Q_grouped, K)
        _, topk = torch.topk(QK, k, dim=-1)
        topk = topk.contiguous()
        topk_broadcast = broadcast(
            topk.float(), groups, torch.ones_like(counts, dtype=torch.float32),
            torch.zeros((N, H, L, k), device=Q.device))
        weights = torch.rand(N, H, L, k).to(self.device).requires_grad_(True)
        values = torch.randn(N, H, S, E).to(self.device).requires_grad_(True)

        self._zero_grad(weights, values)
        n_runs = 20
        s = time.time()
        for i in range(n_runs):
            output_hat = clustered_sparse_weighted_average(
                weights, values, topk, groups, counts)
            output_hat.sum().backward()
        e = time.time()
        t_sparse = (e - s) / n_runs
        print('T_sparse Forward Backward:{}'.format(t_sparse))
Exemplo n.º 7
0
    def test_benchmark_forward_backward(self):
        N = 12
        H = 8
        L = 1024
        S = 1024
        E = 32
        k = 32
        C = 100
        I = 10
        B = 32

        Q = torch.randn(N, H, L, E).to(self.device).requires_grad_(True)
        K = torch.randn(N, H, S, E).to(self.device).requires_grad_(True)
        lengths = torch.full((N,), L).int().to(self.device)
        groups, counts = cluster_queries(Q, lengths, C, I, B)
        Q_grouped = aggregate(Q, groups, 1/counts.float())
        QK = torch.einsum("nhle,nhse->nhls", Q_grouped, K)
        _, topk = torch.topk(QK, k, dim=-1)
        topk = topk.contiguous()

        self._zero_grad(Q, K)
        for i in range(2000):
            QK = torch.einsum("nhle,nhse->nhls", Q, K)
            QK.sum().backward()
        self._zero_grad(Q, K)

        s = torch.cuda.Event(enable_timing=True)
        e = torch.cuda.Event(enable_timing=True)
        s.record()
        QK = torch.einsum("nhle,nhse->nhls", Q, K)
        QK.sum().backward()
        e.record()
        torch.cuda.synchronize()
        t_full = s.elapsed_time(e)

        self._zero_grad(Q, K)
        for i in range(2000):
            QK = clustered_sparse_dot_product(
                Q, K, topk,
                groups, counts,
                lengths
            )
            QK.sum().backward()
        self._zero_grad(Q, K)

        s = torch.cuda.Event(enable_timing=True)
        e = torch.cuda.Event(enable_timing=True)
        s.record()
        QK = clustered_sparse_dot_product(
            Q, K, topk,
            groups, counts,
            lengths
        )
        QK.sum().backward()
        e.record()
        torch.cuda.synchronize()
        t_sparse = s.elapsed_time(e)
        print("Benchmark Forward-Backward: T_Full: {}, T_Sparse: {}".format(t_full, t_sparse))
Exemplo n.º 8
0
    def test_difficult_grad(self):
        N = 12
        H = 5
        I = 5
        B = 16

        for exp in range(30):
            C = np.random.randint(10, 500)
            L = np.random.randint(C, 2000)
            E = np.random.randint(10, 128)
            S = np.random.randint(100, 1000)
            k = np.random.randint(10, 64)

            if os.getenv("VERBOSE_TESTS", ""):
                print(("Testing: N H L S E C k: "
                       "{} {} {} {} {} {} {}").format(N, H, L, S, E, C, k))

            Q = torch.randn(N, H, L, E).to(self.device)
            K = torch.randn(N, H, S, E).to(self.device)
            Q.requires_grad = True
            K.requires_grad = True
            lengths = torch.full((N, ), L, dtype=torch.int32).to(self.device)
            groups, counts = cluster_queries(Q, lengths, C, I, B)
            Q_grouped = aggregate(Q, groups, 1 / counts.float())
            QK = torch.einsum("nhle,nhse->nhls", Q_grouped, K)
            _, topk = torch.topk(QK, k, dim=-1)
            topk = topk.contiguous()
            topk_broadcast = broadcast(
                topk.float(), groups,
                torch.ones_like(counts, dtype=torch.float32),
                torch.zeros((N, H, L, k), device=Q.device))

            self._zero_grad(Q, K)
            QK_full = torch.einsum("nhle,nhse->nhls", Q, K)
            QK_selected = QK_full[
                torch.arange(N).view(N, 1, 1, 1).to(self.device),
                torch.arange(H).view(1, H, 1, 1).to(self.device),
                torch.arange(L).view(1, 1, L, 1).to(self.device),
                topk_broadcast.long()]

            QK_selected.sum().backward()
            grad = [torch.clone(Q.grad), torch.clone(K.grad)]

            self._zero_grad(Q, K)
            QK_selected_hat = sparse_product(Q, K, groups, topk, counts,
                                             lengths)
            QK_selected_hat.sum().backward()
            grad_hat = [torch.clone(Q.grad), torch.clone(K.grad)]

            self.assertLess(
                torch.abs(QK_selected - QK_selected_hat).max(), 1e-4)
            i = 0
            for g1, g2 in zip(grad, grad_hat):
                self.assertLess(torch.abs(g1 - g2).max(), 1e-3)
                i += 1
Exemplo n.º 9
0
    def test_small_benchmark(self):
        N = 12
        H = 8
        L = 1024
        E = 64
        S = 1024
        k = 32
        C = 100
        I = 10
        B = 32

        Q = torch.randn(N, H, L, E).to(self.device)
        K = torch.randn(N, H, S, E).to(self.device)
        lengths = torch.full((N, ), L, dtype=torch.int32).to(self.device)
        groups, counts = cluster_queries(Q, lengths, C, I, B)

        sorted_g, sorted_gi = torch.sort(groups.view(N * H, -1), dim=-1)
        sorted_rev_gi = torch.argsort(sorted_gi, dim=-1)

        q_offset = torch.arange(N * H, device=Q.device).unsqueeze(-1) * L
        q_flat = (sorted_gi + q_offset).reshape(-1)

        # sorted queries, keys, values
        s_queries = Q.reshape(-1, E).index_select(0, q_flat).view(N, H, L, E)
        Q_grouped = aggregate(s_queries, sorted_g.view(N, H, L),
                              1 / counts.float())
        QK = torch.einsum("nhle,nhse->nhls", Q_grouped, K)
        _, topk = torch.topk(QK, k, dim=-1)
        topk = topk.contiguous()

        products_sorted = clustered_sparse_dot_product(s_queries, K, topk,
                                                       groups, counts, lengths)
        q_rev_flat = (sorted_rev_gi + q_offset).reshape(-1)
        products = products_sorted.reshape(-1, k).index_select(0, q_rev_flat)
        products = products.view(N, H, L, k)

        for i in range(1000):
            products_sorted = clustered_sparse_dot_product(
                s_queries, K, topk, groups, counts, lengths)

        torch.cuda.synchronize()
        s = torch.cuda.Event(enable_timing=True)
        e = torch.cuda.Event(enable_timing=True)
        s.record()
        products_sorted = clustered_sparse_dot_product(s_queries, K, topk,
                                                       groups, counts, lengths)
        e.record()
        torch.cuda.synchronize()
        t_sc = s.elapsed_time(e)

        products_sorted = products_sorted.reshape(-1, k).index_select(
            0, q_rev_flat).view(N, H, L, k)
        topk = topk.contiguous()
        print("Sparse_Clustered: {}".format(t_sc))
    def test_small_forward(self):
        N = 12
        H = 8
        L = 2000
        S = 2000
        E = 32
        k = 32
        C = 100
        I = 10
        B = 32

        Q = torch.randn(N, H, L, E).to(self.device)
        K = torch.randn(N, H, S, E).to(self.device)
        lengths = torch.full((N, ), L, dtype=torch.int32).to(self.device)
        groups, counts = cluster_queries(Q, lengths, C, I, B)

        sorted_g, sorted_gi = torch.sort(groups.view(N * H, -1), dim=-1)
        sorted_rev_gi = torch.argsort(sorted_gi, dim=-1)

        q_offset = torch.arange(N * H, device=Q.device).unsqueeze(-1) * L
        q_flat = (sorted_gi + q_offset).reshape(-1)
        s_queries = Q.reshape(-1, E).index_select(0, q_flat).view(N, H, L, E)
        Q_grouped = aggregate(s_queries, sorted_g.view(N, H, L),
                              1 / counts.float())

        QK = torch.einsum("nhle,nhse->nhls", Q_grouped, K)
        _, topk = torch.topk(QK, k, dim=-1)
        topk = topk.contiguous()
        topk_broadcast = broadcast(
            topk.float(), groups, torch.ones_like(counts, dtype=torch.float32),
            torch.zeros((N, H, L, k), device=Q.device))

        weights_sorted = clustered_sparse_dot_product(s_queries, K, topk,
                                                      groups, counts, lengths)
        q_rev_flat = (sorted_rev_gi + q_offset).reshape(-1)
        weights = weights_sorted.reshape(-1, k).index_select(0,
                                                             q_rev_flat).view(
                                                                 N, H, L, k)

        values = torch.randn(N, H, S, E).to(self.device)
        for i in range(2000):
            output_hat = clustered_sparse_weighted_average(
                weights_sorted, values, topk, groups, counts)

        s = torch.cuda.Event(enable_timing=True)
        e = torch.cuda.Event(enable_timing=True)
        s.record()
        output_hat = clustered_sparse_weighted_average(weights, values, topk,
                                                       groups, counts)
        e.record()
        torch.cuda.synchronize()
        t_sparse = s.elapsed_time(e)
        print('T_sparse Forward:{}'.format(t_sparse))
Exemplo n.º 11
0
    def test_small_benchmark(self):
        N = 12
        H = 8
        L = 1000
        E = 32
        S = 1000
        k = 32
        C = 100
        I = 10
        B = 32

        Q = torch.randn(N, H, L, E).to(self.device)
        K = torch.randn(N, H, S, E).to(self.device)
        lengths = torch.full((N, ), L).int().to(self.device)
        groups, counts = cluster_queries(Q, lengths, C, I, B)
        Q_grouped = aggregate(Q, groups, 1 / counts.float())
        QK = torch.einsum("nhle,nhse->nhls", Q_grouped, K)
        _, topk = torch.topk(QK, k, dim=-1)
        topk = topk.contiguous()
        products = torch.zeros((N, H, L, k),
                               dtype=torch.float32).to(self.device)
        products = clustered_sparse_dot_product(Q, K, topk, groups, counts,
                                                lengths)

        n_runs = 10
        s = time.time()
        for i in range(n_runs):
            products = clustered_sparse_dot_product(Q, K, topk, groups, counts,
                                                    lengths)
        e = time.time()
        t_sc = (e - s) / n_runs

        topk_broadcast = broadcast(
            topk.float(), groups, torch.ones_like(counts, dtype=torch.float32),
            torch.zeros((N, H, L, k), device=Q.device))

        s = time.time()
        for i in range(n_runs):
            products = sparse_dot_product(Q, K, topk_broadcast.long())
        e = time.time()
        t_s = (e - s) / n_runs

        s = time.time()
        for i in range(n_runs):
            torch.einsum("nhle,nhse->nhls", Q, K)
        e = time.time()
        t_f = (e - s) / n_runs
        print("Sparse_Clustered: {}, Sparse: {}, Full: {}".format(
            t_sc, t_s, t_f))
Exemplo n.º 12
0
    def test_simple_product(self):
        N = 2
        H = 2
        L = 100
        E = 32
        S = 50
        k = 32
        C = 5
        I = 5
        B = 16

        for i in range(20):
            k = np.random.randint(10, S)
            E = np.random.randint(10, 129)
            k = 32
            E = 32
            if os.getenv("VERBOSE_TESTS", ""):
                print(("Testing: N H L S E C k: "
                       "{} {} {} {} {} {} {}").format(N, H, L, S, E, C, k))

            Q = torch.randn(N, H, L, E).to(self.device)
            K = torch.randn(N, H, S, E).to(self.device)
            lengths = torch.full((N, ), L, dtype=torch.int32).to(self.device)
            lengths[1] = 50
            lengths[1] = 45
            lengths[1] = 10
            groups, counts = cluster_queries(Q, lengths, C, I, B)
            Q_grouped = aggregate(Q, groups, 1 / counts.float())
            QK = torch.einsum("nhle,nhse->nhls", Q_grouped, K)
            _, topk = torch.topk(QK, k, sorted=False, dim=-1)
            topk = topk.contiguous()
            products, Q_grouped_alt = sparse_product(Q, K, groups, topk,
                                                     counts, lengths, k,
                                                     Q_grouped)
            topk_broadcast = broadcast(
                topk.float(), groups,
                torch.ones_like(counts, dtype=torch.float32),
                torch.zeros((N, H, L, k), device=Q.device))

            all_products = torch.einsum("nhle,nhse->nhls", Q, K)
            products_2 = all_products[torch.arange(N).view(N, 1, 1, 1),
                                      torch.arange(H).view(1, H, 1, 1),
                                      torch.arange(L).view(1, 1, L, 1),
                                      topk_broadcast.long()]
            for i in range(N):
                p_1 = products[i, :, :lengths[i], :]
                p_2 = products_2[i, :, :lengths[i], :]
                self.assertLess(torch.max(torch.abs(p_2 - p_1)), 1e-4)
    def test_broadcast_difficult(self):
        N = 10
        H = 3
        E = 64
        I = 5
        B = 16

        for exp in range(20):
            S = np.random.randint(100, 1000)
            C = np.random.randint(10, 500)
            L = np.random.randint(C, 2000)
            E = np.random.randint(10, 160)
            lengths = torch.tensor(
                np.random.randint(C, L+1, N),
                dtype=torch.int32
            ).cpu()
            if os.getenv("VERBOSE_TESTS", ""):
                print(("Test: N H L S E C: "
                       "{} {} {} {} {} {}").format(N, H, L, S, E, C))

            Q = torch.randn(N, H, L, E).cpu()
            groups, counts = cluster_queries(Q, lengths, C, I, B)
            Q_grouped = aggregate(Q, groups, 1/counts.float())
            K = torch.randn(N, H, S, E).cpu()
            QK = torch.einsum("nhle,nhse->nhls", Q_grouped, K)

            V = torch.randn(N, H, S, E).cpu()
            A = F.softmax(QK, dim=-1)
            V_new = torch.einsum("nhls,nhse->nhle", A, V)
            V_broadcast_2 = broadcast(
                V_new,
                groups,
                torch.ones_like(counts, dtype=torch.float32),
                torch.zeros((N, H, L, E), device=Q.device)
            )
            V_broadcast = try_sorted_broadcast(
                Q, K, V, groups, counts, lengths, Q_grouped
            )
            self.assertLess(
                torch.max(torch.abs(
                    V_broadcast_2
                    - V_broadcast
                    )
                ),
                1e-4
            )
    def test_benchmark_forward_backward(self):
        N = 12
        H = 8
        L = 1024
        S = 1024
        E = 32
        k = 32
        C = 100
        I = 10
        B = 32

        Q = torch.randn(N, H, L, E).to(self.device).requires_grad_(True)
        K = torch.randn(N, H, S, E).to(self.device).requires_grad_(True)
        lengths = torch.full((N,), L, dtype=torch.int32).to(self.device)
        groups, counts = cluster_queries(Q, lengths, C, I, B)
        Q_grouped = aggregate(Q, groups, 1 / counts.float())
        QK = torch.einsum("nhle,nhse->nhls", Q_grouped, K)
        _, topk = torch.topk(QK, k, dim=-1)
        topk = topk.contiguous()

        self._zero_grad(Q, K)
        n_runs = 10
        s = time.time()
        for i in range(n_runs):
            QK = torch.einsum("nhle,nhse->nhls", Q, K)
            QK.sum().backward()
        e = time.time()
        t_full = (e - s) / n_runs

        self._zero_grad(Q, K)
        s = time.time()
        for i in range(n_runs):
            QK = clustered_sparse_dot_product(
                Q, K, topk,
                groups, counts,
                lengths
            )
            QK.sum().backward()
        e = time.time()
        t_sparse = (e - s) / n_runs
        print("Benchmark Forward-Backward: T_Full: {}, T_Sparse: {}".format(t_full, t_sparse))
Exemplo n.º 15
0
    def test_difficult_product(self):
        N = 12
        H = 5
        I = 5
        B = 16

        for exp in range(30):

            C = np.random.randint(10, 500)
            L = np.random.randint(C, 2000)
            E = np.random.randint(10, 128)
            S = np.random.randint(100, 1000)
            k = np.random.randint(10, 64)

            if os.getenv("VERBOSE_TESTS", ""):
                print("Testing: N H L S E C k: {} {} {} {} {} {} {}".format(
                    N, H, L, S, E, C, k))
            Q = torch.randn(N, H, L, E).to(self.device)
            K = torch.randn(N, H, S, E).to(self.device)
            lengths = torch.full((N, ), L, dtype=torch.int32).to(self.device)
            groups, counts = cluster_queries(Q, lengths, C, I, B)

            Q_grouped = aggregate(Q, groups, 1 / counts.float())
            QK = torch.einsum("nhle,nhse->nhls", Q_grouped, K)
            _, topk = torch.topk(QK, k, dim=-1)
            topk = topk.contiguous()

            products, _ = sparse_product(Q, K, groups, topk, counts, lengths,
                                         k, Q_grouped)
            topk_broadcast = broadcast(
                topk.float(), groups,
                torch.ones_like(counts, dtype=torch.float32),
                torch.zeros((N, H, L, k), device=Q.device))

            all_products = torch.einsum("nhle,nhse->nhls", Q, K)
            products_2 = all_products[torch.arange(N).view(N, 1, 1, 1),
                                      torch.arange(H).view(1, H, 1, 1),
                                      torch.arange(L).view(1, 1, L, 1),
                                      topk_broadcast.long()]
            self.assertLess(torch.max(torch.abs(products_2 - products)), 1e-4)
    def test_broadcast_full(self):
        N = 3
        H = 2
        L = 400
        S = 100
        E = 256
        C = 211
        I = 5
        B = 16

        for exp in range(50):
            Q = torch.randn(N, H, L, E).cpu()
            lengths = torch.full((N,), L, dtype=torch.int32).cpu()
            lengths[0] = np.random.randint(C, L+1)
            lengths[1] = np.random.randint(C, L+1)
            groups, counts = cluster_queries(Q, lengths, C, I, B)
            Q_grouped = aggregate(Q, groups, 1/counts.float())
            K = torch.randn(N, H, S, E).cpu()
            QK = torch.einsum("nhle,nhse->nhls", Q_grouped, K)

            V = torch.randn(N, H, S, E).cpu()
            A = F.softmax(QK, dim=-1)
            V_new = torch.einsum("nhls,nhse->nhle", A, V)
            V_broadcast_2 = broadcast(
                V_new,
                groups,
                torch.ones_like(counts, dtype=torch.float32),
                torch.zeros((N, H, L, E), device=Q.device)
            )
            V_broadcast = try_sorted_broadcast(
                Q, K, V, groups, counts, lengths, Q_grouped
            )
            self.assertLess(
                torch.max(torch.abs(
                    V_broadcast_2
                    - V_broadcast
                    )
                ),
                1e-4
            )
def sparse_product(Q, K, groups, topk, counts, lengths):
    N, H, L, E = Q.shape
    _, _, C, k = topk.shape

    sorted_g, sorted_gi = torch.sort(groups.view(N*H, -1), dim=-1)
    sorted_rev_gi = torch.argsort(sorted_gi, dim=-1)

    q_offset = torch.arange(N*H, device=Q.device).unsqueeze(-1) * L
    q_flat = (sorted_gi + q_offset).reshape(-1)

    # sorted queries, keys, values
    s_queries = Q.reshape(-1, E).index_select(0, q_flat).view(N, H, L, E)
    Q_grouped = aggregate(s_queries, sorted_g.view(N, H, L), 1/counts.float())
    topk = topk.contiguous()

    products_sorted = clustered_sparse_dot_product(
        s_queries, K, topk, sorted_g.view(N, H, L), counts, lengths
    )
    q_rev_flat = (sorted_rev_gi + q_offset).reshape(-1)
    products = products_sorted.reshape(-1, k).index_select(
        0, q_rev_flat).view(N, H, L, k)
    return products
Exemplo n.º 18
0
    def test_small_forward(self):
        N = 12
        H = 8
        L = 2000
        S = 2000
        E = 32
        k = 32
        C = 100
        I = 10
        B = 32

        Q = torch.randn(N, H, L, E).to(self.device)
        K = torch.randn(N, H, S, E).to(self.device)
        lengths = torch.full((N, ), L, dtype=torch.int32).to(self.device)
        groups, counts = cluster_queries(Q, lengths, C, I, B)
        Q_grouped = aggregate(Q, groups, 1 / counts.float())
        QK = torch.einsum("nhle,nhse->nhls", Q_grouped, K)
        _, topk = torch.topk(QK, k, dim=-1)
        topk = topk.contiguous()
        topk_broadcast = broadcast(
            topk.float(), groups, torch.ones_like(counts, dtype=torch.float32),
            torch.zeros((N, H, L, k), device=Q.device))

        weights = torch.rand(N, H, L, k).to(self.device).requires_grad_(True)
        values = torch.randn(N, H, S, E).to(self.device).requires_grad_(True)

        for i in range(2000):
            output_hat = clustered_sparse_weighted_average(
                weights, values, topk, groups)

        s = torch.cuda.Event(enable_timing=True)
        e = torch.cuda.Event(enable_timing=True)
        s.record()
        output_hat = clustered_sparse_weighted_average(weights, values, topk,
                                                       groups)
        e.record()
        torch.cuda.synchronize()
        t_sparse = s.elapsed_time(e)
        print('T_sparse Forward:{}'.format(t_sparse))
    def test_broadcast_benchmark(self):
        N = 12
        H = 8
        L = 1000
        S = 1000
        E = 64
        D = 64
        C = 200
        I = 5
        B = 63

        Q = torch.randn(N, H, L, E).cuda()
        lengths = torch.full((N,), L, dtype=torch.int32).cuda()
        groups, counts = cluster_queries(Q, lengths, C, I, B)
        sorted_g, sorted_gi = torch.sort(groups.view(N*H, -1), dim=-1)
        sorted_rev_gi = torch.argsort(sorted_gi, dim=-1)

        q_offset = torch.arange(N*H, device=Q.device).unsqueeze(-1) * L
        q_flat = (sorted_gi + q_offset).reshape(-1)

        Q_grouped = aggregate(Q, groups, 1/counts.float())
        K = torch.randn(N, H, S, E).cuda()
        QK = torch.einsum("nhle,nhse->nhls", Q_grouped, K)

        V = torch.randn(N, H, S, E).cuda()
        A = F.softmax(QK, dim=-1)
        V_new = torch.einsum("nhls,nhse->nhle", A, V)
        V_broadcast = torch.zeros((N, H, L, E), dtype=V_new.dtype).cuda()
        factors = torch.ones_like(counts, dtype=torch.float32)
        V_sorted_broadcast = clustered_broadcast(
            V_new, sorted_g.view(N, H, L), counts, factors, V_broadcast
        )
        q_rev_flat = (sorted_rev_gi + q_offset).reshape(-1)
        V_broadcast = V_sorted_broadcast.reshape(-1, D).index_select(
                0, q_rev_flat).view(N, H, L, D)

        for i in range(2000):
            factors = torch.ones_like(counts, dtype=torch.float32)
            V_sorted_broadcast = clustered_broadcast(
                V_new, sorted_g.view(N, H, L), counts, factors, V_broadcast
            )
            q_rev_flat = (sorted_rev_gi + q_offset).reshape(-1)
            V_broadcast = V_sorted_broadcast.reshape(-1, D).index_select(
                    0, q_rev_flat).view(N, H, L, D)

        s = torch.cuda.Event(enable_timing=True)
        e = torch.cuda.Event(enable_timing=True)
        s.record()
        factors = torch.ones_like(counts, dtype=torch.float32)
        V_sorted_broadcast = clustered_broadcast(
            V_new, sorted_g.view(N, H, L), counts, factors, V_broadcast
        )
        q_rev_flat = (sorted_rev_gi + q_offset).reshape(-1)
        V_broadcast = V_sorted_broadcast.reshape(-1, D).index_select(
                0, q_rev_flat).view(N, H, L, D)
        e.record()
        torch.cuda.synchronize()
        t_broadcast = s.elapsed_time(e)

        for i in range(200):
            V_broadcast_2 = broadcast(
                V_new,
                groups,
                torch.ones_like(counts, dtype=torch.float32),
                torch.zeros((N, H, L, E), device=Q.device)
            )

        s = torch.cuda.Event(enable_timing=True)
        e = torch.cuda.Event(enable_timing=True)
        s.record()
        V_broadcast_2 = broadcast(
            V_new,
            groups,
            torch.ones_like(counts, dtype=torch.float32),
            torch.zeros((N, H, L, E), device=Q.device)
        )
        e.record()
        torch.cuda.synchronize()
        t_broadcast_2 = s.elapsed_time(e)

        print("B1: {}, B2: {}".format(t_broadcast, t_broadcast_2))
Exemplo n.º 20
0
    def test_small_benchmark(self):
        N = 12
        H = 8
        L = 1000
        E = 32
        S = 1000
        k = 32
        C = 100
        I = 10
        B = 32

        Q = torch.randn(N, H, L, E).to(self.device)
        K = torch.randn(N, H, S, E).to(self.device)
        lengths = torch.full((N,), L, dtype=torch.int32).to(self.device)
        groups, counts = cluster_queries(Q, lengths, C, I, B)
        Q_grouped = aggregate(Q, groups, 1/counts.float())
        QK = torch.einsum("nhle,nhse->nhls", Q_grouped, K)
        _, topk = torch.topk(QK, k, dim=-1)
        topk = topk.contiguous()
        products = torch.zeros((N, H, L, k), dtype=torch.float32).to(self.device)
        products = clustered_sparse_dot_product(Q, K, topk, groups, counts, lengths)
        for i in range(1000):
            products = clustered_sparse_dot_product(
                Q,
                K,
                topk,
                groups,
                counts,
                lengths
            )

        torch.cuda.synchronize()
        s = torch.cuda.Event(enable_timing=True)
        e = torch.cuda.Event(enable_timing=True)
        s.record()
        products = clustered_sparse_dot_product(
            Q,
            K,
            topk,
            groups,
            counts,
            lengths
        )
        e.record()
        torch.cuda.synchronize()
        t_sc = s.elapsed_time(e)
   
        topk_broadcast = broadcast(
            topk.float(),
            groups,
            torch.ones_like(counts, dtype=torch.float32),
            torch.zeros((N, H, L, k), device=Q.device)
        )

        for i in range(1000):
            products = sparse_dot_product(
                Q,
                K,
                topk_broadcast.long()
            )
        torch.cuda.synchronize()
        s = torch.cuda.Event(enable_timing=True)
        e = torch.cuda.Event(enable_timing=True)
        s.record()
        products_s = sparse_dot_product(
            Q,
            K,
            topk_broadcast.long(),
        )
        e.record()
        torch.cuda.synchronize()
        t_s = s.elapsed_time(e)

        for i in range(1000):
            torch.einsum("nhle,nhse->nhls", Q, K)
        s = torch.cuda.Event(enable_timing=True)
        e = torch.cuda.Event(enable_timing=True)
        s.record()
        torch.einsum("nhle,nhse->nhls", Q, K)
        e.record()
        torch.cuda.synchronize()
        t_f = s.elapsed_time(e)
        print("Sparse_Clustered: {}, Sparse: {}, Full: {}".format(t_sc, t_s, t_f))
Exemplo n.º 21
0
    def test_broadcast_benchmark(self):
        N = 12
        H = 8
        L = 1000
        S = 1000
        E = 32
        C = 200
        I = 5
        B = 32

        Q = torch.randn(N, H, L, E).cuda()
        lengths = torch.full((N, ), L, dtype=torch.int32).cuda()
        groups, counts = cluster_queries(Q, lengths, C, I, B)
        Q_grouped = aggregate(Q, groups, 1 / counts.float())
        K = torch.randn(N, H, S, E).cuda()
        QK = torch.einsum("nhle,nhse->nhls", Q_grouped, K)

        V = torch.randn(N, H, S, E).cuda()
        A = F.softmax(QK, dim=-1)
        V_new = torch.einsum("nhls,nhse->nhle", A, V)
        V_broadcast = torch.zeros((N, H, L, E), dtype=V_new.dtype).cuda()
        # V_broadcast = broadcast_clustered(V_new, groups, counts, lengths, V_broadcast)
        V_broadcast = clustered_broadcast(V_new, groups, counts, lengths,
                                          V_broadcast)

        V_broadcast_2 = broadcast(V_new, groups,
                                  torch.ones_like(counts, dtype=torch.float32),
                                  torch.zeros((N, H, L, E), device=Q.device))

        self.assertLess(torch.max(torch.abs(V_broadcast_2 - V_broadcast)),
                        1e-4)

        for i in range(2000):
            # V_broadcast = broadcast_clustered(V_new, groups, counts, lengths, V_broadcast)
            V_broadcast = clustered_broadcast(V_new, groups, counts, lengths,
                                              V_broadcast)

        s = torch.cuda.Event(enable_timing=True)
        e = torch.cuda.Event(enable_timing=True)
        s.record()
        V_broadcast = clustered_broadcast(V_new, groups, counts, lengths,
                                          V_broadcast)
        e.record()
        torch.cuda.synchronize()
        t_broadcast = s.elapsed_time(e)

        for i in range(200):
            V_broadcast_2 = broadcast(
                V_new, groups, torch.ones_like(counts, dtype=torch.float32),
                torch.zeros((N, H, L, E), device=Q.device))

        s = torch.cuda.Event(enable_timing=True)
        e = torch.cuda.Event(enable_timing=True)
        s.record()
        V_broadcast_2 = broadcast(V_new, groups,
                                  torch.ones_like(counts, dtype=torch.float32),
                                  torch.zeros((N, H, L, E), device=Q.device))
        e.record()
        torch.cuda.synchronize()
        t_broadcast_2 = s.elapsed_time(e)

        print("B1: {}, B2: {}".format(t_broadcast, t_broadcast_2))
    def test_forward(self):
        N = 6
        H = 5
        L = 100
        S = 100
        E = 32
        C = 10
        I = 10
        B = 32
        k = 5

        for exp in range(30):
            C = np.random.randint(10, 500)
            L = np.random.randint(C, 2000)
            E = np.random.randint(10, 128)
            S = np.random.randint(100, 1000)
            k = np.random.randint(10, 64)
            if os.getenv("VERBOSE_TESTS", ""):
                print(("Testing: N H L S E C k: "
                       "{} {} {} {} {} {} {}").format(N, H, L, S, E, C, k))

            Q = torch.randn(N, H, L, E).to(self.device)
            K = torch.randn(N, H, S, E).to(self.device)
            lengths = torch.full((N, ), L, dtype=torch.int32).to(self.device)
            groups, counts = cluster_queries(Q, lengths, C, I, B)

            sorted_g, sorted_gi = torch.sort(groups.view(N * H, -1), dim=-1)
            sorted_rev_gi = torch.argsort(sorted_gi, dim=-1)
            q_offset = torch.arange(N * H, device=Q.device).unsqueeze(-1) * L
            q_flat = (sorted_gi + q_offset).reshape(-1)
            s_queries = Q.reshape(-1, E).index_select(0,
                                                      q_flat).view(N, H, L, E)

            Q_grouped = aggregate(s_queries, sorted_g.view(N, H, L),
                                  1 / counts.float())

            QK = torch.einsum("nhle,nhse->nhls", Q_grouped, K)
            _, topk = torch.topk(QK, k, dim=-1)
            topk = topk.contiguous()
            topk_broadcast = broadcast(
                topk.float(), groups,
                torch.ones_like(counts, dtype=torch.float32),
                torch.zeros((N, H, L, k), device=Q.device))

            weights_sorted = clustered_sparse_dot_product(
                s_queries, K, topk, groups, counts, lengths)
            weights = torch.softmax(weights_sorted, dim=-1)

            q_rev_flat = (sorted_rev_gi + q_offset).reshape(-1)
            weights = weights_sorted.reshape(-1, k).index_select(
                0, q_rev_flat).view(N, H, L, k)
            values = torch.randn(N, H, S,
                                 E).to(self.device).requires_grad_(True)
            values_selected = values[
                torch.arange(N).view(N, 1, 1, 1).to(self.device),
                torch.arange(H).view(1, H, 1, 1).to(self.device),
                topk_broadcast.long()]

            output = (weights.unsqueeze(-1) * values_selected).sum(-2)
            output_hat_sorted = clustered_sparse_weighted_average(
                weights_sorted, values, topk, groups, counts)
            output_hat = output_hat_sorted.reshape(-1, E).index_select(
                0, q_rev_flat).view(N, H, L, E)

            self.assertLess(torch.abs(output_hat - output).max(), 1e-3)
Exemplo n.º 23
0
    def test_benchmark_backward(self):
        N = 12
        H = 8
        L = 1024
        S = 1024
        E = 64
        k = 32
        C = 100
        I = 10
        B = 32

        Q = torch.randn(N, H, L, E).to(self.device).requires_grad_(True)
        K = torch.randn(N, H, S, E).to(self.device).requires_grad_(True)
        lengths = torch.full((N, ), L, dtype=torch.int32).to(self.device)

        self._zero_grad(Q, K)
        for i in range(100):
            QK = torch.einsum("nhle,nhse->nhls", Q, K)
        self._zero_grad(Q, K)

        s = torch.cuda.Event(enable_timing=True)
        e = torch.cuda.Event(enable_timing=True)
        QK = torch.einsum("nhle,nhse->nhls", Q, K)
        s.record()
        QK.sum().backward()
        e.record()
        torch.cuda.synchronize()
        t_full = s.elapsed_time(e)

        self._zero_grad(Q, K)
        groups, counts = cluster_queries(Q, lengths, C, I, B)
        sorted_g, sorted_gi = torch.sort(groups.view(N * H, -1), dim=-1)
        sorted_rev_gi = torch.argsort(sorted_gi, dim=-1)

        q_offset = torch.arange(N * H, device=Q.device).unsqueeze(-1) * L
        q_flat = (sorted_gi + q_offset).reshape(-1)

        s_queries = Q.reshape(-1, E).index_select(0, q_flat).view(N, H, L, E)

        Q_grouped = aggregate(Q, groups, 1 / counts.float())
        QK = torch.einsum("nhle,nhse->nhls", Q_grouped, K)
        _, topk = torch.topk(QK, k, dim=-1)
        topk = topk.contiguous()
        products_sorted = clustered_sparse_dot_product(s_queries, K, topk,
                                                       groups, counts, lengths)
        q_rev_flat = (sorted_rev_gi + q_offset).reshape(-1)
        products = products_sorted.reshape(-1,
                                           k).index_select(0, q_rev_flat).view(
                                               N, H, L, k)

        for i in range(100):
            QK = clustered_sparse_dot_product(s_queries, K, topk, groups,
                                              counts, lengths)
            QK = QK.reshape(-1, k).index_select(0, q_rev_flat).view(N, H, L, k)
        self._zero_grad(Q, K)

        s = torch.cuda.Event(enable_timing=True)
        e = torch.cuda.Event(enable_timing=True)
        QK = clustered_sparse_dot_product(Q, K, topk, groups, counts, lengths)
        QK = QK.reshape(-1, k).index_select(0, q_rev_flat).view(N, H, L, k)
        s.record()
        QK.sum().backward()
        e.record()
        torch.cuda.synchronize()
        t_sparse = s.elapsed_time(e)
        print("Benchmark Backward: T_Full: {}, T_Sparse: {}".format(
            t_full, t_sparse))
    def test_correctness_masked(self):
        N = 12
        H = 6
        L = 1000
        S = 1000
        E = 32
        k = 32
        C = 100
        I = 10
        B = 32
        for exp in range(30):
            N = np.random.randint(1, 6)
            H = np.random.randint(1, 8)
            C = np.random.randint(10, 500)
            L = np.random.randint(C, 2000)
            E = np.random.randint(10, 128)
            S = np.random.randint(100, 1000)
            k = np.random.randint(10, 64)

            if os.getenv("VERBOSE_TESTS", ""):
                print(("Testing Masked: N H L S E C k: "
                       "{} {} {} {} {} {} {}").format(N, H, L, S, E, C, k))

            Q = torch.randn(N, H, L, E).to(self.device)
            K = torch.randn(N, H, S, E).to(self.device)

            lengths = np.random.randint(C, L + 1, N)
            lengths = torch.tensor(lengths, dtype=torch.int32).to(self.device)
            lengths[0] = L
            query_lengths = LengthMask(lengths, max_len=L)
            groups, counts = cluster_queries(Q, lengths, C, I, B)

            sorted_g, sorted_gi = torch.sort(groups.view(N * H, -1), dim=-1)
            sorted_rev_gi = torch.argsort(sorted_gi, dim=-1)

            q_offset = torch.arange(N * H, device=Q.device).unsqueeze(-1) * L
            q_flat = (sorted_gi + q_offset).reshape(-1)
            s_queries = Q.reshape(-1, E).index_select(0,
                                                      q_flat).view(N, H, L, E)
            Q_grouped = aggregate(s_queries, sorted_g.view(N, H, L),
                                  1 / counts.float())

            QK = torch.einsum("nhle,nhse->nhls", Q_grouped, K)
            _, topk = torch.topk(QK, k, dim=-1)
            topk = topk.contiguous()
            topk_broadcast = broadcast(
                topk.float(), groups,
                torch.ones_like(counts, dtype=torch.float32),
                torch.zeros((N, H, L, k), device=Q.device))
            weights_sorted = torch.rand(N, H, L,
                                        k).to(self.device).requires_grad_(True)
            weights_sorted.retain_grad()

            q_rev_flat = (sorted_rev_gi + q_offset).reshape(-1)
            weights = torch.clone(
                weights_sorted.reshape(-1, k).index_select(0, q_rev_flat).view(
                    N, H, L, k))
            weights.retain_grad()

            values = torch.randn(N, H, S,
                                 E).to(self.device).requires_grad_(True)
            self._zero_grad(weights, values)
            values_selected = values[
                torch.arange(N).view(N, 1, 1, 1).to(self.device),
                torch.arange(H).view(1, H, 1, 1).to(self.device),
                topk_broadcast.long()]
            output = (weights.unsqueeze(-1) * values_selected).sum(-2)
            output = output * query_lengths.float_matrix[:, None, :, None]
            output.sum().backward()
            grad = [torch.clone(weights.grad), torch.clone(values.grad)]

            self._zero_grad(weights_sorted, values)
            self._zero_grad(weights, values)

            output_hat_sorted = clustered_sparse_weighted_average(
                weights_sorted, values, topk, groups, counts)
            output_hat = output_hat_sorted.reshape(-1, E).index_select(
                0, q_rev_flat).view(N, H, L, E)

            self.assertLess(torch.abs(output - output_hat).max(), 1e-4)
            output_hat.sum().backward()
            weights_grad_sorted = torch.clone(weights_sorted.grad)
            weights_grad = torch.clone(
                weights_grad_sorted.reshape(-1, k).index_select(
                    0, q_rev_flat).view(N, H, L, k))
            grad_hat = [weights_grad, torch.clone(values.grad)]
            for g1, g2 in zip(grad, grad_hat):
                self.assertLess(torch.abs(g1 - g2).max(), 1e-3)