def test_poisson_likelihood(rate: float, hybridize: bool) -> None:
    """
    Test to check that maximizing the likelihood recovers the parameters
    """

    # generate samples
    rates = mx.nd.zeros(NUM_SAMPLES) + rate

    distr = Poisson(rates)
    samples = distr.sample()

    init_biases = [inv_softplus(rate - START_TOL_MULTIPLE * TOL * rate)]

    rate_hat = maximum_likelihood_estimate_sgd(
        PoissonOutput(),
        samples,
        init_biases=init_biases,
        hybridize=hybridize,
        learning_rate=PositiveFloat(0.05),
        num_epochs=PositiveInt(20),
    )

    print("rate:", rate_hat)
    assert (np.abs(rate_hat[0] - rate) < TOL *
            rate), f"mu did not match: rate = {rate}, rate_hat = {rate_hat}"
예제 #2
0
     alpha=0.5 * mx.nd.ones(shape=BATCH_SHAPE),
     beta=0.5 * mx.nd.ones(shape=BATCH_SHAPE),
 ),
 StudentT(
     mu=mx.nd.zeros(shape=BATCH_SHAPE),
     sigma=mx.nd.ones(shape=BATCH_SHAPE),
     nu=mx.nd.ones(shape=BATCH_SHAPE),
 ),
 Dirichlet(alpha=mx.nd.ones(shape=BATCH_SHAPE)),
 Laplace(mu=mx.nd.zeros(shape=BATCH_SHAPE),
         b=mx.nd.ones(shape=BATCH_SHAPE)),
 NegativeBinomial(
     mu=mx.nd.zeros(shape=BATCH_SHAPE),
     alpha=mx.nd.ones(shape=BATCH_SHAPE),
 ),
 Poisson(rate=mx.nd.ones(shape=BATCH_SHAPE)),
 Uniform(
     low=-mx.nd.ones(shape=BATCH_SHAPE),
     high=mx.nd.ones(shape=BATCH_SHAPE),
 ),
 PiecewiseLinear(
     gamma=mx.nd.ones(shape=BATCH_SHAPE),
     slopes=mx.nd.ones(shape=(3, 4, 5, 10)),
     knot_spacings=mx.nd.ones(shape=(3, 4, 5, 10)) / 10,
 ),
 MixtureDistribution(
     mixture_probs=mx.nd.stack(
         0.2 * mx.nd.ones(shape=BATCH_SHAPE),
         0.8 * mx.nd.ones(shape=BATCH_SHAPE),
         axis=-1,
     ),
예제 #3
0
 ),
 (
     Laplace(mu=mx.nd.zeros(shape=(3, 4, 5)),
             b=mx.nd.ones(shape=(3, 4, 5))),
     (3, 4, 5),
     (),
 ),
 (
     NegativeBinomial(
         mu=mx.nd.zeros(shape=(3, 4, 5)),
         alpha=mx.nd.ones(shape=(3, 4, 5)),
     ),
     (3, 4, 5),
     (),
 ),
 (Poisson(rate=mx.nd.ones(shape=(3, 4, 5))), (3, 4, 5), ()),
 (
     Uniform(
         low=-mx.nd.ones(shape=(3, 4, 5)),
         high=mx.nd.ones(shape=(3, 4, 5)),
     ),
     (3, 4, 5),
     (),
 ),
 (
     PiecewiseLinear(
         gamma=mx.nd.ones(shape=(3, 4, 5)),
         slopes=mx.nd.ones(shape=(3, 4, 5, 10)),
         knot_spacings=mx.nd.ones(shape=(3, 4, 5, 10)) / 10,
     ),
     (3, 4, 5),