def test_mixture_logprob( distribution: Distribution, values_outside_support: Tensor, distribution_output: DistributionOutput, ) -> None: assert np.all( ~np.isnan(distribution.log_prob(values_outside_support).asnumpy()) ), f"{distribution} should return -inf log_probs instead of NaNs" p = 0.5 gaussian = Gaussian(mu=mx.nd.array([0]), sigma=mx.nd.array([2.0])) mixture = MixtureDistribution( mixture_probs=mx.nd.array([[p, 1 - p]]), components=[gaussian, distribution], ) lp = mixture.log_prob(values_outside_support) assert np.allclose( lp.asnumpy(), np.log(p) + gaussian.log_prob(values_outside_support).asnumpy(), atol=1e-6, ), f"log_prob(x) should be equal to log(p)+gaussian.log_prob(x)" fit_mixture = fit_mixture_distribution( values_outside_support, MixtureDistributionOutput([GaussianOutput(), distribution_output]), variate_dimensionality=1, epochs=3, ) for ci, c in enumerate(fit_mixture.components): for ai, a in enumerate(c.args): assert ~np.isnan(a.asnumpy()), f"NaN gradients led to {c}"
def test_mixture( distr1: Distribution, distr2: Distribution, p: Tensor, serialize_fn ) -> None: # sample from component distributions, and select samples samples1 = distr1.sample(num_samples=NUM_SAMPLES_LARGE) samples2 = distr2.sample(num_samples=NUM_SAMPLES_LARGE) # TODO: for multivariate case, test should not sample elements from different components in the event_dim dimension rand = mx.nd.random.uniform(shape=(NUM_SAMPLES_LARGE, *p.shape)) choice = (rand < p.expand_dims(axis=0)).broadcast_like(samples1) samples_ref = mx.nd.where(choice, samples1, samples2) # construct mixture distribution and sample from it mixture_probs = mx.nd.stack(p, 1.0 - p, axis=-1) mixture = MixtureDistribution( mixture_probs=mixture_probs, components=[distr1, distr2] ) mixture = serialize_fn(mixture) samples_mix = mixture.sample(num_samples=NUM_SAMPLES_LARGE) # check that shapes are right assert ( samples1.shape == samples2.shape == samples_mix.shape == samples_ref.shape ) # check mean and stddev calc_mean = mixture.mean.asnumpy() calc_std = mixture.stddev.asnumpy() sample_mean = samples_mix.asnumpy().mean(axis=0) sample_std = samples_mix.asnumpy().std(axis=0) assert np.allclose(calc_mean, sample_mean, atol=1e-1) assert np.allclose(calc_std, sample_std, atol=2e-1) # check that histograms are close assert ( diff( histogram(samples_mix.asnumpy()), histogram(samples_ref.asnumpy()) ) < 0.05 ) # can only calculated cdf for gaussians currently if isinstance(distr1, Gaussian) and isinstance(distr2, Gaussian): emp_cdf, edges = empirical_cdf(samples_mix.asnumpy()) calc_cdf = mixture.cdf(mx.nd.array(edges)).asnumpy() assert np.allclose(calc_cdf[1:, :], emp_cdf, atol=1e-2)
def test_inference_mixture_different_families( mixture_distribution: MixtureDistribution, mixture_distribution_output: MixtureDistributionOutput, epochs: int, serialize_fn, ) -> None: # First sample from mixture distribution and then confirm the MLE are close to true parameters num_samples = 10_000 samples = mixture_distribution.sample(num_samples=num_samples) variate_dimensionality = ( mixture_distribution.components[0].args[0].shape[0] ) fitted_dist = fit_mixture_distribution( samples, mixture_distribution_output, variate_dimensionality, epochs=epochs, ) assert np.allclose( fitted_dist.mixture_probs.asnumpy(), mixture_distribution.mixture_probs.asnumpy(), atol=1e-1, ), f"Mixing probability estimates {fitted_dist.mixture_probs.asnumpy()} too far from {mixture_distribution.mixture_probs.asnumpy()}" for ci, c in enumerate(mixture_distribution.components): for ai, a in enumerate(c.args): assert np.allclose( fitted_dist.components[ci].args[ai].asnumpy(), a.asnumpy(), atol=1e-1, ), f"Parameter {ai} estimate {fitted_dist.components[ci].args[ai].asnumpy()} too far from {c}"
t.set_postfix({"loss": loss_value}) trainer.step(1) distr_args = args_proj(input) d = mdo.distribution(distr_args) return d @pytest.mark.parametrize( "mixture_distribution, mixture_distribution_output, epochs", [ ( MixtureDistribution( mixture_probs=mx.nd.array([[0.6, 0.4]]), components=[ Gaussian(mu=mx.nd.array([-1.0]), sigma=mx.nd.array([0.2])), Gamma(alpha=mx.nd.array([2.0]), beta=mx.nd.array([0.5])), ], ), MixtureDistributionOutput([GaussianOutput(), GammaOutput()]), 2_000, ), ( MixtureDistribution( mixture_probs=mx.nd.array([[0.7, 0.3]]), components=[ Gaussian(mu=mx.nd.array([-1.0]), sigma=mx.nd.array([0.2])), GenPareto(xi=mx.nd.array([0.6]), beta=mx.nd.array([1.0])), ], ), MixtureDistributionOutput([GaussianOutput(), GenParetoOutput()]),
high=mx.nd.ones(shape=BATCH_SHAPE), ), PiecewiseLinear( gamma=mx.nd.ones(shape=BATCH_SHAPE), slopes=mx.nd.ones(shape=(3, 4, 5, 10)), knot_spacings=mx.nd.ones(shape=(3, 4, 5, 10)) / 10, ), MixtureDistribution( mixture_probs=mx.nd.stack( 0.2 * mx.nd.ones(shape=BATCH_SHAPE), 0.8 * mx.nd.ones(shape=BATCH_SHAPE), axis=-1, ), components=[ Gaussian( mu=mx.nd.zeros(shape=BATCH_SHAPE), sigma=mx.nd.ones(shape=BATCH_SHAPE), ), StudentT( mu=mx.nd.zeros(shape=BATCH_SHAPE), sigma=mx.nd.ones(shape=BATCH_SHAPE), nu=mx.nd.ones(shape=BATCH_SHAPE), ), ], ), TransformedDistribution( StudentT( mu=mx.nd.zeros(shape=BATCH_SHAPE), sigma=mx.nd.ones(shape=BATCH_SHAPE), nu=mx.nd.ones(shape=BATCH_SHAPE), ), [
slopes=mx.nd.ones(shape=(3, 4, 5, 10)), knot_spacings=mx.nd.ones(shape=(3, 4, 5, 10)) / 10, ), (3, 4, 5), (), ), ( MixtureDistribution( mixture_probs=mx.nd.stack( 0.2 * mx.nd.ones(shape=(3, 1, 5)), 0.8 * mx.nd.ones(shape=(3, 1, 5)), axis=-1, ), components=[ Gaussian( mu=mx.nd.zeros(shape=(3, 4, 5)), sigma=mx.nd.ones(shape=(3, 4, 5)), ), StudentT( mu=mx.nd.zeros(shape=(3, 4, 5)), sigma=mx.nd.ones(shape=(3, 4, 5)), nu=mx.nd.ones(shape=(3, 4, 5)), ), ], ), (3, 4, 5), (), ), ( MixtureDistribution( mixture_probs=mx.nd.stack( 0.2 * mx.nd.ones(shape=(3, 4)),