Example #1
0

def _layer_to_args(layer, *args, **kwargs):
    """Flattens a layer and handles rng in kwargs."""
    kwargs = kwargs.copy()
    flattened_layer, in_tree = tree_util.tree_flatten(layer)
    kwargs['num_weights'] = len(flattened_layer)
    kwargs['in_tree'] = in_tree
    rng = kwargs.pop('rng', None)
    kwargs['has_rng'] = has_rng = rng is not None
    if has_rng:
        args = (rng, ) + args
    return flattened_layer, args, kwargs


layer_cau_p = primitive.HigherOrderPrimitive('layer_cau')


class NoneProxy:
    pass


not_mapped = NoneProxy()


def custom_layer_cau_batch(trace, f, tracers, params):
    """Batching rule for layer_cau primitive to handle custom layers."""
    vals, dims = jax_util.unzip2((t.val, t.batch_dim) for t in tracers)
    if all(dim is batching.not_mapped for dim in dims):
        return layer_cau_p.bind(f, *vals, **params)
    args = tree_util.tree_unflatten(params['in_tree'], vals)
Example #2
0
    def _component_specs(self):
        return self._param_specs

    def _serialize(self):
        # Include default version 1 for now
        return 1, self._clsid, self._param_specs, self._kwargs

    @classmethod
    def _deserialize(cls, encoded):
        version, clsid, param_specs, kwargs = encoded
        if version != 1: raise ValueError
        if clsid not in _registry: raise ValueError(clsid)
        return cls(clsid, param_specs, kwargs)


random_variable_p = primitive.HigherOrderPrimitive('random_variable')
unzip.block_registry.add(random_variable_p)


def random_variable_log_prob_rule(flat_incells, flat_outcells, **params):
    """Registers Oryx distributions with the log_prob transformation."""
    del params
    # First incell is the call primitive function
    return flat_incells[1:], flat_outcells, None


log_prob.log_prob_rules[random_variable_p] = random_variable_log_prob_rule


def random_variable_log_prob(flat_incells, val, **params):
    """Registers Oryx distributions with the log_prob transformation."""
Example #3
0
def custom_inverse(f):
    """Decorates a function to enable defining a custom inverse.

  A `custom_inverse`-decorated function is semantically identical to the
  original except when it is inverted with `core.inverse`. By default,
  `core.inverse(custom_inverse(f))` will programmatically invert the body of
  `f`, but `f` has two additional methods that can override that behavior:
  `def_inverse_unary` and `def_inverse_ildj`.

  ## `def_inverse_unary`

  `def_inverse_unary` is applicable if `f` is a unary function.
  `def_inverse_unary` takes in an optional `f_inv` function, which is a unary
  function from the output of `f` to the input of `f`.

  Example:
  ```python
  @custom_inverse
  def add_one(x):
    return x + 1.
  add_one.def_inverse_unary(lambda x: x * 2)  # Define silly custom inverse.
  inverse(add_one)(2.)  # ==> 4.
  ```

  With a unary `f_inv` function, Oryx will automatically compute an inverse
  log-det Jacobian using `core.ildj(core.inverse(f_inv))`, but a user can
  also override the Jacobian term by providing the optional `f_ildj` keyword
  argument to `def_inverse_unary`.

  Example:
  ```python
  @custom_inverse
  def add_one(x):
    return x + 1.
  add_one.def_inverse_unary(lambda x: x * 2, f_ildj=lambda x: jnp.ones_like(x))
  ildj(add_one)(2.)  # ==> 1.
  ```

  ## `def_inverse_and_ildj`

  A more general way of defining a custom inverse or ILDJ is to use
  `def_inverse_and_ildj`, which will enable the user to invert functions with
  partially known inputs and outputs. Take an example like
  `add = lambda x, y: x + y`, which cannot be inverted with just the output,
  but can be inverted when just one input is known. `def_inverse_and_ildj`
  takes a single function `f_ildj` as an argument. `f_ildj` is a function from
  `invals` (a set of values corresponding to `f`'s inputs), `outvals` (a set
  of values corresponding to `f`'s outputs) and `out_ildjs` (a set of inverse
  diagonal log-Jacobian values for each of the `outvals`). If any are unknown,
  they will be `None`. `f_ildj` should return a tuple
  `(new_invals, new_inildjs)` which corresponds to known values of the inputs
  and any corresponding diagonal Jacobian values (which should be the same shape
  as `invals`). If these values cannot be computed (e.g. too many values are
  `None`) the user can raise a `NonInvertibleError` which will signal to Oryx to
  give up trying to invert the function for this set of values.

  Example:
  ```python
  @custom_inverse
  def add(x, y):
    return x + y

  def add_ildj(invals, outvals, out_ildjs):
    x, y = invals
    z = outvals
    z_ildj = outildjs
    if x is None and y is None:
      raise NonInvertibleError()
    if x is None:
      return (z - y, y), (z_ildj + jnp.zeros_like(z), jnp.zeros_like(z))
    if y is None:
      return (x, z - x), (jnp.zeros_like(z), z_ildj + jnp.zeros_like(z))

  add.def_inverse_and_ildj(add_ildj)
  inverse(partial(add, 1.))(2.)  # ==> 1.
  inverse(partial(add, 1.))(2.)  # ==> 0.
  ```

  Args:
    f: a function for which we'd like to define a custom inverse.

  Returns:
    A `CustomInverse` object whose inverse can be overridden with
    `def_inverse_unary` or `def_inverse`.
  """
    return CustomInverse(f, primitive.HigherOrderPrimitive(f.__name__))
Example #4
0
    def _component_specs(self):
        return self._param_specs

    def _serialize(self):
        # Include default version 1 for now
        return 1, self._clsid, self._param_specs, self._kwargs

    @classmethod
    def _deserialize(cls, encoded):
        version, clsid, param_specs, kwargs = encoded
        if version != 1: raise ValueError
        if clsid not in _registry: raise ValueError(clsid)
        return cls(clsid, param_specs, kwargs)


bijector_p = primitive.HigherOrderPrimitive('bijector')


class _CellProxy:
    """Used for avoid recursing into cells when doing Pytree flattening/unflattening."""
    def __init__(self, cell):
        self.cell = cell


def bijector_ildj_rule(incells, outcells, **params):
    """Inverse/ILDJ rule for bijectors."""
    incells = incells[1:]
    num_consts = len(incells) - params['num_args']
    const_incells, flat_incells = jax_util.split_list(incells, [num_consts])
    flat_inproxies = safe_map(_CellProxy, flat_incells)
    in_tree = params['in_tree']
Example #5
0
end.
"""
import functools

from jax import lax
from jax import random

from oryx.core import primitive
from oryx.core.interpreters import log_prob
from oryx.core.interpreters import propagate

__all__ = [
    'make_plate',
]

plate_p = primitive.HigherOrderPrimitive('plate')


def plate_log_prob_rule(incells, outcells, *, plate, **params):
    incells, outcells, lp = propagate.call_rule(plate_p,
                                                incells,
                                                outcells,
                                                plate=plate,
                                                **params)
    return incells, outcells, lax.psum(lp, plate)


log_prob.log_prob_rules[plate_p] = plate_log_prob_rule
log_prob.log_prob_registry.add(plate_p)