def eig_abstract_eval(operand, *, compute_left_eigenvectors, compute_right_eigenvectors): if isinstance(operand, ShapedArray): if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]: raise ValueError( "Argument to nonsymmetric eigendecomposition must have " "shape [..., n, n], got shape {}".format(operand.shape)) batch_dims = operand.shape[:-2] n = operand.shape[-1] dtype = np.complex64 if dtypes.finfo( operand.dtype).bits == 32 else np.complex128 dtype = dtypes.canonicalize_dtype(dtype) vl = vr = ShapedArray(batch_dims + (n, n), dtype) w = ShapedArray(batch_dims + (n, ), dtype) else: raise NotImplementedError output = [w] if compute_left_eigenvectors: output.append(vl) if compute_right_eigenvectors: output.append(vr) return tuple(output)
def _lu_abstract_eval(operand): if isinstance(operand, ShapedArray): if operand.ndim < 2: raise ValueError("Argument to LU decomposition must have ndims >= 2") batch_dims = operand.shape[:-2] m = operand.shape[-2] n = operand.shape[-1] pivot = ShapedArray(batch_dims + (min(m, n),), jnp.int32) perm = ShapedArray(batch_dims + (m,), jnp.int32) else: pivot = operand perm = operand return operand, pivot, perm
def qr_abstract_eval(operand, full_matrices): if isinstance(operand, ShapedArray): if operand.ndim < 2: raise ValueError("Argument to QR decomposition must have ndims >= 2") batch_dims = operand.shape[:-2] m = operand.shape[-2] n = operand.shape[-1] k = m if full_matrices else min(m, n) q = ShapedArray(batch_dims + (m, k), operand.dtype) r = ShapedArray(batch_dims + (k, n), operand.dtype) else: q = operand r = operand return q, r
def eigh_abstract_eval(operand, lower): if isinstance(operand, ShapedArray): if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]: raise ValueError( "Argument to symmetric eigendecomposition must have shape [..., n, n]," "got shape {}".format(operand.shape)) batch_dims = operand.shape[:-2] n = operand.shape[-1] v = ShapedArray(batch_dims + (n, n), operand.dtype) w = ShapedArray(batch_dims + (n, ), lax_internal._complex_basetype(operand.dtype)) else: v, w = operand, operand return v, w
def all_reduce(x): replica_groups_protos = xc.make_replica_groups( _replica_groups(axis_env, axis_name, axis_index_groups)) scalar = ShapedArray((), c.get_shape(x).numpy_dtype()) computation = xla.primitive_subcomputation(prim, scalar, scalar) return xops.AllReduce(x, computation, replica_groups_protos, None, None)
def _all_to_all_abstract_eval(x, axis_name, split_axis, concat_axis, axis_index_groups): input_aval = raise_to_shaped(x) shape = list(input_aval.shape) size = shape.pop(split_axis) shape.insert(concat_axis, size) return ShapedArray(tuple(shape), input_aval.dtype, weak_type=False)
def omnistaging_disabler() -> None: global axis_index psum_p.bind = partial(core.Primitive.bind, psum_p) # type: ignore psum_p.def_impl(partial(pxla.apply_parallel_primitive, psum_p)) # type: ignore pxla.parallel_pure_rules[psum_p] = lambda *args, shape: (x * prod(shape) for x in args) # type: ignore def _axis_index_bind(*, axis_name): dynamic_axis_env = pxla._thread_local_state.dynamic_axis_env frame = dynamic_axis_env[axis_name] sizes = dynamic_axis_env.sizes[:dynamic_axis_env.index(frame)+1] nreps = dynamic_axis_env.nreps trace = frame.pmap_trace out_aval = ShapedArray((), np.int32) out_tracer = pe.JaxprTracer(trace, pe.PartialVal.unknown(out_aval), None) eqn = pe.new_eqn_recipe([], [out_tracer], axis_index_p, dict(nreps=nreps, sizes=sizes, axis_name=axis_name), source_info_util.current()) out_tracer.recipe = eqn return out_tracer def _axis_index_translation_rule(c, nreps, sizes, axis_name): div = xb.constant(c, np.array(nreps // prod(sizes), dtype=np.uint32)) mod = xb.constant(c, np.array(sizes[-1], dtype=np.uint32)) unsigned_index = xops.Rem(xops.Div(xops.ReplicaId(c), div), mod) return xops.ConvertElementType(unsigned_index, xb.dtype_to_etype(np.int32)) axis_index_p.def_custom_bind(_axis_index_bind) axis_index_p.def_abstract_eval( lambda *args, **params: ShapedArray((), np.int32)) xla.translations[axis_index_p] = _axis_index_translation_rule
def _allreduce_translation_rule(prim, c, val, *, axis_name, axis_index_groups, axis_env, platform): replica_groups = _replica_groups(axis_env, axis_name, axis_index_groups) dtype = c.get_shape(val).numpy_dtype() scalar = ShapedArray((), dtype) computation = xla.primitive_subcomputation(prim, scalar, scalar) replica_groups_protos = xc.make_replica_groups(replica_groups) return xops.AllReduce(val, computation, replica_groups_protos, None, None)
def svd_abstract_eval(operand, full_matrices, compute_uv): if isinstance(operand, ShapedArray): if operand.ndim < 2: raise ValueError("Argument to singular value decomposition must have ndims >= 2") batch_dims = operand.shape[:-2] m = operand.shape[-2] n = operand.shape[-1] s = ShapedArray(batch_dims + (min(m, n),), lax_internal._complex_basetype(operand.dtype)) if compute_uv: u = ShapedArray(batch_dims + (m, m if full_matrices else min(m, n)), operand.dtype) vt = ShapedArray(batch_dims + (n if full_matrices else min(m, n), n), operand.dtype) return s, u, vt else: return s, else: raise NotImplementedError
def get_structure(eqn: Optional[JaxprEqn], invals: List[Union[ShapedArray, AbstractValue]], idx: int, _s_rules: bool) -> Structure: if any(i is AbstractValue for i in invals): raise TypeError(invals) if eqn is None: # Identity function primitive = None cts_in = invals[0] assert idx == 0 else: if len(eqn.outvars) != 1: raise NotImplementedError(eqn) cts_in = eqn.outvars[0].aval primitive = eqn.primitive assert len(invals) == len(eqn.invars) assert 0 <= idx < len(eqn.invars) if not isinstance(cts_in, ShapedArray): raise TypeError(cts_in) if primitive in STRUCTURE_RULES and _s_rules: structure = STRUCTURE_RULES[primitive](eqn, idx, invals, cts_in) else: # No simplification rule found. structure = Structure() # TODO(romann): can we avoid special-casing `reshape`s? if primitive == lax.reshape_p: cts_in = ShapedArray(invals[idx].shape, invals[idx].dtype) # Check that number of trace output and input axes match. assert len(structure.in_trace) == len(structure.out_trace) # Check that input and output traced sizes are the same. out_trace_size = utils.size_at(cts_in, structure.out_trace) in_trace_size = utils.size_at(invals[idx], structure.in_trace) assert in_trace_size == out_trace_size # Check that number of input/output diagonal axes match. assert len(structure.out_diagonal) == len(structure.in_diagonal) # Check for each output diagonal axis there's only input axes of correct # size or `None`. Inval axis should be not `None`. for out_d, in_d in zip(structure.out_diagonal, structure.in_diagonal): assert len(in_d) == len(invals) assert in_d[idx] is not None for ix, i in enumerate(in_d): if i is not None: assert invals[ix].shape[i] == cts_in.shape[out_d] return structure
def fft_abstract_eval(x, fft_type, fft_lengths): if fft_type == xla_client.FftType.RFFT: shape = (x.shape[:-len(fft_lengths)] + fft_lengths[:-1] + (fft_lengths[-1] // 2 + 1, )) dtype = _complex_dtype(x.dtype) elif fft_type == xla_client.FftType.IRFFT: shape = x.shape[:-len(fft_lengths)] + fft_lengths dtype = _real_dtype(x.dtype) else: shape = x.shape dtype = x.dtype return ShapedArray(shape, dtype)
def _reduce_window_sum_translation_rule(ctx, avals_in, avals_out, operand, *, window_dimensions, window_strides, padding, base_dilation, window_dilation): operand_aval, = avals_in scalar = ShapedArray((), operand_aval.dtype) return [ xops.ReduceWindowWithGeneralPadding( operand, xla.pyval_to_ir_constant(ctx.builder, np.array(0, operand_aval.dtype)), xla.primitive_subcomputation(ctx.platform, ctx.axis_env, lax.add_p, scalar, scalar), window_dimensions, window_strides, base_dilation, window_dilation, padding) ]
def _axis_index_bind(*, axis_name): dynamic_axis_env = pxla._thread_local_state.dynamic_axis_env frame = dynamic_axis_env[axis_name] sizes = dynamic_axis_env.sizes[:dynamic_axis_env.index(frame)+1] nreps = dynamic_axis_env.nreps trace = frame.pmap_trace out_aval = ShapedArray((), np.int32) out_tracer = pe.JaxprTracer(trace, pe.PartialVal.unknown(out_aval), None) eqn = pe.new_eqn_recipe([], [out_tracer], axis_index_p, dict(nreps=nreps, sizes=sizes, axis_name=axis_name), source_info_util.current()) out_tracer.recipe = eqn return out_tracer
def _allreduce_translation_rule(prim, c, *args, axis_name, axis_index_groups, axis_env, platform): if platform in ("cpu", "tpu"): return _notuple_allreduce_translation_rule( prim, c, *args, axis_name=axis_name, axis_index_groups=axis_index_groups, axis_env=axis_env, platform=platform) # XLA's tuple all-reduce doesn't support different dtypes in the same # allreduce. Instead, we perform once all-reduce for each argument input type. args_by_type = collections.defaultdict(lambda: ([], [])) for i, arg in enumerate(args): indices, dtype_args = args_by_type[c.get_shape(arg).numpy_dtype()] indices.append(i) dtype_args.append(arg) # The outputs, in the original argument order. out = [None] * len(args) replica_groups = _replica_groups(axis_env, axis_name, axis_index_groups) replica_groups_protos = xc.make_replica_groups(replica_groups) for dtype, (indices, dtype_args) in sorted(args_by_type.items()): is_complex = dtypes.issubdtype(dtype, np.complexfloating) n = len(dtype_args) if is_complex and prim is lax.add_p: # TODO(b/141575627): we handle complex-dtype sum-reduction directly as a # special case because it's not currently handled by XLA:GPU dtype_args = ([xops.Real(x) for x in dtype_args] + [xops.Imag(x) for x in dtype_args]) scalar = ShapedArray((), c.get_shape(dtype_args[0]).numpy_dtype()) computation = xla.primitive_subcomputation(prim, scalar, scalar) all_reduce = xops.AllReduce(xops.Tuple(c, dtype_args), computation, replica_groups_protos, None, None) if is_complex and prim is lax.add_p: xs = [ xops.Complex(xops.GetTupleElement(all_reduce, i), xops.GetTupleElement(all_reduce, n + i)) for i in range(n) ] else: xs = [xops.GetTupleElement(all_reduce, i) for i in range(n)] for i, x in zip(indices, xs): out[i] = x return xops.Tuple(c, out)
def _reduce_window_abstract_eval_rule( *avals, jaxpr, consts, window_dimensions, window_strides, padding, base_dilation, window_dilation): operand_avals, init_val_avals = util.split_list(avals, [len(avals) // 2]) if any(o.dtype != iv.dtype for o, iv in zip(operand_avals, init_val_avals)): msg = ("reduce_window got inconsistent dtypes for operands and init_values:" " got operand dtypes {} and init_value dtypes {}.") raise TypeError(msg.format([o.dtype for o in operand_avals], [iv.dtype for iv in init_val_avals])) if any(len(v.shape) != 0 for v in init_val_avals): msg = ("reduce_window expected init_values to be scalars but init_values " "have shapes {}.") raise TypeError(msg.format([v.shape for v in init_val_avals])) out_shape = _common_reduce_window_shape_rule( operand_avals[0], window_dimensions, window_strides, padding, base_dilation, window_dilation) return tuple(ShapedArray(out_shape, op.dtype) for op in operand_avals)
def _select_and_scatter_add_translation(ctx, avals_in, avals_out, source, operand, *, select_prim, window_dimensions, window_strides, padding, expand_padding): source_aval, operand_aval = avals_in c = ctx.builder dtype = operand_aval.dtype scalar = ShapedArray((), dtype) select = xla.primitive_subcomputation(ctx.platform, ctx.axis_env, select_prim, scalar, scalar) scatter = xla.primitive_subcomputation( ctx.platform, ctx.axis_env, lax.or_p if dtype == np.bool_ else lax.add_p, scalar, scalar) zero = xla.pyval_to_ir_constant(c, np.array(0, dtype)) # TODO(b/161704903): remove this workaround when XLA:CPU bug is fixed. expand_padding = (expand_padding and not all(lo == 0 and hi == 0 for (lo, hi) in padding)) if expand_padding: original_padding = padding identity = (lax._get_max_identity if select_prim is lax.ge_p else lax._get_min_identity) pads = [(lo, hi, 0) for (lo, hi) in padding] operand = xops.Pad(operand, xla.pyval_to_ir_constant(c, identity(dtype)), xc.make_padding_config(pads)) padding = [(0, 0) for _ in padding] output = xops.SelectAndScatterWithGeneralPadding(operand, select, window_dimensions, window_strides, padding, source, zero, scatter) if expand_padding: start_indices = [lo for (lo, hi) in original_padding] stop_indices = [ lo + d for ((lo, hi), d) in zip(original_padding, operand_aval.shape) ] output = xops.Slice(output, start_indices, stop_indices, [1] * len(start_indices)) return [output]
def sharded_aval(aval: core.ShapedArray, sharding: Optional[xc.OpSharding]) -> core.ShapedArray: """Returns the new aval sharded based on sharding proto.""" if sharding is None: return aval if (sharding.type == xc.OpSharding.Type.REPLICATED or sharding.type == xc.OpSharding.Type.MANUAL): return aval sharded_shape = [] tile_rank = len(sharding.tile_assignment_dimensions) if sharding.replicate_on_last_tile_dim: tile_rank -= 1 if sharding.last_tile_dims: tile_rank -= len(sharding.last_tile_dims) if tile_rank == 0: return aval for i in range(tile_rank): partitions = sharding.tile_assignment_dimensions[i] assert partitions > 0 sharded_shape.append((aval.shape[i] + partitions - 1) // partitions) return aval.update(tuple(sharded_shape))
div = xb.constant(c, np.array(nreplicas * prod(axis_env.sizes[axis_pos+1:]), dtype=np.uint32)) mod = xb.constant(c, np.array(axis_env.sizes[axis_pos], dtype=np.uint32)) unsigned_index = xops.Rem(xops.Div(xops.ReplicaId(c), div), mod) return xops.ConvertElementType(unsigned_index, xb.dtype_to_etype(np.int32)) def _axis_index_soft_pmap_rule(vals, mapped, chunk_size, *, axis_name): assert not vals and not mapped idx = axis_index(axis_name) # type: ignore return idx * chunk_size + np.arange(chunk_size, dtype=np.int32), True axis_index_p = core.Primitive('axis_index') xla.parallel_translations[axis_index_p] = _axis_index_translation_rule pxla.soft_pmap_rules[axis_index_p] = _axis_index_soft_pmap_rule # type: ignore axis_index_p.def_abstract_eval( lambda *args, **params: ShapedArray((), np.int32)) pxla.multi_host_supported_collectives.add(axis_index_p) # Axis index doesn't get any arguments, so that the default bind would have no # way to call into a data-dependency based trace such as vmap. Each trace that # wants to bind an axis name has to additionally implement `process_axis_index` # and put its main trace on the axis env stack. def _axis_index_bind(*, axis_name): if not isinstance(axis_name, (tuple, list)): axis_name = (axis_name,) inner_size = 1 index = 0 for name in reversed(axis_name): frame = core.axis_frame(name) if frame.main_trace is not None: trace = frame.main_trace.with_cur_sublevel()
def _get_invals(idx, *xs): return [ ShapedArray(x.shape, x.dtype) if idx == i else x for i, x in enumerate(xs) ]
def _test_primitive(self, primitive: Optional[Primitive], shapes, dtype, params): xs = _get_inputs(shapes, dtype) n = len(xs) eqn, f = _get_f_and_eqn(params, primitive, *xs) out = f(*xs) cts_in = ShapedArray(out.shape, out.dtype) argnums = tuple(range(n)) js_fwd = jax.jacfwd(f, argnums)(*xs) js_rev = jax.jacrev(f, argnums)(*xs) for idx in range(n): if primitive == lax.conv_general_dilated_p and idx == 0: raise absltest.SkipTest( 'Jacobian of CNN wrt inputs not implemented.') if primitive == lax.div_p and idx == 1: raise absltest.SkipTest( 'Division is linear only in the first arg.') invals = _get_invals(idx, *xs) j_fwd, j_rev = js_fwd[idx], js_rev[idx] if primitive in rules.JACOBIAN_RULES: j_rule = rules.JACOBIAN_RULES[primitive](eqn, idx, invals, cts_in) else: warnings.warn( f'Jacobian rule for {primitive} at position {idx} not ' f'found.') j_rule = None with self.subTest(f'Jacobian ({idx})'): self._compare_jacobians(j_fwd, j_rev, j_rule, primitive) structure = rules.STRUCTURE_RULES[primitive](eqn, idx, invals, cts_in) j = j_fwd if j_rule is None else j_rule if primitive == lax.reshape_p: out_ndim = xs[0].ndim j = j.transpose( tuple(xs[0].ndim + i for i in onp.argsort(structure.in_trace)) + tuple(i for i in onp.argsort(structure.in_trace))) j = j.reshape(xs[0].shape + tuple(xs[0].shape[i] for i in onp.argsort(structure.in_trace))) else: out_ndim = out.ndim with self.subTest(f'Diagonal axes ({idx})'): for i, o in zip(structure.in_diagonal, structure.out_diagonal): self._assert_is_diagonal(j=j, axis1=out_ndim + i[idx], axis2=o, constant_diagonal=False) with self.subTest(f'Constant diagonal axes ({idx})'): for i, o in zip(structure.in_trace, structure.out_trace): self._assert_is_diagonal(j=j, axis1=out_ndim + i, axis2=o, constant_diagonal=True) with self.subTest(f'Input broadcast axes ({idx})'): for i in structure.in_broadcast: self._assert_constant(j=j, axis=i) with self.subTest(f'Output broadcast axes ({idx})'): for i in structure.out_broadcast: self._assert_constant(j=j, axis=i)
from jax import core from jax import lax from jax import numpy as jnp from jax import test_util as jtu from jax.abstract_arrays import make_shaped_array from jax.api import jvp, linearize, vjp, jit, make_jaxpr from jax.core import UnshapedArray, ShapedArray from jax.tree_util import tree_flatten, tree_unflatten, tree_multimap, tree_reduce, tree_leaves from jax.util import partial from jax.interpreters import partial_eval as pe from jax.config import config config.parse_flags_with_absl() _ = pe.PartialVal.unknown(UnshapedArray(np.float32)) __ = pe.PartialVal.unknown(ShapedArray((), np.float32)) def call(f, *args): return jit(f)(*args) def simple_fun(x, y): return jnp.sin(x * y) def simple_fun_fanout(x, y): return jnp.sin(x * y) * x def fun_with_call(x):
def aval(self): return ShapedArray(self.polymorphic_shape, dtypes.canonicalize_dtype(self.dtype))
def _make_abstract_python_scalar(typ, val): return ShapedArray((), dtypes._scalar_type_to_dtype(typ, val), weak_type=True)
def _array_aval_from_xla_shape(xla_shape): # This function instantiates the assumption that we can map fro XLA array # types to JAX array types. # TODO(mattjj): remove assumption can map XLA array types to JAX array types assert not xla_shape.is_tuple() return ShapedArray(xla_shape.dimensions(), xla_shape.numpy_dtype())
def _all_gather_abstract_eval(x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size): x_aval = raise_to_shaped(x) new_shape = list(x_aval.shape) new_shape.insert(all_gather_dimension, axis_size) return ShapedArray(new_shape, x_aval.dtype)
dtype=np.uint32)) mod = xb.constant(c, np.array(axis_env.sizes[axis_pos], dtype=np.uint32)) unsigned_index = xops.Rem(xops.Div(xops.ReplicaId(c), div), mod) return xops.ConvertElementType(unsigned_index, xb.dtype_to_etype(np.int32)) def _axis_index_soft_pmap_rule(vals, mapped, chunk_size, *, axis_name): assert not vals and not mapped idx = axis_index(axis_name) # type: ignore return idx * chunk_size + np.arange(chunk_size, dtype=np.int32), True axis_index_p = core.Primitive('axis_index') xla.parallel_translations[axis_index_p] = _axis_index_translation_rule pxla.soft_pmap_rules[axis_index_p] = _axis_index_soft_pmap_rule # type: ignore axis_index_p.def_abstract_eval(lambda *args, **params: ShapedArray( (), np.int32)) pxla.multi_host_supported_collectives.add(axis_index_p) # Axis index doesn't get any arguments, so that the default bind would have no # way to call into a data-dependency based trace such as vmap. Each trace that # wants to bind an axis name has to additionally implement `process_axis_index` # and put its main trace on the axis env stack. def _axis_index_bind(*, axis_name): if not isinstance(axis_name, (tuple, list)): axis_name = (axis_name, ) inner_size = 1 index = 0 for name in reversed(axis_name): frame = core.axis_frame(name) if frame.main_trace is not None:
def aval(self): return ShapedArray(self.polymorphic_shape, self.dtype)