def test_forward(self): N = 5 H = 2 L = 100 S = 100 E = 32 C = 10 I = 10 B = 32 k = 5 Q = torch.randn(N, H, L, E).to(self.device) K = torch.randn(N, H, S, E).to(self.device) lengths = torch.full((N, ), L, dtype=torch.int32).to(self.device) groups, counts = cluster_queries(Q, lengths, C, I, B) Q_grouped = aggregate(Q, groups, 1 / counts.float()) QK = torch.einsum("nhle,nhse->nhls", Q_grouped, K) _, topk = torch.topk(QK, k, dim=-1) topk = topk.contiguous() topk_broadcast = broadcast( topk.float(), groups, torch.ones_like(counts, dtype=torch.float32), torch.zeros((N, H, L, k), device=Q.device)) weights = torch.rand(N, H, L, k).to(self.device).requires_grad_(True) values = torch.randn(N, H, S, E).to(self.device).requires_grad_(True) values_selected = values[ torch.arange(N).view(N, 1, 1, 1).to(self.device), torch.arange(H).view(1, H, 1, 1).to(self.device), topk_broadcast.long()] output = (weights.unsqueeze(-1) * values_selected).sum(-2) output_hat = clustered_sparse_weighted_average(weights, values, topk, groups) self.assertLess(torch.abs(output - output_hat).max(), 1e-4)
def test_broadcast(self): N = 10 H = 3 L = 500 S = 500 E = 32 C = 4 I = 5 B = 16 Q = torch.randn(N, H, L, E).cuda() lengths = torch.full((N, ), L, dtype=torch.int32).cuda() lengths[1] = 400 lengths[3] = 200 lengths[7] = 450 lengths[8] = 150 groups, counts = cluster_queries(Q, lengths, C, I, B) Q_grouped = aggregate(Q, groups, 1 / counts.float()) K = torch.randn(N, H, S, E).cuda() QK = torch.einsum("nhle,nhse->nhls", Q_grouped, K) V = torch.randn(N, H, S, E).cuda() A = F.softmax(QK, dim=-1) V_new = torch.einsum("nhls,nhse->nhle", A, V) V_broadcast = torch.zeros((N, H, L, E), dtype=V_new.dtype).cuda() # V_broadcast = broadcast_clustered(V_new, groups, counts, lengths, V_broadcast) V_broadcast = clustered_broadcast(V_new, groups, counts, lengths, V_broadcast) V_broadcast_2 = broadcast(V_new, groups, torch.ones_like(counts, dtype=torch.float32), torch.zeros((N, H, L, E), device=Q.device)) self.assertLess(torch.max(torch.abs(V_broadcast_2 - V_broadcast)), 1e-4)
def test_simple_grad(self): N = 2 H = 2 L = 1000 E = 32 S = 1000 k = 32 C = 50 I = 5 B = 16 Q = torch.randn(N, H, L, E).to(self.device).requires_grad_(True) K = torch.randn(N, H, S, E).to(self.device).requires_grad_(True) lengths = torch.full((N,), L).int().to(self.device) groups, counts = cluster_queries(Q, lengths, C, I, B) Q_grouped = aggregate(Q, groups, 1/counts.float()) QK = torch.einsum("nhle,nhse->nhls", Q_grouped, K) _, topk = torch.topk(QK, k, dim=-1) topk = topk.contiguous() topk_broadcast = broadcast( topk.float(), groups, torch.ones_like(counts, dtype=torch.float32), torch.zeros((N, H, L, k), device=Q.device) ) self._zero_grad(Q, K) QK_full = torch.einsum("nhle,nhse->nhls", Q, K) QK_selected = QK_full[ torch.arange(N).view(N, 1, 1, 1).to(self.device), torch.arange(H).view(1, H, 1, 1).to(self.device), torch.arange(L).view(1, 1, L, 1).to(self.device), topk_broadcast.long() ] QK_selected.sum().backward() grad = [torch.clone(Q.grad), torch.clone(K.grad)] self._zero_grad(Q, K) QK_selected_hat = clustered_sparse_dot_product( Q, K, topk, groups, counts, lengths ) QK_selected_hat.sum().backward() grad_hat = [torch.clone(Q.grad), torch.clone(K.grad)] self.assertLess( torch.abs(QK_selected - QK_selected_hat).max(), 1e-4 ) for g1, g2 in zip(grad, grad_hat): self.assertLess( torch.abs(g1 - g2).max(), 1e-4 )
def test_simple_product(self): N = 2 H = 2 L = 1000 E = 32 S = 1000 k = 32 C = 50 I = 5 B = 16 Q = torch.randn(N, H, L, E).to(self.device) K = torch.randn(N, H, S, E).to(self.device) lengths = torch.full((N, ), L).int().to(self.device) groups, counts = cluster_queries(Q, lengths, C, I, B) Q_grouped = aggregate(Q, groups, 1 / counts.float()) QK = torch.einsum("nhle,nhse->nhls", Q_grouped, K) _, topk = torch.topk(QK, k, dim=-1) topk = topk.contiguous() products = clustered_sparse_dot_product(Q, K, topk, groups, counts, lengths) topk_broadcast = broadcast( topk.float(), groups, torch.ones_like(counts, dtype=torch.float32), torch.zeros((N, H, L, k), device=Q.device)) all_products = torch.einsum("nhle,nhse->nhls", Q, K) products_2 = all_products[torch.arange(N).view(N, 1, 1, 1), torch.arange(H).view(1, H, 1, 1), torch.arange(L).view(1, 1, L, 1), topk_broadcast.long()] self.assertLess(torch.max(torch.abs(products_2 - products)), 1e-4)
def test_masked_simple_grad(self): N = 4 H = 2 L = 100 E = 64 S = 100 k = 32 C = 5 I = 5 B = 16 for i in range(30): C = np.random.randint(10, 500) L = np.random.randint(C, 2000) E = np.random.randint(10, 128) S = np.random.randint(100, 1000) k = np.random.randint(10, 64) if os.getenv("VEROSE_TESTS", ""): print(("Testing Masked: N H L S E C k: " "{} {} {} {} {} {} {}").format(N, H, L, S, E, C, k)) Q = torch.randn(N, H, L, E).to(self.device).requires_grad_(True) K = torch.randn(N, H, S, E).to(self.device).requires_grad_(True) lengths = np.random.randint(C, L + 1, N) lengths = torch.tensor(lengths, dtype=torch.int32).to(self.device) query_lengths = LengthMask(lengths, max_len=L) groups, counts = cluster_queries(Q, lengths, C, I, B) Q_grouped = aggregate(Q, groups, 1 / counts.float()) QK = torch.einsum("nhle,nhse->nhls", Q_grouped, K) _, topk = torch.topk(QK, k, dim=-1) topk = topk.contiguous() topk_broadcast = broadcast( topk.float(), groups, torch.ones_like(counts, dtype=torch.float32), torch.zeros((N, H, L, k), device=Q.device)) self._zero_grad(Q, K) QK_full = torch.einsum("nhle,nhse->nhls", Q, K) QK_selected = QK_full[ torch.arange(N).view(N, 1, 1, 1).to(self.device), torch.arange(H).view(1, H, 1, 1).to(self.device), torch.arange(L).view(1, 1, L, 1).to(self.device), topk_broadcast.long()] QK_selected = QK_selected * query_lengths.float_matrix[:, None, :, None] QK_selected.sum().backward() grad = [torch.clone(Q.grad), torch.clone(K.grad)] self._zero_grad(Q, K) QK_selected_hat = sparse_product(Q, K, groups, topk, counts, lengths) QK_selected_hat.sum().backward() grad_hat = [torch.clone(Q.grad), torch.clone(K.grad)] self.assertLess( torch.abs(QK_selected - QK_selected_hat).max(), 1e-4) for g1, g2 in zip(grad, grad_hat): self.assertLess(torch.abs(g1 - g2).max(), 1e-3)
def test_small_forward_backward(self): N = 12 H = 8 L = 2000 S = 2000 E = 32 k = 32 C = 100 I = 10 B = 32 Q = torch.randn(N, H, L, E).to(self.device) K = torch.randn(N, H, S, E).to(self.device) lengths = torch.full((N, ), L, dtype=torch.int32).to(self.device) groups, counts = cluster_queries(Q, lengths, C, I, B) Q_grouped = aggregate(Q, groups, 1 / counts.float()) QK = torch.einsum("nhle,nhse->nhls", Q_grouped, K) _, topk = torch.topk(QK, k, dim=-1) topk = topk.contiguous() topk_broadcast = broadcast( topk.float(), groups, torch.ones_like(counts, dtype=torch.float32), torch.zeros((N, H, L, k), device=Q.device)) weights = torch.rand(N, H, L, k).to(self.device).requires_grad_(True) values = torch.randn(N, H, S, E).to(self.device).requires_grad_(True) self._zero_grad(weights, values) n_runs = 20 s = time.time() for i in range(n_runs): output_hat = clustered_sparse_weighted_average( weights, values, topk, groups, counts) output_hat.sum().backward() e = time.time() t_sparse = (e - s) / n_runs print('T_sparse Forward Backward:{}'.format(t_sparse))
def test_difficult_grad(self): N = 12 H = 5 I = 5 B = 16 for exp in range(30): C = np.random.randint(10, 500) L = np.random.randint(C, 2000) E = np.random.randint(10, 128) S = np.random.randint(100, 1000) k = np.random.randint(10, 64) if os.getenv("VERBOSE_TESTS", ""): print(("Testing: N H L S E C k: " "{} {} {} {} {} {} {}").format(N, H, L, S, E, C, k)) Q = torch.randn(N, H, L, E).to(self.device) K = torch.randn(N, H, S, E).to(self.device) Q.requires_grad = True K.requires_grad = True lengths = torch.full((N, ), L, dtype=torch.int32).to(self.device) groups, counts = cluster_queries(Q, lengths, C, I, B) Q_grouped = aggregate(Q, groups, 1 / counts.float()) QK = torch.einsum("nhle,nhse->nhls", Q_grouped, K) _, topk = torch.topk(QK, k, dim=-1) topk = topk.contiguous() topk_broadcast = broadcast( topk.float(), groups, torch.ones_like(counts, dtype=torch.float32), torch.zeros((N, H, L, k), device=Q.device)) self._zero_grad(Q, K) QK_full = torch.einsum("nhle,nhse->nhls", Q, K) QK_selected = QK_full[ torch.arange(N).view(N, 1, 1, 1).to(self.device), torch.arange(H).view(1, H, 1, 1).to(self.device), torch.arange(L).view(1, 1, L, 1).to(self.device), topk_broadcast.long()] QK_selected.sum().backward() grad = [torch.clone(Q.grad), torch.clone(K.grad)] self._zero_grad(Q, K) QK_selected_hat = sparse_product(Q, K, groups, topk, counts, lengths) QK_selected_hat.sum().backward() grad_hat = [torch.clone(Q.grad), torch.clone(K.grad)] self.assertLess( torch.abs(QK_selected - QK_selected_hat).max(), 1e-4) i = 0 for g1, g2 in zip(grad, grad_hat): self.assertLess(torch.abs(g1 - g2).max(), 1e-3) i += 1
def test_small_forward(self): N = 12 H = 8 L = 2000 S = 2000 E = 32 k = 32 C = 100 I = 10 B = 32 Q = torch.randn(N, H, L, E).to(self.device) K = torch.randn(N, H, S, E).to(self.device) lengths = torch.full((N, ), L, dtype=torch.int32).to(self.device) groups, counts = cluster_queries(Q, lengths, C, I, B) sorted_g, sorted_gi = torch.sort(groups.view(N * H, -1), dim=-1) sorted_rev_gi = torch.argsort(sorted_gi, dim=-1) q_offset = torch.arange(N * H, device=Q.device).unsqueeze(-1) * L q_flat = (sorted_gi + q_offset).reshape(-1) s_queries = Q.reshape(-1, E).index_select(0, q_flat).view(N, H, L, E) Q_grouped = aggregate(s_queries, sorted_g.view(N, H, L), 1 / counts.float()) QK = torch.einsum("nhle,nhse->nhls", Q_grouped, K) _, topk = torch.topk(QK, k, dim=-1) topk = topk.contiguous() topk_broadcast = broadcast( topk.float(), groups, torch.ones_like(counts, dtype=torch.float32), torch.zeros((N, H, L, k), device=Q.device)) weights_sorted = clustered_sparse_dot_product(s_queries, K, topk, groups, counts, lengths) q_rev_flat = (sorted_rev_gi + q_offset).reshape(-1) weights = weights_sorted.reshape(-1, k).index_select(0, q_rev_flat).view( N, H, L, k) values = torch.randn(N, H, S, E).to(self.device) for i in range(2000): output_hat = clustered_sparse_weighted_average( weights_sorted, values, topk, groups, counts) s = torch.cuda.Event(enable_timing=True) e = torch.cuda.Event(enable_timing=True) s.record() output_hat = clustered_sparse_weighted_average(weights, values, topk, groups, counts) e.record() torch.cuda.synchronize() t_sparse = s.elapsed_time(e) print('T_sparse Forward:{}'.format(t_sparse))
def test_small_benchmark(self): N = 12 H = 8 L = 1000 E = 32 S = 1000 k = 32 C = 100 I = 10 B = 32 Q = torch.randn(N, H, L, E).to(self.device) K = torch.randn(N, H, S, E).to(self.device) lengths = torch.full((N, ), L).int().to(self.device) groups, counts = cluster_queries(Q, lengths, C, I, B) Q_grouped = aggregate(Q, groups, 1 / counts.float()) QK = torch.einsum("nhle,nhse->nhls", Q_grouped, K) _, topk = torch.topk(QK, k, dim=-1) topk = topk.contiguous() products = torch.zeros((N, H, L, k), dtype=torch.float32).to(self.device) products = clustered_sparse_dot_product(Q, K, topk, groups, counts, lengths) n_runs = 10 s = time.time() for i in range(n_runs): products = clustered_sparse_dot_product(Q, K, topk, groups, counts, lengths) e = time.time() t_sc = (e - s) / n_runs topk_broadcast = broadcast( topk.float(), groups, torch.ones_like(counts, dtype=torch.float32), torch.zeros((N, H, L, k), device=Q.device)) s = time.time() for i in range(n_runs): products = sparse_dot_product(Q, K, topk_broadcast.long()) e = time.time() t_s = (e - s) / n_runs s = time.time() for i in range(n_runs): torch.einsum("nhle,nhse->nhls", Q, K) e = time.time() t_f = (e - s) / n_runs print("Sparse_Clustered: {}, Sparse: {}, Full: {}".format( t_sc, t_s, t_f))
def test_simple_product(self): N = 2 H = 2 L = 100 E = 32 S = 50 k = 32 C = 5 I = 5 B = 16 for i in range(20): k = np.random.randint(10, S) E = np.random.randint(10, 129) k = 32 E = 32 if os.getenv("VERBOSE_TESTS", ""): print(("Testing: N H L S E C k: " "{} {} {} {} {} {} {}").format(N, H, L, S, E, C, k)) Q = torch.randn(N, H, L, E).to(self.device) K = torch.randn(N, H, S, E).to(self.device) lengths = torch.full((N, ), L, dtype=torch.int32).to(self.device) lengths[1] = 50 lengths[1] = 45 lengths[1] = 10 groups, counts = cluster_queries(Q, lengths, C, I, B) Q_grouped = aggregate(Q, groups, 1 / counts.float()) QK = torch.einsum("nhle,nhse->nhls", Q_grouped, K) _, topk = torch.topk(QK, k, sorted=False, dim=-1) topk = topk.contiguous() products, Q_grouped_alt = sparse_product(Q, K, groups, topk, counts, lengths, k, Q_grouped) topk_broadcast = broadcast( topk.float(), groups, torch.ones_like(counts, dtype=torch.float32), torch.zeros((N, H, L, k), device=Q.device)) all_products = torch.einsum("nhle,nhse->nhls", Q, K) products_2 = all_products[torch.arange(N).view(N, 1, 1, 1), torch.arange(H).view(1, H, 1, 1), torch.arange(L).view(1, 1, L, 1), topk_broadcast.long()] for i in range(N): p_1 = products[i, :, :lengths[i], :] p_2 = products_2[i, :, :lengths[i], :] self.assertLess(torch.max(torch.abs(p_2 - p_1)), 1e-4)
def test_broadcast_difficult(self): N = 10 H = 3 E = 64 I = 5 B = 16 for exp in range(20): S = np.random.randint(100, 1000) C = np.random.randint(10, 500) L = np.random.randint(C, 2000) E = np.random.randint(10, 160) lengths = torch.tensor( np.random.randint(C, L+1, N), dtype=torch.int32 ).cpu() if os.getenv("VERBOSE_TESTS", ""): print(("Test: N H L S E C: " "{} {} {} {} {} {}").format(N, H, L, S, E, C)) Q = torch.randn(N, H, L, E).cpu() groups, counts = cluster_queries(Q, lengths, C, I, B) Q_grouped = aggregate(Q, groups, 1/counts.float()) K = torch.randn(N, H, S, E).cpu() QK = torch.einsum("nhle,nhse->nhls", Q_grouped, K) V = torch.randn(N, H, S, E).cpu() A = F.softmax(QK, dim=-1) V_new = torch.einsum("nhls,nhse->nhle", A, V) V_broadcast_2 = broadcast( V_new, groups, torch.ones_like(counts, dtype=torch.float32), torch.zeros((N, H, L, E), device=Q.device) ) V_broadcast = try_sorted_broadcast( Q, K, V, groups, counts, lengths, Q_grouped ) self.assertLess( torch.max(torch.abs( V_broadcast_2 - V_broadcast ) ), 1e-4 )
def test_difficult_product(self): N = 12 H = 5 I = 5 B = 16 for exp in range(30): C = np.random.randint(10, 500) L = np.random.randint(C, 2000) E = np.random.randint(10, 128) S = np.random.randint(100, 1000) k = np.random.randint(10, 64) if os.getenv("VERBOSE_TESTS", ""): print("Testing: N H L S E C k: {} {} {} {} {} {} {}".format( N, H, L, S, E, C, k)) Q = torch.randn(N, H, L, E).to(self.device) K = torch.randn(N, H, S, E).to(self.device) lengths = torch.full((N, ), L, dtype=torch.int32).to(self.device) groups, counts = cluster_queries(Q, lengths, C, I, B) Q_grouped = aggregate(Q, groups, 1 / counts.float()) QK = torch.einsum("nhle,nhse->nhls", Q_grouped, K) _, topk = torch.topk(QK, k, dim=-1) topk = topk.contiguous() products, _ = sparse_product(Q, K, groups, topk, counts, lengths, k, Q_grouped) topk_broadcast = broadcast( topk.float(), groups, torch.ones_like(counts, dtype=torch.float32), torch.zeros((N, H, L, k), device=Q.device)) all_products = torch.einsum("nhle,nhse->nhls", Q, K) products_2 = all_products[torch.arange(N).view(N, 1, 1, 1), torch.arange(H).view(1, H, 1, 1), torch.arange(L).view(1, 1, L, 1), topk_broadcast.long()] self.assertLess(torch.max(torch.abs(products_2 - products)), 1e-4)
def test_broadcast_full(self): N = 3 H = 2 L = 400 S = 100 E = 256 C = 211 I = 5 B = 16 for exp in range(50): Q = torch.randn(N, H, L, E).cpu() lengths = torch.full((N,), L, dtype=torch.int32).cpu() lengths[0] = np.random.randint(C, L+1) lengths[1] = np.random.randint(C, L+1) groups, counts = cluster_queries(Q, lengths, C, I, B) Q_grouped = aggregate(Q, groups, 1/counts.float()) K = torch.randn(N, H, S, E).cpu() QK = torch.einsum("nhle,nhse->nhls", Q_grouped, K) V = torch.randn(N, H, S, E).cpu() A = F.softmax(QK, dim=-1) V_new = torch.einsum("nhls,nhse->nhle", A, V) V_broadcast_2 = broadcast( V_new, groups, torch.ones_like(counts, dtype=torch.float32), torch.zeros((N, H, L, E), device=Q.device) ) V_broadcast = try_sorted_broadcast( Q, K, V, groups, counts, lengths, Q_grouped ) self.assertLess( torch.max(torch.abs( V_broadcast_2 - V_broadcast ) ), 1e-4 )
def test_small_forward(self): N = 12 H = 8 L = 2000 S = 2000 E = 32 k = 32 C = 100 I = 10 B = 32 Q = torch.randn(N, H, L, E).to(self.device) K = torch.randn(N, H, S, E).to(self.device) lengths = torch.full((N, ), L, dtype=torch.int32).to(self.device) groups, counts = cluster_queries(Q, lengths, C, I, B) Q_grouped = aggregate(Q, groups, 1 / counts.float()) QK = torch.einsum("nhle,nhse->nhls", Q_grouped, K) _, topk = torch.topk(QK, k, dim=-1) topk = topk.contiguous() topk_broadcast = broadcast( topk.float(), groups, torch.ones_like(counts, dtype=torch.float32), torch.zeros((N, H, L, k), device=Q.device)) weights = torch.rand(N, H, L, k).to(self.device).requires_grad_(True) values = torch.randn(N, H, S, E).to(self.device).requires_grad_(True) for i in range(2000): output_hat = clustered_sparse_weighted_average( weights, values, topk, groups) s = torch.cuda.Event(enable_timing=True) e = torch.cuda.Event(enable_timing=True) s.record() output_hat = clustered_sparse_weighted_average(weights, values, topk, groups) e.record() torch.cuda.synchronize() t_sparse = s.elapsed_time(e) print('T_sparse Forward:{}'.format(t_sparse))
def test_broadcast_benchmark(self): N = 12 H = 8 L = 1000 S = 1000 E = 64 D = 64 C = 200 I = 5 B = 63 Q = torch.randn(N, H, L, E).cuda() lengths = torch.full((N,), L, dtype=torch.int32).cuda() groups, counts = cluster_queries(Q, lengths, C, I, B) sorted_g, sorted_gi = torch.sort(groups.view(N*H, -1), dim=-1) sorted_rev_gi = torch.argsort(sorted_gi, dim=-1) q_offset = torch.arange(N*H, device=Q.device).unsqueeze(-1) * L q_flat = (sorted_gi + q_offset).reshape(-1) Q_grouped = aggregate(Q, groups, 1/counts.float()) K = torch.randn(N, H, S, E).cuda() QK = torch.einsum("nhle,nhse->nhls", Q_grouped, K) V = torch.randn(N, H, S, E).cuda() A = F.softmax(QK, dim=-1) V_new = torch.einsum("nhls,nhse->nhle", A, V) V_broadcast = torch.zeros((N, H, L, E), dtype=V_new.dtype).cuda() factors = torch.ones_like(counts, dtype=torch.float32) V_sorted_broadcast = clustered_broadcast( V_new, sorted_g.view(N, H, L), counts, factors, V_broadcast ) q_rev_flat = (sorted_rev_gi + q_offset).reshape(-1) V_broadcast = V_sorted_broadcast.reshape(-1, D).index_select( 0, q_rev_flat).view(N, H, L, D) for i in range(2000): factors = torch.ones_like(counts, dtype=torch.float32) V_sorted_broadcast = clustered_broadcast( V_new, sorted_g.view(N, H, L), counts, factors, V_broadcast ) q_rev_flat = (sorted_rev_gi + q_offset).reshape(-1) V_broadcast = V_sorted_broadcast.reshape(-1, D).index_select( 0, q_rev_flat).view(N, H, L, D) s = torch.cuda.Event(enable_timing=True) e = torch.cuda.Event(enable_timing=True) s.record() factors = torch.ones_like(counts, dtype=torch.float32) V_sorted_broadcast = clustered_broadcast( V_new, sorted_g.view(N, H, L), counts, factors, V_broadcast ) q_rev_flat = (sorted_rev_gi + q_offset).reshape(-1) V_broadcast = V_sorted_broadcast.reshape(-1, D).index_select( 0, q_rev_flat).view(N, H, L, D) e.record() torch.cuda.synchronize() t_broadcast = s.elapsed_time(e) for i in range(200): V_broadcast_2 = broadcast( V_new, groups, torch.ones_like(counts, dtype=torch.float32), torch.zeros((N, H, L, E), device=Q.device) ) s = torch.cuda.Event(enable_timing=True) e = torch.cuda.Event(enable_timing=True) s.record() V_broadcast_2 = broadcast( V_new, groups, torch.ones_like(counts, dtype=torch.float32), torch.zeros((N, H, L, E), device=Q.device) ) e.record() torch.cuda.synchronize() t_broadcast_2 = s.elapsed_time(e) print("B1: {}, B2: {}".format(t_broadcast, t_broadcast_2))
def test_small_benchmark(self): N = 12 H = 8 L = 1000 E = 32 S = 1000 k = 32 C = 100 I = 10 B = 32 Q = torch.randn(N, H, L, E).to(self.device) K = torch.randn(N, H, S, E).to(self.device) lengths = torch.full((N,), L, dtype=torch.int32).to(self.device) groups, counts = cluster_queries(Q, lengths, C, I, B) Q_grouped = aggregate(Q, groups, 1/counts.float()) QK = torch.einsum("nhle,nhse->nhls", Q_grouped, K) _, topk = torch.topk(QK, k, dim=-1) topk = topk.contiguous() products = torch.zeros((N, H, L, k), dtype=torch.float32).to(self.device) products = clustered_sparse_dot_product(Q, K, topk, groups, counts, lengths) for i in range(1000): products = clustered_sparse_dot_product( Q, K, topk, groups, counts, lengths ) torch.cuda.synchronize() s = torch.cuda.Event(enable_timing=True) e = torch.cuda.Event(enable_timing=True) s.record() products = clustered_sparse_dot_product( Q, K, topk, groups, counts, lengths ) e.record() torch.cuda.synchronize() t_sc = s.elapsed_time(e) topk_broadcast = broadcast( topk.float(), groups, torch.ones_like(counts, dtype=torch.float32), torch.zeros((N, H, L, k), device=Q.device) ) for i in range(1000): products = sparse_dot_product( Q, K, topk_broadcast.long() ) torch.cuda.synchronize() s = torch.cuda.Event(enable_timing=True) e = torch.cuda.Event(enable_timing=True) s.record() products_s = sparse_dot_product( Q, K, topk_broadcast.long(), ) e.record() torch.cuda.synchronize() t_s = s.elapsed_time(e) for i in range(1000): torch.einsum("nhle,nhse->nhls", Q, K) s = torch.cuda.Event(enable_timing=True) e = torch.cuda.Event(enable_timing=True) s.record() torch.einsum("nhle,nhse->nhls", Q, K) e.record() torch.cuda.synchronize() t_f = s.elapsed_time(e) print("Sparse_Clustered: {}, Sparse: {}, Full: {}".format(t_sc, t_s, t_f))
def test_broadcast_benchmark(self): N = 12 H = 8 L = 1000 S = 1000 E = 32 C = 200 I = 5 B = 32 Q = torch.randn(N, H, L, E).cuda() lengths = torch.full((N, ), L, dtype=torch.int32).cuda() groups, counts = cluster_queries(Q, lengths, C, I, B) Q_grouped = aggregate(Q, groups, 1 / counts.float()) K = torch.randn(N, H, S, E).cuda() QK = torch.einsum("nhle,nhse->nhls", Q_grouped, K) V = torch.randn(N, H, S, E).cuda() A = F.softmax(QK, dim=-1) V_new = torch.einsum("nhls,nhse->nhle", A, V) V_broadcast = torch.zeros((N, H, L, E), dtype=V_new.dtype).cuda() # V_broadcast = broadcast_clustered(V_new, groups, counts, lengths, V_broadcast) V_broadcast = clustered_broadcast(V_new, groups, counts, lengths, V_broadcast) V_broadcast_2 = broadcast(V_new, groups, torch.ones_like(counts, dtype=torch.float32), torch.zeros((N, H, L, E), device=Q.device)) self.assertLess(torch.max(torch.abs(V_broadcast_2 - V_broadcast)), 1e-4) for i in range(2000): # V_broadcast = broadcast_clustered(V_new, groups, counts, lengths, V_broadcast) V_broadcast = clustered_broadcast(V_new, groups, counts, lengths, V_broadcast) s = torch.cuda.Event(enable_timing=True) e = torch.cuda.Event(enable_timing=True) s.record() V_broadcast = clustered_broadcast(V_new, groups, counts, lengths, V_broadcast) e.record() torch.cuda.synchronize() t_broadcast = s.elapsed_time(e) for i in range(200): V_broadcast_2 = broadcast( V_new, groups, torch.ones_like(counts, dtype=torch.float32), torch.zeros((N, H, L, E), device=Q.device)) s = torch.cuda.Event(enable_timing=True) e = torch.cuda.Event(enable_timing=True) s.record() V_broadcast_2 = broadcast(V_new, groups, torch.ones_like(counts, dtype=torch.float32), torch.zeros((N, H, L, E), device=Q.device)) e.record() torch.cuda.synchronize() t_broadcast_2 = s.elapsed_time(e) print("B1: {}, B2: {}".format(t_broadcast, t_broadcast_2))
def test_forward(self): N = 6 H = 5 L = 100 S = 100 E = 32 C = 10 I = 10 B = 32 k = 5 for exp in range(30): C = np.random.randint(10, 500) L = np.random.randint(C, 2000) E = np.random.randint(10, 128) S = np.random.randint(100, 1000) k = np.random.randint(10, 64) if os.getenv("VERBOSE_TESTS", ""): print(("Testing: N H L S E C k: " "{} {} {} {} {} {} {}").format(N, H, L, S, E, C, k)) Q = torch.randn(N, H, L, E).to(self.device) K = torch.randn(N, H, S, E).to(self.device) lengths = torch.full((N, ), L, dtype=torch.int32).to(self.device) groups, counts = cluster_queries(Q, lengths, C, I, B) sorted_g, sorted_gi = torch.sort(groups.view(N * H, -1), dim=-1) sorted_rev_gi = torch.argsort(sorted_gi, dim=-1) q_offset = torch.arange(N * H, device=Q.device).unsqueeze(-1) * L q_flat = (sorted_gi + q_offset).reshape(-1) s_queries = Q.reshape(-1, E).index_select(0, q_flat).view(N, H, L, E) Q_grouped = aggregate(s_queries, sorted_g.view(N, H, L), 1 / counts.float()) QK = torch.einsum("nhle,nhse->nhls", Q_grouped, K) _, topk = torch.topk(QK, k, dim=-1) topk = topk.contiguous() topk_broadcast = broadcast( topk.float(), groups, torch.ones_like(counts, dtype=torch.float32), torch.zeros((N, H, L, k), device=Q.device)) weights_sorted = clustered_sparse_dot_product( s_queries, K, topk, groups, counts, lengths) weights = torch.softmax(weights_sorted, dim=-1) q_rev_flat = (sorted_rev_gi + q_offset).reshape(-1) weights = weights_sorted.reshape(-1, k).index_select( 0, q_rev_flat).view(N, H, L, k) values = torch.randn(N, H, S, E).to(self.device).requires_grad_(True) values_selected = values[ torch.arange(N).view(N, 1, 1, 1).to(self.device), torch.arange(H).view(1, H, 1, 1).to(self.device), topk_broadcast.long()] output = (weights.unsqueeze(-1) * values_selected).sum(-2) output_hat_sorted = clustered_sparse_weighted_average( weights_sorted, values, topk, groups, counts) output_hat = output_hat_sorted.reshape(-1, E).index_select( 0, q_rev_flat).view(N, H, L, E) self.assertLess(torch.abs(output_hat - output).max(), 1e-3)
def test_correctness_masked(self): N = 12 H = 6 L = 1000 S = 1000 E = 32 k = 32 C = 100 I = 10 B = 32 for exp in range(30): N = np.random.randint(1, 6) H = np.random.randint(1, 8) C = np.random.randint(10, 500) L = np.random.randint(C, 2000) E = np.random.randint(10, 128) S = np.random.randint(100, 1000) k = np.random.randint(10, 64) if os.getenv("VERBOSE_TESTS", ""): print(("Testing Masked: N H L S E C k: " "{} {} {} {} {} {} {}").format(N, H, L, S, E, C, k)) Q = torch.randn(N, H, L, E).to(self.device) K = torch.randn(N, H, S, E).to(self.device) lengths = np.random.randint(C, L + 1, N) lengths = torch.tensor(lengths, dtype=torch.int32).to(self.device) lengths[0] = L query_lengths = LengthMask(lengths, max_len=L) groups, counts = cluster_queries(Q, lengths, C, I, B) sorted_g, sorted_gi = torch.sort(groups.view(N * H, -1), dim=-1) sorted_rev_gi = torch.argsort(sorted_gi, dim=-1) q_offset = torch.arange(N * H, device=Q.device).unsqueeze(-1) * L q_flat = (sorted_gi + q_offset).reshape(-1) s_queries = Q.reshape(-1, E).index_select(0, q_flat).view(N, H, L, E) Q_grouped = aggregate(s_queries, sorted_g.view(N, H, L), 1 / counts.float()) QK = torch.einsum("nhle,nhse->nhls", Q_grouped, K) _, topk = torch.topk(QK, k, dim=-1) topk = topk.contiguous() topk_broadcast = broadcast( topk.float(), groups, torch.ones_like(counts, dtype=torch.float32), torch.zeros((N, H, L, k), device=Q.device)) weights_sorted = torch.rand(N, H, L, k).to(self.device).requires_grad_(True) weights_sorted.retain_grad() q_rev_flat = (sorted_rev_gi + q_offset).reshape(-1) weights = torch.clone( weights_sorted.reshape(-1, k).index_select(0, q_rev_flat).view( N, H, L, k)) weights.retain_grad() values = torch.randn(N, H, S, E).to(self.device).requires_grad_(True) self._zero_grad(weights, values) values_selected = values[ torch.arange(N).view(N, 1, 1, 1).to(self.device), torch.arange(H).view(1, H, 1, 1).to(self.device), topk_broadcast.long()] output = (weights.unsqueeze(-1) * values_selected).sum(-2) output = output * query_lengths.float_matrix[:, None, :, None] output.sum().backward() grad = [torch.clone(weights.grad), torch.clone(values.grad)] self._zero_grad(weights_sorted, values) self._zero_grad(weights, values) output_hat_sorted = clustered_sparse_weighted_average( weights_sorted, values, topk, groups, counts) output_hat = output_hat_sorted.reshape(-1, E).index_select( 0, q_rev_flat).view(N, H, L, E) self.assertLess(torch.abs(output - output_hat).max(), 1e-4) output_hat.sum().backward() weights_grad_sorted = torch.clone(weights_sorted.grad) weights_grad = torch.clone( weights_grad_sorted.reshape(-1, k).index_select( 0, q_rev_flat).view(N, H, L, k)) grad_hat = [weights_grad, torch.clone(values.grad)] for g1, g2 in zip(grad, grad_hat): self.assertLess(torch.abs(g1 - g2).max(), 1e-3)