def _layer_to_args(layer, *args, **kwargs): """Flattens a layer and handles rng in kwargs.""" kwargs = kwargs.copy() flattened_layer, in_tree = tree_util.tree_flatten(layer) kwargs['num_weights'] = len(flattened_layer) kwargs['in_tree'] = in_tree rng = kwargs.pop('rng', None) kwargs['has_rng'] = has_rng = rng is not None if has_rng: args = (rng, ) + args return flattened_layer, args, kwargs layer_cau_p = primitive.HigherOrderPrimitive('layer_cau') class NoneProxy: pass not_mapped = NoneProxy() def custom_layer_cau_batch(trace, f, tracers, params): """Batching rule for layer_cau primitive to handle custom layers.""" vals, dims = jax_util.unzip2((t.val, t.batch_dim) for t in tracers) if all(dim is batching.not_mapped for dim in dims): return layer_cau_p.bind(f, *vals, **params) args = tree_util.tree_unflatten(params['in_tree'], vals)
def _component_specs(self): return self._param_specs def _serialize(self): # Include default version 1 for now return 1, self._clsid, self._param_specs, self._kwargs @classmethod def _deserialize(cls, encoded): version, clsid, param_specs, kwargs = encoded if version != 1: raise ValueError if clsid not in _registry: raise ValueError(clsid) return cls(clsid, param_specs, kwargs) random_variable_p = primitive.HigherOrderPrimitive('random_variable') unzip.block_registry.add(random_variable_p) def random_variable_log_prob_rule(flat_incells, flat_outcells, **params): """Registers Oryx distributions with the log_prob transformation.""" del params # First incell is the call primitive function return flat_incells[1:], flat_outcells, None log_prob.log_prob_rules[random_variable_p] = random_variable_log_prob_rule def random_variable_log_prob(flat_incells, val, **params): """Registers Oryx distributions with the log_prob transformation."""
def custom_inverse(f): """Decorates a function to enable defining a custom inverse. A `custom_inverse`-decorated function is semantically identical to the original except when it is inverted with `core.inverse`. By default, `core.inverse(custom_inverse(f))` will programmatically invert the body of `f`, but `f` has two additional methods that can override that behavior: `def_inverse_unary` and `def_inverse_ildj`. ## `def_inverse_unary` `def_inverse_unary` is applicable if `f` is a unary function. `def_inverse_unary` takes in an optional `f_inv` function, which is a unary function from the output of `f` to the input of `f`. Example: ```python @custom_inverse def add_one(x): return x + 1. add_one.def_inverse_unary(lambda x: x * 2) # Define silly custom inverse. inverse(add_one)(2.) # ==> 4. ``` With a unary `f_inv` function, Oryx will automatically compute an inverse log-det Jacobian using `core.ildj(core.inverse(f_inv))`, but a user can also override the Jacobian term by providing the optional `f_ildj` keyword argument to `def_inverse_unary`. Example: ```python @custom_inverse def add_one(x): return x + 1. add_one.def_inverse_unary(lambda x: x * 2, f_ildj=lambda x: jnp.ones_like(x)) ildj(add_one)(2.) # ==> 1. ``` ## `def_inverse_and_ildj` A more general way of defining a custom inverse or ILDJ is to use `def_inverse_and_ildj`, which will enable the user to invert functions with partially known inputs and outputs. Take an example like `add = lambda x, y: x + y`, which cannot be inverted with just the output, but can be inverted when just one input is known. `def_inverse_and_ildj` takes a single function `f_ildj` as an argument. `f_ildj` is a function from `invals` (a set of values corresponding to `f`'s inputs), `outvals` (a set of values corresponding to `f`'s outputs) and `out_ildjs` (a set of inverse diagonal log-Jacobian values for each of the `outvals`). If any are unknown, they will be `None`. `f_ildj` should return a tuple `(new_invals, new_inildjs)` which corresponds to known values of the inputs and any corresponding diagonal Jacobian values (which should be the same shape as `invals`). If these values cannot be computed (e.g. too many values are `None`) the user can raise a `NonInvertibleError` which will signal to Oryx to give up trying to invert the function for this set of values. Example: ```python @custom_inverse def add(x, y): return x + y def add_ildj(invals, outvals, out_ildjs): x, y = invals z = outvals z_ildj = outildjs if x is None and y is None: raise NonInvertibleError() if x is None: return (z - y, y), (z_ildj + jnp.zeros_like(z), jnp.zeros_like(z)) if y is None: return (x, z - x), (jnp.zeros_like(z), z_ildj + jnp.zeros_like(z)) add.def_inverse_and_ildj(add_ildj) inverse(partial(add, 1.))(2.) # ==> 1. inverse(partial(add, 1.))(2.) # ==> 0. ``` Args: f: a function for which we'd like to define a custom inverse. Returns: A `CustomInverse` object whose inverse can be overridden with `def_inverse_unary` or `def_inverse`. """ return CustomInverse(f, primitive.HigherOrderPrimitive(f.__name__))
def _component_specs(self): return self._param_specs def _serialize(self): # Include default version 1 for now return 1, self._clsid, self._param_specs, self._kwargs @classmethod def _deserialize(cls, encoded): version, clsid, param_specs, kwargs = encoded if version != 1: raise ValueError if clsid not in _registry: raise ValueError(clsid) return cls(clsid, param_specs, kwargs) bijector_p = primitive.HigherOrderPrimitive('bijector') class _CellProxy: """Used for avoid recursing into cells when doing Pytree flattening/unflattening.""" def __init__(self, cell): self.cell = cell def bijector_ildj_rule(incells, outcells, **params): """Inverse/ILDJ rule for bijectors.""" incells = incells[1:] num_consts = len(incells) - params['num_args'] const_incells, flat_incells = jax_util.split_list(incells, [num_consts]) flat_inproxies = safe_map(_CellProxy, flat_incells) in_tree = params['in_tree']
end. """ import functools from jax import lax from jax import random from oryx.core import primitive from oryx.core.interpreters import log_prob from oryx.core.interpreters import propagate __all__ = [ 'make_plate', ] plate_p = primitive.HigherOrderPrimitive('plate') def plate_log_prob_rule(incells, outcells, *, plate, **params): incells, outcells, lp = propagate.call_rule(plate_p, incells, outcells, plate=plate, **params) return incells, outcells, lax.psum(lp, plate) log_prob.log_prob_rules[plate_p] = plate_log_prob_rule log_prob.log_prob_registry.add(plate_p)