def test_treespec_equality(self): self.assertTrue(LeafSpec() == LeafSpec()) self.assertTrue(TreeSpec(list, None, []) == TreeSpec(list, None, [])) self.assertTrue( TreeSpec(list, None, [LeafSpec()]) == TreeSpec( list, None, [LeafSpec()])) self.assertFalse(TreeSpec(tuple, None, []) == TreeSpec(list, None, [])) self.assertTrue(TreeSpec(tuple, None, []) != TreeSpec(list, None, []))
def run_test(tup): expected_spec = TreeSpec(tuple, None, [LeafSpec() for _ in tup]) values, treespec = tree_flatten(tup) self.assertTrue(isinstance(values, list)) self.assertEqual(values, list(tup)) self.assertEqual(treespec, expected_spec) unflattened = tree_unflatten(values, treespec) self.assertEqual(unflattened, tup) self.assertTrue(isinstance(unflattened, tuple))
def run_test(lst): expected_spec = TreeSpec(list, None, [LeafSpec() for _ in lst]) values, treespec = tree_flatten(lst) self.assertTrue(isinstance(values, list)) self.assertEqual(values, lst) self.assertEqual(treespec, expected_spec) unflattened = tree_unflatten(values, treespec) self.assertEqual(unflattened, lst) self.assertTrue(isinstance(unflattened, list))
def run_test(tup): expected_spec = TreeSpec(dict, list(tup.keys()), [LeafSpec() for _ in tup.values()]) values, treespec = tree_flatten(tup) self.assertTrue(isinstance(values, list)) self.assertEqual(values, list(tup.values())) self.assertEqual(treespec, expected_spec) unflattened = tree_unflatten(values, treespec) self.assertEqual(unflattened, tup) self.assertTrue(isinstance(unflattened, dict))
def run_test(odict): expected_spec = TreeSpec(OrderedDict, list(odict.keys()), [LeafSpec() for _ in odict.values()]) values, treespec = tree_flatten(odict) self.assertTrue(isinstance(values, list)) self.assertEqual(values, list(odict.values())) self.assertEqual(treespec, expected_spec) unflattened = tree_unflatten(values, treespec) self.assertEqual(unflattened, odict) self.assertTrue(isinstance(unflattened, OrderedDict))