Exemplo n.º 1
0
    def test_unimplemented_interpreter_rules(self):
        foo_p = Primitive('foo')

        def foo(x):
            return foo_p.bind(x)

        jtu.check_raises(lambda: foo(1.0), NotImplementedError,
                         "Evaluation rule for 'foo' not implemented")

        jtu.check_raises(lambda: jit(foo)(1.0), NotImplementedError,
                         "Abstract evaluation for 'foo' not implemented")

        jtu.check_raises(
            lambda: grad(foo)(1.0), NotImplementedError,
            "Forward-mode differentiation rule for 'foo' not implemented")

        def_abstract_eval(foo_p, lambda x: x)

        jtu.check_raises(lambda: jit(foo)(1.0), NotImplementedError,
                         "XLA translation rule for 'foo' not implemented")

        foo_p.def_impl(lambda x: x)
        defjvp(foo_p, lambda g, x: foo(g))

        jtu.check_raises(
            lambda: grad(foo)(1.0), NotImplementedError,
            "Reverse-mode differentiation rule for 'foo' not implemented")
Exemplo n.º 2
0
  def test_defvjp_all_higher_order_revmode(self):
    foo_p = Primitive('foo')
    def foo(x): return 2. * foo_p.bind(x)

    ad.defvjp_all(foo_p, lambda x: (x**2, lambda g: (g * x ** 2,)))
    ans = api.grad(api.grad(foo))(3.)
    self.assertAllClose(ans, 2 * 2 * 3., check_dtypes=False)
Exemplo n.º 3
0
 def find_callback(
         prim: core.Primitive, vals: Sequence[core.Tracer],
         params: Dict[str,
                      Any]) -> Union[core.Tracer, Sequence[core.Tracer]]:
     vals = prim.bind(*vals, **params)
     _contains_query(vals, queries)
     return vals
Exemplo n.º 4
0
 def rewrite_callback(
         prim: core.Primitive, vals: Sequence[core.Tracer],
         params: Dict[str,
                      Any]) -> Union[core.Tracer, Sequence[core.Tracer]]:
     if prim in rules:
         return rules[prim](*vals, **params)
     return prim.bind(*vals, **params)
Exemplo n.º 5
0
    def process_primitive(self, primitive: core.Primitive,
                          tracers: Sequence[TensorFlowTracer],
                          params) -> TensorFlowTracer:
        impl = self.get_primitive_impl(primitive)
        args_tf: Sequence[TfValOrUnit] = [t.val for t in tracers]
        # impl takes core.unit and returns core.unit when needed.
        val_out: TfValOrUnit = impl(*args_tf, **params)
        if primitive.multiple_results:
            out = util.safe_map(functools.partial(TensorFlowTracer, self),
                                val_out)  # type: ignore
        else:
            out = TensorFlowTracer(self, val_out)

        # Check that the impl rule returned a value of expected shape and dtype
        if not core.skip_checks:
            expected_out_aval: core.AbstractValue = primitive.abstract_eval(
                *[t.aval for t in tracers], **params)
            if primitive.multiple_results:
                for o, expected_aval in zip(out,
                                            expected_out_aval):  # type: ignore
                    assert o.aval == expected_aval, (
                        f"{primitive}: out.aval = {o.aval}; expected {expected_aval}"
                    )
            else:
                assert out.aval == expected_out_aval, (  # type: ignore
                    f"{primitive}: out.aval = {out.aval}; expected {expected_out_aval}"
                )  # type: ignore
        return out  # type: ignore
Exemplo n.º 6
0
  def test_defvjp_all_const(self):
    foo_p = Primitive('foo')
    def foo(x): return foo_p.bind(x)

    ad.defvjp_all(foo_p, lambda x: (x**2, lambda g: (12.,)))
    val_ans, grad_ans = api.value_and_grad(foo)(3.)
    self.assertAllClose(val_ans, 9., check_dtypes=False)
    self.assertAllClose(grad_ans, 12., check_dtypes=True)
Exemplo n.º 7
0
  def test_defvjp_all(self):
    foo_p = Primitive('foo')
    def foo(x): return 2. * foo_p.bind(x)

    ad.defvjp_all(foo_p, lambda x: (x**2, lambda g: (4 * g * np.sin(x),)))
    val_ans, grad_ans = api.value_and_grad(foo)(3.)
    self.assertAllClose(val_ans, 2 * 3.**2, check_dtypes=False)
    self.assertAllClose(grad_ans, 4 * 2 * onp.sin(3.), check_dtypes=False)
Exemplo n.º 8
0
    def _process_primitive(self, primitive: Primitive, flat_inputs, kwargs):
        if primitive in InitTrace._rules:
            return InitTrace._rules[primitive](self)(flat_inputs, kwargs)

        if isinstance(primitive, parametrized):
            return self.process_parametrized(primitive, *flat_inputs, **kwargs)

        return primitive.bind(*flat_inputs, **kwargs)
Exemplo n.º 9
0
    def decorator(fenics_function: Callable) -> Callable:
        @functools.wraps(fenics_function)
        def jax_solve_eval(*args):
            return jax_solve_eval_p.bind(*args)

        jax_solve_eval_p = Primitive("jax_solve_eval")
        jax_solve_eval_p.def_impl(lambda *args: solve_eval(
            fenics_function, fenics_templates, *args)[0])

        jax_solve_eval_p.def_abstract_eval(
            lambda *args: jax.abstract_arrays.make_shaped_array(
                solve_eval(fenics_function, fenics_templates, *args)[0]))

        def jax_solve_eval_batch(vector_arg_values, batch_axes):
            assert len(
                set(batch_axes)) == 1  # assert that all batch axes are same
            assert (
                batch_axes[0] == 0
            )  # assert that batch axis is zero, need to rewrite for a general case?
            # compute function row-by-row
            res = np.asarray([
                jax_solve_eval(*(vector_arg_values[j][i]
                                 for j in range(len(batch_axes))))
                for i in range(vector_arg_values[0].shape[0])
            ])
            return res, batch_axes[0]

        jax.batching.primitive_batchers[
            jax_solve_eval_p] = jax_solve_eval_batch

        # @trace("djax_solve_eval")
        def djax_solve_eval(*args):
            return djax_solve_eval_p.bind(*args)

        djax_solve_eval_p = Primitive("djax_solve_eval")
        # djax_solve_eval_p.multiple_results = True
        djax_solve_eval_p.def_impl(lambda *args: vjp_solve_eval(
            fenics_function, fenics_templates, *args))

        defvjp_all(jax_solve_eval_p, djax_solve_eval)
        return jax_solve_eval
Exemplo n.º 10
0
    def decorator(fenics_function: Callable) -> Callable:
        def jax_assemble_eval(*args):
            return jax_assemble_eval_p.bind(*args)

        jax_assemble_eval_p = Primitive("jax_assemble_eval")
        jax_assemble_eval_p.def_impl(
            lambda *args: assemble_eval(fenics_function, fenics_templates, *args)[0]
        )

        jax_assemble_eval_p.def_abstract_eval(
            lambda *args: jax.abstract_arrays.make_shaped_array(
                assemble_eval(fenics_function, fenics_templates, *args)[0]
            )
        )

        def jax_assemble_eval_batch(vector_arg_values, batch_axes):
            assert len(set(batch_axes)) == 1  # assert that all batch axes are same
            assert (
                batch_axes[0] == 0
            )  # assert that batch axis is zero, need to rewrite for a general case?
            res = list(map(jax_assemble_eval, *vector_arg_values))
            res = np.asarray(res)
            return res, batch_axes[0]

        batching.primitive_batchers[jax_assemble_eval_p] = jax_assemble_eval_batch

        # @trace("djax_assemble_eval")
        def djax_assemble_eval(*args):
            return djax_assemble_eval_p.bind(*args)

        djax_assemble_eval_p = Primitive("djax_assemble_eval")
        # djax_assemble_eval_p.multiple_results = True
        djax_assemble_eval_p.def_impl(
            lambda *args: vjp_assemble_eval(fenics_function, fenics_templates, *args)
        )

        defvjp_all(jax_assemble_eval_p, djax_assemble_eval)

        return jax_assemble_eval
Exemplo n.º 11
0
 def default_process_primitive(
     self, primitive: jax_core.Primitive, tracers: List['HarvestTracer'],
     params: Dict[str, Any]) -> Union['HarvestTracer', List['HarvestTracer']]:
   context = trace_util.get_dynamic_context(self)
   vals = [t.val for t in tracers]
   if primitive is sow_p:
     outvals = context.process_sow(*vals, **params)
     return jax_util.safe_map(self.pure, outvals)
   outvals = primitive.bind(*vals, **params)
   if not primitive.multiple_results:
     outvals = [outvals]
   out_tracers = jax_util.safe_map(self.pure, outvals)
   if primitive.multiple_results:
     return out_tracers
   return out_tracers[0]
Exemplo n.º 12
0
  def test_defvjp_all_multiple_arguments(self):
    # also tests passing in symbolic zero tangents b/c we differentiate wrt only
    # the first argument in one case

    foo_p = Primitive('foo')
    def foo(x, y): return foo_p.bind(x, y)

    def vjpfun(x, y):
      out = x**2 + y**3
      vjp = lambda g: (g + x + y, g * x * 9.)
      return out, vjp

    ad.defvjp_all(foo_p, vjpfun)
    val_ans, grad_ans = api.value_and_grad(foo)(3., 4.)
    self.assertAllClose(val_ans, 3.**2 + 4.**3, check_dtypes=False)
    self.assertAllClose(grad_ans, 1. + 3. + 4., check_dtypes=False)

    ans = api.grad(foo, (0, 1))(3., 4.)
    self.assertAllClose(ans, (1. + 3. + 4., 1. * 3. * 9.), check_dtypes=False)
Exemplo n.º 13
0
    def decorator(fenics_function: Callable) -> Callable:
        @functools.wraps(fenics_function)
        @custom_vjp
        def jax_fem_eval(*args):
            return jax_fem_eval_p.bind(*args)

        jax_fem_eval_p = Primitive("jax_fem_eval")
        jax_fem_eval_p.def_impl(
            lambda *args: evaluate_primal(fenics_function, fenics_templates, *args)[0]
        )

        jax_fem_eval_p.def_abstract_eval(
            lambda *args: jax.abstract_arrays.make_shaped_array(
                evaluate_primal(fenics_function, fenics_templates, *args)[0]
            )
        )

        def jax_fem_eval_batch(vector_arg_values, batch_axes):
            assert len(set(batch_axes)) == 1  # assert that all batch axes are same
            assert (
                batch_axes[0] == 0
            )  # assert that batch axis is zero, need to rewrite for a general case?
            res = list(map(jax_fem_eval, *vector_arg_values))
            res = np.asarray(res)
            return res, batch_axes[0]

        jax.interpreters.batching.primitive_batchers[
            jax_fem_eval_p
        ] = jax_fem_eval_batch

        def primal(*args):
            numpy_output, fenics_output, fenics_inputs, tape = evaluate_primal(
                fenics_function, fenics_templates, *args
            )
            return (
                numpy_output,
                (PyadjointMetadata(fenics_output, fenics_inputs, tape), args),
            )

        def pullback(aux_args, g):
            pb_fn = get_pullback_function(fenics_function, fenics_templates)
            # for some reason output of get_pullback_function is a list but we need tuple
            return tuple(pb_fn(aux_args, g))

        jax_fem_eval.defvjp(primal, pullback)
        return jax_fem_eval
Exemplo n.º 14
0
        v, w = operand, operand
    return core.AbstractTuple((v, w))


def eigh_cpu_translation_rule(c, operand, lower):
    shape = c.GetShape(operand)
    dtype = shape.element_type().type
    if len(shape.dimensions()) == 2 and dtype in _cpu_lapack_types:
        out = lapack.jax_syevd(c, operand, lower=lower)
        return c.Tuple(c.GetTupleElement(out, 0), c.GetTupleElement(out, 1))
    else:
        raise NotImplementedError(
            "Only unbatched eigendecomposition is implemented on CPU")


eigh_p = Primitive('eigh')
eigh_p.def_impl(eigh_impl)
eigh_p.def_abstract_eval(eigh_abstract_eval)
xla.translations[eigh_p] = eigh_translation_rule
xla.backend_specific_translations['Host'][eigh_p] = eigh_cpu_translation_rule

triangular_solve_dtype_rule = partial(binop_dtype_rule, _input_dtype,
                                      (_float | _complex, _float | _complex),
                                      'triangular_solve')


def triangular_solve_shape_rule(a, b, left_side=False, **unused_kwargs):
    if a.ndim < 2:
        msg = "triangular_solve requires a.ndim to be at least 2, got {}."
        raise TypeError(msg.format(a.ndim))
    if a.shape[-1] != a.shape[-2]:
Exemplo n.º 15
0
        == cosu.dtype == dt_tracer.dtype:
        raise ValueError('all inputs must have the same dtype')

    return superbee_p.bind(var, u_wgrid, v_wgrid, w_wgrid, maskW, dxt, dyt,
                           dzw, cost, cosu, dt_tracer)


def superbee_impl(*args, **kwargs):
    return xla.apply_primitive(superbee_p, *args, **kwargs)


def _superbee_gpu_translation_rule(computation_builder, var, u_wgrid, v_wgrid,
                                   w_wgrid, maskW, dxt, dyt, dzw, cost, cosu,
                                   dt_tracer):
    return _superbee(computation_builder, var, u_wgrid, v_wgrid, w_wgrid,
                     maskW, dxt, dyt, dzw, cost, cosu, dt_tracer)


def superbee_abstract_eval(var, u_wgrid, v_wgrid, w_wgrid, maskW, dxt, dyt,
                           dzw, cost, cosu, dt_tracer):

    aarr = abstract_arrays.ShapedArray(var.shape, var.dtype)
    return (aarr, ) * 3


superbee_p = Primitive('superbee')
superbee_p.multiple_results = True
superbee_p.def_impl(superbee_impl)
superbee_p.def_abstract_eval(superbee_abstract_eval)
xla.backend_specific_translations['gpu'][
    superbee_p] = _superbee_gpu_translation_rule
Exemplo n.º 16
0
from jax._src import traceback_util
traceback_util.register_exclusion(__file__)

Array = Any

map = safe_map

jaxval_adders: Dict[type, Callable] = {}


def add_jaxvals(x, y):
    return add_jaxvals_p.bind(x, y)


add_jaxvals_p: Primitive = Primitive('add_any')
add_any_p = add_jaxvals_p


@add_jaxvals_p.def_impl
def add_impl(xs, ys):
    return jaxval_adders[type(xs)](xs, ys)


@add_jaxvals_p.def_abstract_eval
def add_abstract(xs, ys):
    return lattice_join(xs, ys)


jaxval_zeros_likers: Dict[type, Array] = {}
Exemplo n.º 17
0
      i_old += 1
    elif x.shape[i_old] > 1 and new_sizes[i_new] <= 1:
      i_new += 1
    else:
      assert x.shape[i_old] == new_sizes[i_new]
      i_old += 1
      i_new += 1

  return out_indices,
op2ind[lax.reshape_p] = reshape2ind


# Custom Primitive


tex_var_p = Primitive('tex_var')
tex_var_p.def_impl(lambda x, *depends_on, **kwargs: x)
tex_var_p.def_abstract_eval(lambda x, *depends_on, **kwargs: x)


def tex_var_jvp(primals, tangents, name, is_alias):
  primal_out = tex_var(primals[0], name, is_alias, primals[1:])
  tangent_out = tex_var(tangents[0], 'd' + name, False, primals[1:])
  return primal_out, tangent_out
ad.primitive_jvps[tex_var_p] = tex_var_jvp


def tex_var_transpose(ct, x, *depends_on, **params):
  del x
  ct = tex_var(ct, '\\delta ' + params['name'][1:], False, depends_on)
  return (ct,) + (ad.Zero,) * len(depends_on)
Exemplo n.º 18
0
import numpy as np
from jax.util import safe_map, safe_zip
from jax.core import Primitive
from jax.interpreters import ad, xla, batching, numpy_eval
from jax.lax import dynamic_update_slice_p

map = safe_map
zip = safe_zip

inplace_dynamic_update_slice_p = Primitive('inplace_dynamic_update_slice')
inplace_dynamic_update_slice_p.def_impl(dynamic_update_slice_p.impl)
inplace_dynamic_update_slice_p.def_abstract_eval(dynamic_update_slice_p.abstract_eval)
for rules in [xla.translations, ad.primitive_jvps, ad.primitive_transposes,
              batching.primitive_batchers]:
  rules[inplace_dynamic_update_slice_p] = rules[dynamic_update_slice_p]

def _numpy_inplace_dynamic_update_slice(operand, update, *start_indices):
  slices = tuple(map(slice, start_indices, np.add(start_indices, update.shape)))
  operand[slices] = update
  return operand

numpy_eval.np_impl[inplace_dynamic_update_slice_p] = \
  _numpy_inplace_dynamic_update_slice

def inplace_dynamic_update_slice(operand, update, start_indices):
  return inplace_dynamic_update_slice_p.bind(operand, update, *start_indices)
Exemplo n.º 19
0
    def decorator(fenics_function: Callable) -> Callable:
        @functools.wraps(fenics_function)
        def jax_solve_eval(*args):
            return jax_solve_eval_p.bind(*args)

        jax_solve_eval_p = Primitive("jax_solve_eval")
        jax_solve_eval_p.def_impl(lambda *args: solve_eval(
            fenics_function, fenics_templates, *args)[0])

        jax_solve_eval_p.def_abstract_eval(
            lambda *args: jax.abstract_arrays.make_shaped_array(
                solve_eval(fenics_function, fenics_templates, *args)[0]))

        def jax_solve_eval_batch(vector_arg_values, batch_axes):
            assert len(
                set(batch_axes)) == 1  # assert that all batch axes are same
            assert (
                batch_axes[0] == 0
            )  # assert that batch axis is zero, need to rewrite for a general case?
            res = list(map(jax_solve_eval, *vector_arg_values))
            res = np.asarray(res)
            return res, batch_axes[0]

        jax.interpreters.batching.primitive_batchers[
            jax_solve_eval_p] = jax_solve_eval_batch

        # @trace("jvp_jax_solve_eval")
        def jvp_jax_solve_eval(ps, ts):
            return jvp_jax_solve_eval_p.bind(ps, ts)

        jvp_jax_solve_eval_p = Primitive("jvp_jax_solve_eval")
        jvp_jax_solve_eval_p.multiple_results = True
        jvp_jax_solve_eval_p.def_impl(lambda ps, ts: jvp_solve_eval(
            fenics_function, fenics_templates, ps, ts))

        jax.interpreters.ad.primitive_jvps[
            jax_solve_eval_p] = jvp_jax_solve_eval

        # TODO: JAX Tracer goes inside fenics wrappers and zero array is returned
        # because fenics numpy conversion works only for concrete arrays
        vjp_jax_solve_eval_p = Primitive("vjp_jax_solve_eval")
        vjp_jax_solve_eval_p.def_impl(lambda ct, *args: vjp_solve_eval(
            fenics_function, fenics_templates, *args)[1](ct))

        jax.interpreters.ad.primitive_transposes[
            jax_solve_eval_p] = vjp_jax_solve_eval_p

        return jax_solve_eval
Exemplo n.º 20
0
from ..utils import (
    HashableMPIType,
    default_primitive_impl,
    to_dtype_handle,
    to_mpi_handle,
    unpack_hashable,
    wrap_as_hashable,
    xla_constant_intc,
    xla_constant_uintptr,
)
from ..decorators import translation_rule_cpu, translation_rule_gpu
from ..validation import enforce_types
from ..comm import get_default_comm

# The Jax primitive
mpi_gather_p = Primitive("gather_mpi")  # Create the primitive
mpi_gather_impl = default_primitive_impl(mpi_gather_p)


# This function applies the primitive to an AST
@enforce_types(
    root=(_np.integer),
    comm=(type(None), _MPI.Intracomm, HashableMPIType),
    token=(type(None), xla.Token, core.Tracer),
)
def gather(
    x,
    root,
    *,
    comm=None,
    token=None,
Exemplo n.º 21
0
def vjp_solve_eval(fenics_function: Callable, fenics_templates: FenicsVariable,
                   *args: np.array) -> Tuple[np.array, Callable]:
    """Computes the gradients of the output with respect to the input
    Input:
        fenics_function (callable): FEniCS function to be executed during the forward pass
        args (tuple): jax array representation of the input to fenics_function
    Output:
        A pair where the first element is the value of fun applied to the arguments and the second element
        is a Python callable representing the VJP map from output cotangents to input cotangents.
        The returned VJP function must accept a value with the same shape as the value of fun applied
        to the arguments and must return a tuple with length equal to the number of positional arguments to fun.
    """

    numpy_output, fenics_solution, residual_form, fenics_inputs, bcs = solve_eval(
        fenics_function, fenics_templates, *args)

    # @trace("vjp_fun1")
    def vjp_fun1(g):
        return vjp_fun1_p.bind(g)

    vjp_fun1_p = Primitive("vjp_fun1")
    vjp_fun1_p.multiple_results = True
    vjp_fun1_p.def_impl(lambda g: tuple(
        vjp if vjp is not None else jax.ad_util.zeros_like_jaxval(args[i])
        for i, vjp in enumerate(
            vjp_solve_eval_impl(g, fenics_solution, residual_form,
                                fenics_inputs, bcs))))

    # @trace("vjp_fun1_abstract_eval")
    def vjp_fun1_abstract_eval(g):
        if len(args) > 1:
            return tuple(
                (jax.abstract_arrays.ShapedArray(arg.shape, arg.dtype)
                 for arg in args))
        else:
            return (jax.abstract_arrays.ShapedArray((1, *args[0].shape),
                                                    args[0].dtype), )

    vjp_fun1_p.def_abstract_eval(vjp_fun1_abstract_eval)

    # @trace("vjp_fun1_batch")
    def vjp_fun1_batch(vector_arg_values, batch_axes):
        """Computes the batched version of the primitive.

        This must be a JAX-traceable function.

        Args:
            vector_arg_values: a tuple of arguments, each being a tensor of matching
            shape.
            batch_axes: the axes that are being batched. See vmap documentation.
        Returns:
            a tuple of the result, and the result axis that was batched.
        """
        # _trace("Using vjp_fun1 to compute the batch:")
        assert (
            batch_axes[0] == 0
        )  # assert that batch axis is zero, need to rewrite for a general case?
        # compute function row-by-row
        res = [
            vjp_fun1(vector_arg_values[0][i])
            for i in range(vector_arg_values[0].shape[0])
        ]
        # transpose resulting list
        res_T = list(itertools.zip_longest(*res))
        return tuple(map(np.vstack, res_T)), (batch_axes[0], ) * len(args)

    jax.interpreters.batching.primitive_batchers[vjp_fun1_p] = vjp_fun1_batch

    return numpy_output, vjp_fun1
Exemplo n.º 22
0
def eig_jvp_rule(primals, tangents, *, compute_left_eigenvectors,
                 compute_right_eigenvectors):
  if compute_left_eigenvectors or compute_right_eigenvectors:
    raise NotImplementedError(
        'The derivatives of eigenvectors are not implemented, only '
        'eigenvalues. See '
        'https://github.com/google/jax/issues/2748 for discussion.')
  # Formula for derivative of eigenvalues w.r.t. a is eqn 4.60 in
  # https://arxiv.org/abs/1701.00392
  a, = primals
  da, = tangents
  l, v = eig(a, compute_left_eigenvectors=False)
  return [l], [jnp.sum(_solve(v, da.astype(v.dtype)) * _T(v), -1)]

eig_p = Primitive('eig')
eig_p.multiple_results = True
eig_p.def_impl(eig_impl)
eig_p.def_abstract_eval(eig_abstract_eval)
xla.translations[eig_p] = eig_translation_rule
xla.backend_specific_translations['cpu'][eig_p] = eig_cpu_translation_rule
batching.primitive_batchers[eig_p] = eig_batching_rule
ad.primitive_jvps[eig_p] = eig_jvp_rule


# Symmetric/Hermitian eigendecomposition

def eigh_impl(operand, lower):
  v, w = xla.apply_primitive(eigh_p, operand, lower=lower)
  return v, w
Exemplo n.º 23
0
  scale = 1 / prod(fft_lengths)
  out = scale * mask * x
  assert out.dtype == _complex_dtype(t.dtype), (out.dtype, t.dtype)
  # Use JAX's convention for complex gradients
  # https://github.com/google/jax/issues/6223#issuecomment-807740707
  return lax.conj(out)

def fft_transpose_rule(t, operand, fft_type, fft_lengths):
  if fft_type == xla_client.FftType.RFFT:
    result = _rfft_transpose(t, fft_lengths)
  elif fft_type == xla_client.FftType.IRFFT:
    result = _irfft_transpose(t, fft_lengths)
  else:
    result = fft(t, fft_type, fft_lengths)
  return result,

def fft_batching_rule(batched_args, batch_dims, fft_type, fft_lengths):
  x, = batched_args
  bd, = batch_dims
  x = batching.moveaxis(x, bd, 0)
  return fft(x, fft_type, fft_lengths), 0

fft_p = Primitive('fft')
fft_p.def_impl(fft_impl)
fft_p.def_abstract_eval(fft_abstract_eval)
xla.translations[fft_p] = fft_translation_rule
ad.deflinear2(fft_p, fft_transpose_rule)
batching.primitive_batchers[fft_p] = fft_batching_rule
if pocketfft:
  xla.backend_specific_translations['cpu'][fft_p] = pocketfft.pocketfft
Exemplo n.º 24
0
def _random_key_abstract_eval(*args, **params):
    assert len(args) == 0
    assert len(params) == 0

    return ShapedArray((2,), 'uint32')


def _random_key_impl(*args, **params):
    assert len(args) == 0
    assert len(params) == 0

    raise ValueError("This parametrized function is randomized and therefore requires "
                     "a random key when applied, i. e. `apply(*inputs, key=PRNGKey(0))`.")


random_key_p = Primitive("random_key")
random_key_p.def_custom_bind(bind(random_key_p))
random_key_p.def_impl(_random_key_impl)
random_key_p.def_abstract_eval(_random_key_abstract_eval)


def random_key():
    """When called inside a parametrized function, this will return a unique random key derived from
    the `key` argument of `apply` or `init_parameters`."""
    return random_key_p.bind()


class parametrized(Primitive):
    """Represents a parametrized function, providing an
    `init_parameters` function for bundled initialization of all parameters,
    and an `apply` function to evaluate the function given these bundled parameters.
Exemplo n.º 25
0
def qr_translation_rule(c, operand, full_matrices):
    return c.QR(operand, full_matrices=full_matrices)


def qr_abstract_eval(operand, full_matrices):
    if isinstance(operand, ShapedArray):
        if operand.ndim < 2:
            raise ValueError(
                "Argument to QR decomposition must have ndims >= 2")
        batch_dims = operand.shape[:-2]
        m = operand.shape[-2]
        n = operand.shape[-1]
        k = m if full_matrices else min(m, n)
        q = ShapedArray(batch_dims + (m, k), operand.dtype)
        r = ShapedArray(batch_dims + (k, n), operand.dtype)
    else:
        q = operand
        r = operand
    return core.AbstractTuple((q, r))


def qr_dtype_rule(operand, full_matrices=True):
    return operand.dtype


qr_p = Primitive('qr')
qr_p.def_impl(qr_impl)
qr_p.def_abstract_eval(qr_abstract_eval)
xla.translations[qr_p] = qr_translation_rule
Exemplo n.º 26
0
from jax.lib import xla_client
from jax.interpreters import xla

from ..utils import (
    to_mpi_ptr,
    _unpack_builder,
    _ops,
    _constant_s32_scalar,
    _constant_u64_scalar,
    dtype_ptr,
)

from ..warn import warn_missing_omnistaging

# The Jax primitive
mpi_recv_p = Primitive("recv_mpi")  # Create the primitive


# This function applies the primitive to an AST
def Recv(
    x,
    source=_MPI.ANY_SOURCE,
    tag=_MPI.ANY_TAG,
    comm=_MPI.COMM_WORLD,
    status=None,
    token=None,
):
    """
    Recv(x, source=_MPI.ANY_SOURCE, tag=_MPI.ANY_TAG, comm=_MPI.COMM_WORLD, status=None, token=None)

    Receives the input`x` from the target rank `source` using the communicator `comm` 
Exemplo n.º 27
0
from .. import create_token

from ..utils import (
    to_mpi_ptr,
    _unpack_builder,
    _ops,
    _constant_s32_scalar,
    _constant_u64_scalar,
    dtype_ptr,
)

from ..warn import warn_missing_omnistaging

# The Jax primitive
mpi_allreduce_p = Primitive("allreduce_mpi")  # Create the primitive


# This function applies the primitive to an AST
def Allreduce(x, op, comm=_MPI.COMM_WORLD, token=None):
    """
    Allreduce(x, op, comm=_MPI.COMM_WORLD, token=None)

    Performs the Allreduce operation `op` on the input `x` using the
    communicator `comm` which defaults to the world comunicator.
    An optional token can be passed, which is used to force jax to execute
    MPI operations in the correct order.

    Argumemnts:
        x: Array or scalar input.
        op: The reduction operation `MPI.Op` (e.g: MPI.SUM)
Exemplo n.º 28
0
        operands=(a, b, c, d),
        shape_with_layout=shape,
        operand_shapes_with_layout=(shape, ) * 4,
        opaque=opaque)


def tridiag(a, b, c, d):
    if not a.shape == b.shape == c.shape == d.shape:
        raise ValueError('all inputs must have identical shape')
    if not a.dtype == b.dtype == c.dtype == d.dtype:
        raise ValueError('all inputs must have the same dtype')
    return tridiag_p.bind(a, b, c, d)  #transpose(res, (0,2,1))


def tridiag_impl(*args, **kwargs):
    return xla.apply_primitive(tridiag_p, *args, **kwargs)


def _tridiag_gpu_translation_rule(computation_builder, a, b, c, d):
    return _tridiag(computation_builder, a, b, c, d)


def tridiag_abstract_eval(a, b, c, d):
    return abstract_arrays.ShapedArray(a.shape, a.dtype)


tridiag_p = Primitive('tridiag')
tridiag_p.def_impl(tridiag_impl)
tridiag_p.def_abstract_eval(tridiag_abstract_eval)
xla.backend_specific_translations['gpu'][
    tridiag_p] = _tridiag_gpu_translation_rule
Exemplo n.º 29
0
                             _nan_like(c, w))
    vl = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), vl,
                              _nan_like(c, vl))
    vr = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), vr,
                              _nan_like(c, vr))
    return xops.Tuple(c, [w, vl, vr])


def eig_batching_rule(batched_args, batch_dims):
    x, = batched_args
    bd, = batch_dims
    x = batching.moveaxis(x, bd, 0)
    return eig_p.bind(x), (0, 0, 0)


eig_p = Primitive('eig')
eig_p.multiple_results = True
eig_p.def_impl(eig_impl)
eig_p.def_abstract_eval(eig_abstract_eval)
xla.translations[eig_p] = eig_translation_rule
xla.backend_specific_translations['cpu'][eig_p] = eig_cpu_translation_rule
batching.primitive_batchers[eig_p] = eig_batching_rule

# Symmetric/Hermitian eigendecomposition


def eigh_impl(operand, lower):
    v, w = xla.apply_primitive(eigh_p, operand, lower=lower)
    return v, w

Exemplo n.º 30
0
    grads = vmap(_standard_gamma_grad_one)(samples, alphas)
    return grads.reshape(alpha.shape)


def _standard_gamma_abstract_eval(key, alpha, jaxpr, aval, consts):
    return lax.maybe_tracer_tuple_to_abstract_tuple(aval)


def _standard_gamma_translate(c, key, alpha, jaxpr, aval, consts):
    xla_computation = xla.jaxpr_computation(jaxpr, consts, (), c.GetShape(key),
                                            c.GetShape(alpha))
    return c.Call(xla_computation, (key, alpha))


# define primitive
standard_gamma_p = Primitive('standard_gamma')
standard_gamma_p.def_impl(partial(xla.apply_primitive, standard_gamma_p))
standard_gamma_p.def_abstract_eval(_standard_gamma_abstract_eval)
xla.translations[standard_gamma_p] = _standard_gamma_translate
ad.defjvp2(
    standard_gamma_p, None, lambda tangent, sample, key, alpha, **kwargs:
    tangent * _standard_gamma_grad(sample, alpha))


@partial(jit, static_argnums=(2, 3))
def standard_gamma(key, alpha, shape=(), dtype=np.float32):
    shape = shape or np.shape(alpha)
    alpha = lax.convert_element_type(alpha, dtype)
    if np.shape(alpha) != shape:
        alpha = np.broadcast_to(alpha, shape)
    jaxpr, out, consts = partial_eval.trace_unwrapped_to_jaxpr(