def test_dirichlet_multinomial(hybridize: bool) -> None:
    num_samples = 8000
    dim = 3
    n_trials = 500

    alpha = np.array([1.0, 2.0, 3.0])

    distr = DirichletMultinomial(dim=3,
                                 n_trials=n_trials,
                                 alpha=mx.nd.array(alpha))
    cov = distr.variance.asnumpy()

    samples = distr.sample(num_samples)

    alpha_hat = maximum_likelihood_estimate_sgd(
        DirichletMultinomialOutput(dim=dim, n_trials=n_trials),
        samples,
        init_biases=None,
        hybridize=hybridize,
        learning_rate=PositiveFloat(0.05),
        num_epochs=PositiveInt(10),
    )

    distr = DirichletMultinomial(dim=3,
                                 n_trials=n_trials,
                                 alpha=mx.nd.array(alpha_hat))

    cov_hat = distr.variance.asnumpy()

    assert np.allclose(
        alpha_hat, alpha, atol=0.1, rtol=0.1
    ), f"alpha did not match: alpha = {alpha}, alpha_hat = {alpha_hat}"
    assert np.allclose(
        cov_hat, cov, atol=0.1, rtol=0.1
    ), f"Covariance did not match: cov = {cov}, cov_hat = {cov_hat}"
     mx.nd.random.normal(shape=(3, 4, 10)),
     [None, mx.nd.ones(shape=(3, 4, 5))],
     [None, mx.nd.ones(shape=(3, 4, 5))],
     (3, 4),
     (5, ),
 ),
 (
     DirichletOutput(dim=5),
     mx.nd.random.gamma(shape=(3, 4, 5)),
     [None],
     [None],
     (3, 4),
     (5, ),
 ),
 (
     DirichletMultinomialOutput(dim=5, n_trials=10),
     mx.nd.random.gamma(shape=(3, 4, 5)),
     [None],
     [None],
     (3, 4),
     (5, ),
 ),
 (
     LaplaceOutput(),
     mx.nd.random.normal(shape=(3, 4, 5, 6)),
     [None, mx.nd.ones(shape=(3, 4, 5))],
     [None, mx.nd.ones(shape=(3, 4, 5))],
     (3, 4, 5),
     (),
 ),
 (
Exemplo n.º 3
0
    StudentTOutput,
    UniformOutput,
    ZeroAndOneInflatedBetaOutput,
    ZeroInflatedBetaOutput,
    ZeroInflatedNegativeBinomialOutput,
    ZeroInflatedPoissonOutput,
)


@pytest.mark.parametrize(
    "distr_output",
    [
        BetaOutput(),
        CategoricalOutput(num_cats=3),
        DeterministicOutput(value=42.0),
        DirichletMultinomialOutput(dim=3, n_trials=5),
        DirichletOutput(dim=4),
        EmpiricalDistributionOutput(num_samples=10,
                                    distr_output=GaussianOutput()),
        GammaOutput(),
        GaussianOutput(),
        GenParetoOutput(),
        LaplaceOutput(),
        LogitNormalOutput(),
        LoglogisticOutput(),
        LowrankMultivariateGaussianOutput(dim=5, rank=2),
        MultivariateGaussianOutput(dim=4),
        NegativeBinomialOutput(),
        OneInflatedBetaOutput(),
        PiecewiseLinearOutput(num_pieces=10),
        PoissonOutput(),