Beispiel #1
0
    def test_split_batchify_fetch(self):
        data = np.arange(105) + 100
        init, fetch = split_batchify_data((data, ), 10)
        batchifier_state = jax.random.permutation(jax.random.PRNGKey(0),
                                                  jnp.arange(0, 105))
        num_batches = 10

        counts = np.zeros(105)
        for i in range(num_batches):
            batch = fetch(i, batchifier_state)
            batch = batch[0]
            unq_idxs, unq_counts = np.unique(batch, return_counts=True)
            counts[unq_idxs - 100] = unq_counts
            self.assertTrue(
                np.alltrue(unq_counts <= 1)
            )  # ensure each item occurs at most once in the batch
            self.assertTrue(
                np.alltrue(batch >= 100) and np.alltrue(
                    batch < 205))  # ensure batch was plausibly drawn from data

        self.assertTrue(np.alltrue(
            counts <= 1))  # ensure each item occurs at most once in the epoch
        self.assertEqual(
            100, np.sum(counts)
        )  # ensure that amount of elements in batches cover an epoch worth of data
Beispiel #2
0
    def test_split_batchify_fetch_correct_shape(self):
        data = np.random.normal(size=(105, 3))
        init, fetch = split_batchify_data((data, ), 10)
        batchifier_state = jax.random.permutation(jax.random.PRNGKey(0),
                                                  jnp.arange(0, 105))

        batch = fetch(6, batchifier_state)
        batch = batch[0]
        self.assertEqual((10, 3), jnp.shape(batch))
Beispiel #3
0
    def test_split_batchify_batches_differ(self):
        data = np.arange(105) + 100
        init, fetch = split_batchify_data((data, ), 10)
        num_batches, batchifier_state = init(jax.random.PRNGKey(10))

        batch_0 = fetch(3, batchifier_state)
        batch_1 = fetch(8, batchifier_state)
        self.assertFalse(jnp.allclose(batch_0,
                                      batch_1))  # ensure batches are different
Beispiel #4
0
    def test_split_batchify_init_non_divisiable_size(self):
        data = jnp.arange(0, 105)
        init, fetch = split_batchify_data((data, ), 10)

        rng_key = jax.random.PRNGKey(0)
        num_batches, batchifier_state = init(rng_key)

        self.assertEqual(10, num_batches)
        self.assertTrue(
            np.alltrue(np.unique(batchifier_state, return_counts=True)[1] < 2))
Beispiel #5
0
    def test_split_batchify_init(self):
        data = jnp.arange(0, 100)
        init, fetch = split_batchify_data((data, ), 10)

        rng_key = jax.random.PRNGKey(0)
        num_batches, batchifier_state = init(rng_key)

        self.assertEqual(10, num_batches)
        self.assertEqual(jnp.size(data), jnp.size(batchifier_state))
        self.assertTrue(np.allclose(np.unique(batchifier_state), data))
Beispiel #6
0
def main(args):
    rng = PRNGKey(123)
    rng, toy_data_rng = jax.random.split(rng)

    train_data, test_data, true_params = create_toy_data(
        toy_data_rng, args.num_samples, args.dimensions)

    train_init, train_fetch = subsample_batchify_data(
        train_data, batch_size=args.batch_size)
    test_init, test_fetch = split_batchify_data(test_data,
                                                batch_size=args.batch_size)

    ## Init optimizer and training algorithms
    optimizer = optimizers.Adam(args.learning_rate)

    svi = DPSVI(model,
                guide,
                optimizer,
                ELBO(),
                dp_scale=0.01,
                clipping_threshold=20.,
                num_obs_total=args.num_samples)

    rng, svi_init_rng, data_fetch_rng = random.split(rng, 3)
    _, batchifier_state = train_init(rng_key=data_fetch_rng)
    sample_batch = train_fetch(0, batchifier_state)

    svi_state = svi.init(svi_init_rng, *sample_batch)

    @jit
    def epoch_train(svi_state, batchifier_state, num_batch):
        def body_fn(i, val):
            svi_state, loss = val
            batch = train_fetch(i, batchifier_state)
            batch_X, batch_Y = batch

            svi_state, batch_loss = svi.update(svi_state, batch_X, batch_Y)
            loss += batch_loss / (args.num_samples * num_batch)
            return svi_state, loss

        return lax.fori_loop(0, num_batch, body_fn, (svi_state, 0.))

    @jit
    def eval_test(svi_state, batchifier_state, num_batch, rng):
        params = svi.get_params(svi_state)

        def body_fn(i, val):
            loss_sum, acc_sum = val

            batch = test_fetch(i, batchifier_state)
            batch_X, batch_Y = batch

            loss = svi.evaluate(svi_state, batch_X, batch_Y)
            loss_sum += loss / (args.num_samples * num_batch)

            acc_rng = jax.random.fold_in(rng, i)
            acc = estimate_accuracy(batch_X, batch_Y, params, acc_rng, 1)
            acc_sum += acc / num_batch

            return loss_sum, acc_sum

        return lax.fori_loop(0, num_batch, body_fn, (0., 0.))

## Train model

    for i in range(args.num_epochs):
        t_start = time.time()
        rng, data_fetch_rng = random.split(rng, 2)

        num_train_batches, train_batchifier_state = train_init(
            rng_key=data_fetch_rng)
        svi_state, train_loss = epoch_train(svi_state, train_batchifier_state,
                                            num_train_batches)
        train_loss.block_until_ready()
        t_end = time.time()

        if (i % (args.num_epochs // 10)) == 0:
            rng, test_rng, test_fetch_rng = random.split(rng, 3)
            num_test_batches, test_batchifier_state = test_init(
                rng_key=test_fetch_rng)
            test_loss, test_acc = eval_test(svi_state, test_batchifier_state,
                                            num_test_batches, test_rng)
            print(
                "Epoch {}: loss = {}, acc = {} (loss on training set: {}) ({:.2f} s.)"
                .format(i, test_loss, test_acc, train_loss, t_end - t_start))

    # parameters for logistic regression may be scaled arbitrarily. normalize
    #   w (and scale intercept accordingly) for comparison
    w_true = normalize(true_params[0])
    scale_true = jnp.linalg.norm(true_params[0])
    intercept_true = true_params[1] / scale_true

    params = svi.get_params(svi_state)
    w_post = normalize(params['w_loc'])
    scale_post = jnp.linalg.norm(params['w_loc'])
    intercept_post = params['intercept_loc'] / scale_post

    print("w_loc: {}\nexpected: {}\nerror: {}".format(
        w_post, w_true, jnp.linalg.norm(w_post - w_true)))
    print("w_std: {}".format(jnp.exp(params['w_std_log'])))
    print("")
    print("intercept_loc: {}\nexpected: {}\nerror: {}".format(
        intercept_post, intercept_true,
        jnp.abs(intercept_post - intercept_true)))
    print("intercept_std: {}".format(jnp.exp(params['intercept_std_log'])))
    print("")

    X_test, y_test = test_data
    rng, rng_acc_true, rng_acc_post = jax.random.split(rng, 3)
    # for evaluation accuracy with true parameters, we scale them to the same
    #   scale as the found posterior. (gives better results than normalized
    #   parameters (probably due to numerical instabilities))
    acc_true = estimate_accuracy_fixed_params(X_test, y_test, w_true,
                                              intercept_true, rng_acc_true, 10)
    acc_post = estimate_accuracy(X_test, y_test, params, rng_acc_post, 10)

    print(
        "avg accuracy on test set:  with true parameters: {} ; with found posterior: {}\n"
        .format(acc_true, acc_post))
Beispiel #7
0
def main(args):
    N = args.num_samples
    k = args.num_components
    d = args.dimensions

    rng = PRNGKey(1234)
    rng, toy_data_rng = jax.random.split(rng, 2)

    X_train, X_test, latent_vals = create_toy_data(toy_data_rng, N, d)
    train_init, train_fetch = subsample_batchify_data((X_train,), batch_size=args.batch_size)
    test_init, test_fetch = split_batchify_data((X_test,), batch_size=args.batch_size)

    ## Init optimizer and training algorithms
    optimizer = optimizers.Adam(args.learning_rate)

    # note(lumip): fix the parameters in the models
    def fix_params(model_fn, k):
        def fixed_params_fn(obs, **kwargs):
            return model_fn(k, obs, **kwargs)
        return fixed_params_fn

    model_fixed = fix_params(model, k)
    guide_fixed = fix_params(guide, k)

    svi = DPSVI(
        model_fixed, guide_fixed, optimizer, ELBO(),
        dp_scale=0.01,  clipping_threshold=20., num_obs_total=args.num_samples
    )

    rng, svi_init_rng, fetch_rng = random.split(rng, 3)
    _, batchifier_state = train_init(fetch_rng)
    batch = train_fetch(0, batchifier_state)
    svi_state = svi.init(svi_init_rng, *batch)

    @jit
    def epoch_train(svi_state, data_idx, num_batch):
        def body_fn(i, val):
            svi_state, loss = val
            batch = train_fetch(i, batchifier_state)
            svi_state, batch_loss = svi.update(
                svi_state, *batch
            )
            loss += batch_loss / (args.num_samples * num_batch)
            return svi_state, loss

        return lax.fori_loop(0, num_batch, body_fn, (svi_state, 0.))

    @jit
    def eval_test(svi_state, batchifier_state, num_batch):
        def body_fn(i, loss_sum):
            batch = test_fetch(i, batchifier_state)
            loss = svi.evaluate(svi_state, *batch)
            loss_sum += loss / (args.num_samples * num_batch)
            return loss_sum

        return lax.fori_loop(0, num_batch, body_fn, 0.)

	## Train model
    for i in range(args.num_epochs):
        t_start = time.time()
        rng, data_fetch_rng = random.split(rng, 2)

        num_train_batches, train_batchifier_state = train_init(rng_key=data_fetch_rng)
        svi_state, train_loss = epoch_train(
            svi_state, train_batchifier_state, num_train_batches
        )
        train_loss.block_until_ready()
        t_end = time.time()

        if i % 100 == 0:
            rng, test_fetch_rng = random.split(rng, 2)
            num_test_batches, test_batchifier_state = test_init(rng_key=test_fetch_rng)
            test_loss = eval_test(
                svi_state, test_batchifier_state, num_test_batches
            )

            print("Epoch {}: loss = {} (on training set = {}) ({:.2f} s.)".format(
                    i, test_loss, train_loss, t_end - t_start
                ))

    params = svi.get_params(svi_state)
    print(params)
    posterior_modes = params['mus_loc']
    posterior_pis = dist.Dirichlet(jnp.exp(params['alpha_log'])).mean
    print("MAP estimate of mixture weights: {}".format(posterior_pis))
    print("MAP estimate of mixture modes  : {}".format(posterior_modes))

    acc = compute_assignment_accuracy(
        X_test, latent_vals[1], latent_vals[2], posterior_modes, posterior_pis
    )
    print("assignment accuracy: {}".format(acc))
Beispiel #8
0
def main(args):
    rng = PRNGKey(1234)
    rng, toy_data_rng = jax.random.split(rng, 2)
    X_train, X_test, mu_true = create_toy_data(toy_data_rng, args.num_samples,
                                               args.dimensions)

    train_init, train_fetch = subsample_batchify_data(
        (X_train, ), batch_size=args.batch_size)
    test_init, test_fetch = split_batchify_data((X_test, ),
                                                batch_size=args.batch_size)

    ## Init optimizer and training algorithms
    optimizer = optimizers.Adam(args.learning_rate)

    svi = DPSVI(model,
                guide,
                optimizer,
                ELBO(),
                dp_scale=args.sigma,
                clipping_threshold=args.clip_threshold,
                d=args.dimensions,
                num_obs_total=args.num_samples)

    rng, svi_init_rng, batchifier_rng = random.split(rng, 3)
    _, batchifier_state = train_init(rng_key=batchifier_rng)
    batch = train_fetch(0, batchifier_state)
    svi_state = svi.init(svi_init_rng, *batch)

    q = args.batch_size / args.num_samples
    eps = svi.get_epsilon(args.delta, q, num_epochs=args.num_epochs)
    print("Privacy epsilon {} (for sigma: {}, delta: {}, C: {}, q: {})".format(
        eps, args.sigma, args.clip_threshold, args.delta, q))

    @jit
    def epoch_train(svi_state, batchifier_state, num_batch):
        def body_fn(i, val):
            svi_state, loss = val
            batch = train_fetch(i, batchifier_state)
            svi_state, batch_loss = svi.update(svi_state, *batch)
            loss += batch_loss / (args.num_samples * num_batch)
            return svi_state, loss

        return lax.fori_loop(0, num_batch, body_fn, (svi_state, 0.))

    @jit
    def eval_test(svi_state, batchifier_state, num_batch):
        def body_fn(i, loss_sum):
            batch = test_fetch(i, batchifier_state)
            loss = svi.evaluate(svi_state, *batch)
            loss_sum += loss / (args.num_samples * num_batch)

            return loss_sum

        return lax.fori_loop(0, num_batch, body_fn, 0.)

## Train model

    for i in range(args.num_epochs):
        t_start = time.time()
        rng, data_fetch_rng = random.split(rng, 2)

        num_train_batches, train_batchifier_state = train_init(
            rng_key=data_fetch_rng)
        svi_state, train_loss = epoch_train(svi_state, train_batchifier_state,
                                            num_train_batches)
        train_loss.block_until_ready()
        t_end = time.time()

        if (i % (args.num_epochs // 10) == 0):
            rng, test_fetch_rng = random.split(rng, 2)
            num_test_batches, test_batchifier_state = test_init(
                rng_key=test_fetch_rng)
            test_loss = eval_test(svi_state, test_batchifier_state,
                                  num_test_batches)

            print(
                "Epoch {}: loss = {} (on training set: {}) ({:.2f} s.)".format(
                    i, test_loss, train_loss, t_end - t_start))

    params = svi.get_params(svi_state)
    mu_loc = params['mu_loc']
    mu_std = jnp.exp(params['mu_std_log'])
    print("### expected: {}".format(mu_true))
    print("### svi result\nmu_loc: {}\nerror: {}\nmu_std: {}".format(
        mu_loc, jnp.linalg.norm(mu_loc - mu_true), mu_std))
    mu_loc, mu_std = analytical_solution(X_train)
    print("### analytical solution\nmu_loc: {}\nerror: {}\nmu_std: {}".format(
        mu_loc, jnp.linalg.norm(mu_loc - mu_true), mu_std))
    mu_loc, mu_std = ml_estimate(X_train)
    print("### ml estimate\nmu_loc: {}\nerror: {}\nmu_std: {}".format(
        mu_loc, jnp.linalg.norm(mu_loc - mu_true), mu_std))