def test_gamma_likelihood(alpha: float, beta: float, hybridize: bool) -> None: """ Test to check that maximizing the likelihood recovers the parameters """ # generate samples alphas = mx.nd.zeros((NUM_SAMPLES, )) + alpha betas = mx.nd.zeros((NUM_SAMPLES, )) + beta distr = Gamma(alphas, betas) samples = distr.sample() init_biases = [ inv_softplus(alpha - START_TOL_MULTIPLE * TOL * alpha), inv_softplus(beta - START_TOL_MULTIPLE * TOL * beta), ] alpha_hat, beta_hat = maximum_likelihood_estimate_sgd( GammaOutput(), samples, init_biases=init_biases, hybridize=hybridize, learning_rate=PositiveFloat(0.05), num_epochs=PositiveInt(5), ) assert (np.abs(alpha_hat - alpha) < TOL * alpha ), f"alpha did not match: alpha = {alpha}, alpha_hat = {alpha_hat}" assert (np.abs(beta_hat - beta) < TOL * beta), f"beta did not match: beta = {beta}, beta_hat = {beta_hat}"
d = mdo.distribution(distr_args) return d @pytest.mark.parametrize( "mixture_distribution, mixture_distribution_output, epochs", [ ( MixtureDistribution( mixture_probs=mx.nd.array([[0.6, 0.4]]), components=[ Gaussian(mu=mx.nd.array([-1.0]), sigma=mx.nd.array([0.2])), Gamma(alpha=mx.nd.array([2.0]), beta=mx.nd.array([0.5])), ], ), MixtureDistributionOutput([GaussianOutput(), GammaOutput()]), 2_000, ), ( MixtureDistribution( mixture_probs=mx.nd.array([[0.7, 0.3]]), components=[ Gaussian(mu=mx.nd.array([-1.0]), sigma=mx.nd.array([0.2])), GenPareto(xi=mx.nd.array([0.6]), beta=mx.nd.array([1.0])), ], ), MixtureDistributionOutput([GaussianOutput(), GenParetoOutput()]), 2_000, ), ], )
return d @pytest.mark.parametrize( "mixture_distribution, mixture_distribution_output, epochs", [ ( MixtureDistribution( mixture_probs=mx.nd.array([[0.6, 0.4]]), components=[ Gaussian(mu=mx.nd.array([-1.0]), sigma=mx.nd.array([0.2])), Gamma(alpha=mx.nd.array([2.0]), beta=mx.nd.array([0.5])), ], ), MixtureDistributionOutput([GaussianOutput(), GammaOutput()]), 2_000, ), ( MixtureDistribution( mixture_probs=mx.nd.array([[0.7, 0.3]]), components=[ Gaussian(mu=mx.nd.array([-1.0]), sigma=mx.nd.array([0.2])), GenPareto(xi=mx.nd.array([0.6]), beta=mx.nd.array([1.0])), ], ), MixtureDistributionOutput([GaussianOutput(), GenParetoOutput()]), 2_000, ), ],
mx.nd.random.normal(shape=(3, 4, 5, 6)), [None, mx.nd.ones(shape=(3, 4, 5))], [None, mx.nd.ones(shape=(3, 4, 5))], (3, 4, 5), (), ), ( StudentTOutput(), mx.nd.random.normal(shape=(3, 4, 5, 6)), [None, mx.nd.ones(shape=(3, 4, 5))], [None, mx.nd.ones(shape=(3, 4, 5))], (3, 4, 5), (), ), ( GammaOutput(), mx.nd.random.gamma(shape=(3, 4, 5, 6)), [None, mx.nd.ones(shape=(3, 4, 5))], [None, mx.nd.ones(shape=(3, 4, 5))], (3, 4, 5), (), ), ( BetaOutput(), mx.nd.random.gamma(shape=(3, 4, 5, 6)), [None, mx.nd.ones(shape=(3, 4, 5))], [None, mx.nd.ones(shape=(3, 4, 5))], (3, 4, 5), (), ), (