def ildj_rule(incells, outcells, *, in_tree, out_tree, num_consts, **_): # First incell is a wrapped function because prim is a call primitive. const_incells, incells = jax_util.split_list(incells, [num_consts]) if (all(outcell.top() for outcell in outcells) and any(not incell.top() for incell in incells)): flat_outvals = [outcell.val for outcell in outcells] flat_outildjs = [outcell.ildj for outcell in outcells] outvals = tree_util.tree_unflatten(out_tree, flat_outvals) outildjs = tree_util.tree_unflatten(out_tree, flat_outildjs) flat_invals = [ None if not incell.top() else incell.val for incell in incells ] invals = tree_util.tree_unflatten(in_tree, flat_invals) try: new_invals, new_ildjs = f_ildj(invals, outvals, outildjs) except NonInvertibleError: return const_incells + incells, outcells, None # We need to flatten the output from `f_ildj` using # `tree_util.tree_flatten` but if the user returns `None` (when # inversion is not possible), JAX will remove `None`s from the flattened # version and the number of `new_incells` will not match the old # `incells`. We use the private `_replace_nones` feature in JAX to # replace it with a sentinel that won't be removed when flattening. none_ = object() new_invals = tree_util._replace_nones(none_, new_invals) # pylint: disable=protected-access new_ildjs = tree_util._replace_nones(none_, new_ildjs) # pylint: disable=protected-access new_flat_invals = tree_util.tree_leaves(new_invals) new_flat_ildjs = tree_util.tree_leaves(new_ildjs) inslices = [ NDSlice.new(inval, ildj) for inval, ildj in zip(new_flat_invals, new_flat_ildjs) ] new_incells = [] for new_flat_inval, old_incell, inslice in zip( new_flat_invals, incells, inslices): if new_flat_inval is not none_: new_incells.append( InverseAndILDJ(old_incell.aval, [inslice])) else: new_incells.append(old_incell) return const_incells + new_incells, outcells, None elif (all(incell.top() for incell in incells) and any(not outcell.top() for outcell in outcells)): flat_invals = [incell.val for incell in incells] invals = tree_util.tree_unflatten(in_tree, flat_invals) outvals = self(*invals) flat_outvals = tree_util.tree_leaves(outvals) outcells = [ InverseAndILDJ.new(outval) for outval in flat_outvals ] return const_incells + incells, outcells, None return const_incells + incells, outcells, None
def apply_flat_fun_nokwargs(fun, io_tree, py_args): in_tree_expected, out_tree = io_tree args, in_tree = tree_flatten(py_args) if in_tree != in_tree_expected: raise TypeError("Expected {}, got {}".format(in_tree_expected, in_tree)) ans = fun(*args) return tree_unflatten(out_tree, ans)
def apply_flat_fun(fun, io_tree, *py_args): in_tree_expected, out_tree = io_tree args, in_tree = tree_flatten((py_args, {})) if in_tree != in_tree_expected: raise TypeError(f"Expected {in_tree_expected}, got {in_tree}") ans = fun(*args) return tree_unflatten(out_tree, ans)
def dex_fun(*args, **kwargs): args_flat, in_tree = tree_flatten((args, kwargs)) in_avals, in_tree_, keep_inputs = abstractify(args, kwargs) assert in_tree == in_tree_ jaxpr, consts, out_tree = make_jaxpr(fun, in_tree, tuple(in_avals), tuple(keep_inputs)) out_flat = dex_call_p.bind(*consts, *args_flat, jaxpr=jaxpr) return tree_unflatten(out_tree, out_flat)
def flatten_fun_nokwargs2(in_tree, *args_flat): py_args = tree_unflatten(in_tree, args_flat) pair = yield py_args, {} if not isinstance(pair, (list, tuple)) or len(pair) != 2: raise TypeError("expected function with aux output to return a two-element " f"tuple, but got type {type(pair)} with value {repr(pair)}") ans, aux = pair ans_flat, ans_tree = tree_flatten(ans) aux_flat, aux_tree = tree_flatten(aux) yield (ans_flat, aux_flat), (ans_tree, aux_tree)
def testKdePyTree(self): @jax.jit def evaluate_kde(kde, x): return kde.evaluate(x) dtype = np.float32 rng = jtu.rand_default(self.rng()) dataset = rng((3, 15), dtype) x = rng((3, 12), dtype) kde = lsp_stats.gaussian_kde(dataset) leaves, treedef = tree_util.tree_flatten(kde) kde2 = tree_util.tree_unflatten(treedef, leaves) tree_util.tree_map(lambda a, b: self.assertAllClose(a, b), kde, kde2) self.assertAllClose(evaluate_kde(kde, x), kde.evaluate(x))
def flatten_axes(name, treedef, axis_tree, *, kws=False, tupled_args=False): # given an axis spec tree axis_tree (a pytree with integers and Nones at the # leaves, i.e. the Nones are to be considered leaves) that is a tree prefix of # the given treedef, build a complete axis spec tree with the same structure # and return the flattened result # TODO(mattjj,phawkins): improve this implementation proxy = object() dummy = tree_unflatten(treedef, [object()] * treedef.num_leaves) axes = [] add_leaves = lambda i, x: axes.extend([i] * len(tree_flatten(x)[0])) try: tree_map(add_leaves, _replace_nones(proxy, axis_tree), dummy) except ValueError: if kws: # if keyword arguments are included in the tree, we make adapt the error # message only to be about the positional arguments treedef, leaf = treedef_children(treedef) assert treedef_is_leaf(leaf) axis_tree, _ = axis_tree hint = "" if tupled_args: hint += ( f" Note that {name} that are non-trivial pytrees should always be " f"wrapped in a tuple representing the argument list.") if len(treedef.children()) == 1: try: flatten_axes(name, treedef, (axis_tree, )) except ValueError: pass # That's not the issue. else: hint += ( f" In particular, you're passing in a single argument which " f"means that {name} might need to be wrapped in " f"a singleton tuple.") raise ValueError(f"{name} specification must be a tree prefix of the " f"corresponding value, got specification {axis_tree} " f"for value tree {treedef}.{hint}") from None axes = [None if a is proxy else a for a in axes] assert len(axes) == treedef.num_leaves return axes
def ravel_pytree(pytree): """Ravel (i.e. flatten) a pytree of arrays down to a 1D array. Args: pytree: a pytree of arrays and scalars to ravel. Returns: A pair where the first element is a 1D array representing the flattened and concatenated leaf values, with dtype determined by promoting the dtypes of leaf values, and the second element is a callable for unflattening a 1D vector of the same length back to a pytree of of the same structure as the input ``pytree``. If the input pytree is empty (i.e. has no leaves) then as a convention a 1D empty array of dtype float32 is returned in the first component of the output. For details on dtype promotion, see https://jax.readthedocs.io/en/latest/type_promotion.html. """ leaves, treedef = tree_flatten(pytree) flat, unravel_list = _ravel_list(leaves) unravel_pytree = lambda flat: tree_unflatten(treedef, unravel_list(flat)) return flat, unravel_pytree
def flatten_fun_nokwargs(in_tree, *args_flat): py_args = tree_unflatten(in_tree, args_flat) ans = yield py_args, {} yield tree_flatten(ans)
def flatten_fun(in_tree, *args_flat): py_args, py_kwargs = tree_unflatten(in_tree, args_flat) ans = yield py_args, py_kwargs yield tree_flatten(ans)
def flatten_fun_for_vmap(in_tree, *args_flat): py_args, py_kwargs = tree_unflatten(in_tree, args_flat) ans = yield py_args, py_kwargs yield tree_flatten(ans, is_leaf=is_vmappable)