def test_simple_grad(self):
        N = 2
        H = 4
        L = 100
        S = 100
        E = 32
        k = 10
        Q = torch.randn(N, H, L, E).to(self.device).requires_grad_(True)
        K = torch.randn(N, H, S, E).to(self.device).requires_grad_(True)
        topk = torch.round(torch.cumsum(torch.rand(N, H, L, k) * 10,
                                        dim=-1)).long().to(self.device)

        self._zero_grad(Q, K)
        QK_full = torch.einsum("nhle,nhse->nhls", Q, K)
        QK_selected = QK_full[torch.arange(N).view(N, 1, 1, 1).to(self.device),
                              torch.arange(H).view(1, H, 1, 1).to(self.device),
                              torch.arange(L).view(1, 1, L, 1).to(self.device),
                              topk]
        QK_selected.sum().backward()
        grad = [torch.clone(Q.grad), torch.clone(K.grad)]

        self._zero_grad(Q, K)
        QK_selected_hat = sparse_dot_product(Q, K, topk)
        QK_selected_hat.sum().backward()
        grad_hat = [torch.clone(Q.grad), torch.clone(K.grad)]

        self.assertLess(torch.abs(QK_selected - QK_selected_hat).max(), 1e-4)
        for g1, g2 in zip(grad, grad_hat):
            self.assertLess(torch.abs(g1 - g2).max(), 1e-4)
    def test_small_benchmark(self):
        N = 12
        H = 8
        L = 1000
        S = 1000
        E = 32
        k = 32
        X = torch.randn(N, H, L, E)
        Y = torch.randn(N, H, S, E)
        topk = (torch.cumsum(torch.rand(N, H, L, k) * 40, dim=-1)).long()

        n_runs = 10
        s = time.time()
        for run in range(n_runs):
            products = sparse_dot_product(
                X,
                Y,
                topk,
            )
        e = time.time()
        t_s = (e - s) / n_runs

        s = time.time()
        for run in range(n_runs):
            torch.einsum("nhle,nhse->nhls", X, Y)
        e = time.time()
        t_f = (e - s) / n_runs
        print("Sparse: {}, Full: {}, F/S: {}".format(t_s, t_f, t_f / t_s))
    def test_benchmark_forward_backward(self):
        N = 12
        H = 8
        L = 1024
        S = 1024
        E = 32
        k = 32
        Q = torch.randn(N, H, L, E).to(self.device).requires_grad_(True)
        K = torch.randn(N, H, S, E).to(self.device).requires_grad_(True)
        topk = torch.round(
            torch.cumsum(torch.rand(N, H, L, k) * (S // k),
                         dim=-1)).long().to(self.device)
        n_runs = 10
        self._zero_grad(Q, K)
        s = time.time()
        for i in range(n_runs):
            QK = torch.einsum("nhle,nhse->nhls", Q, K)
            QK.sum().backward()
        e = time.time()
        t_full = (e - s) / n_runs

        self._zero_grad(Q, K)
        s = time.time()
        for i in range(n_runs):
            QK = sparse_dot_product(Q, K, topk)
            QK.sum().backward()
        e = time.time()
        t_sparse = (e - s) / n_runs
        print("Benchmark Forward-Backward: T_Full: {}, T_Sparse: {}".format(
            t_full, t_sparse))
示例#4
0
    def test_simple_product(self):
        X = torch.randn(10, 4, 100, 32).cuda()
        Y = torch.randn(10, 4, 100, 32).cuda()
        lengths = torch.full((10, ), 100).int().cuda()
        topk = (torch.cumsum(torch.rand(10, 4, 100, 10) * 10,
                             dim=-1)).long().cuda()

        A = torch.randn(10, 4, 100, 100).to(X.device).requires_grad_(False)
        topk_v, topk = torch.topk(A, 10, dim=-1)
        topk = topk.contiguous()

        products = sparse_dot_product(
            X,
            Y,
            topk,
        )
        all_products = torch.einsum("nhle,nhse->nhls", X, Y)

        self.assertLess(
            torch.max(
                torch.abs(products -
                          all_products[torch.arange(10).view(10, 1, 1, 1),
                                       torch.arange(4).view(1, 4, 1, 1),
                                       torch.arange(100).view(1, 1, 100, 1),
                                       topk])), 1e-4)
    def test_benchmark_backward(self):
        N = 12
        H = 8
        L = 1024
        S = 1024
        E = 32
        k = 32
        Q = torch.randn(N, H, L, E).to(self.device).requires_grad_(True)
        K = torch.randn(N, H, S, E).to(self.device).requires_grad_(True)
        topk = torch.round(
            torch.cumsum(torch.rand(N, H, L, k) * (S // k),
                         dim=-1)).long().to(self.device)

        self._zero_grad(Q, K)
        for i in range(2000):
            QK = torch.einsum("nhle,nhse->nhls", Q, K)
            QK.sum().backward()
        self._zero_grad(Q, K)

        s = torch.cuda.Event(enable_timing=True)
        e = torch.cuda.Event(enable_timing=True)
        QK = torch.einsum("nhle,nhse->nhls", Q, K)
        s.record()
        QK.sum().backward()
        e.record()
        torch.cuda.synchronize()
        t_full = s.elapsed_time(e)

        self._zero_grad(Q, K)
        for i in range(2000):
            QK = sparse_dot_product(Q, K, topk)
            QK.sum().backward()
        self._zero_grad(Q, K)

        s = torch.cuda.Event(enable_timing=True)
        e = torch.cuda.Event(enable_timing=True)
        QK = sparse_dot_product(Q, K, topk)
        s.record()
        QK.sum().backward()
        e.record()
        torch.cuda.synchronize()
        t_sparse = s.elapsed_time(e)
        print("Benchmark Backward: T_Full: {}, T_Sparse: {}".format(
            t_full, t_sparse))
示例#6
0
    def test_small_benchmark(self):
        N = 12
        H = 8
        L = 1000
        S = 1000
        E = 32
        k = 32
        X = torch.randn(N, H, L, E).cuda()
        Y = torch.randn(N, H, S, E).cuda()

        A = torch.randn(N, H, L, S).to(X.device).requires_grad_(False)
        topk_v, topk = torch.topk(A, k, dim=-1)
        topk = topk.contiguous()

        for i in range(1000):
            products = sparse_dot_product(
                X,
                Y,
                topk,
            )
        torch.cuda.synchronize()
        s = torch.cuda.Event(enable_timing=True)
        e = torch.cuda.Event(enable_timing=True)
        s.record()
        products = sparse_dot_product(
            X,
            Y,
            topk,
        )
        e.record()
        torch.cuda.synchronize()
        t_s = s.elapsed_time(e)
        for i in range(1000):
            torch.einsum("nhle,nhse->nhls", X, Y)
        s = torch.cuda.Event(enable_timing=True)
        e = torch.cuda.Event(enable_timing=True)
        s.record()
        torch.einsum("nhle,nhse->nhls", X, Y)
        e.record()
        torch.cuda.synchronize()
        t_f = s.elapsed_time(e)
        print("Sparse: {}, Full: {}, F/S: {}".format(t_s, t_f, t_f / t_s))
示例#7
0
    def test_small_benchmark(self):
        N = 12
        H = 8
        L = 1000
        E = 32
        S = 1000
        k = 32
        C = 100
        I = 10
        B = 32

        Q = torch.randn(N, H, L, E).to(self.device)
        K = torch.randn(N, H, S, E).to(self.device)
        lengths = torch.full((N, ), L).int().to(self.device)
        groups, counts = cluster_queries(Q, lengths, C, I, B)
        Q_grouped = aggregate(Q, groups, 1 / counts.float())
        QK = torch.einsum("nhle,nhse->nhls", Q_grouped, K)
        _, topk = torch.topk(QK, k, dim=-1)
        topk = topk.contiguous()
        products = torch.zeros((N, H, L, k),
                               dtype=torch.float32).to(self.device)
        products = clustered_sparse_dot_product(Q, K, topk, groups, counts,
                                                lengths)

        n_runs = 10
        s = time.time()
        for i in range(n_runs):
            products = clustered_sparse_dot_product(Q, K, topk, groups, counts,
                                                    lengths)
        e = time.time()
        t_sc = (e - s) / n_runs

        topk_broadcast = broadcast(
            topk.float(), groups, torch.ones_like(counts, dtype=torch.float32),
            torch.zeros((N, H, L, k), device=Q.device))

        s = time.time()
        for i in range(n_runs):
            products = sparse_dot_product(Q, K, topk_broadcast.long())
        e = time.time()
        t_s = (e - s) / n_runs

        s = time.time()
        for i in range(n_runs):
            torch.einsum("nhle,nhse->nhls", Q, K)
        e = time.time()
        t_f = (e - s) / n_runs
        print("Sparse_Clustered: {}, Sparse: {}, Full: {}".format(
            t_sc, t_s, t_f))
示例#8
0
    def test_single_query(self):
        X = torch.randn(1, 1, 1, 32).cuda()
        Y = torch.randn(1, 1, 100, 32).cuda()
        lengths = torch.full((1, ), 1).int().cuda()
        topk = (torch.cumsum(torch.rand(1, 1, 1, 10) * 10,
                             dim=-1)).long().cuda()

        products = sparse_dot_product(
            X,
            Y,
            topk,
        )
        all_products = torch.einsum("nhle,nhse->nhls", X, Y)

        self.assertLess(
            torch.max(
                torch.abs(products.squeeze() -
                          all_products[0, 0, 0, topk[0, 0, 0]])), 1e-4)
    def test_simple_product(self):
        X = torch.randn(10, 4, 100, 32)
        Y = torch.randn(10, 4, 100, 32)
        topk = (torch.cumsum(torch.rand(10, 4, 100, 10) * 10, dim=-1)).long()

        products = sparse_dot_product(
            X,
            Y,
            topk,
        )

        all_products = torch.einsum("nhle,nhse->nhls", X, Y)
        self.assertLess(
            torch.max(
                torch.abs(products -
                          all_products[torch.arange(10).view(10, 1, 1, 1),
                                       torch.arange(4).view(1, 4, 1, 1),
                                       torch.arange(100).view(1, 1, 100, 1),
                                       topk])), 1e-4)
示例#10
0
    def test_small_benchmark(self):
        N = 12
        H = 8
        L = 1000
        E = 32
        S = 1000
        k = 32
        C = 100
        I = 10
        B = 32

        Q = torch.randn(N, H, L, E).to(self.device)
        K = torch.randn(N, H, S, E).to(self.device)
        lengths = torch.full((N,), L, dtype=torch.int32).to(self.device)
        groups, counts = cluster_queries(Q, lengths, C, I, B)
        Q_grouped = aggregate(Q, groups, 1/counts.float())
        QK = torch.einsum("nhle,nhse->nhls", Q_grouped, K)
        _, topk = torch.topk(QK, k, dim=-1)
        topk = topk.contiguous()
        products = torch.zeros((N, H, L, k), dtype=torch.float32).to(self.device)
        products = clustered_sparse_dot_product(Q, K, topk, groups, counts, lengths)
        for i in range(1000):
            products = clustered_sparse_dot_product(
                Q,
                K,
                topk,
                groups,
                counts,
                lengths
            )

        torch.cuda.synchronize()
        s = torch.cuda.Event(enable_timing=True)
        e = torch.cuda.Event(enable_timing=True)
        s.record()
        products = clustered_sparse_dot_product(
            Q,
            K,
            topk,
            groups,
            counts,
            lengths
        )
        e.record()
        torch.cuda.synchronize()
        t_sc = s.elapsed_time(e)
   
        topk_broadcast = broadcast(
            topk.float(),
            groups,
            torch.ones_like(counts, dtype=torch.float32),
            torch.zeros((N, H, L, k), device=Q.device)
        )

        for i in range(1000):
            products = sparse_dot_product(
                Q,
                K,
                topk_broadcast.long()
            )
        torch.cuda.synchronize()
        s = torch.cuda.Event(enable_timing=True)
        e = torch.cuda.Event(enable_timing=True)
        s.record()
        products_s = sparse_dot_product(
            Q,
            K,
            topk_broadcast.long(),
        )
        e.record()
        torch.cuda.synchronize()
        t_s = s.elapsed_time(e)

        for i in range(1000):
            torch.einsum("nhle,nhse->nhls", Q, K)
        s = torch.cuda.Event(enable_timing=True)
        e = torch.cuda.Event(enable_timing=True)
        s.record()
        torch.einsum("nhle,nhse->nhls", Q, K)
        e.record()
        torch.cuda.synchronize()
        t_f = s.elapsed_time(e)
        print("Sparse_Clustered: {}, Sparse: {}, Full: {}".format(t_sc, t_s, t_f))