def test_beta_likelihood(alpha: float, beta: float, hybridize: bool) -> None: """ Test to check that maximizing the likelihood recovers the parameters """ # generate samples alphas = mx.nd.zeros((NUM_SAMPLES, )) + alpha betas = mx.nd.zeros((NUM_SAMPLES, )) + beta distr = Beta(alphas, betas) samples = distr.sample() init_biases = [ inv_softplus(alpha - START_TOL_MULTIPLE * TOL * alpha), inv_softplus(beta - START_TOL_MULTIPLE * TOL * beta), ] alpha_hat, beta_hat = maximum_likelihood_estimate_sgd( BetaOutput(), samples, init_biases=init_biases, hybridize=hybridize, learning_rate=PositiveFloat(0.05), num_epochs=PositiveInt(10), ) print("ALPHA:", alpha_hat, "BETA:", beta_hat) assert (np.abs(alpha_hat - alpha) < TOL * alpha ), f"alpha did not match: alpha = {alpha}, alpha_hat = {alpha_hat}" assert (np.abs(beta_hat - beta) < TOL * beta), f"beta did not match: beta = {beta}, beta_hat = {beta_hat}"
[ bij.AffineTransformation( scale=1e-1 + mx.nd.random.uniform(shape=BATCH_SHAPE)), bij.softrelu, ], ), Gaussian( mu=mx.nd.zeros(shape=BATCH_SHAPE), sigma=mx.nd.ones(shape=BATCH_SHAPE), ), Gamma( alpha=mx.nd.ones(shape=BATCH_SHAPE), beta=mx.nd.ones(shape=BATCH_SHAPE), ), Beta( alpha=0.5 * mx.nd.ones(shape=BATCH_SHAPE), beta=0.5 * mx.nd.ones(shape=BATCH_SHAPE), ), StudentT( mu=mx.nd.zeros(shape=BATCH_SHAPE), sigma=mx.nd.ones(shape=BATCH_SHAPE), nu=mx.nd.ones(shape=BATCH_SHAPE), ), Dirichlet(alpha=mx.nd.ones(shape=BATCH_SHAPE)), Laplace(mu=mx.nd.zeros(shape=BATCH_SHAPE), b=mx.nd.ones(shape=BATCH_SHAPE)), NegativeBinomial( mu=mx.nd.zeros(shape=BATCH_SHAPE), alpha=mx.nd.ones(shape=BATCH_SHAPE), ), Poisson(rate=mx.nd.ones(shape=BATCH_SHAPE)), Uniform(
sigma=mx.nd.ones(shape=(3, 4, 5)), ), (3, 4, 5), (), ), ( Gamma( alpha=mx.nd.ones(shape=(3, 4, 5)), beta=mx.nd.ones(shape=(3, 4, 5)), ), (3, 4, 5), (), ), ( Beta( alpha=mx.nd.ones(shape=(3, 4, 5)), beta=mx.nd.ones(shape=(3, 4, 5)), ), (3, 4, 5), (), ), ( StudentT( mu=mx.nd.zeros(shape=(3, 4, 5)), sigma=mx.nd.ones(shape=(3, 4, 5)), nu=mx.nd.ones(shape=(3, 4, 5)), ), (3, 4, 5), (), ), ( MultivariateGaussian(