def test_nanmixture_gaussian_inference() -> None: nmdo = NanMixtureOutput(GaussianOutput()) args_proj = nmdo.get_args_proj() args_proj.initialize() args_proj.hybridize() input = mx.nd.ones((NUM_SAMPLES)) trainer = mx.gluon.Trainer( args_proj.collect_params(), "sgd", {"learning_rate": 0.00001} ) mixture_samples = mx.nd.array(np_samples) N = 1000 t = tqdm(list(range(N))) for _ in t: with mx.autograd.record(): distr_args = args_proj(input) d = nmdo.distribution(distr_args) loss = d.loss(mixture_samples) loss.backward() loss_value = loss.mean().asnumpy() t.set_postfix({"loss": loss_value}) trainer.step(NUM_SAMPLES) mu_hat = d.distribution.mu.asnumpy() sigma_hat = d.distribution.sigma.asnumpy() nan_prob_hat = d.nan_prob.asnumpy() assert ( np.abs(mu - mu_hat) < TOL ), f"mu did not match: mu = {mu}, mu_hat = {mu_hat}" assert ( np.abs(sigma - sigma_hat) < TOL ), f"sigma did not match: sigma = {sigma}, sigma_hat = {sigma_hat}" assert ( np.abs(nan_prob - nan_prob_hat) < TOL ), f"nan_prob did not match: nan_prob = {nan_prob}, nan_prob_hat = {nan_prob_hat}"
def test_nanmixture_categorical_inference() -> None: nmdo = NanMixtureOutput(CategoricalOutput(3)) args_proj = nmdo.get_args_proj() args_proj.initialize() args_proj.hybridize() input = mx.nd.ones((NUM_SAMPLES)) trainer = mx.gluon.Trainer( args_proj.collect_params(), "sgd", {"learning_rate": 0.000002} ) mixture_samples = mx.nd.array(cat_samples) N = 3000 t = tqdm(list(range(N))) for _ in t: with mx.autograd.record(): distr_args = args_proj(input) d = nmdo.distribution(distr_args) loss = d.loss(mixture_samples) loss.backward() loss_value = loss.mean().asnumpy() t.set_postfix({"loss": loss_value}) trainer.step(NUM_SAMPLES) distr_args = args_proj(input) d = nmdo.distribution(distr_args) cat_probs_hat = d.distribution.probs.asnumpy() nan_prob_hat = d.nan_prob.asnumpy() assert np.allclose( cat_probs, cat_probs_hat, atol=TOL ), f"categorical dist: cat_probs did not match: cat_probs = {cat_probs}, cat_probs_hat = {cat_probs_hat}" assert ( np.abs(nan_prob - nan_prob_hat) < TOL ), f"categorical dist: nan_prob did not match: nan_prob = {nan_prob}, nan_prob_hat = {nan_prob_hat}"
def test_nanmixture_output(distribution_output, serialize_fn) -> None: nmdo = NanMixtureOutput(distribution_output) args_proj = nmdo.get_args_proj() args_proj.initialize() input = mx.nd.ones(shape=(3, 2)) distr_args = args_proj(input) d = nmdo.distribution(distr_args) d = serialize_fn(d) samples = d.sample(num_samples=NUM_SAMPLES) sample = d.sample() assert samples.shape == (NUM_SAMPLES, *sample.shape) log_prob = d.log_prob(sample) assert log_prob.shape == d.batch_shape