def test_log_binomial_stirling(tol): k = torch.arange(200.) n_minus_k = k.unsqueeze(-1) n = k + n_minus_k # Test binomial coefficient choose(n, k). expected = (n + 1).lgamma() - (k + 1).lgamma() - (n_minus_k + 1).lgamma() actual = log_binomial(n, k, tol=tol) assert (actual - expected).abs().max() < tol
def log_prob(self, value): if self._validate_args: self._validate_sample(value) n = self.total_count k = value a = self.concentration1 b = self.concentration0 tol = self.approx_log_prob_tol return (log_binomial(n, k, tol) + log_beta(k + a, n - k + b, tol) - log_beta(a, b, tol))
def log_prob(self, value): if self._validate_args: self._validate_sample(value) n = self.total_count k = value # k * log(p) + (n - k) * log(1 - p) = k * (log(p) - log(1 - p)) + n * log(1 - p) # (case logit < 0) = k * logit - n * log1p(e^logit) # (case logit > 0) = k * logit - n * (log(p) - log(1 - p)) + n * log(p) # = k * logit - n * logit - n * log1p(e^-logit) # (merge two cases) = k * logit - n * max(logit, 0) - n * log1p(e^-|logit|) normalize_term = n * (_clamp_by_zero(self.logits) + self.logits.abs().neg().exp().log1p()) return (k * self.logits - normalize_term + log_binomial(n, k, tol=self.approx_log_prob_tol))