def _log_abs_determinant(self): log_det = tfp_math.reduce_kahan_sum(tf.math.log(tf.math.abs( self._diag)), axis=[-1]).total if dtype_util.is_complex(self.dtype): log_det = tf.cast(log_det, dtype=self.dtype) return log_det
def inverse_log_det_jacobian_ratio(p, x, q, y, event_ndims, use_kahan_sum=True): """Computes `p.ildj(x, ndims) - q.idlj(y, ndims)`, numerically stably. Args: p: A bijector instance. x: A tensor from the image of `p.forward`. q: A bijector instance of the same type as `p`, with matching shape. y: A tensor from the image of `q.forward`. event_ndims: The number of right-hand dimensions comprising the event shapes of `x` and `y`. use_kahan_sum: When `True`, the reduction of any remaining `event_ndims` beyond the minimum is done using Kahan summation. This requires statically known ranks. Returns: ildj_ratio: `log ((abs o det o jac p^-1)(x) / (abs o det o jac q^-1)(y))`, i.e. in TFP code, `p.inverse_log_det_jacobian(x, event_ndims) - q.inverse_log_det_jacobian(y, event_ndims)`. In some cases this will be computed with better than naive numerical precision, e.g. by moving differences inside of a sum reduction. """ assert type(p) == type(q) # pylint: disable=unidiomatic-typecheck min_event_ndims = p.inverse_min_event_ndims def default_ildj_ratio_fn(p, x, q, y): return (p.inverse_log_det_jacobian(x, event_ndims=min_event_ndims) - q.inverse_log_det_jacobian(y, event_ndims=min_event_ndims)) ildj_ratio_fn = None fldj_ratio_fn = None for cls in inspect.getmro(type(p)): if cls in _ildj_ratio_registry: ildj_ratio_fn = _ildj_ratio_registry[cls] if cls in _fldj_ratio_registry: fldj_ratio_fn = _fldj_ratio_registry[cls] if ildj_ratio_fn is None: if fldj_ratio_fn is None: ildj_ratio_fn = default_ildj_ratio_fn else: # p.ildj(x) - q.ildj(y) = q.fldj(q^-1(y)) - p.fldj(p^-1(x)) ildj_ratio_fn = (lambda p, x, q, y: fldj_ratio_fn( q, q.inverse(y), p, p.inverse(x))) if use_kahan_sum: sum_fn = lambda x, axis: tfp_math.reduce_kahan_sum(x, axis=axis).total else: sum_fn = tf.reduce_sum return sum_fn(ildj_ratio_fn(p, x, q, y), axis=-1 - ps.range(event_ndims - min_event_ndims))
def _independent_log_prob_ratio(p, x, q, y, name=None): """Sum-of-diffs log(p(x)/q(y)) for `Independent`s.""" with tf.name_scope(name or 'independent_log_prob_ratio'): checks = [] if p.validate_args or q.validate_args: checks.append(tf.debugging.assert_equal( p.reinterpreted_batch_ndims, q.reinterpreted_batch_ndims)) if p._experimental_use_kahan_sum or q._experimental_use_kahan_sum: # pylint: disable=protected-access sum_fn = lambda x, axis: tfp_math.reduce_kahan_sum(x, axis).total else: sum_fn = tf.reduce_sum with tf.control_dependencies(checks): return sum_fn( log_prob_ratio.log_prob_ratio(p.distribution, x, q.distribution, y), axis=-1 - ps.range(p.reinterpreted_batch_ndims))
def _markov_chain_log_prob_ratio(p, x, q, y, name=None): """Implements `log_prob_ratio` for tfd.MarkovChain.""" with tf.name_scope(name or 'markov_chain_log_prob_ratio'): # TODO(davmre): In the case where `p` and `q` have components of the same # families (in addition to just both being MarkovChains), we might prefer to # recursively call `log_prob_ratio` instead of just subtracting log probs. p_prior_lp, p_transition_lps = p._log_prob_parts(x) q_prior_lp, q_transition_lps = q._log_prob_parts(y) prior_lp_ratio = p_prior_lp - q_prior_lp transition_lp_ratios = p_transition_lps - q_transition_lps if (p._experimental_use_kahan_sum or q._experimental_use_kahan_sum): transition_lp_ratio = tfp_math.reduce_kahan_sum( transition_lp_ratios, axis=0).value else: transition_lp_ratio = tf.reduce_sum(transition_lp_ratios, axis=0) return prior_lp_ratio + transition_lp_ratio
def unnormalized_log_prob_parts(self, value, name=None): """Unnormalized log probability density/mass function, part-wise. Args: value: `list` of `Tensor`s in `distribution_fn` order for which we compute the `unnormalized_log_prob_parts` and to parameterize other ("downstream") distributions. name: name prepended to ops created by this function. Default value: `"unnormalized_log_prob_parts"`. Returns: log_prob_parts: a `tuple` of `Tensor`s representing the `log_prob` for each `distribution_fn` evaluated at each corresponding `value`. """ with self._name_and_control_scope(name or 'unnormalized_log_prob_parts'): sum_fn = tf.reduce_sum if self._experimental_use_kahan_sum: sum_fn = lambda x, axis: tfp_math.reduce_kahan_sum(x, axis=axis).total return self._model_unflatten( self._reduce_measure_over_dists( self._map_measure_over_dists('unnormalized_log_prob', value), sum_fn))
def _sum_fn(self): if self._experimental_use_kahan_sum: return lambda x, axis: tfp_math.reduce_kahan_sum(x, axis).total return tf.math.reduce_sum
def _log_abs_determinant(self): return tfp_math.reduce_kahan_sum(tf.math.log( tf.math.abs(self._get_diag())), axis=[-1]).total