def run_test(self, s: int, b: int, zero_tensors: bool): print(f'Test s={s} b={b}, zero_tensors={zero_tensors}') torch.manual_seed(1234) torch.cuda.manual_seed(1234) dtype = torch.float16 device = torch.device('cuda') h = 16 d = 64 slens = [s] * b a = torch.tensor(np.array([0] + slens), dtype=torch.int32) amask = torch.ones(b, h, s, s, dtype=dtype, device=device) seqlens = torch.tensor(slens, dtype=torch.int32, device=device) cu_seqlens = torch.cumsum(a, 0).to(dtype=torch.int32, device=device) total = cu_seqlens[-1].item() qkv = torch.randn((b, s, h, 3, d), device=device, dtype=dtype) qkv_vs = qkv.permute(0, 1, 3, 2, 4).contiguous().view(b * s, 3, h, d) qkv.requires_grad = True if b < 4: ctx, S_ = mha.fwd_nl(qkv_vs, cu_seqlens, 0.0, s, True, zero_tensors, None) else: ctx, S_ = mha.fwd(qkv_vs, cu_seqlens, 0.0, s, True, zero_tensors, None) ctx = ctx.view(b, s, h, d) ctx_ref = py_mha(qkv, amask, b, s, h, d) self.assertTrue(torch.allclose(ctx_ref.float(), ctx.float(), atol=1e-3)) labels = torch.randn_like(ctx_ref) diff = ctx_ref - labels l = (diff * diff).sum() / b l.backward() dw = ctx_ref.grad.permute(0, 2, 1, 3) dw2 = dw.permute(0, 2, 1, 3).clone().detach().contiguous() if b < 4: dqkv2, _, _ = mha.bwd_nl(dw2, qkv_vs, S_, cu_seqlens, 0.0, s, zero_tensors) else: dqkv2, _ = mha.bwd(dw2, qkv_vs, S_, cu_seqlens, 0.0, s, zero_tensors) dqkv2 = dqkv2.permute(0, 2, 1, 3).view(b, s, h, 3, d) self.assertTrue( torch.allclose(qkv.grad.float(), dqkv2.float(), atol=1e-3))
def forward(ctx, qkv, cu_seqlens, p_dropout, max_s, is_training): batch_size = cu_seqlens.numel() - 1 if batch_size < 4: context, S_dmask = mha.fwd_nl(qkv, cu_seqlens, p_dropout, max_s, is_training, None) else: context, S_dmask = mha.fwd(qkv, cu_seqlens, p_dropout, max_s, is_training, None) ctx.save_for_backward(qkv, S_dmask) ctx.cu_seqlens = cu_seqlens ctx.p_dropout = p_dropout ctx.max_s = max_s return context
def forward(ctx, qkv, cu_seqlens, p_dropout, max_s, is_training, zero_tensors): batch_size = cu_seqlens.numel() - 1 if batch_size < 4: max_s = 512 context, S_dmask = mha.fwd_nl(qkv, cu_seqlens, p_dropout, max_s, is_training, True, zero_tensors, None) else: context, S_dmask = mha.fwd(qkv, cu_seqlens, p_dropout, max_s, is_training, False, zero_tensors, None) ctx.save_for_backward(qkv, S_dmask) ctx.cu_seqlens = cu_seqlens ctx.p_dropout = p_dropout ctx.max_s = max_s ctx.zero_tensors = zero_tensors return context