def flatten_fun_nokwargs2(in_tree, *args_flat): py_args = tree_unflatten(in_tree, args_flat) pair = yield py_args, {} if not isinstance(pair, (list, tuple)) or len(pair) != 2: raise TypeError("expected function with aux output to return a two-element " f"tuple, but got type {type(pair)} with value {repr(pair)}") ans, aux = pair ans_flat, ans_tree = tree_flatten(ans) aux_flat, aux_tree = tree_flatten(aux) yield (ans_flat, aux_flat), (ans_tree, aux_tree)
def apply_flat_fun_nokwargs(fun, io_tree, py_args): in_tree_expected, out_tree = io_tree args, in_tree = tree_flatten(py_args) if in_tree != in_tree_expected: raise TypeError("Expected {}, got {}".format(in_tree_expected, in_tree)) ans = fun(*args) return tree_unflatten(out_tree, ans)
def apply_flat_fun(fun, io_tree, *py_args): in_tree_expected, out_tree = io_tree args, in_tree = tree_flatten((py_args, {})) if in_tree != in_tree_expected: raise TypeError(f"Expected {in_tree_expected}, got {in_tree}") ans = fun(*args) return tree_unflatten(out_tree, ans)
def dex_fun(*args, **kwargs): args_flat, in_tree = tree_flatten((args, kwargs)) in_avals, in_tree_, keep_inputs = abstractify(args, kwargs) assert in_tree == in_tree_ jaxpr, consts, out_tree = make_jaxpr(fun, in_tree, tuple(in_avals), tuple(keep_inputs)) out_flat = dex_call_p.bind(*consts, *args_flat, jaxpr=jaxpr) return tree_unflatten(out_tree, out_flat)
def testKdePyTree(self): @jax.jit def evaluate_kde(kde, x): return kde.evaluate(x) dtype = np.float32 rng = jtu.rand_default(self.rng()) dataset = rng((3, 15), dtype) x = rng((3, 12), dtype) kde = lsp_stats.gaussian_kde(dataset) leaves, treedef = tree_util.tree_flatten(kde) kde2 = tree_util.tree_unflatten(treedef, leaves) tree_util.tree_map(lambda a, b: self.assertAllClose(a, b), kde, kde2) self.assertAllClose(evaluate_kde(kde, x), kde.evaluate(x))
def abstractify(args, kwargs): flat_args, in_tree = tree_flatten((args, kwargs)) if abstracted_axes is None: return map(shaped_abstractify, flat_args), in_tree, [True] * len(flat_args) else: # TODO this is for dynamic shapes, replace w/ utilities in jax/api.py axes_specs = broadcast_prefix(abstracted_axes, args) sizes: Dict[Hashable, int] = {} # for error checking counts = it.count() env: Dict[Hashable, int] = defaultdict(lambda: DBIdx(next(counts))) def make_aval(arg, spec): if not spec: return shaped_abstractify(arg) assert all(arg.shape[i] == sizes.setdefault(name, arg.shape[i]) for i, name in spec.items()) shape = [env[spec[i]] if i in spec else d for i, d in enumerate(arg.shape)] return core.DShapedArray(tuple(shape), arg.dtype, False) in_avals = map(make_aval, flat_args, axes_specs) keep_inputs = [False] * len(env) + [True] * len(flat_args) return [*env.values(), *in_avals], in_tree, keep_inputs
def flatten_axes(name, treedef, axis_tree, *, kws=False, tupled_args=False): # given an axis spec tree axis_tree (a pytree with integers and Nones at the # leaves, i.e. the Nones are to be considered leaves) that is a tree prefix of # the given treedef, build a complete axis spec tree with the same structure # and return the flattened result # TODO(mattjj,phawkins): improve this implementation proxy = object() dummy = tree_unflatten(treedef, [object()] * treedef.num_leaves) axes = [] add_leaves = lambda i, x: axes.extend([i] * len(tree_flatten(x)[0])) try: tree_map(add_leaves, _replace_nones(proxy, axis_tree), dummy) except ValueError: if kws: # if keyword arguments are included in the tree, we make adapt the error # message only to be about the positional arguments treedef, leaf = treedef_children(treedef) assert treedef_is_leaf(leaf) axis_tree, _ = axis_tree hint = "" if tupled_args: hint += ( f" Note that {name} that are non-trivial pytrees should always be " f"wrapped in a tuple representing the argument list.") if len(treedef.children()) == 1: try: flatten_axes(name, treedef, (axis_tree, )) except ValueError: pass # That's not the issue. else: hint += ( f" In particular, you're passing in a single argument which " f"means that {name} might need to be wrapped in " f"a singleton tuple.") raise ValueError(f"{name} specification must be a tree prefix of the " f"corresponding value, got specification {axis_tree} " f"for value tree {treedef}.{hint}") from None axes = [None if a is proxy else a for a in axes] assert len(axes) == treedef.num_leaves return axes
def ravel_pytree(pytree): """Ravel (i.e. flatten) a pytree of arrays down to a 1D array. Args: pytree: a pytree of arrays and scalars to ravel. Returns: A pair where the first element is a 1D array representing the flattened and concatenated leaf values, with dtype determined by promoting the dtypes of leaf values, and the second element is a callable for unflattening a 1D vector of the same length back to a pytree of of the same structure as the input ``pytree``. If the input pytree is empty (i.e. has no leaves) then as a convention a 1D empty array of dtype float32 is returned in the first component of the output. For details on dtype promotion, see https://jax.readthedocs.io/en/latest/type_promotion.html. """ leaves, treedef = tree_flatten(pytree) flat, unravel_list = _ravel_list(leaves) unravel_pytree = lambda flat: tree_unflatten(treedef, unravel_list(flat)) return flat, unravel_pytree
def flatten_fun_nokwargs(in_tree, *args_flat): py_args = tree_unflatten(in_tree, args_flat) ans = yield py_args, {} yield tree_flatten(ans)
def flatten_fun(in_tree, *args_flat): py_args, py_kwargs = tree_unflatten(in_tree, args_flat) ans = yield py_args, py_kwargs yield tree_flatten(ans)
def flatten_fun_for_vmap(in_tree, *args_flat): py_args, py_kwargs = tree_unflatten(in_tree, args_flat) ans = yield py_args, py_kwargs yield tree_flatten(ans, is_leaf=is_vmappable)