Ejemplo n.º 1
0
    def test_ntftm_multi_tensor_true_mask(self):
        extected_nt1 = nt.nested_tensor(
            [torch.tensor([[1]]),
             torch.tensor([[2]]),
             torch.tensor([[3]])])

        tensor = torch.tensor([[[1]], [[2]], [[3]]], dtype=torch.float)

        # Mask dim 3
        mask3 = torch.tensor([[[True]], [[True]], [[True]]])

        res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask3)
        TestCase.assertEqual(self, extected_nt1, res_nt)

        # Mask dim 2
        mask2 = torch.tensor([[True], [True], [True]])

        res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask2)
        TestCase.assertEqual(self, extected_nt1, res_nt)

        # Mask dim 1
        mask1 = torch.tensor([True, True, True])

        res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask1)
        TestCase.assertEqual(self, extected_nt1, res_nt)

        # Mask dim 0
        mask0 = torch.tensor(True)
        res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask0)
        TestCase.assertEqual(self, extected_nt1, res_nt)
    def test_ntftm_empty2(self):
        tensor = torch.tensor([[], []])

        expected_nt1 = nt.nested_tensor([
            torch.tensor([]),
            torch.tensor([]),
        ])

        expected_nt2 = nt.nested_tensor(
            [nt.nested_tensor([]), nt.nested_tensor([])])

        res_nt = nt.nested_tensor_from_tensor_mask(tensor, tensor)
        TestCase.assertEqual(self, res_nt, expected_nt1)

        res_nt = nt.nested_tensor_from_tensor_mask(tensor,
                                                   tensor,
                                                   nested_dim=1)
        TestCase.assertEqual(self, res_nt, expected_nt1)

        res_nt = nt.nested_tensor_from_tensor_mask(tensor,
                                                   tensor,
                                                   nested_dim=2)
        TestCase.assertEqual(self, res_nt, expected_nt2)

        self.assertRaises(
            RuntimeError, lambda: nt.nested_tensor_from_tensor_mask(
                tensor, tensor, nested_dim=3))
 def test_ntftm_none_passed(self):
     self.assertRaises(
         RuntimeError,
         lambda: nt.nested_tensor_from_tensor_mask(None, None))
     self.assertRaises(
         RuntimeError,
         lambda: nt.nested_tensor_from_tensor_mask(torch.tensor([]), None))
Ejemplo n.º 4
0
    def test_ntftm_multi_scalars(self):
        tensor = torch.tensor([1, 2, 3])
        mask = torch.tensor(True)
        res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask)
        TestCase.assertEqual(
            self, res_nt,
            nt.nested_tensor(
                [torch.tensor(1),
                 torch.tensor(2),
                 torch.tensor(3)],
                dtype=torch.int64))

        mask = torch.tensor([True])
        res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask)
        TestCase.assertEqual(
            self, res_nt,
            nt.nested_tensor(
                [torch.tensor(1),
                 torch.tensor(2),
                 torch.tensor(3)],
                dtype=torch.int64))

        self.assertRaises(
            RuntimeError, lambda: nt.nested_tensor_from_tensor_mask(
                tensor, mask, nested_dim=2))

        # Extra dim
        tensor = torch.tensor([[1, 2, 3]])
        mask = torch.tensor(True)
        res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask)
        TestCase.assertEqual(
            self, res_nt,
            nt.nested_tensor([torch.tensor([1, 2, 3])], dtype=torch.int64))
    def test_ntftm_test_multi_tensor_mix_mask2(self):
        expected_nt1 = nt.nested_tensor(
            [torch.tensor([[1, 2, 3]]),
             torch.tensor([[4]])])

        expected_nt2 = nt.nested_tensor([
            nt.nested_tensor([torch.tensor([1, 2, 3])]),
            nt.nested_tensor([torch.tensor([4])])
        ])

        expected_nt3 = nt.nested_tensor([
            nt.nested_tensor([
                nt.nested_tensor(
                    [torch.tensor(1),
                     torch.tensor(2),
                     torch.tensor(3)])
            ]),
            nt.nested_tensor([nt.nested_tensor([torch.tensor(4)])])
        ])

        tensor = torch.tensor([[[1, 2, 3]], [[4, 0, 0]]])
        mask = torch.tensor([[[True, True, True]], [[True, False, False]]])

        res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask, nested_dim=1)
        TestCase.assertEqual(self, expected_nt1, res_nt)

        res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask, nested_dim=2)
        TestCase.assertEqual(self, expected_nt2, res_nt)

        res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask, nested_dim=3)
        TestCase.assertEqual(self, expected_nt3, res_nt)

        self.assertRaises(
            RuntimeError, lambda: nt.nested_tensor_from_tensor_mask(
                tensor, mask, nested_dim=4))
Ejemplo n.º 6
0
 def test_tensor_mask(self):
     nt = utils.gen_nested_tensor(2, 2, 2, size_low=1, size_high=2)
     tensor, mask = nt.to_tensor_mask()
     nt1 = nestedtensor.nested_tensor_from_tensor_mask(
         tensor, mask, nested_dim=nt.nested_dim())
     self.assertEqual(nt, nt1)
     nt2 = nestedtensor.nested_tensor_from_tensor_mask(tensor, mask)
     self.assertEqual(nt, nt2)
    def test_ntftm_single_tensor_all_false_mask(self):
        tensor = torch.tensor([[1]])
        mask = torch.tensor([False])
        res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask)
        TestCase.assertEqual(self, res_nt, nt.nested_tensor([]))

        tensor = torch.tensor([[1, 2, 3]])
        mask = torch.tensor([False])
        res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask)
        TestCase.assertEqual(self, res_nt, nt.nested_tensor([]))
    def test_ntftm_single_tensor_all_true_mask(self):
        tensor = torch.tensor([[1]])
        mask = torch.tensor(True)
        res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask)
        TestCase.assertEqual(self, res_nt,
                             nt.nested_tensor([torch.tensor([1])]))

        mask = torch.tensor([True])
        res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask)
        TestCase.assertEqual(self, res_nt,
                             nt.nested_tensor([torch.tensor([1])]))
    def test_ntftm_empty_error(self):
        tensor = torch.tensor([])
        mask = torch.tensor([True])
        self.assertRaisesRegex(
            RuntimeError, "Data tensor can't be emtpy if a mask has values.",
            lambda: nt.nested_tensor_from_tensor_mask(tensor, mask))

        tensor = torch.tensor([1])
        mask = torch.tensor([])
        self.assertRaisesRegex(
            RuntimeError,
            "Mask tensor can't be emtpy if a data tensor has values.",
            lambda: nt.nested_tensor_from_tensor_mask(tensor, mask))
Ejemplo n.º 10
0
    def test_ntftm_test_multi_tensor_mix_mask2(self):
        expected_nt1 = nt.nested_tensor(
            [torch.tensor([[1, 2, 3]]),
             torch.tensor([[4]])])

        tensor = torch.tensor([[[1, 2, 3]], [[4, 0, 0]]], dtype=torch.float)
        mask = torch.tensor([[[True, True, True]], [[True, False, False]]])

        res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask, nested_dim=1)
        TestCase.assertEqual(self, expected_nt1, res_nt)

        self.assertRaises(
            RuntimeError, lambda: nt.nested_tensor_from_tensor_mask(
                tensor, mask, nested_dim=4))
    def test_ntftm_empty(self):
        tensor = torch.tensor([])

        res_nt = nt.nested_tensor_from_tensor_mask(tensor, tensor)
        TestCase.assertEqual(self, res_nt, nt.nested_tensor([]))
        TestCase.assertEqual(self, res_nt.nested_dim(), 1)

        res_nt = nt.nested_tensor_from_tensor_mask(tensor,
                                                   tensor,
                                                   nested_dim=1)
        TestCase.assertEqual(self, res_nt, nt.nested_tensor([]))
        TestCase.assertEqual(self, res_nt.nested_dim(), 1)

        self.assertRaises(
            RuntimeError, lambda: nt.nested_tensor_from_tensor_mask(
                tensor, tensor, nested_dim=2))
Ejemplo n.º 12
0
    def test_ntftm_multi_tensor_all_false_mask(self):
        tensor = torch.tensor([[[1], [2], [3]]])
        mask = torch.tensor([False])
        res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask)
        TestCase.assertEqual(self, res_nt, nt.nested_tensor([]))

        mask = torch.tensor([False, False, False])
        res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask)
        TestCase.assertEqual(self, res_nt, nt.nested_tensor([]))

        mask = torch.tensor([[False], [False], [False]])
        res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask)
        TestCase.assertEqual(
            self, res_nt,
            nt.nested_tensor([torch.tensor([], dtype=tensor.dtype)],
                             dtype=torch.int64))
Ejemplo n.º 13
0
    def test_ntftm_single_scalar(self):
        tensor = torch.tensor([1], dtype=torch.float)
        mask = torch.tensor(True)
        res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask)
        TestCase.assertEqual(self, res_nt, nt.nested_tensor([torch.tensor(1)]))

        mask = torch.tensor([True])
        res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask)
        TestCase.assertEqual(self, res_nt, nt.nested_tensor([torch.tensor(1)]))

        # Extra dim
        tensor = torch.tensor([[1]], dtype=torch.float)
        mask = torch.tensor(True)
        res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask)
        TestCase.assertEqual(self, res_nt,
                             nt.nested_tensor([torch.tensor([1])]))
    def test_grad_nt_from_tensor_mask(self):
        def some_func(x):
            return torch.sum(x**2 + x**3)

        t1 = torch.tensor([1., 2., 3., 4.], requires_grad=True)
        t2 = torch.tensor([1., 2., 3.], requires_grad=True)
        t3 = torch.tensor([1., 2.], requires_grad=True)

        res1 = some_func(t1)
        res2 = some_func(t2)
        res3 = some_func(t3)
        total_t_sum = res1 + res2 + res3

        res1.backward()
        res2.backward()
        res3.backward()

        nt_tensor = torch.tensor([[1., 2., 3., 4.], [1., 2., 3., 0.],
                                  [1., 2., 0., 0.]])  # , requires_grad=True)
        nt_mask = torch.tensor([[True, True, True, True],
                                [True, True, True, False],
                                [True, True, False, False]])

        nt = nestedtensor.nested_tensor_from_tensor_mask(nt_tensor, nt_mask)
        # self.assertTrue(nt.requires_grad)
        # TODO: Re-enable under autograd
        self.assertFalse(nt.requires_grad)

        nt_sum_res = some_func(nt)
        # nt_sum_res.backward()
        # TODO: Re-enable under autograd
        self.assertRaises(RuntimeError, lambda: nt_sum_res.backward())

        self.assertEqual(total_t_sum, nt_sum_res)
    def test_ntftm_empty3(self):
        tensor = torch.tensor([0])
        mask = torch.tensor(False)

        res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask)
        TestCase.assertEqual(self, res_nt, nt.nested_tensor([]))

        tensor = torch.tensor([[0], [0]])
        mask = torch.tensor([[False], [False]])

        expected_nt = nt.nested_tensor(
            [nt.nested_tensor([]), nt.nested_tensor([])])

        res_nt = nt.nested_tensor_from_tensor_mask(
            tensor, mask, nested_dim=expected_nt.nested_dim())
        TestCase.assertEqual(self, res_nt, expected_nt)
    def test_ntftm_mask_dim_cuda(self):
        a = nt.nested_tensor([
            nt.nested_tensor([
                nt.nested_tensor([
                    torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]],
                                 dtype=torch.float16,
                                 device='cuda',
                                 requires_grad=False)
                ]),
                nt.nested_tensor([
                    torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]],
                                 dtype=torch.float16,
                                 device='cuda',
                                 requires_grad=False)
                ]),
                nt.nested_tensor([
                    torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]],
                                 dtype=torch.float16,
                                 device='cuda',
                                 requires_grad=False)
                ]),
            ])
        ])

        for i in range(a.dim()):
            t, m = a.to_tensor_mask(mask_dim=i)
            res_nt = nt.nested_tensor_from_tensor_mask(
                t, m, nested_dim=a.nested_dim())
            TestCase.assertEqual(self, a, res_nt)
            TestCase.assertEqual(self, res_nt.nested_dim(), a.nested_dim())
Ejemplo n.º 17
0
    def test_grad_nt_from_tensor_mask(self):
        def some_func(x):
            return torch.sum(x**2 + x**3)

        t1 = torch.tensor([1., 2., 3., 4.], requires_grad=True)
        t2 = torch.tensor([1., 2., 3.], requires_grad=True)
        t3 = torch.tensor([1., 2.], requires_grad=True)

        res1 = some_func(t1)
        res2 = some_func(t2)
        res3 = some_func(t3)
        total_t_sum = res1 + res2 + res3

        res1.backward()
        res2.backward()
        res3.backward()

        nt_tensor = torch.tensor(
            [[1., 2., 3., 4.], [1., 2., 3., 0.], [1., 2., 0., 0.]],
            requires_grad=True)
        nt_mask = torch.tensor([[True, True, True, True],
                                [True, True, True, False],
                                [True, True, False, False]])

        nt = nestedtensor.nested_tensor_from_tensor_mask(nt_tensor, nt_mask)
        self.assertTrue(nt.requires_grad)

        nt_sum_res = some_func(nt)
        nt_sum_res.backward()

        self.assertEqual(total_t_sum, nt_sum_res)
        self.assertEqual(nt[0].grad, torch.tensor([5., 16., 33., 56.]))
        self.assertEqual(nt[1].grad, torch.tensor([5., 16., 33.]))
        self.assertEqual(nt[2].grad, torch.tensor([5., 16.]))
    def test_ntgtm_multi_tensor_mix_mask(self):
        tensor = torch.tensor([[1], [2], [3], [4]])
        mask = torch.tensor([True, False, False, True])
        expected_nt = nt.nested_tensor([torch.tensor([1]), torch.tensor([4])])

        res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask)
        TestCase.assertEqual(self, expected_nt, res_nt)
    def test_ntgtm_scalar_with_empty_mix_mask(self):
        tensor = torch.tensor([[0], [11]])
        mask = torch.tensor([False, True])

        expected_nt1 = nt.nested_tensor([torch.tensor([11], dtype=torch.long)])

        expected_nt2 = nt.nested_tensor([
            nt.nested_tensor([]),
            nt.nested_tensor([torch.tensor(11, dtype=torch.long)])
        ])

        res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask)
        TestCase.assertEqual(self, expected_nt1, res_nt)

        res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask, nested_dim=2)
        TestCase.assertEqual(self, expected_nt2, res_nt)
    def test_ntgtm_multi_scalar_mix_mask(self):
        tensor = torch.tensor([1, 2, 3, 4])
        mask = torch.tensor([True, False, False, True])
        expected_nt = nt.nested_tensor([torch.tensor(1), torch.tensor(4)])

        res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask)
        TestCase.assertEqual(self, expected_nt, res_nt)
    def test_ntftm_multi_tensor_all_false_mask2(self):
        tensor = torch.tensor([[[1], [2], [3]]])
        mask = torch.tensor([[[False], [False], [False]]])
        res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask)
        TestCase.assertEqual(
            self, res_nt,
            nt.nested_tensor([torch.empty((3, 0), dtype=tensor.dtype)]))

        res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask, nested_dim=2)
        TestCase.assertEqual(
            self, res_nt,
            nt.nested_tensor([
                nt.nested_tensor([
                    torch.tensor([], dtype=tensor.dtype),
                    torch.tensor([], dtype=tensor.dtype),
                    torch.tensor([], dtype=tensor.dtype)
                ])
            ]))
Ejemplo n.º 22
0
    def test_ntftm_test_multi_tensor_mix_mask(self):
        expected_nt1 = nt.nested_tensor(
            [torch.tensor([1, 2, 3]),
             torch.tensor([4])])

        tensor = torch.tensor([[1, 2, 3], [4, 0, 0]], dtype=torch.float)
        mask = torch.tensor([[True, True, True], [True, False, False]])

        res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask, nested_dim=1)
        TestCase.assertEqual(self, expected_nt1, res_nt)
Ejemplo n.º 23
0
    def test_ntftm_multi_tensor_scalar_true_mask(self):
        tensor = torch.tensor([[1], [2], [3]])
        mask = torch.tensor(True)
        res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask)
        TestCase.assertEqual(
            self, res_nt,
            nt.nested_tensor(
                [torch.tensor([1]),
                 torch.tensor([2]),
                 torch.tensor([3])],
                dtype=tensor.dtype))

        # Extra dim
        tensor = torch.tensor([[[1]], [[2]], [[3]]])
        res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask)
        expected_res1 = nt.nested_tensor(
            [torch.tensor([[1]]),
             torch.tensor([[2]]),
             torch.tensor([[3]])],
            dtype=tensor.dtype)
        TestCase.assertEqual(self, res_nt, expected_res1)
    def test_ntftm_multi_tensor_all_false_mask(self):
        tensor = torch.tensor([[[1], [2], [3]]])
        mask = torch.tensor([False])
        res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask)
        TestCase.assertEqual(self, res_nt, nt.nested_tensor([]))

        res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask, nested_dim=2)
        TestCase.assertEqual(self, res_nt, nt.nested_tensor([]))

        mask = torch.tensor([False, False, False])
        res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask)
        TestCase.assertEqual(self, res_nt, nt.nested_tensor([]))

        mask = torch.tensor([[False], [False], [False]])
        res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask)
        TestCase.assertEqual(
            self, res_nt,
            nt.nested_tensor([torch.tensor([], dtype=tensor.dtype)]))

        res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask, nested_dim=3)
        TestCase.assertEqual(self, res_nt,
                             nt.nested_tensor([nt.nested_tensor([])]))

        self.assertRaises(
            RuntimeError, lambda: nt.nested_tensor_from_tensor_mask(
                tensor, mask, nested_dim=4))
    def test_ntftm_multi_tensor_scalar_true_mask(self):
        tensor = torch.tensor([[1], [2], [3]])
        mask = torch.tensor(True)
        res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask)
        TestCase.assertEqual(
            self, res_nt,
            nt.nested_tensor(
                [torch.tensor([1]),
                 torch.tensor([2]),
                 torch.tensor([3])]))

        # Extra dim
        tensor = torch.tensor([[[1]], [[2]], [[3]]])
        res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask)
        expected_res1 = nt.nested_tensor(
            [torch.tensor([[1]]),
             torch.tensor([[2]]),
             torch.tensor([[3]])])
        TestCase.assertEqual(self, res_nt, expected_res1)

        res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask, nested_dim=2)
        expected_res2 = nt.nested_tensor([
            nt.nested_tensor([torch.tensor([1])]),
            nt.nested_tensor([torch.tensor([2])]),
            nt.nested_tensor([torch.tensor([3])])
        ])
        TestCase.assertEqual(self, res_nt, expected_res2)

        res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask, nested_dim=3)
        expected_res3 = nt.nested_tensor([
            nt.nested_tensor([nt.nested_tensor([torch.tensor(1)])]),
            nt.nested_tensor([nt.nested_tensor([torch.tensor(2)])]),
            nt.nested_tensor([nt.nested_tensor([torch.tensor(3)])])
        ])
        TestCase.assertEqual(self, res_nt, expected_res3)

        self.assertRaises(
            RuntimeError, lambda: nt.nested_tensor_from_tensor_mask(
                tensor, mask, nested_dim=4))
Ejemplo n.º 26
0
    def test_ntftm_empty2(self):
        tensor = torch.tensor([[], []])

        expected_nt1 = nt.nested_tensor([
            torch.tensor([]),
            torch.tensor([]),
        ])

        res_nt = nt.nested_tensor_from_tensor_mask(tensor, tensor)
        TestCase.assertEqual(self, res_nt, expected_nt1)

        res_nt = nt.nested_tensor_from_tensor_mask(tensor,
                                                   tensor,
                                                   nested_dim=1)
        TestCase.assertEqual(self, res_nt, expected_nt1)

        res_nt = nt.nested_tensor_from_tensor_mask(tensor, tensor)
        TestCase.assertEqual(self, res_nt, expected_nt1)

        res_nt = nt.nested_tensor_from_tensor_mask(tensor,
                                                   tensor,
                                                   nested_dim=1)
        TestCase.assertEqual(self, res_nt, expected_nt1)
Ejemplo n.º 27
0
 def test_to_padded_tensor(self):
     data1 = torch.tensor([[[0.8413, 0.7325, 0.0000, 0.0000],
                            [0.0000, 0.0000, 0.0000, 0.0000],
                            [0.0000, 0.0000, 0.0000, 0.0000]],
                           [[0.6334, 0.5473, 0.3273, 0.0564],
                            [0.3023, 0.6826, 0.3519, 0.1804],
                            [0.8431, 0.1645, 0.1821, 0.9185]]])
     mask1 = torch.tensor([[[True, True, False, False],
                            [False, False, False, False],
                            [False, False, False, False]],
                           [[True, True, True, True],
                            [True, True, True, True],
                            [True, True, True, True]]])
     nt2 = nt.nested_tensor_from_tensor_mask(data1, mask1)
     data2, mask2 = nt2.to_tensor_mask()
     self.assertEqual(data1, data2)
     self.assertEqual(mask1, mask2)
     data3 = nt2.to_padded_tensor(padding=-10)
     data1 = data1 + ~mask1 * -10
     self.assertEqual(data1, data3)
    def test_ntftm_test_multi_tensor_mix_mask3(self):
        expected_nt2 = nt.nested_tensor([
            nt.nested_tensor([
                torch.tensor([[[1, 2, 3, 4], [5, 6, 7, 8]]]),
                torch.tensor([[[0, 0], [3, 4]]]),
                torch.tensor([[[1]]])
            ])
        ])

        expected_nt3 = nt.nested_tensor([
            nt.nested_tensor([
                nt.nested_tensor([torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]])]),
                nt.nested_tensor([torch.tensor([[0, 0], [3, 4]])]),
                nt.nested_tensor([torch.tensor([[1]])]),
            ])
        ])

        expected_nt4 = nt.nested_tensor([
            nt.nested_tensor([
                nt.nested_tensor([
                    nt.nested_tensor([
                        torch.tensor([1, 2, 3, 4]),
                        torch.tensor([5, 6, 7, 8])
                    ])
                ]),
                nt.nested_tensor([
                    nt.nested_tensor(
                        [torch.tensor([0, 0]),
                         torch.tensor([3, 4])])
                ]),
                nt.nested_tensor([
                    nt.nested_tensor([
                        torch.tensor([1]),
                        torch.tensor([], dtype=torch.long)
                    ])
                ])
            ])
        ])

        expected_nt5 = nt.nested_tensor([
            nt.nested_tensor([
                nt.nested_tensor([
                    nt.nested_tensor([
                        nt.nested_tensor([
                            torch.tensor(1),
                            torch.tensor(2),
                            torch.tensor(3),
                            torch.tensor(4)
                        ]),
                        nt.nested_tensor([
                            torch.tensor(5),
                            torch.tensor(6),
                            torch.tensor(7),
                            torch.tensor(8)
                        ]),
                    ])
                ]),
                nt.nested_tensor([
                    nt.nested_tensor([
                        nt.nested_tensor([torch.tensor(0),
                                          torch.tensor(0)]),
                        nt.nested_tensor([torch.tensor(3),
                                          torch.tensor(4)])
                    ])
                ]),
                nt.nested_tensor([
                    nt.nested_tensor([
                        nt.nested_tensor([torch.tensor(1)]),
                        nt.nested_tensor([])
                    ])
                ])
            ])
        ])

        tensor = torch.tensor([[
            [[[1, 2, 3, 4], [5, 6, 7, 8]]],
            [[[0, 0, 0, 0], [3, 4, 0, 0]]],
            [[[1, 0, 0, 0], [0, 0, 0, 0]]],
        ]])

        mask = torch.tensor([[[[[True, True, True, True],
                                [True, True, True, True]]],
                              [[[True, True, False, False],
                                [True, True, False, False]]],
                              [[[True, False, False, False],
                                [False, False, False, False]]]]])

        self.assertRaises(
            RuntimeError, lambda: nt.nested_tensor_from_tensor_mask(
                tensor, mask, nested_dim=1))

        res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask, nested_dim=2)
        TestCase.assertEqual(self, expected_nt2, res_nt)

        res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask, nested_dim=3)
        TestCase.assertEqual(self, expected_nt3, res_nt)

        res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask, nested_dim=4)
        TestCase.assertEqual(self, expected_nt4, res_nt)

        res_nt = nt.nested_tensor_from_tensor_mask(tensor, mask, nested_dim=5)
        TestCase.assertEqual(self, expected_nt5, res_nt)

        self.assertRaises(
            RuntimeError, lambda: nt.nested_tensor_from_tensor_mask(
                tensor, mask, nested_dim=6))
 def test_ntftm_single_scalar_error(self):
     tensor = torch.tensor(1)
     mask = torch.tensor(True)
     self.assertRaisesRegex(
         RuntimeError, "Can't construct nested tensor from a scalar.",
         lambda: nt.nested_tensor_from_tensor_mask(tensor, mask))
    def test_ntftm_single_scalar_mask_false(self):
        scalar = torch.tensor([1], dtype=torch.uint8)
        mask = torch.tensor(False)

        res_nt = nt.nested_tensor_from_tensor_mask(scalar, mask)
        TestCase.assertEqual(self, res_nt, nt.nested_tensor([]))