def test_dirichlet_multinomial(hybridize: bool) -> None: num_samples = 8000 dim = 3 n_trials = 500 alpha = np.array([1.0, 2.0, 3.0]) distr = DirichletMultinomial(dim=3, n_trials=n_trials, alpha=mx.nd.array(alpha)) cov = distr.variance.asnumpy() samples = distr.sample(num_samples) alpha_hat = maximum_likelihood_estimate_sgd( DirichletMultinomialOutput(dim=dim, n_trials=n_trials), samples, init_biases=None, hybridize=hybridize, learning_rate=PositiveFloat(0.05), num_epochs=PositiveInt(10), ) distr = DirichletMultinomial(dim=3, n_trials=n_trials, alpha=mx.nd.array(alpha_hat)) cov_hat = distr.variance.asnumpy() assert np.allclose( alpha_hat, alpha, atol=0.1, rtol=0.1 ), f"alpha did not match: alpha = {alpha}, alpha_hat = {alpha_hat}" assert np.allclose( cov_hat, cov, atol=0.1, rtol=0.1 ), f"Covariance did not match: cov = {cov}, cov_hat = {cov_hat}"
mx.nd.random.normal(shape=(3, 4, 10)), [None, mx.nd.ones(shape=(3, 4, 5))], [None, mx.nd.ones(shape=(3, 4, 5))], (3, 4), (5, ), ), ( DirichletOutput(dim=5), mx.nd.random.gamma(shape=(3, 4, 5)), [None], [None], (3, 4), (5, ), ), ( DirichletMultinomialOutput(dim=5, n_trials=10), mx.nd.random.gamma(shape=(3, 4, 5)), [None], [None], (3, 4), (5, ), ), ( LaplaceOutput(), mx.nd.random.normal(shape=(3, 4, 5, 6)), [None, mx.nd.ones(shape=(3, 4, 5))], [None, mx.nd.ones(shape=(3, 4, 5))], (3, 4, 5), (), ), (
StudentTOutput, UniformOutput, ZeroAndOneInflatedBetaOutput, ZeroInflatedBetaOutput, ZeroInflatedNegativeBinomialOutput, ZeroInflatedPoissonOutput, ) @pytest.mark.parametrize( "distr_output", [ BetaOutput(), CategoricalOutput(num_cats=3), DeterministicOutput(value=42.0), DirichletMultinomialOutput(dim=3, n_trials=5), DirichletOutput(dim=4), EmpiricalDistributionOutput(num_samples=10, distr_output=GaussianOutput()), GammaOutput(), GaussianOutput(), GenParetoOutput(), LaplaceOutput(), LogitNormalOutput(), LoglogisticOutput(), LowrankMultivariateGaussianOutput(dim=5, rank=2), MultivariateGaussianOutput(dim=4), NegativeBinomialOutput(), OneInflatedBetaOutput(), PiecewiseLinearOutput(num_pieces=10), PoissonOutput(),