def test_mixture_logprob( distribution: Distribution, values_outside_support: Tensor, distribution_output: DistributionOutput, ) -> None: assert np.all( ~np.isnan(distribution.log_prob(values_outside_support).asnumpy()) ), f"{distribution} should return -inf log_probs instead of NaNs" p = 0.5 gaussian = Gaussian(mu=mx.nd.array([0]), sigma=mx.nd.array([2.0])) mixture = MixtureDistribution( mixture_probs=mx.nd.array([[p, 1 - p]]), components=[gaussian, distribution], ) lp = mixture.log_prob(values_outside_support) assert np.allclose( lp.asnumpy(), np.log(p) + gaussian.log_prob(values_outside_support).asnumpy(), atol=1e-6, ), f"log_prob(x) should be equal to log(p)+gaussian.log_prob(x)" fit_mixture = fit_mixture_distribution( values_outside_support, MixtureDistributionOutput([GaussianOutput(), distribution_output]), variate_dimensionality=1, epochs=3, ) for ci, c in enumerate(fit_mixture.components): for ai, a in enumerate(c.args): assert ~np.isnan(a.asnumpy()), f"NaN gradients led to {c}"
def test_mixture_output(distribution_outputs, serialize_fn) -> None: mdo = MixtureDistributionOutput(*distribution_outputs) args_proj = mdo.get_args_proj() args_proj.initialize() input = mx.nd.ones(shape=(512, 30)) distr_args = args_proj(input) d = mdo.distribution(distr_args) d = serialize_fn(d) samples = d.sample(num_samples=NUM_SAMPLES) sample = d.sample() assert samples.shape == (NUM_SAMPLES, *sample.shape) log_prob = d.log_prob(sample) assert log_prob.shape == d.batch_shape
def test_mixture_inference() -> None: mdo = MixtureDistributionOutput([GaussianOutput(), GaussianOutput()]) args_proj = mdo.get_args_proj() args_proj.initialize() args_proj.hybridize() input = mx.nd.ones((BATCH_SIZE, 1)) distr_args = args_proj(input) d = mdo.distribution(distr_args) # plot_samples(d.sample()) trainer = mx.gluon.Trainer( args_proj.collect_params(), "sgd", {"learning_rate": 0.02} ) mixture_samples = mx.nd.array(np_samples) N = 1000 t = tqdm(list(range(N))) for i in t: with mx.autograd.record(): distr_args = args_proj(input) d = mdo.distribution(distr_args) loss = d.loss(mixture_samples) loss.backward() loss_value = loss.mean().asnumpy() t.set_postfix({"loss": loss_value}) trainer.step(BATCH_SIZE) distr_args = args_proj(input) d = mdo.distribution(distr_args) obtained_hist = histogram(d.sample().asnumpy()) # uncomment to see histograms # pl.plot(obtained_hist) # pl.plot(EXPECTED_HIST) # pl.show() assert diff(obtained_hist, EXPECTED_HIST) < 0.5
d = mdo.distribution(distr_args) return d @pytest.mark.parametrize( "mixture_distribution, mixture_distribution_output, epochs", [ ( MixtureDistribution( mixture_probs=mx.nd.array([[0.6, 0.4]]), components=[ Gaussian(mu=mx.nd.array([-1.0]), sigma=mx.nd.array([0.2])), Gamma(alpha=mx.nd.array([2.0]), beta=mx.nd.array([0.5])), ], ), MixtureDistributionOutput([GaussianOutput(), GammaOutput()]), 2_000, ), ( MixtureDistribution( mixture_probs=mx.nd.array([[0.7, 0.3]]), components=[ Gaussian(mu=mx.nd.array([-1.0]), sigma=mx.nd.array([0.2])), GenPareto(xi=mx.nd.array([0.6]), beta=mx.nd.array([1.0])), ], ), MixtureDistributionOutput([GaussianOutput(), GenParetoOutput()]), 2_000, ), ], )
mx.nd.random.normal(shape=(3, 4, 5, 6)), [None, mx.nd.ones(shape=(3, 4, 5))], [None, mx.nd.ones(shape=(3, 4, 5))], (3, 4, 5), (), ), ( PiecewiseLinearOutput(num_pieces=3), mx.nd.random.normal(shape=(3, 4, 5, 6)), [None, mx.nd.ones(shape=(3, 4, 5))], [None, mx.nd.ones(shape=(3, 4, 5))], (3, 4, 5), (), ), ( MixtureDistributionOutput([GaussianOutput(), StudentTOutput()]), mx.nd.random.normal(shape=(3, 4, 5, 6)), [None, mx.nd.ones(shape=(3, 4, 5))], [None, mx.nd.ones(shape=(3, 4, 5))], (3, 4, 5), (), ), ( MixtureDistributionOutput([ MultivariateGaussianOutput(dim=5), MultivariateGaussianOutput(dim=5), ]), mx.nd.random.normal(shape=(3, 4, 10)), [None, mx.nd.ones(shape=(3, 4, 5))], [None, mx.nd.ones(shape=(3, 4, 5))], (3, 4),