コード例 #1
0
ファイル: test_fmha.py プロジェクト: jpool-nv/apex
    def run_test(self, s: int, b: int, zero_tensors: bool):
        print(f'Test s={s} b={b}, zero_tensors={zero_tensors}')

        torch.manual_seed(1234)
        torch.cuda.manual_seed(1234)

        dtype = torch.float16
        device = torch.device('cuda')

        h = 16
        d = 64

        slens = [s] * b
        a = torch.tensor(np.array([0] + slens), dtype=torch.int32)
        amask = torch.ones(b, h, s, s, dtype=dtype, device=device)
        seqlens = torch.tensor(slens, dtype=torch.int32, device=device)
        cu_seqlens = torch.cumsum(a, 0).to(dtype=torch.int32, device=device)
        total = cu_seqlens[-1].item()

        qkv = torch.randn((b, s, h, 3, d), device=device, dtype=dtype)

        qkv_vs = qkv.permute(0, 1, 3, 2, 4).contiguous().view(b * s, 3, h, d)

        qkv.requires_grad = True

        if b < 4:
            ctx, S_ = mha.fwd(qkv_vs, cu_seqlens, 0.0, s, True, True,
                              zero_tensors, None)
        else:
            ctx, S_ = mha.fwd(qkv_vs, cu_seqlens, 0.0, s, True, False,
                              zero_tensors, None)
        ctx = ctx.view(b, s, h, d)

        ctx_ref = py_mha(qkv, amask, b, s, h, d)
        self.assertTrue(torch.allclose(ctx_ref.float(), ctx.float(),
                                       atol=1e-3))

        labels = torch.randn_like(ctx_ref)
        diff = ctx_ref - labels
        l = (diff * diff).sum() / b
        l.backward()

        dw = ctx_ref.grad.permute(0, 2, 1, 3)

        dw2 = dw.permute(0, 2, 1, 3).clone().detach().contiguous()

        if b < 4:
            dqkv2, _, _ = mha.bwd_nl(dw2, qkv_vs, S_, cu_seqlens, 0.0, s,
                                     zero_tensors)
        else:
            dqkv2, _ = mha.bwd(dw2, qkv_vs, S_, cu_seqlens, 0.0, s,
                               zero_tensors)

        dqkv2 = dqkv2.permute(0, 2, 1, 3).view(b, s, h, 3, d)

        self.assertTrue(
            torch.allclose(qkv.grad.float(), dqkv2.float(), atol=1e-3))
コード例 #2
0
    def backward(ctx, dout):
        qkv, S_dmask = ctx.saved_tensors
        batch_size = ctx.cu_seqlens.numel() - 1
        if batch_size < 4:
            dqkv, dp, _ = mha.bwd_nl(dout, qkv, S_dmask, ctx.cu_seqlens,
                                     ctx.p_dropout, ctx.max_s)
        else:
            dqkv, dp = mha.bwd(dout, qkv, S_dmask, ctx.cu_seqlens,
                               ctx.p_dropout, ctx.max_s)

        return dqkv, None, None, None, None, None, None
コード例 #3
0
    def backward(ctx, dout):
        qkv, S_dmask = ctx.saved_tensors
        dqkv, dp = mha.bwd(dout, qkv, S_dmask, ctx.cu_seqlens, ctx.seqlens,
                           ctx.p_dropout, ctx.max_s)

        return dqkv, None, None, None, None, None, None