import numpy as onp from absl.testing import absltest from absl.testing import parameterized from jax import api from jax import core from jax import numpy as np from jax import test_util as jtu from jax.api import jvp, linearize, vjp, jit from jax.lax import UnshapedArray, ShapedArray, ConcreteArray from jax.tree_util import tree_flatten, tree_unflatten, tree_multimap, tree_reduce from jax.util import partial from jax.interpreters import partial_eval as pe from jax.interpreters import xla _ = pe.PartialVal((UnshapedArray(onp.float32), core.unit)) __ = pe.PartialVal((ShapedArray((), onp.float32), core.unit)) def call(f, *args): return jit(f)(*args) def simple_fun(x, y): return np.sin(x * y) def simple_fun_fanout(x, y): return np.sin(x * y) * x
from absl.testing import parameterized from jax import core from jax import numpy as jnp from jax import test_util as jtu from jax.api import jvp, linearize, vjp, jit from jax.lax import UnshapedArray, ShapedArray, ConcreteArray from jax.tree_util import tree_flatten, tree_unflatten, tree_multimap, tree_reduce, tree_leaves from jax.util import partial from jax.interpreters import partial_eval as pe from jax.config import config config.parse_flags_with_absl() _ = pe.PartialVal.unknown(UnshapedArray(np.float32)) __ = pe.PartialVal.unknown(ShapedArray((), np.float32)) def call(f, *args): return jit(f)(*args) def simple_fun(x, y): return jnp.sin(x * y) def simple_fun_fanout(x, y): return jnp.sin(x * y) * x