def test_dirichlet_multinomial(hybridize: bool) -> None:
    num_samples = 8000
    dim = 3
    n_trials = 500

    alpha = np.array([1.0, 2.0, 3.0])

    distr = DirichletMultinomial(dim=3,
                                 n_trials=n_trials,
                                 alpha=mx.nd.array(alpha))
    cov = distr.variance.asnumpy()

    samples = distr.sample(num_samples)

    alpha_hat = maximum_likelihood_estimate_sgd(
        DirichletMultinomialOutput(dim=dim, n_trials=n_trials),
        samples,
        init_biases=None,
        hybridize=hybridize,
        learning_rate=PositiveFloat(0.05),
        num_epochs=PositiveInt(10),
    )

    distr = DirichletMultinomial(dim=3,
                                 n_trials=n_trials,
                                 alpha=mx.nd.array(alpha_hat))

    cov_hat = distr.variance.asnumpy()

    assert np.allclose(
        alpha_hat, alpha, atol=0.1, rtol=0.1
    ), f"alpha did not match: alpha = {alpha}, alpha_hat = {alpha_hat}"
    assert np.allclose(
        cov_hat, cov, atol=0.1, rtol=0.1
    ), f"Covariance did not match: cov = {cov}, cov_hat = {cov_hat}"
예제 #2
0
         nu=mx.nd.ones(shape=(3, 4, 5)),
     ),
     (3, 4, 5),
     (),
 ),
 (
     MultivariateGaussian(
         mu=mx.nd.zeros(shape=(3, 4, 5)),
         L=make_nd_diag(F=mx.nd, x=mx.nd.ones(shape=(3, 4, 5)), d=5),
     ),
     (3, 4),
     (5, ),
 ),
 (Dirichlet(alpha=mx.nd.ones(shape=(3, 4, 5))), (3, 4), (5, )),
 (
     DirichletMultinomial(
         dim=5, n_trials=9, alpha=mx.nd.ones(shape=(3, 4, 5))),
     (3, 4),
     (5, ),
 ),
 (
     Laplace(mu=mx.nd.zeros(shape=(3, 4, 5)),
             b=mx.nd.ones(shape=(3, 4, 5))),
     (3, 4, 5),
     (),
 ),
 (
     NegativeBinomial(
         mu=mx.nd.zeros(shape=(3, 4, 5)),
         alpha=mx.nd.ones(shape=(3, 4, 5)),
     ),
     (3, 4, 5),