def custom_inverse(f):
    """Decorates a function to enable defining a custom inverse.

  A `custom_inverse`-decorated function is semantically identical to the
  original except when it is inverted with `core.inverse`. By default,
  `core.inverse(custom_inverse(f))` will programmatically invert the body of
  `f`, but `f` has two additional methods that can override that behavior:
  `def_inverse_unary` and `def_inverse_ildj`.

  ## `def_inverse_unary`

  `def_inverse_unary` is applicable if `f` is a unary function.
  `def_inverse_unary` takes in an optional `f_inv` function, which is a unary
  function from the output of `f` to the input of `f`.

  Example:
  ```python
  @custom_inverse
  def add_one(x):
    return x + 1.
  add_one.def_inverse_unary(lambda x: x * 2)  # Define silly custom inverse.
  inverse(add_one)(2.)  # ==> 4.
  ```

  With a unary `f_inv` function, Oryx will automatically compute an inverse
  log-det Jacobian using `core.ildj(core.inverse(f_inv))`, but a user can
  also override the Jacobian term by providing the optional `f_ildj` keyword
  argument to `def_inverse_unary`.

  Example:
  ```python
  @custom_inverse
  def add_one(x):
    return x + 1.
  add_one.def_inverse_unary(lambda x: x * 2, f_ildj=lambda x: jnp.ones_like(x))
  ildj(add_one)(2.)  # ==> 1.
  ```

  ## `def_inverse_and_ildj`

  A more general way of defining a custom inverse or ILDJ is to use
  `def_inverse_and_ildj`, which will enable the user to invert functions with
  partially known inputs and outputs. Take an example like
  `add = lambda x, y: x + y`, which cannot be inverted with just the output,
  but can be inverted when just one input is known. `def_inverse_and_ildj`
  takes a single function `f_ildj` as an argument. `f_ildj` is a function from
  `invals` (a set of values corresponding to `f`'s inputs), `outvals` (a set
  of values corresponding to `f`'s outputs) and `out_ildjs` (a set of inverse
  diagonal log-Jacobian values for each of the `outvals`). If any are unknown,
  they will be `None`. `f_ildj` should return a tuple
  `(new_invals, new_inildjs)` which corresponds to known values of the inputs
  and any corresponding diagonal Jacobian values (which should be the same shape
  as `invals`). If these values cannot be computed (e.g. too many values are
  `None`) the user can raise a `NonInvertibleError` which will signal to Oryx to
  give up trying to invert the function for this set of values.

  Example:
  ```python
  @custom_inverse
  def add(x, y):
    return x + y

  def add_ildj(invals, outvals, out_ildjs):
    x, y = invals
    z = outvals
    z_ildj = outildjs
    if x is None and y is None:
      raise NonInvertibleError()
    if x is None:
      return (z - y, y), (z_ildj + jnp.zeros_like(z), jnp.zeros_like(z))
    if y is None:
      return (x, z - x), (jnp.zeros_like(z), z_ildj + jnp.zeros_like(z))

  add.def_inverse_and_ildj(add_ildj)
  inverse(partial(add, 1.))(2.)  # ==> 1.
  inverse(partial(add, 1.))(2.)  # ==> 0.
  ```

  Args:
    f: a function for which we'd like to define a custom inverse.

  Returns:
    A `CustomInverse` object whose inverse can be overridden with
    `def_inverse_unary` or `def_inverse`.
  """
    return CustomInverse(f, primitive.InitialStylePrimitive(f.__name__))
Beispiel #2
0

def _layer_to_args(layer, *args, **kwargs):
    """Flattens a layer and handles rng in kwargs."""
    kwargs = kwargs.copy()
    flattened_layer, in_tree = tree_util.tree_flatten(layer)
    kwargs['num_weights'] = len(flattened_layer)
    kwargs['in_tree'] = in_tree
    rng = kwargs.pop('rng', None)
    kwargs['has_rng'] = has_rng = rng is not None
    if has_rng:
        args = (rng, ) + args
    return flattened_layer, args, kwargs


layer_cau_p = primitive.InitialStylePrimitive('layer_cau')


class NoneProxy:
    pass


not_mapped = NoneProxy()


def custom_layer_cau_batch(vals, dims, *, num_consts, in_tree, out_tree,
                           kwargs, **params):
    """Batching rule for layer_cau primitive to handle custom layers."""
    if all(dim is batching.not_mapped for dim in dims):
        return layer_cau_p.bind(*vals,
                                num_consts=num_consts,
Beispiel #3
0
from jax import util as jax_util
from jax.interpreters import batching
from oryx.core import ppl
from oryx.core import primitive
from oryx.core import trace_util
from oryx.core.interpreters import harvest
from oryx.core.interpreters import inverse
from oryx.core.interpreters import log_prob
from tensorflow_probability.substrates import jax as tfp

tfd = tfp.distributions
tfed = tfp.experimental.distribute

InverseAndILDJ = inverse.core.InverseAndILDJ

random_variable_p = primitive.InitialStylePrimitive('random_variable')


def random_variable_log_prob_rule(flat_incells, flat_outcells, *, num_consts,
                                  in_tree, out_tree, batch_ndims, **_):
    """Registers Oryx distributions with the log_prob transformation."""
    _, incells = jax_util.split_list(flat_incells, [num_consts])
    val_incells = incells[1:]
    if not all(cell.top() for cell in val_incells):
        return flat_incells, flat_outcells, None
    if not all(cell.top() for cell in flat_outcells):
        return flat_incells, flat_outcells, None
    seed_flat_invals = [object()] + [cell.val for cell in val_incells]
    flat_outvals = [cell.val for cell in flat_outcells]
    _, dist = tree_util.tree_unflatten(in_tree, seed_flat_invals)
    outval = tree_util.tree_unflatten(out_tree, flat_outvals)
Beispiel #4
0
from tensorflow_probability.substrates import jax as tfp

__all__ = [
    'make_type',
]

safe_map = jax_util.safe_map
tf = tfp.tf2jax
tfb = tfp.bijectors

_registry = {}

InverseAndILDJ = inverse.core.InverseAndILDJ
NDSlice = slc.NDSlice

bijector_p = primitive.InitialStylePrimitive('bijector')


class _CellProxy:
    """Used for avoid recursing into cells when doing Pytree flattening/unflattening."""
    def __init__(self, cell):
        self.cell = cell


def bijector_ildj_rule(incells, outcells, *, in_tree, num_consts, direction,
                       num_bijector, **_):
    """Inverse/ILDJ rule for bijectors."""
    const_incells, flat_incells = jax_util.split_list(incells, [num_consts])
    flat_bijector_cells, arg_incells = jax_util.split_list(
        flat_incells, [num_bijector])
    if any(not cell.top() for cell in flat_bijector_cells):