def _reduce_log_probs_over_dists(self, lps): if self._experimental_use_kahan_sum: return sum(jd_lib.maybe_check_wont_broadcast( self._reduce_measure_over_dists( lps, reduce_fn=tfp_math.reduce_kahan_sum), self.validate_args)).total else: return sum(jd_lib.maybe_check_wont_broadcast( self._reduce_measure_over_dists(lps, reduce_fn=tf.reduce_sum), self.validate_args))
def _entropy(self): """Shannon entropy in nats.""" if any(self._dist_fn_args): raise ValueError( 'Can only compute entropy when all distributions are independent.') return sum(joint_distribution_lib.maybe_check_wont_broadcast( (d().entropy() for d in self._dist_fn_wrapped), self.validate_args))
def _entropy(self): """Shannon entropy in nats.""" if not all(self._is_distribution_instance): raise ValueError( 'Can only compute entropy when all distributions are independent.' ) return sum( joint_distribution_lib.maybe_check_wont_broadcast( (d.entropy() for d in self.distribution_fn), self.validate_args))
def _cross_entropy(self, other): if (not isinstance(other, JointDistributionSequential) or len(self.model) != len(other.model)): raise ValueError( 'Can only compute cross entropy between `JointDistributionSequential`s ' 'with the same number of component distributions.') if any(self._dist_fn_args) or any(other._dist_fn_args): # pylint: disable=protected-access raise ValueError( 'Can only compute cross entropy when all component distributions ' 'are independent.') return sum(joint_distribution_lib.maybe_check_wont_broadcast( (d0().cross_entropy(d1()) for d0, d1 in zip(self._dist_fn_wrapped, other._dist_fn_wrapped)), # pylint: disable=protected-access self.validate_args))