def test_unimplemented_interpreter_rules(self): foo_p = Primitive('foo') def foo(x): return foo_p.bind(x) jtu.check_raises(lambda: foo(1.0), NotImplementedError, "Evaluation rule for 'foo' not implemented") jtu.check_raises(lambda: jit(foo)(1.0), NotImplementedError, "Abstract evaluation for 'foo' not implemented") jtu.check_raises( lambda: grad(foo)(1.0), NotImplementedError, "Forward-mode differentiation rule for 'foo' not implemented") def_abstract_eval(foo_p, lambda x: x) jtu.check_raises(lambda: jit(foo)(1.0), NotImplementedError, "XLA translation rule for 'foo' not implemented") foo_p.def_impl(lambda x: x) defjvp(foo_p, lambda g, x: foo(g)) jtu.check_raises( lambda: grad(foo)(1.0), NotImplementedError, "Reverse-mode differentiation rule for 'foo' not implemented")
def test_defvjp_all_higher_order_revmode(self): foo_p = Primitive('foo') def foo(x): return 2. * foo_p.bind(x) ad.defvjp_all(foo_p, lambda x: (x**2, lambda g: (g * x ** 2,))) ans = api.grad(api.grad(foo))(3.) self.assertAllClose(ans, 2 * 2 * 3., check_dtypes=False)
def find_callback( prim: core.Primitive, vals: Sequence[core.Tracer], params: Dict[str, Any]) -> Union[core.Tracer, Sequence[core.Tracer]]: vals = prim.bind(*vals, **params) _contains_query(vals, queries) return vals
def rewrite_callback( prim: core.Primitive, vals: Sequence[core.Tracer], params: Dict[str, Any]) -> Union[core.Tracer, Sequence[core.Tracer]]: if prim in rules: return rules[prim](*vals, **params) return prim.bind(*vals, **params)
def process_primitive(self, primitive: core.Primitive, tracers: Sequence[TensorFlowTracer], params) -> TensorFlowTracer: impl = self.get_primitive_impl(primitive) args_tf: Sequence[TfValOrUnit] = [t.val for t in tracers] # impl takes core.unit and returns core.unit when needed. val_out: TfValOrUnit = impl(*args_tf, **params) if primitive.multiple_results: out = util.safe_map(functools.partial(TensorFlowTracer, self), val_out) # type: ignore else: out = TensorFlowTracer(self, val_out) # Check that the impl rule returned a value of expected shape and dtype if not core.skip_checks: expected_out_aval: core.AbstractValue = primitive.abstract_eval( *[t.aval for t in tracers], **params) if primitive.multiple_results: for o, expected_aval in zip(out, expected_out_aval): # type: ignore assert o.aval == expected_aval, ( f"{primitive}: out.aval = {o.aval}; expected {expected_aval}" ) else: assert out.aval == expected_out_aval, ( # type: ignore f"{primitive}: out.aval = {out.aval}; expected {expected_out_aval}" ) # type: ignore return out # type: ignore
def test_defvjp_all_const(self): foo_p = Primitive('foo') def foo(x): return foo_p.bind(x) ad.defvjp_all(foo_p, lambda x: (x**2, lambda g: (12.,))) val_ans, grad_ans = api.value_and_grad(foo)(3.) self.assertAllClose(val_ans, 9., check_dtypes=False) self.assertAllClose(grad_ans, 12., check_dtypes=True)
def test_defvjp_all(self): foo_p = Primitive('foo') def foo(x): return 2. * foo_p.bind(x) ad.defvjp_all(foo_p, lambda x: (x**2, lambda g: (4 * g * np.sin(x),))) val_ans, grad_ans = api.value_and_grad(foo)(3.) self.assertAllClose(val_ans, 2 * 3.**2, check_dtypes=False) self.assertAllClose(grad_ans, 4 * 2 * onp.sin(3.), check_dtypes=False)
def _process_primitive(self, primitive: Primitive, flat_inputs, kwargs): if primitive in InitTrace._rules: return InitTrace._rules[primitive](self)(flat_inputs, kwargs) if isinstance(primitive, parametrized): return self.process_parametrized(primitive, *flat_inputs, **kwargs) return primitive.bind(*flat_inputs, **kwargs)
def decorator(fenics_function: Callable) -> Callable: @functools.wraps(fenics_function) def jax_solve_eval(*args): return jax_solve_eval_p.bind(*args) jax_solve_eval_p = Primitive("jax_solve_eval") jax_solve_eval_p.def_impl(lambda *args: solve_eval( fenics_function, fenics_templates, *args)[0]) jax_solve_eval_p.def_abstract_eval( lambda *args: jax.abstract_arrays.make_shaped_array( solve_eval(fenics_function, fenics_templates, *args)[0])) def jax_solve_eval_batch(vector_arg_values, batch_axes): assert len( set(batch_axes)) == 1 # assert that all batch axes are same assert ( batch_axes[0] == 0 ) # assert that batch axis is zero, need to rewrite for a general case? # compute function row-by-row res = np.asarray([ jax_solve_eval(*(vector_arg_values[j][i] for j in range(len(batch_axes)))) for i in range(vector_arg_values[0].shape[0]) ]) return res, batch_axes[0] jax.batching.primitive_batchers[ jax_solve_eval_p] = jax_solve_eval_batch # @trace("djax_solve_eval") def djax_solve_eval(*args): return djax_solve_eval_p.bind(*args) djax_solve_eval_p = Primitive("djax_solve_eval") # djax_solve_eval_p.multiple_results = True djax_solve_eval_p.def_impl(lambda *args: vjp_solve_eval( fenics_function, fenics_templates, *args)) defvjp_all(jax_solve_eval_p, djax_solve_eval) return jax_solve_eval
def decorator(fenics_function: Callable) -> Callable: def jax_assemble_eval(*args): return jax_assemble_eval_p.bind(*args) jax_assemble_eval_p = Primitive("jax_assemble_eval") jax_assemble_eval_p.def_impl( lambda *args: assemble_eval(fenics_function, fenics_templates, *args)[0] ) jax_assemble_eval_p.def_abstract_eval( lambda *args: jax.abstract_arrays.make_shaped_array( assemble_eval(fenics_function, fenics_templates, *args)[0] ) ) def jax_assemble_eval_batch(vector_arg_values, batch_axes): assert len(set(batch_axes)) == 1 # assert that all batch axes are same assert ( batch_axes[0] == 0 ) # assert that batch axis is zero, need to rewrite for a general case? res = list(map(jax_assemble_eval, *vector_arg_values)) res = np.asarray(res) return res, batch_axes[0] batching.primitive_batchers[jax_assemble_eval_p] = jax_assemble_eval_batch # @trace("djax_assemble_eval") def djax_assemble_eval(*args): return djax_assemble_eval_p.bind(*args) djax_assemble_eval_p = Primitive("djax_assemble_eval") # djax_assemble_eval_p.multiple_results = True djax_assemble_eval_p.def_impl( lambda *args: vjp_assemble_eval(fenics_function, fenics_templates, *args) ) defvjp_all(jax_assemble_eval_p, djax_assemble_eval) return jax_assemble_eval
def default_process_primitive( self, primitive: jax_core.Primitive, tracers: List['HarvestTracer'], params: Dict[str, Any]) -> Union['HarvestTracer', List['HarvestTracer']]: context = trace_util.get_dynamic_context(self) vals = [t.val for t in tracers] if primitive is sow_p: outvals = context.process_sow(*vals, **params) return jax_util.safe_map(self.pure, outvals) outvals = primitive.bind(*vals, **params) if not primitive.multiple_results: outvals = [outvals] out_tracers = jax_util.safe_map(self.pure, outvals) if primitive.multiple_results: return out_tracers return out_tracers[0]
def test_defvjp_all_multiple_arguments(self): # also tests passing in symbolic zero tangents b/c we differentiate wrt only # the first argument in one case foo_p = Primitive('foo') def foo(x, y): return foo_p.bind(x, y) def vjpfun(x, y): out = x**2 + y**3 vjp = lambda g: (g + x + y, g * x * 9.) return out, vjp ad.defvjp_all(foo_p, vjpfun) val_ans, grad_ans = api.value_and_grad(foo)(3., 4.) self.assertAllClose(val_ans, 3.**2 + 4.**3, check_dtypes=False) self.assertAllClose(grad_ans, 1. + 3. + 4., check_dtypes=False) ans = api.grad(foo, (0, 1))(3., 4.) self.assertAllClose(ans, (1. + 3. + 4., 1. * 3. * 9.), check_dtypes=False)
def decorator(fenics_function: Callable) -> Callable: @functools.wraps(fenics_function) @custom_vjp def jax_fem_eval(*args): return jax_fem_eval_p.bind(*args) jax_fem_eval_p = Primitive("jax_fem_eval") jax_fem_eval_p.def_impl( lambda *args: evaluate_primal(fenics_function, fenics_templates, *args)[0] ) jax_fem_eval_p.def_abstract_eval( lambda *args: jax.abstract_arrays.make_shaped_array( evaluate_primal(fenics_function, fenics_templates, *args)[0] ) ) def jax_fem_eval_batch(vector_arg_values, batch_axes): assert len(set(batch_axes)) == 1 # assert that all batch axes are same assert ( batch_axes[0] == 0 ) # assert that batch axis is zero, need to rewrite for a general case? res = list(map(jax_fem_eval, *vector_arg_values)) res = np.asarray(res) return res, batch_axes[0] jax.interpreters.batching.primitive_batchers[ jax_fem_eval_p ] = jax_fem_eval_batch def primal(*args): numpy_output, fenics_output, fenics_inputs, tape = evaluate_primal( fenics_function, fenics_templates, *args ) return ( numpy_output, (PyadjointMetadata(fenics_output, fenics_inputs, tape), args), ) def pullback(aux_args, g): pb_fn = get_pullback_function(fenics_function, fenics_templates) # for some reason output of get_pullback_function is a list but we need tuple return tuple(pb_fn(aux_args, g)) jax_fem_eval.defvjp(primal, pullback) return jax_fem_eval
v, w = operand, operand return core.AbstractTuple((v, w)) def eigh_cpu_translation_rule(c, operand, lower): shape = c.GetShape(operand) dtype = shape.element_type().type if len(shape.dimensions()) == 2 and dtype in _cpu_lapack_types: out = lapack.jax_syevd(c, operand, lower=lower) return c.Tuple(c.GetTupleElement(out, 0), c.GetTupleElement(out, 1)) else: raise NotImplementedError( "Only unbatched eigendecomposition is implemented on CPU") eigh_p = Primitive('eigh') eigh_p.def_impl(eigh_impl) eigh_p.def_abstract_eval(eigh_abstract_eval) xla.translations[eigh_p] = eigh_translation_rule xla.backend_specific_translations['Host'][eigh_p] = eigh_cpu_translation_rule triangular_solve_dtype_rule = partial(binop_dtype_rule, _input_dtype, (_float | _complex, _float | _complex), 'triangular_solve') def triangular_solve_shape_rule(a, b, left_side=False, **unused_kwargs): if a.ndim < 2: msg = "triangular_solve requires a.ndim to be at least 2, got {}." raise TypeError(msg.format(a.ndim)) if a.shape[-1] != a.shape[-2]:
== cosu.dtype == dt_tracer.dtype: raise ValueError('all inputs must have the same dtype') return superbee_p.bind(var, u_wgrid, v_wgrid, w_wgrid, maskW, dxt, dyt, dzw, cost, cosu, dt_tracer) def superbee_impl(*args, **kwargs): return xla.apply_primitive(superbee_p, *args, **kwargs) def _superbee_gpu_translation_rule(computation_builder, var, u_wgrid, v_wgrid, w_wgrid, maskW, dxt, dyt, dzw, cost, cosu, dt_tracer): return _superbee(computation_builder, var, u_wgrid, v_wgrid, w_wgrid, maskW, dxt, dyt, dzw, cost, cosu, dt_tracer) def superbee_abstract_eval(var, u_wgrid, v_wgrid, w_wgrid, maskW, dxt, dyt, dzw, cost, cosu, dt_tracer): aarr = abstract_arrays.ShapedArray(var.shape, var.dtype) return (aarr, ) * 3 superbee_p = Primitive('superbee') superbee_p.multiple_results = True superbee_p.def_impl(superbee_impl) superbee_p.def_abstract_eval(superbee_abstract_eval) xla.backend_specific_translations['gpu'][ superbee_p] = _superbee_gpu_translation_rule
from jax._src import traceback_util traceback_util.register_exclusion(__file__) Array = Any map = safe_map jaxval_adders: Dict[type, Callable] = {} def add_jaxvals(x, y): return add_jaxvals_p.bind(x, y) add_jaxvals_p: Primitive = Primitive('add_any') add_any_p = add_jaxvals_p @add_jaxvals_p.def_impl def add_impl(xs, ys): return jaxval_adders[type(xs)](xs, ys) @add_jaxvals_p.def_abstract_eval def add_abstract(xs, ys): return lattice_join(xs, ys) jaxval_zeros_likers: Dict[type, Array] = {}
i_old += 1 elif x.shape[i_old] > 1 and new_sizes[i_new] <= 1: i_new += 1 else: assert x.shape[i_old] == new_sizes[i_new] i_old += 1 i_new += 1 return out_indices, op2ind[lax.reshape_p] = reshape2ind # Custom Primitive tex_var_p = Primitive('tex_var') tex_var_p.def_impl(lambda x, *depends_on, **kwargs: x) tex_var_p.def_abstract_eval(lambda x, *depends_on, **kwargs: x) def tex_var_jvp(primals, tangents, name, is_alias): primal_out = tex_var(primals[0], name, is_alias, primals[1:]) tangent_out = tex_var(tangents[0], 'd' + name, False, primals[1:]) return primal_out, tangent_out ad.primitive_jvps[tex_var_p] = tex_var_jvp def tex_var_transpose(ct, x, *depends_on, **params): del x ct = tex_var(ct, '\\delta ' + params['name'][1:], False, depends_on) return (ct,) + (ad.Zero,) * len(depends_on)
import numpy as np from jax.util import safe_map, safe_zip from jax.core import Primitive from jax.interpreters import ad, xla, batching, numpy_eval from jax.lax import dynamic_update_slice_p map = safe_map zip = safe_zip inplace_dynamic_update_slice_p = Primitive('inplace_dynamic_update_slice') inplace_dynamic_update_slice_p.def_impl(dynamic_update_slice_p.impl) inplace_dynamic_update_slice_p.def_abstract_eval(dynamic_update_slice_p.abstract_eval) for rules in [xla.translations, ad.primitive_jvps, ad.primitive_transposes, batching.primitive_batchers]: rules[inplace_dynamic_update_slice_p] = rules[dynamic_update_slice_p] def _numpy_inplace_dynamic_update_slice(operand, update, *start_indices): slices = tuple(map(slice, start_indices, np.add(start_indices, update.shape))) operand[slices] = update return operand numpy_eval.np_impl[inplace_dynamic_update_slice_p] = \ _numpy_inplace_dynamic_update_slice def inplace_dynamic_update_slice(operand, update, start_indices): return inplace_dynamic_update_slice_p.bind(operand, update, *start_indices)
def decorator(fenics_function: Callable) -> Callable: @functools.wraps(fenics_function) def jax_solve_eval(*args): return jax_solve_eval_p.bind(*args) jax_solve_eval_p = Primitive("jax_solve_eval") jax_solve_eval_p.def_impl(lambda *args: solve_eval( fenics_function, fenics_templates, *args)[0]) jax_solve_eval_p.def_abstract_eval( lambda *args: jax.abstract_arrays.make_shaped_array( solve_eval(fenics_function, fenics_templates, *args)[0])) def jax_solve_eval_batch(vector_arg_values, batch_axes): assert len( set(batch_axes)) == 1 # assert that all batch axes are same assert ( batch_axes[0] == 0 ) # assert that batch axis is zero, need to rewrite for a general case? res = list(map(jax_solve_eval, *vector_arg_values)) res = np.asarray(res) return res, batch_axes[0] jax.interpreters.batching.primitive_batchers[ jax_solve_eval_p] = jax_solve_eval_batch # @trace("jvp_jax_solve_eval") def jvp_jax_solve_eval(ps, ts): return jvp_jax_solve_eval_p.bind(ps, ts) jvp_jax_solve_eval_p = Primitive("jvp_jax_solve_eval") jvp_jax_solve_eval_p.multiple_results = True jvp_jax_solve_eval_p.def_impl(lambda ps, ts: jvp_solve_eval( fenics_function, fenics_templates, ps, ts)) jax.interpreters.ad.primitive_jvps[ jax_solve_eval_p] = jvp_jax_solve_eval # TODO: JAX Tracer goes inside fenics wrappers and zero array is returned # because fenics numpy conversion works only for concrete arrays vjp_jax_solve_eval_p = Primitive("vjp_jax_solve_eval") vjp_jax_solve_eval_p.def_impl(lambda ct, *args: vjp_solve_eval( fenics_function, fenics_templates, *args)[1](ct)) jax.interpreters.ad.primitive_transposes[ jax_solve_eval_p] = vjp_jax_solve_eval_p return jax_solve_eval
from ..utils import ( HashableMPIType, default_primitive_impl, to_dtype_handle, to_mpi_handle, unpack_hashable, wrap_as_hashable, xla_constant_intc, xla_constant_uintptr, ) from ..decorators import translation_rule_cpu, translation_rule_gpu from ..validation import enforce_types from ..comm import get_default_comm # The Jax primitive mpi_gather_p = Primitive("gather_mpi") # Create the primitive mpi_gather_impl = default_primitive_impl(mpi_gather_p) # This function applies the primitive to an AST @enforce_types( root=(_np.integer), comm=(type(None), _MPI.Intracomm, HashableMPIType), token=(type(None), xla.Token, core.Tracer), ) def gather( x, root, *, comm=None, token=None,
def vjp_solve_eval(fenics_function: Callable, fenics_templates: FenicsVariable, *args: np.array) -> Tuple[np.array, Callable]: """Computes the gradients of the output with respect to the input Input: fenics_function (callable): FEniCS function to be executed during the forward pass args (tuple): jax array representation of the input to fenics_function Output: A pair where the first element is the value of fun applied to the arguments and the second element is a Python callable representing the VJP map from output cotangents to input cotangents. The returned VJP function must accept a value with the same shape as the value of fun applied to the arguments and must return a tuple with length equal to the number of positional arguments to fun. """ numpy_output, fenics_solution, residual_form, fenics_inputs, bcs = solve_eval( fenics_function, fenics_templates, *args) # @trace("vjp_fun1") def vjp_fun1(g): return vjp_fun1_p.bind(g) vjp_fun1_p = Primitive("vjp_fun1") vjp_fun1_p.multiple_results = True vjp_fun1_p.def_impl(lambda g: tuple( vjp if vjp is not None else jax.ad_util.zeros_like_jaxval(args[i]) for i, vjp in enumerate( vjp_solve_eval_impl(g, fenics_solution, residual_form, fenics_inputs, bcs)))) # @trace("vjp_fun1_abstract_eval") def vjp_fun1_abstract_eval(g): if len(args) > 1: return tuple( (jax.abstract_arrays.ShapedArray(arg.shape, arg.dtype) for arg in args)) else: return (jax.abstract_arrays.ShapedArray((1, *args[0].shape), args[0].dtype), ) vjp_fun1_p.def_abstract_eval(vjp_fun1_abstract_eval) # @trace("vjp_fun1_batch") def vjp_fun1_batch(vector_arg_values, batch_axes): """Computes the batched version of the primitive. This must be a JAX-traceable function. Args: vector_arg_values: a tuple of arguments, each being a tensor of matching shape. batch_axes: the axes that are being batched. See vmap documentation. Returns: a tuple of the result, and the result axis that was batched. """ # _trace("Using vjp_fun1 to compute the batch:") assert ( batch_axes[0] == 0 ) # assert that batch axis is zero, need to rewrite for a general case? # compute function row-by-row res = [ vjp_fun1(vector_arg_values[0][i]) for i in range(vector_arg_values[0].shape[0]) ] # transpose resulting list res_T = list(itertools.zip_longest(*res)) return tuple(map(np.vstack, res_T)), (batch_axes[0], ) * len(args) jax.interpreters.batching.primitive_batchers[vjp_fun1_p] = vjp_fun1_batch return numpy_output, vjp_fun1
def eig_jvp_rule(primals, tangents, *, compute_left_eigenvectors, compute_right_eigenvectors): if compute_left_eigenvectors or compute_right_eigenvectors: raise NotImplementedError( 'The derivatives of eigenvectors are not implemented, only ' 'eigenvalues. See ' 'https://github.com/google/jax/issues/2748 for discussion.') # Formula for derivative of eigenvalues w.r.t. a is eqn 4.60 in # https://arxiv.org/abs/1701.00392 a, = primals da, = tangents l, v = eig(a, compute_left_eigenvectors=False) return [l], [jnp.sum(_solve(v, da.astype(v.dtype)) * _T(v), -1)] eig_p = Primitive('eig') eig_p.multiple_results = True eig_p.def_impl(eig_impl) eig_p.def_abstract_eval(eig_abstract_eval) xla.translations[eig_p] = eig_translation_rule xla.backend_specific_translations['cpu'][eig_p] = eig_cpu_translation_rule batching.primitive_batchers[eig_p] = eig_batching_rule ad.primitive_jvps[eig_p] = eig_jvp_rule # Symmetric/Hermitian eigendecomposition def eigh_impl(operand, lower): v, w = xla.apply_primitive(eigh_p, operand, lower=lower) return v, w
scale = 1 / prod(fft_lengths) out = scale * mask * x assert out.dtype == _complex_dtype(t.dtype), (out.dtype, t.dtype) # Use JAX's convention for complex gradients # https://github.com/google/jax/issues/6223#issuecomment-807740707 return lax.conj(out) def fft_transpose_rule(t, operand, fft_type, fft_lengths): if fft_type == xla_client.FftType.RFFT: result = _rfft_transpose(t, fft_lengths) elif fft_type == xla_client.FftType.IRFFT: result = _irfft_transpose(t, fft_lengths) else: result = fft(t, fft_type, fft_lengths) return result, def fft_batching_rule(batched_args, batch_dims, fft_type, fft_lengths): x, = batched_args bd, = batch_dims x = batching.moveaxis(x, bd, 0) return fft(x, fft_type, fft_lengths), 0 fft_p = Primitive('fft') fft_p.def_impl(fft_impl) fft_p.def_abstract_eval(fft_abstract_eval) xla.translations[fft_p] = fft_translation_rule ad.deflinear2(fft_p, fft_transpose_rule) batching.primitive_batchers[fft_p] = fft_batching_rule if pocketfft: xla.backend_specific_translations['cpu'][fft_p] = pocketfft.pocketfft
def _random_key_abstract_eval(*args, **params): assert len(args) == 0 assert len(params) == 0 return ShapedArray((2,), 'uint32') def _random_key_impl(*args, **params): assert len(args) == 0 assert len(params) == 0 raise ValueError("This parametrized function is randomized and therefore requires " "a random key when applied, i. e. `apply(*inputs, key=PRNGKey(0))`.") random_key_p = Primitive("random_key") random_key_p.def_custom_bind(bind(random_key_p)) random_key_p.def_impl(_random_key_impl) random_key_p.def_abstract_eval(_random_key_abstract_eval) def random_key(): """When called inside a parametrized function, this will return a unique random key derived from the `key` argument of `apply` or `init_parameters`.""" return random_key_p.bind() class parametrized(Primitive): """Represents a parametrized function, providing an `init_parameters` function for bundled initialization of all parameters, and an `apply` function to evaluate the function given these bundled parameters.
def qr_translation_rule(c, operand, full_matrices): return c.QR(operand, full_matrices=full_matrices) def qr_abstract_eval(operand, full_matrices): if isinstance(operand, ShapedArray): if operand.ndim < 2: raise ValueError( "Argument to QR decomposition must have ndims >= 2") batch_dims = operand.shape[:-2] m = operand.shape[-2] n = operand.shape[-1] k = m if full_matrices else min(m, n) q = ShapedArray(batch_dims + (m, k), operand.dtype) r = ShapedArray(batch_dims + (k, n), operand.dtype) else: q = operand r = operand return core.AbstractTuple((q, r)) def qr_dtype_rule(operand, full_matrices=True): return operand.dtype qr_p = Primitive('qr') qr_p.def_impl(qr_impl) qr_p.def_abstract_eval(qr_abstract_eval) xla.translations[qr_p] = qr_translation_rule
from jax.lib import xla_client from jax.interpreters import xla from ..utils import ( to_mpi_ptr, _unpack_builder, _ops, _constant_s32_scalar, _constant_u64_scalar, dtype_ptr, ) from ..warn import warn_missing_omnistaging # The Jax primitive mpi_recv_p = Primitive("recv_mpi") # Create the primitive # This function applies the primitive to an AST def Recv( x, source=_MPI.ANY_SOURCE, tag=_MPI.ANY_TAG, comm=_MPI.COMM_WORLD, status=None, token=None, ): """ Recv(x, source=_MPI.ANY_SOURCE, tag=_MPI.ANY_TAG, comm=_MPI.COMM_WORLD, status=None, token=None) Receives the input`x` from the target rank `source` using the communicator `comm`
from .. import create_token from ..utils import ( to_mpi_ptr, _unpack_builder, _ops, _constant_s32_scalar, _constant_u64_scalar, dtype_ptr, ) from ..warn import warn_missing_omnistaging # The Jax primitive mpi_allreduce_p = Primitive("allreduce_mpi") # Create the primitive # This function applies the primitive to an AST def Allreduce(x, op, comm=_MPI.COMM_WORLD, token=None): """ Allreduce(x, op, comm=_MPI.COMM_WORLD, token=None) Performs the Allreduce operation `op` on the input `x` using the communicator `comm` which defaults to the world comunicator. An optional token can be passed, which is used to force jax to execute MPI operations in the correct order. Argumemnts: x: Array or scalar input. op: The reduction operation `MPI.Op` (e.g: MPI.SUM)
operands=(a, b, c, d), shape_with_layout=shape, operand_shapes_with_layout=(shape, ) * 4, opaque=opaque) def tridiag(a, b, c, d): if not a.shape == b.shape == c.shape == d.shape: raise ValueError('all inputs must have identical shape') if not a.dtype == b.dtype == c.dtype == d.dtype: raise ValueError('all inputs must have the same dtype') return tridiag_p.bind(a, b, c, d) #transpose(res, (0,2,1)) def tridiag_impl(*args, **kwargs): return xla.apply_primitive(tridiag_p, *args, **kwargs) def _tridiag_gpu_translation_rule(computation_builder, a, b, c, d): return _tridiag(computation_builder, a, b, c, d) def tridiag_abstract_eval(a, b, c, d): return abstract_arrays.ShapedArray(a.shape, a.dtype) tridiag_p = Primitive('tridiag') tridiag_p.def_impl(tridiag_impl) tridiag_p.def_abstract_eval(tridiag_abstract_eval) xla.backend_specific_translations['gpu'][ tridiag_p] = _tridiag_gpu_translation_rule
_nan_like(c, w)) vl = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), vl, _nan_like(c, vl)) vr = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), vr, _nan_like(c, vr)) return xops.Tuple(c, [w, vl, vr]) def eig_batching_rule(batched_args, batch_dims): x, = batched_args bd, = batch_dims x = batching.moveaxis(x, bd, 0) return eig_p.bind(x), (0, 0, 0) eig_p = Primitive('eig') eig_p.multiple_results = True eig_p.def_impl(eig_impl) eig_p.def_abstract_eval(eig_abstract_eval) xla.translations[eig_p] = eig_translation_rule xla.backend_specific_translations['cpu'][eig_p] = eig_cpu_translation_rule batching.primitive_batchers[eig_p] = eig_batching_rule # Symmetric/Hermitian eigendecomposition def eigh_impl(operand, lower): v, w = xla.apply_primitive(eigh_p, operand, lower=lower) return v, w
grads = vmap(_standard_gamma_grad_one)(samples, alphas) return grads.reshape(alpha.shape) def _standard_gamma_abstract_eval(key, alpha, jaxpr, aval, consts): return lax.maybe_tracer_tuple_to_abstract_tuple(aval) def _standard_gamma_translate(c, key, alpha, jaxpr, aval, consts): xla_computation = xla.jaxpr_computation(jaxpr, consts, (), c.GetShape(key), c.GetShape(alpha)) return c.Call(xla_computation, (key, alpha)) # define primitive standard_gamma_p = Primitive('standard_gamma') standard_gamma_p.def_impl(partial(xla.apply_primitive, standard_gamma_p)) standard_gamma_p.def_abstract_eval(_standard_gamma_abstract_eval) xla.translations[standard_gamma_p] = _standard_gamma_translate ad.defjvp2( standard_gamma_p, None, lambda tangent, sample, key, alpha, **kwargs: tangent * _standard_gamma_grad(sample, alpha)) @partial(jit, static_argnums=(2, 3)) def standard_gamma(key, alpha, shape=(), dtype=np.float32): shape = shape or np.shape(alpha) alpha = lax.convert_element_type(alpha, dtype) if np.shape(alpha) != shape: alpha = np.broadcast_to(alpha, shape) jaxpr, out, consts = partial_eval.trace_unwrapped_to_jaxpr(