Exemplo n.º 1
0
def _flatten_bwd(in_tree, in_avals, out_trees, *args):
    out_tree, res_tree = out_trees()
    res, cts_out = split_list(args, [res_tree.num_leaves])
    py_res = tree_unflatten(res_tree, res)
    py_cts_out = tree_unflatten(out_tree, cts_out)
    py_cts_in = yield (py_res, py_cts_out), {}
    # For each None in py_cts_in, indicating an argument for which the rule
    # produces no cotangent, we replace it with a pytree with the structure of the
    # corresponding subtree of in_tree and with leaves of a non-pytree sentinel
    # object, to be replaced with Nones in the final returned result.
    zero = object()  # non-pytree sentinel to replace Nones in py_cts_in
    dummy = tree_unflatten(in_tree, [object()] * in_tree.num_leaves)
    cts_in_flat = []
    append_cts = lambda x, d: cts_in_flat.extend([x] * len(tree_flatten(d)[0]))
    try:
        if not isinstance(py_cts_in, tuple):
            raise ValueError
        tree_multimap(append_cts,
                      tuple(zero if ct is None else ct for ct in py_cts_in),
                      dummy)
    except ValueError:
        _, in_tree2 = tree_flatten(py_cts_in)
        msg = (
            "Custom VJP rule must produce an output with the same container "
            "(pytree) structure as the args tuple of the primal function, "
            "and in particular must produce a tuple of length equal to the "
            "number of arguments to the primal function, but got VJP output "
            "structure {} for primal input structure {}.")
        raise TypeError(msg.format(in_tree2, in_tree)) from None
    yield [
        zeros_like_aval(aval.at_least_vspace()) if ct is zero else ct
        for aval, ct in zip(in_avals, cts_in_flat)
    ]
Exemplo n.º 2
0
def _linear_call_transpose_rule(cts, *args, callee, transpose,
                                num_callee_consts, num_transpose_consts,
                                num_res):
    f_consts, t_consts, operands_res, operands_lin = split_list(
        args, [num_callee_consts, num_transpose_consts, num_res])
    _, _, cts_avals = split_list(transpose.in_avals,
                                 [num_transpose_consts, num_res])

    assert all(ad.is_undefined_primal(x) for x in operands_lin)
    assert all(not ad.is_undefined_primal(x) for x in operands_res)

    cts = [
        zeros_like_aval(a) if type(ct) is Zero else ct
        for ct, a in zip(cts, cts_avals)
    ]

    cts_out = linear_call_p.bind(*t_consts,
                                 *f_consts,
                                 *operands_res,
                                 *cts,
                                 callee=transpose,
                                 transpose=callee,
                                 num_callee_consts=len(t_consts),
                                 num_transpose_consts=len(f_consts),
                                 num_res=len(operands_res))

    return [None
            ] * (num_callee_consts + num_transpose_consts + num_res) + cts_out
Exemplo n.º 3
0
def custom_transpose_transpose_rule(cts, *args, out_types, res_tree, lin_tree,
                                    out_tree, **params):

    if 'transpose_jaxpr_thunk' in params:
        assert 'call_jaxpr' in params
        transpose = make_transpose_from_thunk(params['transpose_jaxpr_thunk'],
                                              lin_tree)
    else:
        assert 'call' in params
        transpose = params['transpose']

    call_in_tree = treedef_tuple((res_tree, lin_tree))

    # TODO(frostig,mattjj): `lin_arg` indicates the inputs with respect
    # to which we are transposing (via `ad.is_undefined_primal`).
    # Consider passing this information to the custom transpose rule?

    res_arg, lin_arg = tree_unflatten(call_in_tree, args)
    del lin_arg
    assert all(not ad.is_undefined_primal(x) for x in tree_leaves(res_arg))

    cts = [
        ad_util.zeros_like_aval(ct.aval) if type(ct) is ad_util.Zero else ct
        for ct in cts
    ]
    ct_out = tree_unflatten(out_tree, cts)
    ct_lin = transpose(res_arg, ct_out)
    check_transpose_rule_trees(transpose, lin_tree, tree_structure(ct_lin))
    ct_lin_flat, _ = tree_flatten(tree_broadcast(lin_tree,
                                                 ct_lin,
                                                 is_leaf=lambda x: x is None),
                                  is_leaf=lambda x: x is None)
    return [None] * len(tree_leaves(res_arg)) + ct_lin_flat
Exemplo n.º 4
0
 def fix_float0(arg_jax, ct_arg_jax):
     arg_dtype = dtypes.result_type(arg_jax)  # May be scalar
     ct_arg_dtype = core.primal_dtype_to_tangent_dtype(arg_dtype)
     if ct_arg_dtype != ct_arg_jax.dtype:
         return ad_util.zeros_like_aval(
             core.ShapedArray(np.shape(arg_jax), ct_arg_dtype))
     return ct_arg_jax
Exemplo n.º 5
0
def instantiate_zeros(tangent):
    if type(tangent) is Zero:
        if isinstance(tangent.aval, Tracer):
            return tangent.aval
        return zeros_like_aval(tangent.aval)
    else:
        return tangent
Exemplo n.º 6
0
def custom_transpose_transpose_rule(cts, *args, call, rule, res_tree, lin_tree,
                                    out_tree):
    call_in_tree = treedef_tuple((res_tree, lin_tree))

    res_arg, lin_arg = tree_unflatten(call_in_tree, args)
    assert all(ad.is_undefined_primal(x) for x in tree_leaves(lin_arg))
    assert all(not ad.is_undefined_primal(x) for x in tree_leaves(res_arg))

    cts = [
        ad_util.zeros_like_aval(ct_aval) if type(ct) is ad_util.Zero else ct
        for ct, ct_aval in zip(cts, call.out_avals)
    ]
    ct_out = tree_unflatten(out_tree, cts)
    ct_lin = rule(res_arg, ct_out)
    ct_lin_flat, ct_lin_tree = tree_flatten(ct_lin)
    check_transpose_rule_trees(rule, lin_tree, ct_lin_tree)
    return [None] * len(tree_leaves(res_arg)) + ct_lin_flat
Exemplo n.º 7
0
Arquivo: ad.py Projeto: jbampton/jax
def instantiate_zeros_aval(aval, tangent):
  if type(tangent) is Zero:
    assert type(tangent.aval) is core.AbstractUnit or tangent.aval == aval
    return zeros_like_aval(aval)
  else:
    return tangent
Exemplo n.º 8
0
Arquivo: ad.py Projeto: jbampton/jax
def instantiate_zeros(tangent):
  if type(tangent) is Zero:
    return zeros_like_aval(tangent.aval)
  else:
    return tangent
Exemplo n.º 9
0
def instantiate_zeros_aval(aval, tangent):
  if type(tangent) is Zero:
    assert tangent.aval == aval
    return zeros_like_aval(aval)
  else:
    return tangent
Exemplo n.º 10
0
def zeros_like_array(x):
    dtype = dtypes.canonicalize_dtype(dtypes.result_type(x))
    aval = ShapedArray(np.shape(x), dtype)
    return ad_util.zeros_like_aval(aval)
Exemplo n.º 11
0
def _zeros_like_python_scalar(t, x):
    dtype = dtypes.canonicalize_dtype(dtypes.python_scalar_dtypes[t])
    aval = core.ShapedArray((), dtype, weak_type=True)
    return ad_util.zeros_like_aval(aval)
Exemplo n.º 12
0
def zeros_like_array(x):
    dtype, weak_type = dtypes._lattice_result_type(x)
    dtype = dtypes.canonicalize_dtype(dtype)
    aval = ShapedArray(np.shape(x), dtype, weak_type=weak_type)
    return ad_util.zeros_like_aval(aval)
Exemplo n.º 13
0
def _zeros_like_python_scalar(t, x):
    aval = core.ShapedArray((), dtypes.python_scalar_dtypes[t], weak_type=True)
    return ad_util.zeros_like_aval(aval)