class TreeTest(jtu.JaxTestCase): @parameterized.parameters(*PYTREES) def testRoundtrip(self, inputs): xs, tree = tree_util.tree_flatten(inputs) actual = tree_util.tree_unflatten(tree, xs) self.assertEqual(actual, inputs) @parameterized.parameters( (tree_util.Partial(_dummy_func), ), (tree_util.Partial(_dummy_func, 1, 2), ), (tree_util.Partial(_dummy_func, x="a"), ), (tree_util.Partial(_dummy_func, 1, 2, 3, x=4, y=5), ), ) def testRoundtripPartial(self, inputs): xs, tree = tree_util.tree_flatten(inputs) actual = tree_util.tree_unflatten(tree, xs) # functools.partial does not support equality comparisons: # https://stackoverflow.com/a/32786109/809705 self.assertEqual(actual.func, inputs.func) self.assertEqual(actual.args, inputs.args) self.assertEqual(actual.keywords, inputs.keywords) @parameterized.parameters(*PYTREES) def testRoundtripViaBuild(self, inputs): xs, tree = tree_util.process_pytree(tuple, inputs) actual = tree_util.build_tree(tree, xs) self.assertEqual(actual, inputs) def testChildren(self): _, tree = tree_util.tree_flatten(((1, 2, 3), (4, ))) _, c0 = tree_util.tree_flatten((0, 0, 0)) _, c1 = tree_util.tree_flatten((7, )) self.assertEqual([c0, c1], tree.children())
def testPartialFuncAttributeHasStableHash(self): # https://github.com/google/jax/issues/9429 fun = functools.partial(print, 1) p1 = tree_util.Partial(fun, 2) p2 = tree_util.Partial(fun, 2) self.assertEqual(fun, p1.func) self.assertEqual(p1.func, fun) self.assertEqual(p1.func, p2.func) self.assertEqual(hash(p1.func), hash(p2.func))
def _func_fwd(values): """Converts values to numpy array, applies function and returns array.""" dtype = values.dtype values = np.array(values) obj = cls(values, **kwargs) result = obj.compute() return jnp.array(result, dtype=dtype), tree_util.Partial(obj.vjp)
def testPartialDoesNotMergeWithOtherPartials(self): def f(a, b, c): pass g = functools.partial(f, 2) h = tree_util.Partial(g, 3) self.assertEqual(h.args, (3, ))
def test_kohn_sham_neural_xc_density_mse_converge_tolerance( self, density_mse_converge_tolerance, expected_converged): init_fn, xc_energy_density_fn = neural_xc.local_density_approximation( stax.serial(stax.Dense(8), stax.Elu, stax.Dense(1))) params_init = init_fn(rng=random.PRNGKey(0)) states = jit_scf.kohn_sham( locations=self.locations, nuclear_charges=self.nuclear_charges, num_electrons=self.num_electrons, num_iterations=3, grids=self.grids, xc_energy_density_fn=tree_util.Partial(xc_energy_density_fn, params=params_init), interaction_fn=utils.exponential_coulomb, initial_density=self.num_electrons * utils.gaussian(grids=self.grids, center=0., sigma=0.5), density_mse_converge_tolerance=density_mse_converge_tolerance) np.testing.assert_array_equal(states.converged, expected_converged) for single_state in scf.state_iterator(states): self._test_state( single_state, self._create_testing_external_potential( utils.exponential_coulomb))
def loss(flatten_params, initial_state, target_energy): state = scf.kohn_sham_iteration( state=initial_state, num_electrons=self.num_electrons, xc_energy_density_fn=tree_util.Partial( xc_energy_density_fn, params=np_utils.unflatten(spec, flatten_params)), interaction_fn=utils.exponential_coulomb, enforce_reflection_symmetry=True) return (state.total_energy - target_energy) ** 2
def loss(flatten_params, initial_state, target_density): state = scf.kohn_sham_iteration( state=initial_state, num_electrons=self.num_electrons, xc_energy_density_fn=tree_util.Partial( xc_energy_density_fn, params=np_utils.unflatten(spec, flatten_params)), interaction_fn=utils.exponential_coulomb, enforce_reflection_symmetry=False) return jnp.sum(jnp.abs(state.density - target_density)) * utils.get_dx( self.grids)
def test_kohn_sham_iteration( self, interaction_fn, enforce_reflection_symmetry): initial_state = self._create_testing_initial_state(interaction_fn) next_state = scf.kohn_sham_iteration( state=initial_state, num_electrons=self.num_electrons, # Use 3d LDA exchange functional and zero correlation functional. xc_energy_density_fn=tree_util.Partial( lambda density: -0.73855 * density ** (1 / 3)), interaction_fn=interaction_fn, enforce_reflection_symmetry=enforce_reflection_symmetry) self._test_state(next_state, initial_state)
class TreeTest(jtu.JaxTestCase): @parameterized.parameters(((1, 2), ), ([3], ), ({'a': 1, 'b': 2}, )) def testRoundtrip(self, inputs): xs, tree = tree_util.tree_flatten(inputs) actual = tree_util.tree_unflatten(tree, xs) self.assertEqual(actual, inputs) @parameterized.parameters( (tree_util.Partial(_dummy_func), ), (tree_util.Partial(_dummy_func, 1, 2), ), (tree_util.Partial(_dummy_func, x='a'), ), (tree_util.Partial(_dummy_func, 1, 2, 3, x=4, y=5), ), ) def testRoundtrip(self, inputs): xs, tree = tree_util.tree_flatten(inputs) actual = tree_util.tree_unflatten(tree, xs) # functools.partial does not support equality comparisons: # https://stackoverflow.com/a/32786109/809705 self.assertEqual(actual.func, inputs.func) self.assertEqual(actual.args, inputs.args) self.assertEqual(actual.keywords, inputs.keywords)
def loss(flatten_params, target_energy): state = scf.kohn_sham( locations=self.locations, nuclear_charges=self.nuclear_charges, num_electrons=self.num_electrons, num_iterations=3, grids=self.grids, xc_energy_density_fn=tree_util.Partial( xc_energy_density_fn, params=np_utils.unflatten(spec, flatten_params)), interaction_fn=utils.exponential_coulomb) final_state = scf.get_final_state(state) return (final_state.total_energy - target_energy) ** 2
def test_kohn_sham_iteration_neural_xc(self, enforce_reflection_symmetry): init_fn, xc_energy_density_fn = neural_xc.local_density_approximation( stax.serial(stax.Dense(8), stax.Elu, stax.Dense(1))) params_init = init_fn(rng=random.PRNGKey(0)) initial_state = self._create_testing_initial_state( utils.exponential_coulomb) next_state = jit_scf.kohn_sham_iteration( state=initial_state, num_electrons=self.num_electrons, xc_energy_density_fn=tree_util.Partial(xc_energy_density_fn, params=params_init), interaction_fn=utils.exponential_coulomb, enforce_reflection_symmetry=enforce_reflection_symmetry) self._test_state(next_state, initial_state)
def test_kohn_sham_convergence( self, density_mse_converge_tolerance, expected_converged): state = scf.kohn_sham( locations=self.locations, nuclear_charges=self.nuclear_charges, num_electrons=self.num_electrons, num_iterations=3, grids=self.grids, # Use 3d LDA exchange functional and zero correlation functional. xc_energy_density_fn=tree_util.Partial( lambda density: -0.73855 * density ** (1 / 3)), interaction_fn=utils.exponential_coulomb, density_mse_converge_tolerance=density_mse_converge_tolerance) np.testing.assert_allclose(state.converged, expected_converged)
def test_kohn_sham(self, interaction_fn): state = scf.kohn_sham( locations=self.locations, nuclear_charges=self.nuclear_charges, num_electrons=self.num_electrons, num_iterations=3, grids=self.grids, # Use 3d LDA exchange functional and zero correlation functional. xc_energy_density_fn=tree_util.Partial( lambda density: -0.73855 * density ** (1 / 3)), interaction_fn=interaction_fn) for single_state in scf.state_iterator(state): self._test_state( single_state, self._create_testing_external_potential(interaction_fn))
def loss(flatten_params, target_density): state = scf.kohn_sham(locations=self.locations, nuclear_charges=self.nuclear_charges, num_electrons=self.num_electrons, num_iterations=3, grids=self.grids, xc_energy_density_fn=tree_util.Partial( xc_energy_density_fn, params=np_utils.unflatten( spec, flatten_params)), interaction_fn=utils.exponential_coulomb, density_mse_converge_tolerance=-1) final_state = scf.get_final_state(state) return jnp.sum(jnp.abs(final_state.density - target_density)) * utils.get_dx(self.grids)
def test_kohn_sham_neural_xc(self, interaction_fn): init_fn, xc_energy_density_fn = neural_xc.local_density_approximation( stax.serial(stax.Dense(8), stax.Elu, stax.Dense(1))) params_init = init_fn(rng=random.PRNGKey(0)) state = scf.kohn_sham(locations=self.locations, nuclear_charges=self.nuclear_charges, num_electrons=self.num_electrons, num_iterations=3, grids=self.grids, xc_energy_density_fn=tree_util.Partial( xc_energy_density_fn, params=params_init), interaction_fn=interaction_fn) for single_state in scf.state_iterator(state): self._test_state( single_state, self._create_testing_external_potential(interaction_fn))
class TreeTest(jtu.JaxTestCase): @parameterized.parameters(*(TREES + LEAVES)) def testRoundtrip(self, inputs): xs, tree = tree_util.tree_flatten(inputs) actual = tree_util.tree_unflatten(tree, xs) self.assertEqual(actual, inputs) @parameterized.parameters(*(TREES + LEAVES)) def testRoundtripWithFlattenUpTo(self, inputs): _, tree = tree_util.tree_flatten(inputs) if not hasattr(tree, "flatten_up_to"): self.skipTest("Test requires Jaxlib >= 0.1.23") xs = tree.flatten_up_to(inputs) actual = tree_util.tree_unflatten(tree, xs) self.assertEqual(actual, inputs) @parameterized.parameters( (tree_util.Partial(_dummy_func),), (tree_util.Partial(_dummy_func, 1, 2),), (tree_util.Partial(_dummy_func, x="a"),), (tree_util.Partial(_dummy_func, 1, 2, 3, x=4, y=5),), ) def testRoundtripPartial(self, inputs): xs, tree = tree_util.tree_flatten(inputs) actual = tree_util.tree_unflatten(tree, xs) # functools.partial does not support equality comparisons: # https://stackoverflow.com/a/32786109/809705 self.assertEqual(actual.func, inputs.func) self.assertEqual(actual.args, inputs.args) self.assertEqual(actual.keywords, inputs.keywords) @parameterized.parameters(*(TREES + LEAVES)) def testRoundtripViaBuild(self, inputs): xs, tree = tree_util._process_pytree(tuple, inputs) actual = tree_util.build_tree(tree, xs) self.assertEqual(actual, inputs) def testChildren(self): _, tree = tree_util.tree_flatten(((1, 2, 3), (4,))) _, c0 = tree_util.tree_flatten((0, 0, 0)) _, c1 = tree_util.tree_flatten((7,)) if not callable(tree.children): self.skipTest("Test requires Jaxlib >= 0.1.23") self.assertEqual([c0, c1], tree.children()) def testFlattenUpTo(self): _, tree = tree_util.tree_flatten([(1, 2), None, ATuple(foo=3, bar=7)]) if not hasattr(tree, "flatten_up_to"): self.skipTest("Test requires Jaxlib >= 0.1.23") out = tree.flatten_up_to([({ "foo": 7 }, (3, 4)), None, ATuple(foo=(11, 9), bar=None)]) self.assertEqual(out, [{"foo": 7}, (3, 4), (11, 9), None]) def testTreeMultimap(self): x = ((1, 2), (3, 4, 5)) y = (([3], None), ({"foo": "bar"}, 7, [5, 6])) out = tree_util.tree_multimap(lambda *xs: tuple(xs), x, y) self.assertEqual(out, (((1, [3]), (2, None)), ((3, {"foo": "bar"}), (4, 7), (5, [5, 6])))) @parameterized.parameters(*TREES) def testAllLeavesWithTrees(self, tree): leaves = tree_util.tree_leaves(tree) self.assertTrue(tree_util.all_leaves(leaves)) self.assertFalse(tree_util.all_leaves([tree])) @parameterized.parameters(*LEAVES) def testAllLeavesWithLeaves(self, leaf): self.assertTrue(tree_util.all_leaves([leaf])) @parameterized.parameters(*TREES) def testCompose(self, tree): treedef = tree_util.tree_structure(tree) inner_treedef = tree_util.tree_structure(["*", "*", "*"]) composed_treedef = treedef.compose(inner_treedef) expected_leaves = treedef.num_leaves * inner_treedef.num_leaves self.assertEqual(composed_treedef.num_leaves, expected_leaves) expected_nodes = ((treedef.num_nodes - treedef.num_leaves) + (inner_treedef.num_nodes * treedef.num_leaves)) self.assertEqual(composed_treedef.num_nodes, expected_nodes) leaves = [1] * expected_leaves composed = tree_util.tree_unflatten(composed_treedef, leaves) self.assertEqual(leaves, tree_util.tree_leaves(composed)) @parameterized.parameters(*TREES) def testTranspose(self, tree): outer_treedef = tree_util.tree_structure(tree) if not outer_treedef.num_leaves: self.skipTest("Skipping empty tree") inner_treedef = tree_util.tree_structure([1, 1, 1]) nested = tree_util.tree_map(lambda x: [x, x, x], tree) actual = tree_util.tree_transpose(outer_treedef, inner_treedef, nested) self.assertEqual(actual, [tree, tree, tree]) def testTransposeMismatchOuter(self): tree = {"a": [1, 2], "b": [3, 4]} outer_treedef = tree_util.tree_structure({"a": 1, "b": 2, "c": 3}) inner_treedef = tree_util.tree_structure([1, 2]) with self.assertRaisesRegex(TypeError, "Mismatch"): tree_util.tree_transpose(outer_treedef, inner_treedef, tree) def testTransposeMismatchInner(self): tree = {"a": [1, 2], "b": [3, 4]} outer_treedef = tree_util.tree_structure({"a": 1, "b": 2}) inner_treedef = tree_util.tree_structure([1, 2, 3]) with self.assertRaisesRegex(TypeError, "Mismatch"): tree_util.tree_transpose(outer_treedef, inner_treedef, tree) def testTransposeWithCustomObject(self): outer_treedef = tree_util.tree_structure(FlatCache({"a": 1, "b": 2})) inner_treedef = tree_util.tree_structure([1, 2]) expected = [FlatCache({"a": 3, "b": 5}), FlatCache({"a": 4, "b": 6})] actual = tree_util.tree_transpose(outer_treedef, inner_treedef, FlatCache({"a": [3, 4], "b": [5, 6]})) self.assertEqual(expected, actual)
class TreeTest(jtu.JaxTestCase): @parameterized.parameters(*PYTREES) def testRoundtrip(self, inputs): xs, tree = tree_util.tree_flatten(inputs) actual = tree_util.tree_unflatten(tree, xs) self.assertEqual(actual, inputs) @parameterized.parameters(*PYTREES) def testRoundtripWithFlattenUpTo(self, inputs): _, tree = tree_util.tree_flatten(inputs) if not hasattr(tree, "flatten_up_to"): self.skipTest("Test requires Jaxlib >= 0.1.23") xs = tree.flatten_up_to(inputs) actual = tree_util.tree_unflatten(tree, xs) self.assertEqual(actual, inputs) @parameterized.parameters( (tree_util.Partial(_dummy_func),), (tree_util.Partial(_dummy_func, 1, 2),), (tree_util.Partial(_dummy_func, x="a"),), (tree_util.Partial(_dummy_func, 1, 2, 3, x=4, y=5),), ) def testRoundtripPartial(self, inputs): xs, tree = tree_util.tree_flatten(inputs) actual = tree_util.tree_unflatten(tree, xs) # functools.partial does not support equality comparisons: # https://stackoverflow.com/a/32786109/809705 self.assertEqual(actual.func, inputs.func) self.assertEqual(actual.args, inputs.args) self.assertEqual(actual.keywords, inputs.keywords) @parameterized.parameters(*PYTREES) def testRoundtripViaBuild(self, inputs): xs, tree = tree_util._process_pytree(tuple, inputs) actual = tree_util.build_tree(tree, xs) self.assertEqual(actual, inputs) def testChildren(self): _, tree = tree_util.tree_flatten(((1, 2, 3), (4,))) _, c0 = tree_util.tree_flatten((0, 0, 0)) _, c1 = tree_util.tree_flatten((7,)) if not callable(tree.children): self.skipTest("Test requires Jaxlib >= 0.1.23") self.assertEqual([c0, c1], tree.children()) def testFlattenUpTo(self): _, tree = tree_util.tree_flatten([(1, 2), None, ATuple(foo=3, bar=7)]) if not hasattr(tree, "flatten_up_to"): self.skipTest("Test requires Jaxlib >= 0.1.23") out = tree.flatten_up_to([({ "foo": 7 }, (3, 4)), None, ATuple(foo=(11, 9), bar=None)]) self.assertEqual(out, [{"foo": 7}, (3, 4), (11, 9), None]) def testTreeMultimap(self): x = ((1, 2), (3, 4, 5)) y = (([3], None), ({"foo": "bar"}, 7, [5, 6])) out = tree_util.tree_multimap(lambda *xs: tuple(xs), x, y) self.assertEqual(out, (((1, [3]), (2, None)), ((3, {"foo": "bar"}), (4, 7), (5, [5, 6]))))
class TreeTest(jtu.JaxTestCase): @parameterized.parameters(*(TREES + LEAVES)) def testRoundtrip(self, inputs): xs, tree = tree_util.tree_flatten(inputs) actual = tree_util.tree_unflatten(tree, xs) self.assertEqual(actual, inputs) @parameterized.parameters(*(TREES + LEAVES)) def testRoundtripWithFlattenUpTo(self, inputs): _, tree = tree_util.tree_flatten(inputs) xs = tree.flatten_up_to(inputs) actual = tree_util.tree_unflatten(tree, xs) self.assertEqual(actual, inputs) @parameterized.parameters( (tree_util.Partial(_dummy_func),), (tree_util.Partial(_dummy_func, 1, 2),), (tree_util.Partial(_dummy_func, x="a"),), (tree_util.Partial(_dummy_func, 1, 2, 3, x=4, y=5),), ) def testRoundtripPartial(self, inputs): xs, tree = tree_util.tree_flatten(inputs) actual = tree_util.tree_unflatten(tree, xs) # functools.partial does not support equality comparisons: # https://stackoverflow.com/a/32786109/809705 self.assertEqual(actual.func, inputs.func) self.assertEqual(actual.args, inputs.args) self.assertEqual(actual.keywords, inputs.keywords) def testPartialDoesNotMergeWithOtherPartials(self): def f(a, b, c): pass g = functools.partial(f, 2) h = tree_util.Partial(g, 3) self.assertEqual(h.args, (3,)) def testPartialFuncAttributeHasStableHash(self): # https://github.com/google/jax/issues/9429 fun = functools.partial(print, 1) p1 = tree_util.Partial(fun, 2) p2 = tree_util.Partial(fun, 2) self.assertEqual(fun, p1.func) self.assertEqual(p1.func, fun) self.assertEqual(p1.func, p2.func) self.assertEqual(hash(p1.func), hash(p2.func)) @parameterized.parameters(*(TREES + LEAVES)) def testRoundtripViaBuild(self, inputs): xs, tree = _process_pytree(tuple, inputs) actual = tree_util.build_tree(tree, xs) self.assertEqual(actual, inputs) def testChildren(self): _, tree = tree_util.tree_flatten(((1, 2, 3), (4,))) _, c0 = tree_util.tree_flatten((0, 0, 0)) _, c1 = tree_util.tree_flatten((7,)) self.assertEqual([c0, c1], tree.children()) def testTreedefTupleFromChildren(self): # https://github.com/google/jax/issues/7377 tree = ((1, 2, (3, 4)), (5,)) leaves, treedef1 = tree_util.tree_flatten(tree) treedef2 = tree_util.treedef_tuple(treedef1.children()) self.assertEqual(treedef1.num_leaves, len(leaves)) self.assertEqual(treedef1.num_leaves, treedef2.num_leaves) self.assertEqual(treedef1.num_nodes, treedef2.num_nodes) def testTreedefTupleComparesEqual(self): # https://github.com/google/jax/issues/9066 self.assertEqual(tree_util.tree_structure((3,)), tree_util.treedef_tuple((tree_util.tree_structure(3),))) def testFlattenUpTo(self): _, tree = tree_util.tree_flatten([(1, 2), None, ATuple(foo=3, bar=7)]) out = tree.flatten_up_to([({ "foo": 7 }, (3, 4)), None, ATuple(foo=(11, 9), bar=None)]) self.assertEqual(out, [{"foo": 7}, (3, 4), (11, 9), None]) def testTreeMultimap(self): x = ((1, 2), (3, 4, 5)) y = (([3], None), ({"foo": "bar"}, 7, [5, 6])) out = tree_util.tree_multimap(lambda *xs: tuple(xs), x, y) self.assertEqual(out, (((1, [3]), (2, None)), ((3, {"foo": "bar"}), (4, 7), (5, [5, 6])))) def testTreeMultimapWithIsLeafArgument(self): x = ((1, 2), [3, 4, 5]) y = (([3], None), ({"foo": "bar"}, 7, [5, 6])) out = tree_util.tree_multimap(lambda *xs: tuple(xs), x, y, is_leaf=lambda n: isinstance(n, list)) self.assertEqual(out, (((1, [3]), (2, None)), (([3, 4, 5], ({"foo": "bar"}, 7, [5, 6]))))) @parameterized.parameters( tree_util.tree_leaves, lambda tree, is_leaf: tree_util.tree_flatten(tree, is_leaf)[0]) def testFlattenIsLeaf(self, leaf_fn): x = [(1, 2), (3, 4), (5, 6)] leaves = leaf_fn(x, is_leaf=lambda t: False) self.assertEqual(leaves, [1, 2, 3, 4, 5, 6]) leaves = leaf_fn(x, is_leaf=lambda t: isinstance(t, tuple)) self.assertEqual(leaves, x) leaves = leaf_fn(x, is_leaf=lambda t: isinstance(t, list)) self.assertEqual(leaves, [x]) leaves = leaf_fn(x, is_leaf=lambda t: True) self.assertEqual(leaves, [x]) y = [[[(1,)], [[(2,)], {"a": (3,)}]]] leaves = leaf_fn(y, is_leaf=lambda t: isinstance(t, tuple)) self.assertEqual(leaves, [(1,), (2,), (3,)]) @parameterized.parameters( tree_util.tree_structure, lambda tree, is_leaf: tree_util.tree_flatten(tree, is_leaf)[1]) def testStructureIsLeaf(self, structure_fn): x = [(1, 2), (3, 4), (5, 6)] treedef = structure_fn(x, is_leaf=lambda t: False) self.assertEqual(treedef.num_leaves, 6) treedef = structure_fn(x, is_leaf=lambda t: isinstance(t, tuple)) self.assertEqual(treedef.num_leaves, 3) treedef = structure_fn(x, is_leaf=lambda t: isinstance(t, list)) self.assertEqual(treedef.num_leaves, 1) treedef = structure_fn(x, is_leaf=lambda t: True) self.assertEqual(treedef.num_leaves, 1) y = [[[(1,)], [[(2,)], {"a": (3,)}]]] treedef = structure_fn(y, is_leaf=lambda t: isinstance(t, tuple)) self.assertEqual(treedef.num_leaves, 3) @parameterized.parameters(*TREES) def testRoundtripIsLeaf(self, tree): xs, treedef = tree_util.tree_flatten( tree, is_leaf=lambda t: isinstance(t, tuple)) recon_tree = tree_util.tree_unflatten(treedef, xs) self.assertEqual(recon_tree, tree) @parameterized.parameters(*TREES) def testAllLeavesWithTrees(self, tree): leaves = tree_util.tree_leaves(tree) self.assertTrue(tree_util.all_leaves(leaves)) self.assertFalse(tree_util.all_leaves([tree])) @parameterized.parameters(*LEAVES) def testAllLeavesWithLeaves(self, leaf): self.assertTrue(tree_util.all_leaves([leaf])) @parameterized.parameters(*TREES) def testCompose(self, tree): treedef = tree_util.tree_structure(tree) inner_treedef = tree_util.tree_structure(["*", "*", "*"]) composed_treedef = treedef.compose(inner_treedef) expected_leaves = treedef.num_leaves * inner_treedef.num_leaves self.assertEqual(composed_treedef.num_leaves, expected_leaves) expected_nodes = ((treedef.num_nodes - treedef.num_leaves) + (inner_treedef.num_nodes * treedef.num_leaves)) self.assertEqual(composed_treedef.num_nodes, expected_nodes) leaves = [1] * expected_leaves composed = tree_util.tree_unflatten(composed_treedef, leaves) self.assertEqual(leaves, tree_util.tree_leaves(composed)) @parameterized.parameters(*TREES) def testTranspose(self, tree): outer_treedef = tree_util.tree_structure(tree) if not outer_treedef.num_leaves: self.skipTest("Skipping empty tree") inner_treedef = tree_util.tree_structure([1, 1, 1]) nested = tree_util.tree_map(lambda x: [x, x, x], tree) actual = tree_util.tree_transpose(outer_treedef, inner_treedef, nested) self.assertEqual(actual, [tree, tree, tree]) def testTransposeMismatchOuter(self): tree = {"a": [1, 2], "b": [3, 4]} outer_treedef = tree_util.tree_structure({"a": 1, "b": 2, "c": 3}) inner_treedef = tree_util.tree_structure([1, 2]) with self.assertRaisesRegex(TypeError, "Mismatch"): tree_util.tree_transpose(outer_treedef, inner_treedef, tree) def testTransposeMismatchInner(self): tree = {"a": [1, 2], "b": [3, 4]} outer_treedef = tree_util.tree_structure({"a": 1, "b": 2}) inner_treedef = tree_util.tree_structure([1, 2, 3]) with self.assertRaisesRegex(TypeError, "Mismatch"): tree_util.tree_transpose(outer_treedef, inner_treedef, tree) def testTransposeWithCustomObject(self): outer_treedef = tree_util.tree_structure(FlatCache({"a": 1, "b": 2})) inner_treedef = tree_util.tree_structure([1, 2]) expected = [FlatCache({"a": 3, "b": 5}), FlatCache({"a": 4, "b": 6})] actual = tree_util.tree_transpose(outer_treedef, inner_treedef, FlatCache({"a": [3, 4], "b": [5, 6]})) self.assertEqual(expected, actual) @parameterized.parameters([(*t, s) for t, s in zip(TREES, TREE_STRINGS)]) def testStringRepresentation(self, tree, correct_string): """Checks that the string representation of a tree works.""" treedef = tree_util.tree_structure(tree) self.assertRegex(str(treedef), correct_string) def testTreeDefWithEmptyDictStringRepresentation(self): self.assertEqual(str(tree_util.tree_structure({})), "PyTreeDef({})")
class TreeTest(jtu.JaxTestCase): @parameterized.parameters(*(TREES + LEAVES)) def testRoundtrip(self, inputs): xs, tree = tree_util.tree_flatten(inputs) actual = tree_util.tree_unflatten(tree, xs) self.assertEqual(actual, inputs) @parameterized.parameters(*(TREES + LEAVES)) def testRoundtripWithFlattenUpTo(self, inputs): _, tree = tree_util.tree_flatten(inputs) xs = tree.flatten_up_to(inputs) actual = tree_util.tree_unflatten(tree, xs) self.assertEqual(actual, inputs) @parameterized.parameters( (tree_util.Partial(_dummy_func), ), (tree_util.Partial(_dummy_func, 1, 2), ), (tree_util.Partial(_dummy_func, x="a"), ), (tree_util.Partial(_dummy_func, 1, 2, 3, x=4, y=5), ), ) def testRoundtripPartial(self, inputs): xs, tree = tree_util.tree_flatten(inputs) actual = tree_util.tree_unflatten(tree, xs) # functools.partial does not support equality comparisons: # https://stackoverflow.com/a/32786109/809705 self.assertEqual(actual.func, inputs.func) self.assertEqual(actual.args, inputs.args) self.assertEqual(actual.keywords, inputs.keywords) @parameterized.parameters(*(TREES + LEAVES)) def testRoundtripViaBuild(self, inputs): xs, tree = _process_pytree(tuple, inputs) actual = tree_util.build_tree(tree, xs) self.assertEqual(actual, inputs) def testChildren(self): _, tree = tree_util.tree_flatten(((1, 2, 3), (4, ))) _, c0 = tree_util.tree_flatten((0, 0, 0)) _, c1 = tree_util.tree_flatten((7, )) self.assertEqual([c0, c1], tree.children()) def testFlattenUpTo(self): _, tree = tree_util.tree_flatten([(1, 2), None, ATuple(foo=3, bar=7)]) out = tree.flatten_up_to([({ "foo": 7 }, (3, 4)), None, ATuple(foo=(11, 9), bar=None)]) self.assertEqual(out, [{"foo": 7}, (3, 4), (11, 9), None]) def testTreeMultimap(self): x = ((1, 2), (3, 4, 5)) y = (([3], None), ({"foo": "bar"}, 7, [5, 6])) out = tree_util.tree_multimap(lambda *xs: tuple(xs), x, y) self.assertEqual(out, (((1, [3]), (2, None)), ((3, { "foo": "bar" }), (4, 7), (5, [5, 6])))) def testTreeMultimapWithIsLeafArgument(self): x = ((1, 2), [3, 4, 5]) y = (([3], None), ({"foo": "bar"}, 7, [5, 6])) out = tree_util.tree_multimap(lambda *xs: tuple(xs), x, y, is_leaf=lambda n: isinstance(n, list)) self.assertEqual(out, (((1, [3]), (2, None)), (([3, 4, 5], ({ "foo": "bar" }, 7, [5, 6]))))) def testFlattenIsLeaf(self): x = [(1, 2), (3, 4), (5, 6)] leaves, _ = tree_util.tree_flatten(x, is_leaf=lambda t: False) self.assertEqual(leaves, [1, 2, 3, 4, 5, 6]) leaves, _ = tree_util.tree_flatten( x, is_leaf=lambda t: isinstance(t, tuple)) self.assertEqual(leaves, x) leaves, _ = tree_util.tree_flatten( x, is_leaf=lambda t: isinstance(t, list)) self.assertEqual(leaves, [x]) leaves, _ = tree_util.tree_flatten(x, is_leaf=lambda t: True) self.assertEqual(leaves, [x]) y = [[[(1, )], [[(2, )], {"a": (3, )}]]] leaves, _ = tree_util.tree_flatten( y, is_leaf=lambda t: isinstance(t, tuple)) self.assertEqual(leaves, [(1, ), (2, ), (3, )]) @parameterized.parameters(*TREES) def testRoundtripIsLeaf(self, tree): xs, treedef = tree_util.tree_flatten( tree, is_leaf=lambda t: isinstance(t, tuple)) recon_tree = tree_util.tree_unflatten(treedef, xs) self.assertEqual(recon_tree, tree) @parameterized.parameters(*TREES) def testAllLeavesWithTrees(self, tree): leaves = tree_util.tree_leaves(tree) self.assertTrue(tree_util.all_leaves(leaves)) self.assertFalse(tree_util.all_leaves([tree])) @parameterized.parameters(*LEAVES) def testAllLeavesWithLeaves(self, leaf): self.assertTrue(tree_util.all_leaves([leaf])) @parameterized.parameters(*TREES) def testCompose(self, tree): treedef = tree_util.tree_structure(tree) inner_treedef = tree_util.tree_structure(["*", "*", "*"]) composed_treedef = treedef.compose(inner_treedef) expected_leaves = treedef.num_leaves * inner_treedef.num_leaves self.assertEqual(composed_treedef.num_leaves, expected_leaves) expected_nodes = ((treedef.num_nodes - treedef.num_leaves) + (inner_treedef.num_nodes * treedef.num_leaves)) self.assertEqual(composed_treedef.num_nodes, expected_nodes) leaves = [1] * expected_leaves composed = tree_util.tree_unflatten(composed_treedef, leaves) self.assertEqual(leaves, tree_util.tree_leaves(composed)) @parameterized.parameters(*TREES) def testTranspose(self, tree): outer_treedef = tree_util.tree_structure(tree) if not outer_treedef.num_leaves: self.skipTest("Skipping empty tree") inner_treedef = tree_util.tree_structure([1, 1, 1]) nested = tree_util.tree_map(lambda x: [x, x, x], tree) actual = tree_util.tree_transpose(outer_treedef, inner_treedef, nested) self.assertEqual(actual, [tree, tree, tree]) def testTransposeMismatchOuter(self): tree = {"a": [1, 2], "b": [3, 4]} outer_treedef = tree_util.tree_structure({"a": 1, "b": 2, "c": 3}) inner_treedef = tree_util.tree_structure([1, 2]) with self.assertRaisesRegex(TypeError, "Mismatch"): tree_util.tree_transpose(outer_treedef, inner_treedef, tree) def testTransposeMismatchInner(self): tree = {"a": [1, 2], "b": [3, 4]} outer_treedef = tree_util.tree_structure({"a": 1, "b": 2}) inner_treedef = tree_util.tree_structure([1, 2, 3]) with self.assertRaisesRegex(TypeError, "Mismatch"): tree_util.tree_transpose(outer_treedef, inner_treedef, tree) def testTransposeWithCustomObject(self): outer_treedef = tree_util.tree_structure(FlatCache({"a": 1, "b": 2})) inner_treedef = tree_util.tree_structure([1, 2]) expected = [FlatCache({"a": 3, "b": 5}), FlatCache({"a": 4, "b": 6})] actual = tree_util.tree_transpose( outer_treedef, inner_treedef, FlatCache({ "a": [3, 4], "b": [5, 6] })) self.assertEqual(expected, actual) @unittest.skipIf(lib._xla_extension_version < 17, "Test requires jaxlib 0.1.66.") @parameterized.parameters([(*t, s) for t, s in zip(TREES, TREE_STRINGS)]) def testStringRepresentation(self, tree, correct_string): """Checks that the string representation of a tree works.""" treedef = tree_util.tree_structure(tree) self.assertRegex(str(treedef), correct_string)