def _cond_jvp(primals, tangents, branches, linear): nonzeros = [type(t) is not ad_util.Zero for t in tangents] index_nz, *ops_nz = nonzeros assert index_nz is False branches_out_nz = [ ad.jvp_jaxpr(jaxpr, ops_nz, instantiate=False)[1] for jaxpr in branches ] out_nz = [any(nz) for nz in zip(*branches_out_nz)] branches_jvp = tuple( ad.jvp_jaxpr(jaxpr, ops_nz, instantiate=out_nz)[0] for jaxpr in branches) index, *ops = primals _, *ops_dot = tangents ops_dot = _prune_zeros(ops_dot) ops_lin = tuple(linear) linear_jvp = ops_lin + (True, ) * len(ops_dot) out = cond_p.bind(index, *ops, *ops_dot, branches=branches_jvp, linear=linear_jvp) out_primals, out_tangents = split_list(out, [len(out_nz)]) out_tangents_iter = iter(out_tangents) out_tangents = [ next(out_tangents_iter) if nz else ad_util.Zero.from_value(p) for p, nz in zip(out_primals, out_nz) ] return out_primals, out_tangents
def _scan_jvp(primals, tangents, forward, length, jaxpr, num_consts, num_carry, linear): num_xs = len(jaxpr.in_avals) - num_carry - num_consts num_ys = len(jaxpr.out_avals) - num_carry nonzeros = [t is not ad_util.zero for t in tangents] const_nz, init_nz, xs_nz = split_list(nonzeros, [num_consts, num_carry]) carry_nz = init_nz for _ in range(1000): nonzeros = const_nz + carry_nz + xs_nz jaxpr_jvp, nonzeros_out = ad.jvp_jaxpr(jaxpr, nonzeros, instantiate=carry_nz + [False] * num_ys) carry_nz_out, ys_nz = nonzeros_out[:num_carry], nonzeros_out[ num_carry:] if carry_nz_out == carry_nz: break else: carry_nz = carry_nz_out else: raise FixedPointError tangents = [ ad.instantiate_zeros(x, t) if t is ad_util.zero and nz else t for x, t, nz in zip(primals, tangents, nonzeros) ] consts, init, xs = split_list(primals, [num_consts, num_carry]) all_tangents = split_list(tangents, [num_consts, num_carry]) consts_dot, init_dot, xs_dot = _map(_prune_zeros, all_tangents) jaxpr_jvp_rearranged = ad.rearrange_binders( jaxpr_jvp, [num_consts, num_carry, num_xs], [len(consts_dot), len(init_dot), len(xs_dot)], [num_carry, num_ys], [len(init_dot), sum(nonzeros_out) - len(init_dot)]) consts_linear, init_linear, xs_linear = split_list(linear, [num_consts, num_carry]) jaxpr_jvp_linear = (consts_linear + [True] * len(consts_dot) + init_linear + [True] * len(init_dot) + xs_linear + [True] * len(xs_dot)) out_flat = scan_p.bind(*(consts + consts_dot + init + init_dot + xs + xs_dot), forward=forward, length=length, jaxpr=jaxpr_jvp_rearranged, num_consts=num_consts + len(consts_dot), num_carry=num_carry + len(init_dot), linear=jaxpr_jvp_linear) carry, carry_dot, ys, ys_dot = split_list( out_flat, [num_carry, len(init_dot), num_ys]) primals_out = carry + ys tangents_out = iter(carry_dot + ys_dot) tangents_out = [ next(tangents_out) if nz else ad_util.zero for nz in nonzeros_out ] return primals_out, tangents_out
def _for_jvp(primals, tangents, *, jaxpr, nsteps, reverse, which_linear): nonzero_tangents = [not isinstance(t, ad_util.Zero) for t in tangents] # We need to find out which `Ref`s have nonzero tangents after running the # for loop. Ordinarily we do this with a fixed point on the body jaxpr but # a `for` body jaxpr is stateful and has no outputs. We therefore discharge # the state effect from the jaxpr and we will now have a "symmetric" jaxpr # where the inputs line up with the outputs. We use this discharged jaxpr # for the fixed point. discharged_jaxpr, body_consts = discharge_state(jaxpr, ()) for _ in range(len(nonzero_tangents)): _, out_nonzero_tangents = ad.jvp_jaxpr(core.ClosedJaxpr( discharged_jaxpr, body_consts), [False] + nonzero_tangents, instantiate=nonzero_tangents) if out_nonzero_tangents == nonzero_tangents: break nonzero_tangents = map(operator.or_, nonzero_tangents, out_nonzero_tangents) else: raise Exception("Invalid fixpoint") tangents = [ ad.instantiate_zeros(t) if inst else t for t, inst in zip(tangents, nonzero_tangents) ] tangents = [t for t in tangents if type(t) is not ad_util.Zero] closed_jaxpr = core.ClosedJaxpr(jaxpr, ()) jvp_jaxpr_, _ = ad.jvp_jaxpr(closed_jaxpr, [False] + nonzero_tangents, []) jvp_jaxpr, jvp_consts = jvp_jaxpr_.jaxpr, jvp_jaxpr_.consts jvp_which_linear = ((False, ) * len(jvp_consts) + which_linear + (True, ) * len(tangents)) out_flat = for_p.bind(*jvp_consts, *primals, *tangents, jaxpr=jvp_jaxpr, nsteps=nsteps, reverse=reverse, which_linear=jvp_which_linear) # `out_flat` includes constant inputs into the `for_loop` which are # converted into outputs as well. We don't care about these in AD so we # throw them out. _, out_primals, out_tangents = split_list( out_flat, [len(jvp_consts), len(primals)]) out_tangents_iter = iter(out_tangents) out_tangents = [ next(out_tangents_iter) if nz else ad_util.Zero.from_value(p) for p, nz in zip(out_primals, nonzero_tangents) ] return out_primals, out_tangents
def _scan_jvp(primals, tangents, forward, length, jaxpr): consts, init, xs = primals consts_dot, init_dot, xs_dot = tangents consts_aval, carry_aval, x_aval = jaxpr.in_avals _, y_aval = jaxpr.out_aval consts_nonzeros = ad.get_nonzeros(consts_dot) init_nonzeros = ad.get_nonzeros(init_dot) xs_nonzeros = ad.get_nonzeros(xs_dot) # same as x_nonzeros b/c arrays carry_nonzeros = init_nonzeros for _ in range(1000): nonzeros = (consts_nonzeros, carry_nonzeros, xs_nonzeros) jaxpr_jvp, nonzeros_out = ad.jvp_jaxpr(jaxpr, nonzeros, instantiate=(carry_nonzeros, False)) carry_nonzeros_out, ys_nonzeros = nonzeros_out if carry_nonzeros_out == carry_nonzeros: break else: carry_nonzeros = _binary_lattice_join(carry_nonzeros_out, carry_nonzeros) else: raise FixedPointError # convert_zeros is like strip_zeros but uses explicit lattice information to # instantiate zeros in some cases, namely in init_dot based on the fixed point nonzero_init_dot = _convert_zeros(carry_nonzeros, init, init_dot) nonzero_consts_dot = _convert_zeros(consts_nonzeros, consts, consts_dot) nonzero_xs_dot = _convert_zeros(xs_nonzeros, xs, xs_dot) consts_dual = core.pack((consts, nonzero_consts_dot)) init_dual = core.pack((init, nonzero_init_dot)) xs_dual = core.pack((xs, nonzero_xs_dot)) carry_out_dual, ys_dual = scan_p.bind(consts_dual, init_dual, xs_dual, forward=forward, length=length, jaxpr=jaxpr_jvp) ys, ys_dot = ys_dual ys_dot = ad.put_zeros(ad.TangentTuple, ys_nonzeros, ys_dot) carry_out, carry_out_dot = carry_out_dual carry_out_dot = ad.put_zeros(ad.TangentTuple, carry_nonzeros_out, carry_out_dot) return core.pack((carry_out, ys)), ad.TangentTuple((carry_out_dot, ys_dot))
def remat_jvp(primals, tangents, jaxpr, prevent_cse, differentiated, policy): assert not jaxpr.constvars in_nonzeros = [type(t) is not ad_util.Zero for t in tangents] jaxpr_ = core.ClosedJaxpr(jaxpr, ()) jaxpr_jvp_, out_nonzeros = ad.jvp_jaxpr(jaxpr_, in_nonzeros, False) nonzero_tangents = [t for t in tangents if type(t) is not ad_util.Zero] jaxpr_jvp = pe.convert_constvars_jaxpr(jaxpr_jvp_.jaxpr) outs = remat_p.bind( *jaxpr_jvp_.consts, *primals, *nonzero_tangents, jaxpr=jaxpr_jvp, prevent_cse=prevent_cse, differentiated=differentiated, policy=policy) out_primals, out_tangents_ = split_list(outs, [len(jaxpr.outvars)]) out_tangents_ = iter(out_tangents_) out_tangents = [next(out_tangents_) if nz else ad_util.Zero.from_value(p) for p, nz in zip(out_primals, out_nonzeros)] return out_primals, out_tangents
def custom_vmap_jvp(primals, tangents, *, call, rule, in_tree, out_tree): def jvp_of_rule_rule(axis_size, in_batched, primals, tangents): in_batched_ps, in_batched_ts = in_batched mutually_batched = tree_map(operator.and_, in_batched_ps, in_batched_ts) extra_batched_ps = tree_map( lambda pb, tb: 0 if pb and not tb else None, in_batched_ps, in_batched_ts) extra_batched_ts = tree_map( lambda pb, tb: 0 if tb and not pb else None, in_batched_ps, in_batched_ts) out_mutually_batched = lu.Store() flat_ps_ts, tree_ps_ts = tree_flatten((primals, tangents)) flat_extra_batched_ps_ts, tree_ps_ts2 = tree_flatten( (extra_batched_ps, extra_batched_ts), is_leaf=lambda x: x is None) # TODO(frostig): assert these also equal: # treedef_tuple((in_tree, in_tree)) # once https://github.com/google/jax/issues/9066 is fixed assert tree_ps_ts == tree_ps_ts2 del tree_ps_ts2 def to_jvp(*primals): out, out_batched = call_rule(rule, axis_size, mutually_batched, primals) check_vmap_rule_trees(rule, out_tree, tree_structure(out), tree_structure(out_batched)) out_mutually_batched.store(out_batched) return out def to_vmap_over_extra_batched_dims(primals, tangents): return jax.jvp(to_jvp, primals, tangents) to_vmap_over_extra_batched_dims_flat, out_tree2 = flatten_fun_nokwargs( lu.wrap_init(to_vmap_over_extra_batched_dims), tree_ps_ts) flat_out_ps_ts, flat_out_axes = vmap_unrestricted( to_vmap_over_extra_batched_dims_flat, *flat_ps_ts, in_axes=flat_extra_batched_ps_ts, axis_name=core.no_axis_name, axis_size=axis_size) n, ragged = divmod(len(flat_out_ps_ts), 2) assert not ragged flat_out_ps, flat_out_ts = flat_out_ps_ts[:n], flat_out_ps_ts[n:] flat_out_axes_p, flat_out_axes_t = flat_out_axes[:n], flat_out_axes[n:] flat_out_ps = map(maybe_bdim_at_front, flat_out_ps, flat_out_axes_p) flat_out_extra_batched_ps = [ d is not not_mapped for d in flat_out_axes_p ] flat_out_ts = map(maybe_bdim_at_front, flat_out_ts, flat_out_axes_t) flat_out_extra_batched_ts = [ d is not not_mapped for d in flat_out_axes_t ] out_ps, out_ts = tree_unflatten(out_tree2(), [*flat_out_ps, *flat_out_ts]) out_extra_batched_ps, out_extra_batched_ts = tree_unflatten( out_tree2(), [*flat_out_extra_batched_ps, *flat_out_extra_batched_ts]) out_batched_ps = tree_map(operator.or_, out_mutually_batched.val, out_extra_batched_ps) out_batched_ts = tree_map(operator.or_, out_mutually_batched.val, out_extra_batched_ts) return (out_ps, out_ts), (out_batched_ps, out_batched_ts) tangents = map(ad.instantiate_zeros, tangents) jvp_call, _ = ad.jvp_jaxpr(call, [True] * len(primals), True) jvp_in_tree = treedef_tuple((in_tree, in_tree)) jvp_out_tree = treedef_tuple((out_tree, out_tree)) outs = custom_vmap_p.bind(*primals, *tangents, call=jvp_call, rule=jvp_of_rule_rule, in_tree=jvp_in_tree, out_tree=jvp_out_tree) assert len(outs) % 2 == 0, len(outs) out_primals, out_tangents = util.split_list(outs, [len(outs) // 2]) return out_primals, out_tangents