def read_shaped(v): if isinstance(v, core.Literal): if isinstance(v.val, float) or isinstance(v.val, int): return ShapedArray((), type(v.val)) return ShapedArray(v.val.shape, v.val.dtype) else: return abstract[v]
def _contact_points_abstract_eval(*args): if any(a.dtype != np.float64 for a in args): raise ValueError("float64 precision is required") shape = args[0].shape if any(a.shape != shape for a in args[1:]): raise ValueError("Dimension mismatch") return ( ShapedArray(shape, np.float64), ShapedArray(shape, np.float64), ShapedArray(shape, np.int32), )
def eigh_abstract_eval(operand, lower): if isinstance(operand, ShapedArray): if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]: raise ValueError( "Argument to symmetric eigendecomposition must have shape [..., n, n]") batch_dims = operand.shape[:-2] n = operand.shape[-1] v = ShapedArray(batch_dims + (n, n), operand.dtype) w = ShapedArray(batch_dims + (n,), operand.dtype) else: v, w = operand, operand return core.AbstractTuple((v, w))
def eig_abstract_eval(operand): if isinstance(operand, ShapedArray): if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]: raise ValueError("Argument to nonsymmetric eigendecomposition must have " "shape [..., n, n], got shape {}".format(operand.shape)) batch_dims = operand.shape[:-2] n = operand.shape[-1] vl = vr = ShapedArray(batch_dims + (n, n), operand.dtype) w = ShapedArray(batch_dims + (n,), lax.lax._complex_basetype(operand.dtype)) else: w = vl = vr = operand return core.AbstractTuple((w, vl, vr))
def eigh_abstract_eval(operand, lower): if isinstance(operand, ShapedArray): if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]: raise ValueError( "Argument to symmetric eigendecomposition must have shape [..., n, n]," "got shape {}".format(operand.shape)) batch_dims = operand.shape[:-2] n = operand.shape[-1] v = ShapedArray(batch_dims + (n, n), operand.dtype) w = ShapedArray(batch_dims + (n,), lax.lax._complex_basetype(operand.dtype)) else: v, w = operand, operand return v, w
def svd_abstract_eval(operand, full_matrices, compute_uv): if isinstance(operand, ShapedArray): if operand.ndim < 2: raise ValueError("Argument to singular value decomposition must have ndims >= 2") batch_dims = operand.shape[:-2] m = operand.shape[-2] n = operand.shape[-1] s = ShapedArray(batch_dims + (min(m, n),), lax.lax._complex_basetype(operand.dtype)) u = ShapedArray(batch_dims + (m, m if full_matrices else min(m, n)), operand.dtype) vt = ShapedArray(batch_dims + (n if full_matrices else min(m, n), n), operand.dtype) else: raise NotImplementedError return s, u, vt
def _lu_abstract_eval(operand): if isinstance(operand, ShapedArray): if operand.ndim < 2: raise ValueError("Argument to LU decomposition must have ndims >= 2") batch_dims = operand.shape[:-2] m = operand.shape[-2] n = operand.shape[-1] pivot = ShapedArray(batch_dims + (min(m, n),), jnp.int32) perm = ShapedArray(batch_dims + (m,), jnp.int32) else: pivot = operand perm = operand return operand, pivot, perm
def qr_abstract_eval(operand, full_matrices): if isinstance(operand, ShapedArray): if operand.ndim < 2: raise ValueError("Argument to QR decomposition must have ndims >= 2") batch_dims = operand.shape[:-2] m = operand.shape[-2] n = operand.shape[-1] k = m if full_matrices else min(m, n) q = ShapedArray(batch_dims + (m, k), operand.dtype) r = ShapedArray(batch_dims + (k, n), operand.dtype) else: q = operand r = operand return q, r
def eig_abstract_eval(operand): if isinstance(operand, ShapedArray): if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]: raise ValueError("Argument to nonsymmetric eigendecomposition must have " "shape [..., n, n], got shape {}".format(operand.shape)) batch_dims = operand.shape[:-2] n = operand.shape[-1] dtype = onp.complex64 if onp.finfo(operand.dtype).bits == 32 else onp.complex128 dtype = xb.canonicalize_dtype(dtype) vl = vr = ShapedArray(batch_dims + (n, n), dtype) w = ShapedArray(batch_dims + (n,), dtype) else: raise NotImplementedError return w, vl, vr
def svd_abstract_eval(operand, full_matrices, compute_uv): if isinstance(operand, ShapedArray): if operand.ndim < 2: raise ValueError("Argument to singular value decomposition must have ndims >= 2") batch_dims = operand.shape[:-2] m = operand.shape[-2] n = operand.shape[-1] s = ShapedArray(batch_dims + (min(m, n),), operand.dtype) u = ShapedArray(batch_dims + (m, m if full_matrices else min(m, n)), operand.dtype) vt = ShapedArray(batch_dims + (n if full_matrices else min(m, n), n), operand.dtype) else: s = operand u = operand vt = operand return core.AbstractTuple((s, u, vt))
def _kepler_abstract_eval(M, ecc): if M.dtype != np.float64 or ecc.dtype != np.float64: raise ValueError("float64 precision is required") if M.shape != ecc.shape: raise ValueError("Dimension mismatch") out_shape = ShapedArray(M.shape, np.float64) return (out_shape, out_shape)
def _psum_translation_rule(c, *args, replica_groups=None, platform=None): if platform in ("cpu", "tpu"): return _notuple_psum_translation_rule(c, *args, replica_groups=replica_groups) # XLA's tuple all-reduce doesn't support different dtypes in the same # allreduce. Instead, we perform once all-reduce for each argument input type. args_by_type = collections.defaultdict(lambda: ([], [])) for i, arg in enumerate(args): indices, dtype_args = args_by_type[c.get_shape(arg).numpy_dtype()] indices.append(i) dtype_args.append(arg) # The outputs, in the original argument order. out = [None] * len(args) replica_groups_protos = xc.make_replica_groups(replica_groups) for dtype, (indices, dtype_args) in sorted(args_by_type.items()): is_complex = dtypes.issubdtype(dtype, onp.complexfloating) n = len(dtype_args) if is_complex: dtype_args = ([xops.Real(x) for x in dtype_args] + [xops.Imag(x) for x in dtype_args]) scalar = ShapedArray((), c.get_shape(dtype_args[0]).numpy_dtype()) computation = xla.primitive_subcomputation(lax.add_p, scalar, scalar) all_reduce = xops.AllReduce(xops.Tuple(c, dtype_args), computation, replica_groups_protos, None, None) if is_complex: xs = [xops.Complex(xops.GetTupleElement(all_reduce, i), xops.GetTupleElement(all_reduce, n + i)) for i in range(n)] else: xs = [xops.GetTupleElement(all_reduce, i) for i in range(n)] for i, x in zip(indices, xs): out[i] = x return xops.Tuple(c, out)
def _quad_solution_vector_abstract_eval(b, r): if b.dtype != np.float64 or r.dtype != np.float64: raise ValueError("float64 precision is required") if b.shape != r.shape: raise ValueError("Dimension mismatch") out_shape = ShapedArray(tuple(b.shape) + (3, ), np.float64) return (out_shape, out_shape, out_shape)
def _allreduce_translation_rule(prim, c, val, replica_groups, backend=None): dtype = c.GetShape(val).numpy_dtype() scalar = ShapedArray((), dtype) computation = xla.primitive_subcomputation(prim, scalar, scalar, backend=backend) return c.AllReduce(val, computation, replica_groups=replica_groups)
def _allreduce_translation_rule(prim, c, val, *, axis_name, axis_index_groups, axis_env, platform): replica_groups = _replica_groups(axis_env, axis_name, axis_index_groups) dtype = c.get_shape(val).numpy_dtype() scalar = ShapedArray((), dtype) computation = xla.primitive_subcomputation(prim, scalar, scalar) replica_groups_protos = xc.make_replica_groups(replica_groups) return xops.AllReduce(val, computation, replica_groups_protos, None, None)
def fft_abstract_eval(x, fft_type, fft_lengths): if not dtypes.issubdtype(x.dtype, onp.complexfloating): raise TypeError("FFT requires complex inputs, got {}.".format( x.dtype.name)) if x.dtype != onp.complex64: msg = "FFT is only implemented for complex64 types, got {}." raise NotImplementedError(msg.format(x.dtype.name)) return ShapedArray(x.shape, x.dtype)
def layer_abstract_eval(*avals): akey = ShapedArray((2, ), 'uint32') def init_and_apply(key, *inputs): params = init_fun(key, *inputs) return apply_fun(params, *inputs) return pe.abstract_eval_fun(init_and_apply, akey, *avals)
def _grid_trace_shape(fn, *args, **kwargs): """Traces a function to compute the shape of its output.""" shaped_args = [] for arg in args: if isinstance(arg, np.ndarray): shaped_args += [ShapedArray(tuple(arg.shape), arg.dtype)] else: shaped_args += [arg] return pe.abstract_eval_fun(fn, *shaped_args, **kwargs).shape
def while_loop(cond_fun, body_fun, init_val): """Call ``body_fun`` repeatedly in a loop while ``cond_fun`` is True. The type signature in brief is .. code-block:: haskell while_loop :: (a -> Bool) -> (a -> a) -> a -> a The semantics of ``while_loop`` are given by this Python implementation:: def while_loop(cond_fun, body_fun, init_val): val = init_val while cond_fun(val): val = body_fun(val) return val Unlike that Python version, ``while_loop`` is a JAX primitive and is lowered to a single XLA While HLO. That makes it useful for reducing compilation times for jit-compiled functions, since native Python loop constructs in an ``@jit`` function are unrolled, leading to large XLA computations. Another difference from using Python-native loop constructs is that ``while_loop`` is not reverse-mode differentiable because XLA computations require static bounds on memory requirements. Args: cond_fun: function of type ``a -> Bool``. body_fun: function of type ``a -> a``. init_val: value of type ``a``, a type that can be a scalar, array, or any pytree (nested Python tuple/list/dict) thereof, representing the initial loop carry value. Returns: The output from the final iteration of body_fun, of type ``a``. """ init_vals, in_tree = tree_flatten((init_val, )) init_avals = tuple(_map(_abstractify, init_vals)) cond_jaxpr, cond_consts, cond_tree = _initial_style_jaxpr( cond_fun, in_tree, init_avals) body_jaxpr, body_consts, body_tree = _initial_style_jaxpr( body_fun, in_tree, init_avals) if not treedef_is_leaf(cond_tree): msg = "cond_fun must return a boolean scalar, but got pytree {}." raise TypeError(msg.format(cond_tree)) if cond_jaxpr.out_avals != [ShapedArray((), onp.bool_)]: msg = "cond_fun must return a boolean scalar, but got output type(s) {}." raise TypeError(msg.format(cond_jaxpr.out_avals)) if not treedef_children(in_tree) == [body_tree]: msg = "body_fun output pytree structure must match init_val, got {} and {}." raise TypeError(msg.format(body_tree, treedef_children(in_tree)[0])) outs = while_p.bind(*itertools.chain(cond_consts, body_consts, init_vals), cond_nconsts=len(cond_consts), cond_jaxpr=cond_jaxpr, body_nconsts=len(body_consts), body_jaxpr=body_jaxpr) return tree_unflatten(body_tree, outs)
def fft_abstract_eval(x, fft_type, fft_lengths): if fft_type == xla_client.FftType.RFFT: shape = (x.shape[:-len(fft_lengths)] + fft_lengths[:-1] + (fft_lengths[-1] // 2 + 1, )) dtype = _complex_dtype(x.dtype) elif fft_type == xla_client.FftType.IRFFT: shape = x.shape[:-len(fft_lengths)] + fft_lengths dtype = _real_dtype(x.dtype) else: shape = x.shape dtype = x.dtype return ShapedArray(shape, dtype)
def _axis_index_bind(*, axis_name): dynamic_axis_env = pxla._thread_local_state.dynamic_axis_env frame = dynamic_axis_env[axis_name] trace = frame.pmap_trace out_aval = ShapedArray((), np.int32) out_tracer = pe.JaxprTracer(trace, pe.PartialVal.unknown(out_aval), None) eqn = pe.new_eqn_recipe([], [out_tracer], axis_index_p, dict(axis_name=axis_name), source_info_util.current()) out_tracer.recipe = eqn return out_tracer
def eig_abstract_eval(operand, *, compute_left_eigenvectors, compute_right_eigenvectors): if isinstance(operand, ShapedArray): if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]: raise ValueError("Argument to nonsymmetric eigendecomposition must have " "shape [..., n, n], got shape {}".format(operand.shape)) batch_dims = operand.shape[:-2] n = operand.shape[-1] dtype = np.complex64 if dtypes.finfo(operand.dtype).bits == 32 else np.complex128 dtype = dtypes.canonicalize_dtype(dtype) vl = vr = ShapedArray(batch_dims + (n, n), dtype) w = ShapedArray(batch_dims + (n,), dtype) else: raise NotImplementedError output = [w] if compute_left_eigenvectors: output.append(vl) if compute_right_eigenvectors: output.append(vr) return tuple(output)
def _abstract_eval(spec, *args): vals = spec["get_dims"](*(a.shape for a in args)) for s, arg in zip(spec["inputs"], args): if arg.dtype != s.get("dtype", np.float64): raise ValueError( f"Invalid dtype for '{s['name']}'; " f"expected {s.get('dtype', np.float64)}, got {arg.dtype}") shape = eval(s["shape"], dict(vals)) if arg.shape != shape: raise ValueError(f"Invalid shape for '{s['name']}'; " f"expected {shape}, got {arg.shape}") return tuple( ShapedArray(eval(s["shape"], dict(vals)), s.get("dtype", np.float64)) for s in spec["outputs"] + spec["extra_outputs"])
def abstract_call(*inputs): key_and_inputs = (ShapedArray((2, ), 'uint32'), ) + inputs flat_rng_and_inputs, in_tree_with_rng = jax.tree_flatten( key_and_inputs) flat_fun, self._cached_out_tree = jax.flatten_fun_nokwargs( self._init_and_apply, in_tree_with_rng) flat_partial_inputs = [ PartialVal((a, jc.unit)) for a in flat_rng_and_inputs ] _, flat_partial_outs, _ = trace_to_jaxpr(flat_fun, flat_partial_inputs, instantiate=True) flat_outs, _ = unzip2(flat_partial_outs) return flat_outs
def _canonicalize_displacement_or_metric(displacement_or_metric): """Checks whether or not a displacement or metric was provided.""" for dim in range(4): try: R = ShapedArray((1, dim), f32) dR_or_dr = pe.abstract_eval_fun(displacement_or_metric, R, R, t=0) if len(dR_or_dr.shape) == 2: return displacement_or_metric else: return space.metric(displacement_or_metric) except ValueError: continue raise ValueError( 'Canonicalize displacement not implemented for spatial dimension larger' 'than 4.')
def _masked_scan_jaxpr(jaxpr, num_consts, num_carry): fun = core.jaxpr_as_fun(jaxpr) @lu.wrap_init def masked(*args): [dynamic_length], consts, [i], carry, xs = split_list( args, [1, num_consts, 1, num_carry]) out = fun(*(consts + carry + xs)) new_carry, ys = split_list(out, [num_carry]) new_carry = [lax.select(i < dynamic_length, new_c, c) for new_c, c in zip(new_carry, carry)] return [i + 1] + new_carry + ys aval = ShapedArray((), onp.int64) const_avals, carry_avals, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry]) return _make_typed_jaxpr(masked, [aval] + const_avals + [aval] + carry_avals + x_avals)
def canonicalize_displacement_or_metric(displacement_or_metric): """Checks whether or not a displacement or metric was provided.""" for dim in range(1, 4): try: R = ShapedArray((dim,), f32) dR_or_dr = eval_shape(displacement_or_metric, R, R, t=0) if len(dR_or_dr.shape) == 0: return displacement_or_metric else: return metric(displacement_or_metric) except TypeError: continue except ValueError: continue raise ValueError( 'Canonicalize displacement not implemented for spatial dimension larger' 'than 4.')
def testReshardInput(self): if xla_bridge.device_count() < 6: raise SkipTest("testReshardInput requires 6 devices") # Manually construct a ShardedDeviceArray with the wrong sharding for the # subsequent pmap shard_shape = (3, 2) shard = np.arange(np.prod(shard_shape)).reshape(shard_shape) bufs = [xla.device_put(shard, d) for d in xla_bridge.devices()[:4]] aval = ShapedArray((6, 4), shard.dtype) sharding_spec = pxla.ShardingSpec(shards_per_axis=(2, 2), is_axis_materialized=(True, True), replication_factor=2) arr = pxla.ShardedDeviceArray(aval, sharding_spec, bufs) r = pmap(lambda x: x + 1)(arr) self.assertAllClose(r, arr + 1, check_dtypes=True) self.assertEqual(len(r.device_buffers), 6)
def _displacement_or_metric_to_metric_sq(displacement_or_metric): """Checks whether or not a displacement or metric was provided.""" for dim in range(1, 4): try: R = ShapedArray((dim,), f32) dR_or_dr = eval_shape(displacement_or_metric, R, R, t=0) if len(dR_or_dr.shape) == 0: return lambda Ra, Rb, **kwargs: \ displacement_or_metric(Ra, Rb, **kwargs) ** 2 else: return lambda Ra, Rb, **kwargs: space.square_distance( displacement_or_metric(Ra, Rb, **kwargs)) except TypeError: continue except ValueError: continue raise ValueError( 'Canonicalize displacement not implemented for spatial dimension larger' 'than 4.')
def omnistaging_disabler() -> None: global axis_index psum_p.bind = partial(core.Primitive.bind, psum_p) psum_p.def_impl(partial(pxla.apply_parallel_primitive, psum_p)) # type: ignore pxla.parallel_pure_rules[psum_p] = lambda *args, shape: ( x * prod(shape) for x in args) # type: ignore def _axis_index_bind(*, axis_name): dynamic_axis_env = pxla._thread_local_state.dynamic_axis_env frame = dynamic_axis_env[axis_name] sizes = dynamic_axis_env.sizes[:dynamic_axis_env.index(frame) + 1] nreps = dynamic_axis_env.nreps trace = frame.pmap_trace out_aval = ShapedArray((), np.int32) out_tracer = pe.JaxprTracer(trace, pe.PartialVal.unknown(out_aval), None) eqn = pe.new_eqn_recipe([], [out_tracer], axis_index_p, dict(nreps=nreps, sizes=sizes, axis_name=axis_name), source_info_util.current()) out_tracer.recipe = eqn return out_tracer def _axis_index_translation_rule(c, nreps, sizes, axis_name): div = xb.constant(c, np.array(nreps // prod(sizes), dtype=np.uint32)) mod = xb.constant(c, np.array(sizes[-1], dtype=np.uint32)) unsigned_index = xops.Rem(xops.Div(xops.ReplicaId(c), div), mod) return xops.ConvertElementType(unsigned_index, xb.dtype_to_etype(np.int32)) axis_index_p.def_custom_bind(_axis_index_bind) axis_index_p.def_abstract_eval(lambda *args, **params: ShapedArray( (), np.int32)) xla.translations[axis_index_p] = _axis_index_translation_rule