def qr_jvp_rule(primals, tangents, full_matrices): # See j-towns.github.io/papers/qr-derivative.pdf for a terse derivation. x, = primals if full_matrices or np.shape(x)[-2] < np.shape(x)[-1]: raise NotImplementedError dx, = tangents q, r = qr_p.bind(x, full_matrices=False) dx_rinv = triangular_solve(r, dx) # Right side solve by default qt_dx_rinv = np.matmul(_T(q), dx_rinv) qt_dx_rinv_lower = np.tril(qt_dx_rinv, -1) domega = qt_dx_rinv_lower - _T(qt_dx_rinv_lower) # This is skew-symmetric dq = np.matmul(q, domega - qt_dx_rinv) + dx_rinv dr = np.matmul(qt_dx_rinv - domega, r) return core.pack((q, r)), core.pack((dq, dr))
def _scan_jvp(primals, tangents, forward, length, jaxpr): consts, init, xs = primals consts_dot, init_dot, xs_dot = tangents consts_aval, carry_aval, x_aval = jaxpr.in_avals _, y_aval = jaxpr.out_aval consts_nonzeros = ad.get_nonzeros(consts_dot) init_nonzeros = ad.get_nonzeros(init_dot) xs_nonzeros = ad.get_nonzeros(xs_dot) # same as x_nonzeros b/c arrays carry_nonzeros = init_nonzeros for _ in range(1000): nonzeros = (consts_nonzeros, carry_nonzeros, xs_nonzeros) jaxpr_jvp, nonzeros_out = ad.jvp_jaxpr(jaxpr, nonzeros, instantiate=(carry_nonzeros, False)) carry_nonzeros_out, ys_nonzeros = nonzeros_out if _binary_lattice_eq(carry_nonzeros_out, carry_nonzeros): break else: carry_nonzeros = _binary_lattice_join(carry_nonzeros_out, carry_nonzeros) else: raise FixedPointError # convert_zeros is like strip_zeros but uses explicit lattice information to # instantiate zeros in some cases, namely in init_dot based on the fixed point nonzero_init_dot = _convert_zeros(carry_nonzeros, init, init_dot) nonzero_consts_dot = _convert_zeros(consts_nonzeros, consts, consts_dot) nonzero_xs_dot = _convert_zeros(xs_nonzeros, xs, xs_dot) consts_dual = core.pack((consts, nonzero_consts_dot)) init_dual = core.pack((init, nonzero_init_dot)) xs_dual = core.pack((xs, nonzero_xs_dot)) carry_out_dual, ys_dual = scan_p.bind(consts_dual, init_dual, xs_dual, forward=forward, length=length, jaxpr=jaxpr_jvp) ys, ys_dot = ys_dual ys_dot = ad.put_zeros(ad.TangentTuple, ys_nonzeros, ys_dot) carry_out, carry_out_dot = carry_out_dual carry_out_dot = ad.put_zeros(ad.TangentTuple, carry_nonzeros_out, carry_out_dot) return core.pack((carry_out, ys)), ad.TangentTuple((carry_out_dot, ys_dot))
def _tscan_impl(a, bs, fields, consts, aval_out, jaxpr): length = tuple(bs)[0].shape[0] state = [ lax.full((length, ) + a[i].shape, 0, lax._dtype(a[i])) for i in fields ] def body_fun(i, vals): a, state = vals # select i-th element from each b b = [lax.dynamic_index_in_dim(b, i, keepdims=False) for b in bs] a_out = core.eval_jaxpr(jaxpr, consts, (), a, core.pack(b)) # select fields from a_out and update state state_out = [ lax.dynamic_update_index_in_dim(s, a[None, ...], i, axis=0) for a, s in zip([tuple(a_out)[j] for j in fields], state) ] return a_out, state_out _, state = lax.fori_loop(0, length, body_fun, (a, state)) # set None for non-selected fields out = [None] * len(a) for field, i in zip(fields, range(len(fields))): out[field] = state[i] return core.pack(out)
def _update_arrays(i, aval, xs, x): assert isinstance(aval, core.AbstractValue) if isinstance(aval, core.AbstractTuple): return core.pack(map(partial(_update_arrays, i), aval, xs, x)) else: x = lax.reshape(x, (1,) + onp.shape(x)) return lax.dynamic_update_index_in_dim(xs, x, i, axis=0)
def eigh_jvp_rule(primals, tangents, lower): # Derivative for eigh in the simplest case of distinct eigenvalues. # This is classic nondegenerate perurbation theory, but also see # https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf # The general solution treating the case of degenerate eigenvalues is # considerably more complicated. Ambitious readers may refer to the general # methods below or refer to degenerate perturbation theory in physics. # https://www.win.tue.nl/analysis/reports/rana06-33.pdf and # https://people.orie.cornell.edu/aslewis/publications/99-clarke.pdf a, = primals a_dot, = tangents v, w = eigh_p.bind(symmetrize(a), lower=lower) if a_dot is ad_util.zero: return core.pack((v, w)), ad.TangentTuple(ad_util.zero, ad_util.zero) # for complex numbers we need eigenvalues to be full dtype of v, a: w = w.astype(a.dtype) eye_n = np.eye(a.shape[-1], dtype=a.dtype) # carefully build reciprocal delta-eigenvalue matrix, avoiding NaNs. Fmat = np.reciprocal(eye_n + w - w[..., np.newaxis]) - eye_n # eigh impl doesn't support batch dims, but future-proof the grad. dot = lax.dot if a.ndim == 2 else lax.batch_matmul vdag_adot_v = dot(dot(_H(v), a_dot), v) dv = dot(v, np.multiply(Fmat, vdag_adot_v)) dw = np.diagonal(vdag_adot_v) return (v, w), (dv, dw)
def svd_jvp_rule(primals, tangents, full_matrices, compute_uv): A, = primals dA, = tangents s, U, Vt = svd_p.bind(A, full_matrices=False, compute_uv=True) if dA is ad_util.zero: return (core.pack((s, U, Vt)), ad.TangentTuple(ad_util.zero, ad_util.zero, ad_util.zero)) if full_matrices: # TODO: implement full matrices case, documented here: https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf raise NotImplementedError( "Singular value decomposition JVP not implemented for full matrices") k = s.shape[-1] Ut, V = np.conj(U).T, np.conj(Vt).T s_dim = s[..., None, :] dS = np.dot(np.dot(Ut, dA), V) ds = np.real(np.diag(dS)) F = 1 / (np.square(s_dim) - np.square(s_dim.T) + np.eye(k)) - np.eye(k) dSS = s_dim * dS SdS = s_dim.T * dS dU = np.dot(U, F * (dSS + dSS.T)) dV = np.dot(V, F * (SdS + SdS.T)) m, n = A.shape[-2], A.shape[-1] if m > n: dU = dU + np.dot(np.eye(m) - np.dot(U, Ut), np.dot(dA, V)) / s_dim if n > m: dV = dV + np.dot(np.eye(n) - np.dot(V, Vt), np.dot(np.conj(dA).T, U)) / s_dim return (s, U, Vt), (ds, dU, dV.T)
def _jaxtupletree_select(pred, on_true, on_false): aval = core.get_aval(on_true) if type(aval) is core.AbstractTuple: return core.pack(map(partial(_jaxtupletree_select, pred), on_true, on_false)) elif isinstance(aval, UnshapedArray): return lax.select(pred, on_true, on_false) else: raise TypeError(aval)
def _lu_jvp_rule(primals, tangents): a, = primals a_dot, = tangents lu, pivots = lu_p.bind(a) if a_dot is ad_util.zero: return (core.pack( (lu, pivots)), ad.TangentTuple((ad_util.zero, ad_util.zero))) a_shape = np.shape(a) m, n = a_shape[-2:] dtype = lax.dtype(a) k = min(m, n) permutation = lu_pivots_to_permutation(pivots, m) batch_dims = a_shape[:-2] iotas = np.ix_(*(lax.iota(np.int32, b) for b in batch_dims + (1, ))) x = a_dot[iotas[:-1] + (permutation, slice(None))] # Differentiation of Matrix Functionals Using Triangular Factorization # F. R. De Hoog, R. S. Anderssen, and M. A. Lukas # # LU = A # ==> L'U + LU' = A' # ==> inv(L) . L' + U' . inv(U) = inv(L) A' inv(U) # ==> L' = L . tril(inv(L) . A' . inv(U), -1) # U' = triu(inv(L) . A' . inv(U)) . U ndims = len(a_shape) l_padding = [(0, 0, 0)] * ndims l_padding[-1] = (0, m - k, 0) zero = np._constant_like(lu, 0) l = lax.pad(np.tril(lu[..., :, :k], -1), zero, l_padding) l = l + np.eye(m, m, dtype=dtype) u_eye = lax.pad(np.eye(n - k, n - k, dtype=dtype), zero, ((k, 0, 0), (k, 0, 0))) u_padding = [(0, 0, 0)] * ndims u_padding[-2] = (0, n - k, 0) u = lax.pad(np.triu(lu[..., :k, :]), zero, u_padding) + u_eye la = triangular_solve(l, x, left_side=True, transpose_a=False, lower=True, unit_diagonal=True) lau = triangular_solve(u, la, left_side=False, transpose_a=False, lower=False) l_dot = np.matmul(l, np.tril(lau, -1)) u_dot = np.matmul(np.triu(lau), u) lu_dot = l_dot + u_dot return (lu, pivots), (lu_dot, ad_util.zero)
def _scan_partial_eval(trace, *tracers, **kwargs): jaxpr = kwargs.pop('jaxpr') length = kwargs.pop('length') forward = kwargs.pop('forward') assert not kwargs in_pvs, _ = unzip2([t.pval for t in tracers]) sc_consts, sc_init, sc_xs = map(pe.unknown, in_pvs) sc_carry = sc_init for i in range(1000): second_components = (sc_consts, sc_carry, sc_xs) jaxpr_1, jaxpr_2, sc_out = pe.partial_eval_jaxpr(jaxpr, second_components, instantiate=(sc_carry, False)) sc_carry_out, sc_ys = sc_out if sc_carry_out == sc_carry: break else: sc_carry = _binary_lattice_join(sc_carry, sc_carry_out) else: raise FixedPointError consts_tracer, init_tracer, xs_tracer = tracers lifted_init_tracer = _lift_tracer(trace, init_tracer, sc_carry) lifted_tracers = consts_tracer, lifted_init_tracer, xs_tracer in_pvs, in_consts = unzip2([t.pval for t in lifted_tracers]) carry_aval, y_aval = jaxpr.out_aval ys_aval = _promote_aval_rank(length, y_aval) out_aval = core.AbstractTuple((carry_aval, ys_aval)) out_pv = _put_known_pvs(sc_out, out_aval) out_carry, (ys, residuals) = scan_p.bind(*in_consts, forward=forward, length=length, jaxpr=jaxpr_1) out_const = core.pack((out_carry, ys)) residuals_tracer = trace.new_instantiated_const(core.pack(residuals)) d, c, a = lifted_tracers new_tracers = (d, c, (a, residuals_tracer)) eqn = core.JaxprEqn(new_tracers, None, scan_p, (), True, False, dict(forward=forward, length=length, jaxpr=jaxpr_2)) return pe.JaxprTracer(trace, pe.PartialVal((out_pv, out_const)), eqn)
def _convert_zeros(instantiate, example, tangent): t = type(instantiate) if t is bool: if instantiate: return ad.instantiate_zeros(example, tangent) elif tangent is ad_util.zero: return core.unit else: raise TypeError(tangent) # not clear if ever reachable elif t is tuple: if type(tangent) is ad.TangentTuple: return core.pack(map(_convert_zeros, instantiate, example, tangent)) elif tangent is ad_util.zero: zeros = [ad_util.zero] * len(instantiate) return core.pack(map(_convert_zeros, instantiate, example, zeros)) else: raise TypeError(tangent) else: raise TypeError(t)
def _convert_zeros(convert_symbolic, example, tangent): if tangent is ad.zero: if not convert_symbolic: return core.unit else: return ad.zeros_like_jaxval(example) elif type(tangent) is ad.TangentTuple: return core.pack( map(_convert_zeros, convert_symbolic, example, tangent)) else: return tangent
def body_fun(i, vals): a, state = vals # select i-th element from each b b = [lax.dynamic_index_in_dim(b, i, keepdims=False) for b in bs] a_out = core.eval_jaxpr(jaxpr, consts, (), a, core.pack(b)) # select fields from a_out and update state state_out = [ lax.dynamic_update_index_in_dim(s, a[None, ...], i, axis=0) for a, s in zip([tuple(a_out)[j] for j in fields], state) ] return a_out, state_out
def _lu_python(x): """Default LU decomposition in Python, where no better version exists.""" m, n = x.shape[-2:] batch_dims = x.shape[:-2] if len(batch_dims) > 0: batch_size = onp.prod(batch_dims, dtype=onp.int64) pivot, lu = api.vmap(_lu_blocked)(lax.reshape(x, (batch_size, m, n))) pivot = lax.reshape(pivot, batch_dims + (min(m, n), )) lu = lax.reshape(lu, batch_dims + (m, n)) else: pivot, lu = _lu_blocked(x) return core.pack((lu, pivot))
def _lift_tracer(trace, tracer, is_unknown): t = type(is_unknown) if t is bool: if is_unknown: return trace.instantiate_const(tracer) else: return tracer elif t is tuple: tracers = map(trace.full_raise, tracer) return core.pack(map(partial(_lift_tracer, trace), tracers, is_unknown)) else: raise TypeError(t)
def cond(pred, true_operand, true_fun, false_operand, false_fun): def trace_jaxpr(fun, operand): op_flat, in_tree = pytree_to_flatjaxtuple(operand) fun_flat, out_tree = pytree_fun_to_flatjaxtuple_fun( lu.wrap_init(fun), (in_tree, )) jaxpr, pvout, consts = pe.trace_to_jaxpr(fun_flat, (lax._abstractify(op_flat), )) return op_flat, jaxpr, consts, pvout, out_tree true_data = trace_jaxpr(true_fun, true_operand) true_op, true_jaxpr, true_consts, true_pval, true_tree = true_data false_data = trace_jaxpr(false_fun, false_operand) false_op, false_jaxpr, false_consts, false_pval, false_tree = false_data if true_tree() != false_tree(): msg = "true_fun and false_fun outputs must have identical structure" raise TypeError(msg) try: joined_pval = pe.join_pvals(true_pval, false_pval) except TypeError: msg = "could not merge true_fun and false_fun output pvals: {} and {}." raise TypeError(msg.format(true_pval, false_pval)) revis = _revise_cond_jaxpr(joined_pval, true_pval, true_jaxpr, true_consts) true_jaxpr, true_consts = revis revis = _revise_cond_jaxpr(joined_pval, false_pval, false_jaxpr, false_consts) false_jaxpr, false_consts = revis aval_out, _ = joined_pval out = cond_p.bind(pred, true_op, core.pack(true_consts), false_op, core.pack(false_consts), aval_out=aval_out, true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr) out = pe.merge_pvals(out, joined_pval) return tree_unflatten(true_tree(), out)
def _scan_transpose(ct, consts, init, xs, forward, length, jaxpr): assert consts is None and init is None assert type(xs) is tuple a, res = xs assert a is None and res is not None # jaxpr :: d -> c -> (a, res) -> (c, b) # jaxpr_lifted :: res -> (d, c, a) -> (c, b) # jaxpr_lifted_trans :: res -> (CT c, CT b) -> (CT d, CT c, CT a) # jaxpr_trans :: * -> (CT c, CT d) -> (CT b, res) -> ((CT c, CT d), CT a) assert type(jaxpr.jaxpr.invars[2]) is tuple # assume restructuring jaxpr_lifted = rearrange_binders( lambda d, c, a_res: (a_res[1], (d, c, a_res[0])), jaxpr) jaxpr_lifted_trans = _transpose_jaxpr(jaxpr_lifted) jaxpr_trans = _move_stuff_and_add_add(jaxpr_lifted_trans) c_aval, b_aval = jaxpr.out_aval d_aval, c_aval2, _ = jaxpr.in_avals assert c_aval == c_aval2 bs_aval = _promote_aval_rank(length, b_aval) ct_d = ad_util.zeros_like_aval(d_aval) ct_c, ct_bs = ad.instantiate_zeros_aval( core.AbstractTuple((c_aval, bs_aval)), ct) carry_ct = core.pack((ct_c, ct_d)) # jaxpr_trans :: * -> (CT c, CT d) -> (CT b, res) -> ((CT c, CT d), CT a) core.check_jaxpr(jaxpr_trans.jaxpr) unit_aval, (ct_c_aval, ct_d_aval), (ct_b_aval, _) = jaxpr_trans.in_avals assert core.lattice_join(ct_c_aval, core.get_aval(ct_c)) == ct_c_aval assert core.lattice_join(ct_d_aval, core.get_aval(ct_d)) == ct_d_aval out = scan_p.bind(core.unit, carry_ct, core.pack((ct_bs, res)), forward=not forward, length=length, jaxpr=jaxpr_trans) (ct_init, ct_consts), ct_as = out return ct_consts, ct_init, (ct_as, None)
def lu_jvp_rule(primals, tangents): a, = primals a_dot, = tangents lu, pivots = lu_p.bind(a) a_shape = np.shape(a) m, n = a_shape[-2:] dtype = lax._dtype(a) k = min(m, n) # TODO(phawkins): use a gather rather than a matrix multiplication here. permutation = lu_pivots_to_permutation(pivots, m) p = np.array(permutation[:, None] == np.arange(m), dtype=dtype) x = np.matmul(p, a_dot) # Differentiation of Matrix Functionals Using Triangular Factorization # F. R. De Hoog, R. S. Anderssen, and M. A. Lukas # # LU = A # ==> L'U + LU' = A' # ==> inv(L) . L' + U' . inv(U) = inv(L) A' inv(U) # ==> L' = L . tril(inv(L) . A' . inv(U), -1) # U' = triu(inv(L) . A' . inv(U)) . U ndims = len(a_shape) l_padding = [(0, 0, 0)] * ndims l_padding[-1] = (0, m - k, 0) zero = np._constant_like(lu, 0) l = lax.pad(np.tril(lu[..., :, :k], -1), zero, l_padding) l = l + np.eye(m, m, dtype=dtype) u_eye = lax.pad(np.eye(n - k, n - k, dtype=dtype), zero, ((k, 0, 0), (k, 0, 0))) u_padding = [(0, 0, 0)] * ndims u_padding[-2] = (0, n - k, 0) u = lax.pad(np.triu(lu[..., :k, :]), zero, u_padding) + u_eye la = triangular_solve(l, x, left_side=True, transpose_a=False, lower=True) lau = triangular_solve(u, la, left_side=False, transpose_a=False, lower=False) l_dot = np.matmul(l, np.tril(lau, -1)) u_dot = np.matmul(np.triu(lau), u) lu_dot = l_dot + u_dot return core.pack((lu, pivots)), ad.TangentTuple((lu_dot, ad_util.zero))
def _scan_impl(consts, init, xs, forward, length, jaxpr): _, _, x_aval = jaxpr.in_avals _, y_aval = jaxpr.out_aval ys_aval = _promote_aval_rank(length, y_aval) def body_fun(i, vals): idx = i if forward else length - i - 1 carry, ys = vals x = _index_arrays(idx, x_aval, xs) carry_out, y = core.jaxpr_as_fun(jaxpr)(consts, carry, x) ys_out = _update_arrays(idx, y_aval, ys, y) return (carry_out, ys_out) ys_init = _empty_arrays(ys_aval) carry, ys = fori_loop(0, length, body_fun, (init, ys_init)) return core.pack((carry, ys))
def _revise_cond_jaxpr(new_pval, old_pval, jaxpr, consts): new_pv, new_const = new_pval old_pv, old_const = old_pval if new_pv == old_pv: # we didn't move up the lattice by joining with the other side return jaxpr, consts elif old_pv is None: # we moved up the lattice from totally-known, so make a new jaxpr that # returns a single constant JaxTuple with elements that are constants # drawn from consts where new_pv is unknown assert not jaxpr.eqns and not consts outvar = pe.Var(0, "_cond") new_jaxpr = jaxpr.copy() new_jaxpr.constvars = [outvar] new_jaxpr.outvar = outvar new_consts = (core.pack([ core.unit if pv is None else old_c for pv, old_c in zip(new_pv, old_const) ]), ) return new_jaxpr, new_consts else: # we moved up the lattice, but not from totally-constant, so adapt the # japxr to return some new constants in places that are now unknown but # weren't before eqn = jaxpr.eqns[-1] assert eqn.primitive == core.pack_p assert len(eqn.outvars) == 1 and eqn.outvars[0] == jaxpr.outvar newvar = pe.gensym("_cond") new_constvars, new_constvals = unzip2([ (newvar(), c) for new, old, c in zip(new_pv, old_pv, old_const) if old is None and new is not None ]) new_consts = consts + tuple(new_constvals) new_jaxpr = jaxpr.copy() new_jaxpr.constvars = tuple(jaxpr.constvars) + tuple(new_constvars) newvars = iter(new_constvars) new_invars = [ next(newvars) if old is None and new is not None else (core.unitvar if new is None and old is None else v) for new, old, v in zip(new_pv, old_pv, eqn.invars) ] new_jaxpr.eqns = (list(jaxpr.eqns[:-1]) + [_pack_eqn(new_invars, jaxpr.outvar)]) return new_jaxpr, new_consts
def testShardedDeviceTuple(self): f = lambda x: core.pack((x, x)) f = pmap(f) shape = (xla_bridge.device_count(), 4) x = onp.arange(prod(shape), dtype=onp.float32).reshape(shape) # test that we can pass in and out ShardedDeviceTuples (and unpack them) y = f(x) self.assertIsInstance(y, pxla.ShardedDeviceTuple) self.assertIsInstance(y, core.JaxTuple) self.assertAllClose(y, (x, x), check_dtypes=False) z = f(y) self.assertIsInstance(z, pxla.ShardedDeviceTuple) self.assertAllClose(z, (y, y), check_dtypes=True) # test that we can pass a ShardedDeviceTuple to a regular jit computation w = jit(lambda x: list(x)[0])(y) self.assertAllClose(w, x, check_dtypes=False)
def _scan_batching_rule(batched_args, batch_dims, forward, length, jaxpr): consts, init, xs = batched_args consts_bdim, init_bdim, xs_bdim = batch_dims sizes = lax._reduce(set.union, map(batching.dimsize, batch_dims, batched_args)) size = sizes.pop() assert not sizes consts_batched = batching.where_batched(consts_bdim) init_batched = batching.where_batched(init_bdim) xs_batched = batching.where_batched(xs_bdim) carry_batched = init_batched for _ in range(1000): which_batched = (consts_batched, carry_batched, xs_batched) jaxpr_batched, batched_out = batching.batch_jaxpr( jaxpr, size, which_batched, instantiate=(carry_batched, False)) carry_batched_out, ys_batched = batched_out if _binary_lattice_eq(carry_batched_out, carry_batched): break else: carry_batched = _binary_lattice_join(carry_batched_out, carry_batched) else: raise FixedPointError consts_batched = batching.instantiate_bdim(size, 0, consts_batched, consts_bdim, consts) init_batched = batching.instantiate_bdim(size, 0, carry_batched, init_bdim, init) xs_batched = batching.instantiate_bdim(size, 1, xs_batched, xs_bdim, xs) carry_out, ys = scan_p.bind(consts_batched, init_batched, xs_batched, forward=forward, length=length, jaxpr=jaxpr_batched) carry_out_bdim = batching.bools_to_bdims(0, carry_batched) ys_bdim = batching.bools_to_bdims(1, ys_batched) return core.pack((carry_out, ys)), (carry_out_bdim, ys_bdim)
def __init__(self, seed): """Create a new PRNG key. Args: seed: a scalar integer value used to initialize the PRNG key. Returns: A new PRNGKey object. """ convert = lambda key: lax.convert_element_type(key, onp.uint32) if onp.shape(seed): raise TypeError("PRNGKey seed must be a scalar.") if isinstance(seed, (int, onp.ndarray)): # Special handling of raw integer values, which may have be 64bit even # when jax_enable_x64=False and we don't want to drop the top 32 bits k1 = convert(onp.bitwise_and(onp.right_shift(seed, 32), 0xFFFFFFFF)) else: k1 = convert(lax.shift_right_logical(seed, 32)) k2 = convert(lax.bitwise_and(seed, 0xFFFFFFFF)) self.keypair = core.pack((k1, k2))
def _scan_init(rng, submodule_params_dict, consts, init, xs, forward, length, jaxpr, reuse, reuse_only): # TODO update to jax==0.1.42 assert len(consts) == 0 _, _, x_aval = jaxpr.in_avals _, y_aval = jaxpr.out_aval ys_aval = _promote_aval_rank(length, y_aval) x = _index_arrays(0, x_aval, xs) submodule_params_dict = _get_submodule_params(rng, jaxpr.jaxpr, jaxpr.literals, (), submodule_params_dict, consts, init, x, reuse=reuse, reuse_only=reuse_only) if len(submodule_params_dict) == 0: submodule_params = () else: primitive, = submodule_params_dict.keys() submodule_params = (primitive._params_namedtuple( submodule_params_dict[primitive]), ) def body_fun(i, vals): idx = i if forward else length - i - 1 carry, ys = vals x = _index_arrays(idx, x_aval, xs) cell = parametrized(jc.jaxpr_as_fun(jaxpr)) carry_out, y = cell.apply(submodule_params, consts, carry, x) ys_out = _update_arrays(idx, y_aval, ys, y) return carry_out, ys_out ys_init = _empty_arrays(ys_aval) carry, ys = lax.fori_loop(0, length, body_fun, (init, ys_init)) return jc.pack((carry, ys)), submodule_params_dict
def _scan_apply(submodule_params_iter, consts, init, xs, forward, length, jaxpr): # TODO update to jax==0.1.42 _, _, x_aval = jaxpr.in_avals _, y_aval = jaxpr.out_aval ys_aval = _promote_aval_rank(length, y_aval) # TODO fix param sharing cell_params = (submodule_params_iter.get_params(None), ) if len( submodule_params_iter.submodule_params) > 0 else () def body_fun(i, vals): idx = i if forward else length - i - 1 carry, ys = vals x = _index_arrays(idx, x_aval, xs) cell = parametrized(jc.jaxpr_as_fun(jaxpr)) carry_out, y = cell.apply(cell_params, consts, carry, x) ys_out = _update_arrays(idx, y_aval, ys, y) return carry_out, ys_out ys_init = _empty_arrays(ys_aval) carry, ys = lax.fori_loop(0, length, body_fun, (init, ys_init)) return jc.pack((carry, ys))
def _tscan(f, a, bs, fields=(0, )): """ Works as jax.lax.scan but has additional `fields` argument to select only necessary fields from `a`'s structure. Defaults to selecting only the first field. Other fields will be filled by None. """ # Note: code is copied and modified from lax.scan implementation in # [JAX](https://github.com/google/jax) to support the additional `fields` # arg. Original code has the following copyright: # # Copyright 2018 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License") # convert pytree to flat jaxtuple a, a_tree = pytree_to_flatjaxtuple(a) bs, b_tree = pytree_to_flatjaxtuple(bs) fields, _ = pytree_to_flatjaxtuple(fields) f, out_tree = pytree_fun_to_flatjaxtuple_fun(wrap_init(f), (a_tree, b_tree)) # convert arrays to abstract values a_aval, _ = lax._abstractify(a) bs_aval, _ = lax._abstractify(bs) # convert bs to b b_aval = core.AbstractTuple( [ShapedArray(b.shape[1:], b.dtype) for b in bs_aval]) # convert abstract values to partial values (?) then evaluate to get jaxpr a_pval = partial_eval.PartialVal((a_aval, core.unit)) b_pval = partial_eval.PartialVal((b_aval, core.unit)) jaxpr, pval_out, consts = partial_eval.trace_to_jaxpr(f, (a_pval, b_pval)) aval_out, _ = pval_out consts = core.pack(consts) out = tscan_p.bind(a, bs, fields, consts, aval_out=aval_out, jaxpr=jaxpr) return tree_unflatten(out_tree(), out)
def while_loop(cond_fun, body_fun, init_val): """Call `body_fun` repeatedly in a loop while `cond_fun` is True. Arguments: cond_fun: pure function of type `T -> Bool`. body_fun: pure function of type `T -> T`. init_val: value of type `T`, a type that can be a scalar, array, or any (nested) Python tuple/list/dict thereof. Returns: The output from the final iteration of body_fun, of type `T`. The semantics of `while_loop` are given by this Python implementation:: def while_loop(cond_fun, body_fun, init_val): val = init_val while cond_fun(val): val = body_fun(val) return val Unlike that pure Python version, `while_loop` is a JAX primitive and is lowered to a single XLA While HLO. That makes it useful for reducing compilation times for jit-compiled functions, since native Python loop constructs in an `@jit` function are unrolled, leading to large XLA computations. Another difference from using Python-native loop constructs is that `while_loop` is not (yet) reverse-mode differentiable because XLA computations require static bounds on memory requirements. """ init_val_flat, in_tree = pytree_to_jaxtupletree(init_val) flat_body_fun, out_tree = pytree_fun_to_jaxtupletree_fun( lu.wrap_init(body_fun), (in_tree, )) flat_cond_fun, _ = pytree_fun_to_jaxtupletree_fun(lu.wrap_init(cond_fun), (in_tree, )) pval_flat = lax._abstractify(init_val_flat) cond_jaxpr, _, cond_consts = pe.trace_to_jaxpr(flat_cond_fun, (pval_flat, )) body_jaxpr, pval_out, body_consts = pe.trace_to_jaxpr( flat_body_fun, (pval_flat, )) aval_out, _ = pval_out # We don't want to promote literal constants as loop arguments; there are # sometimes many of them. We pass tracers as loop arguments, but leave # nontracers as constants. We also sort the constants so the nontracers are # first. def split_tracers_and_nontracers(jaxpr, consts): tracer = [] nontracer = [] for x in zip(jaxpr.constvars, consts): # TODO(phawkins): We avoid treating DeviceArrays as constant literals so # we don't copy large arrays back to the host. We probably should relax # this and either always copy small constants, or opportunistically use # DeviceArray values for which we already know npy_value. not_literal_const = isinstance(x[1], (core.Tracer, xla.DeviceArray)) (tracer if not_literal_const else nontracer).append(x) tracer_vars, tracer_consts = unzip2(tracer) nontracer_vars, nontracer_consts = unzip2(nontracer) return nontracer_vars + tracer_vars, nontracer_consts, tracer_consts cond_split = split_tracers_and_nontracers(cond_jaxpr, cond_consts) cond_jaxpr.constvars, cond_nontracer_consts, cond_tracer_consts = cond_split body_split = split_tracers_and_nontracers(body_jaxpr, body_consts) body_jaxpr.constvars, body_nontracer_consts, body_tracer_consts = body_split if out_tree() != in_tree: raise TypeError( "body_fun input and output must have identical structure") out_flat = while_p.bind( init_val_flat, core.pack(cond_tracer_consts), core.pack(body_tracer_consts), cond_consts=lax._OpaqueParam(cond_nontracer_consts), body_consts=lax._OpaqueParam(body_nontracer_consts), aval_out=aval_out, cond_jaxpr=cond_jaxpr, body_jaxpr=body_jaxpr) return build_tree(out_tree(), out_flat)
def fun(x): y = pack((x, x)) z = pack((y, x)) y1, _ = z y2, _ = y1 return y2
def foo(x): return np.tup_add(core.pack((x, y)))
def bar(y): x1, y1 = core.pack((x, y)) return np.sin(x1 * y1)
def foo(x): x1, y1 = core.pack((x, y)) assert y1 is y, (y1, y) return x1