def test_unimplemented_interpreter_rules(self): foo_p = Primitive('foo') def foo(x): return foo_p.bind(x) jtu.check_raises(lambda: foo(1.0), NotImplementedError, "Evaluation rule for 'foo' not implemented") jtu.check_raises(lambda: jit(foo)(1.0), NotImplementedError, "Abstract evaluation for 'foo' not implemented") jtu.check_raises( lambda: grad(foo)(1.0), NotImplementedError, "Forward-mode differentiation rule for 'foo' not implemented") foo_p.def_abstract_eval(lambda x: x) jtu.check_raises(lambda: jit(foo)(1.0), NotImplementedError, "XLA translation rule for 'foo' not implemented") foo_p.def_impl(lambda x: x) defjvp(foo_p, lambda g, x: foo(g)) jtu.check_raises( lambda: grad(foo)(1.0), NotImplementedError, "Reverse-mode differentiation rule for 'foo' not implemented")
def decorator(fenics_function: Callable) -> Callable: @functools.wraps(fenics_function) @custom_vjp def jax_fem_eval(*args): return jax_fem_eval_p.bind(*args) jax_fem_eval_p = Primitive("jax_fem_eval") jax_fem_eval_p.def_impl( lambda *args: evaluate_primal(fenics_function, fenics_templates, *args)[0] ) jax_fem_eval_p.def_abstract_eval( lambda *args: jax.abstract_arrays.make_shaped_array( evaluate_primal(fenics_function, fenics_templates, *args)[0] ) ) def jax_fem_eval_batch(vector_arg_values, batch_axes): assert len(set(batch_axes)) == 1 # assert that all batch axes are same assert ( batch_axes[0] == 0 ) # assert that batch axis is zero, need to rewrite for a general case? res = list(map(jax_fem_eval, *vector_arg_values)) res = np.asarray(res) return res, batch_axes[0] jax.interpreters.batching.primitive_batchers[ jax_fem_eval_p ] = jax_fem_eval_batch def primal(*args): numpy_output, fenics_output, fenics_inputs, tape = evaluate_primal( fenics_function, fenics_templates, *args ) return ( numpy_output, (PyadjointMetadata(fenics_output, fenics_inputs, tape), args), ) def pullback(aux_args, g): pb_fn = get_pullback_function(fenics_function, fenics_templates) # for some reason output of get_pullback_function is a list but we need tuple return tuple(pb_fn(aux_args, g)) jax_fem_eval.defvjp(primal, pullback) return jax_fem_eval
def decorator(fenics_function: Callable) -> Callable: @functools.wraps(fenics_function) def jax_solve_eval(*args): return jax_solve_eval_p.bind(*args) jax_solve_eval_p = Primitive("jax_solve_eval") jax_solve_eval_p.def_impl(lambda *args: solve_eval( fenics_function, fenics_templates, *args)[0]) jax_solve_eval_p.def_abstract_eval( lambda *args: jax.abstract_arrays.make_shaped_array( solve_eval(fenics_function, fenics_templates, *args)[0])) def jax_solve_eval_batch(vector_arg_values, batch_axes): assert len( set(batch_axes)) == 1 # assert that all batch axes are same assert ( batch_axes[0] == 0 ) # assert that batch axis is zero, need to rewrite for a general case? # compute function row-by-row res = np.asarray([ jax_solve_eval(*(vector_arg_values[j][i] for j in range(len(batch_axes)))) for i in range(vector_arg_values[0].shape[0]) ]) return res, batch_axes[0] jax.batching.primitive_batchers[ jax_solve_eval_p] = jax_solve_eval_batch # @trace("djax_solve_eval") def djax_solve_eval(*args): return djax_solve_eval_p.bind(*args) djax_solve_eval_p = Primitive("djax_solve_eval") # djax_solve_eval_p.multiple_results = True djax_solve_eval_p.def_impl(lambda *args: vjp_solve_eval( fenics_function, fenics_templates, *args)) defvjp_all(jax_solve_eval_p, djax_solve_eval) return jax_solve_eval
def decorator(fenics_function: Callable) -> Callable: def jax_assemble_eval(*args): return jax_assemble_eval_p.bind(*args) jax_assemble_eval_p = Primitive("jax_assemble_eval") jax_assemble_eval_p.def_impl( lambda *args: assemble_eval(fenics_function, fenics_templates, *args)[0] ) jax_assemble_eval_p.def_abstract_eval( lambda *args: jax.abstract_arrays.make_shaped_array( assemble_eval(fenics_function, fenics_templates, *args)[0] ) ) def jax_assemble_eval_batch(vector_arg_values, batch_axes): assert len(set(batch_axes)) == 1 # assert that all batch axes are same assert ( batch_axes[0] == 0 ) # assert that batch axis is zero, need to rewrite for a general case? res = list(map(jax_assemble_eval, *vector_arg_values)) res = np.asarray(res) return res, batch_axes[0] batching.primitive_batchers[jax_assemble_eval_p] = jax_assemble_eval_batch # @trace("djax_assemble_eval") def djax_assemble_eval(*args): return djax_assemble_eval_p.bind(*args) djax_assemble_eval_p = Primitive("djax_assemble_eval") # djax_assemble_eval_p.multiple_results = True djax_assemble_eval_p.def_impl( lambda *args: vjp_assemble_eval(fenics_function, fenics_templates, *args) ) defvjp_all(jax_assemble_eval_p, djax_assemble_eval) return jax_assemble_eval
def decorator(fenics_function: Callable) -> Callable: @functools.wraps(fenics_function) def jax_solve_eval(*args): return jax_solve_eval_p.bind(*args) jax_solve_eval_p = Primitive("jax_solve_eval") jax_solve_eval_p.def_impl( lambda *args: solve_eval(fenics_function, fenics_templates, *args)[0] ) jax_solve_eval_p.def_abstract_eval( lambda *args: jax.abstract_arrays.make_shaped_array( solve_eval(fenics_function, fenics_templates, *args)[0] ) ) def jax_solve_eval_batch(vector_arg_values, batch_axes): assert len(set(batch_axes)) == 1 # assert that all batch axes are same assert ( batch_axes[0] == 0 ) # assert that batch axis is zero, need to rewrite for a general case? res = list(map(jax_solve_eval, *vector_arg_values)) res = np.asarray(res) return res, batch_axes[0] jax.batching.primitive_batchers[jax_solve_eval_p] = jax_solve_eval_batch # @trace("jvp_jax_solve_eval") def jvp_jax_solve_eval(ps, ts): return jvp_jax_solve_eval_p.bind(ps, ts) jvp_jax_solve_eval_p = Primitive("jvp_jax_solve_eval") jvp_jax_solve_eval_p.multiple_results = True jvp_jax_solve_eval_p.def_impl( lambda ps, ts: jvp_solve_eval(fenics_function, fenics_templates, ps, ts) ) jax.interpreters.ad.primitive_jvps[jax_solve_eval_p] = jvp_jax_solve_eval # TODO: JAX Tracer goes inside fenics wrappers and zero array is returned # because fenics numpy conversion works only for concrete arrays vjp_jax_solve_eval_p = Primitive("vjp_jax_solve_eval") vjp_jax_solve_eval_p.def_impl( lambda ct, *args: vjp_solve_eval(fenics_function, fenics_templates, *args)[ 1 ](ct) ) jax.interpreters.ad.primitive_transposes[ jax_solve_eval_p ] = vjp_jax_solve_eval_p return jax_solve_eval
shape=sh, opaque=descriptor, has_side_effect=True, ) # This function evaluates only the shapes during AST construction def mpi_scatter_abstract_eval(x, token, root, comm): comm = unpack_hashable(comm) rank = comm.Get_rank() if rank == root: out_shape = x.shape[1:] else: out_shape = x.shape return ( abstract_arrays.ShapedArray(out_shape, x.dtype), core.abstract_token, ) mpi_scatter_p.multiple_results = True mpi_scatter_p.def_impl(mpi_scatter_impl) mpi_scatter_p.def_abstract_eval(mpi_scatter_abstract_eval) # assign to the primitive the correct encoder xla.backend_specific_translations["cpu"][ mpi_scatter_p] = mpi_scatter_xla_encode_cpu xla.backend_specific_translations["gpu"][ mpi_scatter_p] = mpi_scatter_xla_encode_gpu
scale = 1 / prod(fft_lengths) out = scale * mask * x assert out.dtype == _complex_dtype(t.dtype), (out.dtype, t.dtype) # Use JAX's convention for complex gradients # https://github.com/google/jax/issues/6223#issuecomment-807740707 return lax.conj(out) def fft_transpose_rule(t, operand, fft_type, fft_lengths): if fft_type == xla_client.FftType.RFFT: result = _rfft_transpose(t, fft_lengths) elif fft_type == xla_client.FftType.IRFFT: result = _irfft_transpose(t, fft_lengths) else: result = fft(t, fft_type, fft_lengths) return result, def fft_batching_rule(batched_args, batch_dims, fft_type, fft_lengths): x, = batched_args bd, = batch_dims x = batching.moveaxis(x, bd, 0) return fft(x, fft_type, fft_lengths), 0 fft_p = Primitive('fft') fft_p.def_impl(fft_impl) fft_p.def_abstract_eval(fft_abstract_eval) xla.translations[fft_p] = fft_translation_rule ad.deflinear2(fft_p, fft_transpose_rule) batching.primitive_batchers[fft_p] = fft_batching_rule if pocketfft: xla.backend_specific_translations['cpu'][fft_p] = pocketfft.pocketfft
operands=( sendbuf, token, ), shape=sh, opaque=descriptor, has_side_effect=True, ) # This function evaluates only the shapes during AST construction def mpi_allgather_abstract_eval(x, token, comm): comm = unpack_hashable(comm) size = comm.Get_size() out_shape = (size, *x.shape) return ( abstract_arrays.ShapedArray(out_shape, x.dtype), abstract_arrays.abstract_token, ) mpi_allgather_p.multiple_results = True mpi_allgather_p.def_impl(mpi_allgather_impl) mpi_allgather_p.def_abstract_eval(mpi_allgather_abstract_eval) # assign to the primitive the correct encoder xla.backend_specific_translations["cpu"][ mpi_allgather_p] = mpi_allgather_xla_encode_cpu xla.backend_specific_translations["gpu"][ mpi_allgather_p] = mpi_allgather_xla_encode_gpu
# Use JAX's convention for complex gradients # https://github.com/google/jax/issues/6223#issuecomment-807740707 return lax.conj(out) def _fft_transpose_rule(t, operand, fft_type, fft_lengths): if fft_type == xla_client.FftType.RFFT: result = _rfft_transpose(t, fft_lengths) elif fft_type == xla_client.FftType.IRFFT: result = _irfft_transpose(t, fft_lengths) else: result = fft(t, fft_type, fft_lengths) return result, def _fft_batching_rule(batched_args, batch_dims, fft_type, fft_lengths): x, = batched_args bd, = batch_dims x = batching.moveaxis(x, bd, 0) return fft(x, fft_type, fft_lengths), 0 fft_p = Primitive('fft') fft_p.def_impl(_fft_impl) fft_p.def_abstract_eval(fft_abstract_eval) mlir.register_lowering(fft_p, _fft_lowering) ad.deflinear2(fft_p, _fft_transpose_rule) batching.primitive_batchers[fft_p] = _fft_batching_rule if pocketfft: mlir.register_lowering(fft_p, _fft_lowering_cpu, platform='cpu')
import numpy as np from jax.util import safe_map, safe_zip from jax.core import Primitive from jax.interpreters import ad, xla, batching, numpy_eval from jax.lax import dynamic_update_slice_p map = safe_map zip = safe_zip inplace_dynamic_update_slice_p = Primitive('inplace_dynamic_update_slice') inplace_dynamic_update_slice_p.def_impl(dynamic_update_slice_p.impl) inplace_dynamic_update_slice_p.def_abstract_eval(dynamic_update_slice_p.abstract_eval) for rules in [xla.translations, ad.primitive_jvps, ad.primitive_transposes, batching.primitive_batchers]: rules[inplace_dynamic_update_slice_p] = rules[dynamic_update_slice_p] def _numpy_inplace_dynamic_update_slice(operand, update, *start_indices): slices = tuple(map(slice, start_indices, np.add(start_indices, update.shape))) operand[slices] = update return operand numpy_eval.np_impl[inplace_dynamic_update_slice_p] = \ _numpy_inplace_dynamic_update_slice def inplace_dynamic_update_slice(operand, update, start_indices): return inplace_dynamic_update_slice_p.bind(operand, update, *start_indices)
shape=sh, opaque=descriptor, has_side_effect=True, ) # This function evaluates only the shapes during AST construction def mpi_gather_abstract_eval(x, token, root, comm): comm = unpack_hashable(comm) rank = comm.Get_rank() size = comm.Get_size() if rank == root: out_shape = (size, *x.shape) else: out_shape = x.shape return ( abstract_arrays.ShapedArray(out_shape, x.dtype), core.abstract_token, ) mpi_gather_p.multiple_results = True mpi_gather_p.def_impl(mpi_gather_impl) mpi_gather_p.def_abstract_eval(mpi_gather_abstract_eval) # assign to the primitive the correct encoder xla.backend_specific_translations["cpu"][mpi_gather_p] = mpi_gather_xla_encode_cpu xla.backend_specific_translations["gpu"][mpi_gather_p] = mpi_gather_xla_encode_gpu
def get_pullback_function( fenics_function: Callable, fenics_templates: Collection[BackendVariable] ) -> Callable: """Computes the gradients of the output with respect to the input Input: fenics_function (callable): FEniCS function to be executed during the forward pass Output: A Python callable representing the VJP map from output cotangents to input cotangents. The returned VJP function must accept a value with the same shape as the value of fun applied to the arguments and must return a tuple with length equal to the number of positional arguments to fun. """ # @trace("vjp_fun1") def vjp_fun1(aux_args, g): return vjp_fun1_p.bind(aux_args, g) def vjp_fun1_p_impl(aux_args, g): fe_aux, args = aux_args fenics_output, fenics_inputs, tape = ( fe_aux.fenics_output, fe_aux.fenics_inputs, fe_aux.tape, ) return tuple( vjp if vjp is not None else jax.ad_util.zeros_like_jaxval(args[i]) for i, vjp in enumerate( evaluate_pullback(fenics_output, fenics_inputs, tape, g) ) ) vjp_fun1_p = Primitive("vjp_fun1") vjp_fun1_p.multiple_results = True vjp_fun1_p.def_impl(vjp_fun1_p_impl) # @trace("vjp_fun1_abstract_eval") def vjp_fun1_abstract_eval(aux_args, g): _, args = aux_args if len(args) > 1: return tuple( (jax.abstract_arrays.ShapedArray(arg.shape, arg.dtype) for arg in args) ) else: return ( jax.abstract_arrays.ShapedArray((1, *args[0].shape), args[0].dtype), ) vjp_fun1_p.def_abstract_eval(vjp_fun1_abstract_eval) # @trace("vjp_fun1_batch") def vjp_fun1_batch(vector_arg_values, batch_axes): """Computes the batched version of the primitive. This must be a JAX-traceable function. Args: vector_arg_values: a tuple of arguments, each being a tensor of matching shape. batch_axes: the axes that are being batched. See vmap documentation. Returns: a tuple of the result, and the result axis that was batched. """ # _trace("Using vjp_fun1 to compute the batch:") _, args = vector_arg_values[0] assert ( batch_axes[0] is None ) # assert that batch axis is None, need to rewrite for a general case? assert ( batch_axes[1] == 0 ) # assert that batch axis is zero, need to rewrite for a general case? # apply function row-by-row vjp_fun1_partial = functools.partial(vjp_fun1, vector_arg_values[0]) res = list(map(vjp_fun1_partial, *(vector_arg_values[1],))) # transpose resulting list res_T = list(itertools.zip_longest(*res)) return tuple(map(np.vstack, res_T)), (batch_axes[1],) * len(args) jax.interpreters.batching.primitive_batchers[vjp_fun1_p] = vjp_fun1_batch return vjp_fun1
def decorator(fenics_function: Callable) -> Callable: @functools.wraps(fenics_function) def jax_fem_eval(*args): return jax_fem_eval_p.bind(*args) jax_fem_eval_p = Primitive("jax_fem_eval") def jax_fem_eval_p_impl(*args): args = ( jax_to_fenics_numpy(arg, ft) for arg, ft in zip(args, fenics_templates) ) return evaluate_primal(fenics_function, fenics_templates, *args)[0] jax_fem_eval_p.def_impl(jax_fem_eval_p_impl) def jax_fem_eval_p_abstract_eval(*args): args = ( jax_to_fenics_numpy(arg, ft) for arg, ft in zip(args, fenics_templates) ) return jax.abstract_arrays.make_shaped_array( evaluate_primal(fenics_function, fenics_templates, *args)[0] ) jax_fem_eval_p.def_abstract_eval(jax_fem_eval_p_abstract_eval) def jax_fem_eval_batch(vector_arg_values, batch_axes): assert len(set(batch_axes)) == 1 # assert that all batch axes are same assert ( batch_axes[0] == 0 ) # assert that batch axis is zero, need to rewrite for a general case? res = list(map(jax_fem_eval, *vector_arg_values)) res = np.asarray(res) return res, batch_axes[0] jax.interpreters.batching.primitive_batchers[ jax_fem_eval_p ] = jax_fem_eval_batch # @trace("jvp_jax_fem_eval") def jvp_jax_fem_eval(ps, ts): return jvp_jax_fem_eval_p.bind(ps, ts) jvp_jax_fem_eval_p = Primitive("jvp_jax_fem_eval") jvp_jax_fem_eval_p.multiple_results = True def jvp_jax_fem_eval_impl(primals, tangents): primals = ( jax_to_fenics_numpy(p, ft) for p, ft in zip(primals, fenics_templates) ) numpy_output, fenics_output, fenics_inputs, tape = evaluate_primal( fenics_function, fenics_templates, *primals ) tangents = ( jax_to_fenics_numpy(t, ft) for t, ft in zip(tangents, fenics_templates) ) dnumpy_output = evaluate_pushforward( fenics_output, fenics_inputs, tape, tangents ) return numpy_output, dnumpy_output jvp_jax_fem_eval_p.def_impl(jvp_jax_fem_eval_impl) jax.interpreters.ad.primitive_jvps[jax_fem_eval_p] = jvp_jax_fem_eval # TODO: JAX Tracer goes inside fenics wrappers and zero array is returned # because fenics numpy conversion works only for concrete arrays # vjp_jax_fem_eval_p = Primitive("vjp_jax_fem_eval") # vjp_jax_fem_eval_p.def_impl( # lambda ct, *args: vjp_fem_eval(fenics_function, fenics_templates, *args)[1]( # ct # ) # ) # jax.interpreters.ad.primitive_transposes[jax_fem_eval_p] = vjp_jax_fem_eval_p return jax_fem_eval
def PmapPrimitive(name): prim = Primitive(name) prim.def_impl(partial(_unbound_name_error, name)) prim.def_abstract_eval(lambda x, *args, **kwargs: x) return prim
) return xla_client.ops.CustomCall( c, b"mpi_scan", operands=( x, token, ), shape=sh, opaque=descriptor, has_side_effect=True, ) # This function evaluates only the shapes during AST construction def mpi_scan_abstract_eval(xs, token, op, comm): return ( abstract_arrays.ShapedArray(xs.shape, xs.dtype), core.abstract_token, ) mpi_scan_p.multiple_results = True mpi_scan_p.def_impl(mpi_scan_impl) mpi_scan_p.def_abstract_eval(mpi_scan_abstract_eval) # assign to the primitive the correct encoder xla.backend_specific_translations["cpu"][mpi_scan_p] = mpi_scan_xla_encode_cpu xla.backend_specific_translations["gpu"][mpi_scan_p] = mpi_scan_xla_encode_gpu
operands=(a, b, c, d), shape_with_layout=shape, operand_shapes_with_layout=(shape, ) * 4, opaque=opaque) def tridiag(a, b, c, d): if not a.shape == b.shape == c.shape == d.shape: raise ValueError('all inputs must have identical shape') if not a.dtype == b.dtype == c.dtype == d.dtype: raise ValueError('all inputs must have the same dtype') return tridiag_p.bind(a, b, c, d) #transpose(res, (0,2,1)) def tridiag_impl(*args, **kwargs): return xla.apply_primitive(tridiag_p, *args, **kwargs) def _tridiag_gpu_translation_rule(computation_builder, a, b, c, d): return _tridiag(computation_builder, a, b, c, d) def tridiag_abstract_eval(a, b, c, d): return abstract_arrays.ShapedArray(a.shape, a.dtype) tridiag_p = Primitive('tridiag') tridiag_p.def_impl(tridiag_impl) tridiag_p.def_abstract_eval(tridiag_abstract_eval) xla.backend_specific_translations['gpu'][ tridiag_p] = _tridiag_gpu_translation_rule
return grads.reshape(alpha.shape) def _standard_gamma_abstract_eval(key, alpha, jaxpr, aval, consts): return lax.maybe_tracer_tuple_to_abstract_tuple(aval) def _standard_gamma_translate(c, key, alpha, jaxpr, aval, consts): xla_computation = xla.jaxpr_computation(jaxpr, consts, (), c.GetShape(key), c.GetShape(alpha)) return c.Call(xla_computation, (key, alpha)) # define primitive standard_gamma_p = Primitive('standard_gamma') standard_gamma_p.def_impl(partial(xla.apply_primitive, standard_gamma_p)) standard_gamma_p.def_abstract_eval(_standard_gamma_abstract_eval) xla.translations[standard_gamma_p] = _standard_gamma_translate ad.defjvp2( standard_gamma_p, None, lambda tangent, sample, key, alpha, **kwargs: tangent * _standard_gamma_grad(sample, alpha)) @partial(jit, static_argnums=(2, 3)) def standard_gamma(key, alpha, shape=(), dtype=np.float32): shape = shape or np.shape(alpha) alpha = lax.convert_element_type(alpha, dtype) if np.shape(alpha) != shape: alpha = np.broadcast_to(alpha, shape) jaxpr, out, consts = partial_eval.trace_unwrapped_to_jaxpr( _standard_gamma_impl, tuple(lax._abstractify(o) for o in (key, alpha)))
out = _ops.CustomCall( c, b"mpi_send", operands=( _nitems, x, _constant_s32_scalar(c, dest), _constant_s32_scalar(c, tag), _constant_u64_scalar(c, to_mpi_ptr(comm)), _constant_u64_scalar(c, _dtype_ptr), token, ), shape=sh, has_side_effect=True, ) return xla_client.ops.GetTupleElement(out, 0) # This function evaluates only the shapes during AST construction def mpi_send_abstract_eval(xs, token, dest, tag, comm): return abstract_arrays.abstract_token # mpi_send_p.multiple_results = True mpi_send_p.def_impl(mpi_send_impl) mpi_send_p.def_abstract_eval(mpi_send_abstract_eval) # assign to the primitive the correct encoder xla.backend_specific_translations["cpu"][mpi_send_p] = mpi_send_xla_encode
return xla_client.ops.CustomCall( c, b"mpi_alltoall", operands=( x, token, ), shape=sh, opaque=descriptor, has_side_effect=True, ) # This function evaluates only the shapes during AST construction def mpi_alltoall_abstract_eval(xs, token, comm): return ( abstract_arrays.ShapedArray(xs.shape, xs.dtype), core.abstract_token, ) mpi_alltoall_p.multiple_results = True mpi_alltoall_p.def_impl(mpi_alltoall_impl) mpi_alltoall_p.def_abstract_eval(mpi_alltoall_abstract_eval) # assign to the primitive the correct encoder xla.backend_specific_translations["cpu"][ mpi_alltoall_p] = mpi_alltoall_xla_encode_cpu xla.backend_specific_translations["gpu"][ mpi_alltoall_p] = mpi_alltoall_xla_encode_gpu
def vjp_solve_eval(fenics_function: Callable, fenics_templates: FenicsVariable, *args: np.array) -> Tuple[np.array, Callable]: """Computes the gradients of the output with respect to the input Input: fenics_function (callable): FEniCS function to be executed during the forward pass args (tuple): jax array representation of the input to fenics_function Output: A pair where the first element is the value of fun applied to the arguments and the second element is a Python callable representing the VJP map from output cotangents to input cotangents. The returned VJP function must accept a value with the same shape as the value of fun applied to the arguments and must return a tuple with length equal to the number of positional arguments to fun. """ numpy_output, fenics_solution, residual_form, fenics_inputs, bcs = solve_eval( fenics_function, fenics_templates, *args) # @trace("vjp_fun1") def vjp_fun1(g): return vjp_fun1_p.bind(g) vjp_fun1_p = Primitive("vjp_fun1") vjp_fun1_p.multiple_results = True vjp_fun1_p.def_impl(lambda g: tuple( vjp if vjp is not None else jax.ad_util.zeros_like_jaxval(args[i]) for i, vjp in enumerate( vjp_solve_eval_impl(g, fenics_solution, residual_form, fenics_inputs, bcs)))) # @trace("vjp_fun1_abstract_eval") def vjp_fun1_abstract_eval(g): if len(args) > 1: return tuple( (jax.abstract_arrays.ShapedArray(arg.shape, arg.dtype) for arg in args)) else: return (jax.abstract_arrays.ShapedArray((1, *args[0].shape), args[0].dtype), ) vjp_fun1_p.def_abstract_eval(vjp_fun1_abstract_eval) # @trace("vjp_fun1_batch") def vjp_fun1_batch(vector_arg_values, batch_axes): """Computes the batched version of the primitive. This must be a JAX-traceable function. Args: vector_arg_values: a tuple of arguments, each being a tensor of matching shape. batch_axes: the axes that are being batched. See vmap documentation. Returns: a tuple of the result, and the result axis that was batched. """ # _trace("Using vjp_fun1 to compute the batch:") assert ( batch_axes[0] == 0 ) # assert that batch axis is zero, need to rewrite for a general case? # compute function row-by-row res = [ vjp_fun1(vector_arg_values[0][i]) for i in range(vector_arg_values[0].shape[0]) ] # transpose resulting list res_T = list(itertools.zip_longest(*res)) return tuple(map(np.vstack, res_T)), (batch_axes[0], ) * len(args) jax.interpreters.batching.primitive_batchers[vjp_fun1_p] = vjp_fun1_batch return numpy_output, vjp_fun1
elif x.shape[i_old] > 1 and new_sizes[i_new] <= 1: i_new += 1 else: assert x.shape[i_old] == new_sizes[i_new] i_old += 1 i_new += 1 return out_indices, op2ind[lax.reshape_p] = reshape2ind # Custom Primitive tex_var_p = Primitive('tex_var') tex_var_p.def_impl(lambda x, *depends_on, **kwargs: x) tex_var_p.def_abstract_eval(lambda x, *depends_on, **kwargs: x) def tex_var_jvp(primals, tangents, name, is_alias): primal_out = tex_var(primals[0], name, is_alias, primals[1:]) tangent_out = tex_var(tangents[0], 'd' + name, False, primals[1:]) return primal_out, tangent_out ad.primitive_jvps[tex_var_p] = tex_var_jvp def tex_var_transpose(ct, x, *depends_on, **params): del x ct = tex_var(ct, '\\delta ' + params['name'][1:], False, depends_on) return (ct,) + (ad.Zero,) * len(depends_on) ad.primitive_transposes[tex_var_p] = tex_var_transpose
return xla_client.ops.CustomCall( c, b"mpi_reduce", operands=( x, token, ), shape=sh, opaque=descriptor, has_side_effect=True, ) # This function evaluates only the shapes during AST construction def mpi_reduce_abstract_eval(xs, token, op, root, comm): return ( abstract_arrays.ShapedArray(xs.shape, xs.dtype), abstract_arrays.abstract_token, ) mpi_reduce_p.multiple_results = True mpi_reduce_p.def_impl(mpi_reduce_impl) mpi_reduce_p.def_abstract_eval(mpi_reduce_abstract_eval) # assign to the primitive the correct encoder xla.backend_specific_translations["cpu"][ mpi_reduce_p] = mpi_reduce_xla_encode_cpu xla.backend_specific_translations["gpu"][ mpi_reduce_p] = mpi_reduce_xla_encode_gpu
== cosu.dtype == dt_tracer.dtype: raise ValueError('all inputs must have the same dtype') return superbee_p.bind(var, u_wgrid, v_wgrid, w_wgrid, maskW, dxt, dyt, dzw, cost, cosu, dt_tracer) def superbee_impl(*args, **kwargs): return xla.apply_primitive(superbee_p, *args, **kwargs) def _superbee_gpu_translation_rule(computation_builder, var, u_wgrid, v_wgrid, w_wgrid, maskW, dxt, dyt, dzw, cost, cosu, dt_tracer): return _superbee(computation_builder, var, u_wgrid, v_wgrid, w_wgrid, maskW, dxt, dyt, dzw, cost, cosu, dt_tracer) def superbee_abstract_eval(var, u_wgrid, v_wgrid, w_wgrid, maskW, dxt, dyt, dzw, cost, cosu, dt_tracer): aarr = abstract_arrays.ShapedArray(var.shape, var.dtype) return (aarr, ) * 3 superbee_p = Primitive('superbee') superbee_p.multiple_results = True superbee_p.def_impl(superbee_impl) superbee_p.def_abstract_eval(superbee_abstract_eval) xla.backend_specific_translations['gpu'][ superbee_p] = _superbee_gpu_translation_rule
compute_right_eigenvectors): if compute_left_eigenvectors or compute_right_eigenvectors: raise NotImplementedError( 'The derivatives of eigenvectors are not implemented, only ' 'eigenvalues. See ' 'https://github.com/google/jax/issues/2748 for discussion.') # Formula for derivative of eigenvalues w.r.t. a is eqn 4.60 in # https://arxiv.org/abs/1701.00392 a, = primals da, = tangents l, v = eig(a, compute_left_eigenvectors=False) return [l], [jnp.sum(_solve(v, da.astype(v.dtype)) * _T(v), -1)] eig_p = Primitive('eig') eig_p.multiple_results = True eig_p.def_impl(eig_impl) eig_p.def_abstract_eval(eig_abstract_eval) xla.translations[eig_p] = eig_translation_rule xla.backend_specific_translations['cpu'][eig_p] = eig_cpu_translation_rule batching.primitive_batchers[eig_p] = eig_batching_rule ad.primitive_jvps[eig_p] = eig_jvp_rule # Symmetric/Hermitian eigendecomposition def eigh_impl(operand, lower): v, w = xla.apply_primitive(eigh_p, operand, lower=lower) return v, w def eigh_translation_rule(c, operand, lower): shape = c.get_shape(operand)
def qr_translation_rule(c, operand, full_matrices): return c.QR(operand, full_matrices=full_matrices) def qr_abstract_eval(operand, full_matrices): if isinstance(operand, ShapedArray): if operand.ndim < 2: raise ValueError( "Argument to QR decomposition must have ndims >= 2") batch_dims = operand.shape[:-2] m = operand.shape[-2] n = operand.shape[-1] k = m if full_matrices else min(m, n) q = ShapedArray(batch_dims + (m, k), operand.dtype) r = ShapedArray(batch_dims + (k, n), operand.dtype) else: q = operand r = operand return core.AbstractTuple((q, r)) def qr_dtype_rule(operand, full_matrices=True): return operand.dtype qr_p = Primitive('qr') qr_p.def_impl(qr_impl) qr_p.def_abstract_eval(qr_abstract_eval) xla.translations[qr_p] = qr_translation_rule
c, b"mpi_sendrecv", operands=( sendbuf, token, ), shape=sh, opaque=descriptor, has_side_effect=True, ) # This function evaluates only the shapes during AST construction def mpi_sendrecv_abstract_eval(sendbuf, recvbuf, token, source, dest, sendtag, recvtag, comm, status): return ( abstract_arrays.ShapedArray(recvbuf.shape, recvbuf.dtype), abstract_arrays.abstract_token, ) mpi_sendrecv_p.multiple_results = True mpi_sendrecv_p.def_impl(mpi_sendrecv_impl) mpi_sendrecv_p.def_abstract_eval(mpi_sendrecv_abstract_eval) # assign to the primitive the correct encoder xla.backend_specific_translations["cpu"][ mpi_sendrecv_p] = mpi_sendrecv_xla_encode_cpu xla.backend_specific_translations["gpu"][ mpi_sendrecv_p] = mpi_sendrecv_xla_encode_gpu
abstract_arrays.abstract_token, ) def mpi_allreduce_value_and_jvp(in_args, tan_args, op, comm): x, token = in_args x_tan, token_tan = tan_args res = Allreduce(x, token=token, op=op, comm=comm) # Identify the correct adjoint if op == _MPI.SUM: jvp = (x_tan, token_tan) else: raise NotImplementedError( "The adjoint of allreduce for {} operation is not defined".format( op)) return (res, jvp) mpi_allreduce_p.multiple_results = True mpi_allreduce_p.def_impl(mpi_allreduce_impl) mpi_allreduce_p.def_abstract_eval(mpi_allreduce_abstract_eval) ad.primitive_jvps[mpi_allreduce_p] = mpi_allreduce_value_and_jvp # assign to the primitive the correct encoder xla.backend_specific_translations["cpu"][ mpi_allreduce_p] = mpi_allreduce_xla_encode
assert len(params) == 0 return ShapedArray((2,), 'uint32') def _random_key_impl(*args, **params): assert len(args) == 0 assert len(params) == 0 raise ValueError("This parametrized function is randomized and therefore requires " "a random key when applied, i. e. `apply(*inputs, key=PRNGKey(0))`.") random_key_p = Primitive("random_key") random_key_p.def_custom_bind(bind(random_key_p)) random_key_p.def_impl(_random_key_impl) random_key_p.def_abstract_eval(_random_key_abstract_eval) def random_key(): """When called inside a parametrized function, this will return a unique random key derived from the `key` argument of `apply` or `init_parameters`.""" return random_key_p.bind() class parametrized(Primitive): """Represents a parametrized function, providing an `init_parameters` function for bundled initialization of all parameters, and an `apply` function to evaluate the function given these bundled parameters. For example, if a dense neural network layer is defined via:
_nan_like(c, vl)) vr = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), vr, _nan_like(c, vr)) return xops.Tuple(c, [w, vl, vr]) def eig_batching_rule(batched_args, batch_dims): x, = batched_args bd, = batch_dims x = batching.moveaxis(x, bd, 0) return eig_p.bind(x), (0, 0, 0) eig_p = Primitive('eig') eig_p.multiple_results = True eig_p.def_impl(eig_impl) eig_p.def_abstract_eval(eig_abstract_eval) xla.translations[eig_p] = eig_translation_rule xla.backend_specific_translations['cpu'][eig_p] = eig_cpu_translation_rule batching.primitive_batchers[eig_p] = eig_batching_rule # Symmetric/Hermitian eigendecomposition def eigh_impl(operand, lower): v, w = xla.apply_primitive(eigh_p, operand, lower=lower) return v, w def eigh_translation_rule(c, operand, lower): shape = c.GetShape(operand)
_nitems, _constant_s32_scalar(c, source), _constant_s32_scalar(c, tag), _constant_u64_scalar(c, to_mpi_ptr(comm)), _constant_u64_scalar(c, _dtype_ptr), _constant_u64_scalar(c, _status), token, ) out = _ops.CustomCall( c, b"mpi_recv", operands=operands, shape=sh, has_side_effect=True, ) return out # This function evaluates only the shapes during AST construction def mpi_recv_abstract_eval(xs, token, source, tag, comm, status): return ( abstract_arrays.ShapedArray(xs.shape, xs.dtype), abstract_arrays.abstract_token, ) mpi_recv_p.multiple_results = True mpi_recv_p.def_impl(mpi_recv_impl) mpi_recv_p.def_abstract_eval(mpi_recv_abstract_eval) # assign to the primitive the correct encoder xla.backend_specific_translations["cpu"][mpi_recv_p] = mpi_recv_xla_encode