def test_mixture(distr1: Distribution, distr2: Distribution, p: Tensor, serialize_fn) -> None: # sample from component distributions, and select samples samples1 = distr1.sample(num_samples=NUM_SAMPLES_LARGE) samples2 = distr2.sample(num_samples=NUM_SAMPLES_LARGE) rand = mx.nd.random.uniform(shape=(NUM_SAMPLES_LARGE, *p.shape)) choice = (rand < p.expand_dims(axis=0)).broadcast_like(samples1) samples_ref = mx.nd.where(choice, samples1, samples2) # construct mixture distribution and sample from it mixture_probs = mx.nd.stack(p, 1.0 - p, axis=-1) mixture = MixtureDistribution(mixture_probs=mixture_probs, components=[distr1, distr2]) mixture = serialize_fn(mixture) samples_mix = mixture.sample(num_samples=NUM_SAMPLES_LARGE) # check that shapes are right assert (samples1.shape == samples2.shape == samples_mix.shape == samples_ref.shape) # check mean and stddev calc_mean = mixture.mean.asnumpy() sample_mean = samples_mix.asnumpy().mean(axis=0) assert np.allclose(calc_mean, sample_mean, atol=1e-1) # check that histograms are close assert (diff(histogram(samples_mix.asnumpy()), histogram(samples_ref.asnumpy())) < 0.05) # can only calculated cdf for gaussians currently if isinstance(distr1, Gaussian) and isinstance(distr2, Gaussian): emp_cdf, edges = empirical_cdf(samples_mix.asnumpy()) calc_cdf = mixture.cdf(mx.nd.array(edges)).asnumpy() assert np.allclose(calc_cdf[1:, :], emp_cdf, atol=1e-2)
def test_mixture( distr1: Distribution, distr2: Distribution, p: Tensor ) -> None: # sample from component distributions, and select samples samples1 = distr1.sample(num_samples=NUM_SAMPLES) samples2 = distr2.sample(num_samples=NUM_SAMPLES) rand = mx.nd.random.uniform(shape=(NUM_SAMPLES, *p.shape)) choice = (rand < p.expand_dims(axis=0)).broadcast_like(samples1) samples_ref = mx.nd.where(choice, samples1, samples2) # construct mixture distribution and sample from it mixture_probs = mx.nd.stack(p, 1.0 - p, axis=-1) mixture = MixtureDistribution( mixture_probs=mixture_probs, components=[distr1, distr2] ) samples_mix = mixture.sample(num_samples=NUM_SAMPLES) # check that shapes are right assert ( samples1.shape == samples2.shape == samples_mix.shape == samples_ref.shape ) # check that histograms are close assert ( diff( histogram(samples_mix.asnumpy()), histogram(samples_ref.asnumpy()) ) < 0.05 )
slopes=mx.nd.ones(shape=(3, 4, 5, 10)), knot_spacings=mx.nd.ones(shape=(3, 4, 5, 10)) / 10, ), (3, 4, 5), (), ), ( MixtureDistribution( mixture_probs=mx.nd.stack( 0.2 * mx.nd.ones(shape=(3, 1, 5)), 0.8 * mx.nd.ones(shape=(3, 1, 5)), axis=-1, ), components=[ Gaussian( mu=mx.nd.zeros(shape=(3, 4, 5)), sigma=mx.nd.ones(shape=(3, 4, 5)), ), StudentT( mu=mx.nd.zeros(shape=(3, 4, 5)), sigma=mx.nd.ones(shape=(3, 4, 5)), nu=mx.nd.ones(shape=(3, 4, 5)), ), ], ), (3, 4, 5), (), ), ( MixtureDistribution( mixture_probs=mx.nd.stack( 0.2 * mx.nd.ones(shape=(3, 4)),
high=mx.nd.ones(shape=BATCH_SHAPE), ), PiecewiseLinear( gamma=mx.nd.ones(shape=BATCH_SHAPE), slopes=mx.nd.ones(shape=(3, 4, 5, 10)), knot_spacings=mx.nd.ones(shape=(3, 4, 5, 10)) / 10, ), MixtureDistribution( mixture_probs=mx.nd.stack( 0.2 * mx.nd.ones(shape=BATCH_SHAPE), 0.8 * mx.nd.ones(shape=BATCH_SHAPE), axis=-1, ), components=[ Gaussian( mu=mx.nd.zeros(shape=BATCH_SHAPE), sigma=mx.nd.ones(shape=BATCH_SHAPE), ), StudentT( mu=mx.nd.zeros(shape=BATCH_SHAPE), sigma=mx.nd.ones(shape=BATCH_SHAPE), nu=mx.nd.ones(shape=BATCH_SHAPE), ), ], ), TransformedDistribution( StudentT( mu=mx.nd.zeros(shape=BATCH_SHAPE), sigma=mx.nd.ones(shape=BATCH_SHAPE), nu=mx.nd.ones(shape=BATCH_SHAPE), ), [