def _scan_polymorphic_shape_rule(shape_exprs, forward, length, jaxpr, num_consts, num_carry, linear): const_shexprs, init_shexprs, xs_shexprs = split_list( shape_exprs, [num_consts, num_carry]) _, y_avals = split_list(jaxpr.out_avals, [num_carry]) ys_shapes = [ShapeExpr(length, *y_aval.shape) for y_aval in y_avals] return init_shexprs + ys_shapes
def _scan_jvp(primals, tangents, forward, length, jaxpr, num_consts, num_carry, linear): num_xs = len(jaxpr.in_avals) - num_carry - num_consts num_ys = len(jaxpr.out_avals) - num_carry nonzeros = [t is not ad_util.zero for t in tangents] const_nz, init_nz, xs_nz = split_list(nonzeros, [num_consts, num_carry]) carry_nz = init_nz for _ in range(1000): nonzeros = const_nz + carry_nz + xs_nz jaxpr_jvp, nonzeros_out = ad.jvp_jaxpr(jaxpr, nonzeros, instantiate=carry_nz + [False] * num_ys) carry_nz_out, ys_nz = nonzeros_out[:num_carry], nonzeros_out[ num_carry:] if carry_nz_out == carry_nz: break else: carry_nz = carry_nz_out else: raise FixedPointError tangents = [ ad.instantiate_zeros(x, t) if t is ad_util.zero and nz else t for x, t, nz in zip(primals, tangents, nonzeros) ] consts, init, xs = split_list(primals, [num_consts, num_carry]) all_tangents = split_list(tangents, [num_consts, num_carry]) consts_dot, init_dot, xs_dot = _map(_prune_zeros, all_tangents) jaxpr_jvp_rearranged = ad.rearrange_binders( jaxpr_jvp, [num_consts, num_carry, num_xs], [len(consts_dot), len(init_dot), len(xs_dot)], [num_carry, num_ys], [len(init_dot), sum(nonzeros_out) - len(init_dot)]) consts_linear, init_linear, xs_linear = split_list(linear, [num_consts, num_carry]) jaxpr_jvp_linear = (consts_linear + [True] * len(consts_dot) + init_linear + [True] * len(init_dot) + xs_linear + [True] * len(xs_dot)) out_flat = scan_p.bind(*(consts + consts_dot + init + init_dot + xs + xs_dot), forward=forward, length=length, jaxpr=jaxpr_jvp_rearranged, num_consts=num_consts + len(consts_dot), num_carry=num_carry + len(init_dot), linear=jaxpr_jvp_linear) carry, carry_dot, ys, ys_dot = split_list( out_flat, [num_carry, len(init_dot), num_ys]) primals_out = carry + ys tangents_out = iter(carry_dot + ys_dot) tangents_out = [ next(tangents_out) if nz else ad_util.zero for nz in nonzeros_out ] return primals_out, tangents_out
def _scan_sparse(spenv, *spvalues, jaxpr, num_consts, num_carry, **params): const_spvalues, carry_spvalues, xs_spvalues = split_list( spvalues, [num_consts, num_carry]) if xs_spvalues: # TODO(jakevdp): we don't want to pass xs_spvalues, we want to pass one row # of xs spvalues. How to do this? raise NotImplementedError("sparse rule for scan with x values.") sp_jaxpr, _ = _sparsify_jaxpr(spenv, jaxpr, *const_spvalues, *carry_spvalues, *xs_spvalues) consts, _ = tree_flatten(spvalues_to_arrays(spenv, const_spvalues)) carry, carry_tree = tree_flatten(spvalues_to_arrays(spenv, carry_spvalues)) xs, xs_tree = tree_flatten(spvalues_to_arrays(spenv, xs_spvalues)) # params['linear'] has one entry per arg; expand it to match the sparsified args. const_linear, carry_linear, xs_linear = split_list( params.pop('linear'), [num_consts, num_carry]) sp_linear = tuple([ *_duplicate_for_sparse_spvalues(const_spvalues, const_linear), *_duplicate_for_sparse_spvalues(carry_spvalues, carry_linear), *_duplicate_for_sparse_spvalues(xs_spvalues, xs_linear)]) out = lax.scan_p.bind(*consts, *carry, *xs, jaxpr=sp_jaxpr, linear=sp_linear, num_consts=len(consts), num_carry=len(carry), **params) carry_out = tree_unflatten(carry_tree, out[:len(carry)]) xs_out = tree_unflatten(xs_tree, out[len(carry):]) return arrays_to_spvalues(spenv, carry_out + xs_out)
def _cond_batching_rule(args, dims, true_jaxpr, false_jaxpr, true_nconsts, false_nconsts): # TODO: maybe avoid moving arg axes to front if we're promoting to select? args = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0 else x for x, d in zip(args, dims)] true_nops = len(true_jaxpr.in_avals) - true_nconsts (pred,), true_consts, true_ops, false_consts, false_ops = split_list( args, [1, true_nconsts, true_nops, false_nconsts]) size, = {x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped} orig_bat = [d is not batching.not_mapped for d in dims] (pred_bat,), tconst_bat, t_bat, fconst_bat, f_bat = split_list( orig_bat, [1, true_nconsts, true_nops, false_nconsts]) _, true_out_bat = batching.batch_jaxpr(true_jaxpr, size, tconst_bat + t_bat, False) _, false_out_bat = batching.batch_jaxpr(false_jaxpr, size, fconst_bat + f_bat, False) out_bat = [a or b for a, b in zip(true_out_bat, false_out_bat)] true_jaxpr_batched, _ = batching.batch_jaxpr(true_jaxpr, size, tconst_bat + t_bat, out_bat) false_jaxpr_batched, _ = batching.batch_jaxpr(false_jaxpr, size, fconst_bat + f_bat, out_bat) if pred_bat: true_out = core.jaxpr_as_fun(true_jaxpr_batched)(*(true_consts + true_ops)) false_out = core.jaxpr_as_fun(false_jaxpr_batched)(*(false_consts + false_ops)) true_out = [batching.broadcast(x, size, 0) if not b else x for x, b in zip(true_out, out_bat)] false_out = [batching.broadcast(x, size, 0) if not b else x for x, b in zip(false_out, out_bat)] return [_cond_pred_bcast_select(pred, t, f) for t, f in zip(true_out, false_out)], [0] * len(true_out) else: out_dims = [0 if b else batching.not_mapped for b in out_bat] return cond_p.bind( *itertools.chain([pred], true_consts, true_ops, false_consts, false_ops), true_jaxpr=true_jaxpr_batched, false_jaxpr=false_jaxpr_batched, true_nconsts=len(true_consts), false_nconsts=len(false_consts)), out_dims
def _scan_transpose(cts, *args, **kwargs): forward, length, num_consts, num_carry, jaxpr, linear = split_dict( kwargs, ["forward", "length", "num_consts", "num_carry", "jaxpr", "linear"]) # we can only transpose scans for which the nonlinear values appear in xs consts_lin, init_lin, xs_lin = split_list(linear, [num_consts, num_carry]) num_lin = sum(xs_lin) if not all(consts_lin) or not all(init_lin) or not all(xs_lin[:num_lin]): raise NotImplementedError consts, init, xs, res = split_list(args, [num_consts, num_carry, num_lin]) assert not any(r is ad.undefined_primal for r in res) carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry]) ys_avals = _map(partial(_promote_aval_rank, length), y_avals) ct_carry, ct_ys = split_list(cts, [num_carry]) ct_carry = _map(ad.instantiate_zeros_aval, carry_avals, ct_carry) ct_ys = _map(ad.instantiate_zeros_aval, ys_avals, ct_ys) ct_consts = _map(ad_util.zeros_like_aval, jaxpr.in_avals[:num_consts]) # jaxpr :: [T d] -> [T c] -> [T a, res] -> ([T c], [T b]) # jaxpr_trans :: [] -> [CT d, CT c] -> [CT b, res] -> ([CT d, CT c], [CT a]) jaxpr_trans = _transpose_jaxpr(num_consts, len(res), jaxpr) linear_trans = ([True] * (len(ct_consts) + len(ct_carry) + len(ct_ys)) + [False] * len(res)) outs = scan_p.bind( *(ct_consts + ct_carry + ct_ys + res), forward=not forward, length=length, jaxpr=jaxpr_trans, num_consts=0, num_carry=num_consts+num_carry, linear=linear_trans) ct_consts, ct_init, ct_xs = split_list(outs, [num_consts, num_carry]) return ct_consts + ct_init + ct_xs + [None] * len(res)
def scan_bind(*args, **kwargs): forward, length, num_consts, num_carry, jaxpr, linear = split_dict( kwargs, ["forward", "length", "num_consts", "num_carry", "jaxpr", "linear"]) consts, init, xs = split_list(args, [num_consts, num_carry]) assert len(linear) == len(args) # check that args match input types consts_avals, init_avals, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry]) xs_avals = _map(partial(_promote_aval_rank, length), x_avals) assert all(_map(typecheck, consts_avals, consts)) assert all(_map(typecheck, init_avals, init)) # assert all(_map(typecheck, xs_avals, xs)) # check that output carry type matches input carry type carry_avals, _ = split_list(jaxpr.out_avals, [num_carry]) assert all(_map(typematch, init_avals, carry_avals)) # check that the data flow is sensible core.check_jaxpr(jaxpr.jaxpr) return core.Primitive.bind(scan_p, *args, forward=forward, length=length, jaxpr=jaxpr, num_consts=num_consts, num_carry=num_carry, linear=linear)
def _while_loop_batching_rule(args, dims, cond_nconsts, cond_jaxpr, body_nconsts, body_jaxpr): size, = {x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped} orig_batched = [d is not batching.not_mapped for d in dims] cconst_bat, bconst_bat, init_bat = split_list(orig_batched, [cond_nconsts, body_nconsts]) carry_bat = init_bat for _ in range(1000): batched = bconst_bat + carry_bat body_jaxpr_batched, carry_bat_out = batching.batch_jaxpr( body_jaxpr, size, batched, instantiate=carry_bat) cond_jaxpr_batched, (pred_bat,) = batching.batch_jaxpr( cond_jaxpr, size, cconst_bat + carry_bat, instantiate=False) carry_bat_out = _map(partial(operator.or_, pred_bat), carry_bat_out) if carry_bat_out == carry_bat: break else: carry_bat = carry_bat_out else: raise FixedPointError consts, init = split_list(args, [cond_nconsts + body_nconsts]) const_dims, init_dims = split_list(dims, [cond_nconsts + body_nconsts]) new_consts = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0 else x for x, d in zip(consts, const_dims)] new_init = [batching.broadcast(x, size, 0) if now_bat and not was_bat else batching.moveaxis(x, d, 0) if now_bat else x for x, d, was_bat, now_bat in zip(init, init_dims, init_bat, carry_bat)] outs = while_p.bind(*(new_consts + new_init), cond_nconsts=cond_nconsts, cond_jaxpr=cond_jaxpr_batched, body_nconsts=body_nconsts, body_jaxpr=body_jaxpr_batched) out_bdims = [0 if b else batching.not_mapped for b in carry_bat] return outs, out_bdims
def body_fun(i, vals): i = i if forward else length - i - 1 carry, ys = split_list(vals, [num_carry]) x = _map(partial(_index_array, i), x_avals, xs) out_flat = core.jaxpr_as_fun(jaxpr)(*(consts + carry + x)) carry_out, y_updates = split_list(out_flat, [num_carry]) ys_out = _map(partial(_update_array, i), y_avals, ys, y_updates) return carry_out + ys_out
def masked(*args): [dynamic_length], consts, [i], carry, xs = split_list( args, [1, num_consts, 1, num_carry]) out = fun(*(consts + carry + xs)) new_carry, ys = split_list(out, [num_carry]) new_carry = [lax.select(i < dynamic_length, new_c, c) for new_c, c in zip(new_carry, carry)] return [i + 1] + new_carry + ys
def _while_loop_translation_rule(c, axis_env, *args, **kwargs): cond_jaxpr, body_jaxpr, cond_nconsts, body_nconsts = split_dict( kwargs, ["cond_jaxpr", "body_jaxpr", "cond_nconsts", "body_nconsts"]) cond_consts, body_consts, init_vals = split_list( args, [cond_nconsts, body_nconsts]) batched = bool(cond_jaxpr.out_avals[0].shape) # Since jaxprs don't have tuples and have multiple return values, but we need # the HLO While loop to take a single tuple input and output a single boolean # (for the cond computation) or a single tuple output (for the body # computation), we build XLA computations that handle the tuple munging before # generating a Call into the computations formed from the jaxprs. init_carry = c.Tuple(*(cond_consts + body_consts + init_vals)) cond_c = xb.make_computation_builder("cond_computation") cond_carry = cond_c.ParameterWithShape(c.GetShape(init_carry)) cond_carry_elts = [ cond_c.GetTupleElement(cond_carry, i) for i in range(len(args)) ] x, _, z = split_list(cond_carry_elts, [cond_nconsts, body_nconsts]) cond_outs = cond_c.Call( xla.jaxpr_computation(cond_jaxpr.jaxpr, axis_env, cond_jaxpr.literals, (), *_map(cond_c.GetShape, x + z)), x + z) pred = cond_c.GetTupleElement(cond_outs, 0) if batched: scalar = xla_client.Shape.array_shape(onp.dtype(onp.bool_), ()) or_ = xla.primitive_computation(lax.or_p, scalar, scalar) pred = cond_c.Reduce(pred, cond_c.Constant(onp.array(False)), or_, list(range(cond_jaxpr.out_avals[0].ndim))) body_c = xb.make_computation_builder("body_computation") body_carry = body_c.ParameterWithShape(c.GetShape(init_carry)) body_carry_elts = [ body_c.GetTupleElement(body_carry, i) for i in range(len(args)) ] x, y, z = split_list(body_carry_elts, [cond_nconsts, body_nconsts]) body_out = body_c.Call( xla.jaxpr_computation(body_jaxpr.jaxpr, axis_env, body_jaxpr.literals, (), *_map(body_c.GetShape, y + z)), y + z) new_z = [ body_c.GetTupleElement(body_out, i) for i in range(len(init_vals)) ] if batched: body_cond_outs = body_c.Call( xla.jaxpr_computation(cond_jaxpr.jaxpr, axis_env, cond_jaxpr.literals, (), *_map(body_c.GetShape, x + z)), x + z) body_pred = body_c.GetTupleElement(body_cond_outs, 0) new_z = _map(partial(_pred_bcast_select, body_c, body_pred), new_z, z) assert _map(body_c.GetShape, new_z) == _map(body_c.GetShape, z) # no broadcast new_carry = body_c.Tuple(*(x + y + new_z)) ans = c.While(cond_c.Build(pred), body_c.Build(new_carry), init_carry) ans_elts = [c.GetTupleElement(ans, i) for i in range(len(args))] _, _, z = split_list(ans_elts, [cond_nconsts, body_nconsts]) return c.Tuple(*z)
def _scan_batching_rule(args, dims, forward, length, jaxpr, num_consts, num_carry, linear): num_ys = len(jaxpr.out_avals) - num_carry size, = { x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped } orig_batched = [d is not batching.not_mapped for d in dims] const_batched, init_batched, xs_batched = split_list( orig_batched, [num_consts, num_carry]) carry_batched = init_batched for _ in range(1000): batched = const_batched + carry_batched + xs_batched jaxpr_batched, batched_out = batching.batch_jaxpr( jaxpr, size, batched, instantiate=carry_batched + [False] * num_ys) carry_batched_out, ys_batched = batched_out[:num_carry], batched_out[ num_carry:] if carry_batched_out == carry_batched: break else: carry_batched = carry_batched_out else: raise FixedPointError consts, init, xs = split_list(args, [num_consts, num_carry]) consts_bdims, init_bdims, xs_bdims = split_list(dims, [num_consts, num_carry]) new_consts = [ batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0 else x for x, d in zip(consts, consts_bdims) ] new_init = [ batching.broadcast(x, size, 0) if now_batched and not was_batched else batching.moveaxis(x, d, 0) if now_batched else x for x, d, was_batched, now_batched in zip(init, init_bdims, init_batched, carry_batched) ] new_xs = [ batching.moveaxis(x, d, 1) if d is not batching.not_mapped and d != 1 else x for x, d in zip(xs, xs_bdims) ] new_args = new_consts + new_init + new_xs outs = scan_p.bind(*new_args, forward=forward, length=length, jaxpr=jaxpr_batched, num_consts=num_consts, num_carry=num_carry, linear=linear) carry_bdims = [0 if b else batching.not_mapped for b in carry_batched] ys_bdims = [1 if b else batching.not_mapped for b in ys_batched] return outs, carry_bdims + ys_bdims
def transposed(*cbar_bbar_res): c_bar, b_bar, res = split_list(cbar_bbar_res, [num_c, num_b]) primals = [ad.undefined_primal] * (num_c + num_a) + res _, cbar_abar = ad.backward_pass(jaxpr.jaxpr, jaxpr.literals, (), primals, b_bar) new_c_bar, a_bar, _ = split_list(cbar_abar, [num_c, num_a]) a_bar = _map(ad.instantiate_zeros_aval, a_avals, a_bar) c_bar = _map(ad.instantiate_zeros_aval, c_avals, _map(ad.add_tangents, c_bar, new_c_bar)) return c_bar + a_bar
def _scan_polymorphic_shape_rule(shape_exprs, forward, length, jaxpr, num_consts, num_carry, linear): const_shexprs, init_shexprs, xs_shexprs = split_list( shape_exprs, [num_consts, num_carry]) if (any(any(type(d) is Id for d in shexpr) for shexpr in const_shexprs) or any( any(type(d) is Id for d in shexpr) for shexpr in init_shexprs) or any( any(type(d) is Id for d in shexpr[1:]) for shexpr in xs_shexprs)): raise NotImplementedError _, y_avals = split_list(jaxpr.out_avals, [num_carry]) ys_shapes = [ShapeExpr(length, *y_aval.shape) for y_aval in y_avals] return init_shexprs + ys_shapes
def _scan_masking_rule(shape_envs, padded_vals, shape_exprs, forward, length, jaxpr, num_consts, num_carry, linear): out_shape = _scan_polymorphic_shape_rule(shape_exprs, forward, length, jaxpr, num_consts, num_carry, linear) dynamic_length = masking.eval_dim_expr(shape_envs.logical, length) masked_jaxpr = _masked_scan_jaxpr(jaxpr, num_consts, num_carry) consts, init, xs = split_list(padded_vals, [num_consts, num_carry]) max_length, = {x.shape[0] for x in xs} const_linear, init_linear, xs_linear = split_list(linear, [num_consts, num_carry]) out_vals = scan_p.bind( *itertools.chain([dynamic_length] + consts, [0], init, xs), forward=forward, length=max_length, jaxpr=masked_jaxpr, num_consts=1 + num_consts, num_carry=1 + num_carry, linear=[False] + const_linear + [False] + init_linear + xs_linear) return out_vals[1:], out_shape
def _scan_partial_eval(trace, *tracers, **kwargs): forward, length, num_consts, num_carry, jaxpr, linear = split_dict( kwargs, ["forward", "length", "num_consts", "num_carry", "jaxpr", "linear"]) num_xs = len(jaxpr.in_avals) - num_carry - num_consts num_ys = len(jaxpr.out_avals) - num_carry unknowns = original_unknowns = [t.pval[0] is not None for t in tracers] const_uk, init_uk, xs_uk = split_list(unknowns, [num_consts, num_carry]) carry_uk = init_uk for _ in range(1000): unknowns = const_uk + carry_uk + xs_uk jaxpr_1, jaxpr_2, out_uk = pe.partial_eval_jaxpr( jaxpr, unknowns, instantiate=carry_uk + [False] * num_ys) carry_uk_out, ys_uk = out_uk[:num_carry], out_uk[num_carry:] if carry_uk_out == carry_uk: break else: carry_uk = carry_uk_out else: raise FixedPointError in_consts = [core.unit if uk else t.pval[1] for uk, t in zip(unknowns, tracers)] new_tracers = [trace.instantiate_const(t) if uk else trace.new_instantiated_literal(core.unit) for uk, t in zip(unknowns, tracers)] carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry]) ys_avals = _map(partial(_promote_aval_rank, length), y_avals) out_avals = carry_avals + ys_avals out_pvs = [aval if uk else None for aval, uk in zip(out_avals, out_uk)] linear_1 = [lin or uk for uk, lin in zip(unknowns, linear)] out_flat = scan_p.bind( *in_consts, forward=forward, length=length, jaxpr=jaxpr_1, num_consts=num_consts, num_carry=num_carry, linear=linear_1) out_carry, ys, residuals = split_list(out_flat, [num_carry, num_ys]) out_consts = out_carry + ys residual_tracers = _map(trace.new_instantiated_const, residuals) out_tracers = [pe.JaxprTracer(trace, pe.PartialVal((pv, const)), None) for pv, const in zip(out_pvs, out_consts)] linear_2 = ([lin or not uk for uk, lin in zip(unknowns, linear)] + [False] * len(residual_tracers)) eqn = pe.new_jaxpr_eqn(new_tracers + residual_tracers, out_tracers, scan_p, (), dict(forward=forward, length=length, jaxpr=jaxpr_2, num_consts=num_consts, num_carry=num_carry, linear=linear_2)) for t in out_tracers: t.recipe = eqn return out_tracers
def initial_ildj(incells, outcells, *, jaxpr, num_consts, **_): const_cells, incells = jax_util.split_list(incells, [num_consts]) env, state = propagate.propagate(InverseAndILDJ, ildj_registry, jaxpr, const_cells, incells, outcells) # pytype: disable=wrong-arg-types new_incells = [env.read(invar) for invar in jaxpr.invars] new_outcells = [env.read(outvar) for outvar in jaxpr.outvars] return const_cells + new_incells, new_outcells, state
def _cond_translation_rule(c, axis_env, pred, *args, **kwargs): backend = kwargs.pop("backend", None) true_jaxpr, false_jaxpr, true_nconsts, false_nconsts = split_dict( kwargs, ["true_jaxpr", "false_jaxpr", "true_nconsts", "false_nconsts"]) true_nops = len(true_jaxpr.in_avals) - true_nconsts true_consts, true_ops, false_consts, false_ops = split_list( args, [true_nconsts, true_nops, false_nconsts]) def make_computation(name, jaxpr, op_shape): c = xb.make_computation_builder(name) op = c.ParameterWithShape(op_shape) ops = [c.GetTupleElement(op, i) for i in range(len(jaxpr.in_avals))] out = c.Call( xla.jaxpr_computation(jaxpr.jaxpr, backend, axis_env, jaxpr.literals, (), *_map(c.GetShape, ops)), ops) return c.Build(out) true_op = c.Tuple(*(true_consts + true_ops)) true_c = make_computation("true_comp", true_jaxpr, c.GetShape(true_op)) false_op = c.Tuple(*(false_consts + false_ops)) false_c = make_computation("false_comp", false_jaxpr, c.GetShape(false_op)) return c.Conditional(pred, true_op, true_c, false_op, false_c)
def go_scan(body, length, xs, init, consts, reverse): num_carry = len(init) if xs is None: xs = [None] * length if reverse: xs = list(map(_reverse, xs)) carry = init ys = [] zxs = _zip(xs) for x in zxs: res = _interpret_jaxpr(body, (), *consts, *carry, *x) carry, y = split_list(res, [num_carry]) ys.append(y) _, yavals = split_list(body.outvars, [num_carry]) ys = list(map(lambda *x: _stack(*x, reverse), yavals, zip(*ys))) return [*carry, *ys]
def _while_sparse(spenv, *argspecs, cond_jaxpr, cond_nconsts, body_jaxpr, body_nconsts): cond_const_argspecs, body_const_argspecs, init_val_argspecs = split_list( argspecs, [cond_nconsts, body_nconsts]) cond_sp_jaxpr, _ = _sparsify_jaxpr(spenv, cond_jaxpr, *cond_const_argspecs, *init_val_argspecs) body_sp_jaxpr, out_tree = _sparsify_jaxpr(spenv, body_jaxpr, *body_const_argspecs, *init_val_argspecs) cond_consts, _ = tree_flatten( argspecs_to_arrays(spenv, cond_const_argspecs)) body_consts, _ = tree_flatten( argspecs_to_arrays(spenv, body_const_argspecs)) init_vals, _ = tree_flatten(argspecs_to_arrays(spenv, init_val_argspecs)) out_flat = lax.while_p.bind(*cond_consts, *body_consts, *init_vals, cond_nconsts=len(cond_consts), cond_jaxpr=cond_sp_jaxpr, body_nconsts=len(body_consts), body_jaxpr=body_sp_jaxpr) return arrays_to_argspecs(spenv, tree_unflatten(out_tree, out_flat))
def converted_fun(y, t, *hconsts_args): hoisted_consts, args = split_list(hconsts_args, [num_consts]) consts = merge(closure_consts, hoisted_consts) all_args, in_tree2 = tree_flatten((y, t, *args)) assert in_tree == in_tree2 out_flat = core.eval_jaxpr(jaxpr, consts, *all_args) return tree_unflatten(out_tree, out_flat)
def initial_ildj(incells, outcells, *, jaxpr, num_consts, **_): const_cells, incells = jax_util.split_list(incells, [num_consts]) env = propagate.propagate(InverseAndILDJ, ildj_registry, jaxpr, const_cells, incells, outcells) new_incells = [env.read(invar) for invar in jaxpr.invars] new_outcells = [env.read(outvar) for outvar in jaxpr.outvars] return const_cells + new_incells, new_outcells, None
def random_variable_log_prob(flat_incells, val, *, num_consts, in_tree, **_): """Registers Oryx distributions with the log_prob transformation.""" _, flat_incells = jax_util.split_list(flat_incells, [num_consts]) _, dist = tree_util.tree_unflatten(in_tree, flat_incells) if any(not cell.top() for cell in flat_incells[1:] if isinstance(val, InverseAndILDJ)): return None return dist.log_prob(val)
def _custom_cell_scan_impl(flat_cell, *args, **kwargs): """lax_control_flow._scan_impl, but allowing for a custom cell function.""" reverse, length, num_consts, num_carry, jaxpr, linear, unroll = split_dict( kwargs, ["reverse", "length", "num_consts", "num_carry", "jaxpr", "linear", "unroll"]) consts, init, xs = split_list(args, [num_consts, num_carry]) _, _, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry]) cell_args = consts + init + map(partial(_index_array, 0), x_avals, xs) jaxpr, new_consts = _flat_initial_style_jaxpr(wrap_init(flat_cell), _abstractified(cell_args)) args = list(new_consts) + init + xs kwargs['jaxpr'] = jaxpr kwargs['num_consts'] = len(new_consts) kwargs['linear'] = (False,) * len(args) return scan_p.bind(*args, **kwargs)
def random_variable_log_prob(flat_invals, val, **params): """Registers Oryx distributions with the log_prob transformation.""" num_consts = len(flat_invals) - params['num_args'] _, flat_invals = jax_util.split_list(flat_invals, [num_consts]) _, dist = tree_util.tree_unflatten(params['in_tree'], flat_invals) if any(val.is_unknown() for val in flat_invals[1:] if isinstance(val, InverseAndILDJ)): return None return dist.log_prob(val)
def ildj_rule(incells, outcells, *, in_tree, out_tree, num_args, **_): # First incell is a wrapped function because prim is a call primitive. incells = incells[1:] num_consts = len(incells) - num_args const_incells, incells = jax_util.split_list(incells, [num_consts]) out_tree = out_tree() if (all(outcell.top() for outcell in outcells) and any(not incell.top() for incell in incells)): flat_outvals = [outcell.val for outcell in outcells] flat_outildjs = [outcell.ildj for outcell in outcells] outvals = tree_util.tree_unflatten(out_tree, flat_outvals) outildjs = tree_util.tree_unflatten(out_tree, flat_outildjs) flat_invals = [ None if not incell.top() else incell.val for incell in incells ] invals = tree_util.tree_unflatten(in_tree, flat_invals) try: new_invals, new_ildjs = f_ildj(invals, outvals, outildjs) except NonInvertibleError: return const_incells + incells, outcells, None # We need to flatten the output from `f_ildj` using # `tree_util.tree_flatten` but if the user returns `None` (when # inversion is not possible), JAX will remove `None`s from the flattened # version and the number of `new_incells` will not match the old # `incells`. We use the private `_replace_nones` feature in JAX to # replace it with a sentinel that won't be removed when flattening. none_ = object() new_invals = tree_util._replace_nones(none_, new_invals) # pylint: disable=protected-access new_ildjs = tree_util._replace_nones(none_, new_ildjs) # pylint: disable=protected-access new_flat_invals = tree_util.tree_leaves(new_invals) new_flat_ildjs = tree_util.tree_leaves(new_ildjs) inslices = [ NDSlice.new(inval, ildj) for inval, ildj in zip(new_flat_invals, new_flat_ildjs) ] new_incells = [] for new_flat_inval, old_incell, inslice in zip( new_flat_invals, incells, inslices): if new_flat_inval is not none_: new_incells.append( InverseAndILDJ(old_incell.aval, [inslice])) else: new_incells.append(old_incell) return const_incells + new_incells, outcells, None elif (all(incell.top() for incell in incells) and any(not outcell.top() for outcell in outcells)): flat_invals = [incell.val for incell in incells] invals = tree_util.tree_unflatten(in_tree, flat_invals) outvals = self(*invals) flat_outvals = tree_util.tree_leaves(outvals) outcells = [ InverseAndILDJ.new(outval) for outval in flat_outvals ] return const_incells + incells, outcells, None return const_incells + incells, outcells, None
def _scan_impl(*args, **kwargs): forward, length, num_consts, num_carry, jaxpr, linear = split_dict( kwargs, ["forward", "length", "num_consts", "num_carry", "jaxpr", "linear"]) consts, init, xs = split_list(args, [num_consts, num_carry]) _, _, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry]) _, y_avals = split_list(jaxpr.out_avals, [num_carry]) def body_fun(i, vals): i = i if forward else length - i - 1 carry, ys = split_list(vals, [num_carry]) x = _map(partial(_index_array, i), x_avals, xs) out_flat = core.jaxpr_as_fun(jaxpr)(*(consts + carry + x)) carry_out, y_updates = split_list(out_flat, [num_carry]) ys_out = _map(partial(_update_array, i), y_avals, ys, y_updates) return carry_out + ys_out ys_init = _map(partial(_empty_array, length), y_avals) return fori_loop(0, length, body_fun, init + ys_init)
def _reap_while_rule(trace: HarvestTrace, *tracers, cond_jaxpr, body_jaxpr, cond_nconsts, body_nconsts): """Reaps the body of a while loop to get the reaps of the final iteration.""" cond_const_tracers, body_const_tracers, init_tracers = jax_util.split_list( tracers, [cond_nconsts, body_nconsts]) _, init_avals = tree_util.tree_map(lambda x: x.aval, (body_const_tracers, init_tracers)) cond_const_vals, body_const_vals, init_vals = tree_util.tree_map( lambda x: x.val, (cond_const_tracers, body_const_tracers, init_tracers)) context = trace_util.get_dynamic_context(trace) settings = context.settings body_metadata = _get_harvest_metadata(body_jaxpr, settings, *(body_const_tracers + init_tracers)) for k, meta in body_metadata.items(): mode = meta['mode'] if mode != 'clobber': raise ValueError( f'Must use clobber mode for \'{k}\' inside of a `while_loop`.') reap_avals = {k: v['aval'] for k, v in body_metadata.items()} cond_fun = jax_core.jaxpr_as_fun(cond_jaxpr) body_fun = jax_core.jaxpr_as_fun(body_jaxpr) reap_settings = dict(tag=settings.tag, allowlist=settings.allowlist, blocklist=settings.blocklist, exclusive=settings.exclusive) def new_cond(carry, _): return cond_fun(*(cond_const_vals + carry)) def new_body(carry, _): carry, reaps = call_and_reap( body_fun, **reap_settings)(*(body_const_vals + carry)) return (carry, reaps) new_in_avals, new_in_tree = tree_util.tree_flatten( (init_avals, reap_avals)) new_cond_jaxpr, cond_consts, _ = lcf._initial_style_jaxpr( # pylint: disable=protected-access new_cond, new_in_tree, tuple(new_in_avals)) new_body_jaxpr, body_consts, out_tree = lcf._initial_style_jaxpr( # pylint: disable=protected-access new_body, new_in_tree, tuple(new_in_avals)) dummy_reap_vals = tree_util.tree_map(lambda x: jnp.zeros(x.shape, x.dtype), reap_avals) new_in_vals = tree_util.tree_leaves((init_vals, dummy_reap_vals)) out = lax.while_p.bind(*(cond_consts + body_consts + new_in_vals), cond_nconsts=len(cond_consts), body_nconsts=len(body_consts), cond_jaxpr=new_cond_jaxpr, body_jaxpr=new_body_jaxpr) out = jax_util.safe_map(trace.pure, out) out, reaps = tree_util.tree_unflatten(out_tree, out) for k, v in reaps.items(): sow(v, name=k, tag=settings.tag, mode=body_metadata[k]['mode']) return out
def new_body(carry, x): x, plants = x all_plants = {**plants, **clobber_plants} all_values = const_vals + tree_util.tree_leaves((carry, x)) out = plant(body_fun, tag=settings.tag, allowlist=settings.allowlist, blocklist=settings.blocklist, exclusive=settings.exclusive)(all_plants, *all_values) carry_out, y = jax_util.split_list(out, [num_carry]) return carry_out, y
def _cond_impl(pred, *args, **kwargs): true_jaxpr, false_jaxpr, true_nconsts, false_nconsts = split_dict( kwargs, ["true_jaxpr", "false_jaxpr", "true_nconsts", "false_nconsts"]) true_consts, true_ops, false_consts, false_ops = split_list( args, [true_nconsts, len(true_jaxpr.in_avals), false_nconsts]) if pred: return core.jaxpr_as_fun(true_jaxpr)(*(true_consts + true_ops)) else: return core.jaxpr_as_fun(false_jaxpr)(*(false_consts + false_ops))
def bijector_ildj_rule(incells, outcells, *, in_tree, num_consts, direction, num_bijector, **_): """Inverse/ILDJ rule for bijectors.""" const_incells, flat_incells = jax_util.split_list(incells, [num_consts]) flat_bijector_cells, arg_incells = jax_util.split_list( flat_incells, [num_bijector]) if any(not cell.top() for cell in flat_bijector_cells): return (const_incells + flat_incells, outcells, None) flat_inproxies = safe_map(_CellProxy, flat_incells) _, inproxy = tree_util.tree_unflatten(in_tree, flat_inproxies) bijector_vals = [cell.val for cell in flat_bijector_cells] bijector, _ = tree_util.tree_unflatten( in_tree, bijector_vals + [None] * len(arg_incells)) if direction == 'forward': forward_func = bijector.forward inv_func = bijector.inverse ildj_func = bijector.inverse_log_det_jacobian elif direction == 'inverse': forward_func = bijector.inverse inv_func = bijector.forward ildj_func = bijector.forward_log_det_jacobian else: raise ValueError('Bijector direction must be ' '"forward" or "inverse".') outcell, = outcells incell = inproxy.cell if incell.bottom() and not outcell.bottom(): val, ildj = outcell.val, outcell.ildj inildj = ildj + ildj_func(val, np.ndim(val)) ndslice = NDSlice.new(inv_func(val), inildj) flat_incells = [InverseAndILDJ(incell.aval, [ndslice])] new_outcells = outcells elif outcell.is_unknown() and not incell.is_unknown(): new_outcells = [InverseAndILDJ.new(forward_func(incell.val))] new_incells = flat_bijector_cells + flat_incells return (const_incells + new_incells, new_outcells, None)