def test_neg_binomial(mu_alpha: Tuple[float, float], hybridize: bool) -> None: ''' Test to check that maximizing the likelihood recovers the parameters ''' # test instance mu, alpha = mu_alpha # generate samples mus = mx.nd.zeros((NUM_SAMPLES, )) + mu alphas = mx.nd.zeros((NUM_SAMPLES, )) + alpha neg_bin_distr = NegativeBinomial(mu=mus, alpha=alphas) samples = neg_bin_distr.sample() init_biases = [ inv_softplus(mu - START_TOL_MULTIPLE * TOL * mu), inv_softplus(alpha + START_TOL_MULTIPLE * TOL * alpha), ] mu_hat, alpha_hat = maximum_likelihood_estimate_sgd( NegativeBinomialOutput(), samples, hybridize=hybridize, init_biases=init_biases, num_epochs=PositiveInt(15), ) assert (np.abs(mu_hat - mu) < TOL * mu), f"mu did not match: mu = {mu}, mu_hat = {mu_hat}" assert (np.abs(alpha_hat - alpha) < TOL * alpha ), f"alpha did not match: alpha = {alpha}, alpha_hat = {alpha_hat}"
(Dirichlet(alpha=mx.nd.ones(shape=(3, 4, 5))), (3, 4), (5, )), ( DirichletMultinomial( dim=5, n_trials=9, alpha=mx.nd.ones(shape=(3, 4, 5))), (3, 4), (5, ), ), ( Laplace(mu=mx.nd.zeros(shape=(3, 4, 5)), b=mx.nd.ones(shape=(3, 4, 5))), (3, 4, 5), (), ), ( NegativeBinomial( mu=mx.nd.zeros(shape=(3, 4, 5)), alpha=mx.nd.ones(shape=(3, 4, 5)), ), (3, 4, 5), (), ), ( Uniform( low=-mx.nd.ones(shape=(3, 4, 5)), high=mx.nd.ones(shape=(3, 4, 5)), ), (3, 4, 5), (), ), ( PiecewiseLinear( gamma=mx.nd.ones(shape=(3, 4, 5)),
), Beta( alpha=0.5 * mx.nd.ones(shape=BATCH_SHAPE), beta=0.5 * mx.nd.ones(shape=BATCH_SHAPE), ), StudentT( mu=mx.nd.zeros(shape=BATCH_SHAPE), sigma=mx.nd.ones(shape=BATCH_SHAPE), nu=mx.nd.ones(shape=BATCH_SHAPE), ), Dirichlet(alpha=mx.nd.ones(shape=BATCH_SHAPE)), Laplace( mu=mx.nd.zeros(shape=BATCH_SHAPE), b=mx.nd.ones(shape=BATCH_SHAPE) ), NegativeBinomial( mu=mx.nd.zeros(shape=BATCH_SHAPE), alpha=mx.nd.ones(shape=BATCH_SHAPE), ), Poisson(rate=mx.nd.ones(shape=BATCH_SHAPE)), Uniform( low=-mx.nd.ones(shape=BATCH_SHAPE), high=mx.nd.ones(shape=BATCH_SHAPE), ), PiecewiseLinear( gamma=mx.nd.ones(shape=BATCH_SHAPE), slopes=mx.nd.ones(shape=(3, 4, 5, 10)), knot_spacings=mx.nd.ones(shape=(3, 4, 5, 10)) / 10, ), MixtureDistribution( mixture_probs=mx.nd.stack( 0.2 * mx.nd.ones(shape=BATCH_SHAPE), 0.8 * mx.nd.ones(shape=BATCH_SHAPE),