def _log_variance(self):
     # Following calculation is based on law of total variance:
     #
     # Var[Z] = E[Var[Z | V]] + Var[E[Z | V]]
     #
     # where,
     #
     # Z|v ~ interpolate_affine[v](dist)
     # V ~ mixture_dist
     #
     # thus,
     #
     # E[Var[Z | V]] = sum{ prob[d] Var[d] : d=0, ..., deg-1 }
     # Var[E[Z | V]] = sum{ prob[d] (Mean[d] - Mean)**2 : d=0, ..., deg-1 }
     distributions = self.poisson_and_mixture_distributions()
     dist, mixture_dist = distributions
     v = tf.stack(
         [
             # log(dist.variance()) = log(Var[d]) = log(rate[d])
             dist.log_rate,
             # log((Mean[d] - Mean)**2)
             2. * tf.math.log(
                 tf.abs(dist.mean() - self._mean(
                     distributions=distributions)[..., tf.newaxis])),
         ],
         axis=-1)
     return tf.reduce_logsumexp(mixture_dist.logits[..., tf.newaxis] + v,
                                axis=[-2, -1])
Ejemplo n.º 2
0
def reduce_logmeanexp(input_tensor, axis=None, keepdims=False, name=None):
    """Computes `log(mean(exp(input_tensor)))`.

  Reduces `input_tensor` along the dimensions given in `axis`.  Unless
  `keepdims` is true, the rank of the tensor is reduced by 1 for each entry in
  `axis`. If `keepdims` is true, the reduced dimensions are retained with length
  1.

  If `axis` has no entries, all dimensions are reduced, and a tensor with a
  single element is returned.

  This function is more numerically stable than `log(reduce_mean(exp(input)))`.
  It avoids overflows caused by taking the exp of large inputs and underflows
  caused by taking the log of small inputs.

  Args:
    input_tensor: The tensor to reduce. Should have numeric type.
    axis: The dimensions to reduce. If `None` (the default), reduces all
      dimensions. Must be in the range `[-rank(input_tensor),
      rank(input_tensor))`.
    keepdims:  Boolean.  Whether to keep the axis as singleton dimensions.
      Default value: `False` (i.e., squeeze the reduced dimensions).
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., `'reduce_logmeanexp'`).

  Returns:
    log_mean_exp: The reduced tensor.
  """
    with tf.name_scope(name or 'reduce_logmeanexp'):
        lse = tf.reduce_logsumexp(input_tensor, axis=axis, keepdims=keepdims)
        n = prefer_static.size(input_tensor) // prefer_static.size(lse)
        log_n = tf.math.log(tf.cast(n, lse.dtype))
        return lse - log_n
Ejemplo n.º 3
0
    def _entropy(self):
        if self._logits is None:
            # If we only have probs, there's not much we can do to ensure numerical
            # precision.
            probs = tf.convert_to_tensor(self._probs)
            return -tf.reduce_sum(
                tf.math.multiply_no_nan(tf.math.log(probs), probs), axis=-1)

        # The following result can be derived as follows. Write log(p[i]) as:
        # s[i]-m-lse(s[i]-m) where m=max(s), then you have:
        #   sum_i exp(s[i]-m-lse(s-m)) (s[i] - m - lse(s-m))
        #   = -m - lse(s-m) + sum_i s[i] exp(s[i]-m-lse(s-m))
        #   = -m - lse(s-m) + (1/exp(lse(s-m))) sum_i s[i] exp(s[i]-m)
        #   = -m - lse(s-m) + (1/sumexp(s-m)) sum_i s[i] exp(s[i]-m)
        # Write x[i]=s[i]-m then you have:
        #   = -m - lse(x) + (1/sum_exp(x)) sum_i s[i] exp(x[i])
        # Negating all of this result is the Shanon (discrete) entropy.
        logits = tf.convert_to_tensor(self._logits)
        m = tf.reduce_max(logits, axis=-1, keepdims=True)
        x = logits - m
        lse_logits = m[..., 0] + tf.reduce_logsumexp(x, axis=-1)
        sum_exp_x = tf.reduce_sum(tf.math.exp(x), axis=-1)
        return lse_logits - tf.reduce_sum(tf.math.multiply_no_nan(
            logits, tf.math.exp(x)),
                                          axis=-1) / sum_exp_x
Ejemplo n.º 4
0
 def _assert_valid_sample(self, x):
   if not self.validate_args:
     return x
   return distribution_util.with_dependencies([
       assert_util.assert_non_positive(x),
       assert_util.assert_near(
           tf.zeros([], dtype=self.dtype), tf.reduce_logsumexp(x, axis=[-1])),
   ], x)
Ejemplo n.º 5
0
 def _log_prob(self, x):
     with tf.control_dependencies(self._runtime_assertions):
         x = self._pad_sample_dims(x)
         log_prob_x = self.components_distribution.log_prob(x)  # [S, B, k]
         log_mix_prob = tf.math.log_softmax(
             self.mixture_distribution.logits_parameter(),
             axis=-1)  # [B, k]
         return tf.reduce_logsumexp(log_prob_x + log_mix_prob,
                                    axis=-1)  # [S, B]
Ejemplo n.º 6
0
 def _log_cdf(self, x):
   with tf.control_dependencies(self._assertions):
     x = tf.convert_to_tensor(x, name="x")
     distribution_log_cdfs = [d.log_cdf(x) for d in self.components]
     cat_log_probs = self._cat_probs(log_probs=True)
     final_log_cdfs = [
         cat_lp + d_lcdf
         for (cat_lp, d_lcdf) in zip(cat_log_probs, distribution_log_cdfs)
     ]
     concatted_log_cdfs = tf.stack(final_log_cdfs, axis=0)
     mixture_log_cdf = tf.reduce_logsumexp(concatted_log_cdfs, axis=[0])
     return mixture_log_cdf
    def _log_prob(self, value):
        with tf.control_dependencies(self._runtime_assertions):
            # The argument `value` is a tensor of sequences of observations.
            # `observation_batch_shape` is the shape of that tensor with the
            # sequence part removed.
            # `observation_batch_shape` is then broadcast to the full batch shape
            # to give the `batch_shape` that defines the shape of the result.

            observation_tensor_shape = tf.shape(value)
            observation_batch_shape = observation_tensor_shape[:-1 - self.
                                                               _underlying_event_rank]
            # value :: observation_batch_shape num_steps observation_event_shape
            batch_shape = tf.broadcast_dynamic_shape(observation_batch_shape,
                                                     self.batch_shape_tensor())
            log_init = tf.broadcast_to(
                self._log_init,
                tf.concat([batch_shape, [self._num_states]], axis=0))
            # log_init :: batch_shape num_states
            log_transition = self._log_trans

            # `observation_event_shape` is the shape of each sequence of observations
            # emitted by the model.
            observation_event_shape = observation_tensor_shape[
                -1 - self._underlying_event_rank:]
            working_obs = tf.broadcast_to(
                value, tf.concat([batch_shape, observation_event_shape],
                                 axis=0))
            # working_obs :: batch_shape observation_event_shape
            r = self._underlying_event_rank

            # Move index into sequence of observations to front so we can apply
            # tf.foldl
            working_obs = distribution_util.move_dimension(
                working_obs, -1 - r, 0)[..., tf.newaxis]
            # working_obs :: num_steps batch_shape underlying_event_shape
            observation_probs = (
                self._observation_distribution.log_prob(working_obs))

            def forward_step(log_prev_step, log_prob_observation):
                return _log_vector_matrix(
                    log_prev_step, log_transition) + log_prob_observation

            fwd_prob = tf.foldl(forward_step,
                                observation_probs,
                                initializer=log_init)
            # fwd_prob :: batch_shape num_states

            log_prob = tf.reduce_logsumexp(fwd_prob, axis=-1)
            # log_prob :: batch_shape

            return log_prob
Ejemplo n.º 8
0
 def _forward_log_det_jacobian(self, x):
     # This code is similar to tf.math.log_softmax but different because we have
     # an implicit zero column to handle. I.e., instead of:
     #   reduce_sum(logits - reduce_sum(exp(logits), dim))
     # we must do:
     #   log_normalization = 1 + reduce_sum(exp(logits))
     #   -log_normalization + reduce_sum(logits - log_normalization)
     n = prefer_static.shape(x)[-1]
     log_normalization = tf.math.softplus(
         tf.reduce_logsumexp(x, axis=-1, keepdims=True))
     return tf.squeeze(
         (-log_normalization +
          tf.reduce_sum(x - log_normalization, axis=-1, keepdims=True)),
         axis=-1) + 0.5 * tf.math.log(tf.cast(n + 1, dtype=x.dtype))
 def _log_prob(self, x):
     # By convention, we always put the grid points right-most.
     y = tf.stack([aff.inverse(x) for aff in self.interpolated_affine],
                  axis=-1)
     log_prob = tf.reduce_sum(self.distribution.log_prob(y), axis=-2)
     # Because the affine transformation has a constant Jacobian, it is the case
     # that `affine.fldj(x) = -affine.ildj(x)`. This is not true in general.
     fldj = tf.stack([
         aff.forward_log_det_jacobian(
             x, event_ndims=tf.rank(self.event_shape_tensor()))
         for aff in self.interpolated_affine
     ],
                     axis=-1)
     return tf.reduce_logsumexp(self.mixture_distribution.logits - fldj +
                                log_prob,
                                axis=-1)
    def _log_prob(self, y, **kwargs):
        distribution_kwargs, bijector_kwargs = self._kwargs_split_fn(kwargs)

        # For caching to work, it is imperative that the bijector is the first to
        # modify the input.
        x = self.bijector.inverse(y, **bijector_kwargs)
        event_ndims = self._maybe_get_static_event_ndims()

        ildj = self.bijector.inverse_log_det_jacobian(y,
                                                      event_ndims=event_ndims,
                                                      **bijector_kwargs)
        if self.bijector._is_injective:  # pylint: disable=protected-access
            return self._finish_log_prob_for_one_fiber(y, x, ildj, event_ndims,
                                                       **distribution_kwargs)

        lp_on_fibers = [
            self._finish_log_prob_for_one_fiber(y, x_i, ildj_i, event_ndims,
                                                **distribution_kwargs)
            for x_i, ildj_i in zip(x, ildj)
        ]
        return tf.reduce_logsumexp(tf.stack(lp_on_fibers), axis=0)
Ejemplo n.º 11
0
 def _sample_3d(self, n, seed=None):
   """Specialized inversion sampler for 3D."""
   seed = SeedStream(seed, salt='von_mises_fisher_3d')
   u_shape = tf.concat([[n], self._batch_shape_tensor()], axis=0)
   z = tf.random.uniform(u_shape, seed=seed(), dtype=self.dtype)
   # TODO(bjp): Higher-order odd dim analytic CDFs are available in [1], could
   # be bisected for bounded sampling runtime (i.e. not rejection sampling).
   # [1]: Inversion sampler via: https://ieeexplore.ieee.org/document/7347705/
   # The inversion is: u = 1 + log(z + (1-z)*exp(-2*kappa)) / kappa
   # We must protect against both kappa and z being zero.
   safe_conc = tf.where(self.concentration > 0, self.concentration,
                        tf.ones_like(self.concentration))
   safe_z = tf.where(z > 0, z, tf.ones_like(z))
   safe_u = 1 + tf.reduce_logsumexp(
       [tf.math.log(safe_z),
        tf.math.log1p(-safe_z) - 2 * safe_conc], axis=0) / safe_conc
   # Limit of the above expression as kappa->0 is 2*z-1
   u = tf.where(self.concentration > tf.zeros_like(safe_u), safe_u, 2 * z - 1)
   # Limit of the expression as z->0 is -1.
   u = tf.where(tf.equal(z, 0), -tf.ones_like(u), u)
   if not self._allow_nan_stats:
     u = tf.debugging.check_numerics(u, 'u in _sample_3d')
   return u[..., tf.newaxis]
def _log_matrix_vector(ms, vs):
    """Multiply tensor of matrices by vectors assuming values stored are logs."""

    return tf.reduce_logsumexp(ms + vs[..., tf.newaxis, :], axis=-1)
Ejemplo n.º 13
0
 def _mean(self, distributions=None):
     if distributions is None:
         distributions = self.poisson_and_mixture_distributions()
     dist, mixture_dist = distributions
     return tf.exp(
         tf.reduce_logsumexp(mixture_dist.logits + dist.log_rate, axis=-1))
Ejemplo n.º 14
0
 def _log_prob(self, x):
     dist, mixture_dist = self.poisson_and_mixture_distributions()
     return tf.reduce_logsumexp(
         (mixture_dist.logits + dist.log_prob(x[..., tf.newaxis])), axis=-1)
Ejemplo n.º 15
0
 def _log_cdf(self, x):
     x = self._pad_sample_dims(x)
     log_cdf_x = self.components_distribution.log_cdf(x)  # [S, B, k]
     log_mix_prob = tf.math.log_softmax(
         self.mixture_distribution.logits_parameter(), axis=-1)  # [B, k]
     return tf.reduce_logsumexp(log_cdf_x + log_mix_prob, axis=-1)  # [S, B]
    def posterior_marginals(self, observations, mask=None, name=None):
        """Compute marginal posterior distribution for each state.

    This function computes, for each time step, the marginal
    conditional probability that the hidden Markov model was in
    each possible state given the observations that were made
    at each time step.
    So if the hidden states are `z[0],...,z[num_steps - 1]` and
    the observations are `x[0], ..., x[num_steps - 1]`, then
    this function computes `P(z[i] | x[0], ..., x[num_steps - 1])`
    for all `i` from `0` to `num_steps - 1`.

    This operation is sometimes called smoothing. It uses a form
    of the forward-backward algorithm.

    Note: the behavior of this function is undefined if the
    `observations` argument represents impossible observations
    from the model.

    Args:
      observations: A tensor representing a batch of observations
        made on the hidden Markov model.  The rightmost dimension of this tensor
        gives the steps in a sequence of observations from a single sample from
        the hidden Markov model. The size of this dimension should match the
        `num_steps` parameter of the hidden Markov model object. The other
        dimensions are the dimensions of the batch and these are broadcast with
        the hidden Markov model's parameters.
      mask: optional bool-type `tensor` with rightmost dimension matching
        `num_steps` indicating which observations the result of this
        function should be conditioned on. When the mask has value
        `True` the corresponding observations aren't used.
        if `mask` is `None` then all of the observations are used.
        the `mask` dimensions left of the last are broadcast with the
        hmm batch as well as with the observations.
      name: Python `str` name prefixed to Ops created by this class.
        Default value: "HiddenMarkovModel".

    Returns:
      posterior_marginal: A `Categorical` distribution object representing the
        marginal probability of the hidden Markov model being in each state at
        each step. The rightmost dimension of the `Categorical` distributions
        batch will equal the `num_steps` parameter providing one marginal
        distribution for each step. The other dimensions are the dimensions
        corresponding to the batch of observations.

    Raises:
      ValueError: if rightmost dimension of `observations` does not
      have size `num_steps`.
    """

        with tf.name_scope(name or "posterior_marginals"):
            with tf.control_dependencies(self._runtime_assertions):
                observation_tensor_shape = tf.shape(observations)
                mask_tensor_shape = tf.shape(
                    mask) if mask is not None else None

                with self._observation_mask_shape_preconditions(
                        observation_tensor_shape, mask_tensor_shape):
                    observation_log_probs = self._observation_log_probs(
                        observations, mask)
                    log_prob = self._log_init + observation_log_probs[0]
                    log_transition = self._log_trans
                    log_adjoint_prob = tf.zeros_like(log_prob)

                    def _scan_multiple_steps_forwards():
                        def forward_step(log_previous_step,
                                         log_prob_observation):
                            return _log_vector_matrix(
                                log_previous_step,
                                log_transition) + log_prob_observation

                        forward_log_probs = tf.scan(forward_step,
                                                    observation_log_probs[1:],
                                                    initializer=log_prob,
                                                    name="forward_log_probs")
                        return tf.concat([[log_prob], forward_log_probs],
                                         axis=0)

                    forward_log_probs = prefer_static.cond(
                        self._num_steps > 1, _scan_multiple_steps_forwards,
                        lambda: tf.convert_to_tensor([log_prob]))

                    total_log_prob = tf.reduce_logsumexp(forward_log_probs[-1],
                                                         axis=-1)

                    def _scan_multiple_steps_backwards():
                        """Perform `scan` operation when `num_steps` > 1."""
                        def backward_step(log_previous_step,
                                          log_prob_observation):
                            return _log_matrix_vector(
                                log_transition,
                                log_prob_observation + log_previous_step)

                        backward_log_adjoint_probs = tf.scan(
                            backward_step,
                            observation_log_probs[1:],
                            initializer=log_adjoint_prob,
                            reverse=True,
                            name="backward_log_adjoint_probs")

                        return tf.concat(
                            [backward_log_adjoint_probs, [log_adjoint_prob]],
                            axis=0)

                    backward_log_adjoint_probs = prefer_static.cond(
                        self._num_steps > 1, _scan_multiple_steps_backwards,
                        lambda: tf.convert_to_tensor([log_adjoint_prob]))

                    log_likelihoods = forward_log_probs + backward_log_adjoint_probs

                    marginal_log_probs = distribution_util.move_dimension(
                        log_likelihoods - total_log_prob[..., tf.newaxis], 0,
                        -2)

                    return categorical.Categorical(logits=marginal_log_probs)
def _log_vector_matrix(vs, ms):
    """Multiply tensor of vectors by matrices assuming values stored are logs."""

    return tf.reduce_logsumexp(vs[..., tf.newaxis] + ms, axis=-2)
Ejemplo n.º 18
0
def reduce_weighted_logsumexp(logx,
                              w=None,
                              axis=None,
                              keep_dims=False,
                              return_sign=False,
                              name=None):
    """Computes `log(abs(sum(weight * exp(elements across tensor dimensions))))`.

  If all weights `w` are known to be positive, it is more efficient to directly
  use `reduce_logsumexp`, i.e., `tf.reduce_logsumexp(logx + tf.log(w))` is more
  efficient than `du.reduce_weighted_logsumexp(logx, w)`.

  Reduces `input_tensor` along the dimensions given in `axis`.
  Unless `keep_dims` is true, the rank of the tensor is reduced by 1 for each
  entry in `axis`. If `keep_dims` is true, the reduced dimensions
  are retained with length 1.

  If `axis` has no entries, all dimensions are reduced, and a
  tensor with a single element is returned.

  This function is more numerically stable than log(sum(w * exp(input))). It
  avoids overflows caused by taking the exp of large inputs and underflows
  caused by taking the log of small inputs.

  For example:

  ```python
  x = tf.constant([[0., 0, 0],
                   [0, 0, 0]])

  w = tf.constant([[-1., 1, 1],
                   [1, 1, 1]])

  du.reduce_weighted_logsumexp(x, w)
  # ==> log(-1*1 + 1*1 + 1*1 + 1*1 + 1*1 + 1*1) = log(4)

  du.reduce_weighted_logsumexp(x, w, axis=0)
  # ==> [log(-1+1), log(1+1), log(1+1)]

  du.reduce_weighted_logsumexp(x, w, axis=1)
  # ==> [log(-1+1+1), log(1+1+1)]

  du.reduce_weighted_logsumexp(x, w, axis=1, keep_dims=True)
  # ==> [[log(-1+1+1)], [log(1+1+1)]]

  du.reduce_weighted_logsumexp(x, w, axis=[0, 1])
  # ==> log(-1+5)
  ```

  Args:
    logx: The tensor to reduce. Should have numeric type.
    w: The weight tensor. Should have numeric type identical to `logx`.
    axis: The dimensions to reduce. If `None` (the default), reduces all
      dimensions. Must be in the range `[-rank(input_tensor),
      rank(input_tensor))`.
    keep_dims: If true, retains reduced dimensions with length 1.
    return_sign: If `True`, returns the sign of the result.
    name: A name for the operation (optional).

  Returns:
    lswe: The `log(abs(sum(weight * exp(x))))` reduced tensor.
    sign: (Optional) The sign of `sum(weight * exp(x))`.
  """
    with tf.name_scope(name or 'reduce_weighted_logsumexp'):
        logx = tf.convert_to_tensor(logx, name='logx')
        if w is None:
            lswe = tf.reduce_logsumexp(logx, axis=axis, keepdims=keep_dims)
            if return_sign:
                sgn = tf.ones_like(lswe)
                return lswe, sgn
            return lswe
        w = tf.convert_to_tensor(w, dtype=logx.dtype, name='w')
        log_absw_x = logx + tf.math.log(tf.abs(w))
        max_log_absw_x = tf.reduce_max(log_absw_x, axis=axis, keepdims=True)
        # If the largest element is `-inf` or `inf` then we don't bother subtracting
        # off the max. We do this because otherwise we'd get `inf - inf = NaN`. That
        # this is ok follows from the fact that we're actually free to subtract any
        # value we like, so long as we add it back after taking the `log(sum(...))`.
        max_log_absw_x = tf.where(tf.math.is_inf(max_log_absw_x),
                                  tf.zeros([], max_log_absw_x.dtype),
                                  max_log_absw_x)
        wx_over_max_absw_x = (tf.sign(w) * tf.exp(log_absw_x - max_log_absw_x))
        sum_wx_over_max_absw_x = tf.reduce_sum(wx_over_max_absw_x,
                                               axis=axis,
                                               keepdims=keep_dims)
        if not keep_dims:
            max_log_absw_x = tf.squeeze(max_log_absw_x, axis)
        sgn = tf.sign(sum_wx_over_max_absw_x)
        lswe = max_log_absw_x + tf.math.log(sgn * sum_wx_over_max_absw_x)
        if return_sign:
            return lswe, sgn
        return lswe