def test_deterministic_output(value: float, model_output_shape):
    do = DeterministicOutput(value)
    x = mx.nd.ones(model_output_shape)

    args_proj = do.get_args_proj()
    args_proj.initialize()
    args = args_proj(x)
    distr = do.distribution(args)

    s = distr.sample()

    assert ((s == value *
             mx.nd.ones(shape=model_output_shape[:-1])).asnumpy().all())

    assert (distr.prob(s) == 1.0).asnumpy().all()

    s10 = distr.sample(10)

    assert (
        (s10 == value *
         mx.nd.ones(shape=(10, ) + model_output_shape[:-1])).asnumpy().all())

    assert (distr.prob(s10) == 1.0).asnumpy().all()
            mx.nd.random.normal(shape=(3, 4, 10)),
            [None, mx.nd.ones(shape=(3, 4, 5))],
            [None, mx.nd.ones(shape=(3, 4, 5))],
            (3, 4),
            (5, ),
        ),
        (
            PoissonOutput(),
            mx.nd.random.normal(shape=(3, 4, 5, 6)),
            [None],
            [None, mx.nd.ones(shape=(3, 4, 5))],
            (3, 4, 5),
            (),
        ),
        (
            DeterministicOutput(42.0),
            mx.nd.random.normal(shape=(3, 4, 5, 6)),
            [None, mx.nd.ones(shape=(3, 4, 5))],
            [None, mx.nd.ones(shape=(3, 4, 5))],
            (3, 4, 5),
            (),
        ),
    ],
)
def test_distribution_output_shapes(
    distr_out: DistributionOutput,
    data: Tensor,
    loc: List[Union[None, Tensor]],
    scale: List[Union[None, Tensor]],
    expected_batch_shape: Tuple,
    expected_event_shape: Tuple,
Example #3
0
    PoissonOutput,
    StudentTOutput,
    UniformOutput,
    ZeroAndOneInflatedBetaOutput,
    ZeroInflatedBetaOutput,
    ZeroInflatedNegativeBinomialOutput,
    ZeroInflatedPoissonOutput,
)


@pytest.mark.parametrize(
    "distr_output",
    [
        BetaOutput(),
        CategoricalOutput(num_cats=3),
        DeterministicOutput(value=42.0),
        DirichletMultinomialOutput(dim=3, n_trials=5),
        DirichletOutput(dim=4),
        EmpiricalDistributionOutput(num_samples=10,
                                    distr_output=GaussianOutput()),
        GammaOutput(),
        GaussianOutput(),
        GenParetoOutput(),
        LaplaceOutput(),
        LogitNormalOutput(),
        LoglogisticOutput(),
        LowrankMultivariateGaussianOutput(dim=5, rank=2),
        MultivariateGaussianOutput(dim=4),
        NegativeBinomialOutput(),
        OneInflatedBetaOutput(),
        PiecewiseLinearOutput(num_pieces=10),