Beispiel #1
0
def test_log_binomial_stirling(tol):
    k = torch.arange(200.)
    n_minus_k = k.unsqueeze(-1)
    n = k + n_minus_k

    # Test binomial coefficient choose(n, k).
    expected = (n + 1).lgamma() - (k + 1).lgamma() - (n_minus_k + 1).lgamma()
    actual = log_binomial(n, k, tol=tol)

    assert (actual - expected).abs().max() < tol
Beispiel #2
0
    def log_prob(self, value):
        if self._validate_args:
            self._validate_sample(value)

        n = self.total_count
        k = value
        a = self.concentration1
        b = self.concentration0
        tol = self.approx_log_prob_tol
        return (log_binomial(n, k, tol) + log_beta(k + a, n - k + b, tol) -
                log_beta(a, b, tol))
Beispiel #3
0
    def log_prob(self, value):
        if self._validate_args:
            self._validate_sample(value)

        n = self.total_count
        k = value
        # k * log(p) + (n - k) * log(1 - p) = k * (log(p) - log(1 - p)) + n * log(1 - p)
        #     (case logit < 0)              = k * logit - n * log1p(e^logit)
        #     (case logit > 0)              = k * logit - n * (log(p) - log(1 - p)) + n * log(p)
        #                                   = k * logit - n * logit - n * log1p(e^-logit)
        #     (merge two cases)             = k * logit - n * max(logit, 0) - n * log1p(e^-|logit|)
        normalize_term = n * (_clamp_by_zero(self.logits) + self.logits.abs().neg().exp().log1p())
        return (k * self.logits - normalize_term
                + log_binomial(n, k, tol=self.approx_log_prob_tol))