Exemplo n.º 1
0
def flatten_fun_nokwargs2(in_tree, *args_flat):
  py_args = tree_unflatten(in_tree, args_flat)
  pair = yield py_args, {}
  if not isinstance(pair, (list, tuple)) or len(pair) != 2:
    raise TypeError("expected function with aux output to return a two-element "
                    f"tuple, but got type {type(pair)} with value {repr(pair)}")
  ans, aux = pair
  ans_flat, ans_tree = tree_flatten(ans)
  aux_flat, aux_tree = tree_flatten(aux)
  yield (ans_flat, aux_flat), (ans_tree, aux_tree)
Exemplo n.º 2
0
def apply_flat_fun_nokwargs(fun, io_tree, py_args):
  in_tree_expected, out_tree = io_tree
  args, in_tree = tree_flatten(py_args)
  if in_tree != in_tree_expected:
    raise TypeError("Expected {}, got {}".format(in_tree_expected, in_tree))
  ans = fun(*args)
  return tree_unflatten(out_tree, ans)
Exemplo n.º 3
0
def apply_flat_fun(fun, io_tree, *py_args):
    in_tree_expected, out_tree = io_tree
    args, in_tree = tree_flatten((py_args, {}))
    if in_tree != in_tree_expected:
        raise TypeError(f"Expected {in_tree_expected}, got {in_tree}")
    ans = fun(*args)
    return tree_unflatten(out_tree, ans)
Exemplo n.º 4
0
 def dex_fun(*args, **kwargs):
   args_flat, in_tree = tree_flatten((args, kwargs))
   in_avals, in_tree_, keep_inputs = abstractify(args, kwargs)
   assert in_tree == in_tree_
   jaxpr, consts, out_tree = make_jaxpr(fun, in_tree, tuple(in_avals),
                                        tuple(keep_inputs))
   out_flat = dex_call_p.bind(*consts, *args_flat, jaxpr=jaxpr)
   return tree_unflatten(out_tree, out_flat)
Exemplo n.º 5
0
  def testKdePyTree(self):
    @jax.jit
    def evaluate_kde(kde, x):
      return kde.evaluate(x)

    dtype = np.float32
    rng = jtu.rand_default(self.rng())
    dataset = rng((3, 15), dtype)
    x = rng((3, 12), dtype)
    kde = lsp_stats.gaussian_kde(dataset)
    leaves, treedef = tree_util.tree_flatten(kde)
    kde2 = tree_util.tree_unflatten(treedef, leaves)
    tree_util.tree_map(lambda a, b: self.assertAllClose(a, b), kde, kde2)
    self.assertAllClose(evaluate_kde(kde, x), kde.evaluate(x))
Exemplo n.º 6
0
 def abstractify(args, kwargs):
   flat_args, in_tree = tree_flatten((args, kwargs))
   if abstracted_axes is None:
     return map(shaped_abstractify, flat_args), in_tree, [True] * len(flat_args)
   else:
     # TODO this is for dynamic shapes, replace w/ utilities in jax/api.py
     axes_specs = broadcast_prefix(abstracted_axes, args)
     sizes: Dict[Hashable, int] = {}  # for error checking
     counts = it.count()
     env: Dict[Hashable, int] = defaultdict(lambda: DBIdx(next(counts)))
     def make_aval(arg, spec):
       if not spec:
         return shaped_abstractify(arg)
       assert all(arg.shape[i] == sizes.setdefault(name, arg.shape[i])
                  for i, name in spec.items())
       shape = [env[spec[i]] if i in spec else d for i, d in enumerate(arg.shape)]
       return core.DShapedArray(tuple(shape), arg.dtype, False)
     in_avals = map(make_aval, flat_args, axes_specs)
     keep_inputs = [False] * len(env) + [True] * len(flat_args)
     return [*env.values(), *in_avals], in_tree, keep_inputs
Exemplo n.º 7
0
def flatten_axes(name, treedef, axis_tree, *, kws=False, tupled_args=False):
    # given an axis spec tree axis_tree (a pytree with integers and Nones at the
    # leaves, i.e. the Nones are to be considered leaves) that is a tree prefix of
    # the given treedef, build a complete axis spec tree with the same structure
    # and return the flattened result
    # TODO(mattjj,phawkins): improve this implementation

    proxy = object()
    dummy = tree_unflatten(treedef, [object()] * treedef.num_leaves)
    axes = []
    add_leaves = lambda i, x: axes.extend([i] * len(tree_flatten(x)[0]))
    try:
        tree_map(add_leaves, _replace_nones(proxy, axis_tree), dummy)
    except ValueError:
        if kws:
            # if keyword arguments are included in the tree, we make adapt the error
            # message only to be about the positional arguments
            treedef, leaf = treedef_children(treedef)
            assert treedef_is_leaf(leaf)
            axis_tree, _ = axis_tree
        hint = ""
        if tupled_args:
            hint += (
                f" Note that {name} that are non-trivial pytrees should always be "
                f"wrapped in a tuple representing the argument list.")
            if len(treedef.children()) == 1:
                try:
                    flatten_axes(name, treedef, (axis_tree, ))
                except ValueError:
                    pass  # That's not the issue.
                else:
                    hint += (
                        f" In particular, you're passing in a single argument which "
                        f"means that {name} might need to be wrapped in "
                        f"a singleton tuple.")
        raise ValueError(f"{name} specification must be a tree prefix of the "
                         f"corresponding value, got specification {axis_tree} "
                         f"for value tree {treedef}.{hint}") from None
    axes = [None if a is proxy else a for a in axes]
    assert len(axes) == treedef.num_leaves
    return axes
Exemplo n.º 8
0
def ravel_pytree(pytree):
  """Ravel (i.e. flatten) a pytree of arrays down to a 1D array.

  Args:
    pytree: a pytree of arrays and scalars to ravel.

  Returns:
    A pair where the first element is a 1D array representing the flattened and
    concatenated leaf values, with dtype determined by promoting the dtypes of
    leaf values, and the second element is a callable for unflattening a 1D
    vector of the same length back to a pytree of of the same structure as the
    input ``pytree``. If the input pytree is empty (i.e. has no leaves) then as
    a convention a 1D empty array of dtype float32 is returned in the first
    component of the output.

  For details on dtype promotion, see
  https://jax.readthedocs.io/en/latest/type_promotion.html.

  """
  leaves, treedef = tree_flatten(pytree)
  flat, unravel_list = _ravel_list(leaves)
  unravel_pytree = lambda flat: tree_unflatten(treedef, unravel_list(flat))
  return flat, unravel_pytree
Exemplo n.º 9
0
def flatten_fun_nokwargs(in_tree, *args_flat):
  py_args = tree_unflatten(in_tree, args_flat)
  ans = yield py_args, {}
  yield tree_flatten(ans)
Exemplo n.º 10
0
def flatten_fun(in_tree, *args_flat):
  py_args, py_kwargs = tree_unflatten(in_tree, args_flat)
  ans = yield py_args, py_kwargs
  yield tree_flatten(ans)
Exemplo n.º 11
0
def flatten_fun_for_vmap(in_tree, *args_flat):
  py_args, py_kwargs = tree_unflatten(in_tree, args_flat)
  ans = yield py_args, py_kwargs
  yield tree_flatten(ans, is_leaf=is_vmappable)