def ildj_rule(incells, outcells, *, in_tree, out_tree, num_consts,
               **_):
     # First incell is a wrapped function because prim is a call primitive.
     const_incells, incells = jax_util.split_list(incells, [num_consts])
     if (all(outcell.top() for outcell in outcells)
             and any(not incell.top() for incell in incells)):
         flat_outvals = [outcell.val for outcell in outcells]
         flat_outildjs = [outcell.ildj for outcell in outcells]
         outvals = tree_util.tree_unflatten(out_tree, flat_outvals)
         outildjs = tree_util.tree_unflatten(out_tree, flat_outildjs)
         flat_invals = [
             None if not incell.top() else incell.val
             for incell in incells
         ]
         invals = tree_util.tree_unflatten(in_tree, flat_invals)
         try:
             new_invals, new_ildjs = f_ildj(invals, outvals, outildjs)
         except NonInvertibleError:
             return const_incells + incells, outcells, None
         # We need to flatten the output from `f_ildj` using
         # `tree_util.tree_flatten` but if the user returns `None` (when
         # inversion is not possible), JAX will remove `None`s from the flattened
         # version and the number of `new_incells` will not match the old
         # `incells`. We use the private `_replace_nones` feature in JAX to
         # replace it with a sentinel that won't be removed when flattening.
         none_ = object()
         new_invals = tree_util._replace_nones(none_, new_invals)  # pylint: disable=protected-access
         new_ildjs = tree_util._replace_nones(none_, new_ildjs)  # pylint: disable=protected-access
         new_flat_invals = tree_util.tree_leaves(new_invals)
         new_flat_ildjs = tree_util.tree_leaves(new_ildjs)
         inslices = [
             NDSlice.new(inval, ildj)
             for inval, ildj in zip(new_flat_invals, new_flat_ildjs)
         ]
         new_incells = []
         for new_flat_inval, old_incell, inslice in zip(
                 new_flat_invals, incells, inslices):
             if new_flat_inval is not none_:
                 new_incells.append(
                     InverseAndILDJ(old_incell.aval, [inslice]))
             else:
                 new_incells.append(old_incell)
         return const_incells + new_incells, outcells, None
     elif (all(incell.top() for incell in incells)
           and any(not outcell.top() for outcell in outcells)):
         flat_invals = [incell.val for incell in incells]
         invals = tree_util.tree_unflatten(in_tree, flat_invals)
         outvals = self(*invals)
         flat_outvals = tree_util.tree_leaves(outvals)
         outcells = [
             InverseAndILDJ.new(outval) for outval in flat_outvals
         ]
         return const_incells + incells, outcells, None
     return const_incells + incells, outcells, None
Beispiel #2
0
def flatten_axes(name, treedef, axis_tree, *, kws=False, tupled_args=False):
    # given an axis spec tree axis_tree (a pytree with integers and Nones at the
    # leaves, i.e. the Nones are to be considered leaves) that is a tree prefix of
    # the given treedef, build a complete axis spec tree with the same structure
    # and return the flattened result
    # TODO(mattjj,phawkins): improve this implementation

    proxy = object()
    dummy = tree_unflatten(treedef, [object()] * treedef.num_leaves)
    axes = []
    add_leaves = lambda i, x: axes.extend([i] * len(tree_flatten(x)[0]))
    try:
        tree_map(add_leaves, _replace_nones(proxy, axis_tree), dummy)
    except ValueError:
        if kws:
            # if keyword arguments are included in the tree, we make adapt the error
            # message only to be about the positional arguments
            treedef, leaf = treedef_children(treedef)
            assert treedef_is_leaf(leaf)
            axis_tree, _ = axis_tree
        hint = ""
        if tupled_args:
            hint += (
                f" Note that {name} that are non-trivial pytrees should always be "
                f"wrapped in a tuple representing the argument list.")
            if len(treedef.children()) == 1:
                try:
                    flatten_axes(name, treedef, (axis_tree, ))
                except ValueError:
                    pass  # That's not the issue.
                else:
                    hint += (
                        f" In particular, you're passing in a single argument which "
                        f"means that {name} might need to be wrapped in "
                        f"a singleton tuple.")
        raise ValueError(f"{name} specification must be a tree prefix of the "
                         f"corresponding value, got specification {axis_tree} "
                         f"for value tree {treedef}.{hint}") from None
    axes = [None if a is proxy else a for a in axes]
    assert len(axes) == treedef.num_leaves
    return axes