def _prepare_args(target_log_prob_fn,
                  state,
                  step_size,
                  momentum_distribution,
                  target_log_prob=None,
                  grads_target_log_prob=None,
                  maybe_expand=False,
                  state_gradients_are_stopped=False):
    """Helper which processes input args to meet list-like assumptions."""
    state_parts, _ = mcmc_util.prepare_state_parts(state, name='current_state')
    if state_gradients_are_stopped:
        state_parts = [tf.stop_gradient(x) for x in state_parts]
    target_log_prob, grads_target_log_prob = mcmc_util.maybe_call_fn_and_grads(
        target_log_prob_fn, state_parts, target_log_prob,
        grads_target_log_prob)
    step_sizes, _ = mcmc_util.prepare_state_parts(step_size,
                                                  dtype=target_log_prob.dtype,
                                                  name='step_size')

    # Default momentum distribution is None, but if `store_parameters_in_results`
    # is true, then `momentum_distribution` defaults to an empty list
    if momentum_distribution is None or isinstance(momentum_distribution,
                                                   list):
        batch_rank = ps.rank(target_log_prob)

        def _batched_isotropic_normal_like(state_part):
            event_ndims = ps.rank(state_part) - batch_rank
            return independent.Independent(
                normal.Normal(ps.zeros_like(state_part, tf.float32), 1.),
                reinterpreted_batch_ndims=event_ndims)

        momentum_distribution = jds.JointDistributionSequential([
            _batched_isotropic_normal_like(state_part)
            for state_part in state_parts
        ])

    # The momentum will get "maybe listified" to zip with the state parts,
    # and this step makes sure that the momentum distribution will have the
    # same "maybe listified" underlying shape.
    if not mcmc_util.is_list_like(momentum_distribution.dtype):
        momentum_distribution = jds.JointDistributionSequential(
            [momentum_distribution])

    if len(step_sizes) == 1:
        step_sizes *= len(state_parts)
    if len(state_parts) != len(step_sizes):
        raise ValueError(
            'There should be exactly one `step_size` or it should '
            'have same length as `current_state`.')

    def maybe_flatten(x):
        return x if maybe_expand or mcmc_util.is_list_like(state) else x[0]

    return [
        maybe_flatten(state_parts),
        maybe_flatten(step_sizes),
        momentum_distribution,
        target_log_prob,
        grads_target_log_prob,
    ]
def _maybe_build_joint_distribution(structure_of_distributions):
    """Turns a (potentially nested) structure of dists into a single dist."""
    # Base case: if we already have a Distribution, return it.
    if dist_util.is_distribution_instance(structure_of_distributions):
        return structure_of_distributions

    # Otherwise, recursively convert all interior nested structures into JDs.
    outer_structure = tf.nest.map_structure(_maybe_build_joint_distribution,
                                            structure_of_distributions)
    if (hasattr(outer_structure, '_asdict')
            or isinstance(outer_structure, collections.Mapping)):
        return joint_distribution_named.JointDistributionNamed(outer_structure)
    else:
        return joint_distribution_sequential.JointDistributionSequential(
            outer_structure)
Example #3
0
def independent_joint_distribution_from_structure(structure_of_distributions,
                                                  validate_args=False):
    """Turns a (potentially nested) structure of dists into a single dist.

  Args:
    structure_of_distributions: instance of `tfd.Distribution`, or nested
      structure (tuple, list, dict, etc.) in which all leaves are
      `tfd.Distribution` instances.
    validate_args: Python `bool`. Whether the joint distribution should validate
      input with asserts. This imposes a runtime cost. If `validate_args` is
      `False`, and the inputs are invalid, correct behavior is not guaranteed.
      Default value: `False`.
  Returns:
    distribution: instance of `tfd.Distribution` such that
      `distribution.sample()` is equivalent to
      `tf.nest.map_structure(lambda d: d.sample(), structure_of_distributions)`.
      If `structure_of_distributions` was indeed a structure (as opposed to
      a single `Distribution` instance), this will be a `JointDistribution`
      with the corresponding structure.
  Raises:
    TypeError: if any leaves of the input structure are not `tfd.Distribution`
      instances.
  """
    # If input is already a Distribution, just return it.
    if dist_util.is_distribution_instance(structure_of_distributions):
        return structure_of_distributions

    # If this structure contains other structures (ie, has elements at depth > 1),
    # recursively turn them into JDs.
    element_depths = nest.map_structure_with_tuple_paths(
        lambda path, x: len(path), structure_of_distributions)
    if max(tf.nest.flatten(element_depths)) > 1:
        next_level_shallow_structure = nest.get_traverse_shallow_structure(
            traverse_fn=lambda x: min(tf.nest.flatten(x)) <= 1,
            structure=element_depths)
        structure_of_distributions = nest.map_structure_up_to(
            next_level_shallow_structure,
            independent_joint_distribution_from_structure,
            structure_of_distributions)

    # Otherwise, build a JD from the current structure.
    if (hasattr(structure_of_distributions, '_asdict')
            or isinstance(structure_of_distributions, collections.Mapping)):
        return joint_distribution_named.JointDistributionNamed(
            structure_of_distributions, validate_args=validate_args)
    return joint_distribution_sequential.JointDistributionSequential(
        structure_of_distributions, validate_args=validate_args)
Example #4
0
  def __init__(self,
               distributions,
               dtype_override=None,
               validate_args=False,
               allow_nan_stats=False,
               name='Blockwise'):
    """Construct the `Blockwise` distribution.

    Args:
      distributions: Python `list` of `tfp.distributions.Distribution`
        instances. All distribution instances must have the same `batch_shape`
        and all must have `event_ndims==1`, i.e., be vector-variate
        distributions.
      dtype_override: samples of `distributions` will be cast to this `dtype`.
        If unspecified, all `distributions` must have the same `dtype`.
        Default value: `None` (i.e., do not cast).
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or more
        of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """
    parameters = dict(locals())
    with tf.name_scope(name) as name:
      self._distributions = distributions
      if dtype_override is not None:
        distributions = tf.nest.map_structure(
            lambda d: _Cast(d, dtype_override), distributions)
      if _is_iterable(distributions):
        self._distribution = (
            joint_distribution_sequential.JointDistributionSequential(
                list(distributions)))
      else:
        self._distribution = distributions

      # Need to cache these for JointDistributions as the batch shape of that
      # distribution can change after `_sample` calls.
      self._cached_batch_shape_tensor = self._distribution.batch_shape_tensor()
      self._cached_batch_shape = self._distribution.batch_shape

      if dtype_override is not None:
        dtype = dtype_override
      else:
        dtype = set(
            dtype_util.base_dtype(dtype)
            for dtype in tf.nest.flatten(self._distribution.dtype)
            if dtype is not None)
        if len(dtype) == 0:  # pylint: disable=g-explicit-length-test
          dtype = tf.float32
        elif len(dtype) == 1:
          dtype = dtype.pop()
        else:
          raise TypeError(
              'Distributions must have same dtype; found: {}.'.format(
                  self._distribution.dtype))

      reparameterization_type = set(
          tf.nest.flatten(self._distribution.reparameterization_type))
      reparameterization_type = (
          reparameterization_type.pop() if len(reparameterization_type) == 1
          else reparameterization.NOT_REPARAMETERIZED)

      super(Blockwise, self).__init__(
          dtype=dtype,
          validate_args=validate_args,
          allow_nan_stats=allow_nan_stats,
          reparameterization_type=reparameterization_type,
          parameters=parameters,
          name=name)
Example #5
0
def _prepare_args(target_log_prob_fn,
                  state,
                  step_size,
                  momentum_distribution,
                  target_log_prob=None,
                  grads_target_log_prob=None,
                  maybe_expand=False,
                  state_gradients_are_stopped=False):
  """Helper which processes input args to meet list-like assumptions."""
  state_parts, _ = mcmc_util.prepare_state_parts(state, name='current_state')
  if state_gradients_are_stopped:
    state_parts = [tf.stop_gradient(x) for x in state_parts]
  target_log_prob, grads_target_log_prob = mcmc_util.maybe_call_fn_and_grads(
      target_log_prob_fn, state_parts, target_log_prob, grads_target_log_prob)
  step_sizes, _ = mcmc_util.prepare_state_parts(
      step_size, dtype=target_log_prob.dtype, name='step_size')

  # Default momentum distribution is None, but if `store_parameters_in_results`
  # is true, then `momentum_distribution` defaults to DefaultStandardNormal().
  if (momentum_distribution is None or
      isinstance(momentum_distribution, DefaultStandardNormal)):
    batch_rank = ps.rank(target_log_prob)
    def _batched_isotropic_normal_like(state_part):
      return sample.Sample(
          normal.Normal(ps.zeros([], dtype=state_part.dtype), 1.),
          ps.shape(state_part)[batch_rank:])

    momentum_distribution = jds.JointDistributionSequential(
        [_batched_isotropic_normal_like(state_part)
         for state_part in state_parts])

  # The momentum will get "maybe listified" to zip with the state parts,
  # and this step makes sure that the momentum distribution will have the
  # same "maybe listified" underlying shape.
  if not mcmc_util.is_list_like(momentum_distribution.dtype):
    momentum_distribution = jds.JointDistributionSequential(
        [momentum_distribution])

  # If all underlying distributions are independent, we can offer some help.
  # This code will also trigger for the output of the two blocks above.
  if (isinstance(momentum_distribution, jds.JointDistributionSequential) and
      not any(callable(dist_fn) for dist_fn in momentum_distribution.model)):
    batch_shape = ps.shape(target_log_prob)
    momentum_distribution = momentum_distribution.copy(model=[
        batch_broadcast.BatchBroadcast(md, to_shape=batch_shape)
        for md in momentum_distribution.model
    ])

  if len(step_sizes) == 1:
    step_sizes *= len(state_parts)
  if len(state_parts) != len(step_sizes):
    raise ValueError('There should be exactly one `step_size` or it should '
                     'have same length as `current_state`.')
  def maybe_flatten(x):
    return x if maybe_expand or mcmc_util.is_list_like(state) else x[0]
  return [
      maybe_flatten(state_parts),
      maybe_flatten(step_sizes),
      momentum_distribution,
      target_log_prob,
      grads_target_log_prob,
  ]