Example #1
0
 def test_select_false(self):
     on_true = ((jnp.zeros(3, ), ), jnp.zeros(4, ))
     on_false = ((jnp.ones(3, ), ), jnp.ones(4, ))
     output = tree_util.tree_select(False, on_true, on_false)
     for x, y in zip(jax.tree_util.tree_leaves(on_false),
                     jax.tree_util.tree_leaves(output)):
         np.testing.assert_array_equal(x, y)
Example #2
0
 def test_select_false(self):
     on_true = ((jnp.zeros(3, ), ), jnp.zeros(4, ))
     on_false = ((jnp.ones(3, ), ), jnp.ones(4, ))
     output = tree_util.tree_select(False, on_true, on_false)
     chex.assert_tree_all_close(output, on_false)