コード例 #1
0
def test_gamma_likelihood(alpha: float, beta: float, hybridize: bool) -> None:
    """
    Test to check that maximizing the likelihood recovers the parameters
    """

    # generate samples
    alphas = mx.nd.zeros((NUM_SAMPLES, )) + alpha
    betas = mx.nd.zeros((NUM_SAMPLES, )) + beta

    distr = Gamma(alphas, betas)
    samples = distr.sample()

    init_biases = [
        inv_softplus(alpha - START_TOL_MULTIPLE * TOL * alpha),
        inv_softplus(beta - START_TOL_MULTIPLE * TOL * beta),
    ]

    alpha_hat, beta_hat = maximum_likelihood_estimate_sgd(
        GammaOutput(),
        samples,
        init_biases=init_biases,
        hybridize=hybridize,
        learning_rate=PositiveFloat(0.05),
        num_epochs=PositiveInt(5),
    )

    assert (np.abs(alpha_hat - alpha) < TOL * alpha
            ), f"alpha did not match: alpha = {alpha}, alpha = {alpha_hat}"
    assert (np.abs(beta_hat - beta) < TOL *
            beta), f"beta did not match: beta = {beta}, beta_hat = {beta_hat}"
コード例 #2
0

@pytest.mark.parametrize(
    "distr, expected_batch_shape, expected_event_shape",
    [
        (
            Gaussian(
                mu=mx.nd.zeros(shape=(3, 4, 5)),
                sigma=mx.nd.ones(shape=(3, 4, 5)),
            ),
            (3, 4, 5),
            (),
        ),
        (
            Gamma(
                alpha=mx.nd.ones(shape=(3, 4, 5)),
                beta=mx.nd.ones(shape=(3, 4, 5)),
            ),
            (3, 4, 5),
            (),
        ),
        (
            Beta(
                alpha=mx.nd.ones(shape=(3, 4, 5)),
                beta=mx.nd.ones(shape=(3, 4, 5)),
            ),
            (3, 4, 5),
            (),
        ),
        (
            StudentT(
                mu=mx.nd.zeros(shape=(3, 4, 5)),
コード例 #3
0
         bin_centers=mx.nd.array(np.logspace(-1, 1, 23))
         + mx.nd.zeros(BATCH_SHAPE + (23,)),
     ),
     [
         bij.AffineTransformation(
             scale=1e-1 + mx.nd.random.uniform(shape=BATCH_SHAPE)
         ),
         bij.softrelu,
     ],
 ),
 Gaussian(
     mu=mx.nd.zeros(shape=BATCH_SHAPE),
     sigma=mx.nd.ones(shape=BATCH_SHAPE),
 ),
 Gamma(
     alpha=mx.nd.ones(shape=BATCH_SHAPE),
     beta=mx.nd.ones(shape=BATCH_SHAPE),
 ),
 Beta(
     alpha=0.5 * mx.nd.ones(shape=BATCH_SHAPE),
     beta=0.5 * mx.nd.ones(shape=BATCH_SHAPE),
 ),
 StudentT(
     mu=mx.nd.zeros(shape=BATCH_SHAPE),
     sigma=mx.nd.ones(shape=BATCH_SHAPE),
     nu=mx.nd.ones(shape=BATCH_SHAPE),
 ),
 Dirichlet(alpha=mx.nd.ones(shape=BATCH_SHAPE)),
 Laplace(
     mu=mx.nd.zeros(shape=BATCH_SHAPE), b=mx.nd.ones(shape=BATCH_SHAPE)
 ),
 NegativeBinomial(