Esempio n. 1
0
    def run_test(self, s: int, b: int, zero_tensors: bool):
        print(f'Test s={s} b={b}, zero_tensors={zero_tensors}')

        torch.manual_seed(1234)
        torch.cuda.manual_seed(1234)

        dtype = torch.float16
        device = torch.device('cuda')

        h = 16
        d = 64

        slens = [s] * b
        a = torch.tensor(np.array([0] + slens), dtype=torch.int32)
        amask = torch.ones(b, h, s, s, dtype=dtype, device=device)
        seqlens = torch.tensor(slens, dtype=torch.int32, device=device)
        cu_seqlens = torch.cumsum(a, 0).to(dtype=torch.int32, device=device)
        total = cu_seqlens[-1].item()

        qkv = torch.randn((b, s, h, 3, d), device=device, dtype=dtype)

        qkv_vs = qkv.permute(0, 1, 3, 2, 4).contiguous().view(b * s, 3, h, d)

        qkv.requires_grad = True

        if b < 4:
            ctx, S_ = mha.fwd(qkv_vs, cu_seqlens, 0.0, s, True, True,
                              zero_tensors, None)
        else:
            ctx, S_ = mha.fwd(qkv_vs, cu_seqlens, 0.0, s, True, False,
                              zero_tensors, None)
        ctx = ctx.view(b, s, h, d)

        ctx_ref = py_mha(qkv, amask, b, s, h, d)
        self.assertTrue(torch.allclose(ctx_ref.float(), ctx.float(),
                                       atol=1e-3))

        labels = torch.randn_like(ctx_ref)
        diff = ctx_ref - labels
        l = (diff * diff).sum() / b
        l.backward()

        dw = ctx_ref.grad.permute(0, 2, 1, 3)

        dw2 = dw.permute(0, 2, 1, 3).clone().detach().contiguous()

        if b < 4:
            dqkv2, _, _ = mha.bwd_nl(dw2, qkv_vs, S_, cu_seqlens, 0.0, s,
                                     zero_tensors)
        else:
            dqkv2, _ = mha.bwd(dw2, qkv_vs, S_, cu_seqlens, 0.0, s,
                               zero_tensors)

        dqkv2 = dqkv2.permute(0, 2, 1, 3).view(b, s, h, 3, d)

        self.assertTrue(
            torch.allclose(qkv.grad.float(), dqkv2.float(), atol=1e-3))
Esempio n. 2
0
 def forward(ctx, qkv, cu_seqlens, seqlens, p_dropout, max_s, is_training):
     context, S_dmask = mha.fwd(qkv, cu_seqlens, seqlens, p_dropout, max_s,
                                is_training, None)
     ctx.save_for_backward(qkv, S_dmask)
     ctx.cu_seqlens = cu_seqlens
     ctx.seqlens = seqlens
     ctx.p_dropout = p_dropout
     ctx.max_s = max_s
     return context
Esempio n. 3
0
 def forward(ctx, qkv, cu_seqlens, p_dropout, max_s, is_training):
     batch_size = cu_seqlens.numel() - 1
     if batch_size < 4:
         context, S_dmask = mha.fwd_nl(qkv, cu_seqlens, p_dropout, max_s,
                                       is_training, None)
     else:
         context, S_dmask = mha.fwd(qkv, cu_seqlens, p_dropout, max_s,
                                    is_training, None)
     ctx.save_for_backward(qkv, S_dmask)
     ctx.cu_seqlens = cu_seqlens
     ctx.p_dropout = p_dropout
     ctx.max_s = max_s
     return context
Esempio n. 4
0
 def forward(ctx, qkv, cu_seqlens, p_dropout, max_s, is_training,
             zero_tensors):
     batch_size = cu_seqlens.numel() - 1
     if batch_size < 4:
         max_s = 512
         context, S_dmask = mha.fwd_nl(qkv, cu_seqlens, p_dropout, max_s,
                                       is_training, True, zero_tensors,
                                       None)
     else:
         context, S_dmask = mha.fwd(qkv, cu_seqlens, p_dropout, max_s,
                                    is_training, False, zero_tensors, None)
     ctx.save_for_backward(qkv, S_dmask)
     ctx.cu_seqlens = cu_seqlens
     ctx.p_dropout = p_dropout
     ctx.max_s = max_s
     ctx.zero_tensors = zero_tensors
     return context