Esempio n. 1
0
 def ildj_rule(incells, outcells, *, in_tree, out_tree, num_consts,
               **_):
     # First incell is a wrapped function because prim is a call primitive.
     const_incells, incells = jax_util.split_list(incells, [num_consts])
     if (all(outcell.top() for outcell in outcells)
             and any(not incell.top() for incell in incells)):
         flat_outvals = [outcell.val for outcell in outcells]
         flat_outildjs = [outcell.ildj for outcell in outcells]
         outvals = tree_util.tree_unflatten(out_tree, flat_outvals)
         outildjs = tree_util.tree_unflatten(out_tree, flat_outildjs)
         flat_invals = [
             None if not incell.top() else incell.val
             for incell in incells
         ]
         invals = tree_util.tree_unflatten(in_tree, flat_invals)
         try:
             new_invals, new_ildjs = f_ildj(invals, outvals, outildjs)
         except NonInvertibleError:
             return const_incells + incells, outcells, None
         # We need to flatten the output from `f_ildj` using
         # `tree_util.tree_flatten` but if the user returns `None` (when
         # inversion is not possible), JAX will remove `None`s from the flattened
         # version and the number of `new_incells` will not match the old
         # `incells`. We use the private `_replace_nones` feature in JAX to
         # replace it with a sentinel that won't be removed when flattening.
         none_ = object()
         new_invals = tree_util._replace_nones(none_, new_invals)  # pylint: disable=protected-access
         new_ildjs = tree_util._replace_nones(none_, new_ildjs)  # pylint: disable=protected-access
         new_flat_invals = tree_util.tree_leaves(new_invals)
         new_flat_ildjs = tree_util.tree_leaves(new_ildjs)
         inslices = [
             NDSlice.new(inval, ildj)
             for inval, ildj in zip(new_flat_invals, new_flat_ildjs)
         ]
         new_incells = []
         for new_flat_inval, old_incell, inslice in zip(
                 new_flat_invals, incells, inslices):
             if new_flat_inval is not none_:
                 new_incells.append(
                     InverseAndILDJ(old_incell.aval, [inslice]))
             else:
                 new_incells.append(old_incell)
         return const_incells + new_incells, outcells, None
     elif (all(incell.top() for incell in incells)
           and any(not outcell.top() for outcell in outcells)):
         flat_invals = [incell.val for incell in incells]
         invals = tree_util.tree_unflatten(in_tree, flat_invals)
         outvals = self(*invals)
         flat_outvals = tree_util.tree_leaves(outvals)
         outcells = [
             InverseAndILDJ.new(outval) for outval in flat_outvals
         ]
         return const_incells + incells, outcells, None
     return const_incells + incells, outcells, None
Esempio n. 2
0
def apply_flat_fun_nokwargs(fun, io_tree, py_args):
  in_tree_expected, out_tree = io_tree
  args, in_tree = tree_flatten(py_args)
  if in_tree != in_tree_expected:
    raise TypeError("Expected {}, got {}".format(in_tree_expected, in_tree))
  ans = fun(*args)
  return tree_unflatten(out_tree, ans)
Esempio n. 3
0
def apply_flat_fun(fun, io_tree, *py_args):
    in_tree_expected, out_tree = io_tree
    args, in_tree = tree_flatten((py_args, {}))
    if in_tree != in_tree_expected:
        raise TypeError(f"Expected {in_tree_expected}, got {in_tree}")
    ans = fun(*args)
    return tree_unflatten(out_tree, ans)
Esempio n. 4
0
 def dex_fun(*args, **kwargs):
   args_flat, in_tree = tree_flatten((args, kwargs))
   in_avals, in_tree_, keep_inputs = abstractify(args, kwargs)
   assert in_tree == in_tree_
   jaxpr, consts, out_tree = make_jaxpr(fun, in_tree, tuple(in_avals),
                                        tuple(keep_inputs))
   out_flat = dex_call_p.bind(*consts, *args_flat, jaxpr=jaxpr)
   return tree_unflatten(out_tree, out_flat)
Esempio n. 5
0
def flatten_fun_nokwargs2(in_tree, *args_flat):
  py_args = tree_unflatten(in_tree, args_flat)
  pair = yield py_args, {}
  if not isinstance(pair, (list, tuple)) or len(pair) != 2:
    raise TypeError("expected function with aux output to return a two-element "
                    f"tuple, but got type {type(pair)} with value {repr(pair)}")
  ans, aux = pair
  ans_flat, ans_tree = tree_flatten(ans)
  aux_flat, aux_tree = tree_flatten(aux)
  yield (ans_flat, aux_flat), (ans_tree, aux_tree)
Esempio n. 6
0
  def testKdePyTree(self):
    @jax.jit
    def evaluate_kde(kde, x):
      return kde.evaluate(x)

    dtype = np.float32
    rng = jtu.rand_default(self.rng())
    dataset = rng((3, 15), dtype)
    x = rng((3, 12), dtype)
    kde = lsp_stats.gaussian_kde(dataset)
    leaves, treedef = tree_util.tree_flatten(kde)
    kde2 = tree_util.tree_unflatten(treedef, leaves)
    tree_util.tree_map(lambda a, b: self.assertAllClose(a, b), kde, kde2)
    self.assertAllClose(evaluate_kde(kde, x), kde.evaluate(x))
Esempio n. 7
0
def flatten_axes(name, treedef, axis_tree, *, kws=False, tupled_args=False):
    # given an axis spec tree axis_tree (a pytree with integers and Nones at the
    # leaves, i.e. the Nones are to be considered leaves) that is a tree prefix of
    # the given treedef, build a complete axis spec tree with the same structure
    # and return the flattened result
    # TODO(mattjj,phawkins): improve this implementation

    proxy = object()
    dummy = tree_unflatten(treedef, [object()] * treedef.num_leaves)
    axes = []
    add_leaves = lambda i, x: axes.extend([i] * len(tree_flatten(x)[0]))
    try:
        tree_map(add_leaves, _replace_nones(proxy, axis_tree), dummy)
    except ValueError:
        if kws:
            # if keyword arguments are included in the tree, we make adapt the error
            # message only to be about the positional arguments
            treedef, leaf = treedef_children(treedef)
            assert treedef_is_leaf(leaf)
            axis_tree, _ = axis_tree
        hint = ""
        if tupled_args:
            hint += (
                f" Note that {name} that are non-trivial pytrees should always be "
                f"wrapped in a tuple representing the argument list.")
            if len(treedef.children()) == 1:
                try:
                    flatten_axes(name, treedef, (axis_tree, ))
                except ValueError:
                    pass  # That's not the issue.
                else:
                    hint += (
                        f" In particular, you're passing in a single argument which "
                        f"means that {name} might need to be wrapped in "
                        f"a singleton tuple.")
        raise ValueError(f"{name} specification must be a tree prefix of the "
                         f"corresponding value, got specification {axis_tree} "
                         f"for value tree {treedef}.{hint}") from None
    axes = [None if a is proxy else a for a in axes]
    assert len(axes) == treedef.num_leaves
    return axes
Esempio n. 8
0
def ravel_pytree(pytree):
  """Ravel (i.e. flatten) a pytree of arrays down to a 1D array.

  Args:
    pytree: a pytree of arrays and scalars to ravel.

  Returns:
    A pair where the first element is a 1D array representing the flattened and
    concatenated leaf values, with dtype determined by promoting the dtypes of
    leaf values, and the second element is a callable for unflattening a 1D
    vector of the same length back to a pytree of of the same structure as the
    input ``pytree``. If the input pytree is empty (i.e. has no leaves) then as
    a convention a 1D empty array of dtype float32 is returned in the first
    component of the output.

  For details on dtype promotion, see
  https://jax.readthedocs.io/en/latest/type_promotion.html.

  """
  leaves, treedef = tree_flatten(pytree)
  flat, unravel_list = _ravel_list(leaves)
  unravel_pytree = lambda flat: tree_unflatten(treedef, unravel_list(flat))
  return flat, unravel_pytree
Esempio n. 9
0
def flatten_fun_nokwargs(in_tree, *args_flat):
  py_args = tree_unflatten(in_tree, args_flat)
  ans = yield py_args, {}
  yield tree_flatten(ans)
Esempio n. 10
0
def flatten_fun(in_tree, *args_flat):
  py_args, py_kwargs = tree_unflatten(in_tree, args_flat)
  ans = yield py_args, py_kwargs
  yield tree_flatten(ans)
Esempio n. 11
0
def flatten_fun_for_vmap(in_tree, *args_flat):
  py_args, py_kwargs = tree_unflatten(in_tree, args_flat)
  ans = yield py_args, py_kwargs
  yield tree_flatten(ans, is_leaf=is_vmappable)