Exemple #1
0
def test_mixture_inference() -> None:
    mdo = MixtureDistributionOutput([GaussianOutput(), GaussianOutput()])

    args_proj = mdo.get_args_proj()
    args_proj.initialize()
    args_proj.hybridize()

    input = mx.nd.ones((BATCH_SIZE, 1))

    distr_args = args_proj(input)
    d = mdo.distribution(distr_args)

    # plot_samples(d.sample())

    trainer = mx.gluon.Trainer(
        args_proj.collect_params(), "sgd", {"learning_rate": 0.02}
    )

    mixture_samples = mx.nd.array(np_samples)

    N = 1000
    t = tqdm(list(range(N)))
    for i in t:
        with mx.autograd.record():
            distr_args = args_proj(input)
            d = mdo.distribution(distr_args)
            loss = d.loss(mixture_samples)
        loss.backward()
        loss_value = loss.mean().asnumpy()
        t.set_postfix({"loss": loss_value})
        trainer.step(BATCH_SIZE)

    distr_args = args_proj(input)
    d = mdo.distribution(distr_args)

    obtained_hist = histogram(d.sample().asnumpy())

    # uncomment to see histograms
    # pl.plot(obtained_hist)
    # pl.plot(EXPECTED_HIST)
    # pl.show()
    assert diff(obtained_hist, EXPECTED_HIST) < 0.5
Exemple #2
0
def test_mixture_output(distribution_outputs) -> None:
    mdo = MixtureDistributionOutput(*distribution_outputs)

    args_proj = mdo.get_args_proj()
    args_proj.initialize()

    input = mx.nd.ones(shape=(512, 30))

    distr_args = args_proj(input)
    d = mdo.distribution(distr_args)

    samples = d.sample(num_samples=NUM_SAMPLES)

    sample = d.sample()

    assert samples.shape == (NUM_SAMPLES, *sample.shape)

    log_prob = d.log_prob(sample)

    assert log_prob.shape == d.batch_shape