Exemple #1
0
    def test_jit_pytree_return(self):
        @iree.jax.jit
        def apply_sqrt(pytree):
            return jax.tree_map(jnp.sqrt, pytree)

        np.random.seed(0)
        input_tree = {
            "a": [
                normal((2, 3)),
                {
                    "b": normal(3)
                },
            ],
            "c": (
                {
                    "d": [normal(2), normal(3)]
                },
                (normal(1), normal(4)),
            )
        }

        expected = jax.tree_map(jnp.sqrt, input_tree)
        expected_arrays, expected_tree = jax.tree_flatten(expected)
        result = apply_sqrt(input_tree)
        result_arrays, result_tree = jax.tree_flatten(result)

        self.assertEqual(expected_tree, result_tree)
        for expected_array, result_array in zip(expected_arrays,
                                                result_arrays):
            np.testing.assert_allclose(expected_array, result_array,
                                       **TOLERANCE)
Exemple #2
0
 def apply_sqrt(pytree):
     return jax.tree_map(jnp.sqrt, pytree)