def test_jit_pytree_return(self): @iree.jax.jit def apply_sqrt(pytree): return jax.tree_map(jnp.sqrt, pytree) np.random.seed(0) input_tree = { "a": [ normal((2, 3)), { "b": normal(3) }, ], "c": ( { "d": [normal(2), normal(3)] }, (normal(1), normal(4)), ) } expected = jax.tree_map(jnp.sqrt, input_tree) expected_arrays, expected_tree = jax.tree_flatten(expected) result = apply_sqrt(input_tree) result_arrays, result_tree = jax.tree_flatten(result) self.assertEqual(expected_tree, result_tree) for expected_array, result_array in zip(expected_arrays, result_arrays): np.testing.assert_allclose(expected_array, result_array, **TOLERANCE)
def apply_sqrt(pytree): return jax.tree_map(jnp.sqrt, pytree)