Exemplo n.º 1
0
 def test_nested_constructor(self):
     for constructor in _iter_constructors():
         # TODO: Currently only supporting nested dim 1
         num_nested_tensor = 1
         # TODO: Shouldn't be constructable
         [utils.gen_nested_tensor(i, i, 3, constructor=constructor)
          for i in range(1, num_nested_tensor)]
Exemplo n.º 2
0
    def test_pin_memory(self):
        # Check if it can be applied widely
        nt = utils.gen_nested_tensor(1, 4, 3)
        nt1 = nt.pin_memory()

        # Make sure it's actually a copy
        self.assertFalse(nt.is_pinned())
        self.assertTrue(nt1.is_pinned())
        a1 = torch.randn(1, 2)
        a2 = torch.randn(2, 3)
        nt2 = nestedtensor.as_nested_tensor([a1, a2])
        self.assertFalse(a1.is_pinned())
        self.assertFalse(a2.is_pinned())

        # Double check property transfers
        nt3 = nt2.pin_memory()
        self.assertFalse(nt2.is_pinned())
        self.assertTrue(nt3.is_pinned())

        # Check whether pinned memory is applied to constiuents
        # and relevant constiuents only.
        a3, a4 = nt3.unbind()
        a5, a6 = nt2.unbind()
        self.assertFalse(a1.is_pinned())
        self.assertFalse(a2.is_pinned())
        self.assertTrue(a3.is_pinned())
        self.assertTrue(a4.is_pinned())
        self.assertFalse(a5.is_pinned())
        self.assertFalse(a6.is_pinned())
Exemplo n.º 3
0
 def test_nested_constructor(self):
     for constructor in _iter_constructors():
         num_nested_tensor = 3
         # TODO: Shouldn't be constructable
         [
             utils.gen_nested_tensor(i, i, 3, constructor=constructor)
             for i in range(1, num_nested_tensor)
         ]
Exemplo n.º 4
0
 def test_tensor_mask(self):
     nt = utils.gen_nested_tensor(2, 2, 2, size_low=1, size_high=2)
     tensor, mask = nt.to_tensor_mask()
     nt1 = nestedtensor.nested_tensor_from_tensor_mask(
         tensor, mask, nested_dim=nt.nested_dim())
     self.assertEqual(nt, nt1)
     nt2 = nestedtensor.nested_tensor_from_tensor_mask(tensor, mask)
     self.assertEqual(nt, nt2)
Exemplo n.º 5
0
 def test_nested_dim(self):
     for constructor in _iter_constructors():
         nt = constructor([torch.tensor(3)])
         for i in range(2, 5):
             nt = utils.gen_nested_tensor(i, i, 3, constructor=constructor)
             self.assertEqual(nt.nested_dim(), i)
Exemplo n.º 6
0
 def test_nested_dim(self):
     nt = nestedtensor.nested_tensor([torch.tensor(3)])
     for i in range(2, 5):
         nt = utils.gen_nested_tensor(i, i, 3)
         self.assertEqual(nt.nested_dim(), i)
Exemplo n.º 7
0
 def test_nested(self):
     nt = utils.gen_nested_tensor(2, 2, 2)
     tensor, mask = nt.to_tensor_mask()