コード例 #1
0
ファイル: test_pytree.py プロジェクト: zacker150/pytorch
 def test_treespec_equality(self):
     self.assertTrue(LeafSpec() == LeafSpec())
     self.assertTrue(TreeSpec(list, None, []) == TreeSpec(list, None, []))
     self.assertTrue(
         TreeSpec(list, None, [LeafSpec()]) == TreeSpec(
             list, None, [LeafSpec()]))
     self.assertFalse(TreeSpec(tuple, None, []) == TreeSpec(list, None, []))
     self.assertTrue(TreeSpec(tuple, None, []) != TreeSpec(list, None, []))
コード例 #2
0
ファイル: test_pytree.py プロジェクト: zacker150/pytorch
        def run_test_with_leaf(leaf):
            values, treespec = tree_flatten(leaf)
            self.assertEqual(values, [leaf])
            self.assertEqual(treespec, LeafSpec())

            unflattened = tree_unflatten(values, treespec)
            self.assertEqual(unflattened, leaf)
コード例 #3
0
ファイル: test_pytree.py プロジェクト: zacker150/pytorch
        def run_test(tup):
            expected_spec = TreeSpec(tuple, None, [LeafSpec() for _ in tup])
            values, treespec = tree_flatten(tup)
            self.assertTrue(isinstance(values, list))
            self.assertEqual(values, list(tup))
            self.assertEqual(treespec, expected_spec)

            unflattened = tree_unflatten(values, treespec)
            self.assertEqual(unflattened, tup)
            self.assertTrue(isinstance(unflattened, tuple))
コード例 #4
0
ファイル: test_pytree.py プロジェクト: zacker150/pytorch
        def run_test(lst):
            expected_spec = TreeSpec(list, None, [LeafSpec() for _ in lst])
            values, treespec = tree_flatten(lst)
            self.assertTrue(isinstance(values, list))
            self.assertEqual(values, lst)
            self.assertEqual(treespec, expected_spec)

            unflattened = tree_unflatten(values, treespec)
            self.assertEqual(unflattened, lst)
            self.assertTrue(isinstance(unflattened, list))
コード例 #5
0
ファイル: test_pytree.py プロジェクト: zacker150/pytorch
        def run_test(tup):
            expected_spec = TreeSpec(dict, list(tup.keys()),
                                     [LeafSpec() for _ in tup.values()])
            values, treespec = tree_flatten(tup)
            self.assertTrue(isinstance(values, list))
            self.assertEqual(values, list(tup.values()))
            self.assertEqual(treespec, expected_spec)

            unflattened = tree_unflatten(values, treespec)
            self.assertEqual(unflattened, tup)
            self.assertTrue(isinstance(unflattened, dict))
コード例 #6
0
ファイル: test_pytree.py プロジェクト: huaxz1986/pytorch
        def run_test(odict):
            expected_spec = TreeSpec(OrderedDict, list(odict.keys()),
                                     [LeafSpec() for _ in odict.values()])
            values, treespec = tree_flatten(odict)
            self.assertTrue(isinstance(values, list))
            self.assertEqual(values, list(odict.values()))
            self.assertEqual(treespec, expected_spec)

            unflattened = tree_unflatten(values, treespec)
            self.assertEqual(unflattened, odict)
            self.assertTrue(isinstance(unflattened, OrderedDict))