def _trace_to_jaxpr_with_refs( f, state_tree: PyTreeDef, state_avals: Sequence[core.AbstractValue] ) -> Tuple[core.Jaxpr, List[Any], PyTreeDef]: f, out_tree_thunk = flatten_fun_nokwargs( lu.wrap_init(f), treedef_tuple((tree_structure(0), state_tree))) jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(f, state_avals) return jaxpr, consts, out_tree_thunk()
def custom_transpose_transpose_rule(cts, *args, out_types, res_tree, lin_tree, out_tree, **params): if 'transpose_jaxpr_thunk' in params: assert 'call_jaxpr' in params transpose = make_transpose_from_thunk(params['transpose_jaxpr_thunk'], lin_tree) else: assert 'call' in params transpose = params['transpose'] call_in_tree = treedef_tuple((res_tree, lin_tree)) # TODO(frostig,mattjj): `lin_arg` indicates the inputs with respect # to which we are transposing (via `ad.is_undefined_primal`). # Consider passing this information to the custom transpose rule? res_arg, lin_arg = tree_unflatten(call_in_tree, args) del lin_arg assert all(not ad.is_undefined_primal(x) for x in tree_leaves(res_arg)) cts = [ ad_util.zeros_like_aval(ct.aval) if type(ct) is ad_util.Zero else ct for ct in cts ] ct_out = tree_unflatten(out_tree, cts) ct_lin = transpose(res_arg, ct_out) check_transpose_rule_trees(transpose, lin_tree, tree_structure(ct_lin)) ct_lin_flat, _ = tree_flatten(tree_broadcast(lin_tree, ct_lin, is_leaf=lambda x: x is None), is_leaf=lambda x: x is None) return [None] * len(tree_leaves(res_arg)) + ct_lin_flat
def testTreedefTupleFromChildren(self): # https://github.com/google/jax/issues/7377 tree = ((1, 2, (3, 4)), (5,)) leaves, treedef1 = tree_util.tree_flatten(tree) treedef2 = tree_util.treedef_tuple(treedef1.children()) self.assertEqual(treedef1.num_leaves, len(leaves)) self.assertEqual(treedef1.num_leaves, treedef2.num_leaves) self.assertEqual(treedef1.num_nodes, treedef2.num_nodes)
def testTreedefTupleFromChildren(self): # https://github.com/google/jax/issues/7377 # TODO(frostig): remove after the minimum jaxlib is is 0.1.70 or newer. if jax.lib._xla_extension_version < 29: self.skipTest("fixed in future jaxlib") tree = ((1, 2, (3, 4)), (5,)) leaves, treedef1 = tree_util.tree_flatten(tree) treedef2 = tree_util.treedef_tuple(treedef1.children()) self.assertEqual(treedef1.num_leaves, len(leaves)) self.assertEqual(treedef1.num_leaves, treedef2.num_leaves) self.assertEqual(treedef1.num_nodes, treedef2.num_nodes)
def custom_transpose_transpose_rule(cts, *args, call, rule, res_tree, lin_tree, out_tree): call_in_tree = treedef_tuple((res_tree, lin_tree)) res_arg, lin_arg = tree_unflatten(call_in_tree, args) assert all(ad.is_undefined_primal(x) for x in tree_leaves(lin_arg)) assert all(not ad.is_undefined_primal(x) for x in tree_leaves(res_arg)) cts = [ ad_util.zeros_like_aval(ct_aval) if type(ct) is ad_util.Zero else ct for ct, ct_aval in zip(cts, call.out_avals) ] ct_out = tree_unflatten(out_tree, cts) ct_lin = rule(res_arg, ct_out) ct_lin_flat, ct_lin_tree = tree_flatten(ct_lin) check_transpose_rule_trees(rule, lin_tree, ct_lin_tree) return [None] * len(tree_leaves(res_arg)) + ct_lin_flat
def linear_call(fun: Callable, fun_transpose: Callable, residual_args, linear_args): """Call a linear function, with a custom implementation for its transpose. The type signatures of ``fun`` and ``fun_transpose`` are: .. code-block:: haskell fun :: r -> a -o b fun_transpose :: r -> b -o a where the ``-o`` arrow indicates a linear function, ``r`` is the residual input type and ``a`` is the linear input type. The functions ``fun`` and ``fun_transpose`` are coupled as transposes of one another. Specifically, the transpose of a ``linear_call`` primitive is another ``linear_call`` to ``fun_transpose``, with ``fun`` as its custom transposition. For example: >>> def f(r, x): ... return x / r >>> def t(r, t): ... return t / r >>> def div_add(x, denom): ... return x + linear_call(f, t, denom, x) >>> def transpose(f, x_example): ... def transposed(y): ... x, = jax.linear_transpose(f, x_example)(y) ... return x ... return transposed >>> div_add(9., 3.) DeviceArray(12., dtype=float32, weak_type=True) >>> transpose(partial(div_add, denom=3.), 1.)(18.) # custom DeviceArray(24., dtype=float32, weak_type=True) >>> transpose(lambda x: x + x / 3., 1.)(18.) # reference DeviceArray(24., dtype=float32, weak_type=True) The above definition of ``f`` illustrates the purpose of a residual argument: division is linear in one of its inputs (the dividend ``x``) but not the other (the divisor ``r``). As another example: >>> def custom_id(x): ... def f(_, x): return x ... def t(_, t): return 7. ... return linear_call(f, t, (), x) >>> custom_id(1.) 1.0 >>> transpose(custom_id, 1.)(1.) 7.0 >>> transpose(transpose(custom_id, 1.), 1.)(1.) 1.0 >>> transpose(transpose(transpose(custom_id, 1.), 1.), 1.)(1.) 7.0 Args: fun: a Python callable specifying a linear function. It should take two arguments: one of "residual" inputs (type ``r``), i.e. inputs in which the function is not necessarly linear, and one of "linear" inputs (type ``a``). It should return output whose components are linear in the linear input (type ``b``). fun_transpose: a Python callable specifying a structurally linear function that is the transpose of ``fun`` with respect to its linear inputs. Its first argument is the same residual inputs (``r``) as ``fun``. Its second argument is of type ``b``. Finally, its output is of type ``a`` and each of its component are linear in its second argument (the ``b`` inputs). residual_args: Argument in which ``fun`` and ``fun_transpose`` are not necessarily linear. Not involved in transposition. linear_args: Argument in which ``fun`` and ``fun_transpose`` are linear and with respect to which the two are transposes. Returns: The call result, i.e. ``fun(residual_args, linear_args)``. """ operands_res, res_tree = tree_flatten(residual_args) operands_lin, lin_tree = tree_flatten(linear_args) f_in_tree = treedef_tuple((res_tree, lin_tree)) f, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), f_in_tree) res_avals = map(abstractify, operands_res) lin_avals = map(abstractify, operands_lin) f_jaxpr, f_consts = _initial_style_jaxpr(f, (*res_avals, *lin_avals)) f_jaxpr = _close_jaxpr(f_jaxpr) out_avals = map(core.raise_to_shaped, f_jaxpr.out_avals) t_in_tree = treedef_tuple((res_tree, out_tree())) t, t_out_tree = flatten_fun_nokwargs(lu.wrap_init(fun_transpose), t_in_tree) t_jaxpr, t_consts = _initial_style_jaxpr(t, (*res_avals, *out_avals)) t_jaxpr = _close_jaxpr(t_jaxpr) if t_out_tree() != lin_tree: raise TypeError( 'transpose output pytree structure must match that of linear inputs, ' f'got output structure {t_out_tree()} ' f'and input structure {lin_tree}.') out = linear_call_p.bind(*f_consts, *t_consts, *operands_res, *operands_lin, callee=f_jaxpr, transpose=t_jaxpr, num_callee_consts=len(f_consts), num_transpose_consts=len(t_consts), num_res=len(operands_res)) return tree_unflatten(out_tree(), out)
def custom_vmap_jvp(primals, tangents, *, call, rule, in_tree, out_tree): def jvp_of_rule_rule(axis_size, in_batched, primals, tangents): in_batched_ps, in_batched_ts = in_batched mutually_batched = tree_map(operator.and_, in_batched_ps, in_batched_ts) extra_batched_ps = tree_map( lambda pb, tb: 0 if pb and not tb else None, in_batched_ps, in_batched_ts) extra_batched_ts = tree_map( lambda pb, tb: 0 if tb and not pb else None, in_batched_ps, in_batched_ts) out_mutually_batched = lu.Store() flat_ps_ts, tree_ps_ts = tree_flatten((primals, tangents)) flat_extra_batched_ps_ts, tree_ps_ts2 = tree_flatten( (extra_batched_ps, extra_batched_ts), is_leaf=lambda x: x is None) # TODO(frostig): assert these also equal: # treedef_tuple((in_tree, in_tree)) # once https://github.com/google/jax/issues/9066 is fixed assert tree_ps_ts == tree_ps_ts2 del tree_ps_ts2 def to_jvp(*primals): out, out_batched = call_rule(rule, axis_size, mutually_batched, primals) check_vmap_rule_trees(rule, out_tree, tree_structure(out), tree_structure(out_batched)) out_mutually_batched.store(out_batched) return out def to_vmap_over_extra_batched_dims(primals, tangents): return jax.jvp(to_jvp, primals, tangents) to_vmap_over_extra_batched_dims_flat, out_tree2 = flatten_fun_nokwargs( lu.wrap_init(to_vmap_over_extra_batched_dims), tree_ps_ts) flat_out_ps_ts, flat_out_axes = vmap_unrestricted( to_vmap_over_extra_batched_dims_flat, *flat_ps_ts, in_axes=flat_extra_batched_ps_ts, axis_name=core.no_axis_name, axis_size=axis_size) n, ragged = divmod(len(flat_out_ps_ts), 2) assert not ragged flat_out_ps, flat_out_ts = flat_out_ps_ts[:n], flat_out_ps_ts[n:] flat_out_axes_p, flat_out_axes_t = flat_out_axes[:n], flat_out_axes[n:] flat_out_ps = map(maybe_bdim_at_front, flat_out_ps, flat_out_axes_p) flat_out_extra_batched_ps = [ d is not not_mapped for d in flat_out_axes_p ] flat_out_ts = map(maybe_bdim_at_front, flat_out_ts, flat_out_axes_t) flat_out_extra_batched_ts = [ d is not not_mapped for d in flat_out_axes_t ] out_ps, out_ts = tree_unflatten(out_tree2(), [*flat_out_ps, *flat_out_ts]) out_extra_batched_ps, out_extra_batched_ts = tree_unflatten( out_tree2(), [*flat_out_extra_batched_ps, *flat_out_extra_batched_ts]) out_batched_ps = tree_map(operator.or_, out_mutually_batched.val, out_extra_batched_ps) out_batched_ts = tree_map(operator.or_, out_mutually_batched.val, out_extra_batched_ts) return (out_ps, out_ts), (out_batched_ps, out_batched_ts) tangents = map(ad.instantiate_zeros, tangents) jvp_call, _ = ad.jvp_jaxpr(call, [True] * len(primals), True) jvp_in_tree = treedef_tuple((in_tree, in_tree)) jvp_out_tree = treedef_tuple((out_tree, out_tree)) outs = custom_vmap_p.bind(*primals, *tangents, call=jvp_call, rule=jvp_of_rule_rule, in_tree=jvp_in_tree, out_tree=jvp_out_tree) assert len(outs) % 2 == 0, len(outs) out_primals, out_tangents = util.split_list(outs, [len(outs) // 2]) return out_primals, out_tangents
def testTreedefTupleComparesEqual(self): # https://github.com/google/jax/issues/9066 self.assertEqual(tree_util.tree_structure((3,)), tree_util.treedef_tuple((tree_util.tree_structure(3),)))
def custom_root(f, initial_guess, solve, tangent_solve, has_aux=False): """Differentiably solve for a roots of a function. This is a low-level routine, mostly intended for internal use in JAX. Gradients of custom_root() are defined with respect to closed-over variables from the provided function ``f`` via the implicit function theorem: https://en.wikipedia.org/wiki/Implicit_function_theorem Args: f: function for which to find a root. Should accept a single argument, return a tree of arrays with the same structure as its input. initial_guess: initial guess for a zero of f. solve: function to solve for the roots of f. Should take two positional arguments, f and initial_guess, and return a solution with the same structure as initial_guess such that func(solution) = 0. In other words, the following is assumed to be true (but not checked):: solution = solve(f, initial_guess) error = f(solution) assert all(error == 0) tangent_solve: function to solve the tangent system. Should take two positional arguments, a linear function ``g`` (the function ``f`` linearized at its root) and a tree of array(s) ``y`` with the same structure as initial_guess, and return a solution ``x`` such that ``g(x)=y``: - For scalar ``y``, use ``lambda g, y: y / g(1.0)``. - For vector ``y``, you could use a linear solve with the Jacobian, if dimensionality of ``y`` is not too large: ``lambda g, y: np.linalg.solve(jacobian(g)(y), y)``. has_aux: bool indicating whether the ``solve`` function returns auxiliary data like solver diagnostics as a second argument. Returns: The result of calling solve(f, initial_guess) with gradients defined via implicit differentiation assuming ``f(solve(f, initial_guess)) == 0``. """ guess_flat, in_args_tree = tree_flatten((initial_guess, )) guess_avals = tuple(_map(_abstractify, guess_flat)) f_jaxpr, f_consts, out_tree = _initial_style_jaxpr(f, in_args_tree, guess_avals) in_tree, = treedef_children(in_args_tree) _check_tree("f", "initial_guess", out_tree, in_tree, False) solve_jaxpr, solve_consts, solution_tree = _initial_style_jaxpr( partial(solve, f), in_args_tree, guess_avals) _check_tree("solve", "initial_guess", solution_tree, in_tree, has_aux) def linearize_and_solve(x, b): unchecked_zeros, f_jvp = jax.linearize(f, x) return tangent_solve(f_jvp, b) l_and_s_jaxpr, l_and_s_consts, out_tree = _initial_style_jaxpr( linearize_and_solve, treedef_tuple((in_tree, ) * 2), guess_avals * 2) _check_tree("tangent_solve", "x", out_tree, in_tree, False) all_consts = [f_consts, solve_consts, l_and_s_consts] const_lengths = _RootTuple(*_map(len, all_consts)) jaxprs = _RootTuple(f_jaxpr, solve_jaxpr, l_and_s_jaxpr) solution_flat = _custom_root(const_lengths, jaxprs, *(_flatten(all_consts) + guess_flat)) return tree_unflatten(solution_tree, solution_flat)