Exemplo n.º 1
0
    def test_unimplemented_interpreter_rules(self):
        foo_p = Primitive('foo')

        def foo(x):
            return foo_p.bind(x)

        jtu.check_raises(lambda: foo(1.0), NotImplementedError,
                         "Evaluation rule for 'foo' not implemented")

        jtu.check_raises(lambda: jit(foo)(1.0), NotImplementedError,
                         "Abstract evaluation for 'foo' not implemented")

        jtu.check_raises(
            lambda: grad(foo)(1.0), NotImplementedError,
            "Forward-mode differentiation rule for 'foo' not implemented")

        foo_p.def_abstract_eval(lambda x: x)

        jtu.check_raises(lambda: jit(foo)(1.0), NotImplementedError,
                         "XLA translation rule for 'foo' not implemented")

        foo_p.def_impl(lambda x: x)
        defjvp(foo_p, lambda g, x: foo(g))

        jtu.check_raises(
            lambda: grad(foo)(1.0), NotImplementedError,
            "Reverse-mode differentiation rule for 'foo' not implemented")
Exemplo n.º 2
0
    def decorator(fenics_function: Callable) -> Callable:
        @functools.wraps(fenics_function)
        @custom_vjp
        def jax_fem_eval(*args):
            return jax_fem_eval_p.bind(*args)

        jax_fem_eval_p = Primitive("jax_fem_eval")
        jax_fem_eval_p.def_impl(
            lambda *args: evaluate_primal(fenics_function, fenics_templates, *args)[0]
        )

        jax_fem_eval_p.def_abstract_eval(
            lambda *args: jax.abstract_arrays.make_shaped_array(
                evaluate_primal(fenics_function, fenics_templates, *args)[0]
            )
        )

        def jax_fem_eval_batch(vector_arg_values, batch_axes):
            assert len(set(batch_axes)) == 1  # assert that all batch axes are same
            assert (
                batch_axes[0] == 0
            )  # assert that batch axis is zero, need to rewrite for a general case?
            res = list(map(jax_fem_eval, *vector_arg_values))
            res = np.asarray(res)
            return res, batch_axes[0]

        jax.interpreters.batching.primitive_batchers[
            jax_fem_eval_p
        ] = jax_fem_eval_batch

        def primal(*args):
            numpy_output, fenics_output, fenics_inputs, tape = evaluate_primal(
                fenics_function, fenics_templates, *args
            )
            return (
                numpy_output,
                (PyadjointMetadata(fenics_output, fenics_inputs, tape), args),
            )

        def pullback(aux_args, g):
            pb_fn = get_pullback_function(fenics_function, fenics_templates)
            # for some reason output of get_pullback_function is a list but we need tuple
            return tuple(pb_fn(aux_args, g))

        jax_fem_eval.defvjp(primal, pullback)
        return jax_fem_eval
Exemplo n.º 3
0
    def decorator(fenics_function: Callable) -> Callable:
        @functools.wraps(fenics_function)
        def jax_solve_eval(*args):
            return jax_solve_eval_p.bind(*args)

        jax_solve_eval_p = Primitive("jax_solve_eval")
        jax_solve_eval_p.def_impl(lambda *args: solve_eval(
            fenics_function, fenics_templates, *args)[0])

        jax_solve_eval_p.def_abstract_eval(
            lambda *args: jax.abstract_arrays.make_shaped_array(
                solve_eval(fenics_function, fenics_templates, *args)[0]))

        def jax_solve_eval_batch(vector_arg_values, batch_axes):
            assert len(
                set(batch_axes)) == 1  # assert that all batch axes are same
            assert (
                batch_axes[0] == 0
            )  # assert that batch axis is zero, need to rewrite for a general case?
            # compute function row-by-row
            res = np.asarray([
                jax_solve_eval(*(vector_arg_values[j][i]
                                 for j in range(len(batch_axes))))
                for i in range(vector_arg_values[0].shape[0])
            ])
            return res, batch_axes[0]

        jax.batching.primitive_batchers[
            jax_solve_eval_p] = jax_solve_eval_batch

        # @trace("djax_solve_eval")
        def djax_solve_eval(*args):
            return djax_solve_eval_p.bind(*args)

        djax_solve_eval_p = Primitive("djax_solve_eval")
        # djax_solve_eval_p.multiple_results = True
        djax_solve_eval_p.def_impl(lambda *args: vjp_solve_eval(
            fenics_function, fenics_templates, *args))

        defvjp_all(jax_solve_eval_p, djax_solve_eval)
        return jax_solve_eval
Exemplo n.º 4
0
    def decorator(fenics_function: Callable) -> Callable:
        def jax_assemble_eval(*args):
            return jax_assemble_eval_p.bind(*args)

        jax_assemble_eval_p = Primitive("jax_assemble_eval")
        jax_assemble_eval_p.def_impl(
            lambda *args: assemble_eval(fenics_function, fenics_templates, *args)[0]
        )

        jax_assemble_eval_p.def_abstract_eval(
            lambda *args: jax.abstract_arrays.make_shaped_array(
                assemble_eval(fenics_function, fenics_templates, *args)[0]
            )
        )

        def jax_assemble_eval_batch(vector_arg_values, batch_axes):
            assert len(set(batch_axes)) == 1  # assert that all batch axes are same
            assert (
                batch_axes[0] == 0
            )  # assert that batch axis is zero, need to rewrite for a general case?
            res = list(map(jax_assemble_eval, *vector_arg_values))
            res = np.asarray(res)
            return res, batch_axes[0]

        batching.primitive_batchers[jax_assemble_eval_p] = jax_assemble_eval_batch

        # @trace("djax_assemble_eval")
        def djax_assemble_eval(*args):
            return djax_assemble_eval_p.bind(*args)

        djax_assemble_eval_p = Primitive("djax_assemble_eval")
        # djax_assemble_eval_p.multiple_results = True
        djax_assemble_eval_p.def_impl(
            lambda *args: vjp_assemble_eval(fenics_function, fenics_templates, *args)
        )

        defvjp_all(jax_assemble_eval_p, djax_assemble_eval)

        return jax_assemble_eval
Exemplo n.º 5
0
    def decorator(fenics_function: Callable) -> Callable:
        @functools.wraps(fenics_function)
        def jax_solve_eval(*args):
            return jax_solve_eval_p.bind(*args)

        jax_solve_eval_p = Primitive("jax_solve_eval")
        jax_solve_eval_p.def_impl(
            lambda *args: solve_eval(fenics_function, fenics_templates, *args)[0]
        )

        jax_solve_eval_p.def_abstract_eval(
            lambda *args: jax.abstract_arrays.make_shaped_array(
                solve_eval(fenics_function, fenics_templates, *args)[0]
            )
        )

        def jax_solve_eval_batch(vector_arg_values, batch_axes):
            assert len(set(batch_axes)) == 1  # assert that all batch axes are same
            assert (
                batch_axes[0] == 0
            )  # assert that batch axis is zero, need to rewrite for a general case?
            res = list(map(jax_solve_eval, *vector_arg_values))
            res = np.asarray(res)
            return res, batch_axes[0]

        jax.batching.primitive_batchers[jax_solve_eval_p] = jax_solve_eval_batch

        # @trace("jvp_jax_solve_eval")
        def jvp_jax_solve_eval(ps, ts):
            return jvp_jax_solve_eval_p.bind(ps, ts)

        jvp_jax_solve_eval_p = Primitive("jvp_jax_solve_eval")
        jvp_jax_solve_eval_p.multiple_results = True
        jvp_jax_solve_eval_p.def_impl(
            lambda ps, ts: jvp_solve_eval(fenics_function, fenics_templates, ps, ts)
        )

        jax.interpreters.ad.primitive_jvps[jax_solve_eval_p] = jvp_jax_solve_eval

        # TODO: JAX Tracer goes inside fenics wrappers and zero array is returned
        # because fenics numpy conversion works only for concrete arrays
        vjp_jax_solve_eval_p = Primitive("vjp_jax_solve_eval")
        vjp_jax_solve_eval_p.def_impl(
            lambda ct, *args: vjp_solve_eval(fenics_function, fenics_templates, *args)[
                1
            ](ct)
        )

        jax.interpreters.ad.primitive_transposes[
            jax_solve_eval_p
        ] = vjp_jax_solve_eval_p

        return jax_solve_eval
Exemplo n.º 6
0
        shape=sh,
        opaque=descriptor,
        has_side_effect=True,
    )


# This function evaluates only the shapes during AST construction
def mpi_scatter_abstract_eval(x, token, root, comm):
    comm = unpack_hashable(comm)
    rank = comm.Get_rank()
    if rank == root:
        out_shape = x.shape[1:]
    else:
        out_shape = x.shape

    return (
        abstract_arrays.ShapedArray(out_shape, x.dtype),
        core.abstract_token,
    )


mpi_scatter_p.multiple_results = True
mpi_scatter_p.def_impl(mpi_scatter_impl)
mpi_scatter_p.def_abstract_eval(mpi_scatter_abstract_eval)

# assign to the primitive the correct encoder
xla.backend_specific_translations["cpu"][
    mpi_scatter_p] = mpi_scatter_xla_encode_cpu
xla.backend_specific_translations["gpu"][
    mpi_scatter_p] = mpi_scatter_xla_encode_gpu
Exemplo n.º 7
0
  scale = 1 / prod(fft_lengths)
  out = scale * mask * x
  assert out.dtype == _complex_dtype(t.dtype), (out.dtype, t.dtype)
  # Use JAX's convention for complex gradients
  # https://github.com/google/jax/issues/6223#issuecomment-807740707
  return lax.conj(out)

def fft_transpose_rule(t, operand, fft_type, fft_lengths):
  if fft_type == xla_client.FftType.RFFT:
    result = _rfft_transpose(t, fft_lengths)
  elif fft_type == xla_client.FftType.IRFFT:
    result = _irfft_transpose(t, fft_lengths)
  else:
    result = fft(t, fft_type, fft_lengths)
  return result,

def fft_batching_rule(batched_args, batch_dims, fft_type, fft_lengths):
  x, = batched_args
  bd, = batch_dims
  x = batching.moveaxis(x, bd, 0)
  return fft(x, fft_type, fft_lengths), 0

fft_p = Primitive('fft')
fft_p.def_impl(fft_impl)
fft_p.def_abstract_eval(fft_abstract_eval)
xla.translations[fft_p] = fft_translation_rule
ad.deflinear2(fft_p, fft_transpose_rule)
batching.primitive_batchers[fft_p] = fft_batching_rule
if pocketfft:
  xla.backend_specific_translations['cpu'][fft_p] = pocketfft.pocketfft
Exemplo n.º 8
0
        operands=(
            sendbuf,
            token,
        ),
        shape=sh,
        opaque=descriptor,
        has_side_effect=True,
    )


# This function evaluates only the shapes during AST construction
def mpi_allgather_abstract_eval(x, token, comm):
    comm = unpack_hashable(comm)
    size = comm.Get_size()
    out_shape = (size, *x.shape)
    return (
        abstract_arrays.ShapedArray(out_shape, x.dtype),
        abstract_arrays.abstract_token,
    )


mpi_allgather_p.multiple_results = True
mpi_allgather_p.def_impl(mpi_allgather_impl)
mpi_allgather_p.def_abstract_eval(mpi_allgather_abstract_eval)

# assign to the primitive the correct encoder
xla.backend_specific_translations["cpu"][
    mpi_allgather_p] = mpi_allgather_xla_encode_cpu
xla.backend_specific_translations["gpu"][
    mpi_allgather_p] = mpi_allgather_xla_encode_gpu
Exemplo n.º 9
0
    # Use JAX's convention for complex gradients
    # https://github.com/google/jax/issues/6223#issuecomment-807740707
    return lax.conj(out)


def _fft_transpose_rule(t, operand, fft_type, fft_lengths):
    if fft_type == xla_client.FftType.RFFT:
        result = _rfft_transpose(t, fft_lengths)
    elif fft_type == xla_client.FftType.IRFFT:
        result = _irfft_transpose(t, fft_lengths)
    else:
        result = fft(t, fft_type, fft_lengths)
    return result,


def _fft_batching_rule(batched_args, batch_dims, fft_type, fft_lengths):
    x, = batched_args
    bd, = batch_dims
    x = batching.moveaxis(x, bd, 0)
    return fft(x, fft_type, fft_lengths), 0


fft_p = Primitive('fft')
fft_p.def_impl(_fft_impl)
fft_p.def_abstract_eval(fft_abstract_eval)
mlir.register_lowering(fft_p, _fft_lowering)
ad.deflinear2(fft_p, _fft_transpose_rule)
batching.primitive_batchers[fft_p] = _fft_batching_rule
if pocketfft:
    mlir.register_lowering(fft_p, _fft_lowering_cpu, platform='cpu')
Exemplo n.º 10
0
import numpy as np
from jax.util import safe_map, safe_zip
from jax.core import Primitive
from jax.interpreters import ad, xla, batching, numpy_eval
from jax.lax import dynamic_update_slice_p

map = safe_map
zip = safe_zip

inplace_dynamic_update_slice_p = Primitive('inplace_dynamic_update_slice')
inplace_dynamic_update_slice_p.def_impl(dynamic_update_slice_p.impl)
inplace_dynamic_update_slice_p.def_abstract_eval(dynamic_update_slice_p.abstract_eval)
for rules in [xla.translations, ad.primitive_jvps, ad.primitive_transposes,
              batching.primitive_batchers]:
  rules[inplace_dynamic_update_slice_p] = rules[dynamic_update_slice_p]

def _numpy_inplace_dynamic_update_slice(operand, update, *start_indices):
  slices = tuple(map(slice, start_indices, np.add(start_indices, update.shape)))
  operand[slices] = update
  return operand

numpy_eval.np_impl[inplace_dynamic_update_slice_p] = \
  _numpy_inplace_dynamic_update_slice

def inplace_dynamic_update_slice(operand, update, start_indices):
  return inplace_dynamic_update_slice_p.bind(operand, update, *start_indices)
Exemplo n.º 11
0
        shape=sh,
        opaque=descriptor,
        has_side_effect=True,
    )


# This function evaluates only the shapes during AST construction
def mpi_gather_abstract_eval(x, token, root, comm):
    comm = unpack_hashable(comm)
    rank = comm.Get_rank()
    size = comm.Get_size()

    if rank == root:
        out_shape = (size, *x.shape)
    else:
        out_shape = x.shape

    return (
        abstract_arrays.ShapedArray(out_shape, x.dtype),
        core.abstract_token,
    )


mpi_gather_p.multiple_results = True
mpi_gather_p.def_impl(mpi_gather_impl)
mpi_gather_p.def_abstract_eval(mpi_gather_abstract_eval)

# assign to the primitive the correct encoder
xla.backend_specific_translations["cpu"][mpi_gather_p] = mpi_gather_xla_encode_cpu
xla.backend_specific_translations["gpu"][mpi_gather_p] = mpi_gather_xla_encode_gpu
Exemplo n.º 12
0
def get_pullback_function(
    fenics_function: Callable, fenics_templates: Collection[BackendVariable]
) -> Callable:
    """Computes the gradients of the output with respect to the input
    Input:
        fenics_function (callable): FEniCS function to be executed during the forward pass
    Output:
        A Python callable representing the VJP map from output cotangents to input cotangents.
        The returned VJP function must accept a value with the same shape as the value of fun applied
        to the arguments and must return a tuple with length equal to the number of positional arguments to fun.
    """

    # @trace("vjp_fun1")
    def vjp_fun1(aux_args, g):
        return vjp_fun1_p.bind(aux_args, g)

    def vjp_fun1_p_impl(aux_args, g):
        fe_aux, args = aux_args
        fenics_output, fenics_inputs, tape = (
            fe_aux.fenics_output,
            fe_aux.fenics_inputs,
            fe_aux.tape,
        )
        return tuple(
            vjp if vjp is not None else jax.ad_util.zeros_like_jaxval(args[i])
            for i, vjp in enumerate(
                evaluate_pullback(fenics_output, fenics_inputs, tape, g)
            )
        )

    vjp_fun1_p = Primitive("vjp_fun1")
    vjp_fun1_p.multiple_results = True
    vjp_fun1_p.def_impl(vjp_fun1_p_impl)

    # @trace("vjp_fun1_abstract_eval")
    def vjp_fun1_abstract_eval(aux_args, g):
        _, args = aux_args
        if len(args) > 1:
            return tuple(
                (jax.abstract_arrays.ShapedArray(arg.shape, arg.dtype) for arg in args)
            )
        else:
            return (
                jax.abstract_arrays.ShapedArray((1, *args[0].shape), args[0].dtype),
            )

    vjp_fun1_p.def_abstract_eval(vjp_fun1_abstract_eval)

    # @trace("vjp_fun1_batch")
    def vjp_fun1_batch(vector_arg_values, batch_axes):
        """Computes the batched version of the primitive.

        This must be a JAX-traceable function.

        Args:
            vector_arg_values: a tuple of arguments, each being a tensor of matching
            shape.
            batch_axes: the axes that are being batched. See vmap documentation.
        Returns:
            a tuple of the result, and the result axis that was batched.
        """
        # _trace("Using vjp_fun1 to compute the batch:")
        _, args = vector_arg_values[0]
        assert (
            batch_axes[0] is None
        )  # assert that batch axis is None, need to rewrite for a general case?
        assert (
            batch_axes[1] == 0
        )  # assert that batch axis is zero, need to rewrite for a general case?
        # apply function row-by-row
        vjp_fun1_partial = functools.partial(vjp_fun1, vector_arg_values[0])
        res = list(map(vjp_fun1_partial, *(vector_arg_values[1],)))
        # transpose resulting list
        res_T = list(itertools.zip_longest(*res))
        return tuple(map(np.vstack, res_T)), (batch_axes[1],) * len(args)

    jax.interpreters.batching.primitive_batchers[vjp_fun1_p] = vjp_fun1_batch

    return vjp_fun1
Exemplo n.º 13
0
    def decorator(fenics_function: Callable) -> Callable:
        @functools.wraps(fenics_function)
        def jax_fem_eval(*args):
            return jax_fem_eval_p.bind(*args)

        jax_fem_eval_p = Primitive("jax_fem_eval")

        def jax_fem_eval_p_impl(*args):
            args = (
                jax_to_fenics_numpy(arg, ft) for arg, ft in zip(args, fenics_templates)
            )
            return evaluate_primal(fenics_function, fenics_templates, *args)[0]

        jax_fem_eval_p.def_impl(jax_fem_eval_p_impl)

        def jax_fem_eval_p_abstract_eval(*args):
            args = (
                jax_to_fenics_numpy(arg, ft) for arg, ft in zip(args, fenics_templates)
            )
            return jax.abstract_arrays.make_shaped_array(
                evaluate_primal(fenics_function, fenics_templates, *args)[0]
            )

        jax_fem_eval_p.def_abstract_eval(jax_fem_eval_p_abstract_eval)

        def jax_fem_eval_batch(vector_arg_values, batch_axes):
            assert len(set(batch_axes)) == 1  # assert that all batch axes are same
            assert (
                batch_axes[0] == 0
            )  # assert that batch axis is zero, need to rewrite for a general case?
            res = list(map(jax_fem_eval, *vector_arg_values))
            res = np.asarray(res)
            return res, batch_axes[0]

        jax.interpreters.batching.primitive_batchers[
            jax_fem_eval_p
        ] = jax_fem_eval_batch

        # @trace("jvp_jax_fem_eval")
        def jvp_jax_fem_eval(ps, ts):
            return jvp_jax_fem_eval_p.bind(ps, ts)

        jvp_jax_fem_eval_p = Primitive("jvp_jax_fem_eval")
        jvp_jax_fem_eval_p.multiple_results = True

        def jvp_jax_fem_eval_impl(primals, tangents):
            primals = (
                jax_to_fenics_numpy(p, ft) for p, ft in zip(primals, fenics_templates)
            )
            numpy_output, fenics_output, fenics_inputs, tape = evaluate_primal(
                fenics_function, fenics_templates, *primals
            )

            tangents = (
                jax_to_fenics_numpy(t, ft) for t, ft in zip(tangents, fenics_templates)
            )
            dnumpy_output = evaluate_pushforward(
                fenics_output, fenics_inputs, tape, tangents
            )
            return numpy_output, dnumpy_output

        jvp_jax_fem_eval_p.def_impl(jvp_jax_fem_eval_impl)

        jax.interpreters.ad.primitive_jvps[jax_fem_eval_p] = jvp_jax_fem_eval

        # TODO: JAX Tracer goes inside fenics wrappers and zero array is returned
        # because fenics numpy conversion works only for concrete arrays
        # vjp_jax_fem_eval_p = Primitive("vjp_jax_fem_eval")
        # vjp_jax_fem_eval_p.def_impl(
        #     lambda ct, *args: vjp_fem_eval(fenics_function, fenics_templates, *args)[1](
        #         ct
        #     )
        # )

        # jax.interpreters.ad.primitive_transposes[jax_fem_eval_p] = vjp_jax_fem_eval_p

        return jax_fem_eval
Exemplo n.º 14
0
def PmapPrimitive(name):
    prim = Primitive(name)
    prim.def_impl(partial(_unbound_name_error, name))
    prim.def_abstract_eval(lambda x, *args, **kwargs: x)
    return prim
Exemplo n.º 15
0
    )

    return xla_client.ops.CustomCall(
        c,
        b"mpi_scan",
        operands=(
            x,
            token,
        ),
        shape=sh,
        opaque=descriptor,
        has_side_effect=True,
    )


# This function evaluates only the shapes during AST construction
def mpi_scan_abstract_eval(xs, token, op, comm):
    return (
        abstract_arrays.ShapedArray(xs.shape, xs.dtype),
        core.abstract_token,
    )


mpi_scan_p.multiple_results = True
mpi_scan_p.def_impl(mpi_scan_impl)
mpi_scan_p.def_abstract_eval(mpi_scan_abstract_eval)

# assign to the primitive the correct encoder
xla.backend_specific_translations["cpu"][mpi_scan_p] = mpi_scan_xla_encode_cpu
xla.backend_specific_translations["gpu"][mpi_scan_p] = mpi_scan_xla_encode_gpu
Exemplo n.º 16
0
        operands=(a, b, c, d),
        shape_with_layout=shape,
        operand_shapes_with_layout=(shape, ) * 4,
        opaque=opaque)


def tridiag(a, b, c, d):
    if not a.shape == b.shape == c.shape == d.shape:
        raise ValueError('all inputs must have identical shape')
    if not a.dtype == b.dtype == c.dtype == d.dtype:
        raise ValueError('all inputs must have the same dtype')
    return tridiag_p.bind(a, b, c, d)  #transpose(res, (0,2,1))


def tridiag_impl(*args, **kwargs):
    return xla.apply_primitive(tridiag_p, *args, **kwargs)


def _tridiag_gpu_translation_rule(computation_builder, a, b, c, d):
    return _tridiag(computation_builder, a, b, c, d)


def tridiag_abstract_eval(a, b, c, d):
    return abstract_arrays.ShapedArray(a.shape, a.dtype)


tridiag_p = Primitive('tridiag')
tridiag_p.def_impl(tridiag_impl)
tridiag_p.def_abstract_eval(tridiag_abstract_eval)
xla.backend_specific_translations['gpu'][
    tridiag_p] = _tridiag_gpu_translation_rule
Exemplo n.º 17
0
    return grads.reshape(alpha.shape)


def _standard_gamma_abstract_eval(key, alpha, jaxpr, aval, consts):
    return lax.maybe_tracer_tuple_to_abstract_tuple(aval)


def _standard_gamma_translate(c, key, alpha, jaxpr, aval, consts):
    xla_computation = xla.jaxpr_computation(jaxpr, consts, (), c.GetShape(key),
                                            c.GetShape(alpha))
    return c.Call(xla_computation, (key, alpha))


# define primitive
standard_gamma_p = Primitive('standard_gamma')
standard_gamma_p.def_impl(partial(xla.apply_primitive, standard_gamma_p))
standard_gamma_p.def_abstract_eval(_standard_gamma_abstract_eval)
xla.translations[standard_gamma_p] = _standard_gamma_translate
ad.defjvp2(
    standard_gamma_p, None, lambda tangent, sample, key, alpha, **kwargs:
    tangent * _standard_gamma_grad(sample, alpha))


@partial(jit, static_argnums=(2, 3))
def standard_gamma(key, alpha, shape=(), dtype=np.float32):
    shape = shape or np.shape(alpha)
    alpha = lax.convert_element_type(alpha, dtype)
    if np.shape(alpha) != shape:
        alpha = np.broadcast_to(alpha, shape)
    jaxpr, out, consts = partial_eval.trace_unwrapped_to_jaxpr(
        _standard_gamma_impl, tuple(lax._abstractify(o) for o in (key, alpha)))
Exemplo n.º 18
0
    out = _ops.CustomCall(
        c,
        b"mpi_send",
        operands=(
            _nitems,
            x,
            _constant_s32_scalar(c, dest),
            _constant_s32_scalar(c, tag),
            _constant_u64_scalar(c, to_mpi_ptr(comm)),
            _constant_u64_scalar(c, _dtype_ptr),
            token,
        ),
        shape=sh,
        has_side_effect=True,
    )

    return xla_client.ops.GetTupleElement(out, 0)


# This function evaluates only the shapes during AST construction
def mpi_send_abstract_eval(xs, token, dest, tag, comm):
    return abstract_arrays.abstract_token


# mpi_send_p.multiple_results = True
mpi_send_p.def_impl(mpi_send_impl)
mpi_send_p.def_abstract_eval(mpi_send_abstract_eval)

# assign to the primitive the correct encoder
xla.backend_specific_translations["cpu"][mpi_send_p] = mpi_send_xla_encode
Exemplo n.º 19
0
    return xla_client.ops.CustomCall(
        c,
        b"mpi_alltoall",
        operands=(
            x,
            token,
        ),
        shape=sh,
        opaque=descriptor,
        has_side_effect=True,
    )


# This function evaluates only the shapes during AST construction
def mpi_alltoall_abstract_eval(xs, token, comm):
    return (
        abstract_arrays.ShapedArray(xs.shape, xs.dtype),
        core.abstract_token,
    )


mpi_alltoall_p.multiple_results = True
mpi_alltoall_p.def_impl(mpi_alltoall_impl)
mpi_alltoall_p.def_abstract_eval(mpi_alltoall_abstract_eval)

# assign to the primitive the correct encoder
xla.backend_specific_translations["cpu"][
    mpi_alltoall_p] = mpi_alltoall_xla_encode_cpu
xla.backend_specific_translations["gpu"][
    mpi_alltoall_p] = mpi_alltoall_xla_encode_gpu
Exemplo n.º 20
0
def vjp_solve_eval(fenics_function: Callable, fenics_templates: FenicsVariable,
                   *args: np.array) -> Tuple[np.array, Callable]:
    """Computes the gradients of the output with respect to the input
    Input:
        fenics_function (callable): FEniCS function to be executed during the forward pass
        args (tuple): jax array representation of the input to fenics_function
    Output:
        A pair where the first element is the value of fun applied to the arguments and the second element
        is a Python callable representing the VJP map from output cotangents to input cotangents.
        The returned VJP function must accept a value with the same shape as the value of fun applied
        to the arguments and must return a tuple with length equal to the number of positional arguments to fun.
    """

    numpy_output, fenics_solution, residual_form, fenics_inputs, bcs = solve_eval(
        fenics_function, fenics_templates, *args)

    # @trace("vjp_fun1")
    def vjp_fun1(g):
        return vjp_fun1_p.bind(g)

    vjp_fun1_p = Primitive("vjp_fun1")
    vjp_fun1_p.multiple_results = True
    vjp_fun1_p.def_impl(lambda g: tuple(
        vjp if vjp is not None else jax.ad_util.zeros_like_jaxval(args[i])
        for i, vjp in enumerate(
            vjp_solve_eval_impl(g, fenics_solution, residual_form,
                                fenics_inputs, bcs))))

    # @trace("vjp_fun1_abstract_eval")
    def vjp_fun1_abstract_eval(g):
        if len(args) > 1:
            return tuple(
                (jax.abstract_arrays.ShapedArray(arg.shape, arg.dtype)
                 for arg in args))
        else:
            return (jax.abstract_arrays.ShapedArray((1, *args[0].shape),
                                                    args[0].dtype), )

    vjp_fun1_p.def_abstract_eval(vjp_fun1_abstract_eval)

    # @trace("vjp_fun1_batch")
    def vjp_fun1_batch(vector_arg_values, batch_axes):
        """Computes the batched version of the primitive.

        This must be a JAX-traceable function.

        Args:
            vector_arg_values: a tuple of arguments, each being a tensor of matching
            shape.
            batch_axes: the axes that are being batched. See vmap documentation.
        Returns:
            a tuple of the result, and the result axis that was batched.
        """
        # _trace("Using vjp_fun1 to compute the batch:")
        assert (
            batch_axes[0] == 0
        )  # assert that batch axis is zero, need to rewrite for a general case?
        # compute function row-by-row
        res = [
            vjp_fun1(vector_arg_values[0][i])
            for i in range(vector_arg_values[0].shape[0])
        ]
        # transpose resulting list
        res_T = list(itertools.zip_longest(*res))
        return tuple(map(np.vstack, res_T)), (batch_axes[0], ) * len(args)

    jax.interpreters.batching.primitive_batchers[vjp_fun1_p] = vjp_fun1_batch

    return numpy_output, vjp_fun1
Exemplo n.º 21
0
    elif x.shape[i_old] > 1 and new_sizes[i_new] <= 1:
      i_new += 1
    else:
      assert x.shape[i_old] == new_sizes[i_new]
      i_old += 1
      i_new += 1

  return out_indices,
op2ind[lax.reshape_p] = reshape2ind


# Custom Primitive


tex_var_p = Primitive('tex_var')
tex_var_p.def_impl(lambda x, *depends_on, **kwargs: x)
tex_var_p.def_abstract_eval(lambda x, *depends_on, **kwargs: x)


def tex_var_jvp(primals, tangents, name, is_alias):
  primal_out = tex_var(primals[0], name, is_alias, primals[1:])
  tangent_out = tex_var(tangents[0], 'd' + name, False, primals[1:])
  return primal_out, tangent_out
ad.primitive_jvps[tex_var_p] = tex_var_jvp


def tex_var_transpose(ct, x, *depends_on, **params):
  del x
  ct = tex_var(ct, '\\delta ' + params['name'][1:], False, depends_on)
  return (ct,) + (ad.Zero,) * len(depends_on)
ad.primitive_transposes[tex_var_p] = tex_var_transpose
Exemplo n.º 22
0
    return xla_client.ops.CustomCall(
        c,
        b"mpi_reduce",
        operands=(
            x,
            token,
        ),
        shape=sh,
        opaque=descriptor,
        has_side_effect=True,
    )


# This function evaluates only the shapes during AST construction
def mpi_reduce_abstract_eval(xs, token, op, root, comm):
    return (
        abstract_arrays.ShapedArray(xs.shape, xs.dtype),
        abstract_arrays.abstract_token,
    )


mpi_reduce_p.multiple_results = True
mpi_reduce_p.def_impl(mpi_reduce_impl)
mpi_reduce_p.def_abstract_eval(mpi_reduce_abstract_eval)

# assign to the primitive the correct encoder
xla.backend_specific_translations["cpu"][
    mpi_reduce_p] = mpi_reduce_xla_encode_cpu
xla.backend_specific_translations["gpu"][
    mpi_reduce_p] = mpi_reduce_xla_encode_gpu
Exemplo n.º 23
0
        == cosu.dtype == dt_tracer.dtype:
        raise ValueError('all inputs must have the same dtype')

    return superbee_p.bind(var, u_wgrid, v_wgrid, w_wgrid, maskW, dxt, dyt,
                           dzw, cost, cosu, dt_tracer)


def superbee_impl(*args, **kwargs):
    return xla.apply_primitive(superbee_p, *args, **kwargs)


def _superbee_gpu_translation_rule(computation_builder, var, u_wgrid, v_wgrid,
                                   w_wgrid, maskW, dxt, dyt, dzw, cost, cosu,
                                   dt_tracer):
    return _superbee(computation_builder, var, u_wgrid, v_wgrid, w_wgrid,
                     maskW, dxt, dyt, dzw, cost, cosu, dt_tracer)


def superbee_abstract_eval(var, u_wgrid, v_wgrid, w_wgrid, maskW, dxt, dyt,
                           dzw, cost, cosu, dt_tracer):

    aarr = abstract_arrays.ShapedArray(var.shape, var.dtype)
    return (aarr, ) * 3


superbee_p = Primitive('superbee')
superbee_p.multiple_results = True
superbee_p.def_impl(superbee_impl)
superbee_p.def_abstract_eval(superbee_abstract_eval)
xla.backend_specific_translations['gpu'][
    superbee_p] = _superbee_gpu_translation_rule
Exemplo n.º 24
0
                 compute_right_eigenvectors):
  if compute_left_eigenvectors or compute_right_eigenvectors:
    raise NotImplementedError(
        'The derivatives of eigenvectors are not implemented, only '
        'eigenvalues. See '
        'https://github.com/google/jax/issues/2748 for discussion.')
  # Formula for derivative of eigenvalues w.r.t. a is eqn 4.60 in
  # https://arxiv.org/abs/1701.00392
  a, = primals
  da, = tangents
  l, v = eig(a, compute_left_eigenvectors=False)
  return [l], [jnp.sum(_solve(v, da.astype(v.dtype)) * _T(v), -1)]

eig_p = Primitive('eig')
eig_p.multiple_results = True
eig_p.def_impl(eig_impl)
eig_p.def_abstract_eval(eig_abstract_eval)
xla.translations[eig_p] = eig_translation_rule
xla.backend_specific_translations['cpu'][eig_p] = eig_cpu_translation_rule
batching.primitive_batchers[eig_p] = eig_batching_rule
ad.primitive_jvps[eig_p] = eig_jvp_rule


# Symmetric/Hermitian eigendecomposition

def eigh_impl(operand, lower):
  v, w = xla.apply_primitive(eigh_p, operand, lower=lower)
  return v, w

def eigh_translation_rule(c, operand, lower):
  shape = c.get_shape(operand)
Exemplo n.º 25
0
def qr_translation_rule(c, operand, full_matrices):
    return c.QR(operand, full_matrices=full_matrices)


def qr_abstract_eval(operand, full_matrices):
    if isinstance(operand, ShapedArray):
        if operand.ndim < 2:
            raise ValueError(
                "Argument to QR decomposition must have ndims >= 2")
        batch_dims = operand.shape[:-2]
        m = operand.shape[-2]
        n = operand.shape[-1]
        k = m if full_matrices else min(m, n)
        q = ShapedArray(batch_dims + (m, k), operand.dtype)
        r = ShapedArray(batch_dims + (k, n), operand.dtype)
    else:
        q = operand
        r = operand
    return core.AbstractTuple((q, r))


def qr_dtype_rule(operand, full_matrices=True):
    return operand.dtype


qr_p = Primitive('qr')
qr_p.def_impl(qr_impl)
qr_p.def_abstract_eval(qr_abstract_eval)
xla.translations[qr_p] = qr_translation_rule
Exemplo n.º 26
0
        c,
        b"mpi_sendrecv",
        operands=(
            sendbuf,
            token,
        ),
        shape=sh,
        opaque=descriptor,
        has_side_effect=True,
    )


# This function evaluates only the shapes during AST construction
def mpi_sendrecv_abstract_eval(sendbuf, recvbuf, token, source, dest, sendtag,
                               recvtag, comm, status):
    return (
        abstract_arrays.ShapedArray(recvbuf.shape, recvbuf.dtype),
        abstract_arrays.abstract_token,
    )


mpi_sendrecv_p.multiple_results = True
mpi_sendrecv_p.def_impl(mpi_sendrecv_impl)
mpi_sendrecv_p.def_abstract_eval(mpi_sendrecv_abstract_eval)

# assign to the primitive the correct encoder
xla.backend_specific_translations["cpu"][
    mpi_sendrecv_p] = mpi_sendrecv_xla_encode_cpu
xla.backend_specific_translations["gpu"][
    mpi_sendrecv_p] = mpi_sendrecv_xla_encode_gpu
Exemplo n.º 27
0
        abstract_arrays.abstract_token,
    )


def mpi_allreduce_value_and_jvp(in_args, tan_args, op, comm):
    x, token = in_args
    x_tan, token_tan = tan_args

    res = Allreduce(x, token=token, op=op, comm=comm)

    # Identify the correct adjoint
    if op == _MPI.SUM:
        jvp = (x_tan, token_tan)
    else:
        raise NotImplementedError(
            "The adjoint of allreduce for {} operation is not defined".format(
                op))

    return (res, jvp)


mpi_allreduce_p.multiple_results = True
mpi_allreduce_p.def_impl(mpi_allreduce_impl)
mpi_allreduce_p.def_abstract_eval(mpi_allreduce_abstract_eval)

ad.primitive_jvps[mpi_allreduce_p] = mpi_allreduce_value_and_jvp

# assign to the primitive the correct encoder
xla.backend_specific_translations["cpu"][
    mpi_allreduce_p] = mpi_allreduce_xla_encode
Exemplo n.º 28
0
    assert len(params) == 0

    return ShapedArray((2,), 'uint32')


def _random_key_impl(*args, **params):
    assert len(args) == 0
    assert len(params) == 0

    raise ValueError("This parametrized function is randomized and therefore requires "
                     "a random key when applied, i. e. `apply(*inputs, key=PRNGKey(0))`.")


random_key_p = Primitive("random_key")
random_key_p.def_custom_bind(bind(random_key_p))
random_key_p.def_impl(_random_key_impl)
random_key_p.def_abstract_eval(_random_key_abstract_eval)


def random_key():
    """When called inside a parametrized function, this will return a unique random key derived from
    the `key` argument of `apply` or `init_parameters`."""
    return random_key_p.bind()


class parametrized(Primitive):
    """Represents a parametrized function, providing an
    `init_parameters` function for bundled initialization of all parameters,
    and an `apply` function to evaluate the function given these bundled parameters.

    For example, if a dense neural network layer is defined via:
Exemplo n.º 29
0
                              _nan_like(c, vl))
    vr = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), vr,
                              _nan_like(c, vr))
    return xops.Tuple(c, [w, vl, vr])


def eig_batching_rule(batched_args, batch_dims):
    x, = batched_args
    bd, = batch_dims
    x = batching.moveaxis(x, bd, 0)
    return eig_p.bind(x), (0, 0, 0)


eig_p = Primitive('eig')
eig_p.multiple_results = True
eig_p.def_impl(eig_impl)
eig_p.def_abstract_eval(eig_abstract_eval)
xla.translations[eig_p] = eig_translation_rule
xla.backend_specific_translations['cpu'][eig_p] = eig_cpu_translation_rule
batching.primitive_batchers[eig_p] = eig_batching_rule

# Symmetric/Hermitian eigendecomposition


def eigh_impl(operand, lower):
    v, w = xla.apply_primitive(eigh_p, operand, lower=lower)
    return v, w


def eigh_translation_rule(c, operand, lower):
    shape = c.GetShape(operand)
Exemplo n.º 30
0
        _nitems,
        _constant_s32_scalar(c, source),
        _constant_s32_scalar(c, tag),
        _constant_u64_scalar(c, to_mpi_ptr(comm)),
        _constant_u64_scalar(c, _dtype_ptr),
        _constant_u64_scalar(c, _status),
        token,
    )

    out = _ops.CustomCall(
        c, b"mpi_recv", operands=operands, shape=sh, has_side_effect=True,
    )

    return out


# This function evaluates only the shapes during AST construction
def mpi_recv_abstract_eval(xs, token, source, tag, comm, status):
    return (
        abstract_arrays.ShapedArray(xs.shape, xs.dtype),
        abstract_arrays.abstract_token,
    )


mpi_recv_p.multiple_results = True
mpi_recv_p.def_impl(mpi_recv_impl)
mpi_recv_p.def_abstract_eval(mpi_recv_abstract_eval)

# assign to the primitive the correct encoder
xla.backend_specific_translations["cpu"][mpi_recv_p] = mpi_recv_xla_encode