示例#1
0
class TreeTest(jtu.JaxTestCase):
    @parameterized.parameters(*PYTREES)
    def testRoundtrip(self, inputs):
        xs, tree = tree_util.tree_flatten(inputs)
        actual = tree_util.tree_unflatten(tree, xs)
        self.assertEqual(actual, inputs)

    @parameterized.parameters(
        (tree_util.Partial(_dummy_func), ),
        (tree_util.Partial(_dummy_func, 1, 2), ),
        (tree_util.Partial(_dummy_func, x="a"), ),
        (tree_util.Partial(_dummy_func, 1, 2, 3, x=4, y=5), ),
    )
    def testRoundtripPartial(self, inputs):
        xs, tree = tree_util.tree_flatten(inputs)
        actual = tree_util.tree_unflatten(tree, xs)
        # functools.partial does not support equality comparisons:
        # https://stackoverflow.com/a/32786109/809705
        self.assertEqual(actual.func, inputs.func)
        self.assertEqual(actual.args, inputs.args)
        self.assertEqual(actual.keywords, inputs.keywords)

    @parameterized.parameters(*PYTREES)
    def testRoundtripViaBuild(self, inputs):
        xs, tree = tree_util.process_pytree(tuple, inputs)
        actual = tree_util.build_tree(tree, xs)
        self.assertEqual(actual, inputs)

    def testChildren(self):
        _, tree = tree_util.tree_flatten(((1, 2, 3), (4, )))
        _, c0 = tree_util.tree_flatten((0, 0, 0))
        _, c1 = tree_util.tree_flatten((7, ))
        self.assertEqual([c0, c1], tree.children())
示例#2
0
 def testPartialFuncAttributeHasStableHash(self):
   # https://github.com/google/jax/issues/9429
   fun = functools.partial(print, 1)
   p1 = tree_util.Partial(fun, 2)
   p2 = tree_util.Partial(fun, 2)
   self.assertEqual(fun, p1.func)
   self.assertEqual(p1.func, fun)
   self.assertEqual(p1.func, p2.func)
   self.assertEqual(hash(p1.func), hash(p2.func))
示例#3
0
 def _func_fwd(values):
   """Converts values to numpy array, applies function and returns array."""
   dtype = values.dtype
   values = np.array(values)
   obj = cls(values, **kwargs)
   result = obj.compute()
   return jnp.array(result, dtype=dtype), tree_util.Partial(obj.vjp)
示例#4
0
    def testPartialDoesNotMergeWithOtherPartials(self):
        def f(a, b, c):
            pass

        g = functools.partial(f, 2)
        h = tree_util.Partial(g, 3)
        self.assertEqual(h.args, (3, ))
示例#5
0
    def test_kohn_sham_neural_xc_density_mse_converge_tolerance(
            self, density_mse_converge_tolerance, expected_converged):
        init_fn, xc_energy_density_fn = neural_xc.local_density_approximation(
            stax.serial(stax.Dense(8), stax.Elu, stax.Dense(1)))
        params_init = init_fn(rng=random.PRNGKey(0))

        states = jit_scf.kohn_sham(
            locations=self.locations,
            nuclear_charges=self.nuclear_charges,
            num_electrons=self.num_electrons,
            num_iterations=3,
            grids=self.grids,
            xc_energy_density_fn=tree_util.Partial(xc_energy_density_fn,
                                                   params=params_init),
            interaction_fn=utils.exponential_coulomb,
            initial_density=self.num_electrons *
            utils.gaussian(grids=self.grids, center=0., sigma=0.5),
            density_mse_converge_tolerance=density_mse_converge_tolerance)

        np.testing.assert_array_equal(states.converged, expected_converged)

        for single_state in scf.state_iterator(states):
            self._test_state(
                single_state,
                self._create_testing_external_potential(
                    utils.exponential_coulomb))
示例#6
0
 def loss(flatten_params, initial_state, target_energy):
   state = scf.kohn_sham_iteration(
       state=initial_state,
       num_electrons=self.num_electrons,
       xc_energy_density_fn=tree_util.Partial(
           xc_energy_density_fn,
           params=np_utils.unflatten(spec, flatten_params)),
       interaction_fn=utils.exponential_coulomb,
       enforce_reflection_symmetry=True)
   return (state.total_energy - target_energy) ** 2
示例#7
0
 def loss(flatten_params, initial_state, target_density):
   state = scf.kohn_sham_iteration(
       state=initial_state,
       num_electrons=self.num_electrons,
       xc_energy_density_fn=tree_util.Partial(
           xc_energy_density_fn,
           params=np_utils.unflatten(spec, flatten_params)),
       interaction_fn=utils.exponential_coulomb,
       enforce_reflection_symmetry=False)
   return jnp.sum(jnp.abs(state.density - target_density)) * utils.get_dx(
       self.grids)
示例#8
0
 def test_kohn_sham_iteration(
     self, interaction_fn, enforce_reflection_symmetry):
   initial_state = self._create_testing_initial_state(interaction_fn)
   next_state = scf.kohn_sham_iteration(
       state=initial_state,
       num_electrons=self.num_electrons,
       # Use 3d LDA exchange functional and zero correlation functional.
       xc_energy_density_fn=tree_util.Partial(
           lambda density: -0.73855 * density ** (1 / 3)),
       interaction_fn=interaction_fn,
       enforce_reflection_symmetry=enforce_reflection_symmetry)
   self._test_state(next_state, initial_state)
示例#9
0
class TreeTest(jtu.JaxTestCase):
    @parameterized.parameters(((1, 2), ), ([3], ), ({'a': 1, 'b': 2}, ))
    def testRoundtrip(self, inputs):
        xs, tree = tree_util.tree_flatten(inputs)
        actual = tree_util.tree_unflatten(tree, xs)
        self.assertEqual(actual, inputs)

    @parameterized.parameters(
        (tree_util.Partial(_dummy_func), ),
        (tree_util.Partial(_dummy_func, 1, 2), ),
        (tree_util.Partial(_dummy_func, x='a'), ),
        (tree_util.Partial(_dummy_func, 1, 2, 3, x=4, y=5), ),
    )
    def testRoundtrip(self, inputs):
        xs, tree = tree_util.tree_flatten(inputs)
        actual = tree_util.tree_unflatten(tree, xs)
        # functools.partial does not support equality comparisons:
        # https://stackoverflow.com/a/32786109/809705
        self.assertEqual(actual.func, inputs.func)
        self.assertEqual(actual.args, inputs.args)
        self.assertEqual(actual.keywords, inputs.keywords)
示例#10
0
 def loss(flatten_params, target_energy):
   state = scf.kohn_sham(
       locations=self.locations,
       nuclear_charges=self.nuclear_charges,
       num_electrons=self.num_electrons,
       num_iterations=3,
       grids=self.grids,
       xc_energy_density_fn=tree_util.Partial(
           xc_energy_density_fn,
           params=np_utils.unflatten(spec, flatten_params)),
       interaction_fn=utils.exponential_coulomb)
   final_state = scf.get_final_state(state)
   return (final_state.total_energy - target_energy) ** 2
示例#11
0
 def test_kohn_sham_iteration_neural_xc(self, enforce_reflection_symmetry):
     init_fn, xc_energy_density_fn = neural_xc.local_density_approximation(
         stax.serial(stax.Dense(8), stax.Elu, stax.Dense(1)))
     params_init = init_fn(rng=random.PRNGKey(0))
     initial_state = self._create_testing_initial_state(
         utils.exponential_coulomb)
     next_state = jit_scf.kohn_sham_iteration(
         state=initial_state,
         num_electrons=self.num_electrons,
         xc_energy_density_fn=tree_util.Partial(xc_energy_density_fn,
                                                params=params_init),
         interaction_fn=utils.exponential_coulomb,
         enforce_reflection_symmetry=enforce_reflection_symmetry)
     self._test_state(next_state, initial_state)
示例#12
0
 def test_kohn_sham_convergence(
     self, density_mse_converge_tolerance, expected_converged):
   state = scf.kohn_sham(
       locations=self.locations,
       nuclear_charges=self.nuclear_charges,
       num_electrons=self.num_electrons,
       num_iterations=3,
       grids=self.grids,
       # Use 3d LDA exchange functional and zero correlation functional.
       xc_energy_density_fn=tree_util.Partial(
           lambda density: -0.73855 * density ** (1 / 3)),
       interaction_fn=utils.exponential_coulomb,
       density_mse_converge_tolerance=density_mse_converge_tolerance)
   np.testing.assert_allclose(state.converged, expected_converged)
示例#13
0
 def test_kohn_sham(self, interaction_fn):
   state = scf.kohn_sham(
       locations=self.locations,
       nuclear_charges=self.nuclear_charges,
       num_electrons=self.num_electrons,
       num_iterations=3,
       grids=self.grids,
       # Use 3d LDA exchange functional and zero correlation functional.
       xc_energy_density_fn=tree_util.Partial(
           lambda density: -0.73855 * density ** (1 / 3)),
       interaction_fn=interaction_fn)
   for single_state in scf.state_iterator(state):
     self._test_state(
         single_state,
         self._create_testing_external_potential(interaction_fn))
示例#14
0
 def loss(flatten_params, target_density):
     state = scf.kohn_sham(locations=self.locations,
                           nuclear_charges=self.nuclear_charges,
                           num_electrons=self.num_electrons,
                           num_iterations=3,
                           grids=self.grids,
                           xc_energy_density_fn=tree_util.Partial(
                               xc_energy_density_fn,
                               params=np_utils.unflatten(
                                   spec, flatten_params)),
                           interaction_fn=utils.exponential_coulomb,
                           density_mse_converge_tolerance=-1)
     final_state = scf.get_final_state(state)
     return jnp.sum(jnp.abs(final_state.density -
                            target_density)) * utils.get_dx(self.grids)
示例#15
0
 def test_kohn_sham_neural_xc(self, interaction_fn):
     init_fn, xc_energy_density_fn = neural_xc.local_density_approximation(
         stax.serial(stax.Dense(8), stax.Elu, stax.Dense(1)))
     params_init = init_fn(rng=random.PRNGKey(0))
     state = scf.kohn_sham(locations=self.locations,
                           nuclear_charges=self.nuclear_charges,
                           num_electrons=self.num_electrons,
                           num_iterations=3,
                           grids=self.grids,
                           xc_energy_density_fn=tree_util.Partial(
                               xc_energy_density_fn, params=params_init),
                           interaction_fn=interaction_fn)
     for single_state in scf.state_iterator(state):
         self._test_state(
             single_state,
             self._create_testing_external_potential(interaction_fn))
示例#16
0
class TreeTest(jtu.JaxTestCase):

  @parameterized.parameters(*(TREES + LEAVES))
  def testRoundtrip(self, inputs):
    xs, tree = tree_util.tree_flatten(inputs)
    actual = tree_util.tree_unflatten(tree, xs)
    self.assertEqual(actual, inputs)

  @parameterized.parameters(*(TREES + LEAVES))
  def testRoundtripWithFlattenUpTo(self, inputs):
    _, tree = tree_util.tree_flatten(inputs)
    if not hasattr(tree, "flatten_up_to"):
      self.skipTest("Test requires Jaxlib >= 0.1.23")
    xs = tree.flatten_up_to(inputs)
    actual = tree_util.tree_unflatten(tree, xs)
    self.assertEqual(actual, inputs)

  @parameterized.parameters(
      (tree_util.Partial(_dummy_func),),
      (tree_util.Partial(_dummy_func, 1, 2),),
      (tree_util.Partial(_dummy_func, x="a"),),
      (tree_util.Partial(_dummy_func, 1, 2, 3, x=4, y=5),),
  )
  def testRoundtripPartial(self, inputs):
    xs, tree = tree_util.tree_flatten(inputs)
    actual = tree_util.tree_unflatten(tree, xs)
    # functools.partial does not support equality comparisons:
    # https://stackoverflow.com/a/32786109/809705
    self.assertEqual(actual.func, inputs.func)
    self.assertEqual(actual.args, inputs.args)
    self.assertEqual(actual.keywords, inputs.keywords)

  @parameterized.parameters(*(TREES + LEAVES))
  def testRoundtripViaBuild(self, inputs):
    xs, tree = tree_util._process_pytree(tuple, inputs)
    actual = tree_util.build_tree(tree, xs)
    self.assertEqual(actual, inputs)

  def testChildren(self):
    _, tree = tree_util.tree_flatten(((1, 2, 3), (4,)))
    _, c0 = tree_util.tree_flatten((0, 0, 0))
    _, c1 = tree_util.tree_flatten((7,))
    if not callable(tree.children):
      self.skipTest("Test requires Jaxlib >= 0.1.23")
    self.assertEqual([c0, c1], tree.children())

  def testFlattenUpTo(self):
    _, tree = tree_util.tree_flatten([(1, 2), None, ATuple(foo=3, bar=7)])
    if not hasattr(tree, "flatten_up_to"):
      self.skipTest("Test requires Jaxlib >= 0.1.23")
    out = tree.flatten_up_to([({
        "foo": 7
    }, (3, 4)), None, ATuple(foo=(11, 9), bar=None)])
    self.assertEqual(out, [{"foo": 7}, (3, 4), (11, 9), None])

  def testTreeMultimap(self):
    x = ((1, 2), (3, 4, 5))
    y = (([3], None), ({"foo": "bar"}, 7, [5, 6]))
    out = tree_util.tree_multimap(lambda *xs: tuple(xs), x, y)
    self.assertEqual(out, (((1, [3]), (2, None)),
                           ((3, {"foo": "bar"}), (4, 7), (5, [5, 6]))))

  @parameterized.parameters(*TREES)
  def testAllLeavesWithTrees(self, tree):
    leaves = tree_util.tree_leaves(tree)
    self.assertTrue(tree_util.all_leaves(leaves))
    self.assertFalse(tree_util.all_leaves([tree]))

  @parameterized.parameters(*LEAVES)
  def testAllLeavesWithLeaves(self, leaf):
    self.assertTrue(tree_util.all_leaves([leaf]))

  @parameterized.parameters(*TREES)
  def testCompose(self, tree):
    treedef = tree_util.tree_structure(tree)
    inner_treedef = tree_util.tree_structure(["*", "*", "*"])
    composed_treedef = treedef.compose(inner_treedef)
    expected_leaves = treedef.num_leaves * inner_treedef.num_leaves
    self.assertEqual(composed_treedef.num_leaves, expected_leaves)
    expected_nodes = ((treedef.num_nodes - treedef.num_leaves) +
                      (inner_treedef.num_nodes * treedef.num_leaves))
    self.assertEqual(composed_treedef.num_nodes, expected_nodes)
    leaves = [1] * expected_leaves
    composed = tree_util.tree_unflatten(composed_treedef, leaves)
    self.assertEqual(leaves, tree_util.tree_leaves(composed))

  @parameterized.parameters(*TREES)
  def testTranspose(self, tree):
    outer_treedef = tree_util.tree_structure(tree)
    if not outer_treedef.num_leaves:
      self.skipTest("Skipping empty tree")
    inner_treedef = tree_util.tree_structure([1, 1, 1])
    nested = tree_util.tree_map(lambda x: [x, x, x], tree)
    actual = tree_util.tree_transpose(outer_treedef, inner_treedef, nested)
    self.assertEqual(actual, [tree, tree, tree])

  def testTransposeMismatchOuter(self):
    tree = {"a": [1, 2], "b": [3, 4]}
    outer_treedef = tree_util.tree_structure({"a": 1, "b": 2, "c": 3})
    inner_treedef = tree_util.tree_structure([1, 2])
    with self.assertRaisesRegex(TypeError, "Mismatch"):
      tree_util.tree_transpose(outer_treedef, inner_treedef, tree)

  def testTransposeMismatchInner(self):
    tree = {"a": [1, 2], "b": [3, 4]}
    outer_treedef = tree_util.tree_structure({"a": 1, "b": 2})
    inner_treedef = tree_util.tree_structure([1, 2, 3])
    with self.assertRaisesRegex(TypeError, "Mismatch"):
      tree_util.tree_transpose(outer_treedef, inner_treedef, tree)

  def testTransposeWithCustomObject(self):
    outer_treedef = tree_util.tree_structure(FlatCache({"a": 1, "b": 2}))
    inner_treedef = tree_util.tree_structure([1, 2])
    expected = [FlatCache({"a": 3, "b": 5}), FlatCache({"a": 4, "b": 6})]
    actual = tree_util.tree_transpose(outer_treedef, inner_treedef,
                                      FlatCache({"a": [3, 4], "b": [5, 6]}))
    self.assertEqual(expected, actual)
示例#17
0
class TreeTest(jtu.JaxTestCase):

  @parameterized.parameters(*PYTREES)
  def testRoundtrip(self, inputs):
    xs, tree = tree_util.tree_flatten(inputs)
    actual = tree_util.tree_unflatten(tree, xs)
    self.assertEqual(actual, inputs)

  @parameterized.parameters(*PYTREES)
  def testRoundtripWithFlattenUpTo(self, inputs):
    _, tree = tree_util.tree_flatten(inputs)
    if not hasattr(tree, "flatten_up_to"):
      self.skipTest("Test requires Jaxlib >= 0.1.23")
    xs = tree.flatten_up_to(inputs)
    actual = tree_util.tree_unflatten(tree, xs)
    self.assertEqual(actual, inputs)

  @parameterized.parameters(
      (tree_util.Partial(_dummy_func),),
      (tree_util.Partial(_dummy_func, 1, 2),),
      (tree_util.Partial(_dummy_func, x="a"),),
      (tree_util.Partial(_dummy_func, 1, 2, 3, x=4, y=5),),
  )
  def testRoundtripPartial(self, inputs):
    xs, tree = tree_util.tree_flatten(inputs)
    actual = tree_util.tree_unflatten(tree, xs)
    # functools.partial does not support equality comparisons:
    # https://stackoverflow.com/a/32786109/809705
    self.assertEqual(actual.func, inputs.func)
    self.assertEqual(actual.args, inputs.args)
    self.assertEqual(actual.keywords, inputs.keywords)

  @parameterized.parameters(*PYTREES)
  def testRoundtripViaBuild(self, inputs):
    xs, tree = tree_util._process_pytree(tuple, inputs)
    actual = tree_util.build_tree(tree, xs)
    self.assertEqual(actual, inputs)

  def testChildren(self):
    _, tree = tree_util.tree_flatten(((1, 2, 3), (4,)))
    _, c0 = tree_util.tree_flatten((0, 0, 0))
    _, c1 = tree_util.tree_flatten((7,))
    if not callable(tree.children):
      self.skipTest("Test requires Jaxlib >= 0.1.23")
    self.assertEqual([c0, c1], tree.children())

  def testFlattenUpTo(self):
    _, tree = tree_util.tree_flatten([(1, 2), None, ATuple(foo=3, bar=7)])
    if not hasattr(tree, "flatten_up_to"):
      self.skipTest("Test requires Jaxlib >= 0.1.23")
    out = tree.flatten_up_to([({
        "foo": 7
    }, (3, 4)), None, ATuple(foo=(11, 9), bar=None)])
    self.assertEqual(out, [{"foo": 7}, (3, 4), (11, 9), None])

  def testTreeMultimap(self):
    x = ((1, 2), (3, 4, 5))
    y = (([3], None), ({"foo": "bar"}, 7, [5, 6]))
    out = tree_util.tree_multimap(lambda *xs: tuple(xs), x, y)
    self.assertEqual(out, (((1, [3]), (2, None)),
                           ((3, {"foo": "bar"}), (4, 7), (5, [5, 6]))))
示例#18
0
class TreeTest(jtu.JaxTestCase):

  @parameterized.parameters(*(TREES + LEAVES))
  def testRoundtrip(self, inputs):
    xs, tree = tree_util.tree_flatten(inputs)
    actual = tree_util.tree_unflatten(tree, xs)
    self.assertEqual(actual, inputs)

  @parameterized.parameters(*(TREES + LEAVES))
  def testRoundtripWithFlattenUpTo(self, inputs):
    _, tree = tree_util.tree_flatten(inputs)
    xs = tree.flatten_up_to(inputs)
    actual = tree_util.tree_unflatten(tree, xs)
    self.assertEqual(actual, inputs)

  @parameterized.parameters(
      (tree_util.Partial(_dummy_func),),
      (tree_util.Partial(_dummy_func, 1, 2),),
      (tree_util.Partial(_dummy_func, x="a"),),
      (tree_util.Partial(_dummy_func, 1, 2, 3, x=4, y=5),),
  )
  def testRoundtripPartial(self, inputs):
    xs, tree = tree_util.tree_flatten(inputs)
    actual = tree_util.tree_unflatten(tree, xs)
    # functools.partial does not support equality comparisons:
    # https://stackoverflow.com/a/32786109/809705
    self.assertEqual(actual.func, inputs.func)
    self.assertEqual(actual.args, inputs.args)
    self.assertEqual(actual.keywords, inputs.keywords)

  def testPartialDoesNotMergeWithOtherPartials(self):
    def f(a, b, c): pass
    g = functools.partial(f, 2)
    h = tree_util.Partial(g, 3)
    self.assertEqual(h.args, (3,))

  def testPartialFuncAttributeHasStableHash(self):
    # https://github.com/google/jax/issues/9429
    fun = functools.partial(print, 1)
    p1 = tree_util.Partial(fun, 2)
    p2 = tree_util.Partial(fun, 2)
    self.assertEqual(fun, p1.func)
    self.assertEqual(p1.func, fun)
    self.assertEqual(p1.func, p2.func)
    self.assertEqual(hash(p1.func), hash(p2.func))

  @parameterized.parameters(*(TREES + LEAVES))
  def testRoundtripViaBuild(self, inputs):
    xs, tree = _process_pytree(tuple, inputs)
    actual = tree_util.build_tree(tree, xs)
    self.assertEqual(actual, inputs)

  def testChildren(self):
    _, tree = tree_util.tree_flatten(((1, 2, 3), (4,)))
    _, c0 = tree_util.tree_flatten((0, 0, 0))
    _, c1 = tree_util.tree_flatten((7,))
    self.assertEqual([c0, c1], tree.children())

  def testTreedefTupleFromChildren(self):
    # https://github.com/google/jax/issues/7377
    tree = ((1, 2, (3, 4)), (5,))
    leaves, treedef1 = tree_util.tree_flatten(tree)
    treedef2 = tree_util.treedef_tuple(treedef1.children())
    self.assertEqual(treedef1.num_leaves, len(leaves))
    self.assertEqual(treedef1.num_leaves, treedef2.num_leaves)
    self.assertEqual(treedef1.num_nodes, treedef2.num_nodes)

  def testTreedefTupleComparesEqual(self):
    # https://github.com/google/jax/issues/9066
    self.assertEqual(tree_util.tree_structure((3,)),
                     tree_util.treedef_tuple((tree_util.tree_structure(3),)))

  def testFlattenUpTo(self):
    _, tree = tree_util.tree_flatten([(1, 2), None, ATuple(foo=3, bar=7)])
    out = tree.flatten_up_to([({
        "foo": 7
    }, (3, 4)), None, ATuple(foo=(11, 9), bar=None)])
    self.assertEqual(out, [{"foo": 7}, (3, 4), (11, 9), None])

  def testTreeMultimap(self):
    x = ((1, 2), (3, 4, 5))
    y = (([3], None), ({"foo": "bar"}, 7, [5, 6]))
    out = tree_util.tree_multimap(lambda *xs: tuple(xs), x, y)
    self.assertEqual(out, (((1, [3]), (2, None)),
                           ((3, {"foo": "bar"}), (4, 7), (5, [5, 6]))))

  def testTreeMultimapWithIsLeafArgument(self):
    x = ((1, 2), [3, 4, 5])
    y = (([3], None), ({"foo": "bar"}, 7, [5, 6]))
    out = tree_util.tree_multimap(lambda *xs: tuple(xs), x, y,
                                  is_leaf=lambda n: isinstance(n, list))
    self.assertEqual(out, (((1, [3]), (2, None)),
                           (([3, 4, 5], ({"foo": "bar"}, 7, [5, 6])))))

  @parameterized.parameters(
      tree_util.tree_leaves,
      lambda tree, is_leaf: tree_util.tree_flatten(tree, is_leaf)[0])
  def testFlattenIsLeaf(self, leaf_fn):
    x = [(1, 2), (3, 4), (5, 6)]
    leaves = leaf_fn(x, is_leaf=lambda t: False)
    self.assertEqual(leaves, [1, 2, 3, 4, 5, 6])
    leaves = leaf_fn(x, is_leaf=lambda t: isinstance(t, tuple))
    self.assertEqual(leaves, x)
    leaves = leaf_fn(x, is_leaf=lambda t: isinstance(t, list))
    self.assertEqual(leaves, [x])
    leaves = leaf_fn(x, is_leaf=lambda t: True)
    self.assertEqual(leaves, [x])

    y = [[[(1,)], [[(2,)], {"a": (3,)}]]]
    leaves = leaf_fn(y, is_leaf=lambda t: isinstance(t, tuple))
    self.assertEqual(leaves, [(1,), (2,), (3,)])

  @parameterized.parameters(
      tree_util.tree_structure,
      lambda tree, is_leaf: tree_util.tree_flatten(tree, is_leaf)[1])
  def testStructureIsLeaf(self, structure_fn):
    x = [(1, 2), (3, 4), (5, 6)]
    treedef = structure_fn(x, is_leaf=lambda t: False)
    self.assertEqual(treedef.num_leaves, 6)
    treedef = structure_fn(x, is_leaf=lambda t: isinstance(t, tuple))
    self.assertEqual(treedef.num_leaves, 3)
    treedef = structure_fn(x, is_leaf=lambda t: isinstance(t, list))
    self.assertEqual(treedef.num_leaves, 1)
    treedef = structure_fn(x, is_leaf=lambda t: True)
    self.assertEqual(treedef.num_leaves, 1)

    y = [[[(1,)], [[(2,)], {"a": (3,)}]]]
    treedef = structure_fn(y, is_leaf=lambda t: isinstance(t, tuple))
    self.assertEqual(treedef.num_leaves, 3)

  @parameterized.parameters(*TREES)
  def testRoundtripIsLeaf(self, tree):
    xs, treedef = tree_util.tree_flatten(
        tree, is_leaf=lambda t: isinstance(t, tuple))
    recon_tree = tree_util.tree_unflatten(treedef, xs)
    self.assertEqual(recon_tree, tree)

  @parameterized.parameters(*TREES)
  def testAllLeavesWithTrees(self, tree):
    leaves = tree_util.tree_leaves(tree)
    self.assertTrue(tree_util.all_leaves(leaves))
    self.assertFalse(tree_util.all_leaves([tree]))

  @parameterized.parameters(*LEAVES)
  def testAllLeavesWithLeaves(self, leaf):
    self.assertTrue(tree_util.all_leaves([leaf]))

  @parameterized.parameters(*TREES)
  def testCompose(self, tree):
    treedef = tree_util.tree_structure(tree)
    inner_treedef = tree_util.tree_structure(["*", "*", "*"])
    composed_treedef = treedef.compose(inner_treedef)
    expected_leaves = treedef.num_leaves * inner_treedef.num_leaves
    self.assertEqual(composed_treedef.num_leaves, expected_leaves)
    expected_nodes = ((treedef.num_nodes - treedef.num_leaves) +
                      (inner_treedef.num_nodes * treedef.num_leaves))
    self.assertEqual(composed_treedef.num_nodes, expected_nodes)
    leaves = [1] * expected_leaves
    composed = tree_util.tree_unflatten(composed_treedef, leaves)
    self.assertEqual(leaves, tree_util.tree_leaves(composed))

  @parameterized.parameters(*TREES)
  def testTranspose(self, tree):
    outer_treedef = tree_util.tree_structure(tree)
    if not outer_treedef.num_leaves:
      self.skipTest("Skipping empty tree")
    inner_treedef = tree_util.tree_structure([1, 1, 1])
    nested = tree_util.tree_map(lambda x: [x, x, x], tree)
    actual = tree_util.tree_transpose(outer_treedef, inner_treedef, nested)
    self.assertEqual(actual, [tree, tree, tree])

  def testTransposeMismatchOuter(self):
    tree = {"a": [1, 2], "b": [3, 4]}
    outer_treedef = tree_util.tree_structure({"a": 1, "b": 2, "c": 3})
    inner_treedef = tree_util.tree_structure([1, 2])
    with self.assertRaisesRegex(TypeError, "Mismatch"):
      tree_util.tree_transpose(outer_treedef, inner_treedef, tree)

  def testTransposeMismatchInner(self):
    tree = {"a": [1, 2], "b": [3, 4]}
    outer_treedef = tree_util.tree_structure({"a": 1, "b": 2})
    inner_treedef = tree_util.tree_structure([1, 2, 3])
    with self.assertRaisesRegex(TypeError, "Mismatch"):
      tree_util.tree_transpose(outer_treedef, inner_treedef, tree)

  def testTransposeWithCustomObject(self):
    outer_treedef = tree_util.tree_structure(FlatCache({"a": 1, "b": 2}))
    inner_treedef = tree_util.tree_structure([1, 2])
    expected = [FlatCache({"a": 3, "b": 5}), FlatCache({"a": 4, "b": 6})]
    actual = tree_util.tree_transpose(outer_treedef, inner_treedef,
                                      FlatCache({"a": [3, 4], "b": [5, 6]}))
    self.assertEqual(expected, actual)

  @parameterized.parameters([(*t, s) for t, s in zip(TREES, TREE_STRINGS)])
  def testStringRepresentation(self, tree, correct_string):
    """Checks that the string representation of a tree works."""
    treedef = tree_util.tree_structure(tree)
    self.assertRegex(str(treedef), correct_string)

  def testTreeDefWithEmptyDictStringRepresentation(self):
    self.assertEqual(str(tree_util.tree_structure({})), "PyTreeDef({})")
示例#19
0
class TreeTest(jtu.JaxTestCase):
    @parameterized.parameters(*(TREES + LEAVES))
    def testRoundtrip(self, inputs):
        xs, tree = tree_util.tree_flatten(inputs)
        actual = tree_util.tree_unflatten(tree, xs)
        self.assertEqual(actual, inputs)

    @parameterized.parameters(*(TREES + LEAVES))
    def testRoundtripWithFlattenUpTo(self, inputs):
        _, tree = tree_util.tree_flatten(inputs)
        xs = tree.flatten_up_to(inputs)
        actual = tree_util.tree_unflatten(tree, xs)
        self.assertEqual(actual, inputs)

    @parameterized.parameters(
        (tree_util.Partial(_dummy_func), ),
        (tree_util.Partial(_dummy_func, 1, 2), ),
        (tree_util.Partial(_dummy_func, x="a"), ),
        (tree_util.Partial(_dummy_func, 1, 2, 3, x=4, y=5), ),
    )
    def testRoundtripPartial(self, inputs):
        xs, tree = tree_util.tree_flatten(inputs)
        actual = tree_util.tree_unflatten(tree, xs)
        # functools.partial does not support equality comparisons:
        # https://stackoverflow.com/a/32786109/809705
        self.assertEqual(actual.func, inputs.func)
        self.assertEqual(actual.args, inputs.args)
        self.assertEqual(actual.keywords, inputs.keywords)

    @parameterized.parameters(*(TREES + LEAVES))
    def testRoundtripViaBuild(self, inputs):
        xs, tree = _process_pytree(tuple, inputs)
        actual = tree_util.build_tree(tree, xs)
        self.assertEqual(actual, inputs)

    def testChildren(self):
        _, tree = tree_util.tree_flatten(((1, 2, 3), (4, )))
        _, c0 = tree_util.tree_flatten((0, 0, 0))
        _, c1 = tree_util.tree_flatten((7, ))
        self.assertEqual([c0, c1], tree.children())

    def testFlattenUpTo(self):
        _, tree = tree_util.tree_flatten([(1, 2), None, ATuple(foo=3, bar=7)])
        out = tree.flatten_up_to([({
            "foo": 7
        }, (3, 4)), None,
                                  ATuple(foo=(11, 9), bar=None)])
        self.assertEqual(out, [{"foo": 7}, (3, 4), (11, 9), None])

    def testTreeMultimap(self):
        x = ((1, 2), (3, 4, 5))
        y = (([3], None), ({"foo": "bar"}, 7, [5, 6]))
        out = tree_util.tree_multimap(lambda *xs: tuple(xs), x, y)
        self.assertEqual(out, (((1, [3]), (2, None)), ((3, {
            "foo": "bar"
        }), (4, 7), (5, [5, 6]))))

    def testTreeMultimapWithIsLeafArgument(self):
        x = ((1, 2), [3, 4, 5])
        y = (([3], None), ({"foo": "bar"}, 7, [5, 6]))
        out = tree_util.tree_multimap(lambda *xs: tuple(xs),
                                      x,
                                      y,
                                      is_leaf=lambda n: isinstance(n, list))
        self.assertEqual(out,
                         (((1, [3]), (2, None)), (([3, 4, 5], ({
                             "foo": "bar"
                         }, 7, [5, 6])))))

    def testFlattenIsLeaf(self):
        x = [(1, 2), (3, 4), (5, 6)]
        leaves, _ = tree_util.tree_flatten(x, is_leaf=lambda t: False)
        self.assertEqual(leaves, [1, 2, 3, 4, 5, 6])
        leaves, _ = tree_util.tree_flatten(
            x, is_leaf=lambda t: isinstance(t, tuple))
        self.assertEqual(leaves, x)
        leaves, _ = tree_util.tree_flatten(
            x, is_leaf=lambda t: isinstance(t, list))
        self.assertEqual(leaves, [x])
        leaves, _ = tree_util.tree_flatten(x, is_leaf=lambda t: True)
        self.assertEqual(leaves, [x])

        y = [[[(1, )], [[(2, )], {"a": (3, )}]]]
        leaves, _ = tree_util.tree_flatten(
            y, is_leaf=lambda t: isinstance(t, tuple))
        self.assertEqual(leaves, [(1, ), (2, ), (3, )])

    @parameterized.parameters(*TREES)
    def testRoundtripIsLeaf(self, tree):
        xs, treedef = tree_util.tree_flatten(
            tree, is_leaf=lambda t: isinstance(t, tuple))
        recon_tree = tree_util.tree_unflatten(treedef, xs)
        self.assertEqual(recon_tree, tree)

    @parameterized.parameters(*TREES)
    def testAllLeavesWithTrees(self, tree):
        leaves = tree_util.tree_leaves(tree)
        self.assertTrue(tree_util.all_leaves(leaves))
        self.assertFalse(tree_util.all_leaves([tree]))

    @parameterized.parameters(*LEAVES)
    def testAllLeavesWithLeaves(self, leaf):
        self.assertTrue(tree_util.all_leaves([leaf]))

    @parameterized.parameters(*TREES)
    def testCompose(self, tree):
        treedef = tree_util.tree_structure(tree)
        inner_treedef = tree_util.tree_structure(["*", "*", "*"])
        composed_treedef = treedef.compose(inner_treedef)
        expected_leaves = treedef.num_leaves * inner_treedef.num_leaves
        self.assertEqual(composed_treedef.num_leaves, expected_leaves)
        expected_nodes = ((treedef.num_nodes - treedef.num_leaves) +
                          (inner_treedef.num_nodes * treedef.num_leaves))
        self.assertEqual(composed_treedef.num_nodes, expected_nodes)
        leaves = [1] * expected_leaves
        composed = tree_util.tree_unflatten(composed_treedef, leaves)
        self.assertEqual(leaves, tree_util.tree_leaves(composed))

    @parameterized.parameters(*TREES)
    def testTranspose(self, tree):
        outer_treedef = tree_util.tree_structure(tree)
        if not outer_treedef.num_leaves:
            self.skipTest("Skipping empty tree")
        inner_treedef = tree_util.tree_structure([1, 1, 1])
        nested = tree_util.tree_map(lambda x: [x, x, x], tree)
        actual = tree_util.tree_transpose(outer_treedef, inner_treedef, nested)
        self.assertEqual(actual, [tree, tree, tree])

    def testTransposeMismatchOuter(self):
        tree = {"a": [1, 2], "b": [3, 4]}
        outer_treedef = tree_util.tree_structure({"a": 1, "b": 2, "c": 3})
        inner_treedef = tree_util.tree_structure([1, 2])
        with self.assertRaisesRegex(TypeError, "Mismatch"):
            tree_util.tree_transpose(outer_treedef, inner_treedef, tree)

    def testTransposeMismatchInner(self):
        tree = {"a": [1, 2], "b": [3, 4]}
        outer_treedef = tree_util.tree_structure({"a": 1, "b": 2})
        inner_treedef = tree_util.tree_structure([1, 2, 3])
        with self.assertRaisesRegex(TypeError, "Mismatch"):
            tree_util.tree_transpose(outer_treedef, inner_treedef, tree)

    def testTransposeWithCustomObject(self):
        outer_treedef = tree_util.tree_structure(FlatCache({"a": 1, "b": 2}))
        inner_treedef = tree_util.tree_structure([1, 2])
        expected = [FlatCache({"a": 3, "b": 5}), FlatCache({"a": 4, "b": 6})]
        actual = tree_util.tree_transpose(
            outer_treedef, inner_treedef, FlatCache({
                "a": [3, 4],
                "b": [5, 6]
            }))
        self.assertEqual(expected, actual)

    @unittest.skipIf(lib._xla_extension_version < 17,
                     "Test requires jaxlib 0.1.66.")
    @parameterized.parameters([(*t, s) for t, s in zip(TREES, TREE_STRINGS)])
    def testStringRepresentation(self, tree, correct_string):
        """Checks that the string representation of a tree works."""
        treedef = tree_util.tree_structure(tree)
        self.assertRegex(str(treedef), correct_string)