Example #1
0
 def transposed(*args):
     res, cts_out = split_list(args, [num_res])
     primals = res + [ad.UndefinedPrimal(aval) for aval in primal_avals]
     cts_in = ad.backward_pass(jaxpr.jaxpr, reduce_axes, False,
                               jaxpr.consts, primals, cts_out)
     _, cts_in = split_list(cts_in, [num_res])
     return _map(ad.instantiate_zeros_aval, primal_avals, cts_in)
Example #2
0
 def transposed(*cbar_bbar_res):
   c_bar, b_bar, res = split_list(cbar_bbar_res, [num_c, num_b])
   primals = [ad.undefined_primal] * (num_c + num_a) + res
   _, cbar_abar = ad.backward_pass(jaxpr.jaxpr, jaxpr.literals, (), primals,
                                   b_bar)
   new_c_bar, a_bar, _ = split_list(cbar_abar, [num_c, num_a])
   a_bar = _map(ad.instantiate_zeros_aval, a_avals, a_bar)
   c_bar = _map(ad.instantiate_zeros_aval, c_avals,
               _map(ad.add_tangents, c_bar, new_c_bar))
   return c_bar + a_bar
Example #3
0
 def transposed(*args):
   in_primals, out_cts = tree_unflatten(treedef, args)
   in_pvals = [pe.PartialVal.unknown(x.aval) if ad.is_undefined_primal(x) else
               pe.PartialVal.known(x) for x in in_primals]
   primal_fun = lu.wrap_init(partial(core.eval_jaxpr, jaxpr, ()))
   tangent_jaxpr, _, consts = pe.trace_to_jaxpr(primal_fun, in_pvals, False)
   dummy_args = [ad.UndefinedPrimal(v.aval) for v in tangent_jaxpr.invars]
   in_cts_ = ad.backward_pass(tangent_jaxpr, reduce_axes, False, consts, dummy_args,
                              out_cts)
   in_cts, cell.treedef = tree_flatten(in_cts_)
   return in_cts
Example #4
0
def _custom_jvp_call_jaxpr_transpose(reduce_axes, cts, *args, fun_jaxpr,
                                     jvp_jaxpr_thunk, num_consts):
    del jvp_jaxpr_thunk, num_consts
    return ad.backward_pass(fun_jaxpr.jaxpr, reduce_axes, fun_jaxpr.consts,
                            args, cts)
Example #5
0
 def transposed(res, b_bar):
   _, (_, a_bar) = ad.backward_pass(jaxpr.jaxpr, jaxpr.literals, (),
                                    (res, None), b_bar)
   a_bar = ad.instantiate_zeros_aval(jaxpr.in_avals[1], a_bar)
   return a_bar