def _maybe_tracer_tuple_to_abstract_tuple(tup): if isinstance(tup, pe.JaxprTracerTuple): return core.AbstractTuple(list(map(_maybe_tracer_tuple_to_abstract_tuple, tup))) elif isinstance(tup, core.AbstractValue): return tup elif tup is None: return core.AbstractTuple(()) else: raise TypeError(tup)
def _scan_transpose(ct, consts, init, xs, forward, length, jaxpr): assert consts is None and init is None assert type(xs) is tuple a, res = xs assert a is None and res is not None # jaxpr :: d -> c -> (a, res) -> (c, b) # jaxpr_lifted :: res -> (d, c, a) -> (c, b) # jaxpr_lifted_trans :: res -> (CT c, CT b) -> (CT d, CT c, CT a) # jaxpr_trans :: * -> (CT c, CT d) -> (CT b, res) -> ((CT c, CT d), CT a) assert type(jaxpr.jaxpr.invars[2]) is tuple # assume restructuring jaxpr_lifted = rearrange_binders( lambda d, c, a_res: (a_res[1], (d, c, a_res[0])), jaxpr) jaxpr_lifted_trans = _transpose_jaxpr(jaxpr_lifted) jaxpr_trans = _move_stuff_and_add_add(jaxpr_lifted_trans) c_aval, b_aval = jaxpr.out_aval d_aval, c_aval2, _ = jaxpr.in_avals assert c_aval == c_aval2 bs_aval = _promote_aval_rank(length, b_aval) ct_d = ad_util.zeros_like_aval(d_aval) ct_c, ct_bs = ad.instantiate_zeros_aval(core.AbstractTuple((c_aval, bs_aval)), ct) carry_ct = core.pack((ct_c, ct_d)) # jaxpr_trans :: * -> (CT c, CT d) -> (CT b, res) -> ((CT c, CT d), CT a) core.check_jaxpr(jaxpr_trans.jaxpr) unit_aval, (ct_c_aval, ct_d_aval), (ct_b_aval, _) = jaxpr_trans.in_avals assert core.lattice_join(ct_c_aval, core.get_aval(ct_c)) == ct_c_aval assert core.lattice_join(ct_d_aval, core.get_aval(ct_d)) == ct_d_aval out = scan_p.bind( core.unit, carry_ct, core.pack((ct_bs, res)), forward=not forward, length=length, jaxpr=jaxpr_trans) (ct_init, ct_consts), ct_as = out return ct_consts, ct_init, (ct_as, None)
def _move_stuff_and_add_add(typed_jaxpr): # jaxpr_lifted_trans :: res -> (CT c, CT b) -> (CT d, CT c, CT a) # jaxpr_trans :: * -> (CT c, CT d) -> (CT b, res) -> ((CT c, CT d), CT a) res_aval, (CTc_aval, CTb_aval) = typed_jaxpr.in_avals CTd_aval, CTc_aval2, CTa_aval = typed_jaxpr.out_aval assert CTc_aval == CTc_aval2 in_avals = (core.AbstractTuple(()), core.AbstractTuple((CTc_aval, CTd_aval)), core.AbstractTuple((CTb_aval, res_aval))) out_aval = core.AbstractTuple((core.AbstractTuple((CTc_aval, CTd_aval)), CTa_aval)) jaxpr = typed_jaxpr.jaxpr.copy() # assume the jaxpr isn't restructuring any inputs assert not any(type(invar) is tuple for invar in jaxpr.invars) # munge input side CTc_in = _scan_newvar() CTb_in = _scan_newvar() CTd_in = _scan_newvar() res_in, CTc_CTb_in = jaxpr.invars jaxpr.invars = ((), (CTc_in, CTd_in), (CTb_in, res_in)) jaxpr.eqns = ( [pe._pack_eqn([CTc_in, CTb_in], CTc_CTb_in)] + jaxpr.eqns) # munge output side CTd_new = _scan_newvar() CTd_sum = _scan_newvar() CTc = _scan_newvar() CTa = _scan_newvar() partial_out = _scan_newvar() outvar = _scan_newvar() jaxpr.eqns = ( jaxpr.eqns + [pe._unpack_eqn(jaxpr.outvar, [CTd_new, CTc, CTa]), _add_any_eqn(CTd_sum, CTd_new, CTd_in), pe._pack_eqn([CTc, CTd_sum], partial_out), pe._pack_eqn([partial_out, CTa], outvar)]) jaxpr.outvar = outvar # TODO(mattjj): add a check_typed_jaxpr and use it here core.skip_checks or core.check_jaxpr(jaxpr) return core.TypedJaxpr(jaxpr, typed_jaxpr.literals, in_avals, out_aval)
def lu_abstract_eval(operand): if isinstance(operand, ShapedArray): if operand.ndim < 2: raise ValueError("Argument to LU decomposition must have ndims >= 2") batch_dims = operand.shape[:-2] m = operand.shape[-2] n = operand.shape[-1] pivot = ShapedArray(batch_dims + (min(m, n),), np.int32) else: pivot = operand return core.AbstractTuple((operand, pivot))
def eigh_abstract_eval(operand, lower): if isinstance(operand, ShapedArray): if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]: raise ValueError( "Argument to symmetric eigendecomposition must have shape [..., n, n]") batch_dims = operand.shape[:-2] n = operand.shape[-1] v = ShapedArray(batch_dims + (n, n), operand.dtype) w = ShapedArray(batch_dims + (n,), operand.dtype) else: v, w = operand, operand return core.AbstractTuple((v, w))
def eig_abstract_eval(operand): if isinstance(operand, ShapedArray): if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]: raise ValueError("Argument to nonsymmetric eigendecomposition must have " "shape [..., n, n], got shape {}".format(operand.shape)) batch_dims = operand.shape[:-2] n = operand.shape[-1] vl = vr = ShapedArray(batch_dims + (n, n), operand.dtype) w = ShapedArray(batch_dims + (n,), lax.lax._complex_basetype(operand.dtype)) else: w = vl = vr = operand return core.AbstractTuple((w, vl, vr))
def qr_abstract_eval(operand, full_matrices): if isinstance(operand, ShapedArray): if operand.ndim < 2: raise ValueError("Argument to QR decomposition must have ndims >= 2") batch_dims = operand.shape[:-2] m = operand.shape[-2] n = operand.shape[-1] k = m if full_matrices else min(m, n) q = ShapedArray(batch_dims + (m, k), operand.dtype) r = ShapedArray(batch_dims + (k, n), operand.dtype) else: q = operand r = operand return core.AbstractTuple((q, r))
def svd_abstract_eval(operand, full_matrices, compute_uv): if isinstance(operand, ShapedArray): if operand.ndim < 2: raise ValueError("Argument to singular value decomposition must have ndims >= 2") batch_dims = operand.shape[:-2] m = operand.shape[-2] n = operand.shape[-1] s = ShapedArray(batch_dims + (min(m, n),), operand.dtype) u = ShapedArray(batch_dims + (m, m if full_matrices else min(m, n)), operand.dtype) vt = ShapedArray(batch_dims + (n if full_matrices else min(m, n), n), operand.dtype) else: s = operand u = operand vt = operand return core.AbstractTuple((s, u, vt))
def _scan_partial_eval(trace, *tracers, **kwargs): jaxpr = kwargs.pop('jaxpr') length = kwargs.pop('length') forward = kwargs.pop('forward') assert not kwargs in_pvs, _ = unzip2([t.pval for t in tracers]) sc_consts, sc_init, sc_xs = map(pe.unknown, in_pvs) sc_carry = sc_init for i in range(1000): second_components = (sc_consts, sc_carry, sc_xs) jaxpr_1, jaxpr_2, sc_out = pe.partial_eval_jaxpr(jaxpr, second_components, instantiate=(sc_carry, False)) sc_carry_out, sc_ys = sc_out if sc_carry_out == sc_carry: break else: sc_carry = _binary_lattice_join(sc_carry, sc_carry_out) else: raise FixedPointError consts_tracer, init_tracer, xs_tracer = tracers lifted_init_tracer = _lift_tracer(trace, init_tracer, sc_carry) lifted_tracers = consts_tracer, lifted_init_tracer, xs_tracer in_pvs, in_consts = unzip2([t.pval for t in lifted_tracers]) carry_aval, y_aval = jaxpr.out_aval ys_aval = _promote_aval_rank(length, y_aval) out_aval = core.AbstractTuple((carry_aval, ys_aval)) out_pv = _put_known_pvs(sc_out, out_aval) out_carry, (ys, residuals) = scan_p.bind(*in_consts, forward=forward, length=length, jaxpr=jaxpr_1) out_const = core.pack((out_carry, ys)) residuals_tracer = trace.new_instantiated_const(core.pack(residuals)) d, c, a = lifted_tracers new_tracers = (d, c, (a, residuals_tracer)) eqn = core.JaxprEqn(new_tracers, None, scan_p, (), True, False, dict(forward=forward, length=length, jaxpr=jaxpr_2)) return pe.JaxprTracer(trace, pe.PartialVal((out_pv, out_const)), eqn)
def _tscan(f, a, bs, fields=(0, )): """ Works as jax.lax.scan but has additional `fields` argument to select only necessary fields from `a`'s structure. Defaults to selecting only the first field. Other fields will be filled by None. """ # Note: code is copied and modified from lax.scan implementation in # [JAX](https://github.com/google/jax) to support the additional `fields` # arg. Original code has the following copyright: # # Copyright 2018 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License") # convert pytree to flat jaxtuple a, a_tree = pytree_to_flatjaxtuple(a) bs, b_tree = pytree_to_flatjaxtuple(bs) fields, _ = pytree_to_flatjaxtuple(fields) f, out_tree = pytree_fun_to_flatjaxtuple_fun(wrap_init(f), (a_tree, b_tree)) # convert arrays to abstract values a_aval, _ = lax._abstractify(a) bs_aval, _ = lax._abstractify(bs) # convert bs to b b_aval = core.AbstractTuple( [ShapedArray(b.shape[1:], b.dtype) for b in bs_aval]) # convert abstract values to partial values (?) then evaluate to get jaxpr a_pval = partial_eval.PartialVal((a_aval, core.unit)) b_pval = partial_eval.PartialVal((b_aval, core.unit)) jaxpr, pval_out, consts = partial_eval.trace_to_jaxpr(f, (a_pval, b_pval)) aval_out, _ = pval_out consts = core.pack(consts) out = tscan_p.bind(a, bs, fields, consts, aval_out=aval_out, jaxpr=jaxpr) return tree_unflatten(out_tree(), out)
def scan(f, init, xs): """Scan a function over leading array axes while carrying along state. The type signature in brief is .. code-block:: haskell scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b]) where we use [t] here to denote the type t with an additional leading axis. That is, if t is an array type then [t] represents the type with an additional leading axis, and if t is a pytree (container) type with array leaves then [t] represents the type with the same pytree structure and corresponding leaves each with an additional leading axis. When both ``a`` and ``b`` are array types, the semantics of ``scan`` are given by this Python implementation:: def scan(f, init, xs): carry = init ys = [] for x in xs: carry, y = f(carry, x) ys.append(y) return carry, np.stack(ys) Unlike that Python version, both ``a`` and ``b`` may be arbitrary pytree types, and so multiple arrays can be scanned over at once and produce multiple output arrays. Also unlike that Python version, ``scan`` is a JAX primitive and is lowered to a single XLA While HLO. That makes it useful for reducing compilation times for jit-compiled functions, since native Python loop constructs in an ``@jit`` function are unrolled, leading to large XLA computations. Args: f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning that ``f`` accepts two arguments where the first is a value of the loop carry and the second is a slice of ``xs`` along its leading axis, and that ``f`` returns a pair where the first element represents a new value for the loop carry and the second represents a slice of the output. init: an initial loop carry value of type ``c``, which can be a scalar, array, or any pytree (nested Python tuple/list/dict) thereof, representing the initial loop carry value. xs: the value of type ``[a]`` over which to scan along the leading axis, where ``[a]`` can be an array or any pytree (nested Python tuple/list/dict) thereof with consistent leading axis sizes. Returns: A pair of type ``(c, [b])`` where the first element represents the final loop carry value and the second element represents the stacked outputs of the second output of ``f`` when scanned over the leading axis of the inputs. """ (init, xs), in_trees = unzip2(map(pytree_to_jaxtupletree, (init, xs))) f, out_tree = pytree_fun_to_jaxtupletree_fun(lu.wrap_init(f), in_trees) carry_pval = carry_aval, _ = _abstractify(init) xs_aval, _ = _abstractify(xs) x_aval = _demote_aval_rank(xs_aval) x_pval = pe.PartialVal((x_aval, core.unit)) jaxpr, pval_out, consts = pe.trace_to_jaxpr( f, (carry_pval, x_pval), instantiate=True) pv_out, const_out = pval_out assert isinstance(pv_out, core.AbstractValue) and const_out == core.unit if not isinstance(pv_out, core.AbstractTuple) or len(pv_out) != 2: msg = ("scanned function must have signature `c -> a -> (c, b)`, but the " "output was not a pair: got type {}.") raise TypeError(msg.format(pv_out)) carry_aval_out, y_aval = pv_out if carry_aval != carry_aval_out: msg = ("scanned function carry output does not match carry input: " "input carry is {} and output carry is {}.") raise TypeError(msg.format(carry_aval, carry_aval_out)) lifted_jaxpr = pe._closure_convert_jaxpr(jaxpr) consts_aval, _ = _abstractify(core.pack(consts)) in_avals = (consts_aval, carry_aval, x_aval) out_aval = core.AbstractTuple((carry_aval, y_aval)) jaxpr = core.TypedJaxpr(lifted_jaxpr, (), in_avals, out_aval) length = _leading_dim_size(xs) out = scan_p.bind(core.pack(consts), init, xs, forward=True, length=length, jaxpr=jaxpr) return build_tree(out_tree(), out)
def _promote_aval_rank(n, xs): assert isinstance(xs, core.AbstractValue) if isinstance(xs, core.AbstractTuple): return core.AbstractTuple(map(partial(_promote_aval_rank, n), xs)) else: return ShapedArray((n,) + xs.shape, xs.dtype)
def _demote_aval_rank(xs): assert isinstance(xs, core.AbstractValue) if isinstance(xs, core.AbstractTuple): return core.AbstractTuple(map(_demote_aval_rank, xs)) else: return ShapedArray(xs.shape[1:], xs.dtype)