def test_repr_string(self): a = nested_tensor([]) expected = "nested_tensor([" "\n\n])" self.assertEqual(str(a), expected) self.assertEqual(repr(a), expected) a = nested_tensor([torch.tensor(1.0)]) expected = "nested_tensor([" "\n tensor(1.)" "\n])" self.assertEqual(str(a), expected) self.assertEqual(repr(a), expected) a = nested_tensor([torch.tensor([[1, 2]]), torch.tensor([[4, 5]])]) expected = ("nested_tensor([" "\n tensor([[1, 2]])" "," "\n tensor([[4, 5]])" "\n])") self.assertEqual(str(a), expected) self.assertEqual(repr(a), expected)
def test_embedding(self, device): inputs = [ torch.randint(100, (L,), device=device, dtype=torch.int64) for L in torch.randint(5, 50, (8,)) ] x = torch.nested_tensor(inputs, device=device, dtype=torch.int64) emb = torch.nn.Embedding(100, 8, device=device) y = emb(x) ys = y.unbind() for i, inp in enumerate(inputs): self.assertEqual(emb(inp), ys[i])
def test_to_then_from_padded_tensor_no_transform0213(self, device, dtype): t = torch.randn(4, 4, 4, device=device, dtype=dtype) ts = list(torch.unbind(t)) ts[0] = ts[0][:-1] nt = torch.nested_tensor(ts, device=device, dtype=dtype) padded = nt.to_padded_tensor(0) nt_to = torch._nested_from_padded_and_nested_example(padded, nt) for (t1, t2) in zip(nt.unbind(), nt_to.unbind()): self.assertEqual(t1, t2) self.assertEqual(nt.device, nt_to.device)
def _test(size): t0 = torch.randn(2, size, device=device, dtype=dtype, requires_grad=False) t1 = torch.randn(2, size, device=device, dtype=dtype, requires_grad=False) ts = [t0, t1, t0, t1] nt = torch.nested_tensor(ts, device=device, dtype=dtype) layer_norm = torch.nn.LayerNorm(size, device=device, dtype=dtype) nt_result = nt._nested_tensor_layer_norm( layer_norm.weight, layer_norm.bias, 1e-5 ) for (nt_subresult, t) in zip(nt_result.unbind(), ts): t_result = layer_norm(t.reshape(1, -1, size).squeeze(0)) self.assertEqual(nt_subresult, t_result)
def random_nt(self, device, dtype, num_tensors, max_dims, min_dims=None): if min_dims is None: min_dims = tuple([0] * len(max_dims)) ts1 = [] for _ in range(num_tensors): tensor_dims = tuple([ torch.randint(low=min_dim, high=max_dim, size=(1, )).item() for (min_dim, max_dim) in zip(min_dims, max_dims) ]) t1 = torch.randn(tensor_dims, device=device, dtype=dtype) ts1.append(t1) return torch.nested_tensor(ts1, device=device, dtype=dtype)
def test_reshape(self, device, dtype): nt = self.random_nt(device, dtype, 4, (4, 4)) # error case: empty shape self.assertRaisesRegex(RuntimeError, r"shape '\[\]' is invalid for a nested tensor", lambda: nt.reshape(())) # error case: empty nested tensor nt_empty = torch.nested_tensor([]) self.assertRaisesRegex(RuntimeError, "empty nested tensor cannot be reshaped", lambda: nt_empty.reshape(-1)) # error case: invalid proposed shape for underlying tensors self.assertRaisesRegex(RuntimeError, r"invalid shape dimension -2", lambda: nt.reshape(-2, 2, 3)) self.assertRaisesRegex( RuntimeError, r"shape '\[.*\]' is invalid for input of size [0-9]+", lambda: nt.reshape(4, 2, 3)) # normal case x0 = torch.randn((2, 20), device=device, dtype=dtype) x1 = torch.randn((3, 20), device=device, dtype=dtype) nt = torch.nested_tensor([x0, x1]) pt = nt.to_padded_tensor(0.0) self.assertRaisesRegex( RuntimeError, r"for now reshape cannot change the implicit batch dimension", lambda: nt.transpose(-1, -2).reshape(40, -1)) # inherit only the ragged dimension # (2, 20) -> (2, 5, 4) # (3, 20) -> (3, 5, 4) nt1 = nt.reshape(2, -1, 5, 4) # (2, 3, 20) -> (2, 3, 5, 4) -> (2, 4, 5, 4) pt1 = pt.reshape(2, -1, 5, 4) self.assertEqual(noncontiguous_to_padded_tensor(nt1), pt1) # also inherit regular dimension nt2 = nt1.reshape(2, -1, -1, 2, 2) pt2 = pt1.reshape(2, -1, 5, 2, 2) self.assertEqual(noncontiguous_to_padded_tensor(nt2), pt2)
def test_bmm(self, device, dtype): # error case: not 3D tensors nt0 = torch.nested_tensor([]) nt1 = torch.nested_tensor([torch.randn(2), torch.randn(3)]) nt2 = torch.nested_tensor([torch.randn((2, 4)), torch.randn((3, 4))]) self.assertRaisesRegex(RuntimeError, "batch1 must be a 3D tensor", lambda: nt0.bmm(nt0)) self.assertRaisesRegex(RuntimeError, "batch1 must be a 3D tensor", lambda: nt0.bmm(nt1)) self.assertRaisesRegex(RuntimeError, "batch1 must be a 3D tensor", lambda: nt0.bmm(nt2)) self.assertRaisesRegex(RuntimeError, "batch1 must be a 3D tensor", lambda: nt1.bmm(nt0)) self.assertRaisesRegex(RuntimeError, "batch1 must be a 3D tensor", lambda: nt1.bmm(nt1)) self.assertRaisesRegex(RuntimeError, "batch1 must be a 3D tensor", lambda: nt1.bmm(nt2)) self.assertRaisesRegex(RuntimeError, "batch2 must be a 3D tensor", lambda: nt2.bmm(nt0)) self.assertRaisesRegex(RuntimeError, "batch2 must be a 3D tensor", lambda: nt2.bmm(nt1)) # error case: incompatible batch size nt0 = torch.nested_tensor([torch.randn((2, 4)), torch.randn((3, 4))]) nt1 = torch.nested_tensor( [torch.randn((4, 6)), torch.randn((4, 5)), torch.randn((4, 7))]) self.assertRaisesRegex( RuntimeError, "Expected size for the 1st dimension of batch2 tensor to be: 2 but got: 3.", lambda: nt0.bmm(nt1)) self.assertRaisesRegex( RuntimeError, "Expected size for the 1st dimension of batch2 tensor to be: 3 but got: 2.", lambda: nt1.bmm(nt0)) # error case: underlying matrices cannot be multiplied nt0 = torch.nested_tensor([torch.randn((2, 4)), torch.randn((3, 4))]) self.assertRaisesRegex( RuntimeError, r"0-th nested matrices in batch cannot be multiplied \(2x4 and 2x4\)", lambda: nt0.bmm(nt0)) # normal nested tensor nt0 = torch.nested_tensor([torch.randn((2, 4)), torch.randn((3, 7))]) nt1 = torch.nested_tensor([torch.randn((4, 6)), torch.randn((7, 5))]) actual = nt0.bmm(nt1) expect = nt0.to_padded_tensor(0.0).bmm(nt1.to_padded_tensor(0.0)) self.assertEqual(actual.to_padded_tensor(0.0), expect)
def test_softmax(self, device, dtype): # normal nested tensor ntensors = 4 nt = self.random_nt(device, dtype, ntensors, (4, 4)) # error case: softmax across nested dimension self.assertRaisesRegex( RuntimeError, "Cannot apply softmax across nested dimension 0", lambda: torch.nn.functional.softmax(nt, 0)) self.assertRaisesRegex( RuntimeError, "Cannot apply softmax across nested dimension 0", lambda: torch.nn.functional.softmax(nt, -3)) # error case: dimension out of range self.assertRaises(IndexError, lambda: torch.nn.functional.softmax(nt, 3)) self.assertRaises(IndexError, lambda: torch.nn.functional.softmax(nt, -4)) # normal case: should equal to padding -inf softmaxer = torch.nn.Softmax(1) y0 = softmaxer(nt) y1 = torch.nn.functional.softmax(nt, 1) self.nt_equal(y0, y1) pt = nt.to_padded_tensor(float("-inf")) # if an entire slice is padded, then softmax will return 0.0 / 0.0 = nan # however, physically speaking that should be 0.0 expect = torch.nn.functional.softmax(pt, 1).nan_to_num_(0.0) self.assertEqual(y0.to_padded_tensor(0.0), expect) # edge case: empty nested tensor nt0 = torch.nested_tensor([]) y = torch.nn.functional.softmax(nt0, 1) self.nt_equal(nt0, y) # edge case: nesting scalars nt1 = torch.nested_tensor([torch.tensor(0.0), torch.tensor(1.0)]) self.assertRaises(RuntimeError, lambda: torch.nn.functional.softmax(nt1, 0)) self.assertRaises(IndexError, lambda: torch.nn.functional.softmax(nt1, 1))
def test_to_padded_tensor_dim4(self, device, dtype): ts = [ torch.randn(16, 21, 13, device=device, dtype=dtype), torch.randn(24, 32, 14, device=device, dtype=dtype), torch.randn(40, 53, 16, device=device, dtype=dtype), ] nt = torch.nested_tensor(ts, device=device, dtype=dtype) pad = 42 correct_output = [] for t in ts: next_output = torch.ones_like(ts[2]) * pad correct_output.append(next_output) next_output[:t.size(0), :t.size(1), :t.size(2)].copy_(t) correct_output = torch.stack(correct_output) padded = nt.to_padded_tensor(pad) self.assertEqual(padded, correct_output)
def test_to_padded_tensor_simple(self, device, dtype): t = torch.randn(4, 4, 4, device=device, dtype=dtype) ts = list(torch.unbind(t)) ts[0] = ts[0][:-1] nt = torch.nested_tensor(ts, device=device, dtype=dtype) for padding_value in (0, 1): padded = nt.to_padded_tensor(padding_value) correct_output = t.clone() if padding_value == 0: correct_output[0][-1] = torch.zeros_like(correct_output[0][-1]) else: correct_output[0][-1] = torch.ones_like(correct_output[0][-1]) self.assertEqual(padded, correct_output) self.assertEqual(padded.device, torch.device(device)) self.assertEqual(padded.dtype, dtype)
def test_nested_tensor_linear_backward(self): a = torch.randn(1, 2, requires_grad=False) b = torch.randn(2, 2, requires_grad=False) c = torch.randn(3, 2, requires_grad=False) weight = torch.randn(2, 2, requires_grad=True) bias = torch.randn(2, requires_grad=True) nt = torch.nested_tensor([a, b, c]) out = torch.functional.F.linear(nt, weight, bias) out.backward(out.clone()) assert weight.grad is not None assert bias.grad is not None assert a.grad is None assert b.grad is None assert c.grad is None
def test_nested_tensor(self): self.assertRaises(TypeError, lambda: nested_tensor([3.0])) self.assertRaises(TypeError, lambda: nested_tensor(torch.tensor([3.0]))) self.assertRaises(TypeError, lambda: nested_tensor(4.0))
def _test_fn(unbind_fn): a = torch.rand(3, 2) b = torch.rand(2, 3) nt = nested_tensor([a, b]) self.assertRaises(RuntimeError, lambda: unbind_fn(nt, 1))
def test_nested_tensor_mul_in_place(self, device, dtype): (nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4)) ref = torch.nested_tensor([t1 * t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())]) nt1 *= nt2 self.nt_equal(ref, nt1)
def test_nested_tensor_add(self, device, dtype): (nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4)) ref = torch.nested_tensor([t1 + t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())]) out = nt1 + nt2 self.nt_equal(ref, out)
def test_device_checks(self, device): nt = torch.nested_tensor([], device=device) is_cuda = 'cuda' in str(device) self.assertEqual(nt.is_cuda, is_cuda)
def test_dropout(self, device, dtype): # edge case: empty nested tensor nt0 = torch.nested_tensor([]) y = torch.nn.functional.dropout(nt0, 0.5) self.nt_equal(nt0, y) # normal nested tensor ntensors = 4 nt = self.random_nt(device, dtype, ntensors, (4, 4)) # edge case: invalid dropout self.assertRaises(ValueError, lambda: torch.nn.Dropout(-0.1)) self.assertRaises(ValueError, lambda: torch.nn.Dropout(1.1)) self.assertRaises(ValueError, lambda: torch.nn.functional.dropout(nt, -0.1)) self.assertRaises(ValueError, lambda: torch.nn.functional.dropout(nt, 1.1)) # edge case: no dropout dropouter = torch.nn.Dropout(0.0) y0 = dropouter(nt) y1 = torch.nn.functional.dropout(nt, 0.0) self.nt_equal(nt, y0) self.nt_equal(nt, y1) # edge case: all dropout dropouter = torch.nn.Dropout(1.0) y0 = dropouter(nt) y1 = torch.nn.functional.dropout(nt, 1.0) nt0 = nt.clone() for i in range(ntensors): nt0[i].fill_(0.0) self.nt_equal(nt0, y0) self.nt_equal(nt0, y1) # normal case: normal dropout p = 0.2 y = torch.nn.functional.dropout(nt, p) expect = nt.clone() for i in range(ntensors): actual_tensor = y[i].view(-1) expect_tensor = expect[i].view(-1) for j in range(actual_tensor.shape[0]): if actual_tensor[j].item() == 0.0: expect_tensor[j] = 0.0 else: expect_tensor[j] /= 1.0 - p self.nt_equal(y, expect) with freeze_rng_state(): dropouter = torch.nn.Dropout(p) y0 = dropouter(nt) with freeze_rng_state(): y1 = torch.nn.functional.dropout(nt, p) self.nt_equal(y0, y1) # inplace # in principle, since we have established the correctness of functional, we could simply compare inplace vs functional # in practice, cuda functional has its own implementation to skip `bernoulli_` # so cuda functional will differ from cuda inplace causing test failure # in `test_dropout_cuda_float64 (__main__.TestNestedTensorDeviceTypeCUDA)` # on `linux-xenial-cuda11.3-py3.7-gcc7 / test (default, 2, 4, linux.4xlarge.nvidia.gpu)` expect = nt.clone() torch.nn.functional.dropout(nt, p, inplace=True) for i in range(ntensors): actual_tensor = nt[i].view(-1) expect_tensor = expect[i].view(-1) for j in range(actual_tensor.shape[0]): if actual_tensor[j].item() == 0.0: expect_tensor[j] = 0.0 else: expect_tensor[j] /= 1.0 - p self.nt_equal(nt, expect)
def _create_nested_tensor_from_list(self, requires_grad=False): return torch.nested_tensor([ torch.randn(1, 2, requires_grad=requires_grad), torch.randn(7, 8, requires_grad=requires_grad) ])
def grad_test_func(a, b, c): c = torch.nested_tensor([a, b, c]) # This implictily tests to_padded_tensor grads return c.to_padded_tensor(0)
def test_to_padded_tensor_on_empty_tensor(self): nt = torch.nested_tensor([]) empty = nt.to_padded_tensor(4) self.assertEqual(empty, torch.tensor([]))
def grad_test_func(a, b, c, weight, bias=None): nt = torch.nested_tensor([a, b, c]) # This implicitly tests to_padded_tensor grads d = torch.functional.F.linear(nt, weight, bias) return d.to_padded_tensor(0)
def _test_transform_bias_rescale_qkv_impl(self, device, dtype, use_nt, use_padding=False): tests = [ (64, 4, 16, 8), # dim_per_head = 12 does not divide evenly by CPU vectorization length of 8 (24, 2, 4, 2), # Make sure CUDA can handle small input sizes (2, 2, 2, 2), # dim_per_head = 6 does not divide evenly by CUDA vectorization length of 4, # causes alignment issues (24, 4, 4, 2), (48, 4, 16, 8), ] for (embed_dim, num_heads, bs, sl) in tests: with self.subTest(embed_dim=embed_dim, num_heads=num_heads, bs=bs, sl=sl): torch.manual_seed(9343) dense_x = x = (torch.randn( bs, sl, 3 * embed_dim, device=device, dtype=dtype) * 10) if use_padding: x[0][-1] = torch.full(x[0][-1].shape, float("-Inf")) if use_nt: xs = list(torch.unbind(x)) if use_padding: xs[0] = xs[0][:-1] x = torch.nested_tensor(xs, device=device, dtype=dtype) qkv = torch.nn.Linear(embed_dim, 3 * embed_dim, device=device, dtype=dtype) # We have to use inference_mode here because q/k/v are # all views of the same Tensor, which autograd doesn't # like. This is fine because this function is only # exposed to Python for purposes of writing this test. with torch.inference_mode(): (q, k, v) = torch._transform_bias_rescale_qkv( x, qkv.bias, num_heads=num_heads) def simple_transform_bias_rescale_qkv(qkv, bias): (q, k, v) = torch.split(qkv, embed_dim, dim=-1) (q_bias, k_bias, v_bias) = torch.split(bias, embed_dim, dim=-1) def embiggen(x): if not use_nt: return x b, t, d = x.size() t = t + (8 - t % 8) % 8 newsize = (b, t, d) new_x = torch.zeros(newsize, device=device, dtype=dtype) new_x[:x.size()[0], :x.size()[1], :x.size()[2]] = x return new_x return tuple( embiggen(x).reshape((bs, -1, num_heads, embed_dim // num_heads)).transpose(2, 1) for x in ( (q + q_bias) / math.sqrt(embed_dim // num_heads), (k + k_bias), (v + v_bias), )) correct_q, correct_k, correct_v = simple_transform_bias_rescale_qkv( dense_x, qkv.bias) if use_nt and use_padding: for t in (correct_q, correct_k, correct_v): t[t == float("-Inf")] = 0 self.assertEqual(q.size(), correct_q.size()) torch.testing.assert_close(q, correct_q) torch.testing.assert_close(k, correct_k) torch.testing.assert_close(v, correct_v)
def _test_multihead_attention_impl(self, device, dtype, mode, use_nt, need_weights, average_attn_weights, use_padding=False, pad_all=False): embed_dim = 64 num_heads = 4 bs = 16 sl = 8 q = torch.randn(bs, sl, embed_dim, device=device, dtype=dtype) * 10 if use_padding: if pad_all: for q_i in q: q_i[-1] = torch.zeros_like(q[0][-1], device=device, dtype=dtype) mask = torch.zeros(q.shape[:-1], device=device, dtype=torch.bool) for mask_i in mask: mask_i[-1] = True else: q[0][-1] = torch.zeros_like(q[0][-1], device=device, dtype=dtype) mask = torch.zeros(q.shape[:-1], device=device, dtype=torch.bool) mask[0][-1] = True if mode == "self": k = q v = q elif mode == "encdec": k = torch.randn(bs, sl, embed_dim, device=device, dtype=dtype) * 10 v = k elif mode == "generic": k = torch.randn(bs, sl, embed_dim, device=device, dtype=dtype) * 10 v = torch.randn(bs, sl, embed_dim, device=device, dtype=dtype) * 10 else: self.fail(f"invalid mode `{mode}`!") qkv = torch.nn.Linear(embed_dim, 3 * embed_dim, device=device, dtype=dtype) proj = torch.nn.Linear(embed_dim, embed_dim, device=device, dtype=dtype) pt = torch.nn.MultiheadAttention(embed_dim, num_heads, batch_first=True, device=device, dtype=dtype) pt.in_proj_weight = qkv.weight pt.in_proj_bias = qkv.bias pt.out_proj.weight = proj.weight pt.out_proj.bias = proj.bias class NativeMHA(torch.nn.Module): def __init__(self, embed_dim, num_heads, qkv, proj): super().__init__() self.qkv = qkv self.proj = proj self.embed_dim = embed_dim self.num_heads = num_heads def forward(self, q, k, v, key_padding_mask): return torch._native_multi_head_attention( q, k, v, self.embed_dim, self.num_heads, self.qkv.weight, self.qkv.bias, self.proj.weight, self.proj.bias, key_padding_mask, need_weights=need_weights, average_attn_weights=average_attn_weights, ) npt = NativeMHA(embed_dim=embed_dim, num_heads=num_heads, qkv=qkv, proj=proj).to(dtype) if device == "cuda": pt = pt.cuda() npt = npt.cuda() ypt, weight_pt = pt( q, k, v, need_weights=need_weights, average_attn_weights=average_attn_weights, key_padding_mask=mask if use_padding else None, ) if use_nt: qs = list(torch.unbind(q)) if use_padding: if pad_all: qs = [x[:-1] for x in qs] else: qs[0] = qs[0][:-1] q = torch.nested_tensor(qs, device=device, dtype=dtype) if mode == "self": k = v = q elif mode == "encdec": k = torch.nested_tensor(torch.unbind(k), device=device, dtype=dtype) v = k else: k = torch.nested_tensor(torch.unbind(k), device=device, dtype=dtype) v = torch.nested_tensor(torch.unbind(v), device=device, dtype=dtype) ynpt, weight_npt = npt( q, k, v, key_padding_mask=mask if use_padding and not use_nt else None) if use_nt: ynpt = ynpt.to_padded_tensor(0) if pad_all: ynpt_final = torch.zeros_like(ypt) ynpt_final[:, :ynpt.shape[1], :] = ynpt ynpt = ynpt_final def do_pad_all(tensors): for t in tensors: for t_i in t: t_i[-1] = torch.zeros_like(t_i[-1], device=device, dtype=dtype) # PyTorch implementation returns non-zero junk in the padding # locations; overwrite it so that the comparison works out. if use_padding: ypt[0][-1] = torch.zeros_like(ypt[0][-1], device=device, dtype=dtype) ynpt[0][-1] = torch.zeros_like(ynpt[0][-1], device=device, dtype=dtype) if pad_all: do_pad_all((ypt, ynpt)) # Zero the last row of each TxT weight matrix if need_weights: if average_attn_weights: weight_pt[0][-1] = torch.zeros_like(weight_pt[0][-1], device=device, dtype=dtype) weight_npt[0][-1] = torch.zeros_like(weight_npt[0][-1], device=device, dtype=dtype) if pad_all: do_pad_all((weight_pt, weight_npt)) else: for nh in range(num_heads): weight_pt[0][nh][-1] = torch.zeros_like( weight_pt[0][nh][-1], device=device, dtype=dtype) weight_npt[0][nh][-1] = torch.zeros_like( weight_npt[0][nh][-1], device=device, dtype=dtype) if dtype == torch.half: torch.testing.assert_close(ypt, ynpt, atol=1e-3, rtol=1e-3) else: # High rtol seems necessary for # test_native_multihead_attention_cpu_float32 on Windows, # otherwise 2e-4 would likely be fine. torch.testing.assert_close(ypt, ynpt, atol=2e-5, rtol=2e-3) if need_weights: torch.testing.assert_close(weight_pt, weight_npt) else: self.assertEqual(weight_pt, weight_npt)