def test_laplace(mu_b: Tuple[float, float], hybridize: bool) -> None:
    """
    Test to check that maximizing the likelihood recovers the parameters
    """
    # test instance
    mu, b = mu_b

    # generate samples
    mus = mx.nd.zeros((NUM_SAMPLES, )) + mu
    bs = mx.nd.zeros((NUM_SAMPLES, )) + b

    laplace_distr = Laplace(mu=mus, b=bs)
    samples = laplace_distr.sample()

    init_biases = [
        mu - START_TOL_MULTIPLE * TOL * mu,
        inv_softplus(b + START_TOL_MULTIPLE * TOL * b),
    ]

    mu_hat, b_hat = maximum_likelihood_estimate_sgd(LaplaceOutput(),
                                                    samples,
                                                    hybridize=hybridize,
                                                    init_biases=init_biases)

    assert (np.abs(mu_hat - mu) <
            TOL * mu), f"mu did not match: mu = {mu}, mu_hat = {mu_hat}"
    assert (np.abs(b_hat - b) <
            TOL * b), f"b did not match: b = {b}, b_hat = {b_hat}"
Ejemplo n.º 2
0
        )
        < 0.05
    )

    # can only calculated cdf for gaussians currently
    if isinstance(distr1, Gaussian) and isinstance(distr2, Gaussian):
        emp_cdf, edges = empirical_cdf(samples_mix.asnumpy())
        calc_cdf = mixture.cdf(mx.nd.array(edges)).asnumpy()
        assert np.allclose(calc_cdf[1:, :], emp_cdf, atol=1e-2)


@pytest.mark.parametrize(
    "distribution_outputs",
    [
        ((GaussianOutput(), GaussianOutput()),),
        ((GaussianOutput(), StudentTOutput(), LaplaceOutput()),),
        ((MultivariateGaussianOutput(3), MultivariateGaussianOutput(3)),),
    ],
)
@pytest.mark.parametrize("serialize_fn", serialize_fn_list)
def test_mixture_output(distribution_outputs, serialize_fn) -> None:
    mdo = MixtureDistributionOutput(*distribution_outputs)

    args_proj = mdo.get_args_proj()
    args_proj.initialize()

    input = mx.nd.ones(shape=(512, 30))

    distr_args = args_proj(input)
    d = mdo.distribution(distr_args)
    d = serialize_fn(d)
     mx.nd.random.gamma(shape=(3, 4, 5)),
     [None],
     [None],
     (3, 4),
     (5, ),
 ),
 (
     DirichletMultinomialOutput(dim=5, n_trials=10),
     mx.nd.random.gamma(shape=(3, 4, 5)),
     [None],
     [None],
     (3, 4),
     (5, ),
 ),
 (
     LaplaceOutput(),
     mx.nd.random.normal(shape=(3, 4, 5, 6)),
     [None, mx.nd.ones(shape=(3, 4, 5))],
     [None, mx.nd.ones(shape=(3, 4, 5))],
     (3, 4, 5),
     (),
 ),
 (
     NegativeBinomialOutput(),
     mx.nd.random.normal(shape=(3, 4, 5, 6)),
     [None],
     [None, mx.nd.ones(shape=(3, 4, 5))],
     (3, 4, 5),
     (),
 ),
 (