Ejemplo n.º 1
0
def covariance(x,
               y=None,
               sample_axis=0,
               event_axis=-1,
               keepdims=False,
               name=None):
    """Sample covariance between observations indexed by `event_axis`.

  Given `N` samples of scalar random variables `X` and `Y`, covariance may be
  estimated as

  ```none
  Cov[X, Y] := N^{-1} sum_{n=1}^N (X_n - Xbar) Conj{(Y_n - Ybar)}
  Xbar := N^{-1} sum_{n=1}^N X_n
  Ybar := N^{-1} sum_{n=1}^N Y_n
  ```

  For vector-variate random variables `X = (X1, ..., Xd)`, `Y = (Y1, ..., Yd)`,
  one is often interested in the covariance matrix, `C_{ij} := Cov[Xi, Yj]`.

  ```python
  x = tf.random.normal(shape=(100, 2, 3))
  y = tf.random.normal(shape=(100, 2, 3))

  # cov[i, j] is the sample covariance between x[:, i, j] and y[:, i, j].
  cov = tfp.stats.covariance(x, y, sample_axis=0, event_axis=None)

  # cov_matrix[i, m, n] is the sample covariance of x[:, i, m] and y[:, i, n]
  cov_matrix = tfp.stats.covariance(x, y, sample_axis=0, event_axis=-1)
  ```

  Notice we divide by `N` (the numpy default), which does not create `NaN`
  when `N = 1`, but is slightly biased.

  Args:
    x:  A numeric `Tensor` holding samples.
    y:  Optional `Tensor` with same `dtype` and `shape` as `x`.
      Default value: `None` (`y` is effectively set to `x`).
    sample_axis: Scalar or vector `Tensor` designating axis holding samples, or
      `None` (meaning all axis hold samples).
      Default value: `0` (leftmost dimension).
    event_axis:  Scalar or vector `Tensor`, or `None` (scalar events).
      Axis indexing random events, whose covariance we are interested in.
      If a vector, entries must form a contiguous block of dims. `sample_axis`
      and `event_axis` should not intersect.
      Default value: `-1` (rightmost axis holds events).
    keepdims:  Boolean.  Whether to keep the sample axis as singletons.
    name: Python `str` name prefixed to Ops created by this function.
          Default value: `None` (i.e., `'covariance'`).

  Returns:
    cov: A `Tensor` of same `dtype` as the `x`, and rank equal to
      `rank(x) - len(sample_axis) + 2 * len(event_axis)`.

  Raises:
    AssertionError:  If `x` and `y` are found to have different shape.
    ValueError:  If `sample_axis` and `event_axis` are found to overlap.
    ValueError:  If `event_axis` is found to not be contiguous.
  """

    with tf.name_scope(name or 'covariance'):
        x = tf.convert_to_tensor(x, name='x')
        # Covariance *only* uses the centered versions of x (and y).
        x -= tf.reduce_mean(x, axis=sample_axis, keepdims=True)

        if y is None:
            y = x
        else:
            y = tf.convert_to_tensor(y, name='y', dtype=x.dtype)
            # If x and y have different shape, sample_axis and event_axis will likely
            # be wrong for one of them!
            tensorshape_util.assert_is_compatible_with(x.shape, y.shape)
            y -= tf.reduce_mean(y, axis=sample_axis, keepdims=True)

        if event_axis is None:
            return tf.reduce_mean(x * tf.math.conj(y),
                                  axis=sample_axis,
                                  keepdims=keepdims)

        if sample_axis is None:
            raise ValueError(
                'sample_axis was None, which means all axis hold events, and this '
                'overlaps with event_axis ({})'.format(event_axis))

        event_axis = _make_positive_axis(event_axis, ps.rank(x))
        sample_axis = _make_positive_axis(sample_axis, ps.rank(x))

        # If we get lucky and axis is statically defined, we can do some checks.
        if _is_list_like(event_axis) and _is_list_like(sample_axis):
            event_axis = tuple(map(int, event_axis))
            sample_axis = tuple(map(int, sample_axis))
            if set(event_axis).intersection(sample_axis):
                raise ValueError(
                    'sample_axis ({}) and event_axis ({}) overlapped'.format(
                        sample_axis, event_axis))
            if (np.diff(np.array(sorted(event_axis))) > 1).any():
                raise ValueError(
                    'event_axis must be contiguous. Found: {}'.format(
                        event_axis))
            batch_axis = list(
                sorted(
                    set(range(tensorshape_util.rank(
                        x.shape))).difference(sample_axis + event_axis)))
        else:
            batch_axis = ps.setdiff1d(ps.range(0, ps.rank(x)),
                                      ps.concat((sample_axis, event_axis), 0))

        event_axis = ps.cast(event_axis, dtype=tf.int32)
        sample_axis = ps.cast(sample_axis, dtype=tf.int32)
        batch_axis = ps.cast(batch_axis, dtype=tf.int32)

        # Permute x/y until shape = B + E + S
        perm_for_xy = ps.concat((batch_axis, event_axis, sample_axis), 0)
        x_permed = tf.transpose(a=x, perm=perm_for_xy)
        y_permed = tf.transpose(a=y, perm=perm_for_xy)

        batch_ndims = ps.size(batch_axis)
        batch_shape = ps.shape(x_permed)[:batch_ndims]
        event_ndims = ps.size(event_axis)
        event_shape = ps.shape(x_permed)[batch_ndims:batch_ndims + event_ndims]
        sample_shape = ps.shape(x_permed)[batch_ndims + event_ndims:]
        sample_ndims = ps.size(sample_shape)
        n_samples = ps.reduce_prod(sample_shape)
        n_events = ps.reduce_prod(event_shape)

        # Flatten sample_axis into one long dim.
        x_permed_flat = tf.reshape(
            x_permed, ps.concat((batch_shape, event_shape, [n_samples]), 0))
        y_permed_flat = tf.reshape(
            y_permed, ps.concat((batch_shape, event_shape, [n_samples]), 0))
        # Do the same for event_axis.
        x_permed_flat = tf.reshape(
            x_permed, ps.concat((batch_shape, [n_events], [n_samples]), 0))
        y_permed_flat = tf.reshape(
            y_permed, ps.concat((batch_shape, [n_events], [n_samples]), 0))

        # After matmul, cov.shape = batch_shape + [n_events, n_events]
        cov = tf.matmul(x_permed_flat, y_permed_flat,
                        adjoint_b=True) / ps.cast(n_samples, x.dtype)

        # Insert some singletons to make
        # cov.shape = batch_shape + event_shape**2 + [1,...,1]
        # This is just like x_permed.shape, except the sample_axis is all 1's, and
        # the [n_events] became event_shape**2.
        cov = tf.reshape(
            cov,
            ps.concat(
                (
                    batch_shape,
                    # event_shape**2 used here because it is the same length as
                    # event_shape, and has the same number of elements as one
                    # batch of covariance.
                    event_shape**2,
                    ps.ones([sample_ndims], tf.int32)),
                0))
        # Permuting by the argsort inverts the permutation, making
        # cov.shape have ones in the position where there were samples, and
        # [n_events * n_events] in the event position.
        cov = tf.transpose(a=cov, perm=ps.invert_permutation(perm_for_xy))

        # Now expand event_shape**2 into event_shape + event_shape.
        # We here use (for the first time) the fact that we require event_axis to be
        # contiguous.
        e_start = event_axis[0]
        e_len = 1 + event_axis[-1] - event_axis[0]
        cov = tf.reshape(
            cov,
            ps.concat((ps.shape(cov)[:e_start], event_shape, event_shape,
                       ps.shape(cov)[e_start + e_len:]), 0))

        # tf.squeeze requires python ints for axis, not Tensor.  This is enough to
        # require our axis args to be constants.
        if not keepdims:
            squeeze_axis = ps.where(sample_axis < e_start, sample_axis,
                                    sample_axis + e_len)
            cov = _squeeze(cov, axis=squeeze_axis)

        return cov
Ejemplo n.º 2
0
    def _one_step_part(self,
                       step_size,
                       state,
                       error_sum,
                       log_averaging_step,
                       shrinkage_target,
                       log_accept_prob_rank=None,
                       log_accept_prob=None,
                       target_accept_prob=None,
                       previous_kernel_results=None):
        """Compute new step sizes for each step size part.

    If step size part has smaller rank than the corresponding state part, then
    the difference is averaged away in the log accept prob.

    Example:

      state_part has shape      [2, 3, 4, 5]
      step_size_part has shape     [1, 4, 1]
      log_accept_prob has shape [2, 3, 4]

    Since step size has 1 rank fewer than the state, we reduce away the leading
    dimension of `log_accept_prob` to get a Tensor with shape [3, 4]. Next,
    since `log_accept_prob` must broadcast into step_size_part on the left, we
    reduce the dimensions where their shapes differ, to get a Tensor with shape
    [1, 4], which now is compatible with the leading dimensions of
    step_size_part.

    There is a subtlety here in that `step_size_parts` might be a length-1 list,
    which means that we'll be "structure-broadcasting" it for all the state
    parts (see logic in, e.g., hmc.py). In this case we must assume that that
    the lone step size provided broadcasts with the event dims of each state
    part. This means that either step size has no dimensions corresponding to
    chain dimensions, or all states are of the same shape. For the former, we
    want to reduce over all chain dimensions. For the later, we want to use
    the same logic as in the non-structure-broadcasted case.

    It turns out we can compute the reduction dimensions for both cases
    uniformly by taking the rank of any state part. This obviously works in
    the second case (where all state ranks are the same). In the first case,
    all state parts have the rank L + D_i + B, where L is the rank of
    log_accept_prob, D_i is the non-shared dimensions amongst all states, and
    B are the shared dimensions of all the states, which are equal to the step
    size. When we subtract B, we will always get a number >= L, which means
    we'll get the full reduction we want.

    Args:
      step_size: Previous step's step_size.
      state: Previous step's state value.
      error_sum: Previous step's error accumulator.
      log_averaging_step: Previous step's log_averaging_step.
      shrinkage_target: Floating point scalar `Tensor`. Arbitrary value the
        exploration step size is biased towards.
      log_accept_prob_rank: Rank of log_accept_prob.
      log_accept_prob: Floating point scalar `Tensor`. Target accept
        probability.
      target_accept_prob: A floating point `Tensor` representing desired
        acceptance probability. Must be a positive number less than 1.
      previous_kernel_results: Results struct from previous step.

    Returns:
      new_step_size: Updated `step_size`.
      new_log_averaging_step: Updated `log_averaging_step`.
      new_error_sum: Updated `error_sum`.
    """
        num_reduce_dims = prefer_static.minimum(
            log_accept_prob_rank,
            (prefer_static.rank(state) - prefer_static.rank(step_size)))
        reduced_log_accept_prob = reduce_logmeanexp(
            log_accept_prob, axis=prefer_static.range(num_reduce_dims))

        # reduced_log_accept_prob must broadcast into step_size on the
        # left, so we do an additional reduction over dimensions where their
        # shapes differ.
        reduce_indices = _get_differing_dims(reduced_log_accept_prob,
                                             step_size)
        reduced_log_accept_prob = reduce_logmeanexp(reduced_log_accept_prob,
                                                    axis=reduce_indices,
                                                    keepdims=True)
        new_error_sum = (error_sum + target_accept_prob -
                         tf.math.exp(reduced_log_accept_prob))
        num_ones_to_pad = prefer_static.maximum(
            prefer_static.rank(shrinkage_target) -
            prefer_static.rank(new_error_sum), 0)
        new_error_sum_extend = tf.reshape(
            new_error_sum,
            shape=prefer_static.pad(prefer_static.shape(new_error_sum),
                                    paddings=[[0, num_ones_to_pad]],
                                    constant_values=1))

        step_count_smoothing = previous_kernel_results.step_count_smoothing
        step = tf.cast(previous_kernel_results.step,
                       step_count_smoothing.dtype) + 1.
        soft_t = step_count_smoothing + step

        new_log_step = (shrinkage_target - (
            (tf.cast(new_error_sum_extend, step.dtype) * tf.math.sqrt(step)) /
            (soft_t * previous_kernel_results.exploration_shrinkage)))

        eta = step**(-previous_kernel_results.decay_rate)
        new_log_averaging_step = (eta * new_log_step +
                                  (1. - eta) * log_averaging_step)

        # - If still adapting, return an exploring step size,
        # - If just finished, return the averaging step size
        # - Otherwise, do not update
        new_step_size = tf.where(
            previous_kernel_results.step < self.num_adaptation_steps,
            tf.math.exp(new_log_step),
            tf.where(previous_kernel_results.step > self.num_adaptation_steps,
                     step_size, tf.math.exp(new_log_averaging_step)))
        new_log_averaging_step = tf.where(
            previous_kernel_results.step > self.num_adaptation_steps,
            log_averaging_step, new_log_averaging_step)
        return new_step_size, new_log_averaging_step, new_error_sum
Ejemplo n.º 3
0
    def __init__(self,
                 distribution,
                 bijector,
                 batch_shape=None,
                 event_shape=None,
                 kwargs_split_fn=_default_kwargs_split_fn,
                 validate_args=False,
                 parameters=None,
                 name=None):
        """Construct a Transformed Distribution.

    Args:
      distribution: The base distribution instance to transform. Typically an
        instance of `Distribution`.
      bijector: The object responsible for calculating the transformation.
        Typically an instance of `Bijector`.
      batch_shape: `integer` vector `Tensor` which overrides `distribution`
        `batch_shape`; valid only if `distribution.is_scalar_batch()`.
      event_shape: `integer` vector `Tensor` which overrides `distribution`
        `event_shape`; valid only if `distribution.is_scalar_event()`.
      kwargs_split_fn: Python `callable` which takes a kwargs `dict` and returns
        a tuple of kwargs `dict`s for each of the `distribution` and `bijector`
        parameters respectively.
        Default value: `_default_kwargs_split_fn` (i.e.,
            `lambda kwargs: (kwargs.get('distribution_kwargs', {}),
                             kwargs.get('bijector_kwargs', {}))`)
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      parameters: Locals dict captured by subclass constructor, to be used for
        copy/slice re-instantiation operations.
      name: Python `str` name prefixed to Ops created by this class. Default:
        `bijector.name + distribution.name`.
    """
        parameters = dict(locals()) if parameters is None else parameters
        name = name or (("" if bijector is None else bijector.name) +
                        (distribution.name or ""))
        with tf.name_scope(name) as name:
            self._kwargs_split_fn = (_default_kwargs_split_fn
                                     if kwargs_split_fn is None else
                                     kwargs_split_fn)
            # For convenience we define some handy constants.
            self._zero = tf.constant(0, dtype=tf.int32, name="zero")
            self._empty = tf.constant([], dtype=tf.int32, name="empty")

            # We will keep track of a static and dynamic version of
            # self._is_{batch,event}_override. This way we can do more prior to graph
            # execution, including possibly raising Python exceptions.

            self._override_batch_shape = self._maybe_validate_shape_override(
                batch_shape, distribution.is_scalar_batch(), validate_args,
                "batch_shape")
            self._is_batch_override = prefer_static.logical_not(
                prefer_static.equal(
                    prefer_static.rank_from_shape(self._override_batch_shape),
                    self._zero))
            self._is_maybe_batch_override = bool(
                tf.get_static_value(self._override_batch_shape) is None
                or tf.get_static_value(self._override_batch_shape).size != 0)

            self._override_event_shape = self._maybe_validate_shape_override(
                event_shape, distribution.is_scalar_event(), validate_args,
                "event_shape")
            self._is_event_override = prefer_static.logical_not(
                prefer_static.equal(
                    prefer_static.rank_from_shape(self._override_event_shape),
                    self._zero))
            self._is_maybe_event_override = bool(
                tf.get_static_value(self._override_event_shape) is None
                or tf.get_static_value(self._override_event_shape).size != 0)

            # To convert a scalar distribution into a multivariate distribution we
            # will draw dims from the sample dims, which are otherwise iid. This is
            # easy to do except in the case that the base distribution has batch dims
            # and we're overriding event shape. When that case happens the event dims
            # will incorrectly be to the left of the batch dims. In this case we'll
            # cyclically permute left the new dims.
            self._needs_rotation = prefer_static.reduce_all([
                self._is_event_override,
                prefer_static.logical_not(self._is_batch_override),
                prefer_static.logical_not(distribution.is_scalar_batch())
            ])
            override_event_ndims = prefer_static.rank_from_shape(
                self._override_event_shape)
            self._rotate_ndims = _pick_scalar_condition(
                self._needs_rotation, override_event_ndims, 0)
            # We'll be reducing the head dims (if at all), i.e., this will be []
            # if we don't need to reduce.
            self._reduce_event_indices = prefer_static.range(
                self._rotate_ndims - override_event_ndims, self._rotate_ndims)

        self._distribution = distribution
        self._bijector = bijector
        super(TransformedDistribution, self).__init__(
            dtype=self._distribution.dtype,
            reparameterization_type=self._distribution.reparameterization_type,
            validate_args=validate_args,
            allow_nan_stats=self._distribution.allow_nan_stats,
            parameters=parameters,
            # We let TransformedDistribution access _graph_parents since this class
            # is more like a baseclass than derived.
            graph_parents=(
                distribution._graph_parents +  # pylint: disable=protected-access
                bijector.graph_parents),
            name=name)
Ejemplo n.º 4
0
def _gather_history(structure, step, num_steps):
  """Gather up to `num_steps` of history from a nested structure."""
  initial_step = prefer_static.maximum(0, step - num_steps)
  return tf.nest.map_structure(
      lambda x: tf.gather(x, prefer_static.range(initial_step, step)),
      structure)
Ejemplo n.º 5
0
    def estimate_parameters(self,
                            observations,
                            num_iterations,
                            num_particles,
                            initial_perturbation_scale,
                            cooling_schedule,
                            seed=None,
                            name=None,
                            **kwargs):
        """Runs multiple iterations of filtering following a cooling schedule.

    Args:
      observations: observed `Tensor` value(s) on which to condition the
        parameter estimate.
      num_iterations: int `Tensor` number of filtering iterations to run.
      num_particles: scalar int `Tensor` number of particles to use.
      initial_perturbation_scale: scalar float `Tensor`, or any structure of
        float `Tensor`s broadcasting to the same shape as the (unconstrained)
        parameters, specifying the scale (standard deviation) of Gaussian
        perturbations to each parameter at the first timestep.
      cooling_schedule: callable with signature
        `cooling_factor = cooling_schedule(iteration)` for `iteration` in
        `[0, ..., num_iterations - 1]`. The filter is
        invoked with perturbations of scale
        `initial_perturbation_scale * cooling_schedule(iteration)`.
      seed: PRNG seed; see `tfp.random.sanitize_seed` for details.
      name: `str` name for ops constructed by this method.
      **kwargs: additional keyword arguments passed to
        `tfp.experimental.mcmc.infer_trajectories`.
    Returns:
      final_parameter_particles: structure of `Tensor`s matching
        `self.parameter_prior`, each with batch shape
        `[num_iterations, num_particles]`. These are the populations
        of particles representing the parameter estimate after each iteration
        of filtering.
    """
        with self._name_scope(name or 'estimate_parameters'):

            step_seed, initial_seed = samplers.split_seed(seed)
            initial_perturbation_scale = tf.convert_to_tensor(
                initial_perturbation_scale, name='initial_perturbation_scale')

            # Get initial parameter particles from the first filtering iteration.
            initial_unconstrained_parameters = self.one_step(
                observations=observations,
                num_particles=num_particles,
                perturbation_scale=initial_perturbation_scale,
                seed=step_seed,
                **kwargs)

            # Run the remaining iterations and accumulate the results.
            @tf.function(autograph=False)
            def loop_body(unconstrained_parameters_seed, cooling_fraction):
                unconstrained_parameters, seed = unconstrained_parameters_seed
                step_seed, seed = samplers.split_seed(seed)
                return (self.one_step(
                    observations=observations,
                    num_particles=num_particles,
                    perturbation_scale=tf.nest.map_structure(
                        lambda s: cooling_fraction * s,
                        initial_perturbation_scale),
                    initial_unconstrained_parameters=unconstrained_parameters,
                    seed=step_seed,
                    **kwargs), seed)

            estimated_unconstrained_parameters, _ = tf.scan(
                fn=loop_body,
                elems=cooling_schedule(ps.range(1, num_iterations)),
                initializer=(initial_unconstrained_parameters, initial_seed))

            return self.parameter_constraining_bijector.forward(
                estimated_unconstrained_parameters)
Ejemplo n.º 6
0
def im2row_index(input_shape,
                 block_shape,
                 rank=2,
                 slice_step=(1, 1),
                 dilations=(1, 1),
                 dtype=tf.int32,
                 transpose=False,
                 validate_args=False,
                 name=None):
    """Computes indexes into a flattened image for building `im2row`."""
    with tf.name_scope(name or 'im2row_index'):
        if tf.get_static_value(rank) != 2:
            raise NotImplementedError(
                'Argument `rank` currently only supports `2`; '
                'saw "{}".'.format(rank))
        fh, fw = prepare_tuple_argument(block_shape,
                                        n=rank,
                                        arg_name='block_shape',
                                        validate_args=validate_args)
        sh, sw = prepare_tuple_argument(slice_step,
                                        n=rank,
                                        arg_name='slice_step',
                                        validate_args=validate_args)
        dh, dw = prepare_tuple_argument(dilations,
                                        n=rank,
                                        arg_name='dilations',
                                        validate_args=validate_args)

        # 1) Process input arguments.
        batch_shape, h, w, c = ps.split(ps.reshape(ps.cast(input_shape,
                                                           dtype=dtype),
                                                   shape=[-1]),
                                        num_or_size_splits=[-1, 1, 1, 1])
        h, w, c = h[0], w[0], c[0]

        tot_fh = dh * (fh - 1) + 1
        tot_fw = dw * (fw - 1) + 1

        # 2) Assemble all block start positions as indexes into the flattened image.
        # start_idx.shape = [fh, fw, c]
        if transpose:
            last_element = lambda size, step: size - (size - 1) % step - 1
            w_step = c * dw
            h_step = c * w * dh
            last_w = last_element(c * tot_fw, w_step)
            last_h = last_element(c * w * tot_fh, h_step)
            start_idx = cartesian_add([
                ps.range(last_h, -1, delta=-h_step, dtype=dtype),
                ps.range(last_w, -1, delta=-w_step, dtype=dtype),
                ps.range(c, delta=1, dtype=dtype),
            ])
        else:
            start_idx = cartesian_add([
                ps.range(c * w * tot_fh, delta=c * w * dh, dtype=dtype),
                ps.range(c * tot_fw, delta=c * dw, dtype=dtype),
                ps.range(c, delta=1, dtype=dtype),
            ])

        # 3) Assemble all block offsets (into flattened image).
        eh = h - tot_fh + 1
        ew = w - tot_fw + 1

        offset_idx = cartesian_add([
            ps.range(w * eh, delta=w * sh, dtype=dtype),
            ps.range(ew, delta=sw, dtype=dtype),
        ])

        offset_idx = offset_idx * c
        oh = (eh - 1) // sh + 1  # out height
        ow = (ew - 1) // sw + 1  # out width

        # 4) Combine block start/offset pairs.
        # shape = [(eh // sh) * (ew // sw), fh * fw * c]
        idx = cartesian_add([offset_idx, start_idx])
        new_shape = ps.concat(
            [batch_shape,
             ps.convert_to_shape_tensor([oh, ow, fh * fw * c])],
            axis=0)
        return idx, new_shape
def chees_criterion(previous_state,
                    proposed_state,
                    accept_prob,
                    validate_args=False,
                    experimental_shard_axis_names=None,
                    experimental_reduce_chain_axis_names=None):
  """The ChEES criterion from [1].

  ChEES stands for Change in the Estimator of the Expected Square.

  ```None
  ChEES = 1/4 E[(||x' - E[x]||**2 - ||x - E[x]||**2)**2],
  ```

  where `x` is the previous chain state, `x'` is the next chain state, and
  `||.||` is the L2 norm. Both expectations are with respect to the chain's
  stationary distribution. In practice, the inner expectation is replaced by the
  empirical mean across chains, so computing this criterion requires that at
  least 2 chains are present. The outer expectation is computed by the caller
  (e.g. in the `GradientBasedTrajectoryLengthAdaptation` kernel).

  This can be thought of as the standard expected squared jump distance (ESJD)
  criterion, except that the jump distance is computed in the space of centered
  squared L2 norms.

  Unlike ChEES, regular ESJD is maximized by perfectly anticorrelated proposals,
  which can give excellent mean estimates but terrible variance estimates;
  maximizing ChEES should give good estimates across a wider range of types of
  posterior expectations.

  Args:
    previous_state: (Possibly nested) floating point `Tensor`. The previous
      state of the HMC chain.
    proposed_state: (Possibly nested) floating point `Tensor`. The proposed
      state of the HMC chain.
    accept_prob: Floating `Tensor`. Probability of acceping the proposed state.
    validate_args: Whether to perform non-static argument validation.
    experimental_shard_axis_names: A structure of string names indicating how
      members of the state are sharded.
    experimental_reduce_chain_axis_names: A string or list of string names
      indicating which named chain axes to reduce over when computing the
      criterion.

  Returns:
    chees: The value of the ChEES criterion.

  Raises:
    ValueError: If `accept_prob` indicates that there are fewer than 2 chains.

  #### References

  [1]: Hoffman, M., Radul, A., & Sountsov, P. (2020). An Adaptive MCMC Scheme
       for Setting Trajectory Lengths in Hamiltonian Monte Carlo. In
       preparation.

  """
  batch_ndims = ps.rank(accept_prob)
  batch_axes = ps.range(batch_ndims, dtype=tf.int32)
  reduce_chain_axis_names = distribute_lib.canonicalize_named_axis(
      experimental_reduce_chain_axis_names)
  # Number of total chains is local batch size * distributed axis size
  local_axis_size = ps.maximum(ps.size(accept_prob), 1)
  distributed_axis_size = int(ps.reduce_prod([
      distribute_lib.get_axis_size(a) for a in reduce_chain_axis_names]))
  num_chains = local_axis_size * distributed_axis_size
  num_chains_ = tf.get_static_value(num_chains)
  if num_chains_ is not None:
    if num_chains_ < 2:
      raise ValueError(
          'chees_criterion requires at least 2 chains. Got: {}'.format(
              num_chains_))
  elif validate_args:
    with tf.control_dependencies([
        assert_util.assert_greater_equal(
            num_chains, 2, 'chees_criterion requires at least 2 chains.')
    ]):
      previous_state = tf.nest.map_structure(tf.identity, previous_state)

  def _center_previous_state(x):
    # The empirical mean here is a stand-in for the true mean, so we drop the
    # gradient that flows through this term.
    x_mean = _reduce_mean_with_axes(
        x, batch_axes, reduce_chain_axis_names)
    return x - tf.stop_gradient(x_mean)

  def _center_proposed_state(x):
    # The empirical mean here is a stand-in for the true mean, so we drop the
    # gradient that flows through this term. The goal here is to get a reliable
    # diagnostic of the unrelying dynamics, rather than incorporating the effect
    # of the MetropolisHastings correction.
    # TODO(mhoffman): Needs more experimentation.
    expanded_accept_prob = bu.left_justified_expand_dims_like(
        accept_prob, x)

    # accept_prob is zero when x is NaN, but we still want to sanitize such
    # values.
    x_safe = tf.where(tf.math.is_finite(x), x, tf.zeros_like(x))
    # If all accept_prob's are zero, the x_center will have a nonsense value,
    # but we'll discard the resultant gradients later on, so it's fine.
    x_center = (
        _reduce_sum_with_axes(expanded_accept_prob * x_safe, batch_axes,
                              reduce_chain_axis_names) /
        (_reduce_sum_with_axes(expanded_accept_prob, batch_axes,
                               reduce_chain_axis_names) + 1e-20))

    return x - tf.stop_gradient(x_center)

  def _sum_event_part(x, shard_axes=None):
    event_axes = ps.range(batch_ndims, ps.rank(x))
    return distribute_lib.psum(tf.reduce_sum(x, axis=event_axes), shard_axes)

  def _sum_event(x):
    event_parts = _map_structure_up_to_with_axes(
        x, _sum_event_part, x,
        experimental_shard_axis_names=experimental_shard_axis_names)
    return sum(tf.nest.flatten(event_parts))

  def _square(x):
    return tf.nest.map_structure(tf.square, x)

  def _sub(x, y):
    return tf.nest.map_structure(lambda x, y: x - y, x, y)

  previous_state = tf.nest.map_structure(_center_previous_state, previous_state)
  proposed_state = tf.nest.map_structure(_center_proposed_state, proposed_state)
  chees = 0.25 * tf.square(
      _sum_event(_sub(_square(proposed_state), _square(previous_state))))
  return chees
Ejemplo n.º 8
0
  def _log_prob(self, x):
    if self.input_output_cholesky:
      x_sqrt = x
    else:
      # Complexity: O(nbk**3)
      x_sqrt = tf.linalg.cholesky(x)

    df = tf.convert_to_tensor(self.df)
    batch_shape = self._batch_shape_tensor(df)
    event_shape = self._event_shape_tensor()
    dimension = self._dimension()
    x_ndims = ps.rank(x_sqrt)
    num_singleton_axes_to_prepend = (
        ps.maximum(ps.size(batch_shape) + 2, x_ndims) - x_ndims)
    x_with_prepended_singletons_shape = ps.concat([
        ps.ones([num_singleton_axes_to_prepend], dtype=tf.int32),
        ps.shape(x_sqrt)
    ], 0)
    x_sqrt = tf.reshape(x_sqrt, x_with_prepended_singletons_shape)
    ndims = ps.rank(x_sqrt)
    # sample_ndims = ndims - batch_ndims - event_ndims
    sample_ndims = ndims - ps.size(batch_shape) - 2
    sample_shape = ps.shape(x_sqrt)[:sample_ndims]

    # We need to be able to pre-multiply each matrix by its corresponding
    # batch scale matrix. Since a Distribution Tensor supports multiple
    # samples per batch, this means we need to reshape the input matrix `x`
    # so that the first b dimensions are batch dimensions and the last two
    # are of shape [dimension, dimensions*number_of_samples]. Doing these
    # gymnastics allows us to do a batch_solve.
    #
    # After we're done with sqrt_solve (the batch operation) we need to undo
    # this reshaping so what we're left with is a Tensor partitionable by
    # sample, batch, event dimensions.

    # Complexity: O(nbk**2) since transpose must access every element.
    scale_sqrt_inv_x_sqrt = x_sqrt
    perm = ps.concat([ps.range(sample_ndims, ndims),
                      ps.range(0, sample_ndims)], 0)
    scale_sqrt_inv_x_sqrt = tf.transpose(a=scale_sqrt_inv_x_sqrt, perm=perm)
    last_dim_size = (
        ps.cast(dimension, dtype=tf.int32) *
        ps.reduce_prod(x_with_prepended_singletons_shape[:sample_ndims]))
    shape = ps.concat(
        [x_with_prepended_singletons_shape[sample_ndims:-2],
         [ps.cast(dimension, dtype=tf.int32), last_dim_size]],
        axis=0)
    scale_sqrt_inv_x_sqrt = tf.reshape(scale_sqrt_inv_x_sqrt, shape)

    # Complexity: O(nbM*k) where M is the complexity of the operator solving a
    # vector system. For LinearOperatorLowerTriangular, each solve is O(k**2) so
    # this step has complexity O(nbk^3).
    scale_sqrt_inv_x_sqrt = self._scale.solve(scale_sqrt_inv_x_sqrt)

    # Undo make batch-op ready.
    # Complexity: O(nbk**2)
    shape = ps.concat(
        [ps.shape(scale_sqrt_inv_x_sqrt)[:-2], event_shape, sample_shape],
        axis=0)
    scale_sqrt_inv_x_sqrt = tf.reshape(scale_sqrt_inv_x_sqrt, shape)
    perm = ps.concat([
        ps.range(ndims - sample_ndims, ndims),
        ps.range(0, ndims - sample_ndims)
    ], 0)
    scale_sqrt_inv_x_sqrt = tf.transpose(a=scale_sqrt_inv_x_sqrt, perm=perm)

    # Write V = SS', X = LL'. Then:
    # tr[inv(V) X] = tr[inv(S)' inv(S) L L']
    #              = tr[inv(S) L L' inv(S)']
    #              = tr[(inv(S) L) (inv(S) L)']
    #              = sum_{ik} (inv(S) L)_{ik}**2
    # The second equality follows from the cyclic permutation property.
    # Complexity: O(nbk**2)
    trace_scale_inv_x = tf.reduce_sum(
        tf.square(scale_sqrt_inv_x_sqrt), axis=[-2, -1])

    # Complexity: O(nbk)
    half_log_det_x = tf.reduce_sum(
        tf.math.log(tf.linalg.diag_part(x_sqrt)), axis=[-1])

    # Complexity: O(nbk**2)
    log_prob = ((df - dimension - 1.) * half_log_det_x -
                0.5 * trace_scale_inv_x -
                self._log_normalization(df=df, scale=self._scale))

    # Set shape hints.
    # Try to merge what we know from the input x with what we know from the
    # parameters of this distribution.
    if tensorshape_util.rank(x.shape) is not None and tensorshape_util.rank(
        self.batch_shape) is not None:
      tensorshape_util.set_shape(
          log_prob,
          tf.broadcast_static_shape(x.shape[:-2], self.batch_shape))

    return log_prob
    def _bootstrap_from_inner_results(self, init_state, inner_results):
        step_size = self.step_size_getter_fn(inner_results)

        log_accept_prob = self.log_accept_prob_getter_fn(inner_results)

        state_parts = tf.nest.flatten(init_state)
        step_size_parts = tf.nest.flatten(step_size)

        if self._parameters['shrinkage_target'] is None:
            shrinkage_target_parts = [None] * len(step_size_parts)
        else:
            shrinkage_target_parts = tf.nest.flatten(
                self._parameters['shrinkage_target'])
            if len(shrinkage_target_parts) not in [1, len(step_size_parts)]:
                raise ValueError(
                    '`shrinkage_target` should be a Tensor or list of tensors of '
                    'same length as `step_size`. Found len(`step_size`) = {} and '
                    'len(shrinkage_target) = {}'.format(
                        len(step_size_parts), len(shrinkage_target_parts)))
            if len(shrinkage_target_parts) < len(step_size_parts):
                shrinkage_target_parts *= len(step_size_parts)

        dtype = dtype_util.common_dtype(step_size_parts, tf.float32)
        error_sum, log_averaging_step, log_shrinkage_target = [], [], []
        for state_part, step_size_part, shrinkage_target_part in zip(
                state_parts, step_size_parts, shrinkage_target_parts):
            num_reduce_dims = ps.minimum(
                ps.rank(log_accept_prob),
                ps.rank(state_part) - ps.rank(step_size_part))
            reduced_log_accept_prob = reduce_logmeanexp(
                log_accept_prob,
                axis=ps.range(num_reduce_dims),
                experimental_named_axis=self.
                experimental_reduce_chain_axis_names)
            reduce_indices = get_differing_dims(reduced_log_accept_prob,
                                                step_size_part)
            reduced_log_accept_prob = reduce_logmeanexp(
                reduced_log_accept_prob, axis=reduce_indices, keepdims=True)
            error_sum.append(
                tf.zeros_like(reduced_log_accept_prob, dtype=dtype))
            log_averaging_step.append(
                tf.zeros_like(step_size_part, dtype=dtype))

            if shrinkage_target_part is None:
                log_shrinkage_target.append(
                    float(np.log(10.)) + tf.math.log(step_size_part))
            else:
                log_shrinkage_target.append(
                    tf.math.log(tf.cast(shrinkage_target_part, dtype)))

        return DualAveragingStepSizeAdaptationResults(
            inner_results=inner_results,
            step=tf.constant(0, dtype=tf.int32),
            target_accept_prob=tf.cast(self.parameters['target_accept_prob'],
                                       log_accept_prob.dtype),
            log_shrinkage_target=log_shrinkage_target,
            exploration_shrinkage=tf.cast(
                self.parameters['exploration_shrinkage'], dtype),
            step_count_smoothing=tf.cast(
                self.parameters['step_count_smoothing'], dtype),
            decay_rate=tf.cast(self.parameters['decay_rate'], dtype),
            error_sum=error_sum,
            log_averaging_step=log_averaging_step,
            new_step_size=step_size,
            num_adaptation_steps=tf.cast(self.num_adaptation_steps,
                                         dtype=tf.int32))
Ejemplo n.º 10
0
 def log_gamma_log_prob(x):
   counter['target_calls'] += 1
   event_dims = ps.range(independent_chain_ndims, ps.rank(x))
   return self._log_gamma_log_prob(x, event_dims)
Ejemplo n.º 11
0
 def log_gamma_log_prob(x):
   event_dims = ps.range(independent_chain_ndims, ps.rank(x))
   return self._log_gamma_log_prob(x, event_dims)
Ejemplo n.º 12
0
def particle_filter(
        observations,
        initial_state_prior,
        transition_fn,
        observation_fn,
        num_particles,
        initial_state_proposal=None,
        proposal_fn=None,
        resample_criterion_fn=ess_below_threshold,
        rejuvenation_kernel_fn=None,  # TODO(davmre): not yet supported. pylint: disable=unused-argument
        num_transitions_per_observation=1,
        num_steps_state_history_to_pass=None,
        num_steps_observation_history_to_pass=None,
        seed=None,
        name=None):  # pylint: disable=g-doc-args
    """Samples a series of particles representing filtered latent states.

  The particle filter samples from the sequence of "filtering" distributions
  `p(state[t] | observations[:t])` over latent
  states: at each point in time, this is the distribution conditioned on all
  observations *up to that time*. Because particles may be resampled, a particle
  at time `t` may be different from the particle with the same index at time
  `t + 1`. To reconstruct trajectories by tracing back through the resampling
  process, see `tfp.mcmc.experimental.reconstruct_trajectories`.

  ${particle_filter_arg_str}
  Returns:
    particles: a (structure of) Tensor(s) matching the latent state, each
      of shape
      `concat([[num_timesteps, num_particles, b1, ..., bN], event_shape])`,
      representing (possibly weighted) samples from the series of filtering
      distributions `p(latent_states[t] | observations[:t])`.
    log_weights: `float` `Tensor` of shape
      `[num_timesteps, num_particles, b1, ..., bN]`, such that
      `log_weights[t, :]` are the logarithms of normalized importance weights
      (such that `exp(reduce_logsumexp(log_weights), axis=-1) == 1.`) of
      the particles at time `t`. These may be used in conjunction with
      `particles` to compute expectations under the series of filtering
      distributions.
    parent_indices: `int` `Tensor` of shape
      `[num_timesteps, num_particles, b1, ..., bN]`,
      such that `parent_indices[t, k]` gives the index of the particle at
      time `t - 1` that the `k`th particle at time `t` is immediately descended
      from. See also
      `tfp.experimental.mcmc.reconstruct_trajectories`.
    step_log_marginal_likelihoods: float `Tensor` of shape
      `[num_observation_steps, b1, ..., bN]`,
      giving the natural logarithm of an unbiased estimate of
      `p(observations[t] | observations[:t])` at each observed timestep `t`.
      Note that (by [Jensen's inequality](
      https://en.wikipedia.org/wiki/Jensen%27s_inequality))
      this is *smaller* in expectation than the true
      `log p(observations[t] | observations[:t])`.

  ${non_markovian_specification_str}
  """
    seed = SeedStream(seed, 'particle_filter')
    with tf.name_scope(name or 'particle_filter'):
        num_observation_steps = prefer_static.shape(
            tf.nest.flatten(observations)[0])[0]
        num_timesteps = (1 + num_transitions_per_observation *
                         (num_observation_steps - 1))

        # If no criterion is specified, default is to resample at every step.
        if not resample_criterion_fn:
            resample_criterion_fn = lambda _: True

        # Dress up the prior and prior proposal as a fake `transition_fn` and
        # `proposal_fn` respectively.
        prior_fn = lambda _1, _2: SampleParticles(  # pylint: disable=g-long-lambda
            initial_state_prior, num_particles)
        prior_proposal_fn = (
            None if initial_state_proposal is None else
            lambda _1, _2: SampleParticles(  # pylint: disable=g-long-lambda
                initial_state_proposal, num_particles))

        # Initially the particles all have the same weight, `1. / num_particles`.
        broadcast_batch_shape = tf.convert_to_tensor(functools.reduce(
            prefer_static.broadcast_shape,
            tf.nest.flatten(initial_state_prior.batch_shape_tensor()), []),
                                                     dtype=tf.int32)
        log_uniform_weights = prefer_static.zeros(
            prefer_static.concat([[num_particles], broadcast_batch_shape],
                                 axis=0),
            dtype=tf.float32) - prefer_static.log(num_particles)

        # Initialize from the prior, and incorporate the first observation.
        initial_step_results = _filter_one_step(
            step=0,
            # `previous_particles` at the first step is a dummy quantity, used only
            # to convey state structure and num_particles to an optional
            # proposal fn.
            previous_particles=prior_fn(0, []).sample(),
            log_weights=log_uniform_weights,
            observation=tf.nest.map_structure(lambda x: tf.gather(x, 0),
                                              observations),
            transition_fn=prior_fn,
            observation_fn=observation_fn,
            proposal_fn=prior_proposal_fn,
            resample_criterion_fn=resample_criterion_fn,
            seed=seed)

        def _loop_body(step, previous_step_results, accumulated_step_results,
                       state_history):
            """Take one step in dynamics and accumulate marginal likelihood."""

            step_has_observation = (
                # The second of these conditions subsumes the first, but both are
                # useful because the first can often be evaluated statically.
                prefer_static.equal(num_transitions_per_observation, 1) |
                prefer_static.equal(step % num_transitions_per_observation, 0))
            observation_idx = step // num_transitions_per_observation
            current_observation = tf.nest.map_structure(
                lambda x, step=step: tf.gather(x, observation_idx),
                observations)

            history_to_pass_into_fns = {}
            if num_steps_observation_history_to_pass:
                history_to_pass_into_fns[
                    'observation_history'] = _gather_history(
                        observations, observation_idx,
                        num_steps_observation_history_to_pass)
            if num_steps_state_history_to_pass:
                history_to_pass_into_fns['state_history'] = state_history

            new_step_results = _filter_one_step(
                step=step,
                previous_particles=previous_step_results.particles,
                log_weights=previous_step_results.log_weights,
                observation=current_observation,
                transition_fn=functools.partial(transition_fn,
                                                **history_to_pass_into_fns),
                observation_fn=functools.partial(observation_fn,
                                                 **history_to_pass_into_fns),
                proposal_fn=(None
                             if proposal_fn is None else functools.partial(
                                 proposal_fn, **history_to_pass_into_fns)),
                resample_criterion_fn=resample_criterion_fn,
                has_observation=step_has_observation,
                seed=seed)

            return _update_loop_variables(step, new_step_results,
                                          accumulated_step_results,
                                          state_history)

        loop_results = tf.while_loop(
            cond=lambda step, *_: step < num_timesteps,
            body=_loop_body,
            loop_vars=_initialize_loop_variables(
                initial_step_results, num_steps_state_history_to_pass,
                num_timesteps))

        results = tf.nest.map_structure(lambda ta: ta.stack(),
                                        loop_results.accumulated_step_results)
        if num_transitions_per_observation != 1:
            # Return a log-prob for each observed step.
            observed_steps = prefer_static.range(
                0, num_timesteps, num_transitions_per_observation)
            results = results._replace(step_log_marginal_likelihood=tf.gather(
                results.step_log_marginal_likelihood, observed_steps))
        return results
Ejemplo n.º 13
0
def _compute_log_acceptance_correction(current_state_parts,
                                       proposed_state_parts,
                                       current_volatility_parts,
                                       proposed_volatility_parts,
                                       current_drift_parts,
                                       proposed_drift_parts,
                                       step_size_parts,
                                       independent_chain_ndims,
                                       experimental_shard_axis_names=None,
                                       name=None):
  r"""Helper to `kernel` which computes the log acceptance-correction.

  Computes `log_acceptance_correction` as described in `MetropolisHastings`
  class. The proposal density is normal. More specifically,

   ```none
  q(proposed_state | current_state) \sim N(current_state + current_drift,
  step_size * current_volatility**2)

  q(current_state | proposed_state) \sim N(proposed_state + proposed_drift,
  step_size * proposed_volatility**2)
  ```

  The `log_acceptance_correction` is then

  ```none
  log_acceptance_correctio = q(current_state | proposed_state)
  - q(proposed_state | current_state)
  ```

  Args:
    current_state_parts: Python `list` of `Tensor`s representing the value(s) of
      the current state of the chain.
    proposed_state_parts:  Python `list` of `Tensor`s representing the value(s)
      of the proposed state of the chain. Must broadcast with the shape of
      `current_state_parts`.
    current_volatility_parts: Python `list` of `Tensor`s representing the value
      of `volatility_fn(*current_volatility_parts)`. Must broadcast with the
      shape of `current_state_parts`.
    proposed_volatility_parts: Python `list` of `Tensor`s representing the value
      of `volatility_fn(*proposed_volatility_parts)`. Must broadcast with the
      shape of `current_state_parts`
    current_drift_parts: Python `list` of `Tensor`s representing value of the
      drift `_get_drift(*current_state_parts, ..)`. Must broadcast with the
      shape of `current_state_parts`.
    proposed_drift_parts: Python `list` of `Tensor`s representing value of the
      drift `_get_drift(*proposed_drift_parts, ..)`. Must broadcast with the
      shape of `current_state_parts`.
    step_size_parts: Python `list` of `Tensor`s representing the step size for
      Euler-Maruyama method. Must broadcast with the shape of
      `current_state_parts`.
    independent_chain_ndims: Scalar `int` `Tensor` representing the number of
      leftmost `Tensor` dimensions which index independent chains.
    experimental_shard_axis_names: A structure of string names indicating how
      members of the state are sharded.
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., 'compute_log_acceptance_correction').

  Returns:
    log_acceptance_correction: `Tensor` representing the `log`
      acceptance-correction.  (See docstring for mathematical definition.)
  """

  with tf.name_scope(name or 'compute_log_acceptance_correction'):

    proposed_log_density_parts = []
    dual_log_density_parts = []

    if experimental_shard_axis_names is None:
      experimental_shard_axis_names = [None] * len(current_state_parts)

    for [
        current_state,
        proposed_state,
        current_volatility,
        proposed_volatility,
        current_drift,
        proposed_drift,
        step_size,
        shard_axes
    ] in zip(
        current_state_parts,
        proposed_state_parts,
        current_volatility_parts,
        proposed_volatility_parts,
        current_drift_parts,
        proposed_drift_parts,
        step_size_parts,
        experimental_shard_axis_names
    ):
      axis = ps.range(independent_chain_ndims, ps.rank(current_state))

      state_diff = proposed_state - current_state

      current_volatility *= tf.sqrt(step_size)

      proposed_energy = (state_diff - current_drift) / current_volatility

      proposed_volatility *= tf.sqrt(step_size)
      # Compute part of `q(proposed_state | current_state)`
      def reduce_sum(shard_axes, x, axis=None):
        x = tf.reduce_sum(x, axis)
        if shard_axes is not None:
          x = distribute_lib.psum(x, shard_axes)
        return x
      proposed_energy = (
          reduce_sum(
              shard_axes,
              mcmc_util.safe_sum(
                  [tf.math.log(current_volatility),
                   0.5 * (proposed_energy**2)]),
              axis=axis))
      proposed_log_density_parts.append(-proposed_energy)

      # Compute part of `q(current_state | proposed_state)`
      dual_energy = (state_diff + proposed_drift) / proposed_volatility
      dual_energy = (
          reduce_sum(
              shard_axes,
              mcmc_util.safe_sum(
                  [tf.math.log(proposed_volatility), 0.5 * (dual_energy**2)]),
              axis=axis))
      dual_log_density_parts.append(-dual_energy)

    # Compute `q(proposed_state | current_state)`
    proposed_log_density_reduce = tf.add_n(proposed_log_density_parts)
    # Compute `q(current_state | proposed_state)`
    dual_log_density_reduce = tf.add_n(dual_log_density_parts)

    return mcmc_util.safe_sum([
        dual_log_density_reduce, -proposed_log_density_reduce])
Ejemplo n.º 14
0
def auto_correlation(x,
                     axis=-1,
                     max_lags=None,
                     center=True,
                     normalize=True,
                     name='auto_correlation'):
    """Auto correlation along one axis.

  Given a `1-D` wide sense stationary (WSS) sequence `X`, the auto correlation
  `RXX` may be defined as  (with `E` expectation and `Conj` complex conjugate)

  ```
  RXX[m] := E{ W[m] Conj(W[0]) } = E{ W[0] Conj(W[-m]) },
  W[n]   := (X[n] - MU) / S,
  MU     := E{ X[0] },
  S**2   := E{ (X[0] - MU) Conj(X[0] - MU) }.
  ```

  This function takes the viewpoint that `x` is (along one axis) a finite
  sub-sequence of a realization of (WSS) `X`, and then uses `x` to produce an
  estimate of `RXX[m]` as follows:

  After extending `x` from length `L` to `inf` by zero padding, the auto
  correlation estimate `rxx[m]` is computed for `m = 0, 1, ..., max_lags` as

  ```
  rxx[m] := (L - m)**-1 sum_n w[n + m] Conj(w[n]),
  w[n]   := (x[n] - mu) / s,
  mu     := L**-1 sum_n x[n],
  s**2   := L**-1 sum_n (x[n] - mu) Conj(x[n] - mu)
  ```

  The error in this estimate is proportional to `1 / sqrt(len(x) - m)`, so users
  often set `max_lags` small enough so that the entire output is meaningful.

  Note that since `mu` is an imperfect estimate of `E{ X[0] }`, and we divide by
  `len(x) - m` rather than `len(x) - m - 1`, our estimate of auto correlation
  contains a slight bias, which goes to zero as `len(x) - m --> infinity`.

  Args:
    x:  `float32` or `complex64` `Tensor`.
    axis:  Python `int`. The axis number along which to compute correlation.
      Other dimensions index different batch members.
    max_lags:  Positive `int` tensor.  The maximum value of `m` to consider (in
      equation above).  If `max_lags >= x.shape[axis]`, we effectively re-set
      `max_lags` to `x.shape[axis] - 1`.
    center:  Python `bool`.  If `False`, do not subtract the mean estimate `mu`
      from `x[n]` when forming `w[n]`.
    normalize:  Python `bool`.  If `False`, do not divide by the variance
      estimate `s**2` when forming `w[n]`.
    name:  `String` name to prepend to created ops.

  Returns:
    `rxx`: `Tensor` of same `dtype` as `x`.  `rxx.shape[i] = x.shape[i]` for
      `i != axis`, and `rxx.shape[axis] = max_lags + 1`.

  Raises:
    TypeError:  If `x` is not a supported type.
  """
    # Implementation details:
    # Extend length N / 2 1-D array x to length N by zero padding onto the end.
    # Then, set
    #   F[x]_k := sum_n x_n exp{-i 2 pi k n / N }.
    # It is not hard to see that
    #   F[x]_k Conj(F[x]_k) = F[R]_k, where
    #   R_m := sum_n x_n Conj(x_{(n - m) mod N}).
    # One can also check that R_m / (N / 2 - m) is an unbiased estimate of RXX[m].

    # Since F[x] is the DFT of x, this leads us to a zero-padding and FFT/IFFT
    # based version of estimating RXX.
    # Note that this is a special case of the Wiener-Khinchin Theorem.
    with tf.name_scope(name):
        x = tf.convert_to_tensor(x, name='x')

        # Rotate dimensions of x in order to put axis at the rightmost dim.
        # FFT op requires this.
        rank = ps.rank(x)
        if axis < 0:
            axis = rank + axis
        shift = rank - 1 - axis
        # Suppose x.shape[axis] = T, so there are T 'time' steps.
        #   ==> x_rotated.shape = B + [T],
        # where B is x_rotated's batch shape.
        x_rotated = distribution_util.rotate_transpose(x, shift)

        if center:
            x_rotated = x_rotated - tf.reduce_mean(
                x_rotated, axis=-1, keepdims=True)

        # x_len = N / 2 from above explanation.  The length of x along axis.
        # Get a value for x_len that works in all cases.
        x_len = ps.shape(x_rotated)[-1]

        # TODO(langmore) Investigate whether this zero padding helps or hurts.  At
        # the moment is necessary so that all FFT implementations work.
        # Zero pad to the next power of 2 greater than 2 * x_len, which equals
        # 2**(ceil(Log_2(2 * x_len))).  Note: Log_2(X) = Log_e(X) / Log_e(2).
        x_len_float64 = ps.cast(x_len, np.float64)
        target_length = ps.pow(np.float64(2.),
                               ps.ceil(ps.log(x_len_float64 * 2) / np.log(2.)))
        pad_length = ps.cast(target_length - x_len_float64, np.int32)

        # We should have:
        # x_rotated_pad.shape = x_rotated.shape[:-1] + [T + pad_length]
        #                     = B + [T + pad_length]
        x_rotated_pad = distribution_util.pad(x_rotated,
                                              axis=-1,
                                              back=True,
                                              count=pad_length)

        dtype = x.dtype
        if not dtype_util.is_complex(dtype):
            if not dtype_util.is_floating(dtype):
                raise TypeError(
                    'Argument x must have either float or complex dtype'
                    ' found: {}'.format(dtype))
            x_rotated_pad = tf.complex(
                x_rotated_pad,
                dtype_util.as_numpy_dtype(dtype_util.real_dtype(dtype))(0.))

        # Autocorrelation is IFFT of power-spectral density (up to some scaling).
        fft_x_rotated_pad = tf.signal.fft(x_rotated_pad)
        spectral_density = fft_x_rotated_pad * tf.math.conj(fft_x_rotated_pad)
        # shifted_product is R[m] from above detailed explanation.
        # It is the inner product sum_n X[n] * Conj(X[n - m]).
        shifted_product = tf.signal.ifft(spectral_density)

        # Cast back to real-valued if x was real to begin with.
        shifted_product = tf.cast(shifted_product, dtype)

        # Figure out if we can deduce the final static shape, and set max_lags.
        # Use x_rotated as a reference, because it has the time dimension in the far
        # right, and was created before we performed all sorts of crazy shape
        # manipulations.
        know_static_shape = True
        if not tensorshape_util.is_fully_defined(x_rotated.shape):
            know_static_shape = False
        if max_lags is None:
            max_lags = x_len - 1
        else:
            max_lags = tf.convert_to_tensor(max_lags, name='max_lags')
            max_lags_ = tf.get_static_value(max_lags)
            if max_lags_ is None or not know_static_shape:
                know_static_shape = False
                max_lags = tf.minimum(x_len - 1, max_lags)
            else:
                max_lags = min(x_len - 1, max_lags_)

        # Chop off the padding.
        # We allow users to provide a huge max_lags, but cut it off here.
        # shifted_product_chopped.shape = x_rotated.shape[:-1] + [max_lags]
        shifted_product_chopped = shifted_product[..., :max_lags + 1]

        # If possible, set shape.
        if know_static_shape:
            chopped_shape = tensorshape_util.as_list(x_rotated.shape)
            chopped_shape[-1] = min(x_len, max_lags + 1)
            tensorshape_util.set_shape(shifted_product_chopped, chopped_shape)

        # Recall R[m] is a sum of N / 2 - m nonzero terms x[n] Conj(x[n - m]).  The
        # other terms were zeros arising only due to zero padding.
        # `denominator = (N / 2 - m)` (defined below) is the proper term to
        # divide by to make this an unbiased estimate of the expectation
        # E[X[n] Conj(X[n - m])].
        x_len = ps.cast(x_len, dtype_util.real_dtype(dtype))
        max_lags = ps.cast(max_lags, dtype_util.real_dtype(dtype))
        denominator = x_len - ps.range(0., max_lags + 1.)
        denominator = ps.cast(denominator, dtype)
        shifted_product_rotated = shifted_product_chopped / denominator

        if normalize:
            shifted_product_rotated /= shifted_product_rotated[..., :1]

        # Transpose dimensions back to those of x.
        return distribution_util.rotate_transpose(shifted_product_rotated,
                                                  -shift)
Ejemplo n.º 15
0
def _sample_next(target_log_prob_fn,
                 current_state_parts,
                 step_sizes,
                 max_doublings,
                 current_target_log_prob,
                 batch_rank,
                 seed=None,
                 experimental_shard_axis_names=None,
                 name=None):
  """Applies a single iteration of slice sampling update.

  Applies hit and run style slice sampling. Chooses a uniform random direction
  on the unit sphere in the event space. Applies the one dimensional slice
  sampling update along that direction.

  Args:
    target_log_prob_fn: Python callable which takes an argument like
      `*current_state_parts` and returns its (possibly unnormalized) log-density
      under the target distribution.
    current_state_parts: Python `list` of `Tensor`s representing the current
      state(s) of the Markov chain(s). The first `independent_chain_ndims` of
      the `Tensor`(s) index different chains.
    step_sizes: Python `list` of `Tensor`s. Provides a measure of the width
      of the density. Used to find the slice bounds. Must broadcast with the
      shape of `current_state_parts`.
    max_doublings: Integer number of doublings to allow while locating the slice
      boundaries.
    current_target_log_prob: `Tensor` representing the value of
      `target_log_prob_fn(*current_state_parts)`. The only reason to specify
      this argument is to reduce TF graph size.
    batch_rank: Integer. The number of axes in the state that correspond to
      independent batches.
    seed: Tensor seed pair.
    experimental_shard_axis_names: A structure of string names indicating how
      members of the state are sharded.
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., 'find_slice_bounds').

  Returns:
    proposed_state_parts: Tensor or Python list of `Tensor`s representing the
      state(s) of the Markov chain(s) at each result step. Has same shape as
      input `current_state_parts`.
    proposed_target_log_prob: `Tensor` representing the value of
      `target_log_prob_fn` at `next_state`.
    bounds_satisfied: Boolean `Tensor` of the same shape as the log density.
      True indicates whether the an interval containing the slice for that
      batch was found successfully.
    direction: `Tensor` or Python list of `Tensors`s representing the direction
      along which the slice was sampled. Has the same shape and dtype(s) as
      `current_state_parts`.
    upper_bounds: `Tensor` of batch shape and the dtype of the input state. The
      upper bounds of the slices along the sampling direction.
    lower_bounds: `Tensor` of batch shape and the dtype of the input state. The
      lower bounds of the slices along the sampling direction.
  """
  direction_seed, slice_seed = samplers.split_seed(seed)
  with tf.name_scope(name or 'sample_next'):
    # First step: Choose a random direction.
    # Direction is a list of tensors. The i'th tensor should have the same shape
    # as the i'th state part.
    direction = _choose_random_direction(
        current_state_parts,
        batch_rank=batch_rank,
        seed=direction_seed,
        experimental_shard_axis_names=experimental_shard_axis_names)

    # Interpolates the step sizes for the chosen direction.
    # Applies an ellipsoidal interpolation to compute the step direction for
    # the chosen direction. Suppose we are given step sizes for each direction.
    # Label these s_1, s_2, ... s_k. These are the step sizes to use if moving
    # in a direction parallel to one of the axes. Consider an ellipsoid which
    # intercepts the i'th axis at s_i. The step size for a direction specified
    # by the unit vector (n_1, n_2 ...n_k) is then defined as the intersection
    # of the line through this vector with this ellipsoid.
    #
    # One can show that the length of the vector from the origin to the
    # intersection point is given by:
    # 1 / sqrt(n_1^2 / s_1^2  + n_2^2 / s_2^2  + ...).
    #
    # Proof:
    # The equation of the ellipsoid is:
    # Sum_i [x_i^2 / s_i^2 ] = 1. Let n be a unit direction vector. Points
    # along the line given by n may be parameterized as alpha*n where alpha is
    # the distance along the vector. Plugging this into the equation for the
    # ellipsoid, we get:
    # alpha^2 ( n_1^2 / s_1^2 + n_2^2 / s_2^2 + ...) = 1
    # so alpha = \sqrt { \frac{1} { ( n_1^2 / s_1^2 + n_2^2 / s_2^2 + ...) } }
    reduce_axes = [ps.range(batch_rank, ps.rank(dirn_part))
                   for dirn_part in direction]
    experimental_shard_axis_names = (experimental_shard_axis_names
                                     or ([None] * len(direction)))

    def reduce_sum(v, axis, shard_axes):
      out = tf.reduce_sum(v, axis=axis)
      if shard_axes is not None:
        out = distribute_lib.psum(out, shard_axes)
      return out

    components = [
        reduce_sum((dirn_part / step_size)**2, reduce_axes[i], shard_axes)
        for i, (step_size, dirn_part, shard_axes) in enumerate(zip(
            step_sizes, direction, experimental_shard_axis_names))
    ]
    step_size = tf.math.rsqrt(tf.add_n(components))
    # Computes the rank of a tensor. Uses the static rank if possible.
    state_part_ranks = [ps.rank(part)
                        for part in current_state_parts]

    def _step_along_direction(alpha):
      """Converts the scalar alpha into an n-dim vector with full state info.

      Computes x_0 + alpha * direction where x_0 is the current state and
      direction is the direction chosen above.

      Args:
        alpha: A tensor of shape equal to the batch dimensions of
          `current_state_parts`.

      Returns:
        state_parts: Tensor or Python list of `Tensor`s representing the
          state(s) of the Markov chain(s) for a given alpha and a given chosen
          direction. Has the same shape as `current_state_parts`.
      """
      padded_alphas = [_right_pad(alpha, final_rank=part_rank)
                       for part_rank in state_part_ranks]

      state_parts = [state_part + padded_alpha * direction_part
                     for state_part, direction_part, padded_alpha in
                     zip(current_state_parts, direction, padded_alphas)]
      return state_parts

    def projected_target_log_prob_fn(alpha):
      """The target log density projected along the chosen direction.

      Args:
        alpha: A tensor of shape equal to the batch dimensions of
          `current_state_parts`.

      Returns:
        Target log density evaluated at x_0 + alpha * direction where x_0 is the
        current state and direction is the direction chosen above. Has the same
        shape as `alpha`.
      """
      return target_log_prob_fn(*_step_along_direction(alpha))

    alpha_init = tf.zeros_like(current_target_log_prob,
                               dtype=current_state_parts[0].dtype)
    [
        next_alpha,
        next_target_log_prob,
        bounds_satisfied,
        upper_bounds,
        lower_bounds
    ] = ssu.slice_sampler_one_dim(projected_target_log_prob_fn,
                                  x_initial=alpha_init,
                                  max_doublings=max_doublings,
                                  step_size=step_size, seed=slice_seed)
    return [
        _step_along_direction(next_alpha),
        next_target_log_prob,
        bounds_satisfied,
        direction,
        upper_bounds,
        lower_bounds
    ]
Ejemplo n.º 16
0
def _compute_log_acceptance_correction(current_momentums,
                                       proposed_momentums,
                                       independent_chain_ndims,
                                       name=None):
    """Helper to `kernel` which computes the log acceptance-correction.

  A sufficient but not necessary condition for the existence of a stationary
  distribution, `p(x)`, is "detailed balance", i.e.:

  ```none
  p(x'|x) p(x) = p(x|x') p(x')
  ```

  In the Metropolis-Hastings algorithm, a state is proposed according to
  `g(x'|x)` and accepted according to `a(x'|x)`, hence
  `p(x'|x) = g(x'|x) a(x'|x)`.

  Inserting this into the detailed balance equation implies:

  ```none
      g(x'|x) a(x'|x) p(x) = g(x|x') a(x|x') p(x')
  ==> a(x'|x) / a(x|x') = p(x') / p(x) [g(x|x') / g(x'|x)]    (*)
  ```

  One definition of `a(x'|x)` which satisfies (*) is:

  ```none
  a(x'|x) = min(1, p(x') / p(x) [g(x|x') / g(x'|x)])
  ```

  (To see that this satisfies (*), notice that under this definition only at
  most one `a(x'|x)` and `a(x|x') can be other than one.)

  We call the bracketed term the "acceptance correction".

  In the case of UncalibratedHMC, the log acceptance-correction is not the log
  proposal-ratio. UncalibratedHMC augments the state-space with momentum, z.
  Assuming a standard Gaussian distribution for momentums, the chain eventually
  converges to:

  ```none
  p([x, z]) propto= target_prob(x) exp(-0.5 z**2)
  ```

  Relating this back to Metropolis-Hastings parlance, for HMC we have:

  ```none
  p([x, z]) propto= target_prob(x) exp(-0.5 z**2)
  g([x, z] | [x', z']) = g([x', z'] | [x, z])
  ```

  In other words, the MH bracketed term is `1`. However, because we desire to
  use a general MH framework, we can place the momentum probability ratio inside
  the metropolis-correction factor thus getting an acceptance probability:

  ```none
                       target_prob(x')
  accept_prob(x'|x) = -----------------  [exp(-0.5 z**2) / exp(-0.5 z'**2)]
                       target_prob(x)
  ```

  (Note: we actually need to handle the kinetic energy change at each leapfrog
  step, but this is the idea.)

  Args:
    current_momentums: `Tensor` representing the value(s) of the current
      momentum(s) of the state (parts).
    proposed_momentums: `Tensor` representing the value(s) of the proposed
      momentum(s) of the state (parts).
    independent_chain_ndims: Scalar `int` `Tensor` representing the number of
      leftmost `Tensor` dimensions which index independent chains.
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., 'compute_log_acceptance_correction').

  Returns:
    log_acceptance_correction: `Tensor` representing the `log`
      acceptance-correction.  (See docstring for mathematical definition.)
  """
    with tf.name_scope(name or 'compute_log_acceptance_correction'):
        sum_sq = lambda v: tf.reduce_sum(
            v**2.,
            axis=ps.range(  # pylint: disable=g-long-lambda
                independent_chain_ndims, ps.rank(v)))
        current_kinetic = tf.add_n([sum_sq(v) for v in current_momentums])
        proposed_kinetic = tf.add_n([sum_sq(v) for v in proposed_momentums])
        return 0.5 * mcmc_util.safe_sum([current_kinetic, -proposed_kinetic])
Ejemplo n.º 17
0
        def op(x, kernel):
            input_dtype = dtype_util.common_dtype([x, kernel],
                                                  dtype_hint=tf.float32)
            x = tf.convert_to_tensor(x, dtype=input_dtype, name='x')
            kernel = tf.convert_to_tensor(kernel,
                                          dtype=input_dtype,
                                          name='kernel')

            batch_shape, event_shape = ps.split(ps.shape(x),
                                                num_or_size_splits=[-1, 3])
            xh, xw, c_in = ps.unstack(event_shape, num=3)

            kernel_shape = ps.shape(kernel)
            c_out = kernel_shape[-1]
            kernel_batch = kernel_shape[:-2]
            assertions = _maybe_validate_input_shapes(
                kernel_shape,
                channels_in=c_in,
                filter_height=fh,
                filter_width=fw,
                validate_args=validate_args)

            with tf.control_dependencies(assertions):

                # If the kernel does not have batch shape, fall back to
                # `conv2d_transpose` (unless dilations > 1, which is not implemented in
                # `conv2d_transpose`).
                if (tf.get_static_value(ps.rank(kernel)) == 2
                        and all(d == 1 for d in dilations)):
                    return _call_conv2d_transpose(x,
                                                  kernel=kernel,
                                                  filter_shape=filter_shape,
                                                  strides=(strides, ) * rank,
                                                  padding=padding,
                                                  dilations=dilations,
                                                  c_out=c_out,
                                                  batch_shape=batch_shape,
                                                  event_shape=event_shape)

                n = ps.maximum(0, ps.rank(x) - 3)
                paddings = ps.pad(padding_vals,
                                  paddings=[[n, 1], [0, 0]],
                                  constant_values=0)

                x_pad = tf.pad(x, paddings=paddings, constant_values=0)
                x_pad_shape = ps.shape(x_pad)[:-3]
                flat_shape = ps.pad(x_pad_shape,
                                    paddings=[[0, 1]],
                                    constant_values=-1)
                flat_x = tf.reshape(x_pad, shape=flat_shape)

                idx, s = im2row_index(
                    (xh + tf.reduce_sum(padding_vals[0]),
                     xw + tf.reduce_sum(padding_vals[1]), c_in),
                    block_shape=(sub_fh, sub_fw),
                    slice_step=(1, 1),
                    dilations=dilations)

                x_ = tf.gather(flat_x, indices=idx, axis=-1)
                im_x = tf.reshape(x_,
                                  shape=ps.concat([x_pad_shape, s], axis=0))

                # Add channels to subkernel indices
                idx_event = event_ind * [[c_in, 1]]
                idx_event_channels = (idx_event[tf.newaxis] + tf.stack(
                    [ps.range(c_in),
                     tf.zeros(
                         (c_in, ), dtype=dtype)], axis=-1)[:, tf.newaxis, :])
                idx_event = tf.squeeze(tf.batch_to_space(idx_event_channels,
                                                         block_shape=[c_in],
                                                         crops=[[0, 0]]),
                                       axis=0)
                idx_event_broadcast = tf.broadcast_to(
                    idx_event,
                    shape=ps.concat(
                        [kernel_batch, ps.shape(idx_event)], axis=0))

                # Add cartesian product of batch indices, since scatter_nd can only be
                # applied to leading dimensions.
                idx_batch = tf.stack(tf.meshgrid(*[
                    ps.range(b_, delta=1, dtype=dtype)
                    for b_ in tf.unstack(kernel_batch)
                ],
                                                 indexing='ij'),
                                     axis=ps.size(kernel_batch))

                idx_batch = tf.cast(idx_batch,
                                    dtype=dtype)  # empty tensor is float

                idx_batch_broadcast = idx_batch[..., tf.newaxis, :] + tf.zeros(
                    (ps.shape(idx_event)[0], 1), dtype=dtype)
                idx_kernel = tf.concat(
                    [idx_batch_broadcast, idx_event_broadcast], axis=-1)

                kernel_mat = tf.scatter_nd(
                    idx_kernel,
                    updates=kernel,
                    shape=ps.cast(ps.concat([
                        kernel_batch,
                        [sub_fh * sub_fw * c_in, strides**2, c_out]
                    ],
                                            axis=0),
                                  dtype=dtype))

                kernel_mat = tf.reshape(
                    kernel_mat,
                    shape=ps.concat(
                        [ps.shape(kernel_mat)[:-2], [strides**2 * c_out]],
                        axis=0))

                kernel_mat = kernel_mat[..., tf.newaxis, :, :]
                out = tf.matmul(im_x, kernel_mat)
                broadcast_batch_shape = ps.broadcast_shape(
                    batch_shape, kernel_batch)

                if strides > 1:
                    tot_size = tf.reduce_prod(broadcast_batch_shape)
                    flat_out = tf.reshape(out,
                                          shape=ps.concat([[tot_size],
                                                           ps.shape(out)[-3:]],
                                                          axis=0))
                    out = tf.nn.depth_to_space(flat_out, block_size=strides)

                out_height = _deconv_output_length(xh,
                                                   filter_size=fh,
                                                   padding=padding,
                                                   output_padding=None,
                                                   stride=strides,
                                                   dilation=dh)
                out_width = _deconv_output_length(xw,
                                                  filter_size=fw,
                                                  padding=padding,
                                                  output_padding=None,
                                                  stride=strides,
                                                  dilation=dw)

                out = out[..., truncate_top:truncate_top + out_height,
                          truncate_left:truncate_left + out_width, :]
                out = tf.reshape(
                    out,
                    shape=ps.concat([
                        broadcast_batch_shape, [out_height, out_width, c_out]
                    ],
                                    axis=0))
                return out
Ejemplo n.º 18
0
def convolution_batch(x,
                      kernel,
                      rank,
                      strides,
                      padding,
                      data_format=None,
                      dilations=None,
                      name=None):
    """Like `tf.nn.conv2d` except applies batch of kernels to batch of `x`."""
    if rank != 2:
        raise NotImplementedError(
            'Argument `rank` currently only supports `2`; '
            'saw "{}".'.format(rank))
    if data_format is not None and data_format.upper() != 'NHWBC':
        raise ValueError(
            'Argument `data_format` currently only supports "NHWBC"; '
            'saw "{}".'.format(data_format))
    with tf.name_scope(name or 'conv2d_nhwbc'):
        # Prepare arguments.
        [
            rank,
            _,  # strides
            padding,
            dilations,
            data_format,
        ] = prepare_conv_args(rank, strides, padding, dilations)
        strides = prepare_tuple_argument(strides, rank + 2, arg_name='strides')

        dtype = dtype_util.common_dtype([x, kernel], dtype_hint=tf.float32)
        x = tf.convert_to_tensor(x, dtype=dtype, name='x')
        kernel = tf.convert_to_tensor(kernel, dtype=dtype, name='kernel')

        # Step 1: Transpose and double flatten kernel.
        # kernel.shape = B + F + [c, c']. Eg: [b, fh, fw, c, c']
        kernel_shape = prefer_static.shape(kernel)
        kernel_batch_shape, kernel_event_shape = prefer_static.split(
            kernel_shape, num_or_size_splits=[-1, rank + 2])
        kernel_batch_size = prefer_static.reduce_prod(kernel_batch_shape)
        kernel_ndims = prefer_static.rank(kernel)
        kernel_batch_ndims = kernel_ndims - rank - 2
        perm = prefer_static.concat([
            prefer_static.range(kernel_batch_ndims, kernel_batch_ndims + rank),
            prefer_static.range(0, kernel_batch_ndims),
            prefer_static.range(kernel_batch_ndims + rank, kernel_ndims),
        ],
                                    axis=0)  # Eg, [1, 2, 0, 3, 4]
        kernel = tf.transpose(kernel, perm=perm)  # F + B + [c, c']
        kernel = tf.reshape(kernel,
                            shape=prefer_static.concat([
                                kernel_event_shape[:rank],
                                [
                                    kernel_batch_size * kernel_event_shape[-2],
                                    kernel_event_shape[-1]
                                ],
                            ],
                                                       axis=0))  # F + [bc, c']

        # Step 2: Double flatten x.
        # x.shape = N + D + B + [c]
        x_shape = prefer_static.shape(x)
        [
            x_sample_shape,
            x_rank_shape,
            x_batch_shape,
            x_channel_shape,
        ] = prefer_static.split(
            x_shape, num_or_size_splits=[-1, rank, kernel_batch_ndims, 1])
        x = tf.reshape(
            x,  # N + D + B + [c]
            shape=prefer_static.concat([
                [prefer_static.reduce_prod(x_sample_shape)],
                x_rank_shape,
                [
                    prefer_static.reduce_prod(x_batch_shape) *
                    prefer_static.reduce_prod(x_channel_shape)
                ],
            ],
                                       axis=0))  # [n] + D + [bc]

        # Step 3: Apply convolution.
        y = tf.nn.depthwise_conv2d(x,
                                   kernel,
                                   strides=strides,
                                   padding=padding,
                                   data_format='NHWC',
                                   dilations=dilations)
        #  SAME: y.shape = [n, h,      w,      bcc']
        # VALID: y.shape = [n, h-fh+1, w-fw+1, bcc']

        # Step 4: Reshape/reduce for output.
        y_shape = prefer_static.shape(y)
        y = tf.reshape(y,
                       shape=prefer_static.concat(
                           [
                               x_sample_shape,
                               y_shape[1:-1],
                               kernel_batch_shape,
                               kernel_event_shape[-2:],
                           ],
                           axis=0))  # N + D' + B + [c, c']
        y = tf.reduce_sum(y, axis=-2)  # N + D' + B + [c']

        return y
Ejemplo n.º 19
0
    def __init__(self,
                 paddings=((0, 1), ),
                 mode='CONSTANT',
                 constant_values=0,
                 axis=None,
                 validate_args=False,
                 name=None):
        """Initializes the `Pad` bijector.

    Args:
      paddings: A vector-shaped `Tensor` of `int` pairs representing the number
        of elements to pad on the left and right, respectively.
        Default value: `((0, 1),)`.
      mode: One of `'CONSTANT'`, `'REFLECT'`, or `'SYMMETRIC'`
        (case-insensitive). For more details, see `tf.pad`.
      constant_values: In "CONSTANT" mode, the scalar pad value to use. Must be
        same type as `tensor`. For more details, see `tf.pad`.
      axis: The dimensions for which `paddings` are applied. Must be 1:1 with
        `paddings` or `None`.
        Default value: `None` (i.e., `tf.range(start=-len(paddings), limit=0)`).
      validate_args: Python `bool` indicating whether arguments should
        be checked for correctness.
        Default value: `False`.
      name: Python `str`, name given to ops managed by this object.
        Default value: `None` (i.e., `'pad'`).
    """
        parameters = dict(locals())
        with tf.name_scope(name or 'pad') as name:
            paddings = tensor_util.convert_nonref_to_tensor(
                paddings,
                dtype_hint=tf.int32,
                name='paddings',
                as_shape_tensor=True)
            if axis is None:
                axis = ps.range(start=-ps.size0(paddings),
                                limit=0,
                                dtype=tf.int32,
                                name='axis')
            else:
                axis = tensor_util.convert_nonref_to_tensor(
                    axis,
                    dtype_hint=tf.int32,
                    name='axis',
                    as_shape_tensor=True)
            axis_ = tf.get_static_value(axis)
            if axis_ is None:
                raise NotImplementedError(
                    'Argument `axis` must be known statically. If you need this '
                    'feature,  please contact `[email protected]`.')
            self._axis = axis
            self._paddings = paddings
            self._mode = mode
            self._constant_values = tensor_util.convert_nonref_to_tensor(
                constant_values, dtype_hint=tf.float32, name='constant_values')
            min_event_ndims_ = int(-np.min(
                np.pad(np.reshape(axis_, newshape=[-1]),
                       mode='constant',
                       pad_width=[[0, 1]])))
            super(Pad, self).__init__(forward_min_event_ndims=min_event_ndims_,
                                      inverse_min_event_ndims=min_event_ndims_,
                                      is_constant_jacobian=True,
                                      validate_args=validate_args,
                                      parameters=parameters,
                                      name=name)
Ejemplo n.º 20
0
 def _transpose_around_bijector_fn(self,
                                   bijector_fn,
                                   arg,
                                   src_event_ndims,
                                   dest_event_ndims=None,
                                   fn_reduces_event=False,
                                   **kwargs):
     # This function moves the axes corresponding to `self.sample_shape` to the
     # left of the batch shape, then applies `bijector_fn`, then moves the axes
     # corresponding to `self.sample_shape` back to the event part of the shape.
     #
     # `src_event_ndims` and `dest_event_ndims` indicate the expected event rank
     # (omitting `self.sample_shape`) before and after applying `bijector_fn`.
     #
     # This function arose because forward and inverse ended up being quite
     # similar. It was then only a small generalization to also support {F/I}LDJ.
     batch_ndims = ps.rank_from_shape(self.distribution.batch_shape_tensor,
                                      self.distribution.batch_shape)
     extra_sample_ndims = ps.rank_from_shape(self.sample_shape)
     arg_ndims = ps.rank(arg)
     # (1) Expand arg's dims.
     d = arg_ndims - batch_ndims - extra_sample_ndims - src_event_ndims
     arg = tf.reshape(arg,
                      shape=ps.pad(ps.shape(arg),
                                   paddings=[[ps.maximum(0, -d), 0]],
                                   constant_values=1))
     arg_ndims = ps.rank(arg)
     sample_ndims = ps.maximum(0, d)
     # (2) Transpose arg's dims.
     sample_dims = ps.range(0, sample_ndims)
     batch_dims = ps.range(sample_ndims, sample_ndims + batch_ndims)
     extra_sample_dims = ps.range(
         sample_ndims + batch_ndims,
         sample_ndims + batch_ndims + extra_sample_ndims)
     event_dims = ps.range(sample_ndims + batch_ndims + extra_sample_ndims,
                           arg_ndims)
     perm = ps.concat(
         [sample_dims, extra_sample_dims, batch_dims, event_dims], axis=0)
     arg = tf.transpose(arg, perm=perm)
     # (3) Apply underlying bijector.
     result = bijector_fn(arg, **kwargs)
     # (4) Transpose sample_shape from the sample to the event shape.
     result_ndims = ps.rank(result)
     if fn_reduces_event:
         dest_event_ndims = 0
     d = result_ndims - batch_ndims - extra_sample_ndims - dest_event_ndims
     if fn_reduces_event:
         # In some cases, fn may reduce event too far, i.e. ildj may return a
         # scalar `0.`, which won't work with the transpose we do below.
         result = tf.reshape(result,
                             shape=ps.pad(ps.shape(result),
                                          paddings=[[ps.maximum(0, -d), 0]],
                                          constant_values=1))
         result_ndims = ps.rank(result)
     sample_ndims = ps.maximum(0, d)
     sample_dims = ps.range(0, sample_ndims)
     extra_sample_dims = ps.range(sample_ndims,
                                  sample_ndims + extra_sample_ndims)
     batch_dims = ps.range(sample_ndims + extra_sample_ndims,
                           sample_ndims + extra_sample_ndims + batch_ndims)
     event_dims = ps.range(sample_ndims + extra_sample_ndims + batch_ndims,
                           result_ndims)
     perm = ps.concat(
         [sample_dims, batch_dims, extra_sample_dims, event_dims], axis=0)
     return tf.transpose(result, perm=perm)
 def _sum_event_part(x, shard_axes=None):
   event_axes = ps.range(batch_ndims, ps.rank(x))
   return distribute_lib.psum(tf.reduce_sum(x, axis=event_axes), shard_axes)
Ejemplo n.º 22
0
    def one_step(self, current_state, previous_kernel_results):
        with tf.name_scope(
                mcmc_util.make_name(self.name, 'simple_step_size_adaptation',
                                    'one_step')):
            # Set the step_size.
            inner_results = self.step_size_setter_fn(
                previous_kernel_results.inner_results,
                previous_kernel_results.new_step_size)

            # Step the inner kernel.
            new_state, new_inner_results = self.inner_kernel.one_step(
                current_state, inner_results)

            # Get the new step size.
            log_accept_prob = self.log_accept_prob_getter_fn(new_inner_results)
            log_target_accept_prob = tf.math.log(
                tf.cast(previous_kernel_results.target_accept_prob,
                        dtype=log_accept_prob.dtype))

            state_parts = tf.nest.flatten(current_state)
            step_size = self.step_size_getter_fn(new_inner_results)
            step_size_parts = tf.nest.flatten(step_size)
            log_accept_prob_rank = prefer_static.rank(log_accept_prob)

            new_step_size_parts = []
            for step_size_part, state_part in zip(step_size_parts,
                                                  state_parts):
                # Compute new step sizes for each step size part. If step size part has
                # smaller rank than the corresponding state part, then the difference is
                # averaged away in the log accept prob.
                #
                # Example:
                #
                # state_part has shape      [2, 3, 4, 5]
                # step_size_part has shape     [1, 4, 1]
                # log_accept_prob has shape [2, 3, 4]
                #
                # Since step size has 1 rank fewer than the state, we reduce away the
                # leading dimension of log_accept_prob to get a Tensor with shape [3,
                # 4]. Next, since log_accept_prob must broadcast into step_size_part on
                # the left, we reduce the dimensions where their shapes differ, to get a
                # Tensor with shape [1, 4], which now is compatible with the leading
                # dimensions of step_size_part.
                #
                # There is a subtlety here in that step_size_parts might be a length-1
                # list, which means that we'll be "structure-broadcasting" it for all
                # the state parts (see logic in, e.g., hmc.py). In this case we must
                # assume that that the lone step size provided broadcasts with the event
                # dims of each state part. This means that either step size has no
                # dimensions corresponding to chain dimensions, or all states are of the
                # same shape. For the former, we want to reduce over all chain
                # dimensions. For the later, we want to use the same logic as in the
                # non-structure-broadcasted case.
                #
                # It turns out we can compute the reduction dimensions for both cases
                # uniformly by taking the rank of any state part. This obviously works
                # in the second case (where all state ranks are the same). In the first
                # case, all state parts have the rank L + D_i + B, where L is the rank
                # of log_accept_prob, D_i is the non-shared dimensions amongst all
                # states, and B are the shared dimensions of all the states, which are
                # equal to the step size. When we subtract B, we will always get a
                # number >= L, which means we'll get the full reduction we want.
                num_reduce_dims = prefer_static.minimum(
                    log_accept_prob_rank,
                    prefer_static.rank(state_part) -
                    prefer_static.rank(step_size_part))
                reduced_log_accept_prob = reduce_logmeanexp(
                    log_accept_prob, axis=prefer_static.range(num_reduce_dims))
                # reduced_log_accept_prob must broadcast into step_size_part on the
                # left, so we do an additional reduction over dimensions where their
                # shapes differ.
                reduce_indices = get_differing_dims(reduced_log_accept_prob,
                                                    step_size_part)
                reduced_log_accept_prob = reduce_logmeanexp(
                    reduced_log_accept_prob,
                    axis=reduce_indices,
                    keepdims=True)

                one_plus_adaptation_rate = 1. + tf.cast(
                    previous_kernel_results.adaptation_rate,
                    dtype=step_size_part.dtype)
                new_step_size_part = mcmc_util.choose(
                    reduced_log_accept_prob > log_target_accept_prob,
                    step_size_part * one_plus_adaptation_rate,
                    step_size_part / one_plus_adaptation_rate)

                new_step_size_parts.append(
                    tf.where(
                        previous_kernel_results.step <
                        self.num_adaptation_steps, new_step_size_part,
                        step_size_part))
            new_step_size = tf.nest.pack_sequence_as(step_size,
                                                     new_step_size_parts)

            return new_state, previous_kernel_results._replace(
                inner_results=new_inner_results,
                step=1 + previous_kernel_results.step,
                new_step_size=new_step_size)
Ejemplo n.º 23
0
def remc_thermodynamic_integrals(
    inverse_temperatures,
    potential_energy,
    iid_chain_ndims=0,
):
    """Estimate thermodynamic integrals using results of ReplicaExchangeMC.

  Write the density, when tempering with inverse temperature `b`, as
  `p_b(x) = exp(-b * U(x)) f(x) / Z_b`. Here `Z_b` is a normalizing constant,
  and `U(x)` is the potential energy. f(x) is the untempered part, if any.

  Let `E_b[U(X)]` be the expected potential energy when `X ~ p_b`. Then,
  `-1 * integral_c^d E_b[U(X)] db = log[Z_d / Z_c]`, the log normalizing
  constant ratio.

  Let `Var_b[U(X)] be the variance of potential energy when `X ~ p_b(x)`. Then,
  `integral_c^d Var_b[U(X)] db = E_d[U(X)] - E_c[U(X)]`, the cross entropy
  difference.

  Integration is done via the trapezoidal rule. Assume `E_b[U(X)]` and
  `Var_b[U(X)]` have bounded second derivatives, uniform in `b`. Then, the
  bias due to approximation of the integral by a summation is `O(1 / K^2)`.

  Suppose `U(X)`, `X ~ p_b` has bounded fourth moment, uniform in `b`. Suppose
  further that the swap acceptance rate between every adjacent pair is greater
  than `C_s > 0`.  If we have `N` effective samples from each of the `n_replica`
  replicas, then the standard error of the summation is
  `O(1 / Sqrt(n_replica * N))`.

  Args:
    inverse_temperatures: `Tensor` of shape `[n_replica, ...]`, used to temper
      `n_replica` replicas. Assumed to be decreasing with respect to the replica
      index.
    potential_energy: The `potential_energy` field of
      `ReplicaExchangeMCKernelResults`, shape `[n_samples, n_replica, ...]`.
      If the kth replica has density `p_k(x) = exp(-beta_k * U(x)) * f_k(x)`,
      then `potential_energy[k]` is `U(X)`, where `X ~ p_k`.
    iid_chain_ndims: Number of dimensions in `potential_energy`, to the
      right of the replica dimension, that index independent identically
      distributed chains. In particular, the temperature for these chains should
      be identical. The sample means will be computed over these dimensions.

  Returns:
    ReplicaExchangeMCThermodynamicIntegrals namedtuple.
  """
    dtype = dtype_util.common_dtype([inverse_temperatures, potential_energy],
                                    dtype_hint=tf.float32)
    inverse_temperatures = tf.convert_to_tensor(inverse_temperatures,
                                                dtype=dtype)
    potential_energy = tf.convert_to_tensor(potential_energy, dtype=dtype)

    # mean is E[U(beta)].
    # Reduction is over samples and (possibly) independent chains.
    # Squeeze out the singleton left over from samples in axis=0.
    # Keepdims so we can broadcast with inverse_temperatures, which *may* have
    # additional batch dimensions.
    iid_axis = ps.concat([[0], ps.range(2, 2 + iid_chain_ndims)], axis=0)
    mean = tf.reduce_mean(potential_energy, axis=iid_axis, keepdims=True)[0]
    var = sample_stats.variance(potential_energy,
                                sample_axis=iid_axis,
                                keepdims=True)[0]

    # Integrate over the single temperature dimension.
    # dx[k] = beta_k - beta_{k+1} > 0.
    dx = bu.left_justified_expand_dims_like(
        inverse_temperatures[:-1] - inverse_temperatures[1:], mean)

    def _trapz(y):
        avg_y = 0.5 * (y[:-1] + y[1:])
        return tf.reduce_sum(avg_y * dx, axis=0)

    def _squeeze_chains(x):
        # Squeeze with a reshape, since squeeze can't use tensors.
        return tf.reshape(x, ps.shape(x)[iid_chain_ndims:])

    return ReplicaExchangeMCThermodynamicIntegrals(
        log_normalizing_constant_ratio=-_squeeze_chains(_trapz(mean)),
        cross_entropy_difference=_squeeze_chains(_trapz(var)),
    )
Ejemplo n.º 24
0
 def _reduce(self, op, stat):
     axis = 1 + prefer_static.range(self._get_reinterpreted_batch_ndims())
     return op(stat, axis=-axis)
Ejemplo n.º 25
0
def _convolution_batch_nhwbc(x, kernel, rank, strides, padding, dilations,
                             name):
    """Specialization of batch conv to NHWBC data format."""
    with tf.name_scope(name or 'conv2d_nhwbc'):
        # Prepare arguments.
        [
            rank,
            _,  # strides
            padding,
            dilations,
            _,  # data_format
        ] = prepare_conv_args(rank, strides, padding, dilations)
        strides = prepare_strides(strides, rank + 2, arg_name='strides')

        dtype = dtype_util.common_dtype([x, kernel], dtype_hint=tf.float32)
        x = tf.convert_to_tensor(x, dtype=dtype, name='x')
        kernel = tf.convert_to_tensor(kernel, dtype=dtype, name='kernel')

        # Step 1: Transpose and double flatten kernel.
        # kernel.shape = B + F + [c, c']. Eg: [b, fh, fw, c, c']
        kernel_shape = prefer_static.shape(kernel)
        kernel_batch_shape, kernel_event_shape = prefer_static.split(
            kernel_shape, num_or_size_splits=[-1, rank + 2])
        kernel_batch_size = prefer_static.reduce_prod(kernel_batch_shape)
        kernel_ndims = prefer_static.rank(kernel)
        kernel_batch_ndims = kernel_ndims - rank - 2
        perm = prefer_static.concat([
            prefer_static.range(kernel_batch_ndims, kernel_batch_ndims + rank),
            prefer_static.range(0, kernel_batch_ndims),
            prefer_static.range(kernel_batch_ndims + rank, kernel_ndims),
        ],
                                    axis=0)  # Eg, [1, 2, 0, 3, 4]
        kernel = tf.transpose(kernel, perm=perm)  # F + B + [c, c']
        kernel = tf.reshape(kernel,
                            shape=prefer_static.concat([
                                kernel_event_shape[:rank],
                                [
                                    kernel_batch_size * kernel_event_shape[-2],
                                    kernel_event_shape[-1]
                                ],
                            ],
                                                       axis=0))  # F + [bc, c']

        # Step 2: Double flatten x.
        # x.shape = N + D + B + [c]
        x_shape = prefer_static.shape(x)
        [
            x_sample_shape,
            x_rank_shape,
            x_batch_shape,
            x_channel_shape,
        ] = prefer_static.split(
            x_shape, num_or_size_splits=[-1, rank, kernel_batch_ndims, 1])
        x = tf.reshape(
            x,  # N + D + B + [c]
            shape=prefer_static.concat([
                [prefer_static.reduce_prod(x_sample_shape)],
                x_rank_shape,
                [
                    prefer_static.reduce_prod(x_batch_shape) *
                    prefer_static.reduce_prod(x_channel_shape)
                ],
            ],
                                       axis=0))  # [n] + D + [bc]

        # Step 3: Apply convolution.
        y = tf.nn.depthwise_conv2d(x,
                                   kernel,
                                   strides=strides,
                                   padding=padding,
                                   data_format='NHWC',
                                   dilations=dilations)
        #  SAME: y.shape = [n, h,      w,      bcc']
        # VALID: y.shape = [n, h-fh+1, w-fw+1, bcc']

        # Step 4: Reshape/reduce for output.
        y_shape = prefer_static.shape(y)
        y = tf.reshape(y,
                       shape=prefer_static.concat(
                           [
                               x_sample_shape,
                               y_shape[1:-1],
                               kernel_batch_shape,
                               kernel_event_shape[-2:],
                           ],
                           axis=0))  # N + D' + B + [c, c']
        y = tf.reduce_sum(y, axis=-2)  # N + D' + B + [c']

        return y
Ejemplo n.º 26
0
def _kl_independent(a, b, name='kl_independent'):
    """Batched KL divergence `KL(a || b)` for Independent distributions.

  We can leverage the fact that
  ```
  KL(Independent(a) || Independent(b)) = sum(KL(a || b))
  ```
  where the sum is over the `reinterpreted_batch_ndims`.

  Args:
    a: Instance of `Independent`.
    b: Instance of `Independent`.
    name: (optional) name to use for created ops. Default 'kl_independent'.

  Returns:
    Batchwise `KL(a || b)`.

  Raises:
    ValueError: If the event space for `a` and `b`, or their underlying
      distributions don't match.
  """
    p = a.distribution
    q = b.distribution

    # The KL between any two (non)-batched distributions is a scalar.
    # Given that the KL between two factored distributions is the sum, i.e.
    # KL(p1(x)p2(y) || q1(x)q2(y)) = KL(p1 || q1) + KL(q1 || q2), we compute
    # KL(p || q) and do a `reduce_sum` on the reinterpreted batch dimensions.
    if (tensorshape_util.is_fully_defined(a.event_shape)
            and tensorshape_util.is_fully_defined(b.event_shape)):
        if a.event_shape == b.event_shape:
            if p.event_shape == q.event_shape:
                num_reduce_dims = (tensorshape_util.rank(a.event_shape) -
                                   tensorshape_util.rank(p.event_shape))
                reduce_dims = [-i - 1 for i in range(0, num_reduce_dims)]

                return tf.reduce_sum(kullback_leibler.kl_divergence(p,
                                                                    q,
                                                                    name=name),
                                     axis=reduce_dims)
            else:
                raise NotImplementedError(
                    'KL between Independents with different '
                    'event shapes not supported.')
        else:
            raise ValueError('Event shapes do not match.')
    else:
        p_event_shape_tensor = p.event_shape_tensor()
        q_event_shape_tensor = q.event_shape_tensor()
        # NOTE: We could optimize by passing the event_shape_tensor of p and q
        # to a.event_shape_tensor() and b.event_shape_tensor().
        a_event_shape_tensor = a.event_shape_tensor()
        b_event_shape_tensor = b.event_shape_tensor()
        with tf.control_dependencies([
                assert_util.assert_equal(a_event_shape_tensor,
                                         b_event_shape_tensor,
                                         message='Event shapes do not match.'),
                assert_util.assert_equal(p_event_shape_tensor,
                                         q_event_shape_tensor,
                                         message='Event shapes do not match.'),
        ]):
            num_reduce_dims = (prefer_static.rank_from_shape(
                a_event_shape_tensor, a.event_shape) -
                               prefer_static.rank_from_shape(
                                   p_event_shape_tensor, p.event_shape))
            reduce_dims = prefer_static.range(-num_reduce_dims, 0, 1)
            return tf.reduce_sum(kullback_leibler.kl_divergence(p,
                                                                q,
                                                                name=name),
                                 axis=reduce_dims)
Ejemplo n.º 27
0
def pivoted_cholesky(matrix, max_rank, diag_rtol=1e-3, name=None):
    """Computes the (partial) pivoted cholesky decomposition of `matrix`.

  The pivoted Cholesky is a low rank approximation of the Cholesky decomposition
  of `matrix`, i.e. as described in [(Harbrecht et al., 2012)][1]. The
  currently-worst-approximated diagonal element is selected as the pivot at each
  iteration. This yields from a `[B1...Bn, N, N]` shaped `matrix` a `[B1...Bn,
  N, K]` shaped rank-`K` approximation `lr` such that `lr @ lr.T ~= matrix`.
  Note that, unlike the Cholesky decomposition, `lr` is not triangular even in
  a rectangular-matrix sense. However, under a permutation it could be made
  triangular (it has one more zero in each column as you move to the right).

  Such a matrix can be useful as a preconditioner for conjugate gradient
  optimization, i.e. as in [(Wang et al. 2019)][2], as matmuls and solves can be
  cheaply done via the Woodbury matrix identity, as implemented by
  `tf.linalg.LinearOperatorLowRankUpdate`.

  Args:
    matrix: Floating point `Tensor` batch of symmetric, positive definite
      matrices.
    max_rank: Scalar `int` `Tensor`, the rank at which to truncate the
      approximation.
    diag_rtol: Scalar floating point `Tensor` (same dtype as `matrix`). If the
      errors of all diagonal elements of `lr @ lr.T` are each lower than
      `element * diag_rtol`, iteration is permitted to terminate early.
    name: Optional name for the op.

  Returns:
    lr: Low rank pivoted Cholesky approximation of `matrix`.

  #### References

  [1]: H Harbrecht, M Peters, R Schneider. On the low-rank approximation by the
       pivoted Cholesky decomposition. _Applied numerical mathematics_,
       62(4):428-440, 2012.

  [2]: K. A. Wang et al. Exact Gaussian Processes on a Million Data Points.
       _arXiv preprint arXiv:1903.08114_, 2019. https://arxiv.org/abs/1903.08114
  """
    with tf.name_scope(name or 'pivoted_cholesky'):
        dtype = dtype_util.common_dtype([matrix, diag_rtol],
                                        dtype_hint=tf.float32)
        if not isinstance(matrix, tf.linalg.LinearOperator):
            matrix = tf.convert_to_tensor(matrix, name='matrix', dtype=dtype)
        if tensorshape_util.rank(matrix.shape) is None:
            raise NotImplementedError(
                'Rank of `matrix` must be known statically')
        if isinstance(matrix, tf.linalg.LinearOperator):
            matrix_shape = tf.cast(matrix.shape_tensor(), tf.int64)
        else:
            matrix_shape = ps.shape(matrix, out_type=tf.int64)

        max_rank = tf.convert_to_tensor(max_rank,
                                        name='max_rank',
                                        dtype=tf.int64)
        max_rank = tf.minimum(max_rank, matrix_shape[-1])
        diag_rtol = tf.convert_to_tensor(diag_rtol,
                                         dtype=dtype,
                                         name='diag_rtol')
        matrix_diag = tf.linalg.diag_part(matrix)
        # matrix is P.D., therefore all matrix_diag > 0, so we don't need abs.
        orig_error = tf.reduce_max(matrix_diag, axis=-1)

        def cond(m, pchol, perm, matrix_diag):
            """Condition for `tf.while_loop` continuation."""
            del pchol
            del perm
            error = tf.linalg.norm(matrix_diag, ord=1, axis=-1)
            max_err = tf.reduce_max(error / orig_error)
            return (m < max_rank) & (tf.equal(m, 0) | (max_err > diag_rtol))

        batch_dims = tensorshape_util.rank(matrix.shape) - 2

        def batch_gather(params, indices, axis=-1):
            return tf.gather(params, indices, axis=axis, batch_dims=batch_dims)

        def body(m, pchol, perm, matrix_diag):
            """Body of a single `tf.while_loop` iteration."""
            # Here is roughly a numpy, non-batched version of what's going to happen.
            # (See also Algorithm 1 of Harbrecht et al.)
            # 1: maxi = np.argmax(matrix_diag[perm[m:]]) + m
            # 2: maxval = matrix_diag[perm][maxi]
            # 3: perm[m], perm[maxi] = perm[maxi], perm[m]
            # 4: row = matrix[perm[m]][perm[m + 1:]]
            # 5: row -= np.sum(pchol[:m][perm[m + 1:]] * pchol[:m][perm[m]]], axis=-2)
            # 6: pivot = np.sqrt(maxval); row /= pivot
            # 7: row = np.concatenate([[[pivot]], row], -1)
            # 8: matrix_diag[perm[m:]] -= row**2
            # 9: pchol[m, perm[m:]] = row

            # Find the maximal position of the (remaining) permuted diagonal.
            # Steps 1, 2 above.
            permuted_diag = batch_gather(matrix_diag, perm[..., m:])
            maxi = tf.argmax(permuted_diag, axis=-1,
                             output_type=tf.int64)[..., tf.newaxis]
            maxval = batch_gather(permuted_diag, maxi)
            maxi = maxi + m
            maxval = maxval[..., 0]
            # Update perm: Swap perm[...,m] with perm[...,maxi]. Step 3 above.
            perm = _swap_m_with_i(perm, m, maxi)
            # Step 4.
            if callable(getattr(matrix, 'row', None)):
                row = matrix.row(perm[..., m])[..., tf.newaxis, :]
            else:
                row = batch_gather(matrix, perm[..., m:m + 1], axis=-2)
            row = batch_gather(row, perm[..., m + 1:])
            # Step 5.
            prev_rows = pchol[..., :m, :]
            prev_rows_perm_m_onward = batch_gather(prev_rows, perm[...,
                                                                   m + 1:])
            prev_rows_pivot_col = batch_gather(prev_rows, perm[..., m:m + 1])
            row -= tf.reduce_sum(prev_rows_perm_m_onward * prev_rows_pivot_col,
                                 axis=-2)[..., tf.newaxis, :]
            # Step 6.
            pivot = tf.sqrt(maxval)[..., tf.newaxis, tf.newaxis]
            # Step 7.
            row = tf.concat([pivot, row / pivot], axis=-1)
            # TODO(b/130899118): Pad grad fails with int64 paddings.
            # Step 8.
            paddings = tf.concat([
                tf.zeros([ps.rank(pchol) - 1, 2], dtype=tf.int32),
                [[tf.cast(m, tf.int32), 0]]
            ],
                                 axis=0)
            diag_update = tf.pad(row**2, paddings=paddings)[..., 0, :]
            reverse_perm = _invert_permutation(perm)
            matrix_diag -= batch_gather(diag_update, reverse_perm)
            # Step 9.
            row = tf.pad(row, paddings=paddings)
            # TODO(bjp): Defer the reverse permutation all-at-once at the end?
            row = batch_gather(row, reverse_perm)
            pchol_shape = pchol.shape
            pchol = tf.concat([pchol[..., :m, :], row, pchol[..., m + 1:, :]],
                              axis=-2)
            tensorshape_util.set_shape(pchol, pchol_shape)
            return m + 1, pchol, perm, matrix_diag

        m = np.int64(0)
        pchol = tf.zeros(matrix_shape, dtype=matrix.dtype)[..., :max_rank, :]
        perm = tf.broadcast_to(ps.range(matrix_shape[-1]), matrix_shape[:-1])
        _, pchol, _, _ = tf.while_loop(cond=cond,
                                       body=body,
                                       loop_vars=(m, pchol, perm, matrix_diag))
        pchol = tf.linalg.matrix_transpose(pchol)
        tensorshape_util.set_shape(
            pchol, tensorshape_util.concatenate(matrix_diag.shape, [None]))
        return pchol
Ejemplo n.º 28
0
def slice_bounds_by_doubling(x_initial,
                             target_log_prob,
                             log_slice_heights,
                             max_doublings,
                             step_size,
                             seed=None,
                             name=None):
    """Returns the bounds of the slice at each stage of doubling procedure.

  Precomputes the x coordinates of the left (L) and right (R) endpoints of the
  interval `I` produced in the "doubling" algorithm [Neal 2003][1] P713. Note
  that we simultaneously compute all possible doubling values for each chain,
  for the reason that at small-medium densities, the gains from parallel
  evaluation might cause a speed-up, but this will be benchmarked against the
  while loop implementation.

  Args:
    x_initial: `tf.Tensor` of any shape and any real dtype consumable by
      `target_log_prob`. The initial points.
    target_log_prob: A callable taking a `tf.Tensor` of shape and dtype as
      `x_initial` and returning a tensor of the same shape. The log density of
      the target distribution.
    log_slice_heights: `tf.Tensor` with the same shape as `x_initial` and the
      same dtype as returned by `target_log_prob`. The log of the height of the
      slice for each chain. The values must be bounded above by
      `target_log_prob(x_initial)`.
    max_doublings: Scalar positive int32 `tf.Tensor`. The maximum number of
      doublings to consider.
    step_size: `tf.Tensor` with same dtype as and shape compatible with
      `x_initial`. The size of the initial interval.
    seed: (Optional) positive int or Tensor seed pair. The random seed.
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., 'find_slice_bounds').

  Returns:
    upper_bounds: A tensor of same shape and dtype as `x_initial`. Slice upper
      bounds for each chain.
    lower_bounds: A tensor of same shape and dtype as `x_initial`. Slice lower
      bounds for each chain.
    both_ok: A tensor of shape `x_initial` and boolean dtype. Indicates if both
      the chosen upper and lower bound lie outside of the slice.

  #### References

  [1]: Radford M. Neal. Slice Sampling. The Annals of Statistics. 2003, Vol 31,
       No. 3 , 705-767.
       https://projecteuclid.org/download/pdf_1/euclid.aos/1056562461
  """
    with tf.name_scope(name or 'slice_bounds_by_doubling'):
        left_seed, increments_seed = samplers.split_seed(
            seed, salt='slice_bounds_by_doubling')
        x_initial = tf.convert_to_tensor(value=x_initial)
        batch_shape = ps.shape(x_initial)
        dtype = dtype_util.base_dtype(step_size.dtype)
        left_endpoints = x_initial + step_size * samplers.uniform(
            batch_shape, minval=-1.0, maxval=0.0, dtype=dtype, seed=left_seed)

        # Compute the increments by which we need to step the upper and lower bounds
        # part of the doubling procedure.
        left_increments, widths = _left_doubling_increments(
            batch_shape, max_doublings, step_size, seed=increments_seed)
        # The left and right end points. Shape (max_doublings+1,) + batch_shape.
        left_endpoints = left_endpoints - left_increments
        right_endpoints = left_endpoints + widths

        # Test if these end points lie outside of the slice.
        # Checks if the end points of the slice are outside the graph of the pdf.
        left_ep_values = tf.map_fn(target_log_prob, left_endpoints)
        right_ep_values = tf.map_fn(target_log_prob, right_endpoints)
        left_ok = left_ep_values < log_slice_heights
        right_ok = right_ep_values < log_slice_heights
        both_ok = left_ok & right_ok

        both_ok_f = tf.reshape(both_ok, [max_doublings + 1, -1])

        best_interval_idx = _find_best_interval_idx(
            tf.cast(both_ok_f, dtype=tf.int32))

        # Formats the above index as required to use with gather_nd.
        point_index_gather = tf.stack(
            [best_interval_idx,
             ps.range(ps.size(best_interval_idx))],
            axis=1,
            name='point_index_gather')
        left_ep_f = tf.reshape(left_endpoints, [max_doublings + 1, -1])
        right_ep_f = tf.reshape(right_endpoints, [max_doublings + 1, -1])
        # The x values of the uppper and lower bounds of the slices for each chain.
        lower_bounds = tf.reshape(tf.gather_nd(left_ep_f, point_index_gather),
                                  batch_shape)
        upper_bounds = tf.reshape(tf.gather_nd(right_ep_f, point_index_gather),
                                  batch_shape)
        both_ok = tf.reduce_any(both_ok, axis=0)
        return upper_bounds, lower_bounds, both_ok
def make_flow_posterior(prior,
                        num_hidden_units,
                        invert=True,
                        num_flow_layers=2):
    """Make a MAF/IAF surrogate posterior.

  Args:
    prior: tfd.JointDistribution instance of the prior.
    num_hidden_units: int value. Specifies the number of hidden units.
    invert: Optional Boolean value. If `True`, produces inverse autoregressive
      flow. If `False`, produces a masked autoregressive flow.
      Default value: `True`.
    num_flow_layers: Optional int value. Specifies the number of layers.
  Returns:
    surrogate_posterior: A `tfd.TransformedDistribution` instance
      whose samples have shape and structure matching that of `prior`.
  """

    event_shape = prior.event_shape_tensor()
    event_space_bijector = prior.experimental_default_event_space_bijector()
    flat_event_shape = tf.nest.flatten(event_shape)
    flat_event_size = [tf.reduce_prod(s) for s in flat_event_shape]

    ndims = tf.reduce_sum(flat_event_size)
    dtype = tf.nest.flatten(prior.dtype)[0]

    make_swap = lambda: tfb.Permute(ps.range(ndims - 1, -1, -1))

    def make_maf():
        net = tfb.AutoregressiveNetwork(
            2,
            hidden_units=[num_hidden_units, num_hidden_units],
            activation=tf.tanh,
            dtype=dtype)

        maf = tfb.MaskedAutoregressiveFlow(bijector_fn=lambda x: tfb.Chain([
            tfb.Shift(net(x)[Ellipsis, 0]),  # pylint: disable=g-long-lambda
            tfb.Scale(log_scale=net(x)[Ellipsis, 1])
        ]))
        if invert:
            maf = tfb.Invert(maf)
        # To track the variables
        maf._net = net  # pylint: disable=protected-access
        return maf

    dist = tfd.Sample(tfd.Normal(tf.zeros([], dtype=dtype), 1.),
                      sample_shape=[ndims])

    bijectors = [
        event_space_bijector,
        tfb.Restructure(
            tf.nest.pack_sequence_as(event_shape,
                                     range(len(flat_event_shape)))),
        tfb.JointMap(tf.nest.map_structure(tfb.Reshape, flat_event_shape)),
        tfb.Split(flat_event_size),
    ]
    bijectors.append(make_maf())

    for _ in range(num_flow_layers - 1):
        bijectors.extend([make_swap(), make_maf()])

    return tfd.TransformedDistribution(dist, tfb.Chain(bijectors))
Ejemplo n.º 30
0
def _get_reduction_axes(x, nd):
    """Enumerates the final `nd` axis indices of `x`."""
    x_rank = prefer_static.rank_from_shape(prefer_static.shape(x))
    return prefer_static.range(x_rank - 1, x_rank - nd - 1, -1)