예제 #1
0
파일: solve.py 프로젝트: Orcuslc/jax-fenics
    def decorator(fenics_function: Callable) -> Callable:
        @functools.wraps(fenics_function)
        def jax_solve_eval(*args):
            return jax_solve_eval_p.bind(*args)

        jax_solve_eval_p = Primitive("jax_solve_eval")
        jax_solve_eval_p.def_impl(
            lambda *args: solve_eval(fenics_function, fenics_templates, *args)[0]
        )

        jax_solve_eval_p.def_abstract_eval(
            lambda *args: jax.abstract_arrays.make_shaped_array(
                solve_eval(fenics_function, fenics_templates, *args)[0]
            )
        )

        def jax_solve_eval_batch(vector_arg_values, batch_axes):
            assert len(set(batch_axes)) == 1  # assert that all batch axes are same
            assert (
                batch_axes[0] == 0
            )  # assert that batch axis is zero, need to rewrite for a general case?
            res = list(map(jax_solve_eval, *vector_arg_values))
            res = np.asarray(res)
            return res, batch_axes[0]

        jax.batching.primitive_batchers[jax_solve_eval_p] = jax_solve_eval_batch

        # @trace("jvp_jax_solve_eval")
        def jvp_jax_solve_eval(ps, ts):
            return jvp_jax_solve_eval_p.bind(ps, ts)

        jvp_jax_solve_eval_p = Primitive("jvp_jax_solve_eval")
        jvp_jax_solve_eval_p.multiple_results = True
        jvp_jax_solve_eval_p.def_impl(
            lambda ps, ts: jvp_solve_eval(fenics_function, fenics_templates, ps, ts)
        )

        jax.interpreters.ad.primitive_jvps[jax_solve_eval_p] = jvp_jax_solve_eval

        # TODO: JAX Tracer goes inside fenics wrappers and zero array is returned
        # because fenics numpy conversion works only for concrete arrays
        vjp_jax_solve_eval_p = Primitive("vjp_jax_solve_eval")
        vjp_jax_solve_eval_p.def_impl(
            lambda ct, *args: vjp_solve_eval(fenics_function, fenics_templates, *args)[
                1
            ](ct)
        )

        jax.interpreters.ad.primitive_transposes[
            jax_solve_eval_p
        ] = vjp_jax_solve_eval_p

        return jax_solve_eval
예제 #2
0
파일: lax_linalg.py 프로젝트: yotarok/jax
    vl = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), vl,
                              _nan_like(c, vl))
    vr = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), vr,
                              _nan_like(c, vr))
    return xops.Tuple(c, [w, vl, vr])


def eig_batching_rule(batched_args, batch_dims):
    x, = batched_args
    bd, = batch_dims
    x = batching.moveaxis(x, bd, 0)
    return eig_p.bind(x), (0, 0, 0)


eig_p = Primitive('eig')
eig_p.multiple_results = True
eig_p.def_impl(eig_impl)
eig_p.def_abstract_eval(eig_abstract_eval)
xla.translations[eig_p] = eig_translation_rule
xla.backend_specific_translations['cpu'][eig_p] = eig_cpu_translation_rule
batching.primitive_batchers[eig_p] = eig_batching_rule

# Symmetric/Hermitian eigendecomposition


def eigh_impl(operand, lower):
    v, w = xla.apply_primitive(eigh_p, operand, lower=lower)
    return v, w


def eigh_translation_rule(c, operand, lower):
예제 #3
0
파일: allreduce.py 프로젝트: kiminh/mpi4jax
        abstract_arrays.abstract_token,
    )


def mpi_allreduce_value_and_jvp(in_args, tan_args, op, comm):
    x, token = in_args
    x_tan, token_tan = tan_args

    res = Allreduce(x, token=token, op=op, comm=comm)

    # Identify the correct adjoint
    if op == _MPI.SUM:
        jvp = (x_tan, token_tan)
    else:
        raise NotImplementedError(
            "The adjoint of allreduce for {} operation is not defined".format(
                op))

    return (res, jvp)


mpi_allreduce_p.multiple_results = True
mpi_allreduce_p.def_impl(mpi_allreduce_impl)
mpi_allreduce_p.def_abstract_eval(mpi_allreduce_abstract_eval)

ad.primitive_jvps[mpi_allreduce_p] = mpi_allreduce_value_and_jvp

# assign to the primitive the correct encoder
xla.backend_specific_translations["cpu"][
    mpi_allreduce_p] = mpi_allreduce_xla_encode
예제 #4
0
        operands=(
            sendbuf,
            token,
        ),
        shape=sh,
        opaque=descriptor,
        has_side_effect=True,
    )


# This function evaluates only the shapes during AST construction
def mpi_allgather_abstract_eval(x, token, comm):
    comm = unpack_hashable(comm)
    size = comm.Get_size()
    out_shape = (size, *x.shape)
    return (
        abstract_arrays.ShapedArray(out_shape, x.dtype),
        abstract_arrays.abstract_token,
    )


mpi_allgather_p.multiple_results = True
mpi_allgather_p.def_impl(mpi_allgather_impl)
mpi_allgather_p.def_abstract_eval(mpi_allgather_abstract_eval)

# assign to the primitive the correct encoder
xla.backend_specific_translations["cpu"][
    mpi_allgather_p] = mpi_allgather_xla_encode_cpu
xla.backend_specific_translations["gpu"][
    mpi_allgather_p] = mpi_allgather_xla_encode_gpu
예제 #5
0
        == cosu.dtype == dt_tracer.dtype:
        raise ValueError('all inputs must have the same dtype')

    return superbee_p.bind(var, u_wgrid, v_wgrid, w_wgrid, maskW, dxt, dyt,
                           dzw, cost, cosu, dt_tracer)


def superbee_impl(*args, **kwargs):
    return xla.apply_primitive(superbee_p, *args, **kwargs)


def _superbee_gpu_translation_rule(computation_builder, var, u_wgrid, v_wgrid,
                                   w_wgrid, maskW, dxt, dyt, dzw, cost, cosu,
                                   dt_tracer):
    return _superbee(computation_builder, var, u_wgrid, v_wgrid, w_wgrid,
                     maskW, dxt, dyt, dzw, cost, cosu, dt_tracer)


def superbee_abstract_eval(var, u_wgrid, v_wgrid, w_wgrid, maskW, dxt, dyt,
                           dzw, cost, cosu, dt_tracer):

    aarr = abstract_arrays.ShapedArray(var.shape, var.dtype)
    return (aarr, ) * 3


superbee_p = Primitive('superbee')
superbee_p.multiple_results = True
superbee_p.def_impl(superbee_impl)
superbee_p.def_abstract_eval(superbee_abstract_eval)
xla.backend_specific_translations['gpu'][
    superbee_p] = _superbee_gpu_translation_rule
예제 #6
0
    return xla_client.ops.CustomCall(
        c,
        b"mpi_alltoall",
        operands=(
            x,
            token,
        ),
        shape=sh,
        opaque=descriptor,
        has_side_effect=True,
    )


# This function evaluates only the shapes during AST construction
def mpi_alltoall_abstract_eval(xs, token, comm):
    return (
        abstract_arrays.ShapedArray(xs.shape, xs.dtype),
        core.abstract_token,
    )


mpi_alltoall_p.multiple_results = True
mpi_alltoall_p.def_impl(mpi_alltoall_impl)
mpi_alltoall_p.def_abstract_eval(mpi_alltoall_abstract_eval)

# assign to the primitive the correct encoder
xla.backend_specific_translations["cpu"][
    mpi_alltoall_p] = mpi_alltoall_xla_encode_cpu
xla.backend_specific_translations["gpu"][
    mpi_alltoall_p] = mpi_alltoall_xla_encode_gpu
예제 #7
0
파일: recv.py 프로젝트: dionhaefner/mpi4jax
        _nitems,
        _constant_s32_scalar(c, source),
        _constant_s32_scalar(c, tag),
        _constant_u64_scalar(c, to_mpi_ptr(comm)),
        _constant_u64_scalar(c, _dtype_ptr),
        _constant_u64_scalar(c, _status),
        token,
    )

    out = _ops.CustomCall(
        c, b"mpi_recv", operands=operands, shape=sh, has_side_effect=True,
    )

    return out


# This function evaluates only the shapes during AST construction
def mpi_recv_abstract_eval(xs, token, source, tag, comm, status):
    return (
        abstract_arrays.ShapedArray(xs.shape, xs.dtype),
        abstract_arrays.abstract_token,
    )


mpi_recv_p.multiple_results = True
mpi_recv_p.def_impl(mpi_recv_impl)
mpi_recv_p.def_abstract_eval(mpi_recv_abstract_eval)

# assign to the primitive the correct encoder
xla.backend_specific_translations["cpu"][mpi_recv_p] = mpi_recv_xla_encode
예제 #8
0
        c,
        b"mpi_sendrecv",
        operands=(
            sendbuf,
            token,
        ),
        shape=sh,
        opaque=descriptor,
        has_side_effect=True,
    )


# This function evaluates only the shapes during AST construction
def mpi_sendrecv_abstract_eval(sendbuf, recvbuf, token, source, dest, sendtag,
                               recvtag, comm, status):
    return (
        abstract_arrays.ShapedArray(recvbuf.shape, recvbuf.dtype),
        abstract_arrays.abstract_token,
    )


mpi_sendrecv_p.multiple_results = True
mpi_sendrecv_p.def_impl(mpi_sendrecv_impl)
mpi_sendrecv_p.def_abstract_eval(mpi_sendrecv_abstract_eval)

# assign to the primitive the correct encoder
xla.backend_specific_translations["cpu"][
    mpi_sendrecv_p] = mpi_sendrecv_xla_encode_cpu
xla.backend_specific_translations["gpu"][
    mpi_sendrecv_p] = mpi_sendrecv_xla_encode_gpu
예제 #9
0
파일: linalg.py 프로젝트: nbswords/jax
def eig_jvp_rule(primals, tangents, *, compute_left_eigenvectors,
                 compute_right_eigenvectors):
  if compute_left_eigenvectors or compute_right_eigenvectors:
    raise NotImplementedError(
        'The derivatives of eigenvectors are not implemented, only '
        'eigenvalues. See '
        'https://github.com/google/jax/issues/2748 for discussion.')
  # Formula for derivative of eigenvalues w.r.t. a is eqn 4.60 in
  # https://arxiv.org/abs/1701.00392
  a, = primals
  da, = tangents
  l, v = eig(a, compute_left_eigenvectors=False)
  return [l], [jnp.sum(_solve(v, da.astype(v.dtype)) * _T(v), -1)]

eig_p = Primitive('eig')
eig_p.multiple_results = True
eig_p.def_impl(eig_impl)
eig_p.def_abstract_eval(eig_abstract_eval)
xla.translations[eig_p] = eig_translation_rule
xla.backend_specific_translations['cpu'][eig_p] = eig_cpu_translation_rule
batching.primitive_batchers[eig_p] = eig_batching_rule
ad.primitive_jvps[eig_p] = eig_jvp_rule


# Symmetric/Hermitian eigendecomposition

def eigh_impl(operand, lower):
  v, w = xla.apply_primitive(eigh_p, operand, lower=lower)
  return v, w

def eigh_translation_rule(c, operand, lower):
예제 #10
0
def vjp_solve_eval(fenics_function: Callable, fenics_templates: FenicsVariable,
                   *args: np.array) -> Tuple[np.array, Callable]:
    """Computes the gradients of the output with respect to the input
    Input:
        fenics_function (callable): FEniCS function to be executed during the forward pass
        args (tuple): jax array representation of the input to fenics_function
    Output:
        A pair where the first element is the value of fun applied to the arguments and the second element
        is a Python callable representing the VJP map from output cotangents to input cotangents.
        The returned VJP function must accept a value with the same shape as the value of fun applied
        to the arguments and must return a tuple with length equal to the number of positional arguments to fun.
    """

    numpy_output, fenics_solution, residual_form, fenics_inputs, bcs = solve_eval(
        fenics_function, fenics_templates, *args)

    # @trace("vjp_fun1")
    def vjp_fun1(g):
        return vjp_fun1_p.bind(g)

    vjp_fun1_p = Primitive("vjp_fun1")
    vjp_fun1_p.multiple_results = True
    vjp_fun1_p.def_impl(lambda g: tuple(
        vjp if vjp is not None else jax.ad_util.zeros_like_jaxval(args[i])
        for i, vjp in enumerate(
            vjp_solve_eval_impl(g, fenics_solution, residual_form,
                                fenics_inputs, bcs))))

    # @trace("vjp_fun1_abstract_eval")
    def vjp_fun1_abstract_eval(g):
        if len(args) > 1:
            return tuple(
                (jax.abstract_arrays.ShapedArray(arg.shape, arg.dtype)
                 for arg in args))
        else:
            return (jax.abstract_arrays.ShapedArray((1, *args[0].shape),
                                                    args[0].dtype), )

    vjp_fun1_p.def_abstract_eval(vjp_fun1_abstract_eval)

    # @trace("vjp_fun1_batch")
    def vjp_fun1_batch(vector_arg_values, batch_axes):
        """Computes the batched version of the primitive.

        This must be a JAX-traceable function.

        Args:
            vector_arg_values: a tuple of arguments, each being a tensor of matching
            shape.
            batch_axes: the axes that are being batched. See vmap documentation.
        Returns:
            a tuple of the result, and the result axis that was batched.
        """
        # _trace("Using vjp_fun1 to compute the batch:")
        assert (
            batch_axes[0] == 0
        )  # assert that batch axis is zero, need to rewrite for a general case?
        # compute function row-by-row
        res = [
            vjp_fun1(vector_arg_values[0][i])
            for i in range(vector_arg_values[0].shape[0])
        ]
        # transpose resulting list
        res_T = list(itertools.zip_longest(*res))
        return tuple(map(np.vstack, res_T)), (batch_axes[0], ) * len(args)

    jax.interpreters.batching.primitive_batchers[vjp_fun1_p] = vjp_fun1_batch

    return numpy_output, vjp_fun1
예제 #11
0
def get_pullback_function(
    fenics_function: Callable, fenics_templates: Collection[BackendVariable]
) -> Callable:
    """Computes the gradients of the output with respect to the input
    Input:
        fenics_function (callable): FEniCS function to be executed during the forward pass
    Output:
        A Python callable representing the VJP map from output cotangents to input cotangents.
        The returned VJP function must accept a value with the same shape as the value of fun applied
        to the arguments and must return a tuple with length equal to the number of positional arguments to fun.
    """

    # @trace("vjp_fun1")
    def vjp_fun1(aux_args, g):
        return vjp_fun1_p.bind(aux_args, g)

    def vjp_fun1_p_impl(aux_args, g):
        fe_aux, args = aux_args
        fenics_output, fenics_inputs, tape = (
            fe_aux.fenics_output,
            fe_aux.fenics_inputs,
            fe_aux.tape,
        )
        return tuple(
            vjp if vjp is not None else jax.ad_util.zeros_like_jaxval(args[i])
            for i, vjp in enumerate(
                evaluate_pullback(fenics_output, fenics_inputs, tape, g)
            )
        )

    vjp_fun1_p = Primitive("vjp_fun1")
    vjp_fun1_p.multiple_results = True
    vjp_fun1_p.def_impl(vjp_fun1_p_impl)

    # @trace("vjp_fun1_abstract_eval")
    def vjp_fun1_abstract_eval(aux_args, g):
        _, args = aux_args
        if len(args) > 1:
            return tuple(
                (jax.abstract_arrays.ShapedArray(arg.shape, arg.dtype) for arg in args)
            )
        else:
            return (
                jax.abstract_arrays.ShapedArray((1, *args[0].shape), args[0].dtype),
            )

    vjp_fun1_p.def_abstract_eval(vjp_fun1_abstract_eval)

    # @trace("vjp_fun1_batch")
    def vjp_fun1_batch(vector_arg_values, batch_axes):
        """Computes the batched version of the primitive.

        This must be a JAX-traceable function.

        Args:
            vector_arg_values: a tuple of arguments, each being a tensor of matching
            shape.
            batch_axes: the axes that are being batched. See vmap documentation.
        Returns:
            a tuple of the result, and the result axis that was batched.
        """
        # _trace("Using vjp_fun1 to compute the batch:")
        _, args = vector_arg_values[0]
        assert (
            batch_axes[0] is None
        )  # assert that batch axis is None, need to rewrite for a general case?
        assert (
            batch_axes[1] == 0
        )  # assert that batch axis is zero, need to rewrite for a general case?
        # apply function row-by-row
        vjp_fun1_partial = functools.partial(vjp_fun1, vector_arg_values[0])
        res = list(map(vjp_fun1_partial, *(vector_arg_values[1],)))
        # transpose resulting list
        res_T = list(itertools.zip_longest(*res))
        return tuple(map(np.vstack, res_T)), (batch_axes[1],) * len(args)

    jax.interpreters.batching.primitive_batchers[vjp_fun1_p] = vjp_fun1_batch

    return vjp_fun1
예제 #12
0
    def decorator(fenics_function: Callable) -> Callable:
        @functools.wraps(fenics_function)
        def jax_fem_eval(*args):
            return jax_fem_eval_p.bind(*args)

        jax_fem_eval_p = Primitive("jax_fem_eval")

        def jax_fem_eval_p_impl(*args):
            args = (
                jax_to_fenics_numpy(arg, ft) for arg, ft in zip(args, fenics_templates)
            )
            return evaluate_primal(fenics_function, fenics_templates, *args)[0]

        jax_fem_eval_p.def_impl(jax_fem_eval_p_impl)

        def jax_fem_eval_p_abstract_eval(*args):
            args = (
                jax_to_fenics_numpy(arg, ft) for arg, ft in zip(args, fenics_templates)
            )
            return jax.abstract_arrays.make_shaped_array(
                evaluate_primal(fenics_function, fenics_templates, *args)[0]
            )

        jax_fem_eval_p.def_abstract_eval(jax_fem_eval_p_abstract_eval)

        def jax_fem_eval_batch(vector_arg_values, batch_axes):
            assert len(set(batch_axes)) == 1  # assert that all batch axes are same
            assert (
                batch_axes[0] == 0
            )  # assert that batch axis is zero, need to rewrite for a general case?
            res = list(map(jax_fem_eval, *vector_arg_values))
            res = np.asarray(res)
            return res, batch_axes[0]

        jax.interpreters.batching.primitive_batchers[
            jax_fem_eval_p
        ] = jax_fem_eval_batch

        # @trace("jvp_jax_fem_eval")
        def jvp_jax_fem_eval(ps, ts):
            return jvp_jax_fem_eval_p.bind(ps, ts)

        jvp_jax_fem_eval_p = Primitive("jvp_jax_fem_eval")
        jvp_jax_fem_eval_p.multiple_results = True

        def jvp_jax_fem_eval_impl(primals, tangents):
            primals = (
                jax_to_fenics_numpy(p, ft) for p, ft in zip(primals, fenics_templates)
            )
            numpy_output, fenics_output, fenics_inputs, tape = evaluate_primal(
                fenics_function, fenics_templates, *primals
            )

            tangents = (
                jax_to_fenics_numpy(t, ft) for t, ft in zip(tangents, fenics_templates)
            )
            dnumpy_output = evaluate_pushforward(
                fenics_output, fenics_inputs, tape, tangents
            )
            return numpy_output, dnumpy_output

        jvp_jax_fem_eval_p.def_impl(jvp_jax_fem_eval_impl)

        jax.interpreters.ad.primitive_jvps[jax_fem_eval_p] = jvp_jax_fem_eval

        # TODO: JAX Tracer goes inside fenics wrappers and zero array is returned
        # because fenics numpy conversion works only for concrete arrays
        # vjp_jax_fem_eval_p = Primitive("vjp_jax_fem_eval")
        # vjp_jax_fem_eval_p.def_impl(
        #     lambda ct, *args: vjp_fem_eval(fenics_function, fenics_templates, *args)[1](
        #         ct
        #     )
        # )

        # jax.interpreters.ad.primitive_transposes[jax_fem_eval_p] = vjp_jax_fem_eval_p

        return jax_fem_eval
예제 #13
0
    )

    return xla_client.ops.CustomCall(
        c,
        b"mpi_scan",
        operands=(
            x,
            token,
        ),
        shape=sh,
        opaque=descriptor,
        has_side_effect=True,
    )


# This function evaluates only the shapes during AST construction
def mpi_scan_abstract_eval(xs, token, op, comm):
    return (
        abstract_arrays.ShapedArray(xs.shape, xs.dtype),
        core.abstract_token,
    )


mpi_scan_p.multiple_results = True
mpi_scan_p.def_impl(mpi_scan_impl)
mpi_scan_p.def_abstract_eval(mpi_scan_abstract_eval)

# assign to the primitive the correct encoder
xla.backend_specific_translations["cpu"][mpi_scan_p] = mpi_scan_xla_encode_cpu
xla.backend_specific_translations["gpu"][mpi_scan_p] = mpi_scan_xla_encode_gpu
예제 #14
0
파일: scatter.py 프로젝트: mpi4jax/mpi4jax
        shape=sh,
        opaque=descriptor,
        has_side_effect=True,
    )


# This function evaluates only the shapes during AST construction
def mpi_scatter_abstract_eval(x, token, root, comm):
    comm = unpack_hashable(comm)
    rank = comm.Get_rank()
    if rank == root:
        out_shape = x.shape[1:]
    else:
        out_shape = x.shape

    return (
        abstract_arrays.ShapedArray(out_shape, x.dtype),
        core.abstract_token,
    )


mpi_scatter_p.multiple_results = True
mpi_scatter_p.def_impl(mpi_scatter_impl)
mpi_scatter_p.def_abstract_eval(mpi_scatter_abstract_eval)

# assign to the primitive the correct encoder
xla.backend_specific_translations["cpu"][
    mpi_scatter_p] = mpi_scatter_xla_encode_cpu
xla.backend_specific_translations["gpu"][
    mpi_scatter_p] = mpi_scatter_xla_encode_gpu
예제 #15
0
파일: bcast.py 프로젝트: yzhwang/mpi4jax
    return xla_client.ops.CustomCall(
        c,
        b"mpi_bcast",
        operands=(
            x,
            token,
        ),
        shape=sh,
        opaque=descriptor,
        has_side_effect=True,
    )


# This function evaluates only the shapes during AST construction
def mpi_bcast_abstract_eval(xs, token, root, comm):
    return (
        abstract_arrays.ShapedArray(xs.shape, xs.dtype),
        abstract_arrays.abstract_token,
    )


mpi_bcast_p.multiple_results = True
mpi_bcast_p.def_impl(mpi_bcast_impl)
mpi_bcast_p.def_abstract_eval(mpi_bcast_abstract_eval)

# assign to the primitive the correct encoder
xla.backend_specific_translations["cpu"][
    mpi_bcast_p] = mpi_bcast_xla_encode_cpu
xla.backend_specific_translations["gpu"][
    mpi_bcast_p] = mpi_bcast_xla_encode_gpu