def test_mixture_inference() -> None: mdo = MixtureDistributionOutput([GaussianOutput(), GaussianOutput()]) args_proj = mdo.get_args_proj() args_proj.initialize() args_proj.hybridize() input = mx.nd.ones((BATCH_SIZE, 1)) distr_args = args_proj(input) d = mdo.distribution(distr_args) # plot_samples(d.sample()) trainer = mx.gluon.Trainer( args_proj.collect_params(), "sgd", {"learning_rate": 0.02} ) mixture_samples = mx.nd.array(np_samples) N = 1000 t = tqdm(list(range(N))) for i in t: with mx.autograd.record(): distr_args = args_proj(input) d = mdo.distribution(distr_args) loss = d.loss(mixture_samples) loss.backward() loss_value = loss.mean().asnumpy() t.set_postfix({"loss": loss_value}) trainer.step(BATCH_SIZE) distr_args = args_proj(input) d = mdo.distribution(distr_args) obtained_hist = histogram(d.sample().asnumpy()) # uncomment to see histograms # pl.plot(obtained_hist) # pl.plot(EXPECTED_HIST) # pl.show() assert diff(obtained_hist, EXPECTED_HIST) < 0.5
def test_mixture_output(distribution_outputs, serialize_fn) -> None: mdo = MixtureDistributionOutput(*distribution_outputs) args_proj = mdo.get_args_proj() args_proj.initialize() input = mx.nd.ones(shape=(512, 30)) distr_args = args_proj(input) d = mdo.distribution(distr_args) d = serialize_fn(d) samples = d.sample(num_samples=NUM_SAMPLES) sample = d.sample() assert samples.shape == (NUM_SAMPLES, *sample.shape) log_prob = d.log_prob(sample) assert log_prob.shape == d.batch_shape