def custom_inverse(f): """Decorates a function to enable defining a custom inverse. A `custom_inverse`-decorated function is semantically identical to the original except when it is inverted with `core.inverse`. By default, `core.inverse(custom_inverse(f))` will programmatically invert the body of `f`, but `f` has two additional methods that can override that behavior: `def_inverse_unary` and `def_inverse_ildj`. ## `def_inverse_unary` `def_inverse_unary` is applicable if `f` is a unary function. `def_inverse_unary` takes in an optional `f_inv` function, which is a unary function from the output of `f` to the input of `f`. Example: ```python @custom_inverse def add_one(x): return x + 1. add_one.def_inverse_unary(lambda x: x * 2) # Define silly custom inverse. inverse(add_one)(2.) # ==> 4. ``` With a unary `f_inv` function, Oryx will automatically compute an inverse log-det Jacobian using `core.ildj(core.inverse(f_inv))`, but a user can also override the Jacobian term by providing the optional `f_ildj` keyword argument to `def_inverse_unary`. Example: ```python @custom_inverse def add_one(x): return x + 1. add_one.def_inverse_unary(lambda x: x * 2, f_ildj=lambda x: jnp.ones_like(x)) ildj(add_one)(2.) # ==> 1. ``` ## `def_inverse_and_ildj` A more general way of defining a custom inverse or ILDJ is to use `def_inverse_and_ildj`, which will enable the user to invert functions with partially known inputs and outputs. Take an example like `add = lambda x, y: x + y`, which cannot be inverted with just the output, but can be inverted when just one input is known. `def_inverse_and_ildj` takes a single function `f_ildj` as an argument. `f_ildj` is a function from `invals` (a set of values corresponding to `f`'s inputs), `outvals` (a set of values corresponding to `f`'s outputs) and `out_ildjs` (a set of inverse diagonal log-Jacobian values for each of the `outvals`). If any are unknown, they will be `None`. `f_ildj` should return a tuple `(new_invals, new_inildjs)` which corresponds to known values of the inputs and any corresponding diagonal Jacobian values (which should be the same shape as `invals`). If these values cannot be computed (e.g. too many values are `None`) the user can raise a `NonInvertibleError` which will signal to Oryx to give up trying to invert the function for this set of values. Example: ```python @custom_inverse def add(x, y): return x + y def add_ildj(invals, outvals, out_ildjs): x, y = invals z = outvals z_ildj = outildjs if x is None and y is None: raise NonInvertibleError() if x is None: return (z - y, y), (z_ildj + jnp.zeros_like(z), jnp.zeros_like(z)) if y is None: return (x, z - x), (jnp.zeros_like(z), z_ildj + jnp.zeros_like(z)) add.def_inverse_and_ildj(add_ildj) inverse(partial(add, 1.))(2.) # ==> 1. inverse(partial(add, 1.))(2.) # ==> 0. ``` Args: f: a function for which we'd like to define a custom inverse. Returns: A `CustomInverse` object whose inverse can be overridden with `def_inverse_unary` or `def_inverse`. """ return CustomInverse(f, primitive.InitialStylePrimitive(f.__name__))
def _layer_to_args(layer, *args, **kwargs): """Flattens a layer and handles rng in kwargs.""" kwargs = kwargs.copy() flattened_layer, in_tree = tree_util.tree_flatten(layer) kwargs['num_weights'] = len(flattened_layer) kwargs['in_tree'] = in_tree rng = kwargs.pop('rng', None) kwargs['has_rng'] = has_rng = rng is not None if has_rng: args = (rng, ) + args return flattened_layer, args, kwargs layer_cau_p = primitive.InitialStylePrimitive('layer_cau') class NoneProxy: pass not_mapped = NoneProxy() def custom_layer_cau_batch(vals, dims, *, num_consts, in_tree, out_tree, kwargs, **params): """Batching rule for layer_cau primitive to handle custom layers.""" if all(dim is batching.not_mapped for dim in dims): return layer_cau_p.bind(*vals, num_consts=num_consts,
from jax import util as jax_util from jax.interpreters import batching from oryx.core import ppl from oryx.core import primitive from oryx.core import trace_util from oryx.core.interpreters import harvest from oryx.core.interpreters import inverse from oryx.core.interpreters import log_prob from tensorflow_probability.substrates import jax as tfp tfd = tfp.distributions tfed = tfp.experimental.distribute InverseAndILDJ = inverse.core.InverseAndILDJ random_variable_p = primitive.InitialStylePrimitive('random_variable') def random_variable_log_prob_rule(flat_incells, flat_outcells, *, num_consts, in_tree, out_tree, batch_ndims, **_): """Registers Oryx distributions with the log_prob transformation.""" _, incells = jax_util.split_list(flat_incells, [num_consts]) val_incells = incells[1:] if not all(cell.top() for cell in val_incells): return flat_incells, flat_outcells, None if not all(cell.top() for cell in flat_outcells): return flat_incells, flat_outcells, None seed_flat_invals = [object()] + [cell.val for cell in val_incells] flat_outvals = [cell.val for cell in flat_outcells] _, dist = tree_util.tree_unflatten(in_tree, seed_flat_invals) outval = tree_util.tree_unflatten(out_tree, flat_outvals)
from tensorflow_probability.substrates import jax as tfp __all__ = [ 'make_type', ] safe_map = jax_util.safe_map tf = tfp.tf2jax tfb = tfp.bijectors _registry = {} InverseAndILDJ = inverse.core.InverseAndILDJ NDSlice = slc.NDSlice bijector_p = primitive.InitialStylePrimitive('bijector') class _CellProxy: """Used for avoid recursing into cells when doing Pytree flattening/unflattening.""" def __init__(self, cell): self.cell = cell def bijector_ildj_rule(incells, outcells, *, in_tree, num_consts, direction, num_bijector, **_): """Inverse/ILDJ rule for bijectors.""" const_incells, flat_incells = jax_util.split_list(incells, [num_consts]) flat_bijector_cells, arg_incells = jax_util.split_list( flat_incells, [num_bijector]) if any(not cell.top() for cell in flat_bijector_cells):