def test_simple_trace(self): def foo(x): return np.sin(x) + np.cos(x) pval = pe.PartialVal((ShapedArray((3, 2), onp.float32), core.unit)) check_trace_eval(foo, (pval, ), (onp.random.randn(3, 2), ), pval)
from absl.testing import absltest from absl.testing import parameterized from jax import api from jax import core from jax import numpy as np from jax import test_util as jtu from jax.api import jvp, linearize, vjp, jit from jax.lax import UnshapedArray, ShapedArray, ConcreteArray from jax.tree_util import tree_flatten, tree_unflatten, tree_multimap, tree_reduce from jax.util import partial from jax.interpreters import partial_eval as pe from jax.interpreters import xla _ = pe.PartialVal((UnshapedArray(onp.float32), core.unit)) __ = pe.PartialVal((ShapedArray((), onp.float32), core.unit)) def call(f, *args): return jit(f)(*args) def simple_fun(x, y): return np.sin(x * y) def simple_fun_fanout(x, y): return np.sin(x * y) * x def fun_with_call(x):
def DIABLED_test_print_jaxpr_compound(self): # TODO(dougalm): figure out what jaxpr-tracing api to expose and re-enable pv = pe.PartialVal((ShapedArray((2, 3), onp.float32), core.unit)) print(pe.trace_to_jaxpr(fun_with_call_closure, (pv, ))[0])
from jax import core from jax import numpy as jnp from jax import test_util as jtu from jax.api import jvp, linearize, vjp, jit from jax.lax import UnshapedArray, ShapedArray, ConcreteArray from jax.tree_util import tree_flatten, tree_unflatten, tree_multimap, tree_reduce, tree_leaves from jax.util import partial from jax.interpreters import partial_eval as pe from jax.config import config config.parse_flags_with_absl() _ = pe.PartialVal.unknown(UnshapedArray(np.float32)) __ = pe.PartialVal.unknown(ShapedArray((), np.float32)) def call(f, *args): return jit(f)(*args) def simple_fun(x, y): return jnp.sin(x * y) def simple_fun_fanout(x, y): return jnp.sin(x * y) * x def fun_with_call(x):