def _scan_jvp(primals, tangents, forward, length, jaxpr, num_consts, num_carry, linear): num_xs = len(jaxpr.in_avals) - num_carry - num_consts num_ys = len(jaxpr.out_avals) - num_carry nonzeros = [t is not ad_util.zero for t in tangents] const_nz, init_nz, xs_nz = split_list(nonzeros, [num_consts, num_carry]) carry_nz = init_nz for _ in range(1000): nonzeros = const_nz + carry_nz + xs_nz jaxpr_jvp, nonzeros_out = ad.jvp_jaxpr(jaxpr, nonzeros, instantiate=carry_nz + [False] * num_ys) carry_nz_out, ys_nz = nonzeros_out[:num_carry], nonzeros_out[ num_carry:] if carry_nz_out == carry_nz: break else: carry_nz = carry_nz_out else: raise FixedPointError tangents = [ ad.instantiate_zeros(x, t) if t is ad_util.zero and nz else t for x, t, nz in zip(primals, tangents, nonzeros) ] consts, init, xs = split_list(primals, [num_consts, num_carry]) all_tangents = split_list(tangents, [num_consts, num_carry]) consts_dot, init_dot, xs_dot = _map(_prune_zeros, all_tangents) jaxpr_jvp_rearranged = ad.rearrange_binders( jaxpr_jvp, [num_consts, num_carry, num_xs], [len(consts_dot), len(init_dot), len(xs_dot)], [num_carry, num_ys], [len(init_dot), sum(nonzeros_out) - len(init_dot)]) consts_linear, init_linear, xs_linear = split_list(linear, [num_consts, num_carry]) jaxpr_jvp_linear = (consts_linear + [True] * len(consts_dot) + init_linear + [True] * len(init_dot) + xs_linear + [True] * len(xs_dot)) out_flat = scan_p.bind(*(consts + consts_dot + init + init_dot + xs + xs_dot), forward=forward, length=length, jaxpr=jaxpr_jvp_rearranged, num_consts=num_consts + len(consts_dot), num_carry=num_carry + len(init_dot), linear=jaxpr_jvp_linear) carry, carry_dot, ys, ys_dot = split_list( out_flat, [num_carry, len(init_dot), num_ys]) primals_out = carry + ys tangents_out = iter(carry_dot + ys_dot) tangents_out = [ next(tangents_out) if nz else ad_util.zero for nz in nonzeros_out ] return primals_out, tangents_out
def _for_jvp(primals, tangents, *, jaxpr, nsteps, reverse, which_linear): nonzero_tangents = [not isinstance(t, ad_util.Zero) for t in tangents] # We need to find out which `Ref`s have nonzero tangents after running the # for loop. Ordinarily we do this with a fixed point on the body jaxpr but # a `for` body jaxpr is stateful and has no outputs. We therefore discharge # the state effect from the jaxpr and we will now have a "symmetric" jaxpr # where the inputs line up with the outputs. We use this discharged jaxpr # for the fixed point. discharged_jaxpr, body_consts = discharge_state(jaxpr, ()) for _ in range(len(nonzero_tangents)): _, out_nonzero_tangents = ad.jvp_jaxpr(core.ClosedJaxpr( discharged_jaxpr, body_consts), [False] + nonzero_tangents, instantiate=nonzero_tangents) if out_nonzero_tangents == nonzero_tangents: break nonzero_tangents = map(operator.or_, nonzero_tangents, out_nonzero_tangents) else: raise Exception("Invalid fixpoint") tangents = [ ad.instantiate_zeros(t) if inst else t for t, inst in zip(tangents, nonzero_tangents) ] tangents = [t for t in tangents if type(t) is not ad_util.Zero] closed_jaxpr = core.ClosedJaxpr(jaxpr, ()) jvp_jaxpr_, _ = ad.jvp_jaxpr(closed_jaxpr, [False] + nonzero_tangents, []) jvp_jaxpr, jvp_consts = jvp_jaxpr_.jaxpr, jvp_jaxpr_.consts jvp_which_linear = ((False, ) * len(jvp_consts) + which_linear + (True, ) * len(tangents)) out_flat = for_p.bind(*jvp_consts, *primals, *tangents, jaxpr=jvp_jaxpr, nsteps=nsteps, reverse=reverse, which_linear=jvp_which_linear) # `out_flat` includes constant inputs into the `for_loop` which are # converted into outputs as well. We don't care about these in AD so we # throw them out. _, out_primals, out_tangents = split_list( out_flat, [len(jvp_consts), len(primals)]) out_tangents_iter = iter(out_tangents) out_tangents = [ next(out_tangents_iter) if nz else ad_util.Zero.from_value(p) for p, nz in zip(out_primals, nonzero_tangents) ] return out_primals, out_tangents
def _convert_zeros(instantiate, example, tangent): t = type(instantiate) if t is bool: if instantiate: return ad.instantiate_zeros(example, tangent) elif tangent is ad_util.zero: return core.unit else: raise TypeError(tangent) # not clear if ever reachable elif t is tuple: if type(tangent) is ad.TangentTuple: return core.pack(map(_convert_zeros, instantiate, example, tangent)) elif tangent is ad_util.zero: zeros = [ad_util.zero] * len(instantiate) return core.pack(map(_convert_zeros, instantiate, example, zeros)) else: raise TypeError(tangent) else: raise TypeError(t)