Exemple #1
0
def test_nanmixture_gaussian_inference() -> None:

    nmdo = NanMixtureOutput(GaussianOutput())

    args_proj = nmdo.get_args_proj()
    args_proj.initialize()
    args_proj.hybridize()

    input = mx.nd.ones((NUM_SAMPLES))

    trainer = mx.gluon.Trainer(
        args_proj.collect_params(), "sgd", {"learning_rate": 0.00001}
    )

    mixture_samples = mx.nd.array(np_samples)

    N = 1000
    t = tqdm(list(range(N)))
    for _ in t:
        with mx.autograd.record():
            distr_args = args_proj(input)
            d = nmdo.distribution(distr_args)
            loss = d.loss(mixture_samples)
        loss.backward()

        loss_value = loss.mean().asnumpy()
        t.set_postfix({"loss": loss_value})
        trainer.step(NUM_SAMPLES)

    mu_hat = d.distribution.mu.asnumpy()
    sigma_hat = d.distribution.sigma.asnumpy()
    nan_prob_hat = d.nan_prob.asnumpy()

    assert (
        np.abs(mu - mu_hat) < TOL
    ), f"mu did not match: mu = {mu}, mu_hat = {mu_hat}"
    assert (
        np.abs(sigma - sigma_hat) < TOL
    ), f"sigma did not match: sigma = {sigma}, sigma_hat = {sigma_hat}"
    assert (
        np.abs(nan_prob - nan_prob_hat) < TOL
    ), f"nan_prob did not match: nan_prob = {nan_prob}, nan_prob_hat = {nan_prob_hat}"
Exemple #2
0
def test_nanmixture_categorical_inference() -> None:

    nmdo = NanMixtureOutput(CategoricalOutput(3))

    args_proj = nmdo.get_args_proj()
    args_proj.initialize()
    args_proj.hybridize()

    input = mx.nd.ones((NUM_SAMPLES))

    trainer = mx.gluon.Trainer(
        args_proj.collect_params(), "sgd", {"learning_rate": 0.000002}
    )

    mixture_samples = mx.nd.array(cat_samples)

    N = 3000
    t = tqdm(list(range(N)))
    for _ in t:
        with mx.autograd.record():
            distr_args = args_proj(input)
            d = nmdo.distribution(distr_args)
            loss = d.loss(mixture_samples)
        loss.backward()
        loss_value = loss.mean().asnumpy()
        t.set_postfix({"loss": loss_value})
        trainer.step(NUM_SAMPLES)

    distr_args = args_proj(input)
    d = nmdo.distribution(distr_args)

    cat_probs_hat = d.distribution.probs.asnumpy()
    nan_prob_hat = d.nan_prob.asnumpy()

    assert np.allclose(
        cat_probs, cat_probs_hat, atol=TOL
    ), f"categorical dist: cat_probs did not match: cat_probs = {cat_probs}, cat_probs_hat = {cat_probs_hat}"
    assert (
        np.abs(nan_prob - nan_prob_hat) < TOL
    ), f"categorical dist: nan_prob did not match: nan_prob = {nan_prob}, nan_prob_hat = {nan_prob_hat}"
Exemple #3
0
def test_nanmixture_output(distribution_output, serialize_fn) -> None:

    nmdo = NanMixtureOutput(distribution_output)

    args_proj = nmdo.get_args_proj()
    args_proj.initialize()

    input = mx.nd.ones(shape=(3, 2))

    distr_args = args_proj(input)

    d = nmdo.distribution(distr_args)
    d = serialize_fn(d)

    samples = d.sample(num_samples=NUM_SAMPLES)

    sample = d.sample()

    assert samples.shape == (NUM_SAMPLES, *sample.shape)

    log_prob = d.log_prob(sample)

    assert log_prob.shape == d.batch_shape