예제 #1
0
    def step(bij, x, x_event_ndims, increased_dof, **kwargs):  # pylint: disable=missing-docstring
      # Transform inputs for the next bijector.
      y = forward_fn(bij, x, **kwargs) if compute_x_values else None
      y_event_ndims = forward_event_ndims_fn(bij, x_event_ndims, **kwargs)

      # Check if the inputs to this bijector have increased degrees of freedom
      # due to some upstream bijector. We assume that the upstream bijector
      # produced a valid LDJ, but this one does not (unless LDJ is 0, in which
      # case it doesn't matter).
      increased_dof = ps.reduce_any(nest.flatten(increased_dof))
      if compute_x_values and self.validate_event_size:
        assertions = [
            self._maybe_warn_increased_dof(
                component_name=bij.name, increased_dof=increased_dof)
        ]
        increased_dof |= (_event_size(y, y_event_ndims)
                          > _event_size(x, x_event_ndims))
      else:
        assertions = []

      y = nest_util.broadcast_structure(y_event_ndims, y)
      increased_dof = nest_util.broadcast_structure(y_event_ndims,
                                                    increased_dof)
      bijectors_with_metadata.append(
          BijectorWithMetadata(
              bijector=bij,
              x=x,
              x_event_ndims=x_event_ndims,
              kwargs=kwargs,
              assertions=assertions,
          ))
      return y, y_event_ndims, increased_dof
예제 #2
0
    def __init__(self,
                 bijectors=None,
                 validate_args=False,
                 parameters=None,
                 name=None):
        """Instantiates `JointMap` bijector.

    Args:
      bijectors: Structure of bijector instances to apply in parallel.
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.
      parameters: Locals dict captured by subclass constructor, to be used for
        copy/slice re-instantiation operators.
      name: Python `str`, name given to ops managed by this object. Default:
        E.g., ```
          JointMap([Exp(), Softplus()]).name == "jointmap_of_exp_and_softplus"
        ```.

    Raises:
      ValueError: if bijectors have different dtypes.
    """
        parameters = dict(locals()) if parameters is None else parameters

        if not bijectors:
            raise ValueError('`bijectors` must not be empty.')

        if name is None:
            name = ('jointmap_of_' +
                    '_and_'.join([b.name for b in nest.flatten(bijectors)]))
            name = name.replace('/', '')
        with tf.name_scope(name) as name:
            # Structured dtypes are based on the non-wrapped input.
            # Keep track of the non-wrapped structure of bijectors to correctly
            # wrap inputs/outputs in _walk methods.
            self._nested_structure = self._no_dependency(
                nest.map_structure(lambda b: None, bijectors))

            super(JointMap, self).__init__(
                bijectors=bijectors,
                validate_args=validate_args,
                parameters=parameters,
                name=name,
                # JointMap and other bijectors that operate independently on
                # parts of structured inputs do not have statically-known
                # `min_event_ndims`. Infer the input/output structures, and fill them
                # with `None`.
                forward_min_event_ndims=nest.map_structure(
                    lambda b: nest_util.broadcast_structure(  # pylint: disable=g-long-lambda
                        b.forward_min_event_ndims, None),
                    bijectors),
                inverse_min_event_ndims=nest.map_structure(
                    lambda b: nest_util.broadcast_structure(  # pylint: disable=g-long-lambda
                        b.forward_min_event_ndims, None),
                    bijectors),
            )
예제 #3
0
    def step(bij, y, y_event_ndims, increased_dof=False, **kwargs):  # pylint: disable=missing-docstring
      nonlocal ldj_sum

      # Compute the LDJ for this step, and add it to the rolling sum.
      component_ldj = tf.convert_to_tensor(
          bij.inverse_log_det_jacobian(y, y_event_ndims, **kwargs),
          dtype_hint=ldj_sum.dtype)

      if not dtype_util.is_floating(component_ldj.dtype):
        raise TypeError(('Nested bijector "{}" of Composition "{}" returned '
                         'LDJ with a non-floating dtype: {}')
                        .format(bij.name, self.name, component_ldj.dtype))
      ldj_sum = _max_precision_sum(ldj_sum, component_ldj)

      # Transform inputs for the next bijector.
      x = bij.inverse(y, **kwargs)
      x_event_ndims = bij.inverse_event_ndims(y_event_ndims, **kwargs)

      # Check if the inputs to this bijector have increased degrees of freedom
      # due to some upstream bijector. We assume that the upstream bijector
      # produced a valid LDJ, but this one does not (unless LDJ is 0, in which
      # case it doesn't matter).
      increased_dof = ps.reduce_any(nest.flatten(increased_dof))
      if self.validate_event_size:
        assertions.append(self._maybe_warn_increased_dof(
            component_name=bij.name,
            component_ldj=component_ldj,
            increased_dof=increased_dof))
        increased_dof |= (_event_size(x, x_event_ndims)
                          > _event_size(y, y_event_ndims))

      increased_dof = nest_util.broadcast_structure(x, increased_dof)
      return x, x_event_ndims, increased_dof
예제 #4
0
    def from_shape(cls, shape=(), independent_chain_ndims=1, dtype=tf.float32):
        """Starts an empty `RunningPotentialScaleReduction` from metadata.

    Args:
      shape: Python `Tuple` or `TensorShape` representing the shape of incoming
        samples. Using a collection implies that future samples will mimic that
        exact structure. This is useful to supply if the
        `RunningPotentialScaleReduction` will be carried by a `tf.while_loop`,
        so that broadcasting does not change the shape across loop iterations.
      independent_chain_ndims: Integer or Integer type `Tensor` with value
        `>= 1` giving the number of leading dimensions holding independent
        chain results to be tested for convergence. Using a collection
        implies that future samples will mimic that exact structure.
      dtype: Dtype of incoming samples and the resulting statistics.
        By default, the dtype is `tf.float32`. Any integer dtypes will be
        cast to corresponding floats (i.e. `tf.int32` will be cast to
        `tf.float32`), as intermediate calculations should be performing
        floating-point division.

    Returns:
      state: `RunningPotentialScaleReduction` representing a stream
        of no inputs.
    """
        dtype = tf.nest.map_structure(_float_dtype_like, dtype)

        dtype = nest_util.broadcast_structure(independent_chain_ndims, dtype)
        chain_variances = nest.map_structure_up_to(independent_chain_ndims,
                                                   RunningVariance.from_shape,
                                                   shape,
                                                   dtype,
                                                   check_types=False)
        return cls(chain_variances, independent_chain_ndims)
    def one_step(self,
                 new_chain_state,
                 current_reducer_state,
                 previous_kernel_results,
                 axis=None):
        """Update the `current_reducer_state` with a new chain state.

    Chunking semantics are similar to those of batching and are specified by the
    `axis` parameter. If chunking is enabled (axis is not `None`), all elements
    along the specified `axis` will be treated as separate samples. If a
    single scalar value is provided for a non-scalar sample structure, that
    value will be used for all elements in the structure. If not, an identical
    structure must be provided.

    Args:
      new_chain_state: A (possibly nested) structure of incoming chain state(s)
        with shape and dtype compatible with those used to initialize the
        `current_reducer_state`.
      current_reducer_state: `CovarianceReducerState`s representing the current
        state of the running covariance.
      previous_kernel_results: A (possibly nested) structure of `Tensor`s
        representing internal calculations made in a related
        `TransitionKernel`.
      axis: If chunking is desired, this is a (possibly nested) structure of
        integers that specifies the axis with chunked samples. For individual
        samples, set this to `None`. By default, samples are not chunked
        (`axis` is None).

    Returns:
      new_reducer_state: `CovarianceReducerState` with updated running
        statistics. Its `cov_state` field has an identical structure to the
        results of `self.transform_fn`. Each of the individual values in that
        structure subsequently mimics the structure of `current_reducer_state`.
    """
        with tf.name_scope(
                mcmc_util.make_name(self.name, 'covariance_reducer',
                                    'one_step')):
            cov_streams = _prepare_args(current_reducer_state.init_structure,
                                        self.event_ndims)
            new_chain_state = tf.nest.map_structure(tf.convert_to_tensor,
                                                    new_chain_state)
            previous_kernel_results = tf.nest.map_structure(
                tf.convert_to_tensor, previous_kernel_results)
            fn_results = tf.nest.map_structure(
                lambda fn: fn(new_chain_state, previous_kernel_results),
                self.transform_fn,
            )
            if not nest.is_nested(axis):
                axis = nest_util.broadcast_structure(fn_results, axis)
            running_cov_state = nest.map_structure_up_to(
                current_reducer_state.init_structure,
                lambda strm, *args: strm.update(*args),
                cov_streams,
                current_reducer_state.cov_state,
                fn_results,
                axis,
                check_types=False,
            )
            return CovarianceReducerState(current_reducer_state.init_structure,
                                          running_cov_state)
예제 #6
0
    def update(self, state, new_sample):
        """Update the `RunningPotentialScaleReductionState` with a new sample.

    Args:
      state: `RunningPotentialScaleReductionState` that represents the
        current state of running statistics.
      new_sample: Incoming `Tensor` sample or (possibly nested) collection of
        `Tensor`s with shape and dtype compatible with those used to form the
        `RunningPotentialScaleReductionState`.

    Returns:
      state: `RunningPotentialScaleReductionState` with updated calculations.
    """
        def _update_for_one_state(shape, dtype, chain_var, new_sample):
            """Updates the running variance for one group of Markov chains."""
            # TODO(axch): chunking could be reasonably added here by accepting and
            # including the chunked axis to the running variance object
            var_stream = RunningVariance(shape, dtype=dtype)
            return var_stream.update(chain_var, new_sample)

        broadcasted_dtype = nest_util.broadcast_structure(
            self.independent_chain_ndims, self.dtype)
        updated_chain_vars = nest.map_structure_up_to(
            self.independent_chain_ndims,
            _update_for_one_state,
            self.shape,
            broadcasted_dtype,
            state.chain_var,
            new_sample,
            check_types=False)
        return RunningPotentialScaleReductionState(updated_chain_vars)
    def test_specifying_distribution_type(self,
                                          event_shape,
                                          base_distribution_cls,
                                          is_stateless=JAX_MODE):
        init_seed, sample_seed = samplers.split_seed(
            test_util.test_seed(sampler_type='stateless'), n=2)
        surrogate_posterior = self._initialize_surrogate(
            'build_factored_surrogate_posterior',
            is_stateless=is_stateless,
            seed=init_seed,
            event_shape=event_shape,
            base_distribution_cls=base_distribution_cls,
            validate_args=True)

        # Test that the surrogate uses the expected distribution types.
        if tf.nest.is_nested(surrogate_posterior.event_shape):
            ds, _ = surrogate_posterior.sample_distributions(seed=sample_seed)
        else:
            ds = [surrogate_posterior]
        for cls, d in zip(
                nest_util.broadcast_structure(ds, base_distribution_cls), ds):
            d = _as_concrete_instance(d)
            while isinstance(d, tfd.Independent):
                d = _as_concrete_instance(d.distribution)
            self.assertIsInstance(d, cls)
예제 #8
0
def _prepare_args(target, event_ndims):
    """Creates a structure of `RunningCovariance`s based on inferred metadata.

  Metadata required to create a `RunningCovariance` object (`shape`, `dtype`,
  and `event_ndims` of incoming chain states) will be inferred from the
  `target`. Using that information, an identical structure of
  `RunningCovariance`s to `target` will be returned.

  Args:
    target: A (possibly nested) structure of `Tensor`s or Python
      `list`s of `Tensor`s representing the current state(s) of the Markov
      chain(s). It is used to infer the shape and dtype of future samples.
    event_ndims: A (possibly nested) structure of integers. Defines
        the number of inner-most dimensions that represent the event shape.
        Must be either a singleton or of the same shape as `target`.

  Returns:
    cov_streams: Structure of `sample_stats.RunningCovariance` matching
      the shape of `target`.
  """

    shape = tf.nest.map_structure(lambda target: target.shape, target)
    dtype = tf.nest.map_structure(lambda target: target.dtype, target)
    if event_ndims is None:
        event_ndims = tf.nest.map_structure(ps.rank, target)
    elif not nest.is_nested(event_ndims):
        event_ndims = nest_util.broadcast_structure(target, event_ndims)
    return nest.map_structure_up_to(
        target,
        sample_stats.RunningCovariance,
        shape,
        event_ndims,
        dtype,
        check_types=False,
    )
예제 #9
0
 def step_broadcast(step_size):
     # Only apply the bijector to nested step sizes or non-scalar batches.
     if tf.nest.is_nested(step_size):
         return step_bijector(
             nest_util.broadcast_structure(
                 pinned_model.event_shape_tensor(), step_size))
     else:
         return step_size
예제 #10
0
    def one_step(self,
                 new_chain_state,
                 current_reducer_state,
                 previous_kernel_results=None,
                 axis=None):
        """Update the `current_reducer_state` with a new chain state.

    Chunking semantics are specified by the `axis` parameter. If chunking is
    enabled (axis is not `None`), all elements along the specified `axis` will
    be treated as separate samples. If a single scalar value is provided for a
    non-scalar sample structure, that value will be used for all elements in the
    structure. If not, an identical structure must be provided.

    Args:
      new_chain_state: A (possibly nested) structure of incoming chain state(s)
        with shape and dtype compatible with those used to initialize the
        `current_reducer_state`.
      current_reducer_state: `ExpectationsReducerState` representing the current
        reducer state.
      previous_kernel_results: A (possibly nested) structure of `Tensor`s
        representing internal calculations made in a related
        `TransitionKernel`.
      axis: If chunking is desired, this is a (possibly nested) structure of
        integers that specifies the axis with chunked samples. For individual
        samples, set this to `None`. By default, samples are not chunked
        (`axis` is None).

    Returns:
      new_reducer_state: `ExpectationsReducerState` with updated running
        statistics. It tracks a running total and the number of processed
        samples.
    """
        with tf.name_scope(
                mcmc_util.make_name(self.name, 'expectations_reducer',
                                    'one_step')):
            new_chain_state = tf.nest.map_structure(tf.convert_to_tensor,
                                                    new_chain_state)
            if previous_kernel_results is not None:
                previous_kernel_results = tf.nest.map_structure(
                    tf.convert_to_tensor,
                    previous_kernel_results,
                    expand_composites=True)
            fn_results = tf.nest.map_structure(
                lambda fn: fn(new_chain_state, previous_kernel_results),
                self.transform_fn)
            if not nest.is_nested(axis):
                axis = nest_util.broadcast_structure(fn_results, axis)

            def update(fn_results, state, axis):
                return state.update(fn_results, axis=axis)

            return ExpectationsReducerState(
                nest.map_structure(update,
                                   fn_results,
                                   current_reducer_state.expectation_state,
                                   axis,
                                   check_types=False))
예제 #11
0
def _canonicalize_event_ndims(target, event_ndims):
    """Returns `event_ndims` shaped parallel to `target`, repeating as needed."""
    # This is only here to support the possibility of different event_ndims across
    # different Tensors in the target structure.  Otherwise, event_ndims could
    # just be an integer (or None) and wouldn't need to be canonicalized to a
    # structure.
    if not nest.is_nested(event_ndims):
        return nest_util.broadcast_structure(target, event_ndims)
    else:
        return event_ndims
예제 #12
0
  def _get_bijectors_with_metadata(self,
                                   x,
                                   event_ndims,
                                   forward=True,
                                   **kwargs):
    """Trace bijectors + metadata forward/backward."""
    bijectors_with_metadata = []

    if forward:
      forward_fn = lambda bij, *args, **kwargs: bij.forward(*args, **kwargs)
      forward_event_ndims_fn = (
          lambda bij, *args, **kwargs: bij.forward_event_ndims(*args, **kwargs))
      walk_forward_fn = self._call_walk_forward
    else:
      forward_fn = lambda bij, *args, **kwargs: bij.inverse(*args, **kwargs)
      forward_event_ndims_fn = (
          lambda bij, *args, **kwargs: bij.inverse_event_ndims(*args, **kwargs))
      walk_forward_fn = self._call_walk_inverse

    def step(bij, x, x_event_ndims, increased_dof, **kwargs):  # pylint: disable=missing-docstring
      # Transform inputs for the next bijector.
      y = forward_fn(bij, x, **kwargs)
      y_event_ndims = forward_event_ndims_fn(bij, x_event_ndims, **kwargs)

      # Check if the inputs to this bijector have increased degrees of freedom
      # due to some upstream bijector. We assume that the upstream bijector
      # produced a valid LDJ, but this one does not (unless LDJ is 0, in which
      # case it doesn't matter).
      increased_dof = ps.reduce_any(nest.flatten(increased_dof))
      if self.validate_event_size:
        assertions = [
            self._maybe_warn_increased_dof(
                component_name=bij.name, increased_dof=increased_dof)
        ]
        increased_dof |= (_event_size(y, y_event_ndims)
                          > _event_size(x, x_event_ndims))
      else:
        assertions = []

      increased_dof = nest_util.broadcast_structure(y, increased_dof)
      bijectors_with_metadata.append(
          BijectorWithMetadata(
              bijector=bij,
              x=x,
              x_event_ndims=x_event_ndims,
              kwargs=kwargs,
              assertions=assertions,
          ))
      return y, y_event_ndims, increased_dof

    increased_dof = nest_util.broadcast_structure(event_ndims, False)
    walk_forward_fn(step, x, event_ndims, increased_dof, **kwargs)
    return bijectors_with_metadata
예제 #13
0
    def initialize(self):
        """Initializes an empty `RunningPotentialScaleReductionState`.

    Returns:
      state: `RunningPotentialScaleReductionState` representing a stream
        of no inputs.
    """
        broadcasted_dtype = nest_util.broadcast_structure(
            self.independent_chain_ndims, self.dtype)
        chain_var = nest.map_structure_up_to(self.independent_chain_ndims,
                                             RunningVariance.from_shape,
                                             self.shape,
                                             broadcasted_dtype,
                                             check_types=False)
        return RunningPotentialScaleReductionState(chain_var)
예제 #14
0
    def _forward_log_det_jacobian(self, x, event_ndims, **kwargs):
        # Container for accumulated LDJ.
        ldj_sum = tf.zeros([], dtype=tf.float32)
        # Container for accumulated assertions.
        assertions = []

        def step(bij, x, x_event_ndims, increased_dof, **kwargs):  # pylint: disable=missing-docstring
            nonlocal ldj_sum

            # Compute the LDJ for this step, and add it to the rolling sum.
            component_ldj = tf.convert_to_tensor(bij.forward_log_det_jacobian(
                x, x_event_ndims, **kwargs),
                                                 dtype_hint=ldj_sum.dtype)

            if not dtype_util.is_floating(component_ldj.dtype):
                raise TypeError(
                    ('Nested bijector "{}" of Composition "{}" returned '
                     'LDJ with a non-floating dtype: {}').format(
                         bij.name, self.name, component_ldj.dtype))
            ldj_sum = _max_precision_sum(ldj_sum, component_ldj)

            # Transform inputs for the next bijector.
            y = bij.forward(x, **kwargs)
            y_event_ndims = bij.forward_event_ndims(x_event_ndims, **kwargs)

            # Check if the inputs to this bijector have increased degrees of freedom
            # due to some upstream bijector. We assume that the upstream bijector
            # produced a valid LDJ, but this one does not (unless LDJ is 0, in which
            # case it doesn't matter).
            increased_dof = ps.reduce_any(nest.flatten(increased_dof))
            if self.validate_event_size:
                assertions.append(
                    self._maybe_warn_increased_dof(
                        component_name=bij.name,
                        component_ldj=component_ldj,
                        increased_dof=increased_dof))
                increased_dof |= (_event_size(y, y_event_ndims) > _event_size(
                    x, x_event_ndims))

            increased_dof = nest_util.broadcast_structure(y, increased_dof)
            return y, y_event_ndims, increased_dof

        increased_dof = nest_util.broadcast_structure(event_ndims, False)
        self._call_walk_forward(step, x, event_ndims, increased_dof, **kwargs)
        with tf.control_dependencies([x for x in assertions if x is not None]):
            return tf.identity(ldj_sum, name='fldj')
    def test_specifying_distribution_type(self, event_shape,
                                          base_distribution_cls):
        surrogate_posterior = (
            tfp.experimental.vi.build_factored_surrogate_posterior(
                event_shape=event_shape,
                base_distribution_cls=base_distribution_cls,
                validate_args=True))

        # Test that the surrogate uses the expected distribution types.
        if tf.nest.is_nested(surrogate_posterior.event_shape):
            ds, _ = surrogate_posterior.sample_distributions()
        else:
            ds = [surrogate_posterior]
        for cls, d in zip(
                nest_util.broadcast_structure(ds, base_distribution_cls), ds):
            while isinstance(d, tfd.Independent):
                d = d.distribution
            self.assertIsInstance(d, cls)
예제 #16
0
    def initialize(self):
        """Initializes an empty `RunningPotentialScaleReductionState`.

    Returns:
      state: `RunningPotentialScaleReductionState` representing a stream
        of no inputs.
    """
        def _initialize_for_one_state(shape, dtype):
            """Initializes a running variance state for one group of Markov chains."""
            var_stream = RunningVariance(shape, dtype=dtype)
            return var_stream.initialize()

        broadcasted_dtype = nest_util.broadcast_structure(
            self.independent_chain_ndims, self.dtype)
        chain_var = nest.map_structure_up_to(self.independent_chain_ndims,
                                             _initialize_for_one_state,
                                             self.shape,
                                             broadcasted_dtype,
                                             check_types=False)
        return RunningPotentialScaleReductionState(chain_var)
예제 #17
0
  def _get_bijectors_with_metadata(self,
                                   x=None,
                                   event_ndims=None,
                                   forward=True,
                                   pack_as_original_structure=False,
                                   **kwargs):
    """Trace bijectors + metadata forward/backward."""
    bijectors_with_metadata = []

    compute_x_values = x is not None
    if event_ndims is None:
      event_ndims = self.forward_min_event_ndims

    if forward:
      forward_fn = lambda bij, *args, **kwargs: bij.forward(*args, **kwargs)
      forward_event_ndims_fn = (
          lambda bij, *args, **kwargs: bij.forward_event_ndims(*args, **kwargs))
      walk_forward_fn = self._call_walk_forward
    else:
      forward_fn = lambda bij, *args, **kwargs: bij.inverse(*args, **kwargs)
      forward_event_ndims_fn = (
          lambda bij, *args, **kwargs: bij.inverse_event_ndims(*args, **kwargs))
      walk_forward_fn = self._call_walk_inverse

    def step(bij, x, x_event_ndims, increased_dof, **kwargs):  # pylint: disable=missing-docstring
      # Transform inputs for the next bijector.
      y = forward_fn(bij, x, **kwargs) if compute_x_values else None
      y_event_ndims = forward_event_ndims_fn(bij, x_event_ndims, **kwargs)

      # Check if the inputs to this bijector have increased degrees of freedom
      # due to some upstream bijector. We assume that the upstream bijector
      # produced a valid LDJ, but this one does not (unless LDJ is 0, in which
      # case it doesn't matter).
      increased_dof = ps.reduce_any(nest.flatten(increased_dof))
      if compute_x_values and self.validate_event_size:
        assertions = [
            self._maybe_warn_increased_dof(
                component_name=bij.name, increased_dof=increased_dof)
        ]
        increased_dof |= (_event_size(y, y_event_ndims)
                          > _event_size(x, x_event_ndims))
      else:
        assertions = []

      y = nest_util.broadcast_structure(y_event_ndims, y)
      increased_dof = nest_util.broadcast_structure(y_event_ndims,
                                                    increased_dof)
      bijectors_with_metadata.append(
          BijectorWithMetadata(
              bijector=bij,
              x=x,
              x_event_ndims=x_event_ndims,
              kwargs=kwargs,
              assertions=assertions,
          ))
      return y, y_event_ndims, increased_dof

    x = nest_util.broadcast_structure(event_ndims, x)
    increased_dof = nest_util.broadcast_structure(event_ndims, False)
    walk_forward_fn(step, x, event_ndims, increased_dof, **kwargs)

    if pack_as_original_structure:
      bijector_map = {id(bm.bijector): bm for bm in bijectors_with_metadata}
      bijectors_with_metadata = tf.nest.map_structure(
          lambda b: bijector_map[id(b)], self.bijectors)
    return bijectors_with_metadata
예제 #18
0
def effective_sample_size(states,
                          filter_threshold=0.,
                          filter_beyond_lag=None,
                          filter_beyond_positive_pairs=False,
                          cross_chain_dims=None,
                          validate_args=False,
                          name=None):
  """Estimate a lower bound on effective sample size for each independent chain.

  Roughly speaking, "effective sample size" (ESS) is the size of an iid sample
  with the same variance as `state`.

  More precisely, given a stationary sequence of possibly correlated random
  variables `X_1, X_2, ..., X_N`, identically distributed, ESS is the
  number such that

  ```
  Variance{ N**-1 * Sum{X_i} } = ESS**-1 * Variance{ X_1 }.
  ```

  If the sequence is uncorrelated, `ESS = N`.  If the sequence is positively
  auto-correlated, `ESS` will be less than `N`. If there are negative
  correlations, then `ESS` can exceed `N`.

  Some math shows that, with `R_k` the auto-correlation sequence,
  `R_k := Covariance{X_1, X_{1+k}} / Variance{X_1}`, we have

  ```
  ESS(N) =  N / [ 1 + 2 * ( (N - 1) / N * R_1 + ... + 1 / N * R_{N-1}  ) ]
  ```

  This function estimates the above by first estimating the auto-correlation.
  Since `R_k` must be estimated using only `N - k` samples, it becomes
  progressively noisier for larger `k`.  For this reason, the summation over
  `R_k` should be truncated at some number `filter_beyond_lag < N`. This
  function provides two methods to perform this truncation.

  * `filter_threshold` -- since many MCMC methods generate chains where `R_k >
    0`, a reasonable criterion is to truncate at the first index where the
    estimated auto-correlation becomes negative. This method does not estimate
    the `ESS` of super-efficient chains (where `ESS > N`) correctly.

  * `filter_beyond_positive_pairs` -- reversible MCMC chains produce
    an auto-correlation sequence with the property that pairwise sums of the
    elements of that sequence are positive [Geyer][1], i.e.
    `R_{2k} + R_{2k + 1} > 0` for `k in {0, ..., N/2}`. Deviations are only
    possible due to noise. This method truncates the auto-correlation sequence
    where the pairwise sums become non-positive.

  The arguments `filter_beyond_lag`, `filter_threshold` and
  `filter_beyond_positive_pairs` are filters intended to remove noisy tail terms
  from `R_k`.  You can combine `filter_beyond_lag` with `filter_threshold` or
  `filter_beyond_positive_pairs. E.g., combining `filter_beyond_lag` and
  `filter_beyond_positive_pairs` means that terms are removed if they were to be
  filtered under the `filter_beyond_lag` OR `filter_beyond_positive_pairs`
  criteria.

  This function can also compute cross-chain ESS following
  [Vehtari et al. (2019)][2] by specifying the `cross_chain_dims` argument.
  Cross-chain ESS takes into account the cross-chain variance to reduce the ESS
  in cases where the chains are not mixing well. In general, this will be a
  smaller number than computing the ESS for individual chains and then summing
  them. In an extreme case where the chains have fallen into K non-mixing modes,
  this function will return ESS ~ K. Even when chains are mixing well it is
  still preferrable to compute cross-chain ESS via this method because it will
  reduce the noise in the estimate of `R_k`, reducing the need for truncation.

  Args:
    states: `Tensor` or Python structure of `Tensor` objects.  Dimension zero
      should index identically distributed states.
    filter_threshold: `Tensor` or Python structure of `Tensor` objects.  Must
      broadcast with `state`.  The sequence of auto-correlations is truncated
      after the first appearance of a term less than `filter_threshold`.
      Setting to `None` means we use no threshold filter.  Since `|R_k| <= 1`,
      setting to any number less than `-1` has the same effect. Ignored if
      `filter_beyond_positive_pairs` is `True`.
    filter_beyond_lag: `Tensor` or Python structure of `Tensor` objects.  Must
      be `int`-like and scalar valued.  The sequence of auto-correlations is
      truncated to this length.  Setting to `None` means we do not filter based
      on the size of lags.
    filter_beyond_positive_pairs: Python boolean. If `True`, only consider the
      initial auto-correlation sequence where the pairwise sums are positive.
    cross_chain_dims: An integer `Tensor` or a structure of integer `Tensors`
      corresponding to each state component. If a list of `states` is provided,
      then this argument should also be a list of the same length. Which
      dimensions of `states` to treat as independent chains that ESS will be
      summed over.  If `None`, no summation is performed. Note this requires at
      least 2 chains.
    validate_args: Whether to add runtime checks of argument validity. If False,
      and arguments are incorrect, correct behavior is not guaranteed.
    name:  `String` name to prepend to created ops.

  Returns:
    ess: `Tensor` structure parallel to `states`.  The effective sample size of
      each component of `states`.  If `cross_chain_dims` is None, the shape will
      be `states.shape[1:]`. Otherwise, the shape is `tf.reduce_mean(states,
      cross_chain_dims).shape[1:]`.

  Raises:
    ValueError: If `states` and `filter_threshold` or `states` and
      `filter_beyond_lag` are both structures of different shapes.
    ValueError: If `cross_chain_dims` is not `None` and there are less than 2
      chains.

  #### Examples

  We use ESS to estimate standard error.

  ```
  import tensorflow as tf
  import tensorflow_probability as tfp
  tfd = tfp.distributions

  target = tfd.MultivariateNormalDiag(scale_diag=[1., 2.])

  # Get 1000 states from one chain.
  states = tfp.mcmc.sample_chain(
      num_burnin_steps=200,
      num_results=1000,
      current_state=tf.constant([0., 0.]),
      trace_fn=None,
      kernel=tfp.mcmc.HamiltonianMonteCarlo(
        target_log_prob_fn=target.log_prob,
        step_size=0.05,
        num_leapfrog_steps=20))
  print(states.shape)
  ==> (1000, 2)

  ess = effective_sample_size(states, filter_beyond_positive_pairs=True)
  print(ess.shape)
  ==> (2,)

  mean, variance = tf.nn.moments(states, axes=0)
  standard_error = tf.sqrt(variance / ess)
  ```

  #### References

  [1]: Charles J. Geyer, Practical Markov chain Monte Carlo (with discussion).
       Statistical Science, 7:473-511, 1992.

  [2]: Aki Vehtari, Andrew Gelman, Daniel Simpson, Bob Carpenter, Paul-Christian
       Burkner. Rank-normalization, folding, and localization: An improved R-hat
       for assessing convergence of MCMC, 2019. Retrieved from
       http://arxiv.org/abs/1903.08008
  """
  if cross_chain_dims is None:
    cross_chain_dims = nest_util.broadcast_structure(states, None)
  filter_beyond_lag = nest_util.broadcast_structure(states, filter_beyond_lag)
  filter_threshold = nest_util.broadcast_structure(states, filter_threshold)
  filter_beyond_positive_pairs = nest_util.broadcast_structure(
      states, filter_beyond_positive_pairs)

  # Process items, one at a time.
  def single_state(*args):
    return _effective_sample_size_single_state(
        *args, validate_args=validate_args)
  with tf.name_scope('effective_sample_size' if name is None else name):
    return nest.map_structure_up_to(
        states,
        single_state,
        states, filter_beyond_lag, filter_threshold,
        filter_beyond_positive_pairs, cross_chain_dims)
예제 #19
0
def init_near_unconstrained_zero(
    model=None, constraining_bijector=None, event_shapes=None,
    event_shape_tensors=None, batch_shapes=None, batch_shape_tensors=None,
    dtypes=None):
  """Returns an initialization Distribution for starting a Markov chain.

  This initialization scheme follows Stan: we sample every latent
  independently, uniformly from -2 to 2 in its unconstrained space,
  and then transform into constrained space to construct an initial
  state that can be passed to `sample_chain` or other MCMC drivers.

  The argument signature is arranged to let the user pass either a
  `JointDistribution` describing their model, if it's in that form, or
  the essential information necessary for the sampling, namely a
  bijector (from unconstrained to constrained space) and the desired
  shape and dtype of each sample (specified in constrained space).

  Note: As currently implemented, this function has the limitation
  that the batch shape of the supplied model is ignored, but that
  could probably be generalized if needed.

  Args:
    model: A `Distribution` (typically a `JointDistribution`) giving the
      model to be initialized.  If supplied, it is queried for
      its default event space bijector, its event shape, and its dtype.
      If not supplied, those three elements must be supplied instead.
    constraining_bijector: A (typically multipart) `Bijector` giving
      the mapping from unconstrained to constrained space.  If
      supplied together with a `model`, acts as an override.  A nested
      structure of `Bijector`s is accepted, and interpreted as
      applying in parallel to a corresponding structure of state parts
      (see `JointMap` for details).
    event_shapes: A structure of shapes giving the (unconstrained)
      event space shape of the desired samples.  Must be an acceptable
      input to `constraining_bijector.inverse_event_shape`.  If
      supplied together with `model`, acts as an override.
    event_shape_tensors: A structure of tensors giving the (unconstrained)
      event space shape of the desired samples.  Must be an acceptable
      input to `constraining_bijector.inverse_event_shape_tensor`.  If
      supplied together with `model`, acts as an override. Required if any of
      `event_shapes` are not fully-defined.
    batch_shapes: A structure of shapes giving the batch shape of the desired
      samples.  If supplied together with `model`, acts as an override.  If
      unspecified, we assume scalar batch `[]`.
    batch_shape_tensors: A structure of tensors giving the batch shape of the
      desired samples.  If supplied together with `model`, acts as an override.
      Required if any of `batch_shapes` are not fully-defined.
    dtypes: A structure of dtypes giving the (unconstrained) dtypes of
      the desired samples.  Must be an acceptable input to
      `constraining_bijector.inverse_dtype`.  If supplied together
      with `model`, acts as an override.

  Returns:
    init_dist: A `Distribution` representing the initialization
      distribution, in constrained space.  Samples from this
      `Distribution` are valid initial states for a Markov chain
      targeting the model.

  #### Example

  Initialize 100 chains from the unconstrained -2, 2 distribution
  for a model expressed as a `JointDistributionCoroutine`:

  ```python
  @tfp.distributions.JointDistributionCoroutine
  def model():
    ...

  init_dist = tfp.experimental.mcmc.init_near_unconstrained_zero(model)
  states = tfp.mcmc.sample_chain(
    current_state=init_dist.sample(100, seed=[4, 8]),
    ...)
  ```

  """
  # Canonicalize arguments into the parts we need, namely
  # the constraining_bijector, the event_shapes, and the dtypes.
  if model is not None:
    # Got a Distribution model; treat other arguments as overrides if
    # present.
    if constraining_bijector is None:
      # pylint: disable=protected-access
      constraining_bijector = model.experimental_default_event_space_bijector()
    if event_shapes is None:
      event_shapes = model.event_shape
    if event_shape_tensors is None:
      event_shape_tensors = model.event_shape_tensor()
    if dtypes is None:
      dtypes = model.dtype
    if batch_shapes is None:
      batch_shapes = nest_util.broadcast_structure(dtypes, model.batch_shape)
    if batch_shape_tensors is None:
      batch_shape_tensors = nest_util.broadcast_structure(
          dtypes, model.batch_shape_tensor())

  else:
    if constraining_bijector is None or event_shapes is None or dtypes is None:
      msg = ('Must pass either a Distribution (typically a JointDistribution), '
             'or a bijector, a structure of event shapes, and a '
             'structure of dtypes')
      raise ValueError(msg)
    event_shapes_fully_defined = all(tensorshape_util.is_fully_defined(s)
                                     for s in tf.nest.flatten(event_shapes))
    if not event_shapes_fully_defined and event_shape_tensors is None:
      raise ValueError('Must specify `event_shape_tensors` when `event_shapes` '
                       f'are not fully-defined: {event_shapes}')
    if batch_shapes is None:
      batch_shapes = tf.TensorShape([])
    batch_shapes = nest_util.broadcast_structure(dtypes, batch_shapes)
    batch_shapes_fully_defined = all(tensorshape_util.is_fully_defined(s)
                                     for s in tf.nest.flatten(batch_shapes))
    if batch_shape_tensors is None:
      if not batch_shapes_fully_defined:
        raise ValueError(
            'Must specify `batch_shape_tensors` when `batch_shapes` are not '
            f'fully-defined: {batch_shapes}')
      batch_shape_tensors = tf.nest.map_structure(
          tf.convert_to_tensor, batch_shapes)

  # Interpret a structure of Bijectors as the joint multipart bijector.
  if not isinstance(constraining_bijector, tfb.Bijector):
    constraining_bijector = tfb.JointMap(constraining_bijector)

  # Actually initialize
  def one_term(event_shape, event_shape_tensor, batch_shape, batch_shape_tensor,
               dtype):
    if not tensorshape_util.is_fully_defined(event_shape):
      event_shape = event_shape_tensor
    result = tfd.Sample(
        tfd.Uniform(low=tf.constant(-2., dtype=dtype),
                    high=tf.constant(2., dtype=dtype)),
        sample_shape=event_shape)
    if not tensorshape_util.is_fully_defined(batch_shape):
      batch_shape = batch_shape_tensor
      needs_bcast = True
    else:  # Only batch broadcast when batch ndims > 0.
      needs_bcast = bool(tensorshape_util.as_list(batch_shape))
    if needs_bcast:
      result = tfd.BatchBroadcast(result, batch_shape)
    return result

  inv_shapes = constraining_bijector.inverse_event_shape(event_shapes)
  if event_shape_tensors is not None:
    inv_shape_tensors = constraining_bijector.inverse_event_shape_tensor(
        event_shape_tensors)
  else:
    inv_shape_tensors = tf.nest.map_structure(lambda _: None, inv_shapes)
  inv_dtypes = constraining_bijector.inverse_dtype(dtypes)
  terms = tf.nest.map_structure(
      one_term, inv_shapes, inv_shape_tensors, batch_shapes,
      batch_shape_tensors, inv_dtypes)
  unconstrained = tfb.pack_sequence_as(inv_shapes)(
      tfd.JointDistributionSequential(tf.nest.flatten(terms)))
  return tfd.TransformedDistribution(
      unconstrained, bijector=constraining_bijector)
예제 #20
0
def _infer_min_event_ndims(bijectors):
    """Computes `min_event_ndims` for a sequence of bijectors."""
    # Find the index of the first bijector with statically-known min_event_ndims.
    try:
        idx = next(i for i, b in enumerate(bijectors)
                   if b.has_static_min_event_ndims)
    except StopIteration:
        # If none of the nested bijectors have static min_event_ndims, give up
        # and return tail-structures filled with `None`.
        return (nest_util.broadcast_structure(
            bijectors[-1].forward_min_event_ndims, None),
                nest_util.broadcast_structure(
                    bijectors[0].inverse_min_event_ndims, None))

    # Accumulator tracking the maximum value of "min_event_ndims - ndims".
    rolling_offset = 0

    def update_event_ndims(input_event_ndims, input_min_event_ndims,
                           output_min_event_ndims):
        """Returns output_event_ndims and updates rolling_offset as needed."""
        nonlocal rolling_offset
        ldj_reduce_ndims = bijector_lib.ldj_reduction_ndims(
            input_event_ndims, input_min_event_ndims)
        # Update rolling_offset when batch_ndims are negative.
        rolling_offset = ps.maximum(rolling_offset, -ldj_reduce_ndims)
        return nest.map_structure(lambda nd: ldj_reduce_ndims + nd,
                                  output_min_event_ndims)

    def sanitize_event_ndims(event_ndims):
        """Updates `rolling_offset` when event_ndims are negative."""
        nonlocal rolling_offset
        max_missing_ndims = -ps.reduce_min(nest.flatten(event_ndims))
        rolling_offset = ps.maximum(rolling_offset, max_missing_ndims)
        return event_ndims

    # Wrappers for Bijector.forward_event_ndims and Bijector.inverse_event_ndims
    # that recursively walk into Composition bijectors when static min_event_ndims
    # is not available.

    def update_f_event_ndims(bij, event_ndims):
        event_ndims = nest_util.coerce_structure(bij.inverse_min_event_ndims,
                                                 event_ndims)
        if bij.has_static_min_event_ndims:
            return update_event_ndims(
                input_event_ndims=event_ndims,
                input_min_event_ndims=bij.inverse_min_event_ndims,
                output_min_event_ndims=bij.forward_min_event_ndims)
        elif isinstance(bij, composition.Composition):
            return bij._call_walk_inverse(update_f_event_ndims, event_ndims)  # pylint: disable=protected-access
        else:
            return sanitize_event_ndims(bij.inverse_event_ndims(event_ndims))

    def update_i_event_ndims(bij, event_ndims):
        event_ndims = nest_util.coerce_structure(bij.forward_min_event_ndims,
                                                 event_ndims)
        if bij.has_static_min_event_ndims:
            return update_event_ndims(
                input_event_ndims=event_ndims,
                input_min_event_ndims=bij.forward_min_event_ndims,
                output_min_event_ndims=bij.inverse_min_event_ndims)
        elif isinstance(bij, composition.Composition):
            return bij._call_walk_forward(update_i_event_ndims, event_ndims)  # pylint: disable=protected-access
        else:
            return sanitize_event_ndims(bij.forward_event_ndims(event_ndims))

    # Initialize event_ndims to the first statically-known min_event_ndims in
    # the Chain of bijectors.
    f_event_ndims = i_event_ndims = bijectors[idx].inverse_min_event_ndims
    for b in bijectors[idx:]:
        f_event_ndims = update_f_event_ndims(b, f_event_ndims)
    for b in reversed(bijectors[:idx]):
        i_event_ndims = update_i_event_ndims(b, i_event_ndims)

    # Shift both event_ndims to satisfy min_event_ndims for nested components.
    return (nest.map_structure(lambda nd: rolling_offset + nd, f_event_ndims),
            nest.map_structure(lambda nd: rolling_offset + nd, i_event_ndims))
예제 #21
0
    def __init__(self,
                 output_structure,
                 input_structure=None,
                 name='restructure'):
        """Creates a `Restructure` bijector.

    Args:
      output_structure: A tf.nest-compatible structure of tokens describing the
        output of `forward` (equivalently, the input of `inverse`).
      input_structure: A tf.nest-compatible structure of tokens describing the
        input to `forward`. If unspecified, a default structure is inferred from
        `output_structure`. The default structure expects a `list` if tokens are
        integers, or a `dict` if the tokens are strings.
      name: Name of this bijector.
    Raises:
      ValueError: If tokens are duplicated, or a required default structure
        cannot be inferred.
    """
        parameters = dict(locals())

        # Get the flat set of tokens, making sure they're unique.
        output_tokens = unique_token_set(output_structure)

        # Create a default input_structure when it isn't provided.
        if input_structure is None:
            # If all tokens are strings, assume input is a dict.
            if all(isinstance(tok, six.string_types) for tok in output_tokens):
                input_structure = {token: token for token in output_tokens}

            # If tokens are contiguous 0-based ints, return a list.
            elif (all(
                    isinstance(tok, six.integer_types)
                    for tok in output_tokens)
                  and output_tokens == set(range(len(output_tokens)))):
                input_structure = list(range(len(output_tokens)))

            # Otherwise, we cannot infer a default structure.
            else:
                raise ValueError(
                    ('Tokens in output_structure must be all strings or '
                     'contiguous 0-based indices when input_structure '
                     'is not specified. Saw: {}').format(output_tokens))

        # If input_structure _is_ provided, make sure tokens are unique
        # and that they match the output_structure tokens.
        else:
            input_tokens = unique_token_set(output_structure)
            if input_tokens != output_tokens:
                raise ValueError(
                    ('The `input_structure` tokens must match the '
                     '`output_structure` tokens exactly. Missing from '
                     '`input_structure`: {}. Missing from '
                     '`output_structure`: {}.').format(
                         output_tokens - input_tokens,
                         input_tokens - output_tokens))

        self._input_structure = self._no_dependency(input_structure)
        self._output_structure = self._no_dependency(output_structure)
        super(Restructure, self).__init__(
            forward_min_event_ndims=nest_util.broadcast_structure(
                self._input_structure, 0),
            inverse_min_event_ndims=nest_util.broadcast_structure(
                self._output_structure, 0),
            is_constant_jacobian=True,
            validate_args=False,
            parameters=parameters,
            name=name)
예제 #22
0
 def testBroadcastStructure(self, from_structure, to_structure, expected):
     ret = nest_util.broadcast_structure(to_structure, from_structure)
     self.assertAllEqual(expected, ret)
예제 #23
0
def _factored_surrogate_posterior(  # pylint: disable=dangerous-default-value
        event_shape=None,
        bijector=None,
        batch_shape=(),
        base_distribution_cls=normal.Normal,
        initial_parameters={'scale': 1e-2},
        dtype=tf.float32,
        validate_args=False,
        name=None):
    """Builds a joint variational posterior that factors over model variables.

  By default, this method creates an independent trainable Normal distribution
  for each variable, transformed using a bijector (if provided) to
  match the support of that variable. This makes extremely strong
  assumptions about the posterior: that it is approximately normal (or
  transformed normal), and that all model variables are independent.

  Args:
    event_shape: `Tensor` shape, or nested structure of `Tensor` shapes,
      specifying the event shape(s) of the posterior variables.
    bijector: Optional `tfb.Bijector` instance, or nested structure of such
      instances, defining support(s) of the posterior variables. The structure
      must match that of `event_shape` and may contain `None` values. A
      posterior variable will be modeled as
      `tfd.TransformedDistribution(underlying_dist, bijector)` if a
      corresponding constraining bijector is specified, otherwise it is modeled
      as supported on the unconstrained real line.
    batch_shape: The `batch_shape` of the output distribution.
      Default value: `()`.
    base_distribution_cls: Subclass of `tfd.Distribution` that is instantiated
      and optionally transformed by the bijector to define the component
      distributions. May optionally be a structure of such subclasses
      matching `event_shape`.
      Default value: `tfd.Normal`.
    initial_parameters: Optional `str : Tensor` dictionary specifying initial
      values for some or all of the base distribution's trainable parameters,
      or a Python `callable` with signature
      `value = parameter_init_fn(parameter_name, shape, dtype, seed,
      constraining_bijector)`, passed to `tfp.experimental.util.make_trainable`.
      May optionally be a structure matching `event_shape` of such dictionaries
      and/or callables. Dictionary entries that do not correspond to parameter
      names are ignored.
      Default value: `{'scale': 1e-2}` (ignored when `base_distribution` does
        not have a `scale` parameter).
    dtype: Optional float `dtype` for trainable parameters. May
      optionally be a structure of such `dtype`s matching `event_shape`.
      Default value: `tf.float32`.
    validate_args: Python `bool`. Whether to validate input with asserts. This
      imposes a runtime cost. If `validate_args` is `False`, and the inputs are
      invalid, correct behavior is not guaranteed.
      Default value: `False`.
    name: Python `str` name prefixed to ops created by this function.
      Default value: `None` (i.e., 'build_factored_surrogate_posterior').
  Yields:
    *parameters: sequence of `trainable_state_util.Parameter` namedtuples.
      These are intended to be consumed by
      `trainable_state_util.as_stateful_builder` and
      `trainable_state_util.as_stateless_builder` to define stateful and
      stateless variants respectively.

  ### Examples

  Consider a Gamma model with unknown parameters, expressed as a joint
  Distribution:

  ```python
  Root = tfd.JointDistributionCoroutine.Root
  def model_fn():
    concentration = yield Root(tfd.Exponential(1.))
    rate = yield Root(tfd.Exponential(1.))
    y = yield tfd.Sample(tfd.Gamma(concentration=concentration, rate=rate),
                         sample_shape=4)
  model = tfd.JointDistributionCoroutine(model_fn)
  ```

  Let's use variational inference to approximate the posterior over the
  data-generating parameters for some observed `y`. We'll build a
  surrogate posterior distribution by specifying the shapes of the latent
  `rate` and `concentration` parameters, and that both are constrained to
  be positive.

  ```python
  surrogate_posterior = tfp.experimental.vi.build_factored_surrogate_posterior(
    event_shape=model.event_shape_tensor()[:-1],  # Omit the observed `y`.
    bijector=[tfb.Softplus(),   # Rate is positive.
              tfb.Softplus()])  # Concentration is positive.
  ```

  This creates a trainable joint distribution, defined by variables in
  `surrogate_posterior.trainable_variables`. We use `fit_surrogate_posterior`
  to fit this distribution by minimizing a divergence to the true posterior.

  ```python
  y = [0.2, 0.5, 0.3, 0.7]
  losses = tfp.vi.fit_surrogate_posterior(
    lambda rate, concentration: model.log_prob([rate, concentration, y]),
    surrogate_posterior=surrogate_posterior,
    num_steps=100,
    optimizer=tf.optimizers.Adam(0.1),
    sample_size=10)

  # After optimization, samples from the surrogate will approximate
  # samples from the true posterior.
  samples = surrogate_posterior.sample(100)
  posterior_mean = [tf.reduce_mean(x) for x in samples]     # mean ~= [1.1, 2.1]
  posterior_std = [tf.math.reduce_std(x) for x in samples]  # std  ~= [0.3, 0.8]
  ```

  If we wanted to initialize the optimization at a specific location, we can
  specify initial parameters when we build the surrogate posterior. Note that
  these parameterize the distribution(s) over unconstrained values,
  so we need to transform our desired constrained locations using the inverse
  of the constraining bijector(s).

  ```python
  surrogate_posterior = tfp.experimental.vi.build_factored_surrogate_posterior(
    event_shape=tf.nest.map_fn(tf.shape, initial_loc),
    bijector={'concentration': tfb.Softplus(),   # Rate is positive.
              'rate': tfb.Softplus()}   # Concentration is positive.
    initial_parameters={
      'concentration': {'loc': tfb.Softplus().inverse(0.4), 'scale': 1e-2},
      'rate': {'loc': tfb.Softplus().inverse(0.2), 'scale': 1e-2}})
  ```

  """
    with tf.name_scope(name or 'build_factored_surrogate_posterior'):
        # Convert event shapes to Tensors.
        shallow_structure = _get_event_shape_shallow_structure(event_shape)
        event_shape = nest.map_structure_up_to(
            shallow_structure,
            lambda s: tf.convert_to_tensor(s, dtype=tf.int32), event_shape)

        if nest.is_nested(bijector):
            event_space_bijector = joint_map.JointMap(
                nest.map_structure(
                    lambda b: identity.Identity() if b is None else b,
                    nest_util.coerce_structure(event_shape, bijector)),
                validate_args=validate_args)
        else:
            event_space_bijector = bijector

        if event_space_bijector is None:
            unconstrained_event_shape = event_shape
        else:
            unconstrained_event_shape = (
                event_space_bijector.inverse_event_shape_tensor(event_shape))
        unconstrained_batch_and_event_shape = tf.nest.map_structure(
            lambda s: ps.concat([batch_shape, s], axis=0),
            unconstrained_event_shape)

        base_distribution_cls = nest_util.broadcast_structure(
            event_shape, base_distribution_cls)
        try:
            # Check that we have initial parameters for each event part.
            nest.assert_shallow_structure(event_shape, initial_parameters)
        except (ValueError, TypeError):
            # If not, broadcast the parameters to match the event structure.
            # We do this manually rather than using `nest_util.broadcast_structure`
            # because the initial parameters can themselves be structures (dicts).
            initial_parameters = nest.map_structure(
                lambda x: initial_parameters, event_shape)

        unconstrained_trainable_distributions = yield from (
            nest_util.map_structure_coroutine(
                trainable._make_trainable,  # pylint: disable=protected-access
                cls=base_distribution_cls,
                initial_parameters=initial_parameters,
                batch_and_event_shape=unconstrained_batch_and_event_shape,
                parameter_dtype=nest_util.broadcast_structure(
                    event_shape, dtype),
                _up_to=event_shape))
        unconstrained_trainable_distribution = (
            joint_distribution_util.
            independent_joint_distribution_from_structure(
                unconstrained_trainable_distributions,
                batch_ndims=ps.rank_from_shape(batch_shape),
                validate_args=validate_args))
        if event_space_bijector is None:
            return unconstrained_trainable_distribution
        return transformed_distribution.TransformedDistribution(
            unconstrained_trainable_distribution, event_space_bijector)
예제 #24
0
 def step_broadcast(step_size):
   return step_bijector(
       nest_util.broadcast_structure(pinned_model.event_shape_tensor(),
                                     step_size))
예제 #25
0
    def __init__(self,
                 output_structure,
                 input_structure=None,
                 name='restructure'):
        """Converts between nested structures of Tensor.

    This is useful when constructing non-trivial chains of multipart bijectors.
    It partitions inputs into different logical "blocks", which may be fed as
    arguments to downstream multipart bijectors.

    Example Usage:
      ```python

      # What restructure does:
      restructure = Restructure({
        'foo': [0, 1],
        'bar': [3, 2],
        'baz': [4, 5, 6]
      })

      # Note that x is a *python-list* of tensors.
      # To permute elements of an individual Tensor, see `tfb.Permute`.
      x = [1, 2, 4, 8, 16, 32, 64]

      assert restructure.forward() == {
          'foo': [1, 2],
          'bar': [8, 4],
          'baz': [16, 32, 64]
      }

      # Where restructure is useful:
      complex_bijector = Chain([
        # Apply different transformations to each block.
        JointMap({
          'foo': ScaleMatVecLinearOperator(...),  # Operates on the full block
          'bar': ScaleMatVecLinearOperator(...),  # Operates on the full block
          'baz': [Exp(), Scale(10.), Shift(-1.)]  # Different bijectors for each
        }),
        # Group the tensor into logical blocks.
        Restructure({
          'foo': [0, 1],
          'bar': [3, 2],
          'baz': [4, 5, 6],
        }),
        # Split an input tensor into 7 chunks.
        Split([2, 4, 6, 8, 10, 12, 14])
      ])
      ```

    Args:
      output_structure: A tf.nest-compatible structure of tokens describing the
        output of `forward` (equivalently, the input of `inverse`).
      input_structure: A tf.nest-compatible structure of tokens describing the
        input to `forward`. If unspecified, a default structure is inferred from
        `output_structure`. The default structure expects a `list` if tokens are
        integers, or a `dict` if the tokens are strings.
      name: Name of this bijector.
    Raises:
      ValueError: If tokens are duplicated, or a required default structure
        cannot be inferred.
    """
        parameters = dict(locals())

        # Get the flat set of tokens, making sure they're unique.
        output_tokens = unique_token_set(output_structure)

        # Create a default input_structure when it isn't provided.
        if input_structure is None:
            # If all tokens are strings, assume input is a dict.
            if all(isinstance(tok, six.string_types) for tok in output_tokens):
                input_structure = {token: token for token in output_tokens}

            # If tokens are contiguous 0-based ints, return a list.
            elif (all(
                    isinstance(tok, six.integer_types)
                    for tok in output_tokens)
                  and output_tokens == set(range(len(output_tokens)))):
                input_structure = list(range(len(output_tokens)))

            # Otherwise, we cannot infer a default structure.
            else:
                raise ValueError(
                    ('Tokens in output_structure must be all strings or '
                     'contiguous 0-based indices when input_structure '
                     'is not specified. Saw: {}').format(output_tokens))

        # If input_structure _is_ provided, make sure tokens are unique
        # and that they match the output_structure tokens.
        else:
            input_tokens = unique_token_set(output_structure)
            if input_tokens != output_tokens:
                raise ValueError(
                    ('The `input_structure` tokens must match the '
                     '`output_structure` tokens exactly. Missing from '
                     '`input_structure`: {}. Missing from '
                     '`output_structure`: {}.').format(
                         output_tokens - input_tokens,
                         input_tokens - output_tokens))

        self._input_structure = self._no_dependency(input_structure)
        self._output_structure = self._no_dependency(output_structure)
        super(Restructure, self).__init__(
            forward_min_event_ndims=nest_util.broadcast_structure(
                self._input_structure, None),
            inverse_min_event_ndims=nest_util.broadcast_structure(
                self._output_structure, None),
            is_constant_jacobian=True,
            validate_args=False,
            parameters=parameters,
            name=name)