コード例 #1
0
def test_mixture_logprob(
    distribution: Distribution,
    values_outside_support: Tensor,
    distribution_output: DistributionOutput,
) -> None:

    assert np.all(
        ~np.isnan(distribution.log_prob(values_outside_support).asnumpy())
    ), f"{distribution} should return -inf log_probs instead of NaNs"

    p = 0.5
    gaussian = Gaussian(mu=mx.nd.array([0]), sigma=mx.nd.array([2.0]))
    mixture = MixtureDistribution(
        mixture_probs=mx.nd.array([[p, 1 - p]]),
        components=[gaussian, distribution],
    )
    lp = mixture.log_prob(values_outside_support)
    assert np.allclose(
        lp.asnumpy(),
        np.log(p) + gaussian.log_prob(values_outside_support).asnumpy(),
        atol=1e-6,
    ), f"log_prob(x) should be equal to log(p)+gaussian.log_prob(x)"

    fit_mixture = fit_mixture_distribution(
        values_outside_support,
        MixtureDistributionOutput([GaussianOutput(), distribution_output]),
        variate_dimensionality=1,
        epochs=3,
    )
    for ci, c in enumerate(fit_mixture.components):
        for ai, a in enumerate(c.args):
            assert ~np.isnan(a.asnumpy()), f"NaN gradients led to {c}"
コード例 #2
0
ファイル: test_mixture.py プロジェクト: szhengac/gluon-ts
def test_mixture_output(distribution_outputs, serialize_fn) -> None:
    mdo = MixtureDistributionOutput(*distribution_outputs)

    args_proj = mdo.get_args_proj()
    args_proj.initialize()

    input = mx.nd.ones(shape=(512, 30))

    distr_args = args_proj(input)
    d = mdo.distribution(distr_args)
    d = serialize_fn(d)

    samples = d.sample(num_samples=NUM_SAMPLES)

    sample = d.sample()

    assert samples.shape == (NUM_SAMPLES, *sample.shape)

    log_prob = d.log_prob(sample)

    assert log_prob.shape == d.batch_shape
コード例 #3
0
ファイル: test_mixture.py プロジェクト: szhengac/gluon-ts
def test_mixture_inference() -> None:
    mdo = MixtureDistributionOutput([GaussianOutput(), GaussianOutput()])

    args_proj = mdo.get_args_proj()
    args_proj.initialize()
    args_proj.hybridize()

    input = mx.nd.ones((BATCH_SIZE, 1))

    distr_args = args_proj(input)
    d = mdo.distribution(distr_args)

    # plot_samples(d.sample())

    trainer = mx.gluon.Trainer(
        args_proj.collect_params(), "sgd", {"learning_rate": 0.02}
    )

    mixture_samples = mx.nd.array(np_samples)

    N = 1000
    t = tqdm(list(range(N)))
    for i in t:
        with mx.autograd.record():
            distr_args = args_proj(input)
            d = mdo.distribution(distr_args)
            loss = d.loss(mixture_samples)
        loss.backward()
        loss_value = loss.mean().asnumpy()
        t.set_postfix({"loss": loss_value})
        trainer.step(BATCH_SIZE)

    distr_args = args_proj(input)
    d = mdo.distribution(distr_args)

    obtained_hist = histogram(d.sample().asnumpy())

    # uncomment to see histograms
    # pl.plot(obtained_hist)
    # pl.plot(EXPECTED_HIST)
    # pl.show()
    assert diff(obtained_hist, EXPECTED_HIST) < 0.5
コード例 #4
0
    d = mdo.distribution(distr_args)
    return d


@pytest.mark.parametrize(
    "mixture_distribution, mixture_distribution_output, epochs",
    [
        (
            MixtureDistribution(
                mixture_probs=mx.nd.array([[0.6, 0.4]]),
                components=[
                    Gaussian(mu=mx.nd.array([-1.0]), sigma=mx.nd.array([0.2])),
                    Gamma(alpha=mx.nd.array([2.0]), beta=mx.nd.array([0.5])),
                ],
            ),
            MixtureDistributionOutput([GaussianOutput(), GammaOutput()]),
            2_000,
        ),
        (
            MixtureDistribution(
                mixture_probs=mx.nd.array([[0.7, 0.3]]),
                components=[
                    Gaussian(mu=mx.nd.array([-1.0]), sigma=mx.nd.array([0.2])),
                    GenPareto(xi=mx.nd.array([0.6]), beta=mx.nd.array([1.0])),
                ],
            ),
            MixtureDistributionOutput([GaussianOutput(), GenParetoOutput()]),
            2_000,
        ),
    ],
)
コード例 #5
0
     mx.nd.random.normal(shape=(3, 4, 5, 6)),
     [None, mx.nd.ones(shape=(3, 4, 5))],
     [None, mx.nd.ones(shape=(3, 4, 5))],
     (3, 4, 5),
     (),
 ),
 (
     PiecewiseLinearOutput(num_pieces=3),
     mx.nd.random.normal(shape=(3, 4, 5, 6)),
     [None, mx.nd.ones(shape=(3, 4, 5))],
     [None, mx.nd.ones(shape=(3, 4, 5))],
     (3, 4, 5),
     (),
 ),
 (
     MixtureDistributionOutput([GaussianOutput(),
                                StudentTOutput()]),
     mx.nd.random.normal(shape=(3, 4, 5, 6)),
     [None, mx.nd.ones(shape=(3, 4, 5))],
     [None, mx.nd.ones(shape=(3, 4, 5))],
     (3, 4, 5),
     (),
 ),
 (
     MixtureDistributionOutput([
         MultivariateGaussianOutput(dim=5),
         MultivariateGaussianOutput(dim=5),
     ]),
     mx.nd.random.normal(shape=(3, 4, 10)),
     [None, mx.nd.ones(shape=(3, 4, 5))],
     [None, mx.nd.ones(shape=(3, 4, 5))],
     (3, 4),