def test_beta_likelihood(alpha: float, beta: float, hybridize: bool) -> None:
    """
    Test to check that maximizing the likelihood recovers the parameters
    """

    # generate samples
    alphas = mx.nd.zeros((NUM_SAMPLES, )) + alpha
    betas = mx.nd.zeros((NUM_SAMPLES, )) + beta

    distr = Beta(alphas, betas)
    samples = distr.sample()

    init_biases = [
        inv_softplus(alpha - START_TOL_MULTIPLE * TOL * alpha),
        inv_softplus(beta - START_TOL_MULTIPLE * TOL * beta),
    ]

    alpha_hat, beta_hat = maximum_likelihood_estimate_sgd(
        BetaOutput(),
        samples,
        init_biases=init_biases,
        hybridize=hybridize,
        learning_rate=PositiveFloat(0.05),
        num_epochs=PositiveInt(10),
    )

    print("ALPHA:", alpha_hat, "BETA:", beta_hat)
    assert (np.abs(alpha_hat - alpha) < TOL * alpha
            ), f"alpha did not match: alpha = {alpha}, alpha_hat = {alpha_hat}"
    assert (np.abs(beta_hat - beta) < TOL *
            beta), f"beta did not match: beta = {beta}, beta_hat = {beta_hat}"
     [
         bij.AffineTransformation(
             scale=1e-1 + mx.nd.random.uniform(shape=BATCH_SHAPE)),
         bij.softrelu,
     ],
 ),
 Gaussian(
     mu=mx.nd.zeros(shape=BATCH_SHAPE),
     sigma=mx.nd.ones(shape=BATCH_SHAPE),
 ),
 Gamma(
     alpha=mx.nd.ones(shape=BATCH_SHAPE),
     beta=mx.nd.ones(shape=BATCH_SHAPE),
 ),
 Beta(
     alpha=0.5 * mx.nd.ones(shape=BATCH_SHAPE),
     beta=0.5 * mx.nd.ones(shape=BATCH_SHAPE),
 ),
 StudentT(
     mu=mx.nd.zeros(shape=BATCH_SHAPE),
     sigma=mx.nd.ones(shape=BATCH_SHAPE),
     nu=mx.nd.ones(shape=BATCH_SHAPE),
 ),
 Dirichlet(alpha=mx.nd.ones(shape=BATCH_SHAPE)),
 Laplace(mu=mx.nd.zeros(shape=BATCH_SHAPE),
         b=mx.nd.ones(shape=BATCH_SHAPE)),
 NegativeBinomial(
     mu=mx.nd.zeros(shape=BATCH_SHAPE),
     alpha=mx.nd.ones(shape=BATCH_SHAPE),
 ),
 Poisson(rate=mx.nd.ones(shape=BATCH_SHAPE)),
 Uniform(
示例#3
0
         sigma=mx.nd.ones(shape=(3, 4, 5)),
     ),
     (3, 4, 5),
     (),
 ),
 (
     Gamma(
         alpha=mx.nd.ones(shape=(3, 4, 5)),
         beta=mx.nd.ones(shape=(3, 4, 5)),
     ),
     (3, 4, 5),
     (),
 ),
 (
     Beta(
         alpha=mx.nd.ones(shape=(3, 4, 5)),
         beta=mx.nd.ones(shape=(3, 4, 5)),
     ),
     (3, 4, 5),
     (),
 ),
 (
     StudentT(
         mu=mx.nd.zeros(shape=(3, 4, 5)),
         sigma=mx.nd.ones(shape=(3, 4, 5)),
         nu=mx.nd.ones(shape=(3, 4, 5)),
     ),
     (3, 4, 5),
     (),
 ),
 (
     MultivariateGaussian(