def _reduce_log_l2_exp(loga, logb, axis=-1): return tf.math.reduce_logsumexp( 2 * tfp_math.reduce_weighted_logsumexp( tf.stack([loga, logb], axis=-1), w=[1, -1], axis=-1), axis=axis, )
def _variance(self): concentration1 = tf.convert_to_tensor(self.concentration1) concentration0 = tf.convert_to_tensor(self.concentration0) log_moment2 = self._log_moment( 2, concentration1=concentration1, concentration0=concentration0) log_moment1 = self._log_moment( 1, concentration1=concentration1, concentration0=concentration0) lswe, sign = tfp_math.reduce_weighted_logsumexp( tf.stack([log_moment2, 2 * log_moment1], axis=-1), [1., -1], axis=-1, return_sign=True) return sign * tf.exp(lswe)