def transposed(*args): res, cts_out = split_list(args, [num_res]) primals = res + [ad.UndefinedPrimal(aval) for aval in primal_avals] cts_in = ad.backward_pass(jaxpr.jaxpr, reduce_axes, False, jaxpr.consts, primals, cts_out) _, cts_in = split_list(cts_in, [num_res]) return _map(ad.instantiate_zeros_aval, primal_avals, cts_in)
def transposed(*cbar_bbar_res): c_bar, b_bar, res = split_list(cbar_bbar_res, [num_c, num_b]) primals = [ad.undefined_primal] * (num_c + num_a) + res _, cbar_abar = ad.backward_pass(jaxpr.jaxpr, jaxpr.literals, (), primals, b_bar) new_c_bar, a_bar, _ = split_list(cbar_abar, [num_c, num_a]) a_bar = _map(ad.instantiate_zeros_aval, a_avals, a_bar) c_bar = _map(ad.instantiate_zeros_aval, c_avals, _map(ad.add_tangents, c_bar, new_c_bar)) return c_bar + a_bar
def transposed(*args): in_primals, out_cts = tree_unflatten(treedef, args) in_pvals = [pe.PartialVal.unknown(x.aval) if ad.is_undefined_primal(x) else pe.PartialVal.known(x) for x in in_primals] primal_fun = lu.wrap_init(partial(core.eval_jaxpr, jaxpr, ())) tangent_jaxpr, _, consts = pe.trace_to_jaxpr(primal_fun, in_pvals, False) dummy_args = [ad.UndefinedPrimal(v.aval) for v in tangent_jaxpr.invars] in_cts_ = ad.backward_pass(tangent_jaxpr, reduce_axes, False, consts, dummy_args, out_cts) in_cts, cell.treedef = tree_flatten(in_cts_) return in_cts
def _custom_jvp_call_jaxpr_transpose(reduce_axes, cts, *args, fun_jaxpr, jvp_jaxpr_thunk, num_consts): del jvp_jaxpr_thunk, num_consts return ad.backward_pass(fun_jaxpr.jaxpr, reduce_axes, fun_jaxpr.consts, args, cts)
def transposed(res, b_bar): _, (_, a_bar) = ad.backward_pass(jaxpr.jaxpr, jaxpr.literals, (), (res, None), b_bar) a_bar = ad.instantiate_zeros_aval(jaxpr.in_avals[1], a_bar) return a_bar