Exemple #1
0
def test_mixture(distr1: Distribution, distr2: Distribution, p: Tensor,
                 serialize_fn) -> None:
    # sample from component distributions, and select samples

    samples1 = distr1.sample(num_samples=NUM_SAMPLES_LARGE)
    samples2 = distr2.sample(num_samples=NUM_SAMPLES_LARGE)

    rand = mx.nd.random.uniform(shape=(NUM_SAMPLES_LARGE, *p.shape))
    choice = (rand < p.expand_dims(axis=0)).broadcast_like(samples1)
    samples_ref = mx.nd.where(choice, samples1, samples2)

    # construct mixture distribution and sample from it

    mixture_probs = mx.nd.stack(p, 1.0 - p, axis=-1)

    mixture = MixtureDistribution(mixture_probs=mixture_probs,
                                  components=[distr1, distr2])
    mixture = serialize_fn(mixture)

    samples_mix = mixture.sample(num_samples=NUM_SAMPLES_LARGE)

    # check that shapes are right

    assert (samples1.shape == samples2.shape == samples_mix.shape ==
            samples_ref.shape)

    # check mean and stddev
    calc_mean = mixture.mean.asnumpy()
    sample_mean = samples_mix.asnumpy().mean(axis=0)

    assert np.allclose(calc_mean, sample_mean, atol=1e-1)

    # check that histograms are close
    assert (diff(histogram(samples_mix.asnumpy()),
                 histogram(samples_ref.asnumpy())) < 0.05)

    # can only calculated cdf for gaussians currently
    if isinstance(distr1, Gaussian) and isinstance(distr2, Gaussian):
        emp_cdf, edges = empirical_cdf(samples_mix.asnumpy())
        calc_cdf = mixture.cdf(mx.nd.array(edges)).asnumpy()
        assert np.allclose(calc_cdf[1:, :], emp_cdf, atol=1e-2)
Exemple #2
0
def test_mixture(
    distr1: Distribution, distr2: Distribution, p: Tensor
) -> None:

    # sample from component distributions, and select samples

    samples1 = distr1.sample(num_samples=NUM_SAMPLES)
    samples2 = distr2.sample(num_samples=NUM_SAMPLES)

    rand = mx.nd.random.uniform(shape=(NUM_SAMPLES, *p.shape))
    choice = (rand < p.expand_dims(axis=0)).broadcast_like(samples1)
    samples_ref = mx.nd.where(choice, samples1, samples2)

    # construct mixture distribution and sample from it

    mixture_probs = mx.nd.stack(p, 1.0 - p, axis=-1)

    mixture = MixtureDistribution(
        mixture_probs=mixture_probs, components=[distr1, distr2]
    )

    samples_mix = mixture.sample(num_samples=NUM_SAMPLES)

    # check that shapes are right

    assert (
        samples1.shape
        == samples2.shape
        == samples_mix.shape
        == samples_ref.shape
    )

    # check that histograms are close

    assert (
        diff(
            histogram(samples_mix.asnumpy()), histogram(samples_ref.asnumpy())
        )
        < 0.05
    )
Exemple #3
0
         slopes=mx.nd.ones(shape=(3, 4, 5, 10)),
         knot_spacings=mx.nd.ones(shape=(3, 4, 5, 10)) / 10,
     ),
     (3, 4, 5),
     (),
 ),
 (
     MixtureDistribution(
         mixture_probs=mx.nd.stack(
             0.2 * mx.nd.ones(shape=(3, 1, 5)),
             0.8 * mx.nd.ones(shape=(3, 1, 5)),
             axis=-1,
         ),
         components=[
             Gaussian(
                 mu=mx.nd.zeros(shape=(3, 4, 5)),
                 sigma=mx.nd.ones(shape=(3, 4, 5)),
             ),
             StudentT(
                 mu=mx.nd.zeros(shape=(3, 4, 5)),
                 sigma=mx.nd.ones(shape=(3, 4, 5)),
                 nu=mx.nd.ones(shape=(3, 4, 5)),
             ),
         ],
     ),
     (3, 4, 5),
     (),
 ),
 (
     MixtureDistribution(
         mixture_probs=mx.nd.stack(
             0.2 * mx.nd.ones(shape=(3, 4)),
Exemple #4
0
     high=mx.nd.ones(shape=BATCH_SHAPE),
 ),
 PiecewiseLinear(
     gamma=mx.nd.ones(shape=BATCH_SHAPE),
     slopes=mx.nd.ones(shape=(3, 4, 5, 10)),
     knot_spacings=mx.nd.ones(shape=(3, 4, 5, 10)) / 10,
 ),
 MixtureDistribution(
     mixture_probs=mx.nd.stack(
         0.2 * mx.nd.ones(shape=BATCH_SHAPE),
         0.8 * mx.nd.ones(shape=BATCH_SHAPE),
         axis=-1,
     ),
     components=[
         Gaussian(
             mu=mx.nd.zeros(shape=BATCH_SHAPE),
             sigma=mx.nd.ones(shape=BATCH_SHAPE),
         ),
         StudentT(
             mu=mx.nd.zeros(shape=BATCH_SHAPE),
             sigma=mx.nd.ones(shape=BATCH_SHAPE),
             nu=mx.nd.ones(shape=BATCH_SHAPE),
         ),
     ],
 ),
 TransformedDistribution(
     StudentT(
         mu=mx.nd.zeros(shape=BATCH_SHAPE),
         sigma=mx.nd.ones(shape=BATCH_SHAPE),
         nu=mx.nd.ones(shape=BATCH_SHAPE),
     ),
     [