def test_nested_constructor(self): for constructor in _iter_constructors(): # TODO: Currently only supporting nested dim 1 num_nested_tensor = 1 # TODO: Shouldn't be constructable [utils.gen_nested_tensor(i, i, 3, constructor=constructor) for i in range(1, num_nested_tensor)]
def test_pin_memory(self): # Check if it can be applied widely nt = utils.gen_nested_tensor(1, 4, 3) nt1 = nt.pin_memory() # Make sure it's actually a copy self.assertFalse(nt.is_pinned()) self.assertTrue(nt1.is_pinned()) a1 = torch.randn(1, 2) a2 = torch.randn(2, 3) nt2 = nestedtensor.as_nested_tensor([a1, a2]) self.assertFalse(a1.is_pinned()) self.assertFalse(a2.is_pinned()) # Double check property transfers nt3 = nt2.pin_memory() self.assertFalse(nt2.is_pinned()) self.assertTrue(nt3.is_pinned()) # Check whether pinned memory is applied to constiuents # and relevant constiuents only. a3, a4 = nt3.unbind() a5, a6 = nt2.unbind() self.assertFalse(a1.is_pinned()) self.assertFalse(a2.is_pinned()) self.assertTrue(a3.is_pinned()) self.assertTrue(a4.is_pinned()) self.assertFalse(a5.is_pinned()) self.assertFalse(a6.is_pinned())
def test_nested_constructor(self): for constructor in _iter_constructors(): num_nested_tensor = 3 # TODO: Shouldn't be constructable [ utils.gen_nested_tensor(i, i, 3, constructor=constructor) for i in range(1, num_nested_tensor) ]
def test_tensor_mask(self): nt = utils.gen_nested_tensor(2, 2, 2, size_low=1, size_high=2) tensor, mask = nt.to_tensor_mask() nt1 = nestedtensor.nested_tensor_from_tensor_mask( tensor, mask, nested_dim=nt.nested_dim()) self.assertEqual(nt, nt1) nt2 = nestedtensor.nested_tensor_from_tensor_mask(tensor, mask) self.assertEqual(nt, nt2)
def test_nested_dim(self): for constructor in _iter_constructors(): nt = constructor([torch.tensor(3)]) for i in range(2, 5): nt = utils.gen_nested_tensor(i, i, 3, constructor=constructor) self.assertEqual(nt.nested_dim(), i)
def test_nested_dim(self): nt = nestedtensor.nested_tensor([torch.tensor(3)]) for i in range(2, 5): nt = utils.gen_nested_tensor(i, i, 3) self.assertEqual(nt.nested_dim(), i)
def test_nested(self): nt = utils.gen_nested_tensor(2, 2, 2) tensor, mask = nt.to_tensor_mask()