Example #1
0
import numpy as onp
from absl.testing import absltest
from absl.testing import parameterized

from jax import api
from jax import core
from jax import numpy as np
from jax import test_util as jtu
from jax.api import jvp, linearize, vjp, jit
from jax.lax import UnshapedArray, ShapedArray, ConcreteArray
from jax.tree_util import tree_flatten, tree_unflatten, tree_multimap, tree_reduce
from jax.util import partial
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla

_ = pe.PartialVal((UnshapedArray(onp.float32), core.unit))
__ = pe.PartialVal((ShapedArray((), onp.float32), core.unit))


def call(f, *args):
    return jit(f)(*args)


def simple_fun(x, y):
    return np.sin(x * y)


def simple_fun_fanout(x, y):
    return np.sin(x * y) * x

Example #2
0
from absl.testing import parameterized

from jax import core
from jax import numpy as jnp
from jax import test_util as jtu
from jax.api import jvp, linearize, vjp, jit
from jax.lax import UnshapedArray, ShapedArray, ConcreteArray
from jax.tree_util import tree_flatten, tree_unflatten, tree_multimap, tree_reduce, tree_leaves
from jax.util import partial
from jax.interpreters import partial_eval as pe

from jax.config import config

config.parse_flags_with_absl()

_ = pe.PartialVal.unknown(UnshapedArray(np.float32))
__ = pe.PartialVal.unknown(ShapedArray((), np.float32))


def call(f, *args):
    return jit(f)(*args)


def simple_fun(x, y):
    return jnp.sin(x * y)


def simple_fun_fanout(x, y):
    return jnp.sin(x * y) * x