def decorator(fenics_function: Callable) -> Callable: @functools.wraps(fenics_function) def jax_solve_eval(*args): return jax_solve_eval_p.bind(*args) jax_solve_eval_p = Primitive("jax_solve_eval") jax_solve_eval_p.def_impl( lambda *args: solve_eval(fenics_function, fenics_templates, *args)[0] ) jax_solve_eval_p.def_abstract_eval( lambda *args: jax.abstract_arrays.make_shaped_array( solve_eval(fenics_function, fenics_templates, *args)[0] ) ) def jax_solve_eval_batch(vector_arg_values, batch_axes): assert len(set(batch_axes)) == 1 # assert that all batch axes are same assert ( batch_axes[0] == 0 ) # assert that batch axis is zero, need to rewrite for a general case? res = list(map(jax_solve_eval, *vector_arg_values)) res = np.asarray(res) return res, batch_axes[0] jax.batching.primitive_batchers[jax_solve_eval_p] = jax_solve_eval_batch # @trace("jvp_jax_solve_eval") def jvp_jax_solve_eval(ps, ts): return jvp_jax_solve_eval_p.bind(ps, ts) jvp_jax_solve_eval_p = Primitive("jvp_jax_solve_eval") jvp_jax_solve_eval_p.multiple_results = True jvp_jax_solve_eval_p.def_impl( lambda ps, ts: jvp_solve_eval(fenics_function, fenics_templates, ps, ts) ) jax.interpreters.ad.primitive_jvps[jax_solve_eval_p] = jvp_jax_solve_eval # TODO: JAX Tracer goes inside fenics wrappers and zero array is returned # because fenics numpy conversion works only for concrete arrays vjp_jax_solve_eval_p = Primitive("vjp_jax_solve_eval") vjp_jax_solve_eval_p.def_impl( lambda ct, *args: vjp_solve_eval(fenics_function, fenics_templates, *args)[ 1 ](ct) ) jax.interpreters.ad.primitive_transposes[ jax_solve_eval_p ] = vjp_jax_solve_eval_p return jax_solve_eval
vl = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), vl, _nan_like(c, vl)) vr = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), vr, _nan_like(c, vr)) return xops.Tuple(c, [w, vl, vr]) def eig_batching_rule(batched_args, batch_dims): x, = batched_args bd, = batch_dims x = batching.moveaxis(x, bd, 0) return eig_p.bind(x), (0, 0, 0) eig_p = Primitive('eig') eig_p.multiple_results = True eig_p.def_impl(eig_impl) eig_p.def_abstract_eval(eig_abstract_eval) xla.translations[eig_p] = eig_translation_rule xla.backend_specific_translations['cpu'][eig_p] = eig_cpu_translation_rule batching.primitive_batchers[eig_p] = eig_batching_rule # Symmetric/Hermitian eigendecomposition def eigh_impl(operand, lower): v, w = xla.apply_primitive(eigh_p, operand, lower=lower) return v, w def eigh_translation_rule(c, operand, lower):
abstract_arrays.abstract_token, ) def mpi_allreduce_value_and_jvp(in_args, tan_args, op, comm): x, token = in_args x_tan, token_tan = tan_args res = Allreduce(x, token=token, op=op, comm=comm) # Identify the correct adjoint if op == _MPI.SUM: jvp = (x_tan, token_tan) else: raise NotImplementedError( "The adjoint of allreduce for {} operation is not defined".format( op)) return (res, jvp) mpi_allreduce_p.multiple_results = True mpi_allreduce_p.def_impl(mpi_allreduce_impl) mpi_allreduce_p.def_abstract_eval(mpi_allreduce_abstract_eval) ad.primitive_jvps[mpi_allreduce_p] = mpi_allreduce_value_and_jvp # assign to the primitive the correct encoder xla.backend_specific_translations["cpu"][ mpi_allreduce_p] = mpi_allreduce_xla_encode
operands=( sendbuf, token, ), shape=sh, opaque=descriptor, has_side_effect=True, ) # This function evaluates only the shapes during AST construction def mpi_allgather_abstract_eval(x, token, comm): comm = unpack_hashable(comm) size = comm.Get_size() out_shape = (size, *x.shape) return ( abstract_arrays.ShapedArray(out_shape, x.dtype), abstract_arrays.abstract_token, ) mpi_allgather_p.multiple_results = True mpi_allgather_p.def_impl(mpi_allgather_impl) mpi_allgather_p.def_abstract_eval(mpi_allgather_abstract_eval) # assign to the primitive the correct encoder xla.backend_specific_translations["cpu"][ mpi_allgather_p] = mpi_allgather_xla_encode_cpu xla.backend_specific_translations["gpu"][ mpi_allgather_p] = mpi_allgather_xla_encode_gpu
== cosu.dtype == dt_tracer.dtype: raise ValueError('all inputs must have the same dtype') return superbee_p.bind(var, u_wgrid, v_wgrid, w_wgrid, maskW, dxt, dyt, dzw, cost, cosu, dt_tracer) def superbee_impl(*args, **kwargs): return xla.apply_primitive(superbee_p, *args, **kwargs) def _superbee_gpu_translation_rule(computation_builder, var, u_wgrid, v_wgrid, w_wgrid, maskW, dxt, dyt, dzw, cost, cosu, dt_tracer): return _superbee(computation_builder, var, u_wgrid, v_wgrid, w_wgrid, maskW, dxt, dyt, dzw, cost, cosu, dt_tracer) def superbee_abstract_eval(var, u_wgrid, v_wgrid, w_wgrid, maskW, dxt, dyt, dzw, cost, cosu, dt_tracer): aarr = abstract_arrays.ShapedArray(var.shape, var.dtype) return (aarr, ) * 3 superbee_p = Primitive('superbee') superbee_p.multiple_results = True superbee_p.def_impl(superbee_impl) superbee_p.def_abstract_eval(superbee_abstract_eval) xla.backend_specific_translations['gpu'][ superbee_p] = _superbee_gpu_translation_rule
return xla_client.ops.CustomCall( c, b"mpi_alltoall", operands=( x, token, ), shape=sh, opaque=descriptor, has_side_effect=True, ) # This function evaluates only the shapes during AST construction def mpi_alltoall_abstract_eval(xs, token, comm): return ( abstract_arrays.ShapedArray(xs.shape, xs.dtype), core.abstract_token, ) mpi_alltoall_p.multiple_results = True mpi_alltoall_p.def_impl(mpi_alltoall_impl) mpi_alltoall_p.def_abstract_eval(mpi_alltoall_abstract_eval) # assign to the primitive the correct encoder xla.backend_specific_translations["cpu"][ mpi_alltoall_p] = mpi_alltoall_xla_encode_cpu xla.backend_specific_translations["gpu"][ mpi_alltoall_p] = mpi_alltoall_xla_encode_gpu
_nitems, _constant_s32_scalar(c, source), _constant_s32_scalar(c, tag), _constant_u64_scalar(c, to_mpi_ptr(comm)), _constant_u64_scalar(c, _dtype_ptr), _constant_u64_scalar(c, _status), token, ) out = _ops.CustomCall( c, b"mpi_recv", operands=operands, shape=sh, has_side_effect=True, ) return out # This function evaluates only the shapes during AST construction def mpi_recv_abstract_eval(xs, token, source, tag, comm, status): return ( abstract_arrays.ShapedArray(xs.shape, xs.dtype), abstract_arrays.abstract_token, ) mpi_recv_p.multiple_results = True mpi_recv_p.def_impl(mpi_recv_impl) mpi_recv_p.def_abstract_eval(mpi_recv_abstract_eval) # assign to the primitive the correct encoder xla.backend_specific_translations["cpu"][mpi_recv_p] = mpi_recv_xla_encode
c, b"mpi_sendrecv", operands=( sendbuf, token, ), shape=sh, opaque=descriptor, has_side_effect=True, ) # This function evaluates only the shapes during AST construction def mpi_sendrecv_abstract_eval(sendbuf, recvbuf, token, source, dest, sendtag, recvtag, comm, status): return ( abstract_arrays.ShapedArray(recvbuf.shape, recvbuf.dtype), abstract_arrays.abstract_token, ) mpi_sendrecv_p.multiple_results = True mpi_sendrecv_p.def_impl(mpi_sendrecv_impl) mpi_sendrecv_p.def_abstract_eval(mpi_sendrecv_abstract_eval) # assign to the primitive the correct encoder xla.backend_specific_translations["cpu"][ mpi_sendrecv_p] = mpi_sendrecv_xla_encode_cpu xla.backend_specific_translations["gpu"][ mpi_sendrecv_p] = mpi_sendrecv_xla_encode_gpu
def eig_jvp_rule(primals, tangents, *, compute_left_eigenvectors, compute_right_eigenvectors): if compute_left_eigenvectors or compute_right_eigenvectors: raise NotImplementedError( 'The derivatives of eigenvectors are not implemented, only ' 'eigenvalues. See ' 'https://github.com/google/jax/issues/2748 for discussion.') # Formula for derivative of eigenvalues w.r.t. a is eqn 4.60 in # https://arxiv.org/abs/1701.00392 a, = primals da, = tangents l, v = eig(a, compute_left_eigenvectors=False) return [l], [jnp.sum(_solve(v, da.astype(v.dtype)) * _T(v), -1)] eig_p = Primitive('eig') eig_p.multiple_results = True eig_p.def_impl(eig_impl) eig_p.def_abstract_eval(eig_abstract_eval) xla.translations[eig_p] = eig_translation_rule xla.backend_specific_translations['cpu'][eig_p] = eig_cpu_translation_rule batching.primitive_batchers[eig_p] = eig_batching_rule ad.primitive_jvps[eig_p] = eig_jvp_rule # Symmetric/Hermitian eigendecomposition def eigh_impl(operand, lower): v, w = xla.apply_primitive(eigh_p, operand, lower=lower) return v, w def eigh_translation_rule(c, operand, lower):
def vjp_solve_eval(fenics_function: Callable, fenics_templates: FenicsVariable, *args: np.array) -> Tuple[np.array, Callable]: """Computes the gradients of the output with respect to the input Input: fenics_function (callable): FEniCS function to be executed during the forward pass args (tuple): jax array representation of the input to fenics_function Output: A pair where the first element is the value of fun applied to the arguments and the second element is a Python callable representing the VJP map from output cotangents to input cotangents. The returned VJP function must accept a value with the same shape as the value of fun applied to the arguments and must return a tuple with length equal to the number of positional arguments to fun. """ numpy_output, fenics_solution, residual_form, fenics_inputs, bcs = solve_eval( fenics_function, fenics_templates, *args) # @trace("vjp_fun1") def vjp_fun1(g): return vjp_fun1_p.bind(g) vjp_fun1_p = Primitive("vjp_fun1") vjp_fun1_p.multiple_results = True vjp_fun1_p.def_impl(lambda g: tuple( vjp if vjp is not None else jax.ad_util.zeros_like_jaxval(args[i]) for i, vjp in enumerate( vjp_solve_eval_impl(g, fenics_solution, residual_form, fenics_inputs, bcs)))) # @trace("vjp_fun1_abstract_eval") def vjp_fun1_abstract_eval(g): if len(args) > 1: return tuple( (jax.abstract_arrays.ShapedArray(arg.shape, arg.dtype) for arg in args)) else: return (jax.abstract_arrays.ShapedArray((1, *args[0].shape), args[0].dtype), ) vjp_fun1_p.def_abstract_eval(vjp_fun1_abstract_eval) # @trace("vjp_fun1_batch") def vjp_fun1_batch(vector_arg_values, batch_axes): """Computes the batched version of the primitive. This must be a JAX-traceable function. Args: vector_arg_values: a tuple of arguments, each being a tensor of matching shape. batch_axes: the axes that are being batched. See vmap documentation. Returns: a tuple of the result, and the result axis that was batched. """ # _trace("Using vjp_fun1 to compute the batch:") assert ( batch_axes[0] == 0 ) # assert that batch axis is zero, need to rewrite for a general case? # compute function row-by-row res = [ vjp_fun1(vector_arg_values[0][i]) for i in range(vector_arg_values[0].shape[0]) ] # transpose resulting list res_T = list(itertools.zip_longest(*res)) return tuple(map(np.vstack, res_T)), (batch_axes[0], ) * len(args) jax.interpreters.batching.primitive_batchers[vjp_fun1_p] = vjp_fun1_batch return numpy_output, vjp_fun1
def get_pullback_function( fenics_function: Callable, fenics_templates: Collection[BackendVariable] ) -> Callable: """Computes the gradients of the output with respect to the input Input: fenics_function (callable): FEniCS function to be executed during the forward pass Output: A Python callable representing the VJP map from output cotangents to input cotangents. The returned VJP function must accept a value with the same shape as the value of fun applied to the arguments and must return a tuple with length equal to the number of positional arguments to fun. """ # @trace("vjp_fun1") def vjp_fun1(aux_args, g): return vjp_fun1_p.bind(aux_args, g) def vjp_fun1_p_impl(aux_args, g): fe_aux, args = aux_args fenics_output, fenics_inputs, tape = ( fe_aux.fenics_output, fe_aux.fenics_inputs, fe_aux.tape, ) return tuple( vjp if vjp is not None else jax.ad_util.zeros_like_jaxval(args[i]) for i, vjp in enumerate( evaluate_pullback(fenics_output, fenics_inputs, tape, g) ) ) vjp_fun1_p = Primitive("vjp_fun1") vjp_fun1_p.multiple_results = True vjp_fun1_p.def_impl(vjp_fun1_p_impl) # @trace("vjp_fun1_abstract_eval") def vjp_fun1_abstract_eval(aux_args, g): _, args = aux_args if len(args) > 1: return tuple( (jax.abstract_arrays.ShapedArray(arg.shape, arg.dtype) for arg in args) ) else: return ( jax.abstract_arrays.ShapedArray((1, *args[0].shape), args[0].dtype), ) vjp_fun1_p.def_abstract_eval(vjp_fun1_abstract_eval) # @trace("vjp_fun1_batch") def vjp_fun1_batch(vector_arg_values, batch_axes): """Computes the batched version of the primitive. This must be a JAX-traceable function. Args: vector_arg_values: a tuple of arguments, each being a tensor of matching shape. batch_axes: the axes that are being batched. See vmap documentation. Returns: a tuple of the result, and the result axis that was batched. """ # _trace("Using vjp_fun1 to compute the batch:") _, args = vector_arg_values[0] assert ( batch_axes[0] is None ) # assert that batch axis is None, need to rewrite for a general case? assert ( batch_axes[1] == 0 ) # assert that batch axis is zero, need to rewrite for a general case? # apply function row-by-row vjp_fun1_partial = functools.partial(vjp_fun1, vector_arg_values[0]) res = list(map(vjp_fun1_partial, *(vector_arg_values[1],))) # transpose resulting list res_T = list(itertools.zip_longest(*res)) return tuple(map(np.vstack, res_T)), (batch_axes[1],) * len(args) jax.interpreters.batching.primitive_batchers[vjp_fun1_p] = vjp_fun1_batch return vjp_fun1
def decorator(fenics_function: Callable) -> Callable: @functools.wraps(fenics_function) def jax_fem_eval(*args): return jax_fem_eval_p.bind(*args) jax_fem_eval_p = Primitive("jax_fem_eval") def jax_fem_eval_p_impl(*args): args = ( jax_to_fenics_numpy(arg, ft) for arg, ft in zip(args, fenics_templates) ) return evaluate_primal(fenics_function, fenics_templates, *args)[0] jax_fem_eval_p.def_impl(jax_fem_eval_p_impl) def jax_fem_eval_p_abstract_eval(*args): args = ( jax_to_fenics_numpy(arg, ft) for arg, ft in zip(args, fenics_templates) ) return jax.abstract_arrays.make_shaped_array( evaluate_primal(fenics_function, fenics_templates, *args)[0] ) jax_fem_eval_p.def_abstract_eval(jax_fem_eval_p_abstract_eval) def jax_fem_eval_batch(vector_arg_values, batch_axes): assert len(set(batch_axes)) == 1 # assert that all batch axes are same assert ( batch_axes[0] == 0 ) # assert that batch axis is zero, need to rewrite for a general case? res = list(map(jax_fem_eval, *vector_arg_values)) res = np.asarray(res) return res, batch_axes[0] jax.interpreters.batching.primitive_batchers[ jax_fem_eval_p ] = jax_fem_eval_batch # @trace("jvp_jax_fem_eval") def jvp_jax_fem_eval(ps, ts): return jvp_jax_fem_eval_p.bind(ps, ts) jvp_jax_fem_eval_p = Primitive("jvp_jax_fem_eval") jvp_jax_fem_eval_p.multiple_results = True def jvp_jax_fem_eval_impl(primals, tangents): primals = ( jax_to_fenics_numpy(p, ft) for p, ft in zip(primals, fenics_templates) ) numpy_output, fenics_output, fenics_inputs, tape = evaluate_primal( fenics_function, fenics_templates, *primals ) tangents = ( jax_to_fenics_numpy(t, ft) for t, ft in zip(tangents, fenics_templates) ) dnumpy_output = evaluate_pushforward( fenics_output, fenics_inputs, tape, tangents ) return numpy_output, dnumpy_output jvp_jax_fem_eval_p.def_impl(jvp_jax_fem_eval_impl) jax.interpreters.ad.primitive_jvps[jax_fem_eval_p] = jvp_jax_fem_eval # TODO: JAX Tracer goes inside fenics wrappers and zero array is returned # because fenics numpy conversion works only for concrete arrays # vjp_jax_fem_eval_p = Primitive("vjp_jax_fem_eval") # vjp_jax_fem_eval_p.def_impl( # lambda ct, *args: vjp_fem_eval(fenics_function, fenics_templates, *args)[1]( # ct # ) # ) # jax.interpreters.ad.primitive_transposes[jax_fem_eval_p] = vjp_jax_fem_eval_p return jax_fem_eval
) return xla_client.ops.CustomCall( c, b"mpi_scan", operands=( x, token, ), shape=sh, opaque=descriptor, has_side_effect=True, ) # This function evaluates only the shapes during AST construction def mpi_scan_abstract_eval(xs, token, op, comm): return ( abstract_arrays.ShapedArray(xs.shape, xs.dtype), core.abstract_token, ) mpi_scan_p.multiple_results = True mpi_scan_p.def_impl(mpi_scan_impl) mpi_scan_p.def_abstract_eval(mpi_scan_abstract_eval) # assign to the primitive the correct encoder xla.backend_specific_translations["cpu"][mpi_scan_p] = mpi_scan_xla_encode_cpu xla.backend_specific_translations["gpu"][mpi_scan_p] = mpi_scan_xla_encode_gpu
shape=sh, opaque=descriptor, has_side_effect=True, ) # This function evaluates only the shapes during AST construction def mpi_scatter_abstract_eval(x, token, root, comm): comm = unpack_hashable(comm) rank = comm.Get_rank() if rank == root: out_shape = x.shape[1:] else: out_shape = x.shape return ( abstract_arrays.ShapedArray(out_shape, x.dtype), core.abstract_token, ) mpi_scatter_p.multiple_results = True mpi_scatter_p.def_impl(mpi_scatter_impl) mpi_scatter_p.def_abstract_eval(mpi_scatter_abstract_eval) # assign to the primitive the correct encoder xla.backend_specific_translations["cpu"][ mpi_scatter_p] = mpi_scatter_xla_encode_cpu xla.backend_specific_translations["gpu"][ mpi_scatter_p] = mpi_scatter_xla_encode_gpu
return xla_client.ops.CustomCall( c, b"mpi_bcast", operands=( x, token, ), shape=sh, opaque=descriptor, has_side_effect=True, ) # This function evaluates only the shapes during AST construction def mpi_bcast_abstract_eval(xs, token, root, comm): return ( abstract_arrays.ShapedArray(xs.shape, xs.dtype), abstract_arrays.abstract_token, ) mpi_bcast_p.multiple_results = True mpi_bcast_p.def_impl(mpi_bcast_impl) mpi_bcast_p.def_abstract_eval(mpi_bcast_abstract_eval) # assign to the primitive the correct encoder xla.backend_specific_translations["cpu"][ mpi_bcast_p] = mpi_bcast_xla_encode_cpu xla.backend_specific_translations["gpu"][ mpi_bcast_p] = mpi_bcast_xla_encode_gpu