Ejemplo n.º 1
0
    def __init__(self,
                 values_with_sample_dim,
                 batch_ndims=0,
                 validate_args=False,
                 name=None):
        """Initializes an empirical distribution with a list of samples.

    Args:
      values_with_sample_dim: nested structure of `Tensor`s, each of shape
        prefixed by `[num_samples, B1, ..., Bn]`, where `num_samples` as well as
        `B1, ..., Bn` are batch dimensions shared across all `Tensor`s.
      batch_ndims: optional scalar int `Tensor`, or structure matching
        `values_with_sample_dim` of scalar int `Tensor`s, specifying the number
        of batch dimensions. Used to determine the batch and event shapes of the
        distribution.
        Default value: `0`.
      validate_args: Python `bool` indicating whether to perform runtime checks
        that may have performance cost.
        Default value: `False`.
      name: Python `str` name for ops created by this distribution.
    """
        parameters = dict(locals())
        with tf.name_scope(name or 'DeterministicEmpirical') as name:

            # Ensure we don't break if the passed-in structures are externally
            # mutated.
            values_with_sample_dim = _copy_structure(values_with_sample_dim)
            batch_ndims = _copy_structure(batch_ndims)

            # Prevent tf.Module from wrapping passed-in values, because the
            # wrapper breaks JointDistributionNamed (and maybe other JDs). Instead, we
            # save a separate ref to the input that is used only by tf.Module
            # tracking.
            self._values_for_tracking = values_with_sample_dim
            self._values_with_sample_dim = self._no_dependency(
                values_with_sample_dim)

            if not tf.nest.is_nested(batch_ndims):
                batch_ndims = tf.nest.map_structure(lambda _: batch_ndims,
                                                    values_with_sample_dim)
            self._batch_ndims = batch_ndims

            self._max_num_samples = prefer_static.reduce_min([
                prefer_static.size0(x)
                for x in tf.nest.flatten(values_with_sample_dim)
            ])

            super(DeterministicEmpirical,
                  self).__init__(dtype=tf.nest.map_structure(
                      lambda x: x.dtype, self.values_with_sample_dim),
                                 reparameterization_type=reparameterization.
                                 FULLY_REPARAMETERIZED,
                                 validate_args=validate_args,
                                 allow_nan_stats=True,
                                 name=name)
            self._parameters = self._no_dependency(parameters)
Ejemplo n.º 2
0
def _update_forward_min_event_ndims(
    bij,
    downstream_quantities,
    get_forward_min_event_ndims=lambda b: b.forward_min_event_ndims,
    get_inverse_min_event_ndims=lambda b: b.inverse_min_event_ndims,
    inverse_event_ndims_fn=lambda b, nd: b.inverse_event_ndims(nd)):
  """Step backwards through the graph to infer `forward_min_event_ndims`.

  Args:
    bij: local tfb.Bijector instance at the current graph node.
    downstream_quantities: Instance of `MinEventNdimsDownstreamQuantities`
      namedtuple, containing event_ndims that satisfy the bijector(s)
      downstream from `bij` in the graph. May be `None` if there are no such
      bijectors.
    get_forward_min_event_ndims: callable; may be overridden to swap
      forward/inverse direction.
    get_inverse_min_event_ndims: callable; may be overridden to swap
      forward/inverse direction.
    inverse_event_ndims_fn: callable; may be overridden to swap
      forward/inverse direction.
  Returns:
    downstream_quantities: Instance of `MinEventNdimsDownstreamQuantities`
      namedtuple containing event_ndims that satisfy `bij` and all downstream
      bijectors.
  """
  if downstream_quantities is None:  # This is a leaf bijector.
    return MinEventNdimsInferenceDownstreamQuantities(
        forward_min_event_ndims=get_forward_min_event_ndims(bij),
        parts_interact=bij._parts_interact)  # pylint: disable=protected-access

  inverse_min_event_ndims = get_inverse_min_event_ndims(bij)
  downstream_min_event_ndims = nest_util.coerce_structure(
      inverse_min_event_ndims,
      downstream_quantities.forward_min_event_ndims)

  # Update the min_event_ndims that is a valid input to downstream bijectors
  # to also be a valid *output* of this bijector, or equivalently, a valid
  # input to `bij.inverse`.
  rank_mismatches = tf.nest.flatten(
      tf.nest.map_structure(
          lambda dim, min_dim: dim - min_dim,
          downstream_min_event_ndims,
          inverse_min_event_ndims))
  if downstream_quantities.parts_interact:
    # If downstream bijectors involve interaction between parts,
    # then a valid input to the downstream bijectors must augment the
    # `downstream_min_event_ndims` by the
    # same rank for every part (otherwise we would induce event shape
    # broadcasting). Hopefully, this will also avoid event-shape broadcasting
    # at the current bijector---if not, the composition is invalid, and the call
    # to `bij.inverse_event_ndims(valid_inverse_min_event_ndims)` below will
    # raise an exception.
    maximum_rank_deficiency = -ps.reduce_min([0] + rank_mismatches)
    valid_inverse_min_event_ndims = tf.nest.map_structure(
        lambda ndims: maximum_rank_deficiency + ndims,
        downstream_min_event_ndims)
  else:
    if bij._parts_interact:  # pylint: disable=protected-access
      # If this bijector does *not* operate independently on its parts, then a
      # valid input to `inverse` cannot require event shape broadcasting. That
      # is, each part must have the same 'excess rank' above the local
      # inverse_min_event_ndims; we ensure this by construction.
      maximum_excess_rank = ps.reduce_max([0] + rank_mismatches)
      valid_inverse_min_event_ndims = tf.nest.map_structure(
          lambda ndims: maximum_excess_rank + ndims,
          inverse_min_event_ndims)
    else:
      # If all parts are independent, can take the pointwise max event_ndims.
      valid_inverse_min_event_ndims = tf.nest.map_structure(
          ps.maximum, downstream_min_event_ndims, inverse_min_event_ndims)

  return MinEventNdimsInferenceDownstreamQuantities(
      # Pull the desired output ndims back through the bijector, to get
      # the ndims of a valid *input*.
      forward_min_event_ndims=inverse_event_ndims_fn(
          bij, valid_inverse_min_event_ndims),
      parts_interact=(
          downstream_quantities.parts_interact or
          bij._parts_interact))  # pylint: disable=protected-access
Ejemplo n.º 3
0
 def sanitize_event_ndims(event_ndims):
     """Updates `rolling_offset` when event_ndims are negative."""
     nonlocal rolling_offset
     max_missing_ndims = -ps.reduce_min(nest.flatten(event_ndims))
     rolling_offset = ps.maximum(rolling_offset, max_missing_ndims)
     return event_ndims