コード例 #1
0
    def test_simple_trace(self):
        def foo(x):
            return np.sin(x) + np.cos(x)

        pval = pe.PartialVal((ShapedArray((3, 2), onp.float32), core.unit))
        check_trace_eval(foo, (pval, ), (onp.random.randn(3, 2), ), pval)
コード例 #2
0
from absl.testing import absltest
from absl.testing import parameterized

from jax import api
from jax import core
from jax import numpy as np
from jax import test_util as jtu
from jax.api import jvp, linearize, vjp, jit
from jax.lax import UnshapedArray, ShapedArray, ConcreteArray
from jax.tree_util import tree_flatten, tree_unflatten, tree_multimap, tree_reduce
from jax.util import partial
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla

_ = pe.PartialVal((UnshapedArray(onp.float32), core.unit))
__ = pe.PartialVal((ShapedArray((), onp.float32), core.unit))


def call(f, *args):
    return jit(f)(*args)


def simple_fun(x, y):
    return np.sin(x * y)


def simple_fun_fanout(x, y):
    return np.sin(x * y) * x


def fun_with_call(x):
コード例 #3
0
 def DIABLED_test_print_jaxpr_compound(self):
     # TODO(dougalm): figure out what jaxpr-tracing api to expose and re-enable
     pv = pe.PartialVal((ShapedArray((2, 3), onp.float32), core.unit))
     print(pe.trace_to_jaxpr(fun_with_call_closure, (pv, ))[0])
コード例 #4
0
ファイル: core_test.py プロジェクト: odashi/jax
from jax import core
from jax import numpy as jnp
from jax import test_util as jtu
from jax.api import jvp, linearize, vjp, jit
from jax.lax import UnshapedArray, ShapedArray, ConcreteArray
from jax.tree_util import tree_flatten, tree_unflatten, tree_multimap, tree_reduce, tree_leaves
from jax.util import partial
from jax.interpreters import partial_eval as pe

from jax.config import config

config.parse_flags_with_absl()

_ = pe.PartialVal.unknown(UnshapedArray(np.float32))
__ = pe.PartialVal.unknown(ShapedArray((), np.float32))


def call(f, *args):
    return jit(f)(*args)


def simple_fun(x, y):
    return jnp.sin(x * y)


def simple_fun_fanout(x, y):
    return jnp.sin(x * y) * x


def fun_with_call(x):