def test_neg_binomial(total_count: float, logit: float) -> None:
    """
    Test to check that maximizing the likelihood recovers the parameters
    """
    # generate samples
    total_counts = torch.zeros((NUM_SAMPLES,)) + total_count
    logits = torch.zeros((NUM_SAMPLES,)) + logit

    neg_bin_distr = NegativeBinomial(total_count=total_counts, logits=logits)
    samples = neg_bin_distr.sample()

    init_biases = [
        inv_softplus(total_count - START_TOL_MULTIPLE * TOL * total_count),
        logit - START_TOL_MULTIPLE * TOL * logit,
    ]

    total_count_hat, logit_hat = maximum_likelihood_estimate_sgd(
        NegativeBinomialOutput(),
        samples,
        init_biases=init_biases,
        num_epochs=15,
    )

    assert (
        np.abs(total_count_hat - total_count) < TOL * total_count
    ), f"total_count did not match: total_count = {total_count}, total_count_hat = {total_count_hat}"
    assert (
        np.abs(logit_hat - logit) < TOL * logit
    ), f"logit did not match: logit = {logit}, logit_hat = {logit_hat}"
def test_neg_binomial(mu_alpha: Tuple[float, float]) -> None:
    """
    Test to check that maximizing the likelihood recovers the parameters
    """
    # test instance
    mu, alpha = mu_alpha

    # generate samples
    mus = torch.zeros((NUM_SAMPLES, )) + mu
    alphas = torch.zeros((NUM_SAMPLES, )) + alpha

    neg_bin_distr = NegativeBinomial(total_count=1.0 / alphas,
                                     probs=mus * alphas / (1.0 + mus * alphas))
    samples = neg_bin_distr.sample()

    init_biases = [
        inv_softplus(mu - START_TOL_MULTIPLE * TOL * mu),
        inv_softplus(alpha + START_TOL_MULTIPLE * TOL * alpha),
    ]

    mu_hat, alpha_hat = maximum_likelihood_estimate_sgd(
        NegativeBinomialOutput(),
        samples,
        init_biases=init_biases,
        num_epochs=15,
    )

    assert (np.abs(mu_hat - mu) <
            TOL * mu), f"mu did not match: mu = {mu}, mu_hat = {mu_hat}"
    assert (np.abs(alpha_hat - alpha) < TOL * alpha
            ), f"alpha did not match: alpha = {alpha}, alpha_hat = {alpha_hat}"
Exemplo n.º 3
0
def test_zinb_0_gate(total_count, probs):
    # if gate is 0 ZINB is NegativeBinomial
    zinb_ = ZeroInflatedNegativeBinomial(torch.zeros(1),
                                         total_count=torch.tensor(total_count),
                                         probs=torch.tensor(probs))
    neg_bin = NegativeBinomial(torch.tensor(total_count),
                               probs=torch.tensor(probs))
    s = neg_bin.sample((20, ))
    zinb_prob = zinb_.log_prob(s)
    neg_bin_prob = neg_bin.log_prob(s)
    assert_close(zinb_prob, neg_bin_prob, atol=1e-06)