def try_sorted_broadcast(Q, K, V, groups, counts, lengths, Q_grouped_orig):
    N, H, L, E = Q.shape
    _, _, S, D = V.shape
    sorted_g, sorted_gi = torch.sort(groups.view(N*H, -1), dim=-1)
    sorted_rev_gi = torch.argsort(sorted_gi, dim=-1)

    q_offset = torch.arange(N*H, device=Q.device).unsqueeze(-1) * L
    q_flat = (sorted_gi + q_offset).reshape(-1)

    # sorted queries, keys, values
    s_queries = Q.reshape(-1, E).index_select(0, q_flat).view(N, H, L, E)
    Q_grouped = clustered_aggregate(
        s_queries, sorted_g.view(N, H, L), 1 / counts.float(), lengths
    )
    assert(abs(Q_grouped_orig - Q_grouped).max().item() < 1e-4)
    QK = torch.einsum("nhle,nhse->nhls", Q_grouped, K)
    A = F.softmax(QK, dim=-1)
    V_new = torch.einsum("nhls,nhse->nhle", A, V)
    V_broadcast = torch.zeros((N, H, L, E), dtype=V_new.dtype).cpu()
    factors = torch.ones_like(counts, dtype=torch.float)
    V_sorted_broadcast = clustered_broadcast(
        V_new, sorted_g.view(N, H, L), counts, factors, V_broadcast
    )
    q_rev_flat = (sorted_rev_gi + q_offset).reshape(-1)
    V_broadcast_remap = V_sorted_broadcast.reshape(-1, D).index_select(
        0, q_rev_flat).view(N, H, L, D)
    return V_broadcast_remap
Ejemplo n.º 2
0
    def test_broadcast(self):
        N = 10
        H = 3
        L = 500
        S = 500
        E = 32
        C = 4
        I = 5
        B = 16

        Q = torch.randn(N, H, L, E).cuda()
        lengths = torch.full((N, ), L, dtype=torch.int32).cuda()
        lengths[1] = 400
        lengths[3] = 200
        lengths[7] = 450
        lengths[8] = 150
        groups, counts = cluster_queries(Q, lengths, C, I, B)
        Q_grouped = aggregate(Q, groups, 1 / counts.float())
        K = torch.randn(N, H, S, E).cuda()
        QK = torch.einsum("nhle,nhse->nhls", Q_grouped, K)

        V = torch.randn(N, H, S, E).cuda()
        A = F.softmax(QK, dim=-1)
        V_new = torch.einsum("nhls,nhse->nhle", A, V)
        V_broadcast = torch.zeros((N, H, L, E), dtype=V_new.dtype).cuda()
        # V_broadcast = broadcast_clustered(V_new, groups, counts, lengths, V_broadcast)
        V_broadcast = clustered_broadcast(V_new, groups, counts, lengths,
                                          V_broadcast)

        V_broadcast_2 = broadcast(V_new, groups,
                                  torch.ones_like(counts, dtype=torch.float32),
                                  torch.zeros((N, H, L, E), device=Q.device))
        self.assertLess(torch.max(torch.abs(V_broadcast_2 - V_broadcast)),
                        1e-4)
Ejemplo n.º 3
0
    def test_broadcast_benchmark(self):
        N = 12
        H = 8
        L = 1000
        S = 1000
        E = 32
        C = 200
        I = 5
        B = 32

        Q = torch.randn(N, H, L, E).cuda()
        lengths = torch.full((N, ), L, dtype=torch.int32).cuda()
        groups, counts = cluster_queries(Q, lengths, C, I, B)
        Q_grouped = aggregate(Q, groups, 1 / counts.float())
        K = torch.randn(N, H, S, E).cuda()
        QK = torch.einsum("nhle,nhse->nhls", Q_grouped, K)

        V = torch.randn(N, H, S, E).cuda()
        A = F.softmax(QK, dim=-1)
        V_new = torch.einsum("nhls,nhse->nhle", A, V)
        V_broadcast = torch.zeros((N, H, L, E), dtype=V_new.dtype).cuda()
        # V_broadcast = broadcast_clustered(V_new, groups, counts, lengths, V_broadcast)
        V_broadcast = clustered_broadcast(V_new, groups, counts, lengths,
                                          V_broadcast)

        V_broadcast_2 = broadcast(V_new, groups,
                                  torch.ones_like(counts, dtype=torch.float32),
                                  torch.zeros((N, H, L, E), device=Q.device))

        self.assertLess(torch.max(torch.abs(V_broadcast_2 - V_broadcast)),
                        1e-4)

        for i in range(2000):
            # V_broadcast = broadcast_clustered(V_new, groups, counts, lengths, V_broadcast)
            V_broadcast = clustered_broadcast(V_new, groups, counts, lengths,
                                              V_broadcast)

        s = torch.cuda.Event(enable_timing=True)
        e = torch.cuda.Event(enable_timing=True)
        s.record()
        V_broadcast = clustered_broadcast(V_new, groups, counts, lengths,
                                          V_broadcast)
        e.record()
        torch.cuda.synchronize()
        t_broadcast = s.elapsed_time(e)

        for i in range(200):
            V_broadcast_2 = broadcast(
                V_new, groups, torch.ones_like(counts, dtype=torch.float32),
                torch.zeros((N, H, L, E), device=Q.device))

        s = torch.cuda.Event(enable_timing=True)
        e = torch.cuda.Event(enable_timing=True)
        s.record()
        V_broadcast_2 = broadcast(V_new, groups,
                                  torch.ones_like(counts, dtype=torch.float32),
                                  torch.zeros((N, H, L, E), device=Q.device))
        e.record()
        torch.cuda.synchronize()
        t_broadcast_2 = s.elapsed_time(e)

        print("B1: {}, B2: {}".format(t_broadcast, t_broadcast_2))
    def test_broadcast_benchmark(self):
        N = 12
        H = 8
        L = 1000
        S = 1000
        E = 64
        D = 64
        C = 200
        I = 5
        B = 63

        Q = torch.randn(N, H, L, E).cuda()
        lengths = torch.full((N,), L, dtype=torch.int32).cuda()
        groups, counts = cluster_queries(Q, lengths, C, I, B)
        sorted_g, sorted_gi = torch.sort(groups.view(N*H, -1), dim=-1)
        sorted_rev_gi = torch.argsort(sorted_gi, dim=-1)

        q_offset = torch.arange(N*H, device=Q.device).unsqueeze(-1) * L
        q_flat = (sorted_gi + q_offset).reshape(-1)

        Q_grouped = aggregate(Q, groups, 1/counts.float())
        K = torch.randn(N, H, S, E).cuda()
        QK = torch.einsum("nhle,nhse->nhls", Q_grouped, K)

        V = torch.randn(N, H, S, E).cuda()
        A = F.softmax(QK, dim=-1)
        V_new = torch.einsum("nhls,nhse->nhle", A, V)
        V_broadcast = torch.zeros((N, H, L, E), dtype=V_new.dtype).cuda()
        factors = torch.ones_like(counts, dtype=torch.float32)
        V_sorted_broadcast = clustered_broadcast(
            V_new, sorted_g.view(N, H, L), counts, factors, V_broadcast
        )
        q_rev_flat = (sorted_rev_gi + q_offset).reshape(-1)
        V_broadcast = V_sorted_broadcast.reshape(-1, D).index_select(
                0, q_rev_flat).view(N, H, L, D)

        for i in range(2000):
            factors = torch.ones_like(counts, dtype=torch.float32)
            V_sorted_broadcast = clustered_broadcast(
                V_new, sorted_g.view(N, H, L), counts, factors, V_broadcast
            )
            q_rev_flat = (sorted_rev_gi + q_offset).reshape(-1)
            V_broadcast = V_sorted_broadcast.reshape(-1, D).index_select(
                    0, q_rev_flat).view(N, H, L, D)

        s = torch.cuda.Event(enable_timing=True)
        e = torch.cuda.Event(enable_timing=True)
        s.record()
        factors = torch.ones_like(counts, dtype=torch.float32)
        V_sorted_broadcast = clustered_broadcast(
            V_new, sorted_g.view(N, H, L), counts, factors, V_broadcast
        )
        q_rev_flat = (sorted_rev_gi + q_offset).reshape(-1)
        V_broadcast = V_sorted_broadcast.reshape(-1, D).index_select(
                0, q_rev_flat).view(N, H, L, D)
        e.record()
        torch.cuda.synchronize()
        t_broadcast = s.elapsed_time(e)

        for i in range(200):
            V_broadcast_2 = broadcast(
                V_new,
                groups,
                torch.ones_like(counts, dtype=torch.float32),
                torch.zeros((N, H, L, E), device=Q.device)
            )

        s = torch.cuda.Event(enable_timing=True)
        e = torch.cuda.Event(enable_timing=True)
        s.record()
        V_broadcast_2 = broadcast(
            V_new,
            groups,
            torch.ones_like(counts, dtype=torch.float32),
            torch.zeros((N, H, L, E), device=Q.device)
        )
        e.record()
        torch.cuda.synchronize()
        t_broadcast_2 = s.elapsed_time(e)

        print("B1: {}, B2: {}".format(t_broadcast, t_broadcast_2))