예제 #1
0
 def _call_inverse_log_det_jacobian(self, y, event_ndims, name, **kwargs):
   """Compute inverse_log_det_jacobian over the composition."""
   with self._name_and_control_scope(name):
     dtype = self.forward_dtype(**kwargs)
     y = nest_util.convert_to_nested_tensor(
         y, name='y', dtype_hint=dtype,
         dtype=None if bijector.SKIP_DTYPE_CHECKS else dtype,
         allow_packing=True)
     event_ndims = nest_util.coerce_structure(
         self.inverse_min_event_ndims, event_ndims)
     return self._inverse_log_det_jacobian(y, event_ndims, **kwargs)
예제 #2
0
 def update_i_event_ndims(bij, event_ndims):
     event_ndims = nest_util.coerce_structure(bij.forward_min_event_ndims,
                                              event_ndims)
     if bij.has_static_min_event_ndims:
         return update_event_ndims(
             input_event_ndims=event_ndims,
             input_min_event_ndims=bij.forward_min_event_ndims,
             output_min_event_ndims=bij.inverse_min_event_ndims)
     elif isinstance(bij, composition.Composition):
         return bij._call_walk_forward(update_i_event_ndims, event_ndims)  # pylint: disable=protected-access
     else:
         return sanitize_event_ndims(bij.forward_event_ndims(event_ndims))
예제 #3
0
 def _call_inverse_log_det_jacobian(self, y, event_ndims, name, **kwargs):
   """Compute inverse_log_det_jacobian over the composition."""
   with self._name_and_control_scope(name):
     dtype = self.forward_dtype(**kwargs)
     y = nest_util.convert_to_nested_tensor(
         y, name='y', dtype_hint=dtype,
         dtype=None if bijector.SKIP_DTYPE_CHECKS else dtype,
         allow_packing=True)
     if event_ndims is None:
       if self._has_static_min_event_ndims:
         event_ndims = self.inverse_min_event_ndims
       else:
         raise ValueError('Composition bijector with non-static '
                          '`min_event_ndims` does not support '
                          '`event_ndims=None`. Please pass a value '
                          'for `event_ndims`.')
     event_ndims = nest_util.coerce_structure(
         self.inverse_min_event_ndims, event_ndims)
     return self._inverse_log_det_jacobian(y, event_ndims, **kwargs)
예제 #4
0
def _update_forward_min_event_ndims(
    bij,
    downstream_quantities,
    get_forward_min_event_ndims=lambda b: b.forward_min_event_ndims,
    get_inverse_min_event_ndims=lambda b: b.inverse_min_event_ndims,
    inverse_event_ndims_fn=lambda b, nd: b.inverse_event_ndims(nd)):
  """Step backwards through the graph to infer `forward_min_event_ndims`.

  Args:
    bij: local tfb.Bijector instance at the current graph node.
    downstream_quantities: Instance of `MinEventNdimsDownstreamQuantities`
      namedtuple, containing event_ndims that satisfy the bijector(s)
      downstream from `bij` in the graph. May be `None` if there are no such
      bijectors.
    get_forward_min_event_ndims: callable; may be overridden to swap
      forward/inverse direction.
    get_inverse_min_event_ndims: callable; may be overridden to swap
      forward/inverse direction.
    inverse_event_ndims_fn: callable; may be overridden to swap
      forward/inverse direction.
  Returns:
    downstream_quantities: Instance of `MinEventNdimsDownstreamQuantities`
      namedtuple containing event_ndims that satisfy `bij` and all downstream
      bijectors.
  """
  if downstream_quantities is None:  # This is a leaf bijector.
    return MinEventNdimsInferenceDownstreamQuantities(
        forward_min_event_ndims=get_forward_min_event_ndims(bij),
        parts_interact=bij._parts_interact)  # pylint: disable=protected-access

  inverse_min_event_ndims = get_inverse_min_event_ndims(bij)
  downstream_min_event_ndims = nest_util.coerce_structure(
      inverse_min_event_ndims,
      downstream_quantities.forward_min_event_ndims)

  # Update the min_event_ndims that is a valid input to downstream bijectors
  # to also be a valid *output* of this bijector, or equivalently, a valid
  # input to `bij.inverse`.
  rank_mismatches = tf.nest.flatten(
      tf.nest.map_structure(
          lambda dim, min_dim: dim - min_dim,
          downstream_min_event_ndims,
          inverse_min_event_ndims))
  if downstream_quantities.parts_interact:
    # If downstream bijectors involve interaction between parts,
    # then a valid input to the downstream bijectors must augment the
    # `downstream_min_event_ndims` by the
    # same rank for every part (otherwise we would induce event shape
    # broadcasting). Hopefully, this will also avoid event-shape broadcasting
    # at the current bijector---if not, the composition is invalid, and the call
    # to `bij.inverse_event_ndims(valid_inverse_min_event_ndims)` below will
    # raise an exception.
    maximum_rank_deficiency = -ps.reduce_min([0] + rank_mismatches)
    valid_inverse_min_event_ndims = tf.nest.map_structure(
        lambda ndims: maximum_rank_deficiency + ndims,
        downstream_min_event_ndims)
  else:
    if bij._parts_interact:  # pylint: disable=protected-access
      # If this bijector does *not* operate independently on its parts, then a
      # valid input to `inverse` cannot require event shape broadcasting. That
      # is, each part must have the same 'excess rank' above the local
      # inverse_min_event_ndims; we ensure this by construction.
      maximum_excess_rank = ps.reduce_max([0] + rank_mismatches)
      valid_inverse_min_event_ndims = tf.nest.map_structure(
          lambda ndims: maximum_excess_rank + ndims,
          inverse_min_event_ndims)
    else:
      # If all parts are independent, can take the pointwise max event_ndims.
      valid_inverse_min_event_ndims = tf.nest.map_structure(
          ps.maximum, downstream_min_event_ndims, inverse_min_event_ndims)

  return MinEventNdimsInferenceDownstreamQuantities(
      # Pull the desired output ndims back through the bijector, to get
      # the ndims of a valid *input*.
      forward_min_event_ndims=inverse_event_ndims_fn(
          bij, valid_inverse_min_event_ndims),
      parts_interact=(
          downstream_quantities.parts_interact or
          bij._parts_interact))  # pylint: disable=protected-access
예제 #5
0
  def _call_walk_inverse(self, step_fn, *args, **kwargs):
    """Prepares args and calls `_walk_inverse`.

    Converts a tuple of structured positional arguments to a structure of
    argument tuples, and wraps `step_fn` to unpack inputs and re-pack
    returned values. This way, users may invoke walks using `map_structure`
    semantics, and the concrete `_walk` implementations can operate on
    single-structure of inputs (without worrying about tuple unpacking).

    For example, the `inverse` method looks roughly like:
    ```python

    MyComposition()._call_walk_inverse(
        lambda bij, y, **kwargs: bij.inverse(y, **kwargs),
        composite_inputs, **composite_kwargs)
    ```

    More complex methods may need to mutate external state from `step_fn`:
    ```python

    shape_trace = {}

    def trace_step(bijector, y_shape):
      shape_trace[bijector.name] = y_shape
      return bijector.inverse_event_shape(y_shape)

    # Calling this populates the `shape_trace` dictionary
    composition.walk_forward(trace_step, composite_y_shape)
    ```

    Args:
      step_fn: Callable applied to each wrapped bijector.
        Must accept a bijector instance followed by `len(args)` positional
        arguments whose structures match `bijector.inverse_min_event_ndims`,
        and return `len(args)` structures matching
        `bijector.forward_min_event_ndims`.
      *args: Input arguments propagated to nested bijectors.
      **kwargs: Keyword arguments forwarded to `_walk_inverse`.
    Returns:
      The transformed output. If multiple positional arguments are provided, a
        tuple of matching length will be returned.
    """
    args = tuple(nest_util.coerce_structure(self.inverse_min_event_ndims, y)
                 for y in args)

    if len(args) == 1:
      return self._walk_inverse(step_fn, *args, **kwargs)

    # Convert a tuple of structures to a structure of tuples. This
    # allows `_walk` methods to route aligned structures of inputs/outputs
    # independently, obviates the need for conditional tuple unpacking.
    packed_args = pack_structs_like(self.inverse_min_event_ndims, *args)

    def transform_wrapper(bij, packed_ys, **nested):
      ys = unpack_structs_like(bij.inverse_min_event_ndims, packed_ys)
      xs = step_fn(bij, *ys, **nested)
      return pack_structs_like(bij.forward_min_event_ndims, *xs)

    packed_result = self._walk_inverse(
        transform_wrapper, packed_args, **kwargs)
    return unpack_structs_like(self.forward_min_event_ndims, packed_result)
예제 #6
0
 def testCoerceStructureRaises(self, target, source, message):
     with self.assertRaisesRegex((ValueError, TypeError), message):
         nest_util.coerce_structure(target, source)
예제 #7
0
 def testCoerceStructure(self, target, source, expect):
     coerced = nest_util.coerce_structure(target, source)
     self.assertAllEqualNested(expect, coerced)
예제 #8
0
def _factored_surrogate_posterior(  # pylint: disable=dangerous-default-value
        event_shape=None,
        bijector=None,
        batch_shape=(),
        base_distribution_cls=normal.Normal,
        initial_parameters={'scale': 1e-2},
        dtype=tf.float32,
        validate_args=False,
        name=None):
    """Builds a joint variational posterior that factors over model variables.

  By default, this method creates an independent trainable Normal distribution
  for each variable, transformed using a bijector (if provided) to
  match the support of that variable. This makes extremely strong
  assumptions about the posterior: that it is approximately normal (or
  transformed normal), and that all model variables are independent.

  Args:
    event_shape: `Tensor` shape, or nested structure of `Tensor` shapes,
      specifying the event shape(s) of the posterior variables.
    bijector: Optional `tfb.Bijector` instance, or nested structure of such
      instances, defining support(s) of the posterior variables. The structure
      must match that of `event_shape` and may contain `None` values. A
      posterior variable will be modeled as
      `tfd.TransformedDistribution(underlying_dist, bijector)` if a
      corresponding constraining bijector is specified, otherwise it is modeled
      as supported on the unconstrained real line.
    batch_shape: The `batch_shape` of the output distribution.
      Default value: `()`.
    base_distribution_cls: Subclass of `tfd.Distribution` that is instantiated
      and optionally transformed by the bijector to define the component
      distributions. May optionally be a structure of such subclasses
      matching `event_shape`.
      Default value: `tfd.Normal`.
    initial_parameters: Optional `str : Tensor` dictionary specifying initial
      values for some or all of the base distribution's trainable parameters,
      or a Python `callable` with signature
      `value = parameter_init_fn(parameter_name, shape, dtype, seed,
      constraining_bijector)`, passed to `tfp.experimental.util.make_trainable`.
      May optionally be a structure matching `event_shape` of such dictionaries
      and/or callables. Dictionary entries that do not correspond to parameter
      names are ignored.
      Default value: `{'scale': 1e-2}` (ignored when `base_distribution` does
        not have a `scale` parameter).
    dtype: Optional float `dtype` for trainable parameters. May
      optionally be a structure of such `dtype`s matching `event_shape`.
      Default value: `tf.float32`.
    validate_args: Python `bool`. Whether to validate input with asserts. This
      imposes a runtime cost. If `validate_args` is `False`, and the inputs are
      invalid, correct behavior is not guaranteed.
      Default value: `False`.
    name: Python `str` name prefixed to ops created by this function.
      Default value: `None` (i.e., 'build_factored_surrogate_posterior').
  Yields:
    *parameters: sequence of `trainable_state_util.Parameter` namedtuples.
      These are intended to be consumed by
      `trainable_state_util.as_stateful_builder` and
      `trainable_state_util.as_stateless_builder` to define stateful and
      stateless variants respectively.

  ### Examples

  Consider a Gamma model with unknown parameters, expressed as a joint
  Distribution:

  ```python
  Root = tfd.JointDistributionCoroutine.Root
  def model_fn():
    concentration = yield Root(tfd.Exponential(1.))
    rate = yield Root(tfd.Exponential(1.))
    y = yield tfd.Sample(tfd.Gamma(concentration=concentration, rate=rate),
                         sample_shape=4)
  model = tfd.JointDistributionCoroutine(model_fn)
  ```

  Let's use variational inference to approximate the posterior over the
  data-generating parameters for some observed `y`. We'll build a
  surrogate posterior distribution by specifying the shapes of the latent
  `rate` and `concentration` parameters, and that both are constrained to
  be positive.

  ```python
  surrogate_posterior = tfp.experimental.vi.build_factored_surrogate_posterior(
    event_shape=model.event_shape_tensor()[:-1],  # Omit the observed `y`.
    bijector=[tfb.Softplus(),   # Rate is positive.
              tfb.Softplus()])  # Concentration is positive.
  ```

  This creates a trainable joint distribution, defined by variables in
  `surrogate_posterior.trainable_variables`. We use `fit_surrogate_posterior`
  to fit this distribution by minimizing a divergence to the true posterior.

  ```python
  y = [0.2, 0.5, 0.3, 0.7]
  losses = tfp.vi.fit_surrogate_posterior(
    lambda rate, concentration: model.log_prob([rate, concentration, y]),
    surrogate_posterior=surrogate_posterior,
    num_steps=100,
    optimizer=tf.optimizers.Adam(0.1),
    sample_size=10)

  # After optimization, samples from the surrogate will approximate
  # samples from the true posterior.
  samples = surrogate_posterior.sample(100)
  posterior_mean = [tf.reduce_mean(x) for x in samples]     # mean ~= [1.1, 2.1]
  posterior_std = [tf.math.reduce_std(x) for x in samples]  # std  ~= [0.3, 0.8]
  ```

  If we wanted to initialize the optimization at a specific location, we can
  specify initial parameters when we build the surrogate posterior. Note that
  these parameterize the distribution(s) over unconstrained values,
  so we need to transform our desired constrained locations using the inverse
  of the constraining bijector(s).

  ```python
  surrogate_posterior = tfp.experimental.vi.build_factored_surrogate_posterior(
    event_shape=tf.nest.map_fn(tf.shape, initial_loc),
    bijector={'concentration': tfb.Softplus(),   # Rate is positive.
              'rate': tfb.Softplus()}   # Concentration is positive.
    initial_parameters={
      'concentration': {'loc': tfb.Softplus().inverse(0.4), 'scale': 1e-2},
      'rate': {'loc': tfb.Softplus().inverse(0.2), 'scale': 1e-2}})
  ```

  """
    with tf.name_scope(name or 'build_factored_surrogate_posterior'):
        # Convert event shapes to Tensors.
        shallow_structure = _get_event_shape_shallow_structure(event_shape)
        event_shape = nest.map_structure_up_to(
            shallow_structure,
            lambda s: tf.convert_to_tensor(s, dtype=tf.int32), event_shape)

        if nest.is_nested(bijector):
            event_space_bijector = joint_map.JointMap(
                nest.map_structure(
                    lambda b: identity.Identity() if b is None else b,
                    nest_util.coerce_structure(event_shape, bijector)),
                validate_args=validate_args)
        else:
            event_space_bijector = bijector

        if event_space_bijector is None:
            unconstrained_event_shape = event_shape
        else:
            unconstrained_event_shape = (
                event_space_bijector.inverse_event_shape_tensor(event_shape))
        unconstrained_batch_and_event_shape = tf.nest.map_structure(
            lambda s: ps.concat([batch_shape, s], axis=0),
            unconstrained_event_shape)

        base_distribution_cls = nest_util.broadcast_structure(
            event_shape, base_distribution_cls)
        try:
            # Check that we have initial parameters for each event part.
            nest.assert_shallow_structure(event_shape, initial_parameters)
        except (ValueError, TypeError):
            # If not, broadcast the parameters to match the event structure.
            # We do this manually rather than using `nest_util.broadcast_structure`
            # because the initial parameters can themselves be structures (dicts).
            initial_parameters = nest.map_structure(
                lambda x: initial_parameters, event_shape)

        unconstrained_trainable_distributions = yield from (
            nest_util.map_structure_coroutine(
                trainable._make_trainable,  # pylint: disable=protected-access
                cls=base_distribution_cls,
                initial_parameters=initial_parameters,
                batch_and_event_shape=unconstrained_batch_and_event_shape,
                parameter_dtype=nest_util.broadcast_structure(
                    event_shape, dtype),
                _up_to=event_shape))
        unconstrained_trainable_distribution = (
            joint_distribution_util.
            independent_joint_distribution_from_structure(
                unconstrained_trainable_distributions,
                batch_ndims=ps.rank_from_shape(batch_shape),
                validate_args=validate_args))
        if event_space_bijector is None:
            return unconstrained_trainable_distribution
        return transformed_distribution.TransformedDistribution(
            unconstrained_trainable_distribution, event_space_bijector)