bin_log_probs=mx.nd.uniform(shape=BATCH_SHAPE + (23, )),
         bin_centers=mx.nd.array(np.logspace(-1, 1, 23)) +
         mx.nd.zeros(BATCH_SHAPE + (23, )),
     ),
     [
         bij.AffineTransformation(
             scale=1e-1 + mx.nd.random.uniform(shape=BATCH_SHAPE)),
         bij.softrelu,
     ],
 ),
 Gaussian(
     mu=mx.nd.zeros(shape=BATCH_SHAPE),
     sigma=mx.nd.ones(shape=BATCH_SHAPE),
 ),
 Gamma(
     alpha=mx.nd.ones(shape=BATCH_SHAPE),
     beta=mx.nd.ones(shape=BATCH_SHAPE),
 ),
 Beta(
     alpha=0.5 * mx.nd.ones(shape=BATCH_SHAPE),
     beta=0.5 * mx.nd.ones(shape=BATCH_SHAPE),
 ),
 StudentT(
     mu=mx.nd.zeros(shape=BATCH_SHAPE),
     sigma=mx.nd.ones(shape=BATCH_SHAPE),
     nu=mx.nd.ones(shape=BATCH_SHAPE),
 ),
 Dirichlet(alpha=mx.nd.ones(shape=BATCH_SHAPE)),
 Laplace(mu=mx.nd.zeros(shape=BATCH_SHAPE),
         b=mx.nd.ones(shape=BATCH_SHAPE)),
 NegativeBinomial(
     mu=mx.nd.zeros(shape=BATCH_SHAPE),
Example #2
0
        trainer.step(1)

    distr_args = args_proj(input)
    d = mdo.distribution(distr_args)
    return d


@pytest.mark.parametrize(
    "mixture_distribution, mixture_distribution_output, epochs",
    [
        (
            MixtureDistribution(
                mixture_probs=mx.nd.array([[0.6, 0.4]]),
                components=[
                    Gaussian(mu=mx.nd.array([-1.0]), sigma=mx.nd.array([0.2])),
                    Gamma(alpha=mx.nd.array([2.0]), beta=mx.nd.array([0.5])),
                ],
            ),
            MixtureDistributionOutput([GaussianOutput(),
                                       GammaOutput()]),
            2_000,
        ),
        (
            MixtureDistribution(
                mixture_probs=mx.nd.array([[0.7, 0.3]]),
                components=[
                    Gaussian(mu=mx.nd.array([-1.0]), sigma=mx.nd.array([0.2])),
                    GenPareto(xi=mx.nd.array([0.6]), beta=mx.nd.array([1.0])),
                ],
            ),
            MixtureDistributionOutput([GaussianOutput(),
Example #3
0

@pytest.mark.parametrize(
    "distr, expected_batch_shape, expected_event_shape",
    [
        (
            Gaussian(
                mu=mx.nd.zeros(shape=(3, 4, 5)),
                sigma=mx.nd.ones(shape=(3, 4, 5)),
            ),
            (3, 4, 5),
            (),
        ),
        (
            Gamma(
                alpha=mx.nd.ones(shape=(3, 4, 5)),
                beta=mx.nd.ones(shape=(3, 4, 5)),
            ),
            (3, 4, 5),
            (),
        ),
        (
            Beta(
                alpha=mx.nd.ones(shape=(3, 4, 5)),
                beta=mx.nd.ones(shape=(3, 4, 5)),
            ),
            (3, 4, 5),
            (),
        ),
        (
            StudentT(
                mu=mx.nd.zeros(shape=(3, 4, 5)),