コード例 #1
0
ファイル: test_nestedtensor.py プロジェクト: alvgaona/pytorch
    def test_repr_string(self):
        a = nested_tensor([])
        expected = "nested_tensor([" "\n\n])"
        self.assertEqual(str(a), expected)
        self.assertEqual(repr(a), expected)

        a = nested_tensor([torch.tensor(1.0)])
        expected = "nested_tensor([" "\n  tensor(1.)" "\n])"
        self.assertEqual(str(a), expected)
        self.assertEqual(repr(a), expected)

        a = nested_tensor([torch.tensor([[1, 2]]), torch.tensor([[4, 5]])])
        expected = ("nested_tensor(["
                    "\n  tensor([[1, 2]])"
                    ","
                    "\n  tensor([[4, 5]])"
                    "\n])")
        self.assertEqual(str(a), expected)
        self.assertEqual(repr(a), expected)
コード例 #2
0
 def test_embedding(self, device):
     inputs = [
         torch.randint(100, (L,), device=device, dtype=torch.int64)
         for L in torch.randint(5, 50, (8,))
     ]
     x = torch.nested_tensor(inputs, device=device, dtype=torch.int64)
     emb = torch.nn.Embedding(100, 8, device=device)
     y = emb(x)
     ys = y.unbind()
     for i, inp in enumerate(inputs):
         self.assertEqual(emb(inp), ys[i])
コード例 #3
0
    def test_to_then_from_padded_tensor_no_transform0213(self, device, dtype):
        t = torch.randn(4, 4, 4, device=device, dtype=dtype)
        ts = list(torch.unbind(t))
        ts[0] = ts[0][:-1]
        nt = torch.nested_tensor(ts, device=device, dtype=dtype)
        padded = nt.to_padded_tensor(0)

        nt_to = torch._nested_from_padded_and_nested_example(padded, nt)

        for (t1, t2) in zip(nt.unbind(), nt_to.unbind()):
            self.assertEqual(t1, t2)
        self.assertEqual(nt.device, nt_to.device)
コード例 #4
0
 def _test(size):
     t0 = torch.randn(2, size, device=device, dtype=dtype, requires_grad=False)
     t1 = torch.randn(2, size, device=device, dtype=dtype, requires_grad=False)
     ts = [t0, t1, t0, t1]
     nt = torch.nested_tensor(ts, device=device, dtype=dtype)
     layer_norm = torch.nn.LayerNorm(size, device=device, dtype=dtype)
     nt_result = nt._nested_tensor_layer_norm(
         layer_norm.weight, layer_norm.bias, 1e-5
     )
     for (nt_subresult, t) in zip(nt_result.unbind(), ts):
         t_result = layer_norm(t.reshape(1, -1, size).squeeze(0))
         self.assertEqual(nt_subresult, t_result)
コード例 #5
0
ファイル: test_nestedtensor.py プロジェクト: alvgaona/pytorch
 def random_nt(self, device, dtype, num_tensors, max_dims, min_dims=None):
     if min_dims is None:
         min_dims = tuple([0] * len(max_dims))
     ts1 = []
     for _ in range(num_tensors):
         tensor_dims = tuple([
             torch.randint(low=min_dim, high=max_dim, size=(1, )).item()
             for (min_dim, max_dim) in zip(min_dims, max_dims)
         ])
         t1 = torch.randn(tensor_dims, device=device, dtype=dtype)
         ts1.append(t1)
     return torch.nested_tensor(ts1, device=device, dtype=dtype)
コード例 #6
0
 def test_reshape(self, device, dtype):
     nt = self.random_nt(device, dtype, 4, (4, 4))
     # error case: empty shape
     self.assertRaisesRegex(RuntimeError,
                            r"shape '\[\]' is invalid for a nested tensor",
                            lambda: nt.reshape(()))
     # error case: empty nested tensor
     nt_empty = torch.nested_tensor([])
     self.assertRaisesRegex(RuntimeError,
                            "empty nested tensor cannot be reshaped",
                            lambda: nt_empty.reshape(-1))
     # error case: invalid proposed shape for underlying tensors
     self.assertRaisesRegex(RuntimeError, r"invalid shape dimension -2",
                            lambda: nt.reshape(-2, 2, 3))
     self.assertRaisesRegex(
         RuntimeError,
         r"shape '\[.*\]' is invalid for input of size [0-9]+",
         lambda: nt.reshape(4, 2, 3))
     # normal case
     x0 = torch.randn((2, 20), device=device, dtype=dtype)
     x1 = torch.randn((3, 20), device=device, dtype=dtype)
     nt = torch.nested_tensor([x0, x1])
     pt = nt.to_padded_tensor(0.0)
     self.assertRaisesRegex(
         RuntimeError,
         r"for now reshape cannot change the implicit batch dimension",
         lambda: nt.transpose(-1, -2).reshape(40, -1))
     # inherit only the ragged dimension
     # (2, 20) -> (2, 5, 4)
     # (3, 20) -> (3, 5, 4)
     nt1 = nt.reshape(2, -1, 5, 4)
     # (2, 3, 20) -> (2, 3, 5, 4) -> (2, 4, 5, 4)
     pt1 = pt.reshape(2, -1, 5, 4)
     self.assertEqual(noncontiguous_to_padded_tensor(nt1), pt1)
     # also inherit regular dimension
     nt2 = nt1.reshape(2, -1, -1, 2, 2)
     pt2 = pt1.reshape(2, -1, 5, 2, 2)
     self.assertEqual(noncontiguous_to_padded_tensor(nt2), pt2)
コード例 #7
0
 def test_bmm(self, device, dtype):
     # error case: not 3D tensors
     nt0 = torch.nested_tensor([])
     nt1 = torch.nested_tensor([torch.randn(2), torch.randn(3)])
     nt2 = torch.nested_tensor([torch.randn((2, 4)), torch.randn((3, 4))])
     self.assertRaisesRegex(RuntimeError, "batch1 must be a 3D tensor",
                            lambda: nt0.bmm(nt0))
     self.assertRaisesRegex(RuntimeError, "batch1 must be a 3D tensor",
                            lambda: nt0.bmm(nt1))
     self.assertRaisesRegex(RuntimeError, "batch1 must be a 3D tensor",
                            lambda: nt0.bmm(nt2))
     self.assertRaisesRegex(RuntimeError, "batch1 must be a 3D tensor",
                            lambda: nt1.bmm(nt0))
     self.assertRaisesRegex(RuntimeError, "batch1 must be a 3D tensor",
                            lambda: nt1.bmm(nt1))
     self.assertRaisesRegex(RuntimeError, "batch1 must be a 3D tensor",
                            lambda: nt1.bmm(nt2))
     self.assertRaisesRegex(RuntimeError, "batch2 must be a 3D tensor",
                            lambda: nt2.bmm(nt0))
     self.assertRaisesRegex(RuntimeError, "batch2 must be a 3D tensor",
                            lambda: nt2.bmm(nt1))
     # error case: incompatible batch size
     nt0 = torch.nested_tensor([torch.randn((2, 4)), torch.randn((3, 4))])
     nt1 = torch.nested_tensor(
         [torch.randn((4, 6)),
          torch.randn((4, 5)),
          torch.randn((4, 7))])
     self.assertRaisesRegex(
         RuntimeError,
         "Expected size for the 1st dimension of batch2 tensor to be: 2 but got: 3.",
         lambda: nt0.bmm(nt1))
     self.assertRaisesRegex(
         RuntimeError,
         "Expected size for the 1st dimension of batch2 tensor to be: 3 but got: 2.",
         lambda: nt1.bmm(nt0))
     # error case: underlying matrices cannot be multiplied
     nt0 = torch.nested_tensor([torch.randn((2, 4)), torch.randn((3, 4))])
     self.assertRaisesRegex(
         RuntimeError,
         r"0-th nested matrices in batch cannot be multiplied \(2x4 and 2x4\)",
         lambda: nt0.bmm(nt0))
     # normal nested tensor
     nt0 = torch.nested_tensor([torch.randn((2, 4)), torch.randn((3, 7))])
     nt1 = torch.nested_tensor([torch.randn((4, 6)), torch.randn((7, 5))])
     actual = nt0.bmm(nt1)
     expect = nt0.to_padded_tensor(0.0).bmm(nt1.to_padded_tensor(0.0))
     self.assertEqual(actual.to_padded_tensor(0.0), expect)
コード例 #8
0
 def test_softmax(self, device, dtype):
     # normal nested tensor
     ntensors = 4
     nt = self.random_nt(device, dtype, ntensors, (4, 4))
     # error case: softmax across nested dimension
     self.assertRaisesRegex(
         RuntimeError, "Cannot apply softmax across nested dimension 0",
         lambda: torch.nn.functional.softmax(nt, 0))
     self.assertRaisesRegex(
         RuntimeError, "Cannot apply softmax across nested dimension 0",
         lambda: torch.nn.functional.softmax(nt, -3))
     # error case: dimension out of range
     self.assertRaises(IndexError,
                       lambda: torch.nn.functional.softmax(nt, 3))
     self.assertRaises(IndexError,
                       lambda: torch.nn.functional.softmax(nt, -4))
     # normal case: should equal to padding -inf
     softmaxer = torch.nn.Softmax(1)
     y0 = softmaxer(nt)
     y1 = torch.nn.functional.softmax(nt, 1)
     self.nt_equal(y0, y1)
     pt = nt.to_padded_tensor(float("-inf"))
     # if an entire slice is padded, then softmax will return 0.0 / 0.0 = nan
     # however, physically speaking that should be 0.0
     expect = torch.nn.functional.softmax(pt, 1).nan_to_num_(0.0)
     self.assertEqual(y0.to_padded_tensor(0.0), expect)
     # edge case: empty nested tensor
     nt0 = torch.nested_tensor([])
     y = torch.nn.functional.softmax(nt0, 1)
     self.nt_equal(nt0, y)
     # edge case: nesting scalars
     nt1 = torch.nested_tensor([torch.tensor(0.0), torch.tensor(1.0)])
     self.assertRaises(RuntimeError,
                       lambda: torch.nn.functional.softmax(nt1, 0))
     self.assertRaises(IndexError,
                       lambda: torch.nn.functional.softmax(nt1, 1))
コード例 #9
0
 def test_to_padded_tensor_dim4(self, device, dtype):
     ts = [
         torch.randn(16, 21, 13, device=device, dtype=dtype),
         torch.randn(24, 32, 14, device=device, dtype=dtype),
         torch.randn(40, 53, 16, device=device, dtype=dtype),
     ]
     nt = torch.nested_tensor(ts, device=device, dtype=dtype)
     pad = 42
     correct_output = []
     for t in ts:
         next_output = torch.ones_like(ts[2]) * pad
         correct_output.append(next_output)
         next_output[:t.size(0), :t.size(1), :t.size(2)].copy_(t)
     correct_output = torch.stack(correct_output)
     padded = nt.to_padded_tensor(pad)
     self.assertEqual(padded, correct_output)
コード例 #10
0
    def test_to_padded_tensor_simple(self, device, dtype):
        t = torch.randn(4, 4, 4, device=device, dtype=dtype)
        ts = list(torch.unbind(t))
        ts[0] = ts[0][:-1]
        nt = torch.nested_tensor(ts, device=device, dtype=dtype)
        for padding_value in (0, 1):
            padded = nt.to_padded_tensor(padding_value)

            correct_output = t.clone()
            if padding_value == 0:
                correct_output[0][-1] = torch.zeros_like(correct_output[0][-1])
            else:
                correct_output[0][-1] = torch.ones_like(correct_output[0][-1])

            self.assertEqual(padded, correct_output)
            self.assertEqual(padded.device, torch.device(device))
            self.assertEqual(padded.dtype, dtype)
コード例 #11
0
    def test_nested_tensor_linear_backward(self):
        a = torch.randn(1, 2, requires_grad=False)
        b = torch.randn(2, 2, requires_grad=False)
        c = torch.randn(3, 2, requires_grad=False)

        weight = torch.randn(2, 2, requires_grad=True)
        bias = torch.randn(2, requires_grad=True)
        nt = torch.nested_tensor([a, b, c])

        out = torch.functional.F.linear(nt, weight, bias)

        out.backward(out.clone())

        assert weight.grad is not None
        assert bias.grad is not None

        assert a.grad is None
        assert b.grad is None
        assert c.grad is None
コード例 #12
0
 def test_nested_tensor(self):
     self.assertRaises(TypeError, lambda: nested_tensor([3.0]))
     self.assertRaises(TypeError, lambda: nested_tensor(torch.tensor([3.0])))
     self.assertRaises(TypeError, lambda: nested_tensor(4.0))
コード例 #13
0
 def _test_fn(unbind_fn):
     a = torch.rand(3, 2)
     b = torch.rand(2, 3)
     nt = nested_tensor([a, b])
     self.assertRaises(RuntimeError, lambda: unbind_fn(nt, 1))
コード例 #14
0
 def test_nested_tensor_mul_in_place(self, device, dtype):
     (nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4))
     ref = torch.nested_tensor([t1 * t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())])
     nt1 *= nt2
     self.nt_equal(ref, nt1)
コード例 #15
0
 def test_nested_tensor_add(self, device, dtype):
     (nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4))
     ref = torch.nested_tensor([t1 + t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())])
     out = nt1 + nt2
     self.nt_equal(ref, out)
コード例 #16
0
 def test_device_checks(self, device):
     nt = torch.nested_tensor([], device=device)
     is_cuda = 'cuda' in str(device)
     self.assertEqual(nt.is_cuda, is_cuda)
コード例 #17
0
 def test_dropout(self, device, dtype):
     # edge case: empty nested tensor
     nt0 = torch.nested_tensor([])
     y = torch.nn.functional.dropout(nt0, 0.5)
     self.nt_equal(nt0, y)
     # normal nested tensor
     ntensors = 4
     nt = self.random_nt(device, dtype, ntensors, (4, 4))
     # edge case: invalid dropout
     self.assertRaises(ValueError, lambda: torch.nn.Dropout(-0.1))
     self.assertRaises(ValueError, lambda: torch.nn.Dropout(1.1))
     self.assertRaises(ValueError,
                       lambda: torch.nn.functional.dropout(nt, -0.1))
     self.assertRaises(ValueError,
                       lambda: torch.nn.functional.dropout(nt, 1.1))
     # edge case: no dropout
     dropouter = torch.nn.Dropout(0.0)
     y0 = dropouter(nt)
     y1 = torch.nn.functional.dropout(nt, 0.0)
     self.nt_equal(nt, y0)
     self.nt_equal(nt, y1)
     # edge case: all dropout
     dropouter = torch.nn.Dropout(1.0)
     y0 = dropouter(nt)
     y1 = torch.nn.functional.dropout(nt, 1.0)
     nt0 = nt.clone()
     for i in range(ntensors):
         nt0[i].fill_(0.0)
     self.nt_equal(nt0, y0)
     self.nt_equal(nt0, y1)
     # normal case: normal dropout
     p = 0.2
     y = torch.nn.functional.dropout(nt, p)
     expect = nt.clone()
     for i in range(ntensors):
         actual_tensor = y[i].view(-1)
         expect_tensor = expect[i].view(-1)
         for j in range(actual_tensor.shape[0]):
             if actual_tensor[j].item() == 0.0:
                 expect_tensor[j] = 0.0
             else:
                 expect_tensor[j] /= 1.0 - p
     self.nt_equal(y, expect)
     with freeze_rng_state():
         dropouter = torch.nn.Dropout(p)
         y0 = dropouter(nt)
     with freeze_rng_state():
         y1 = torch.nn.functional.dropout(nt, p)
     self.nt_equal(y0, y1)
     # inplace
     # in principle, since we have established the correctness of functional, we could simply compare inplace vs functional
     # in practice, cuda functional has its own implementation to skip `bernoulli_`
     # so cuda functional will differ from cuda inplace causing test failure
     # in `test_dropout_cuda_float64 (__main__.TestNestedTensorDeviceTypeCUDA)`
     # on `linux-xenial-cuda11.3-py3.7-gcc7 / test (default, 2, 4, linux.4xlarge.nvidia.gpu)`
     expect = nt.clone()
     torch.nn.functional.dropout(nt, p, inplace=True)
     for i in range(ntensors):
         actual_tensor = nt[i].view(-1)
         expect_tensor = expect[i].view(-1)
         for j in range(actual_tensor.shape[0]):
             if actual_tensor[j].item() == 0.0:
                 expect_tensor[j] = 0.0
             else:
                 expect_tensor[j] /= 1.0 - p
     self.nt_equal(nt, expect)
コード例 #18
0
 def _create_nested_tensor_from_list(self, requires_grad=False):
     return torch.nested_tensor([
         torch.randn(1, 2, requires_grad=requires_grad),
         torch.randn(7, 8, requires_grad=requires_grad)
     ])
コード例 #19
0
 def grad_test_func(a, b, c):
     c = torch.nested_tensor([a, b, c])
     # This implictily tests to_padded_tensor grads
     return c.to_padded_tensor(0)
コード例 #20
0
 def test_to_padded_tensor_on_empty_tensor(self):
     nt = torch.nested_tensor([])
     empty = nt.to_padded_tensor(4)
     self.assertEqual(empty, torch.tensor([]))
コード例 #21
0
 def grad_test_func(a, b, c, weight, bias=None):
     nt = torch.nested_tensor([a, b, c])
     # This implicitly tests to_padded_tensor grads
     d = torch.functional.F.linear(nt, weight, bias)
     return d.to_padded_tensor(0)
コード例 #22
0
ファイル: test_native_mha.py プロジェクト: yuguo68/pytorch
    def _test_transform_bias_rescale_qkv_impl(self,
                                              device,
                                              dtype,
                                              use_nt,
                                              use_padding=False):
        tests = [
            (64, 4, 16, 8),
            # dim_per_head = 12 does not divide evenly by CPU vectorization length of 8
            (24, 2, 4, 2),
            # Make sure CUDA can handle small input sizes
            (2, 2, 2, 2),
            # dim_per_head = 6 does not divide evenly by CUDA vectorization length of 4,
            # causes alignment issues
            (24, 4, 4, 2),
            (48, 4, 16, 8),
        ]
        for (embed_dim, num_heads, bs, sl) in tests:
            with self.subTest(embed_dim=embed_dim,
                              num_heads=num_heads,
                              bs=bs,
                              sl=sl):
                torch.manual_seed(9343)
                dense_x = x = (torch.randn(
                    bs, sl, 3 * embed_dim, device=device, dtype=dtype) * 10)
                if use_padding:
                    x[0][-1] = torch.full(x[0][-1].shape, float("-Inf"))
                if use_nt:
                    xs = list(torch.unbind(x))
                    if use_padding:
                        xs[0] = xs[0][:-1]
                    x = torch.nested_tensor(xs, device=device, dtype=dtype)
                qkv = torch.nn.Linear(embed_dim,
                                      3 * embed_dim,
                                      device=device,
                                      dtype=dtype)

                # We have to use inference_mode here because q/k/v are
                # all views of the same Tensor, which autograd doesn't
                # like. This is fine because this function is only
                # exposed to Python for purposes of writing this test.
                with torch.inference_mode():
                    (q, k, v) = torch._transform_bias_rescale_qkv(
                        x, qkv.bias, num_heads=num_heads)

                    def simple_transform_bias_rescale_qkv(qkv, bias):
                        (q, k, v) = torch.split(qkv, embed_dim, dim=-1)
                        (q_bias, k_bias, v_bias) = torch.split(bias,
                                                               embed_dim,
                                                               dim=-1)

                        def embiggen(x):
                            if not use_nt:
                                return x
                            b, t, d = x.size()
                            t = t + (8 - t % 8) % 8
                            newsize = (b, t, d)
                            new_x = torch.zeros(newsize,
                                                device=device,
                                                dtype=dtype)
                            new_x[:x.size()[0], :x.size()[1], :x.size()[2]] = x
                            return new_x

                        return tuple(
                            embiggen(x).reshape((bs, -1, num_heads,
                                                 embed_dim //
                                                 num_heads)).transpose(2, 1)
                            for x in (
                                (q + q_bias) /
                                math.sqrt(embed_dim // num_heads),
                                (k + k_bias),
                                (v + v_bias),
                            ))

                    correct_q, correct_k, correct_v = simple_transform_bias_rescale_qkv(
                        dense_x, qkv.bias)
                    if use_nt and use_padding:
                        for t in (correct_q, correct_k, correct_v):
                            t[t == float("-Inf")] = 0

                self.assertEqual(q.size(), correct_q.size())
                torch.testing.assert_close(q, correct_q)
                torch.testing.assert_close(k, correct_k)
                torch.testing.assert_close(v, correct_v)
コード例 #23
0
ファイル: test_native_mha.py プロジェクト: yuguo68/pytorch
    def _test_multihead_attention_impl(self,
                                       device,
                                       dtype,
                                       mode,
                                       use_nt,
                                       need_weights,
                                       average_attn_weights,
                                       use_padding=False,
                                       pad_all=False):
        embed_dim = 64
        num_heads = 4
        bs = 16
        sl = 8

        q = torch.randn(bs, sl, embed_dim, device=device, dtype=dtype) * 10
        if use_padding:
            if pad_all:
                for q_i in q:
                    q_i[-1] = torch.zeros_like(q[0][-1],
                                               device=device,
                                               dtype=dtype)
                mask = torch.zeros(q.shape[:-1],
                                   device=device,
                                   dtype=torch.bool)
                for mask_i in mask:
                    mask_i[-1] = True
            else:
                q[0][-1] = torch.zeros_like(q[0][-1],
                                            device=device,
                                            dtype=dtype)
                mask = torch.zeros(q.shape[:-1],
                                   device=device,
                                   dtype=torch.bool)
                mask[0][-1] = True
        if mode == "self":
            k = q
            v = q
        elif mode == "encdec":
            k = torch.randn(bs, sl, embed_dim, device=device, dtype=dtype) * 10
            v = k
        elif mode == "generic":
            k = torch.randn(bs, sl, embed_dim, device=device, dtype=dtype) * 10
            v = torch.randn(bs, sl, embed_dim, device=device, dtype=dtype) * 10
        else:
            self.fail(f"invalid mode `{mode}`!")

        qkv = torch.nn.Linear(embed_dim,
                              3 * embed_dim,
                              device=device,
                              dtype=dtype)
        proj = torch.nn.Linear(embed_dim,
                               embed_dim,
                               device=device,
                               dtype=dtype)

        pt = torch.nn.MultiheadAttention(embed_dim,
                                         num_heads,
                                         batch_first=True,
                                         device=device,
                                         dtype=dtype)
        pt.in_proj_weight = qkv.weight
        pt.in_proj_bias = qkv.bias
        pt.out_proj.weight = proj.weight
        pt.out_proj.bias = proj.bias

        class NativeMHA(torch.nn.Module):
            def __init__(self, embed_dim, num_heads, qkv, proj):
                super().__init__()
                self.qkv = qkv
                self.proj = proj
                self.embed_dim = embed_dim
                self.num_heads = num_heads

            def forward(self, q, k, v, key_padding_mask):
                return torch._native_multi_head_attention(
                    q,
                    k,
                    v,
                    self.embed_dim,
                    self.num_heads,
                    self.qkv.weight,
                    self.qkv.bias,
                    self.proj.weight,
                    self.proj.bias,
                    key_padding_mask,
                    need_weights=need_weights,
                    average_attn_weights=average_attn_weights,
                )

        npt = NativeMHA(embed_dim=embed_dim,
                        num_heads=num_heads,
                        qkv=qkv,
                        proj=proj).to(dtype)

        if device == "cuda":
            pt = pt.cuda()
            npt = npt.cuda()

        ypt, weight_pt = pt(
            q,
            k,
            v,
            need_weights=need_weights,
            average_attn_weights=average_attn_weights,
            key_padding_mask=mask if use_padding else None,
        )
        if use_nt:
            qs = list(torch.unbind(q))
            if use_padding:
                if pad_all:
                    qs = [x[:-1] for x in qs]
                else:
                    qs[0] = qs[0][:-1]
            q = torch.nested_tensor(qs, device=device, dtype=dtype)
            if mode == "self":
                k = v = q
            elif mode == "encdec":
                k = torch.nested_tensor(torch.unbind(k),
                                        device=device,
                                        dtype=dtype)
                v = k
            else:
                k = torch.nested_tensor(torch.unbind(k),
                                        device=device,
                                        dtype=dtype)
                v = torch.nested_tensor(torch.unbind(v),
                                        device=device,
                                        dtype=dtype)

        ynpt, weight_npt = npt(
            q,
            k,
            v,
            key_padding_mask=mask if use_padding and not use_nt else None)
        if use_nt:
            ynpt = ynpt.to_padded_tensor(0)
            if pad_all:
                ynpt_final = torch.zeros_like(ypt)
                ynpt_final[:, :ynpt.shape[1], :] = ynpt
                ynpt = ynpt_final

        def do_pad_all(tensors):
            for t in tensors:
                for t_i in t:
                    t_i[-1] = torch.zeros_like(t_i[-1],
                                               device=device,
                                               dtype=dtype)

        # PyTorch implementation returns non-zero junk in the padding
        # locations; overwrite it so that the comparison works out.
        if use_padding:
            ypt[0][-1] = torch.zeros_like(ypt[0][-1],
                                          device=device,
                                          dtype=dtype)
            ynpt[0][-1] = torch.zeros_like(ynpt[0][-1],
                                           device=device,
                                           dtype=dtype)
            if pad_all:
                do_pad_all((ypt, ynpt))
            # Zero the last row of each TxT weight matrix
            if need_weights:
                if average_attn_weights:
                    weight_pt[0][-1] = torch.zeros_like(weight_pt[0][-1],
                                                        device=device,
                                                        dtype=dtype)
                    weight_npt[0][-1] = torch.zeros_like(weight_npt[0][-1],
                                                         device=device,
                                                         dtype=dtype)
                    if pad_all:
                        do_pad_all((weight_pt, weight_npt))
                else:
                    for nh in range(num_heads):
                        weight_pt[0][nh][-1] = torch.zeros_like(
                            weight_pt[0][nh][-1], device=device, dtype=dtype)
                        weight_npt[0][nh][-1] = torch.zeros_like(
                            weight_npt[0][nh][-1], device=device, dtype=dtype)

        if dtype == torch.half:
            torch.testing.assert_close(ypt, ynpt, atol=1e-3, rtol=1e-3)
        else:
            # High rtol seems necessary for
            # test_native_multihead_attention_cpu_float32 on Windows,
            # otherwise 2e-4 would likely be fine.
            torch.testing.assert_close(ypt, ynpt, atol=2e-5, rtol=2e-3)

        if need_weights:
            torch.testing.assert_close(weight_pt, weight_npt)
        else:
            self.assertEqual(weight_pt, weight_npt)