def test_inflated_beta_likelihood(
    alpha: float,
    beta: float,
    hybridize: bool,
    inflated_at: str,
    zero_probability: float,
    one_probability: float,
) -> None:
    """
    Test to check that maximizing the likelihood recovers the parameters
    """

    # generate samples
    alphas = mx.nd.zeros((NUM_SAMPLES, )) + alpha
    betas = mx.nd.zeros((NUM_SAMPLES, )) + beta

    zero_probabilities = mx.nd.zeros((NUM_SAMPLES, )) + zero_probability
    one_probabilities = mx.nd.zeros((NUM_SAMPLES, )) + one_probability
    if inflated_at == "zero":
        distr = ZeroInflatedBeta(alphas,
                                 betas,
                                 zero_probability=zero_probabilities)
        distr_output = ZeroInflatedBetaOutput()
    elif inflated_at == "one":
        distr = OneInflatedBeta(alphas,
                                betas,
                                one_probability=one_probabilities)
        distr_output = OneInflatedBetaOutput()

    else:
        distr = ZeroAndOneInflatedBeta(
            alphas,
            betas,
            zero_probability=zero_probabilities,
            one_probability=one_probabilities,
        )
        distr_output = ZeroAndOneInflatedBetaOutput()

    samples = distr.sample()

    init_biases = [
        inv_softplus(alpha - START_TOL_MULTIPLE * TOL * alpha),
        inv_softplus(beta - START_TOL_MULTIPLE * TOL * beta),
    ]

    parameters = maximum_likelihood_estimate_sgd(
        distr_output,
        samples,
        init_biases=init_biases,
        hybridize=hybridize,
        learning_rate=PositiveFloat(0.05),
        num_epochs=PositiveInt(10),
    )

    if inflated_at == "zero":
        alpha_hat, beta_hat, zero_probability_hat = parameters

        assert (
            np.abs(zero_probability_hat[0] - zero_probability) <
            TOL * zero_probability
        ), f"zero_probability did not match: zero_probability = {zero_probability}, zero_probability_hat = {zero_probability_hat}"

    elif inflated_at == "one":
        alpha_hat, beta_hat, one_probability_hat = parameters

        assert (
            np.abs(one_probability_hat - one_probability) <
            TOL * one_probability
        ), f"one_probability did not match: one_probability = {one_probability}, one_probability_hat = {one_probability_hat}"
    else:
        (
            alpha_hat,
            beta_hat,
            zero_probability_hat,
            one_probability_hat,
        ) = parameters

        assert (
            np.abs(zero_probability_hat - zero_probability) <
            TOL * zero_probability
        ), f"zero_probability did not match: zero_probability = {zero_probability}, zero_probability_hat = {zero_probability_hat}"
        assert (
            np.abs(one_probability_hat - one_probability) <
            TOL * one_probability
        ), f"one_probability did not match: one_probability = {one_probability}, one_probability_hat = {one_probability_hat}"

    assert (np.abs(alpha_hat - alpha) < TOL * alpha
            ), f"alpha did not match: alpha = {alpha}, alpha_hat = {alpha_hat}"
    assert (np.abs(beta_hat - beta) < TOL *
            beta), f"beta did not match: beta = {beta}, beta_hat = {beta_hat}"
示例#2
0
        DirichletMultinomialOutput(dim=3, n_trials=5),
        DirichletOutput(dim=4),
        EmpiricalDistributionOutput(num_samples=10,
                                    distr_output=GaussianOutput()),
        GammaOutput(),
        GaussianOutput(),
        GenParetoOutput(),
        LaplaceOutput(),
        LogitNormalOutput(),
        LoglogisticOutput(),
        LowrankMultivariateGaussianOutput(dim=5, rank=2),
        MultivariateGaussianOutput(dim=4),
        NegativeBinomialOutput(),
        OneInflatedBetaOutput(),
        PiecewiseLinearOutput(num_pieces=10),
        PoissonOutput(),
        StudentTOutput(),
        UniformOutput(),
        WeibullOutput(),
        ZeroAndOneInflatedBetaOutput(),
        ZeroInflatedBetaOutput(),
        ZeroInflatedNegativeBinomialOutput(),
        ZeroInflatedPoissonOutput(),
    ],
)
def test_distribution_output_serde(distr_output: DistributionOutput):
    distr_output_copy = decode(encode(distr_output))

    assert isinstance(distr_output_copy, type(distr_output))
    assert dump_json(distr_output_copy) == dump_json(distr_output)