예제 #1
0
def _linear_call_transpose_rule(cts, *args, callee, transpose,
                                num_callee_consts, num_transpose_consts,
                                num_res):
    f_consts, t_consts, operands_res, operands_lin = split_list(
        args, [num_callee_consts, num_transpose_consts, num_res])
    _, _, cts_avals = split_list(transpose.in_avals,
                                 [num_transpose_consts, num_res])

    assert all(ad.is_undefined_primal(x) for x in operands_lin)
    assert all(not ad.is_undefined_primal(x) for x in operands_res)

    cts = [
        zeros_like_aval(a) if type(ct) is Zero else ct
        for ct, a in zip(cts, cts_avals)
    ]

    cts_out = linear_call_p.bind(*t_consts,
                                 *f_consts,
                                 *operands_res,
                                 *cts,
                                 callee=transpose,
                                 transpose=callee,
                                 num_callee_consts=len(t_consts),
                                 num_transpose_consts=len(f_consts),
                                 num_res=len(operands_res))

    return [None
            ] * (num_callee_consts + num_transpose_consts + num_res) + cts_out
예제 #2
0
def _flatten_bwd(in_tree, in_avals, out_trees, *args):
    out_tree, res_tree = out_trees()
    res, cts_out = split_list(args, [res_tree.num_leaves])
    py_res = tree_unflatten(res_tree, res)
    py_cts_out = tree_unflatten(out_tree, cts_out)
    py_cts_in = yield (py_res, py_cts_out), {}
    # For each None in py_cts_in, indicating an argument for which the rule
    # produces no cotangent, we replace it with a pytree with the structure of the
    # corresponding subtree of in_tree and with leaves of a non-pytree sentinel
    # object, to be replaced with Nones in the final returned result.
    zero = object()  # non-pytree sentinel to replace Nones in py_cts_in
    dummy = tree_unflatten(in_tree, [object()] * in_tree.num_leaves)
    cts_in_flat = []
    append_cts = lambda x, d: cts_in_flat.extend([x] * len(tree_flatten(d)[0]))
    try:
        if not isinstance(py_cts_in, tuple):
            raise ValueError
        tree_multimap(append_cts,
                      tuple(zero if ct is None else ct for ct in py_cts_in),
                      dummy)
    except ValueError:
        _, in_tree2 = tree_flatten(py_cts_in)
        msg = (
            "Custom VJP rule must produce an output with the same container "
            "(pytree) structure as the args tuple of the primal function, "
            "and in particular must produce a tuple of length equal to the "
            "number of arguments to the primal function, but got VJP output "
            "structure {} for primal input structure {}.")
        raise TypeError(msg.format(in_tree2, in_tree)) from None
    yield [
        zeros_like_aval(aval.at_least_vspace()) if ct is zero else ct
        for aval, ct in zip(in_avals, cts_in_flat)
    ]
예제 #3
0
def _scan_transpose(ct, consts, init, xs, forward, length, jaxpr):
  assert consts is None and init is None
  assert type(xs) is tuple
  a, res = xs
  assert a is None and res is not None

  # jaxpr :: d -> c -> (a, res) ->  (c, b)
  # jaxpr_lifted :: res -> (d, c, a) -> (c, b)
  # jaxpr_lifted_trans :: res -> (CT c, CT b) -> (CT d, CT c, CT a)
  # jaxpr_trans :: * -> (CT c, CT d) -> (CT b, res) -> ((CT c, CT d), CT a)
  assert type(jaxpr.jaxpr.invars[2]) is tuple  # assume restructuring
  jaxpr_lifted = rearrange_binders(
      lambda d, c, a_res: (a_res[1], (d, c, a_res[0])), jaxpr)
  jaxpr_lifted_trans = _transpose_jaxpr(jaxpr_lifted)
  jaxpr_trans = _move_stuff_and_add_add(jaxpr_lifted_trans)

  c_aval, b_aval = jaxpr.out_aval
  d_aval, c_aval2, _ = jaxpr.in_avals
  assert c_aval == c_aval2
  bs_aval = _promote_aval_rank(length, b_aval)
  ct_d = ad_util.zeros_like_aval(d_aval)
  ct_c, ct_bs = ad.instantiate_zeros_aval(core.AbstractTuple((c_aval, bs_aval)), ct)
  carry_ct = core.pack((ct_c, ct_d))

  # jaxpr_trans :: * -> (CT c, CT d) -> (CT b, res) -> ((CT c, CT d), CT a)
  core.check_jaxpr(jaxpr_trans.jaxpr)
  unit_aval, (ct_c_aval, ct_d_aval), (ct_b_aval, _) = jaxpr_trans.in_avals
  assert core.lattice_join(ct_c_aval, core.get_aval(ct_c)) == ct_c_aval
  assert core.lattice_join(ct_d_aval, core.get_aval(ct_d)) == ct_d_aval

  out = scan_p.bind(
      core.unit, carry_ct, core.pack((ct_bs, res)),
      forward=not forward, length=length, jaxpr=jaxpr_trans)
  (ct_init, ct_consts), ct_as = out
  return ct_consts, ct_init, (ct_as, None)