def test_dirichlet_multinomial(hybridize: bool) -> None: num_samples = 8000 dim = 3 n_trials = 500 alpha = np.array([1.0, 2.0, 3.0]) distr = DirichletMultinomial(dim=3, n_trials=n_trials, alpha=mx.nd.array(alpha)) cov = distr.variance.asnumpy() samples = distr.sample(num_samples) alpha_hat = maximum_likelihood_estimate_sgd( DirichletMultinomialOutput(dim=dim, n_trials=n_trials), samples, init_biases=None, hybridize=hybridize, learning_rate=PositiveFloat(0.05), num_epochs=PositiveInt(10), ) distr = DirichletMultinomial(dim=3, n_trials=n_trials, alpha=mx.nd.array(alpha_hat)) cov_hat = distr.variance.asnumpy() assert np.allclose( alpha_hat, alpha, atol=0.1, rtol=0.1 ), f"alpha did not match: alpha = {alpha}, alpha_hat = {alpha_hat}" assert np.allclose( cov_hat, cov, atol=0.1, rtol=0.1 ), f"Covariance did not match: cov = {cov}, cov_hat = {cov_hat}"
nu=mx.nd.ones(shape=(3, 4, 5)), ), (3, 4, 5), (), ), ( MultivariateGaussian( mu=mx.nd.zeros(shape=(3, 4, 5)), L=make_nd_diag(F=mx.nd, x=mx.nd.ones(shape=(3, 4, 5)), d=5), ), (3, 4), (5, ), ), (Dirichlet(alpha=mx.nd.ones(shape=(3, 4, 5))), (3, 4), (5, )), ( DirichletMultinomial( dim=5, n_trials=9, alpha=mx.nd.ones(shape=(3, 4, 5))), (3, 4), (5, ), ), ( Laplace(mu=mx.nd.zeros(shape=(3, 4, 5)), b=mx.nd.ones(shape=(3, 4, 5))), (3, 4, 5), (), ), ( NegativeBinomial( mu=mx.nd.zeros(shape=(3, 4, 5)), alpha=mx.nd.ones(shape=(3, 4, 5)), ), (3, 4, 5),