def ildj_rule(incells, outcells, *, in_tree, out_tree, num_consts, **_): # First incell is a wrapped function because prim is a call primitive. const_incells, incells = jax_util.split_list(incells, [num_consts]) if (all(outcell.top() for outcell in outcells) and any(not incell.top() for incell in incells)): flat_outvals = [outcell.val for outcell in outcells] flat_outildjs = [outcell.ildj for outcell in outcells] outvals = tree_util.tree_unflatten(out_tree, flat_outvals) outildjs = tree_util.tree_unflatten(out_tree, flat_outildjs) flat_invals = [ None if not incell.top() else incell.val for incell in incells ] invals = tree_util.tree_unflatten(in_tree, flat_invals) try: new_invals, new_ildjs = f_ildj(invals, outvals, outildjs) except NonInvertibleError: return const_incells + incells, outcells, None # We need to flatten the output from `f_ildj` using # `tree_util.tree_flatten` but if the user returns `None` (when # inversion is not possible), JAX will remove `None`s from the flattened # version and the number of `new_incells` will not match the old # `incells`. We use the private `_replace_nones` feature in JAX to # replace it with a sentinel that won't be removed when flattening. none_ = object() new_invals = tree_util._replace_nones(none_, new_invals) # pylint: disable=protected-access new_ildjs = tree_util._replace_nones(none_, new_ildjs) # pylint: disable=protected-access new_flat_invals = tree_util.tree_leaves(new_invals) new_flat_ildjs = tree_util.tree_leaves(new_ildjs) inslices = [ NDSlice.new(inval, ildj) for inval, ildj in zip(new_flat_invals, new_flat_ildjs) ] new_incells = [] for new_flat_inval, old_incell, inslice in zip( new_flat_invals, incells, inslices): if new_flat_inval is not none_: new_incells.append( InverseAndILDJ(old_incell.aval, [inslice])) else: new_incells.append(old_incell) return const_incells + new_incells, outcells, None elif (all(incell.top() for incell in incells) and any(not outcell.top() for outcell in outcells)): flat_invals = [incell.val for incell in incells] invals = tree_util.tree_unflatten(in_tree, flat_invals) outvals = self(*invals) flat_outvals = tree_util.tree_leaves(outvals) outcells = [ InverseAndILDJ.new(outval) for outval in flat_outvals ] return const_incells + incells, outcells, None return const_incells + incells, outcells, None
def flatten_axes(name, treedef, axis_tree, *, kws=False, tupled_args=False): # given an axis spec tree axis_tree (a pytree with integers and Nones at the # leaves, i.e. the Nones are to be considered leaves) that is a tree prefix of # the given treedef, build a complete axis spec tree with the same structure # and return the flattened result # TODO(mattjj,phawkins): improve this implementation proxy = object() dummy = tree_unflatten(treedef, [object()] * treedef.num_leaves) axes = [] add_leaves = lambda i, x: axes.extend([i] * len(tree_flatten(x)[0])) try: tree_map(add_leaves, _replace_nones(proxy, axis_tree), dummy) except ValueError: if kws: # if keyword arguments are included in the tree, we make adapt the error # message only to be about the positional arguments treedef, leaf = treedef_children(treedef) assert treedef_is_leaf(leaf) axis_tree, _ = axis_tree hint = "" if tupled_args: hint += ( f" Note that {name} that are non-trivial pytrees should always be " f"wrapped in a tuple representing the argument list.") if len(treedef.children()) == 1: try: flatten_axes(name, treedef, (axis_tree, )) except ValueError: pass # That's not the issue. else: hint += ( f" In particular, you're passing in a single argument which " f"means that {name} might need to be wrapped in " f"a singleton tuple.") raise ValueError(f"{name} specification must be a tree prefix of the " f"corresponding value, got specification {axis_tree} " f"for value tree {treedef}.{hint}") from None axes = [None if a is proxy else a for a in axes] assert len(axes) == treedef.num_leaves return axes