def run_test(self, s: int, b: int, zero_tensors: bool): print(f'Test s={s} b={b}, zero_tensors={zero_tensors}') torch.manual_seed(1234) torch.cuda.manual_seed(1234) dtype = torch.float16 device = torch.device('cuda') h = 16 d = 64 slens = [s] * b a = torch.tensor(np.array([0] + slens), dtype=torch.int32) amask = torch.ones(b, h, s, s, dtype=dtype, device=device) seqlens = torch.tensor(slens, dtype=torch.int32, device=device) cu_seqlens = torch.cumsum(a, 0).to(dtype=torch.int32, device=device) total = cu_seqlens[-1].item() qkv = torch.randn((b, s, h, 3, d), device=device, dtype=dtype) qkv_vs = qkv.permute(0, 1, 3, 2, 4).contiguous().view(b * s, 3, h, d) qkv.requires_grad = True if b < 4: ctx, S_ = mha.fwd(qkv_vs, cu_seqlens, 0.0, s, True, True, zero_tensors, None) else: ctx, S_ = mha.fwd(qkv_vs, cu_seqlens, 0.0, s, True, False, zero_tensors, None) ctx = ctx.view(b, s, h, d) ctx_ref = py_mha(qkv, amask, b, s, h, d) self.assertTrue(torch.allclose(ctx_ref.float(), ctx.float(), atol=1e-3)) labels = torch.randn_like(ctx_ref) diff = ctx_ref - labels l = (diff * diff).sum() / b l.backward() dw = ctx_ref.grad.permute(0, 2, 1, 3) dw2 = dw.permute(0, 2, 1, 3).clone().detach().contiguous() if b < 4: dqkv2, _, _ = mha.bwd_nl(dw2, qkv_vs, S_, cu_seqlens, 0.0, s, zero_tensors) else: dqkv2, _ = mha.bwd(dw2, qkv_vs, S_, cu_seqlens, 0.0, s, zero_tensors) dqkv2 = dqkv2.permute(0, 2, 1, 3).view(b, s, h, 3, d) self.assertTrue( torch.allclose(qkv.grad.float(), dqkv2.float(), atol=1e-3))
def backward(ctx, dout): qkv, S_dmask = ctx.saved_tensors batch_size = ctx.cu_seqlens.numel() - 1 if batch_size < 4: dqkv, dp, _ = mha.bwd_nl(dout, qkv, S_dmask, ctx.cu_seqlens, ctx.p_dropout, ctx.max_s) else: dqkv, dp = mha.bwd(dout, qkv, S_dmask, ctx.cu_seqlens, ctx.p_dropout, ctx.max_s) return dqkv, None, None, None, None, None, None
def backward(ctx, dout): qkv, S_dmask = ctx.saved_tensors dqkv, dp = mha.bwd(dout, qkv, S_dmask, ctx.cu_seqlens, ctx.seqlens, ctx.p_dropout, ctx.max_s) return dqkv, None, None, None, None, None, None