def _flatten_bwd(in_tree, in_avals, out_trees, *args): out_tree, res_tree = out_trees() res, cts_out = split_list(args, [res_tree.num_leaves]) py_res = tree_unflatten(res_tree, res) py_cts_out = tree_unflatten(out_tree, cts_out) py_cts_in = yield (py_res, py_cts_out), {} # For each None in py_cts_in, indicating an argument for which the rule # produces no cotangent, we replace it with a pytree with the structure of the # corresponding subtree of in_tree and with leaves of a non-pytree sentinel # object, to be replaced with Nones in the final returned result. zero = object() # non-pytree sentinel to replace Nones in py_cts_in dummy = tree_unflatten(in_tree, [object()] * in_tree.num_leaves) cts_in_flat = [] append_cts = lambda x, d: cts_in_flat.extend([x] * len(tree_flatten(d)[0])) try: if not isinstance(py_cts_in, tuple): raise ValueError tree_multimap(append_cts, tuple(zero if ct is None else ct for ct in py_cts_in), dummy) except ValueError: _, in_tree2 = tree_flatten(py_cts_in) msg = ( "Custom VJP rule must produce an output with the same container " "(pytree) structure as the args tuple of the primal function, " "and in particular must produce a tuple of length equal to the " "number of arguments to the primal function, but got VJP output " "structure {} for primal input structure {}.") raise TypeError(msg.format(in_tree2, in_tree)) from None yield [ zeros_like_aval(aval.at_least_vspace()) if ct is zero else ct for aval, ct in zip(in_avals, cts_in_flat) ]
def _linear_call_transpose_rule(cts, *args, callee, transpose, num_callee_consts, num_transpose_consts, num_res): f_consts, t_consts, operands_res, operands_lin = split_list( args, [num_callee_consts, num_transpose_consts, num_res]) _, _, cts_avals = split_list(transpose.in_avals, [num_transpose_consts, num_res]) assert all(ad.is_undefined_primal(x) for x in operands_lin) assert all(not ad.is_undefined_primal(x) for x in operands_res) cts = [ zeros_like_aval(a) if type(ct) is Zero else ct for ct, a in zip(cts, cts_avals) ] cts_out = linear_call_p.bind(*t_consts, *f_consts, *operands_res, *cts, callee=transpose, transpose=callee, num_callee_consts=len(t_consts), num_transpose_consts=len(f_consts), num_res=len(operands_res)) return [None ] * (num_callee_consts + num_transpose_consts + num_res) + cts_out
def custom_transpose_transpose_rule(cts, *args, out_types, res_tree, lin_tree, out_tree, **params): if 'transpose_jaxpr_thunk' in params: assert 'call_jaxpr' in params transpose = make_transpose_from_thunk(params['transpose_jaxpr_thunk'], lin_tree) else: assert 'call' in params transpose = params['transpose'] call_in_tree = treedef_tuple((res_tree, lin_tree)) # TODO(frostig,mattjj): `lin_arg` indicates the inputs with respect # to which we are transposing (via `ad.is_undefined_primal`). # Consider passing this information to the custom transpose rule? res_arg, lin_arg = tree_unflatten(call_in_tree, args) del lin_arg assert all(not ad.is_undefined_primal(x) for x in tree_leaves(res_arg)) cts = [ ad_util.zeros_like_aval(ct.aval) if type(ct) is ad_util.Zero else ct for ct in cts ] ct_out = tree_unflatten(out_tree, cts) ct_lin = transpose(res_arg, ct_out) check_transpose_rule_trees(transpose, lin_tree, tree_structure(ct_lin)) ct_lin_flat, _ = tree_flatten(tree_broadcast(lin_tree, ct_lin, is_leaf=lambda x: x is None), is_leaf=lambda x: x is None) return [None] * len(tree_leaves(res_arg)) + ct_lin_flat
def fix_float0(arg_jax, ct_arg_jax): arg_dtype = dtypes.result_type(arg_jax) # May be scalar ct_arg_dtype = core.primal_dtype_to_tangent_dtype(arg_dtype) if ct_arg_dtype != ct_arg_jax.dtype: return ad_util.zeros_like_aval( core.ShapedArray(np.shape(arg_jax), ct_arg_dtype)) return ct_arg_jax
def instantiate_zeros(tangent): if type(tangent) is Zero: if isinstance(tangent.aval, Tracer): return tangent.aval return zeros_like_aval(tangent.aval) else: return tangent
def custom_transpose_transpose_rule(cts, *args, call, rule, res_tree, lin_tree, out_tree): call_in_tree = treedef_tuple((res_tree, lin_tree)) res_arg, lin_arg = tree_unflatten(call_in_tree, args) assert all(ad.is_undefined_primal(x) for x in tree_leaves(lin_arg)) assert all(not ad.is_undefined_primal(x) for x in tree_leaves(res_arg)) cts = [ ad_util.zeros_like_aval(ct_aval) if type(ct) is ad_util.Zero else ct for ct, ct_aval in zip(cts, call.out_avals) ] ct_out = tree_unflatten(out_tree, cts) ct_lin = rule(res_arg, ct_out) ct_lin_flat, ct_lin_tree = tree_flatten(ct_lin) check_transpose_rule_trees(rule, lin_tree, ct_lin_tree) return [None] * len(tree_leaves(res_arg)) + ct_lin_flat
def instantiate_zeros_aval(aval, tangent): if type(tangent) is Zero: assert type(tangent.aval) is core.AbstractUnit or tangent.aval == aval return zeros_like_aval(aval) else: return tangent
def instantiate_zeros(tangent): if type(tangent) is Zero: return zeros_like_aval(tangent.aval) else: return tangent
def instantiate_zeros_aval(aval, tangent): if type(tangent) is Zero: assert tangent.aval == aval return zeros_like_aval(aval) else: return tangent
def zeros_like_array(x): dtype = dtypes.canonicalize_dtype(dtypes.result_type(x)) aval = ShapedArray(np.shape(x), dtype) return ad_util.zeros_like_aval(aval)
def _zeros_like_python_scalar(t, x): dtype = dtypes.canonicalize_dtype(dtypes.python_scalar_dtypes[t]) aval = core.ShapedArray((), dtype, weak_type=True) return ad_util.zeros_like_aval(aval)
def zeros_like_array(x): dtype, weak_type = dtypes._lattice_result_type(x) dtype = dtypes.canonicalize_dtype(dtype) aval = ShapedArray(np.shape(x), dtype, weak_type=weak_type) return ad_util.zeros_like_aval(aval)
def _zeros_like_python_scalar(t, x): aval = core.ShapedArray((), dtypes.python_scalar_dtypes[t], weak_type=True) return ad_util.zeros_like_aval(aval)