def test_inflated_poisson_likelihood(
    rate: float,
    hybridize: bool,
    zero_probability: float,
) -> None:
    """
    Test to check that maximizing the likelihood recovers the parameters
    """
    # generate samples
    num_samples = 2000  # Required for convergence

    distr = ZeroInflatedPoissonOutput().distribution(distr_args=[
        mx.nd.array([[1 - zero_probability, zero_probability]]),
        mx.nd.array([rate]),
        mx.nd.array([0.0]),
    ])
    distr_output = ZeroInflatedPoissonOutput()

    samples = distr.sample(num_samples).squeeze()

    init_biases = None

    (_, zero_probability_hat), rate_hat, _ = maximum_likelihood_estimate_sgd(
        distr_output=distr_output,
        samples=samples,
        init_biases=init_biases,
        hybridize=hybridize,
        learning_rate=PositiveFloat(0.15),
        num_epochs=PositiveInt(25),
    )

    assert (
        np.abs(zero_probability_hat - zero_probability) <
        TOL * zero_probability
    ), f"zero_probability did not match: zero_probability = {zero_probability}, zero_probability_hat = {zero_probability_hat}"

    assert (np.abs(rate_hat - rate) < TOL *
            rate), f"rate did not match: rate = {rate}, rate_hat = {rate_hat}"
示例#2
0
        DirichletMultinomialOutput(dim=3, n_trials=5),
        DirichletOutput(dim=4),
        EmpiricalDistributionOutput(num_samples=10,
                                    distr_output=GaussianOutput()),
        GammaOutput(),
        GaussianOutput(),
        GenParetoOutput(),
        LaplaceOutput(),
        LogitNormalOutput(),
        LoglogisticOutput(),
        LowrankMultivariateGaussianOutput(dim=5, rank=2),
        MultivariateGaussianOutput(dim=4),
        NegativeBinomialOutput(),
        OneInflatedBetaOutput(),
        PiecewiseLinearOutput(num_pieces=10),
        PoissonOutput(),
        StudentTOutput(),
        UniformOutput(),
        WeibullOutput(),
        ZeroAndOneInflatedBetaOutput(),
        ZeroInflatedBetaOutput(),
        ZeroInflatedNegativeBinomialOutput(),
        ZeroInflatedPoissonOutput(),
    ],
)
def test_distribution_output_serde(distr_output: DistributionOutput):
    distr_output_copy = decode(encode(distr_output))

    assert isinstance(distr_output_copy, type(distr_output))
    assert dump_json(distr_output_copy) == dump_json(distr_output)