def _log_abs_determinant(self):
     log_det = tfp_math.reduce_kahan_sum(tf.math.log(tf.math.abs(
         self._diag)),
                                         axis=[-1]).total
     if dtype_util.is_complex(self.dtype):
         log_det = tf.cast(log_det, dtype=self.dtype)
     return log_det
def inverse_log_det_jacobian_ratio(p,
                                   x,
                                   q,
                                   y,
                                   event_ndims,
                                   use_kahan_sum=True):
    """Computes `p.ildj(x, ndims) - q.idlj(y, ndims)`, numerically stably.

  Args:
    p: A bijector instance.
    x: A tensor from the image of `p.forward`.
    q: A bijector instance of the same type as `p`, with matching shape.
    y: A tensor from the image of `q.forward`.
    event_ndims: The number of right-hand dimensions comprising the event shapes
      of `x` and `y`.
    use_kahan_sum: When `True`, the reduction of any remaining `event_ndims`
      beyond the minimum is done using Kahan summation. This requires statically
      known ranks.

  Returns:
    ildj_ratio: `log ((abs o det o jac p^-1)(x) / (abs o det o jac q^-1)(y))`,
      i.e. in TFP code, `p.inverse_log_det_jacobian(x, event_ndims) -
      q.inverse_log_det_jacobian(y, event_ndims)`. In some cases
      this will be computed with better than naive numerical precision, e.g. by
      moving differences inside of a sum reduction.
  """
    assert type(p) == type(q)  # pylint: disable=unidiomatic-typecheck

    min_event_ndims = p.inverse_min_event_ndims

    def default_ildj_ratio_fn(p, x, q, y):
        return (p.inverse_log_det_jacobian(x, event_ndims=min_event_ndims) -
                q.inverse_log_det_jacobian(y, event_ndims=min_event_ndims))

    ildj_ratio_fn = None
    fldj_ratio_fn = None
    for cls in inspect.getmro(type(p)):
        if cls in _ildj_ratio_registry:
            ildj_ratio_fn = _ildj_ratio_registry[cls]
        if cls in _fldj_ratio_registry:
            fldj_ratio_fn = _fldj_ratio_registry[cls]

    if ildj_ratio_fn is None:
        if fldj_ratio_fn is None:
            ildj_ratio_fn = default_ildj_ratio_fn
        else:
            # p.ildj(x) - q.ildj(y) = q.fldj(q^-1(y)) - p.fldj(p^-1(x))
            ildj_ratio_fn = (lambda p, x, q, y: fldj_ratio_fn(
                q, q.inverse(y), p, p.inverse(x)))

    if use_kahan_sum:
        sum_fn = lambda x, axis: tfp_math.reduce_kahan_sum(x, axis=axis).total
    else:
        sum_fn = tf.reduce_sum
    return sum_fn(ildj_ratio_fn(p, x, q, y),
                  axis=-1 - ps.range(event_ndims - min_event_ndims))
Exemple #3
0
def _independent_log_prob_ratio(p, x, q, y, name=None):
  """Sum-of-diffs log(p(x)/q(y)) for `Independent`s."""
  with tf.name_scope(name or 'independent_log_prob_ratio'):
    checks = []
    if p.validate_args or q.validate_args:
      checks.append(tf.debugging.assert_equal(
          p.reinterpreted_batch_ndims, q.reinterpreted_batch_ndims))
    if p._experimental_use_kahan_sum or q._experimental_use_kahan_sum:  # pylint: disable=protected-access
      sum_fn = lambda x, axis: tfp_math.reduce_kahan_sum(x, axis).total
    else:
      sum_fn = tf.reduce_sum
    with tf.control_dependencies(checks):
      return sum_fn(
          log_prob_ratio.log_prob_ratio(p.distribution, x, q.distribution, y),
          axis=-1 - ps.range(p.reinterpreted_batch_ndims))
def _markov_chain_log_prob_ratio(p, x, q, y, name=None):
    """Implements `log_prob_ratio` for tfd.MarkovChain."""
    with tf.name_scope(name or 'markov_chain_log_prob_ratio'):
        # TODO(davmre): In the case where `p` and `q` have components of the same
        # families (in addition to just both being MarkovChains), we might prefer to
        # recursively call `log_prob_ratio` instead of just subtracting log probs.
        p_prior_lp, p_transition_lps = p._log_prob_parts(x)
        q_prior_lp, q_transition_lps = q._log_prob_parts(y)
        prior_lp_ratio = p_prior_lp - q_prior_lp
        transition_lp_ratios = p_transition_lps - q_transition_lps
        if (p._experimental_use_kahan_sum or q._experimental_use_kahan_sum):
            transition_lp_ratio = tfp_math.reduce_kahan_sum(
                transition_lp_ratios, axis=0).value
        else:
            transition_lp_ratio = tf.reduce_sum(transition_lp_ratios, axis=0)
        return prior_lp_ratio + transition_lp_ratio
Exemple #5
0
  def unnormalized_log_prob_parts(self, value, name=None):
    """Unnormalized log probability density/mass function, part-wise.

    Args:
      value: `list` of `Tensor`s in `distribution_fn` order for which we compute
        the `unnormalized_log_prob_parts` and to parameterize other
        ("downstream") distributions.
      name: name prepended to ops created by this function.
        Default value: `"unnormalized_log_prob_parts"`.

    Returns:
      log_prob_parts: a `tuple` of `Tensor`s representing the `log_prob` for
        each `distribution_fn` evaluated at each corresponding `value`.
    """
    with self._name_and_control_scope(name or 'unnormalized_log_prob_parts'):
      sum_fn = tf.reduce_sum
      if self._experimental_use_kahan_sum:
        sum_fn = lambda x, axis: tfp_math.reduce_kahan_sum(x, axis=axis).total
      return self._model_unflatten(
          self._reduce_measure_over_dists(
              self._map_measure_over_dists('unnormalized_log_prob', value),
              sum_fn))
Exemple #6
0
 def _sum_fn(self):
   if self._experimental_use_kahan_sum:
     return lambda x, axis: tfp_math.reduce_kahan_sum(x, axis).total
   return tf.math.reduce_sum
Exemple #7
0
 def _log_abs_determinant(self):
     return tfp_math.reduce_kahan_sum(tf.math.log(
         tf.math.abs(self._get_diag())),
                                      axis=[-1]).total