Пример #1
0
def test_beta_bernoulli(auto_class):
    data = jnp.array([[1.0] * 8 + [0.0] * 2, [1.0] * 4 + [0.0] * 6]).T

    def model(data):
        f = numpyro.sample('beta', dist.Beta(jnp.ones(2), jnp.ones(2)))
        numpyro.sample('obs', dist.Bernoulli(f), obs=data)

    adam = optim.Adam(0.01)
    guide = auto_class(model, init_strategy=init_strategy)
    svi = SVI(model, guide, adam, ELBO())
    svi_state = svi.init(random.PRNGKey(1), data)

    def body_fn(i, val):
        svi_state, loss = svi.update(val, data)
        return svi_state

    svi_state = fori_loop(0, 3000, body_fn, svi_state)
    params = svi.get_params(svi_state)
    true_coefs = (jnp.sum(data, axis=0) + 1) / (data.shape[0] + 2)
    # test .sample_posterior method
    posterior_samples = guide.sample_posterior(random.PRNGKey(1),
                                               params,
                                               sample_shape=(1000, ))
    assert_allclose(jnp.mean(posterior_samples['beta'], 0),
                    true_coefs,
                    atol=0.05)
Пример #2
0
    def __init__(self,
                 model: Model,
                 guide: Guide,
                 rng_key: int = 0,
                 *,
                 loss: ELBO = ELBO(num_particles=1),
                 optim_builder: optim.optimizers.optimizer = optim.Adam):
        """Handling the model and guide for training and prediction

        Args:
            model: function holding the numpyro model
            guide: function holding the numpyro guide
            rng_key: random key as int
            loss: loss to optimize
            optim_builder: builder for an optimizer
        """
        self.model = model
        self.guide = guide
        self.rng_key = random.PRNGKey(rng_key)  # current random key
        self.loss = loss
        self.optim_builder = optim_builder
        self.svi = None
        self.svi_state = None
        self.optim = None
        self.log_func = print  # overwrite e.g. logger.info(...)
Пример #3
0
def test_uniform_normal():
    true_coef = 0.9
    data = true_coef + random.normal(random.PRNGKey(0), (1000, ))

    def model(data):
        alpha = numpyro.sample('alpha', dist.Uniform(0, 1))
        with numpyro.handlers.reparam(config={'loc': TransformReparam()}):
            loc = numpyro.sample('loc', dist.Uniform(0, alpha))
        numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data)

    adam = optim.Adam(0.01)
    rng_key_init = random.PRNGKey(1)
    guide = AutoDiagonalNormal(model)
    svi = SVI(model, guide, adam, ELBO())
    svi_state = svi.init(rng_key_init, data)

    def body_fn(i, val):
        svi_state, loss = svi.update(val, data)
        return svi_state

    svi_state = fori_loop(0, 1000, body_fn, svi_state)
    params = svi.get_params(svi_state)
    median = guide.median(params)
    assert_allclose(median['loc'], true_coef, rtol=0.05)
    # test .quantile method
    median = guide.quantiles(params, [0.2, 0.5])
    assert_allclose(median['loc'][1], true_coef, rtol=0.1)
Пример #4
0
def main(args):
    encoder_nn = encoder(args.hidden_dim, args.z_dim)
    decoder_nn = decoder(args.hidden_dim, 28 * 28)
    adam = optim.Adam(args.learning_rate)
    svi = SVI(model, guide, adam, ELBO(), hidden_dim=args.hidden_dim, z_dim=args.z_dim)
    rng_key = PRNGKey(0)
    train_init, train_fetch = load_dataset(MNIST, batch_size=args.batch_size, split='train')
    test_init, test_fetch = load_dataset(MNIST, batch_size=args.batch_size, split='test')
    num_train, train_idx = train_init()
    rng_key, rng_key_binarize, rng_key_init = random.split(rng_key, 3)
    sample_batch = binarize(rng_key_binarize, train_fetch(0, train_idx)[0])
    svi_state = svi.init(rng_key_init, sample_batch)

    @jit
    def epoch_train(svi_state, rng_key):
        def body_fn(i, val):
            loss_sum, svi_state = val
            rng_key_binarize = random.fold_in(rng_key, i)
            batch = binarize(rng_key_binarize, train_fetch(i, train_idx)[0])
            svi_state, loss = svi.update(svi_state, batch)
            loss_sum += loss
            return loss_sum, svi_state

        return lax.fori_loop(0, num_train, body_fn, (0., svi_state))

    @jit
    def eval_test(svi_state, rng_key):
        def body_fun(i, loss_sum):
            rng_key_binarize = random.fold_in(rng_key, i)
            batch = binarize(rng_key_binarize, test_fetch(i, test_idx)[0])
            # FIXME: does this lead to a requirement for an rng_key arg in svi_eval?
            loss = svi.evaluate(svi_state, batch) / len(batch)
            loss_sum += loss
            return loss_sum

        loss = lax.fori_loop(0, num_test, body_fun, 0.)
        loss = loss / num_test
        return loss

    def reconstruct_img(epoch, rng_key):
        img = test_fetch(0, test_idx)[0][0]
        plt.imsave(os.path.join(RESULTS_DIR, 'original_epoch={}.png'.format(epoch)), img, cmap='gray')
        rng_key_binarize, rng_key_sample = random.split(rng_key)
        test_sample = binarize(rng_key_binarize, img)
        params = svi.get_params(svi_state)
        z_mean, z_var = encoder_nn[1](params['encoder$params'], test_sample.reshape([1, -1]))
        z = dist.Normal(z_mean, z_var).sample(rng_key_sample)
        img_loc = decoder_nn[1](params['decoder$params'], z).reshape([28, 28])
        plt.imsave(os.path.join(RESULTS_DIR, 'recons_epoch={}.png'.format(epoch)), img_loc, cmap='gray')

    for i in range(args.num_epochs):
        rng_key, rng_key_train, rng_key_test, rng_key_reconstruct = random.split(rng_key, 4)
        t_start = time.time()
        num_train, train_idx = train_init()
        _, svi_state = epoch_train(svi_state, rng_key_train)
        rng_key, rng_key_test, rng_key_reconstruct = random.split(rng_key, 3)
        num_test, test_idx = test_init()
        test_loss = eval_test(svi_state, rng_key_test)
        reconstruct_img(i, rng_key_reconstruct)
        print("Epoch {}: loss = {} ({:.2f} s.)".format(i, test_loss, time.time() - t_start))
Пример #5
0
def test_predictive_with_guide():
    data = jnp.array([1] * 8 + [0] * 2)

    def model(data):
        f = numpyro.sample("beta", dist.Beta(1., 1.))
        with numpyro.plate("plate", 10):
            numpyro.deterministic("beta_sq", f ** 2)
            numpyro.sample("obs", dist.Bernoulli(f), obs=data)

    def guide(data):
        alpha_q = numpyro.param("alpha_q", 1.0,
                                constraint=constraints.positive)
        beta_q = numpyro.param("beta_q", 1.0,
                               constraint=constraints.positive)
        numpyro.sample("beta", dist.Beta(alpha_q, beta_q))

    svi = SVI(model, guide, optim.Adam(0.1), ELBO())
    svi_state = svi.init(random.PRNGKey(1), data)

    def body_fn(i, val):
        svi_state, _ = svi.update(val, data)
        return svi_state

    svi_state = lax.fori_loop(0, 1000, body_fn, svi_state)
    params = svi.get_params(svi_state)
    predictive = Predictive(model, guide=guide, params=params, num_samples=1000)(random.PRNGKey(2), data=None)
    assert predictive["beta_sq"].shape == (1000,)
    obs_pred = predictive["obs"]
    assert_allclose(jnp.mean(obs_pred), 0.8, atol=0.05)
Пример #6
0
def test_elbo_dynamic_support():
    x_prior = dist.TransformedDistribution(
        dist.Normal(),
        [AffineTransform(0, 2),
         SigmoidTransform(),
         AffineTransform(0, 3)])
    x_guide = dist.Uniform(0, 3)

    def model():
        numpyro.sample('x', x_prior)

    def guide():
        numpyro.sample('x', x_guide)

    adam = optim.Adam(0.01)
    # set base value of x_guide is 0.9
    x_base = 0.9
    guide = substitute(guide, base_param_map={'x': x_base})
    svi = SVI(model, guide, adam, ELBO())
    svi_state = svi.init(random.PRNGKey(0))
    actual_loss = svi.evaluate(svi_state)
    assert np.isfinite(actual_loss)
    x, _ = x_guide.transform_with_intermediates(x_base)
    expected_loss = x_guide.log_prob(x) - x_prior.log_prob(x)
    assert_allclose(actual_loss, expected_loss)
Пример #7
0
def main(args):
    # Generate some data.
    data = random.normal(PRNGKey(0), shape=(100,)) + 3.0

    # Construct an SVI object so we can do variational inference on our
    # model/guide pair.
    adam = optim.Adam(args.learning_rate)

    svi = SVI(model, guide, adam, ELBO(num_particles=100))
    svi_state = svi.init(PRNGKey(0), data)

    # Training loop
    def body_fn(i, val):
        svi_state, loss = svi.update(val, data)
        return svi_state

    svi_state = fori_loop(0, args.num_steps, body_fn, svi_state)

    # Report the final values of the variational parameters
    # in the guide after training.
    params = svi.get_params(svi_state)
    for name, value in params.items():
        print("{} = {}".format(name, value))

    # For this simple (conjugate) model we know the exact posterior. In
    # particular we know that the variational distribution should be
    # centered near 3.0. So let's check this explicitly.
    assert np.abs(params["guide_loc"] - 3.0) < 0.1
Пример #8
0
def test_logistic_regression(auto_class):
    N, dim = 3000, 3
    data = random.normal(random.PRNGKey(0), (N, dim))
    true_coefs = jnp.arange(1., dim + 1.)
    logits = jnp.sum(true_coefs * data, axis=-1)
    labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))

    def model(data, labels):
        coefs = numpyro.sample('coefs', dist.Normal(jnp.zeros(dim), jnp.ones(dim)))
        logits = jnp.sum(coefs * data, axis=-1)
        return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels)

    adam = optim.Adam(0.01)
    rng_key_init = random.PRNGKey(1)
    guide = auto_class(model, init_strategy=init_strategy)
    svi = SVI(model, guide, adam, ELBO())
    svi_state = svi.init(rng_key_init, data, labels)

    def body_fn(i, val):
        svi_state, loss = svi.update(val, data, labels)
        return svi_state

    svi_state = fori_loop(0, 2000, body_fn, svi_state)
    params = svi.get_params(svi_state)
    if auto_class not in (AutoIAFNormal, AutoBNAFNormal):
        median = guide.median(params)
        assert_allclose(median['coefs'], true_coefs, rtol=0.1)
        # test .quantile method
        median = guide.quantiles(params, [0.2, 0.5])
        assert_allclose(median['coefs'][1], true_coefs, rtol=0.1)
    # test .sample_posterior method
    posterior_samples = guide.sample_posterior(random.PRNGKey(1), params, sample_shape=(1000,))
    assert_allclose(jnp.mean(posterior_samples['coefs'], 0), true_coefs, rtol=0.1)
Пример #9
0
def test_dynamic_supports():
    true_coef = 0.9
    data = true_coef + random.normal(random.PRNGKey(0), (1000, ))

    def actual_model(data):
        alpha = numpyro.sample('alpha', dist.Uniform(0, 1))
        with numpyro.handlers.reparam(config={'loc': TransformReparam()}):
            loc = numpyro.sample('loc', dist.Uniform(0, alpha))
        numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data)

    def expected_model(data):
        alpha = numpyro.sample('alpha', dist.Uniform(0, 1))
        loc = numpyro.sample('loc', dist.Uniform(0, 1)) * alpha
        numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data)

    adam = optim.Adam(0.01)
    rng_key_init = random.PRNGKey(1)

    guide = AutoDiagonalNormal(actual_model)
    svi = SVI(actual_model, guide, adam, ELBO())
    svi_state = svi.init(rng_key_init, data)
    actual_opt_params = adam.get_params(svi_state.optim_state)
    actual_params = svi.get_params(svi_state)
    actual_values = guide.median(actual_params)
    actual_loss = svi.evaluate(svi_state, data)

    guide = AutoDiagonalNormal(expected_model)
    svi = SVI(expected_model, guide, adam, ELBO())
    svi_state = svi.init(rng_key_init, data)
    expected_opt_params = adam.get_params(svi_state.optim_state)
    expected_params = svi.get_params(svi_state)
    expected_values = guide.median(expected_params)
    expected_loss = svi.evaluate(svi_state, data)

    # test auto_loc, auto_scale
    check_eq(actual_opt_params, expected_opt_params)
    check_eq(actual_params, expected_params)
    # test latent values
    assert_allclose(actual_values['alpha'], expected_values['alpha'])
    assert_allclose(actual_values['loc_base'], expected_values['loc'])
    assert_allclose(actual_loss, expected_loss)
Пример #10
0
def svi(model, guide, num_steps, lr, rng_key, X, Y):
    """
    Helper function for doing SVI inference.
    """
    svi = SVI(model, guide, optim.Adam(lr), ELBO(num_particles=1), X=X, Y=Y)

    svi_state = svi.init(rng_key)
    print('Optimizing...')
    state, loss = lax.scan(lambda x, i: svi.update(x), svi_state,
                           np.zeros(num_steps))

    return loss, svi.get_params(state)
Пример #11
0
def test_iaf():
    # test for substitute logic for exposed methods `sample_posterior` and `get_transforms`
    N, dim = 3000, 3
    data = random.normal(random.PRNGKey(0), (N, dim))
    true_coefs = jnp.arange(1., dim + 1.)
    logits = jnp.sum(true_coefs * data, axis=-1)
    labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))

    def model(data, labels):
        coefs = numpyro.sample('coefs',
                               dist.Normal(jnp.zeros(dim), jnp.ones(dim)))
        offset = numpyro.sample('offset', dist.Uniform(-1, 1))
        logits = offset + jnp.sum(coefs * data, axis=-1)
        return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels)

    adam = optim.Adam(0.01)
    rng_key_init = random.PRNGKey(1)
    guide = AutoIAFNormal(model)
    svi = SVI(model, guide, adam, ELBO())
    svi_state = svi.init(rng_key_init, data, labels)
    params = svi.get_params(svi_state)

    x = random.normal(random.PRNGKey(0), (dim + 1, ))
    rng_key = random.PRNGKey(1)
    actual_sample = guide.sample_posterior(rng_key, params)
    actual_output = guide._unpack_latent(guide.get_transform(params)(x))

    flows = []
    for i in range(guide.num_flows):
        if i > 0:
            flows.append(transforms.PermuteTransform(
                jnp.arange(dim + 1)[::-1]))
        arn_init, arn_apply = AutoregressiveNN(
            dim + 1, [dim + 1, dim + 1],
            permutation=jnp.arange(dim + 1),
            skip_connections=guide._skip_connections,
            nonlinearity=guide._nonlinearity)
        arn = partial(arn_apply, params['auto_arn__{}$params'.format(i)])
        flows.append(InverseAutoregressiveTransform(arn))
    flows.append(guide._unpack_latent)

    transform = transforms.ComposeTransform(flows)
    _, rng_key_sample = random.split(rng_key)
    expected_sample = transform(
        dist.Normal(jnp.zeros(dim + 1), 1).sample(rng_key_sample))
    expected_output = transform(x)
    assert_allclose(actual_sample['coefs'], expected_sample['coefs'])
    assert_allclose(
        actual_sample['offset'],
        transforms.biject_to(constraints.interval(-1, 1))(
            expected_sample['offset']))
    check_eq(actual_output, expected_output)
Пример #12
0
def test_reparam_log_joint(model, kwargs):
    guide = AutoIAFNormal(model)
    svi = SVI(model, guide, Adam(1e-10), ELBO(), **kwargs)
    svi_state = svi.init(random.PRNGKey(0))
    params = svi.get_params(svi_state)
    neutra = NeuTraReparam(guide, params)
    reparam_model = neutra.reparam(model)
    _, pe_fn, _, _ = initialize_model(random.PRNGKey(1), model, model_kwargs=kwargs)
    init_params, pe_fn_neutra, _, _ = initialize_model(random.PRNGKey(2), reparam_model, model_kwargs=kwargs)
    latent_x = list(init_params[0].values())[0]
    pe_transformed = pe_fn_neutra(init_params[0])
    latent_y = neutra.transform(latent_x)
    log_det_jacobian = neutra.transform.log_abs_det_jacobian(latent_x, latent_y)
    pe = pe_fn(guide._unpack_latent(latent_y))
    assert_allclose(pe_transformed, pe - log_det_jacobian)
Пример #13
0
def test_laplace_approximation_warning():
    def model(x, y):
        a = numpyro.sample("a", dist.Normal(0, 10))
        b = numpyro.sample("b", dist.Normal(0, 10), sample_shape=(3,))
        mu = a + b[0] * x + b[1] * x ** 2 + b[2] * x ** 3
        numpyro.sample("y", dist.Normal(mu, 0.001), obs=y)

    x = random.normal(random.PRNGKey(0), (3,))
    y = 1 + 2 * x + 3 * x ** 2 + 4 * x ** 3
    guide = AutoLaplaceApproximation(model)
    svi = SVI(model, guide, optim.Adam(0.1), ELBO(), x=x, y=y)
    init_state = svi.init(random.PRNGKey(0))
    svi_state = fori_loop(0, 10000, lambda i, val: svi.update(val)[0], init_state)
    params = svi.get_params(svi_state)
    with pytest.warns(UserWarning, match="Hessian of log posterior"):
        guide.sample_posterior(random.PRNGKey(1), params)
Пример #14
0
def test_improper():
    y = random.normal(random.PRNGKey(0), (100, ))

    def model(y):
        lambda1 = numpyro.sample(
            'lambda1', dist.ImproperUniform(dist.constraints.real, (), ()))
        lambda2 = numpyro.sample(
            'lambda2', dist.ImproperUniform(dist.constraints.real, (), ()))
        sigma = numpyro.sample(
            'sigma', dist.ImproperUniform(dist.constraints.positive, (), ()))
        mu = numpyro.deterministic('mu', lambda1 + lambda2)
        numpyro.sample('y', dist.Normal(mu, sigma), obs=y)

    guide = AutoDiagonalNormal(model)
    svi = SVI(model, guide, optim.Adam(0.003), ELBO(), y=y)
    svi_state = svi.init(random.PRNGKey(2))
    lax.scan(lambda state, i: svi.update(state), svi_state, jnp.zeros(10000))
Пример #15
0
def test_autoguide(deterministic):
    GLOBAL["count"] = 0
    guide = AutoDiagonalNormal(model)
    svi = SVI(model,
              guide,
              optim.Adam(0.1),
              ELBO(),
              deterministic=deterministic)
    svi_state = svi.init(random.PRNGKey(0))
    svi_state = lax.fori_loop(0, 100, lambda i, val: svi.update(val)[0],
                              svi_state)
    params = svi.get_params(svi_state)
    guide.sample_posterior(random.PRNGKey(1), params, sample_shape=(100, ))

    if deterministic:
        assert GLOBAL["count"] == 5
    else:
        assert GLOBAL["count"] == 4
Пример #16
0
def test_param():
    # this test the validity of model/guide sites having
    # param constraints contain composed transformed
    rng_keys = random.split(random.PRNGKey(0), 5)
    a_minval = 1
    c_minval = -2
    c_maxval = -1
    a_init = jnp.exp(random.normal(rng_keys[0])) + a_minval
    b_init = jnp.exp(random.normal(rng_keys[1]))
    c_init = random.uniform(rng_keys[2], minval=c_minval, maxval=c_maxval)
    d_init = random.uniform(rng_keys[3])
    obs = random.normal(rng_keys[4])

    def model():
        a = numpyro.param('a',
                          a_init,
                          constraint=constraints.greater_than(a_minval))
        b = numpyro.param('b', b_init, constraint=constraints.positive)
        numpyro.sample('x', dist.Normal(a, b), obs=obs)

    def guide():
        c = numpyro.param('c',
                          c_init,
                          constraint=constraints.interval(c_minval, c_maxval))
        d = numpyro.param('d', d_init, constraint=constraints.unit_interval)
        numpyro.sample('y', dist.Normal(c, d), obs=obs)

    adam = optim.Adam(0.01)
    svi = SVI(model, guide, adam, ELBO())
    svi_state = svi.init(random.PRNGKey(0))

    params = svi.get_params(svi_state)
    assert_allclose(params['a'], a_init)
    assert_allclose(params['b'], b_init)
    assert_allclose(params['c'], c_init)
    assert_allclose(params['d'], d_init)

    actual_loss = svi.evaluate(svi_state)
    assert jnp.isfinite(actual_loss)
    expected_loss = dist.Normal(c_init, d_init).log_prob(obs) - dist.Normal(
        a_init, b_init).log_prob(obs)
    # not so precisely because we do transform / inverse transform stuffs
    assert_allclose(actual_loss, expected_loss, rtol=1e-6)
Пример #17
0
def test_jitted_update_fn():
    data = jnp.array([1.0] * 8 + [0.0] * 2)

    def model(data):
        f = numpyro.sample("beta", dist.Beta(1., 1.))
        numpyro.sample("obs", dist.Bernoulli(f), obs=data)

    def guide(data):
        alpha_q = numpyro.param("alpha_q",
                                1.0,
                                constraint=constraints.positive)
        beta_q = numpyro.param("beta_q", 1.0, constraint=constraints.positive)
        numpyro.sample("beta", dist.Beta(alpha_q, beta_q))

    adam = optim.Adam(0.05)
    svi = SVI(model, guide, adam, ELBO())
    svi_state = svi.init(random.PRNGKey(1), data)
    expected = svi.get_params(svi.update(svi_state, data)[0])

    actual = svi.get_params(jit(svi.update)(svi_state, data=data)[0])
    check_close(actual, expected, atol=1e-5)
Пример #18
0
def test_param():
    # this test the validity of model having
    # param sites contain composed transformed constraints
    rng_keys = random.split(random.PRNGKey(0), 3)
    a_minval = 1
    a_init = jnp.exp(random.normal(rng_keys[0])) + a_minval
    b_init = jnp.exp(random.normal(rng_keys[1]))
    x_init = random.normal(rng_keys[2])

    def model():
        a = numpyro.param('a',
                          a_init,
                          constraint=constraints.greater_than(a_minval))
        b = numpyro.param('b', b_init, constraint=constraints.positive)
        numpyro.sample('x', dist.Normal(a, b))

    # this class is used to force init value of `x` to x_init
    class _AutoGuide(AutoDiagonalNormal):
        def __call__(self, *args, **kwargs):
            return substitute(
                super(_AutoGuide, self).__call__,
                {'_auto_latent': x_init})(*args, **kwargs)

    adam = optim.Adam(0.01)
    rng_key_init = random.PRNGKey(1)
    guide = _AutoGuide(model)
    svi = SVI(model, guide, adam, ELBO())
    svi_state = svi.init(rng_key_init)

    params = svi.get_params(svi_state)
    assert_allclose(params['a'], a_init)
    assert_allclose(params['b'], b_init)
    assert_allclose(params['auto_loc'], guide._init_latent)
    assert_allclose(params['auto_scale'], jnp.ones(1) * guide._init_scale)

    actual_loss = svi.evaluate(svi_state)
    assert jnp.isfinite(actual_loss)
    expected_loss = dist.Normal(guide._init_latent, guide._init_scale).log_prob(x_init) \
        - dist.Normal(a_init, b_init).log_prob(x_init)
    assert_allclose(actual_loss, expected_loss, rtol=1e-6)
Пример #19
0
def test_elbo_dynamic_support():
    x_prior = dist.TransformedDistribution(
        dist.Normal(),
        [AffineTransform(0, 2),
         SigmoidTransform(),
         AffineTransform(0, 3)])
    x_guide = dist.Uniform(0, 3)

    def model():
        numpyro.sample('x', x_prior)

    def guide():
        numpyro.sample('x', x_guide)

    adam = optim.Adam(0.01)
    x = 2.
    guide = substitute(guide, data={'x': x})
    svi = SVI(model, guide, adam, ELBO())
    svi_state = svi.init(random.PRNGKey(0))
    actual_loss = svi.evaluate(svi_state)
    assert jnp.isfinite(actual_loss)
    expected_loss = x_guide.log_prob(x) - x_prior.log_prob(x)
    assert_allclose(actual_loss, expected_loss)
Пример #20
0
def test_neals_funnel_smoke():
    dim = 10

    guide = AutoIAFNormal(neals_funnel)
    svi = SVI(neals_funnel, guide, Adam(1e-10), ELBO())
    svi_state = svi.init(random.PRNGKey(0), dim)

    def body_fn(i, val):
        svi_state, loss = svi.update(val, dim)
        return svi_state

    svi_state = lax.fori_loop(0, 1000, body_fn, svi_state)
    params = svi.get_params(svi_state)

    neutra = NeuTraReparam(guide, params)
    model = neutra.reparam(neals_funnel)
    nuts = NUTS(model)
    mcmc = MCMC(nuts, num_warmup=50, num_samples=50)
    mcmc.run(random.PRNGKey(1), dim)
    samples = mcmc.get_samples()
    transformed_samples = neutra.transform_sample(samples['auto_shared_latent'])
    assert 'x' in transformed_samples
    assert 'y' in transformed_samples
Пример #21
0
def main(args):
    rng = PRNGKey(1234)
    rng, toy_data_rng = jax.random.split(rng, 2)
    X_train, X_test, mu_true = create_toy_data(toy_data_rng, args.num_samples,
                                               args.dimensions)

    train_init, train_fetch = subsample_batchify_data(
        (X_train, ), batch_size=args.batch_size)
    test_init, test_fetch = split_batchify_data((X_test, ),
                                                batch_size=args.batch_size)

    ## Init optimizer and training algorithms
    optimizer = optimizers.Adam(args.learning_rate)

    svi = DPSVI(model,
                guide,
                optimizer,
                ELBO(),
                dp_scale=args.sigma,
                clipping_threshold=args.clip_threshold,
                d=args.dimensions,
                num_obs_total=args.num_samples)

    rng, svi_init_rng, batchifier_rng = random.split(rng, 3)
    _, batchifier_state = train_init(rng_key=batchifier_rng)
    batch = train_fetch(0, batchifier_state)
    svi_state = svi.init(svi_init_rng, *batch)

    q = args.batch_size / args.num_samples
    eps = svi.get_epsilon(args.delta, q, num_epochs=args.num_epochs)
    print("Privacy epsilon {} (for sigma: {}, delta: {}, C: {}, q: {})".format(
        eps, args.sigma, args.clip_threshold, args.delta, q))

    @jit
    def epoch_train(svi_state, batchifier_state, num_batch):
        def body_fn(i, val):
            svi_state, loss = val
            batch = train_fetch(i, batchifier_state)
            svi_state, batch_loss = svi.update(svi_state, *batch)
            loss += batch_loss / (args.num_samples * num_batch)
            return svi_state, loss

        return lax.fori_loop(0, num_batch, body_fn, (svi_state, 0.))

    @jit
    def eval_test(svi_state, batchifier_state, num_batch):
        def body_fn(i, loss_sum):
            batch = test_fetch(i, batchifier_state)
            loss = svi.evaluate(svi_state, *batch)
            loss_sum += loss / (args.num_samples * num_batch)

            return loss_sum

        return lax.fori_loop(0, num_batch, body_fn, 0.)

## Train model

    for i in range(args.num_epochs):
        t_start = time.time()
        rng, data_fetch_rng = random.split(rng, 2)

        num_train_batches, train_batchifier_state = train_init(
            rng_key=data_fetch_rng)
        svi_state, train_loss = epoch_train(svi_state, train_batchifier_state,
                                            num_train_batches)
        train_loss.block_until_ready()
        t_end = time.time()

        if (i % (args.num_epochs // 10) == 0):
            rng, test_fetch_rng = random.split(rng, 2)
            num_test_batches, test_batchifier_state = test_init(
                rng_key=test_fetch_rng)
            test_loss = eval_test(svi_state, test_batchifier_state,
                                  num_test_batches)

            print(
                "Epoch {}: loss = {} (on training set: {}) ({:.2f} s.)".format(
                    i, test_loss, train_loss, t_end - t_start))

    params = svi.get_params(svi_state)
    mu_loc = params['mu_loc']
    mu_std = jnp.exp(params['mu_std_log'])
    print("### expected: {}".format(mu_true))
    print("### svi result\nmu_loc: {}\nerror: {}\nmu_std: {}".format(
        mu_loc, jnp.linalg.norm(mu_loc - mu_true), mu_std))
    mu_loc, mu_std = analytical_solution(X_train)
    print("### analytical solution\nmu_loc: {}\nerror: {}\nmu_std: {}".format(
        mu_loc, jnp.linalg.norm(mu_loc - mu_true), mu_std))
    mu_loc, mu_std = ml_estimate(X_train)
    print("### ml estimate\nmu_loc: {}\nerror: {}\nmu_std: {}".format(
        mu_loc, jnp.linalg.norm(mu_loc - mu_true), mu_std))
Пример #22
0
def main(args):
    print("Start vanilla HMC...")
    nuts_kernel = NUTS(dual_moon_model)
    mcmc = MCMC(
        nuts_kernel,
        args.num_warmup,
        args.num_samples,
        num_chains=args.num_chains,
        progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
    mcmc.run(random.PRNGKey(0))
    mcmc.print_summary()
    vanilla_samples = mcmc.get_samples()['x'].copy()

    guide = AutoBNAFNormal(
        dual_moon_model,
        hidden_factors=[args.hidden_factor, args.hidden_factor])
    svi = SVI(dual_moon_model, guide, optim.Adam(0.003), ELBO())
    svi_state = svi.init(random.PRNGKey(1))

    print("Start training guide...")
    last_state, losses = lax.scan(lambda state, i: svi.update(state),
                                  svi_state, jnp.zeros(args.num_iters))
    params = svi.get_params(last_state)
    print("Finish training guide. Extract samples...")
    guide_samples = guide.sample_posterior(
        random.PRNGKey(2), params,
        sample_shape=(args.num_samples, ))['x'].copy()

    print("\nStart NeuTra HMC...")
    neutra = NeuTraReparam(guide, params)
    neutra_model = neutra.reparam(dual_moon_model)
    nuts_kernel = NUTS(neutra_model)
    mcmc = MCMC(
        nuts_kernel,
        args.num_warmup,
        args.num_samples,
        num_chains=args.num_chains,
        progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
    mcmc.run(random.PRNGKey(3))
    mcmc.print_summary()
    zs = mcmc.get_samples(group_by_chain=True)["auto_shared_latent"]
    print("Transform samples into unwarped space...")
    samples = neutra.transform_sample(zs)
    print_summary(samples)
    zs = zs.reshape(-1, 2)
    samples = samples['x'].reshape(-1, 2).copy()

    # make plots

    # guide samples (for plotting)
    guide_base_samples = dist.Normal(jnp.zeros(2),
                                     1.).sample(random.PRNGKey(4), (1000, ))
    guide_trans_samples = neutra.transform_sample(guide_base_samples)['x']

    x1 = jnp.linspace(-3, 3, 100)
    x2 = jnp.linspace(-3, 3, 100)
    X1, X2 = jnp.meshgrid(x1, x2)
    P = jnp.exp(DualMoonDistribution().log_prob(jnp.stack([X1, X2], axis=-1)))

    fig = plt.figure(figsize=(12, 8), constrained_layout=True)
    gs = GridSpec(2, 3, figure=fig)
    ax1 = fig.add_subplot(gs[0, 0])
    ax2 = fig.add_subplot(gs[1, 0])
    ax3 = fig.add_subplot(gs[0, 1])
    ax4 = fig.add_subplot(gs[1, 1])
    ax5 = fig.add_subplot(gs[0, 2])
    ax6 = fig.add_subplot(gs[1, 2])

    ax1.plot(losses[1000:])
    ax1.set_title('Autoguide training loss\n(after 1000 steps)')

    ax2.contourf(X1, X2, P, cmap='OrRd')
    sns.kdeplot(guide_samples[:, 0], guide_samples[:, 1], n_levels=30, ax=ax2)
    ax2.set(xlim=[-3, 3],
            ylim=[-3, 3],
            xlabel='x0',
            ylabel='x1',
            title='Posterior using\nAutoBNAFNormal guide')

    sns.scatterplot(guide_base_samples[:, 0],
                    guide_base_samples[:, 1],
                    ax=ax3,
                    hue=guide_trans_samples[:, 0] < 0.)
    ax3.set(
        xlim=[-3, 3],
        ylim=[-3, 3],
        xlabel='x0',
        ylabel='x1',
        title='AutoBNAFNormal base samples\n(True=left moon; False=right moon)'
    )

    ax4.contourf(X1, X2, P, cmap='OrRd')
    sns.kdeplot(vanilla_samples[:, 0],
                vanilla_samples[:, 1],
                n_levels=30,
                ax=ax4)
    ax4.plot(vanilla_samples[-50:, 0],
             vanilla_samples[-50:, 1],
             'bo-',
             alpha=0.5)
    ax4.set(xlim=[-3, 3],
            ylim=[-3, 3],
            xlabel='x0',
            ylabel='x1',
            title='Posterior using\nvanilla HMC sampler')

    sns.scatterplot(zs[:, 0],
                    zs[:, 1],
                    ax=ax5,
                    hue=samples[:, 0] < 0.,
                    s=30,
                    alpha=0.5,
                    edgecolor="none")
    ax5.set(xlim=[-5, 5],
            ylim=[-5, 5],
            xlabel='x0',
            ylabel='x1',
            title='Samples from the\nwarped posterior - p(z)')

    ax6.contourf(X1, X2, P, cmap='OrRd')
    sns.kdeplot(samples[:, 0], samples[:, 1], n_levels=30, ax=ax6)
    ax6.plot(samples[-50:, 0], samples[-50:, 1], 'bo-', alpha=0.2)
    ax6.set(xlim=[-3, 3],
            ylim=[-3, 3],
            xlabel='x0',
            ylabel='x1',
            title='Posterior using\nNeuTra HMC sampler')

    plt.savefig("neutra.pdf")
Пример #23
0
Файл: vae.py Проект: byzhang/d3p
def main(args):
    # loading data
    (train_init, train_fetch_plain), num_samples = load_dataset(
        MNIST,
        batch_size=args.batch_size,
        split='train',
        batchifier=subsample_batchify_data)
    (test_init,
     test_fetch_plain), _ = load_dataset(MNIST,
                                         batch_size=args.batch_size,
                                         split='test',
                                         batchifier=split_batchify_data)

    def binarize_fetch(fetch_fn):
        @jit
        def fetch_binarized(batch_nr, batchifier_state, binarize_rng):
            batch = fetch_fn(batch_nr, batchifier_state)
            return binarize(binarize_rng, batch[0]), batch[1]

        return fetch_binarized

    train_fetch = binarize_fetch(train_fetch_plain)
    test_fetch = binarize_fetch(test_fetch_plain)

    # setting up optimizer
    optimizer = optimizers.Adam(args.learning_rate)

    # the minibatch environment in our model scales individual
    # records' contributions to the loss up by num_samples.
    # This can cause numerical instabilities so we scale down
    # the loss by 1/num_samples here.
    sc_model = scale(model, scale=1 / num_samples)
    sc_guide = scale(guide, scale=1 / num_samples)

    if args.no_dp:
        svi = SVI(sc_model,
                  sc_guide,
                  optimizer,
                  ELBO(),
                  num_obs_total=num_samples,
                  z_dim=args.z_dim,
                  hidden_dim=args.hidden_dim)
    else:
        q = args.batch_size / num_samples
        target_eps = args.epsilon
        dp_scale, act_eps, _ = approximate_sigma(target_eps=target_eps,
                                                 delta=1 / num_samples,
                                                 q=q,
                                                 num_iter=int(1 / q) *
                                                 args.num_epochs,
                                                 force_smaller=True)
        print(
            f"using noise scale {dp_scale} for epsilon of {act_eps} (targeted: {target_eps})"
        )

        svi = DPSVI(sc_model,
                    sc_guide,
                    optimizer,
                    ELBO(),
                    dp_scale=dp_scale,
                    clipping_threshold=10.,
                    num_obs_total=num_samples,
                    z_dim=args.z_dim,
                    hidden_dim=args.hidden_dim)

    # preparing random number generators and initializing svi
    rng = PRNGKey(0)
    rng, binarize_rng, svi_init_rng, batchifier_rng = random.split(rng, 4)
    _, batchifier_state = train_init(rng_key=batchifier_rng)
    sample_batch = train_fetch(0, batchifier_state, binarize_rng)[0]
    svi_state = svi.init(svi_init_rng, sample_batch)

    # functions for training tasks
    @jit
    def epoch_train(svi_state, batchifier_state, num_batches, rng):
        """Trains one epoch

        :param svi_state: current state of the optimizer
        :param rng: rng key

        :return: overall training loss over the epoch
        """
        def body_fn(i, val):
            svi_state, loss = val
            binarize_rng = random.fold_in(rng, i)
            batch = train_fetch(i, batchifier_state, binarize_rng)[0]
            svi_state, batch_loss = svi.update(svi_state, batch)
            loss += batch_loss / num_batches
            return svi_state, loss

        svi_state, loss = lax.fori_loop(0, num_batches, body_fn,
                                        (svi_state, 0.))
        return svi_state, loss

    @jit
    def eval_test(svi_state, batchifier_state, num_batches, rng):
        """Evaluates current model state on test data.

        :param svi_state: current state of the optimizer
        :param rng: rng key

        :return: loss over the test split
        """
        def body_fn(i, loss_sum):
            binarize_rng = random.fold_in(rng, i)
            batch = test_fetch(i, batchifier_state, binarize_rng)[0]
            batch_loss = svi.evaluate(svi_state, batch)
            loss_sum += batch_loss / num_batches
            return loss_sum

        return lax.fori_loop(0, num_batches, body_fn, 0.)

    def reconstruct_img(epoch, num_epochs, batchifier_state, svi_state, rng):
        """Reconstructs an image for the given epoch

        Obtains a sample from the testing data set and passes it through the
        VAE. Stores the result as image file 'epoch_{epoch}_recons.png' and
        the original input as 'epoch_{epoch}_original.png' in folder '.results'.

        :param epoch: Number of the current epoch
        :param num_epochs: Number of total epochs
        :param opt_state: Current state of the optimizer
        :param rng: rng key
        """
        assert (num_epochs > 0)
        img = test_fetch_plain(0, batchifier_state)[0][0]
        plt.imsave(os.path.join(
            RESULTS_DIR, "epoch_{:0{}d}_original.png".format(
                epoch, (int(jnp.log10(num_epochs)) + 1))),
                   img,
                   cmap='gray')
        rng, rng_binarize = random.split(rng, 2)
        test_sample = binarize(rng_binarize, img)
        test_sample = jnp.reshape(test_sample, (1, *jnp.shape(test_sample)))
        params = svi.get_params(svi_state)

        samples = sample_multi_posterior_predictive(
            rng, 10, model,
            (1, args.z_dim, args.hidden_dim, np.prod(test_sample.shape[1:])),
            guide, (test_sample, args.z_dim, args.hidden_dim), params)

        img_loc = samples['obs'][0].reshape([28, 28])
        avg_img_loc = jnp.mean(samples['obs'], axis=0).reshape([28, 28])
        plt.imsave(os.path.join(
            RESULTS_DIR, "epoch_{:0{}d}_recons_single.png".format(
                epoch, (int(jnp.log10(num_epochs)) + 1))),
                   img_loc,
                   cmap='gray')
        plt.imsave(os.path.join(
            RESULTS_DIR, "epoch_{:0{}d}_recons_avg.png".format(
                epoch, (int(jnp.log10(num_epochs)) + 1))),
                   avg_img_loc,
                   cmap='gray')

    # main training loop
    for i in range(args.num_epochs):
        t_start = time.time()
        rng, data_fetch_rng, train_rng = random.split(rng, 3)
        num_train_batches, train_batchifier_state, = train_init(
            rng_key=data_fetch_rng)
        svi_state, train_loss = epoch_train(svi_state, train_batchifier_state,
                                            num_train_batches, train_rng)

        rng, test_fetch_rng, test_rng, recons_rng = random.split(rng, 4)
        num_test_batches, test_batchifier_state = test_init(
            rng_key=test_fetch_rng)
        test_loss = eval_test(svi_state, test_batchifier_state,
                              num_test_batches, test_rng)

        reconstruct_img(i, args.num_epochs, test_batchifier_state, svi_state,
                        recons_rng)
        print("Epoch {}: loss = {} (on training set: {}) ({:.2f} s.)".format(
            i, test_loss, train_loss,
            time.time() - t_start))
Пример #24
0
 def elbo_loss_fn(x):
     return ELBO().loss(random.PRNGKey(0), {}, model, guide, x)
Пример #25
0
    def elbo_loss_fn(x):
        return ELBO().loss(random.PRNGKey(0), {}, model, guide, x)

    def renyi_loss_fn(x):
        return RenyiELBO(alpha=alpha,
                         num_particles=10).loss(random.PRNGKey(0), {}, model,
                                                guide, x)

    elbo_loss, elbo_grad = value_and_grad(elbo_loss_fn)(2.)
    renyi_loss, renyi_grad = value_and_grad(renyi_loss_fn)(2.)
    assert_allclose(elbo_loss, renyi_loss, rtol=1e-6)
    assert_allclose(elbo_grad, renyi_grad, rtol=1e-6)


@pytest.mark.parametrize('elbo', [
    ELBO(),
    RenyiELBO(num_particles=10),
])
def test_beta_bernoulli(elbo):
    data = jnp.array([1.0] * 8 + [0.0] * 2)

    def model(data):
        f = numpyro.sample("beta", dist.Beta(1., 1.))
        numpyro.sample("obs", dist.Bernoulli(f), obs=data)

    def guide(data):
        alpha_q = numpyro.param("alpha_q",
                                1.0,
                                constraint=constraints.positive)
        beta_q = numpyro.param("beta_q", 1.0, constraint=constraints.positive)
        numpyro.sample("beta", dist.Beta(alpha_q, beta_q))
Пример #26
0
def main(args):
    rng = PRNGKey(123)
    rng, toy_data_rng = jax.random.split(rng)

    train_data, test_data, true_params = create_toy_data(
        toy_data_rng, args.num_samples, args.dimensions)

    train_init, train_fetch = subsample_batchify_data(
        train_data, batch_size=args.batch_size)
    test_init, test_fetch = split_batchify_data(test_data,
                                                batch_size=args.batch_size)

    ## Init optimizer and training algorithms
    optimizer = optimizers.Adam(args.learning_rate)

    svi = DPSVI(model,
                guide,
                optimizer,
                ELBO(),
                dp_scale=0.01,
                clipping_threshold=20.,
                num_obs_total=args.num_samples)

    rng, svi_init_rng, data_fetch_rng = random.split(rng, 3)
    _, batchifier_state = train_init(rng_key=data_fetch_rng)
    sample_batch = train_fetch(0, batchifier_state)

    svi_state = svi.init(svi_init_rng, *sample_batch)

    @jit
    def epoch_train(svi_state, batchifier_state, num_batch):
        def body_fn(i, val):
            svi_state, loss = val
            batch = train_fetch(i, batchifier_state)
            batch_X, batch_Y = batch

            svi_state, batch_loss = svi.update(svi_state, batch_X, batch_Y)
            loss += batch_loss / (args.num_samples * num_batch)
            return svi_state, loss

        return lax.fori_loop(0, num_batch, body_fn, (svi_state, 0.))

    @jit
    def eval_test(svi_state, batchifier_state, num_batch, rng):
        params = svi.get_params(svi_state)

        def body_fn(i, val):
            loss_sum, acc_sum = val

            batch = test_fetch(i, batchifier_state)
            batch_X, batch_Y = batch

            loss = svi.evaluate(svi_state, batch_X, batch_Y)
            loss_sum += loss / (args.num_samples * num_batch)

            acc_rng = jax.random.fold_in(rng, i)
            acc = estimate_accuracy(batch_X, batch_Y, params, acc_rng, 1)
            acc_sum += acc / num_batch

            return loss_sum, acc_sum

        return lax.fori_loop(0, num_batch, body_fn, (0., 0.))

## Train model

    for i in range(args.num_epochs):
        t_start = time.time()
        rng, data_fetch_rng = random.split(rng, 2)

        num_train_batches, train_batchifier_state = train_init(
            rng_key=data_fetch_rng)
        svi_state, train_loss = epoch_train(svi_state, train_batchifier_state,
                                            num_train_batches)
        train_loss.block_until_ready()
        t_end = time.time()

        if (i % (args.num_epochs // 10)) == 0:
            rng, test_rng, test_fetch_rng = random.split(rng, 3)
            num_test_batches, test_batchifier_state = test_init(
                rng_key=test_fetch_rng)
            test_loss, test_acc = eval_test(svi_state, test_batchifier_state,
                                            num_test_batches, test_rng)
            print(
                "Epoch {}: loss = {}, acc = {} (loss on training set: {}) ({:.2f} s.)"
                .format(i, test_loss, test_acc, train_loss, t_end - t_start))

    # parameters for logistic regression may be scaled arbitrarily. normalize
    #   w (and scale intercept accordingly) for comparison
    w_true = normalize(true_params[0])
    scale_true = jnp.linalg.norm(true_params[0])
    intercept_true = true_params[1] / scale_true

    params = svi.get_params(svi_state)
    w_post = normalize(params['w_loc'])
    scale_post = jnp.linalg.norm(params['w_loc'])
    intercept_post = params['intercept_loc'] / scale_post

    print("w_loc: {}\nexpected: {}\nerror: {}".format(
        w_post, w_true, jnp.linalg.norm(w_post - w_true)))
    print("w_std: {}".format(jnp.exp(params['w_std_log'])))
    print("")
    print("intercept_loc: {}\nexpected: {}\nerror: {}".format(
        intercept_post, intercept_true,
        jnp.abs(intercept_post - intercept_true)))
    print("intercept_std: {}".format(jnp.exp(params['intercept_std_log'])))
    print("")

    X_test, y_test = test_data
    rng, rng_acc_true, rng_acc_post = jax.random.split(rng, 3)
    # for evaluation accuracy with true parameters, we scale them to the same
    #   scale as the found posterior. (gives better results than normalized
    #   parameters (probably due to numerical instabilities))
    acc_true = estimate_accuracy_fixed_params(X_test, y_test, w_true,
                                              intercept_true, rng_acc_true, 10)
    acc_post = estimate_accuracy(X_test, y_test, params, rng_acc_post, 10)

    print(
        "avg accuracy on test set:  with true parameters: {} ; with found posterior: {}\n"
        .format(acc_true, acc_post))
Пример #27
0
def main(args):
    N = args.num_samples
    k = args.num_components
    d = args.dimensions

    rng = PRNGKey(1234)
    rng, toy_data_rng = jax.random.split(rng, 2)

    X_train, X_test, latent_vals = create_toy_data(toy_data_rng, N, d)
    train_init, train_fetch = subsample_batchify_data((X_train,), batch_size=args.batch_size)
    test_init, test_fetch = split_batchify_data((X_test,), batch_size=args.batch_size)

    ## Init optimizer and training algorithms
    optimizer = optimizers.Adam(args.learning_rate)

    # note(lumip): fix the parameters in the models
    def fix_params(model_fn, k):
        def fixed_params_fn(obs, **kwargs):
            return model_fn(k, obs, **kwargs)
        return fixed_params_fn

    model_fixed = fix_params(model, k)
    guide_fixed = fix_params(guide, k)

    svi = DPSVI(
        model_fixed, guide_fixed, optimizer, ELBO(),
        dp_scale=0.01,  clipping_threshold=20., num_obs_total=args.num_samples
    )

    rng, svi_init_rng, fetch_rng = random.split(rng, 3)
    _, batchifier_state = train_init(fetch_rng)
    batch = train_fetch(0, batchifier_state)
    svi_state = svi.init(svi_init_rng, *batch)

    @jit
    def epoch_train(svi_state, data_idx, num_batch):
        def body_fn(i, val):
            svi_state, loss = val
            batch = train_fetch(i, batchifier_state)
            svi_state, batch_loss = svi.update(
                svi_state, *batch
            )
            loss += batch_loss / (args.num_samples * num_batch)
            return svi_state, loss

        return lax.fori_loop(0, num_batch, body_fn, (svi_state, 0.))

    @jit
    def eval_test(svi_state, batchifier_state, num_batch):
        def body_fn(i, loss_sum):
            batch = test_fetch(i, batchifier_state)
            loss = svi.evaluate(svi_state, *batch)
            loss_sum += loss / (args.num_samples * num_batch)
            return loss_sum

        return lax.fori_loop(0, num_batch, body_fn, 0.)

	## Train model
    for i in range(args.num_epochs):
        t_start = time.time()
        rng, data_fetch_rng = random.split(rng, 2)

        num_train_batches, train_batchifier_state = train_init(rng_key=data_fetch_rng)
        svi_state, train_loss = epoch_train(
            svi_state, train_batchifier_state, num_train_batches
        )
        train_loss.block_until_ready()
        t_end = time.time()

        if i % 100 == 0:
            rng, test_fetch_rng = random.split(rng, 2)
            num_test_batches, test_batchifier_state = test_init(rng_key=test_fetch_rng)
            test_loss = eval_test(
                svi_state, test_batchifier_state, num_test_batches
            )

            print("Epoch {}: loss = {} (on training set = {}) ({:.2f} s.)".format(
                    i, test_loss, train_loss, t_end - t_start
                ))

    params = svi.get_params(svi_state)
    print(params)
    posterior_modes = params['mus_loc']
    posterior_pis = dist.Dirichlet(jnp.exp(params['alpha_log'])).mean
    print("MAP estimate of mixture weights: {}".format(posterior_pis))
    print("MAP estimate of mixture modes  : {}".format(posterior_modes))

    acc = compute_assignment_accuracy(
        X_test, latent_vals[1], latent_vals[2], posterior_modes, posterior_pis
    )
    print("assignment accuracy: {}".format(acc))