def test_dirichlet(hybridize: bool) -> None: num_samples = 2000 dim = 3 alpha = np.array([1.0, 2.0, 3.0]) distr = Dirichlet(alpha=mx.nd.array(alpha)) cov = distr.variance.asnumpy() samples = distr.sample(num_samples) alpha_hat = maximum_likelihood_estimate_sgd( DirichletOutput(dim=dim), samples, init_biases=None, hybridize=hybridize, learning_rate=PositiveFloat(0.05), num_epochs=PositiveInt(10), ) distr = Dirichlet(alpha=mx.nd.array(alpha_hat)) cov_hat = distr.variance.asnumpy() assert np.allclose( alpha_hat, alpha, atol=0.1, rtol=0.1 ), f"alpha did not match: alpha = {alpha}, alpha_hat = {alpha_hat}" assert np.allclose( cov_hat, cov, atol=0.1, rtol=0.1 ), f"Covariance did not match: cov = {cov}, cov_hat = {cov_hat}"
sigma=mx.nd.ones(shape=BATCH_SHAPE), ), Gamma( alpha=mx.nd.ones(shape=BATCH_SHAPE), beta=mx.nd.ones(shape=BATCH_SHAPE), ), Beta( alpha=0.5 * mx.nd.ones(shape=BATCH_SHAPE), beta=0.5 * mx.nd.ones(shape=BATCH_SHAPE), ), StudentT( mu=mx.nd.zeros(shape=BATCH_SHAPE), sigma=mx.nd.ones(shape=BATCH_SHAPE), nu=mx.nd.ones(shape=BATCH_SHAPE), ), Dirichlet(alpha=mx.nd.ones(shape=BATCH_SHAPE)), Laplace(mu=mx.nd.zeros(shape=BATCH_SHAPE), b=mx.nd.ones(shape=BATCH_SHAPE)), NegativeBinomial( mu=mx.nd.zeros(shape=BATCH_SHAPE), alpha=mx.nd.ones(shape=BATCH_SHAPE), ), Poisson(rate=mx.nd.ones(shape=BATCH_SHAPE)), Uniform( low=-mx.nd.ones(shape=BATCH_SHAPE), high=mx.nd.ones(shape=BATCH_SHAPE), ), PiecewiseLinear( gamma=mx.nd.ones(shape=BATCH_SHAPE), slopes=mx.nd.ones(shape=(3, 4, 5, 10)), knot_spacings=mx.nd.ones(shape=(3, 4, 5, 10)) / 10,
mu=mx.nd.zeros(shape=(3, 4, 5)), sigma=mx.nd.ones(shape=(3, 4, 5)), nu=mx.nd.ones(shape=(3, 4, 5)), ), (3, 4, 5), (), ), ( MultivariateGaussian( mu=mx.nd.zeros(shape=(3, 4, 5)), L=make_nd_diag(F=mx.nd, x=mx.nd.ones(shape=(3, 4, 5)), d=5), ), (3, 4), (5, ), ), (Dirichlet(alpha=mx.nd.ones(shape=(3, 4, 5))), (3, 4), (5, )), ( DirichletMultinomial( dim=5, n_trials=9, alpha=mx.nd.ones(shape=(3, 4, 5))), (3, 4), (5, ), ), ( Laplace(mu=mx.nd.zeros(shape=(3, 4, 5)), b=mx.nd.ones(shape=(3, 4, 5))), (3, 4, 5), (), ), ( NegativeBinomial( mu=mx.nd.zeros(shape=(3, 4, 5)),