def _linear_call_transpose_rule(cts, *args, callee, transpose, num_callee_consts, num_transpose_consts, num_res): f_consts, t_consts, operands_res, operands_lin = split_list( args, [num_callee_consts, num_transpose_consts, num_res]) _, _, cts_avals = split_list(transpose.in_avals, [num_transpose_consts, num_res]) assert all(ad.is_undefined_primal(x) for x in operands_lin) assert all(not ad.is_undefined_primal(x) for x in operands_res) cts = [ zeros_like_aval(a) if type(ct) is Zero else ct for ct, a in zip(cts, cts_avals) ] cts_out = linear_call_p.bind(*t_consts, *f_consts, *operands_res, *cts, callee=transpose, transpose=callee, num_callee_consts=len(t_consts), num_transpose_consts=len(f_consts), num_res=len(operands_res)) return [None ] * (num_callee_consts + num_transpose_consts + num_res) + cts_out
def _flatten_bwd(in_tree, in_avals, out_trees, *args): out_tree, res_tree = out_trees() res, cts_out = split_list(args, [res_tree.num_leaves]) py_res = tree_unflatten(res_tree, res) py_cts_out = tree_unflatten(out_tree, cts_out) py_cts_in = yield (py_res, py_cts_out), {} # For each None in py_cts_in, indicating an argument for which the rule # produces no cotangent, we replace it with a pytree with the structure of the # corresponding subtree of in_tree and with leaves of a non-pytree sentinel # object, to be replaced with Nones in the final returned result. zero = object() # non-pytree sentinel to replace Nones in py_cts_in dummy = tree_unflatten(in_tree, [object()] * in_tree.num_leaves) cts_in_flat = [] append_cts = lambda x, d: cts_in_flat.extend([x] * len(tree_flatten(d)[0])) try: if not isinstance(py_cts_in, tuple): raise ValueError tree_multimap(append_cts, tuple(zero if ct is None else ct for ct in py_cts_in), dummy) except ValueError: _, in_tree2 = tree_flatten(py_cts_in) msg = ( "Custom VJP rule must produce an output with the same container " "(pytree) structure as the args tuple of the primal function, " "and in particular must produce a tuple of length equal to the " "number of arguments to the primal function, but got VJP output " "structure {} for primal input structure {}.") raise TypeError(msg.format(in_tree2, in_tree)) from None yield [ zeros_like_aval(aval.at_least_vspace()) if ct is zero else ct for aval, ct in zip(in_avals, cts_in_flat) ]
def _scan_transpose(ct, consts, init, xs, forward, length, jaxpr): assert consts is None and init is None assert type(xs) is tuple a, res = xs assert a is None and res is not None # jaxpr :: d -> c -> (a, res) -> (c, b) # jaxpr_lifted :: res -> (d, c, a) -> (c, b) # jaxpr_lifted_trans :: res -> (CT c, CT b) -> (CT d, CT c, CT a) # jaxpr_trans :: * -> (CT c, CT d) -> (CT b, res) -> ((CT c, CT d), CT a) assert type(jaxpr.jaxpr.invars[2]) is tuple # assume restructuring jaxpr_lifted = rearrange_binders( lambda d, c, a_res: (a_res[1], (d, c, a_res[0])), jaxpr) jaxpr_lifted_trans = _transpose_jaxpr(jaxpr_lifted) jaxpr_trans = _move_stuff_and_add_add(jaxpr_lifted_trans) c_aval, b_aval = jaxpr.out_aval d_aval, c_aval2, _ = jaxpr.in_avals assert c_aval == c_aval2 bs_aval = _promote_aval_rank(length, b_aval) ct_d = ad_util.zeros_like_aval(d_aval) ct_c, ct_bs = ad.instantiate_zeros_aval(core.AbstractTuple((c_aval, bs_aval)), ct) carry_ct = core.pack((ct_c, ct_d)) # jaxpr_trans :: * -> (CT c, CT d) -> (CT b, res) -> ((CT c, CT d), CT a) core.check_jaxpr(jaxpr_trans.jaxpr) unit_aval, (ct_c_aval, ct_d_aval), (ct_b_aval, _) = jaxpr_trans.in_avals assert core.lattice_join(ct_c_aval, core.get_aval(ct_c)) == ct_c_aval assert core.lattice_join(ct_d_aval, core.get_aval(ct_d)) == ct_d_aval out = scan_p.bind( core.unit, carry_ct, core.pack((ct_bs, res)), forward=not forward, length=length, jaxpr=jaxpr_trans) (ct_init, ct_consts), ct_as = out return ct_consts, ct_init, (ct_as, None)