def test_ntftm_multi_tensor_true_mask(self): extected_nt1 = nt.nested_tensor( [torch.tensor([[1]]), torch.tensor([[2]]), torch.tensor([[3]])]) tensor = torch.tensor([[[1]], [[2]], [[3]]], dtype=torch.float) # Mask dim 3 mask3 = torch.tensor([[[True]], [[True]], [[True]]]) res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask3) TestCase.assertEqual(self, extected_nt1, res_nt) # Mask dim 2 mask2 = torch.tensor([[True], [True], [True]]) res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask2) TestCase.assertEqual(self, extected_nt1, res_nt) # Mask dim 1 mask1 = torch.tensor([True, True, True]) res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask1) TestCase.assertEqual(self, extected_nt1, res_nt) # Mask dim 0 mask0 = torch.tensor(True) res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask0) TestCase.assertEqual(self, extected_nt1, res_nt)
def test_ntftm_empty2(self): tensor = torch.tensor([[], []]) expected_nt1 = nt.nested_tensor([ torch.tensor([]), torch.tensor([]), ]) expected_nt2 = nt.nested_tensor( [nt.nested_tensor([]), nt.nested_tensor([])]) res_nt = nt.nested_tensor_from_tensor_mask(tensor, tensor) TestCase.assertEqual(self, res_nt, expected_nt1) res_nt = nt.nested_tensor_from_tensor_mask(tensor, tensor, nested_dim=1) TestCase.assertEqual(self, res_nt, expected_nt1) res_nt = nt.nested_tensor_from_tensor_mask(tensor, tensor, nested_dim=2) TestCase.assertEqual(self, res_nt, expected_nt2) self.assertRaises( RuntimeError, lambda: nt.nested_tensor_from_tensor_mask( tensor, tensor, nested_dim=3))
def test_ntftm_none_passed(self): self.assertRaises( RuntimeError, lambda: nt.nested_tensor_from_tensor_mask(None, None)) self.assertRaises( RuntimeError, lambda: nt.nested_tensor_from_tensor_mask(torch.tensor([]), None))
def test_ntftm_multi_scalars(self): tensor = torch.tensor([1, 2, 3]) mask = torch.tensor(True) res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask) TestCase.assertEqual( self, res_nt, nt.nested_tensor( [torch.tensor(1), torch.tensor(2), torch.tensor(3)], dtype=torch.int64)) mask = torch.tensor([True]) res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask) TestCase.assertEqual( self, res_nt, nt.nested_tensor( [torch.tensor(1), torch.tensor(2), torch.tensor(3)], dtype=torch.int64)) self.assertRaises( RuntimeError, lambda: nt.nested_tensor_from_tensor_mask( tensor, mask, nested_dim=2)) # Extra dim tensor = torch.tensor([[1, 2, 3]]) mask = torch.tensor(True) res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask) TestCase.assertEqual( self, res_nt, nt.nested_tensor([torch.tensor([1, 2, 3])], dtype=torch.int64))
def test_ntftm_test_multi_tensor_mix_mask2(self): expected_nt1 = nt.nested_tensor( [torch.tensor([[1, 2, 3]]), torch.tensor([[4]])]) expected_nt2 = nt.nested_tensor([ nt.nested_tensor([torch.tensor([1, 2, 3])]), nt.nested_tensor([torch.tensor([4])]) ]) expected_nt3 = nt.nested_tensor([ nt.nested_tensor([ nt.nested_tensor( [torch.tensor(1), torch.tensor(2), torch.tensor(3)]) ]), nt.nested_tensor([nt.nested_tensor([torch.tensor(4)])]) ]) tensor = torch.tensor([[[1, 2, 3]], [[4, 0, 0]]]) mask = torch.tensor([[[True, True, True]], [[True, False, False]]]) res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask, nested_dim=1) TestCase.assertEqual(self, expected_nt1, res_nt) res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask, nested_dim=2) TestCase.assertEqual(self, expected_nt2, res_nt) res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask, nested_dim=3) TestCase.assertEqual(self, expected_nt3, res_nt) self.assertRaises( RuntimeError, lambda: nt.nested_tensor_from_tensor_mask( tensor, mask, nested_dim=4))
def test_tensor_mask(self): nt = utils.gen_nested_tensor(2, 2, 2, size_low=1, size_high=2) tensor, mask = nt.to_tensor_mask() nt1 = nestedtensor.nested_tensor_from_tensor_mask( tensor, mask, nested_dim=nt.nested_dim()) self.assertEqual(nt, nt1) nt2 = nestedtensor.nested_tensor_from_tensor_mask(tensor, mask) self.assertEqual(nt, nt2)
def test_ntftm_single_tensor_all_false_mask(self): tensor = torch.tensor([[1]]) mask = torch.tensor([False]) res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask) TestCase.assertEqual(self, res_nt, nt.nested_tensor([])) tensor = torch.tensor([[1, 2, 3]]) mask = torch.tensor([False]) res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask) TestCase.assertEqual(self, res_nt, nt.nested_tensor([]))
def test_ntftm_single_tensor_all_true_mask(self): tensor = torch.tensor([[1]]) mask = torch.tensor(True) res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask) TestCase.assertEqual(self, res_nt, nt.nested_tensor([torch.tensor([1])])) mask = torch.tensor([True]) res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask) TestCase.assertEqual(self, res_nt, nt.nested_tensor([torch.tensor([1])]))
def test_ntftm_empty_error(self): tensor = torch.tensor([]) mask = torch.tensor([True]) self.assertRaisesRegex( RuntimeError, "Data tensor can't be emtpy if a mask has values.", lambda: nt.nested_tensor_from_tensor_mask(tensor, mask)) tensor = torch.tensor([1]) mask = torch.tensor([]) self.assertRaisesRegex( RuntimeError, "Mask tensor can't be emtpy if a data tensor has values.", lambda: nt.nested_tensor_from_tensor_mask(tensor, mask))
def test_ntftm_test_multi_tensor_mix_mask2(self): expected_nt1 = nt.nested_tensor( [torch.tensor([[1, 2, 3]]), torch.tensor([[4]])]) tensor = torch.tensor([[[1, 2, 3]], [[4, 0, 0]]], dtype=torch.float) mask = torch.tensor([[[True, True, True]], [[True, False, False]]]) res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask, nested_dim=1) TestCase.assertEqual(self, expected_nt1, res_nt) self.assertRaises( RuntimeError, lambda: nt.nested_tensor_from_tensor_mask( tensor, mask, nested_dim=4))
def test_ntftm_empty(self): tensor = torch.tensor([]) res_nt = nt.nested_tensor_from_tensor_mask(tensor, tensor) TestCase.assertEqual(self, res_nt, nt.nested_tensor([])) TestCase.assertEqual(self, res_nt.nested_dim(), 1) res_nt = nt.nested_tensor_from_tensor_mask(tensor, tensor, nested_dim=1) TestCase.assertEqual(self, res_nt, nt.nested_tensor([])) TestCase.assertEqual(self, res_nt.nested_dim(), 1) self.assertRaises( RuntimeError, lambda: nt.nested_tensor_from_tensor_mask( tensor, tensor, nested_dim=2))
def test_ntftm_multi_tensor_all_false_mask(self): tensor = torch.tensor([[[1], [2], [3]]]) mask = torch.tensor([False]) res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask) TestCase.assertEqual(self, res_nt, nt.nested_tensor([])) mask = torch.tensor([False, False, False]) res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask) TestCase.assertEqual(self, res_nt, nt.nested_tensor([])) mask = torch.tensor([[False], [False], [False]]) res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask) TestCase.assertEqual( self, res_nt, nt.nested_tensor([torch.tensor([], dtype=tensor.dtype)], dtype=torch.int64))
def test_ntftm_single_scalar(self): tensor = torch.tensor([1], dtype=torch.float) mask = torch.tensor(True) res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask) TestCase.assertEqual(self, res_nt, nt.nested_tensor([torch.tensor(1)])) mask = torch.tensor([True]) res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask) TestCase.assertEqual(self, res_nt, nt.nested_tensor([torch.tensor(1)])) # Extra dim tensor = torch.tensor([[1]], dtype=torch.float) mask = torch.tensor(True) res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask) TestCase.assertEqual(self, res_nt, nt.nested_tensor([torch.tensor([1])]))
def test_grad_nt_from_tensor_mask(self): def some_func(x): return torch.sum(x**2 + x**3) t1 = torch.tensor([1., 2., 3., 4.], requires_grad=True) t2 = torch.tensor([1., 2., 3.], requires_grad=True) t3 = torch.tensor([1., 2.], requires_grad=True) res1 = some_func(t1) res2 = some_func(t2) res3 = some_func(t3) total_t_sum = res1 + res2 + res3 res1.backward() res2.backward() res3.backward() nt_tensor = torch.tensor([[1., 2., 3., 4.], [1., 2., 3., 0.], [1., 2., 0., 0.]]) # , requires_grad=True) nt_mask = torch.tensor([[True, True, True, True], [True, True, True, False], [True, True, False, False]]) nt = nestedtensor.nested_tensor_from_tensor_mask(nt_tensor, nt_mask) # self.assertTrue(nt.requires_grad) # TODO: Re-enable under autograd self.assertFalse(nt.requires_grad) nt_sum_res = some_func(nt) # nt_sum_res.backward() # TODO: Re-enable under autograd self.assertRaises(RuntimeError, lambda: nt_sum_res.backward()) self.assertEqual(total_t_sum, nt_sum_res)
def test_ntftm_empty3(self): tensor = torch.tensor([0]) mask = torch.tensor(False) res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask) TestCase.assertEqual(self, res_nt, nt.nested_tensor([])) tensor = torch.tensor([[0], [0]]) mask = torch.tensor([[False], [False]]) expected_nt = nt.nested_tensor( [nt.nested_tensor([]), nt.nested_tensor([])]) res_nt = nt.nested_tensor_from_tensor_mask( tensor, mask, nested_dim=expected_nt.nested_dim()) TestCase.assertEqual(self, res_nt, expected_nt)
def test_ntftm_mask_dim_cuda(self): a = nt.nested_tensor([ nt.nested_tensor([ nt.nested_tensor([ torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.float16, device='cuda', requires_grad=False) ]), nt.nested_tensor([ torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.float16, device='cuda', requires_grad=False) ]), nt.nested_tensor([ torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.float16, device='cuda', requires_grad=False) ]), ]) ]) for i in range(a.dim()): t, m = a.to_tensor_mask(mask_dim=i) res_nt = nt.nested_tensor_from_tensor_mask( t, m, nested_dim=a.nested_dim()) TestCase.assertEqual(self, a, res_nt) TestCase.assertEqual(self, res_nt.nested_dim(), a.nested_dim())
def test_grad_nt_from_tensor_mask(self): def some_func(x): return torch.sum(x**2 + x**3) t1 = torch.tensor([1., 2., 3., 4.], requires_grad=True) t2 = torch.tensor([1., 2., 3.], requires_grad=True) t3 = torch.tensor([1., 2.], requires_grad=True) res1 = some_func(t1) res2 = some_func(t2) res3 = some_func(t3) total_t_sum = res1 + res2 + res3 res1.backward() res2.backward() res3.backward() nt_tensor = torch.tensor( [[1., 2., 3., 4.], [1., 2., 3., 0.], [1., 2., 0., 0.]], requires_grad=True) nt_mask = torch.tensor([[True, True, True, True], [True, True, True, False], [True, True, False, False]]) nt = nestedtensor.nested_tensor_from_tensor_mask(nt_tensor, nt_mask) self.assertTrue(nt.requires_grad) nt_sum_res = some_func(nt) nt_sum_res.backward() self.assertEqual(total_t_sum, nt_sum_res) self.assertEqual(nt[0].grad, torch.tensor([5., 16., 33., 56.])) self.assertEqual(nt[1].grad, torch.tensor([5., 16., 33.])) self.assertEqual(nt[2].grad, torch.tensor([5., 16.]))
def test_ntgtm_multi_tensor_mix_mask(self): tensor = torch.tensor([[1], [2], [3], [4]]) mask = torch.tensor([True, False, False, True]) expected_nt = nt.nested_tensor([torch.tensor([1]), torch.tensor([4])]) res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask) TestCase.assertEqual(self, expected_nt, res_nt)
def test_ntgtm_scalar_with_empty_mix_mask(self): tensor = torch.tensor([[0], [11]]) mask = torch.tensor([False, True]) expected_nt1 = nt.nested_tensor([torch.tensor([11], dtype=torch.long)]) expected_nt2 = nt.nested_tensor([ nt.nested_tensor([]), nt.nested_tensor([torch.tensor(11, dtype=torch.long)]) ]) res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask) TestCase.assertEqual(self, expected_nt1, res_nt) res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask, nested_dim=2) TestCase.assertEqual(self, expected_nt2, res_nt)
def test_ntgtm_multi_scalar_mix_mask(self): tensor = torch.tensor([1, 2, 3, 4]) mask = torch.tensor([True, False, False, True]) expected_nt = nt.nested_tensor([torch.tensor(1), torch.tensor(4)]) res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask) TestCase.assertEqual(self, expected_nt, res_nt)
def test_ntftm_multi_tensor_all_false_mask2(self): tensor = torch.tensor([[[1], [2], [3]]]) mask = torch.tensor([[[False], [False], [False]]]) res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask) TestCase.assertEqual( self, res_nt, nt.nested_tensor([torch.empty((3, 0), dtype=tensor.dtype)])) res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask, nested_dim=2) TestCase.assertEqual( self, res_nt, nt.nested_tensor([ nt.nested_tensor([ torch.tensor([], dtype=tensor.dtype), torch.tensor([], dtype=tensor.dtype), torch.tensor([], dtype=tensor.dtype) ]) ]))
def test_ntftm_test_multi_tensor_mix_mask(self): expected_nt1 = nt.nested_tensor( [torch.tensor([1, 2, 3]), torch.tensor([4])]) tensor = torch.tensor([[1, 2, 3], [4, 0, 0]], dtype=torch.float) mask = torch.tensor([[True, True, True], [True, False, False]]) res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask, nested_dim=1) TestCase.assertEqual(self, expected_nt1, res_nt)
def test_ntftm_multi_tensor_scalar_true_mask(self): tensor = torch.tensor([[1], [2], [3]]) mask = torch.tensor(True) res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask) TestCase.assertEqual( self, res_nt, nt.nested_tensor( [torch.tensor([1]), torch.tensor([2]), torch.tensor([3])], dtype=tensor.dtype)) # Extra dim tensor = torch.tensor([[[1]], [[2]], [[3]]]) res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask) expected_res1 = nt.nested_tensor( [torch.tensor([[1]]), torch.tensor([[2]]), torch.tensor([[3]])], dtype=tensor.dtype) TestCase.assertEqual(self, res_nt, expected_res1)
def test_ntftm_multi_tensor_all_false_mask(self): tensor = torch.tensor([[[1], [2], [3]]]) mask = torch.tensor([False]) res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask) TestCase.assertEqual(self, res_nt, nt.nested_tensor([])) res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask, nested_dim=2) TestCase.assertEqual(self, res_nt, nt.nested_tensor([])) mask = torch.tensor([False, False, False]) res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask) TestCase.assertEqual(self, res_nt, nt.nested_tensor([])) mask = torch.tensor([[False], [False], [False]]) res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask) TestCase.assertEqual( self, res_nt, nt.nested_tensor([torch.tensor([], dtype=tensor.dtype)])) res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask, nested_dim=3) TestCase.assertEqual(self, res_nt, nt.nested_tensor([nt.nested_tensor([])])) self.assertRaises( RuntimeError, lambda: nt.nested_tensor_from_tensor_mask( tensor, mask, nested_dim=4))
def test_ntftm_multi_tensor_scalar_true_mask(self): tensor = torch.tensor([[1], [2], [3]]) mask = torch.tensor(True) res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask) TestCase.assertEqual( self, res_nt, nt.nested_tensor( [torch.tensor([1]), torch.tensor([2]), torch.tensor([3])])) # Extra dim tensor = torch.tensor([[[1]], [[2]], [[3]]]) res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask) expected_res1 = nt.nested_tensor( [torch.tensor([[1]]), torch.tensor([[2]]), torch.tensor([[3]])]) TestCase.assertEqual(self, res_nt, expected_res1) res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask, nested_dim=2) expected_res2 = nt.nested_tensor([ nt.nested_tensor([torch.tensor([1])]), nt.nested_tensor([torch.tensor([2])]), nt.nested_tensor([torch.tensor([3])]) ]) TestCase.assertEqual(self, res_nt, expected_res2) res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask, nested_dim=3) expected_res3 = nt.nested_tensor([ nt.nested_tensor([nt.nested_tensor([torch.tensor(1)])]), nt.nested_tensor([nt.nested_tensor([torch.tensor(2)])]), nt.nested_tensor([nt.nested_tensor([torch.tensor(3)])]) ]) TestCase.assertEqual(self, res_nt, expected_res3) self.assertRaises( RuntimeError, lambda: nt.nested_tensor_from_tensor_mask( tensor, mask, nested_dim=4))
def test_ntftm_empty2(self): tensor = torch.tensor([[], []]) expected_nt1 = nt.nested_tensor([ torch.tensor([]), torch.tensor([]), ]) res_nt = nt.nested_tensor_from_tensor_mask(tensor, tensor) TestCase.assertEqual(self, res_nt, expected_nt1) res_nt = nt.nested_tensor_from_tensor_mask(tensor, tensor, nested_dim=1) TestCase.assertEqual(self, res_nt, expected_nt1) res_nt = nt.nested_tensor_from_tensor_mask(tensor, tensor) TestCase.assertEqual(self, res_nt, expected_nt1) res_nt = nt.nested_tensor_from_tensor_mask(tensor, tensor, nested_dim=1) TestCase.assertEqual(self, res_nt, expected_nt1)
def test_to_padded_tensor(self): data1 = torch.tensor([[[0.8413, 0.7325, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000]], [[0.6334, 0.5473, 0.3273, 0.0564], [0.3023, 0.6826, 0.3519, 0.1804], [0.8431, 0.1645, 0.1821, 0.9185]]]) mask1 = torch.tensor([[[True, True, False, False], [False, False, False, False], [False, False, False, False]], [[True, True, True, True], [True, True, True, True], [True, True, True, True]]]) nt2 = nt.nested_tensor_from_tensor_mask(data1, mask1) data2, mask2 = nt2.to_tensor_mask() self.assertEqual(data1, data2) self.assertEqual(mask1, mask2) data3 = nt2.to_padded_tensor(padding=-10) data1 = data1 + ~mask1 * -10 self.assertEqual(data1, data3)
def test_ntftm_test_multi_tensor_mix_mask3(self): expected_nt2 = nt.nested_tensor([ nt.nested_tensor([ torch.tensor([[[1, 2, 3, 4], [5, 6, 7, 8]]]), torch.tensor([[[0, 0], [3, 4]]]), torch.tensor([[[1]]]) ]) ]) expected_nt3 = nt.nested_tensor([ nt.nested_tensor([ nt.nested_tensor([torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]])]), nt.nested_tensor([torch.tensor([[0, 0], [3, 4]])]), nt.nested_tensor([torch.tensor([[1]])]), ]) ]) expected_nt4 = nt.nested_tensor([ nt.nested_tensor([ nt.nested_tensor([ nt.nested_tensor([ torch.tensor([1, 2, 3, 4]), torch.tensor([5, 6, 7, 8]) ]) ]), nt.nested_tensor([ nt.nested_tensor( [torch.tensor([0, 0]), torch.tensor([3, 4])]) ]), nt.nested_tensor([ nt.nested_tensor([ torch.tensor([1]), torch.tensor([], dtype=torch.long) ]) ]) ]) ]) expected_nt5 = nt.nested_tensor([ nt.nested_tensor([ nt.nested_tensor([ nt.nested_tensor([ nt.nested_tensor([ torch.tensor(1), torch.tensor(2), torch.tensor(3), torch.tensor(4) ]), nt.nested_tensor([ torch.tensor(5), torch.tensor(6), torch.tensor(7), torch.tensor(8) ]), ]) ]), nt.nested_tensor([ nt.nested_tensor([ nt.nested_tensor([torch.tensor(0), torch.tensor(0)]), nt.nested_tensor([torch.tensor(3), torch.tensor(4)]) ]) ]), nt.nested_tensor([ nt.nested_tensor([ nt.nested_tensor([torch.tensor(1)]), nt.nested_tensor([]) ]) ]) ]) ]) tensor = torch.tensor([[ [[[1, 2, 3, 4], [5, 6, 7, 8]]], [[[0, 0, 0, 0], [3, 4, 0, 0]]], [[[1, 0, 0, 0], [0, 0, 0, 0]]], ]]) mask = torch.tensor([[[[[True, True, True, True], [True, True, True, True]]], [[[True, True, False, False], [True, True, False, False]]], [[[True, False, False, False], [False, False, False, False]]]]]) self.assertRaises( RuntimeError, lambda: nt.nested_tensor_from_tensor_mask( tensor, mask, nested_dim=1)) res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask, nested_dim=2) TestCase.assertEqual(self, expected_nt2, res_nt) res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask, nested_dim=3) TestCase.assertEqual(self, expected_nt3, res_nt) res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask, nested_dim=4) TestCase.assertEqual(self, expected_nt4, res_nt) res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask, nested_dim=5) TestCase.assertEqual(self, expected_nt5, res_nt) self.assertRaises( RuntimeError, lambda: nt.nested_tensor_from_tensor_mask( tensor, mask, nested_dim=6))
def test_ntftm_single_scalar_error(self): tensor = torch.tensor(1) mask = torch.tensor(True) self.assertRaisesRegex( RuntimeError, "Can't construct nested tensor from a scalar.", lambda: nt.nested_tensor_from_tensor_mask(tensor, mask))
def test_ntftm_single_scalar_mask_false(self): scalar = torch.tensor([1], dtype=torch.uint8) mask = torch.tensor(False) res_nt = nt.nested_tensor_from_tensor_mask(scalar, mask) TestCase.assertEqual(self, res_nt, nt.nested_tensor([]))