Пример #1
0
def train_model(rng,
                rng_suite,
                model,
                guide,
                data,
                batch_size,
                num_data,
                dp_scale,
                num_epochs,
                clipping_threshold=1.):
    """ trains a given model using DPSVI and the globally defined parameters and data """

    optimizer = Adam(1e-3)

    svi = DPSVI(model,
                guide,
                optimizer,
                Trace_ELBO(),
                num_obs_total=num_data,
                clipping_threshold=clipping_threshold,
                dp_scale=dp_scale,
                rng_suite=rng_suite)

    return _train_model(rng, rng_suite, svi, data, batch_size, num_data,
                        num_epochs)
Пример #2
0
 def setUp(self):
     self.rng = jax.random.PRNGKey(9782346)
     self.batch_size = 10
     self.num_obs_total = 100
     self.px_grads = ((jnp.zeros(
         (self.batch_size, 10000)), jnp.zeros((self.batch_size, 10000))))
     self.px_grads_list, self.tree_def = jax.tree_flatten(self.px_grads)
     self.px_loss = jnp.arange(self.batch_size, dtype=jnp.float32)
     self.dp_scale = 1.
     self.clipping_threshold = 2.
     self.svi = DPSVI(None,
                      None,
                      None,
                      None,
                      self.clipping_threshold,
                      self.dp_scale,
                      num_obs_total=self.num_obs_total)
Пример #3
0
def main(args):
    rng = PRNGKey(123)
    rng, toy_data_rng = jax.random.split(rng)

    train_data, test_data, true_params = create_toy_data(
        toy_data_rng, args.num_samples, args.dimensions)

    train_init, train_fetch = subsample_batchify_data(
        train_data, batch_size=args.batch_size)
    test_init, test_fetch = split_batchify_data(test_data,
                                                batch_size=args.batch_size)

    ## Init optimizer and training algorithms
    optimizer = optimizers.Adam(args.learning_rate)

    svi = DPSVI(model,
                guide,
                optimizer,
                ELBO(),
                dp_scale=0.01,
                clipping_threshold=20.,
                num_obs_total=args.num_samples)

    rng, svi_init_rng, data_fetch_rng = random.split(rng, 3)
    _, batchifier_state = train_init(rng_key=data_fetch_rng)
    sample_batch = train_fetch(0, batchifier_state)

    svi_state = svi.init(svi_init_rng, *sample_batch)

    @jit
    def epoch_train(svi_state, batchifier_state, num_batch):
        def body_fn(i, val):
            svi_state, loss = val
            batch = train_fetch(i, batchifier_state)
            batch_X, batch_Y = batch

            svi_state, batch_loss = svi.update(svi_state, batch_X, batch_Y)
            loss += batch_loss / (args.num_samples * num_batch)
            return svi_state, loss

        return lax.fori_loop(0, num_batch, body_fn, (svi_state, 0.))

    @jit
    def eval_test(svi_state, batchifier_state, num_batch, rng):
        params = svi.get_params(svi_state)

        def body_fn(i, val):
            loss_sum, acc_sum = val

            batch = test_fetch(i, batchifier_state)
            batch_X, batch_Y = batch

            loss = svi.evaluate(svi_state, batch_X, batch_Y)
            loss_sum += loss / (args.num_samples * num_batch)

            acc_rng = jax.random.fold_in(rng, i)
            acc = estimate_accuracy(batch_X, batch_Y, params, acc_rng, 1)
            acc_sum += acc / num_batch

            return loss_sum, acc_sum

        return lax.fori_loop(0, num_batch, body_fn, (0., 0.))

## Train model

    for i in range(args.num_epochs):
        t_start = time.time()
        rng, data_fetch_rng = random.split(rng, 2)

        num_train_batches, train_batchifier_state = train_init(
            rng_key=data_fetch_rng)
        svi_state, train_loss = epoch_train(svi_state, train_batchifier_state,
                                            num_train_batches)
        train_loss.block_until_ready()
        t_end = time.time()

        if (i % (args.num_epochs // 10)) == 0:
            rng, test_rng, test_fetch_rng = random.split(rng, 3)
            num_test_batches, test_batchifier_state = test_init(
                rng_key=test_fetch_rng)
            test_loss, test_acc = eval_test(svi_state, test_batchifier_state,
                                            num_test_batches, test_rng)
            print(
                "Epoch {}: loss = {}, acc = {} (loss on training set: {}) ({:.2f} s.)"
                .format(i, test_loss, test_acc, train_loss, t_end - t_start))

    # parameters for logistic regression may be scaled arbitrarily. normalize
    #   w (and scale intercept accordingly) for comparison
    w_true = normalize(true_params[0])
    scale_true = jnp.linalg.norm(true_params[0])
    intercept_true = true_params[1] / scale_true

    params = svi.get_params(svi_state)
    w_post = normalize(params['w_loc'])
    scale_post = jnp.linalg.norm(params['w_loc'])
    intercept_post = params['intercept_loc'] / scale_post

    print("w_loc: {}\nexpected: {}\nerror: {}".format(
        w_post, w_true, jnp.linalg.norm(w_post - w_true)))
    print("w_std: {}".format(jnp.exp(params['w_std_log'])))
    print("")
    print("intercept_loc: {}\nexpected: {}\nerror: {}".format(
        intercept_post, intercept_true,
        jnp.abs(intercept_post - intercept_true)))
    print("intercept_std: {}".format(jnp.exp(params['intercept_std_log'])))
    print("")

    X_test, y_test = test_data
    rng, rng_acc_true, rng_acc_post = jax.random.split(rng, 3)
    # for evaluation accuracy with true parameters, we scale them to the same
    #   scale as the found posterior. (gives better results than normalized
    #   parameters (probably due to numerical instabilities))
    acc_true = estimate_accuracy_fixed_params(X_test, y_test, w_true,
                                              intercept_true, rng_acc_true, 10)
    acc_post = estimate_accuracy(X_test, y_test, params, rng_acc_post, 10)

    print(
        "avg accuracy on test set:  with true parameters: {} ; with found posterior: {}\n"
        .format(acc_true, acc_post))
Пример #4
0
def main(args):
    N = args.num_samples
    k = args.num_components
    d = args.dimensions

    rng = PRNGKey(1234)
    rng, toy_data_rng = jax.random.split(rng, 2)

    X_train, X_test, latent_vals = create_toy_data(toy_data_rng, N, d)
    train_init, train_fetch = subsample_batchify_data((X_train,), batch_size=args.batch_size)
    test_init, test_fetch = split_batchify_data((X_test,), batch_size=args.batch_size)

    ## Init optimizer and training algorithms
    optimizer = optimizers.Adam(args.learning_rate)

    # note(lumip): fix the parameters in the models
    def fix_params(model_fn, k):
        def fixed_params_fn(obs, **kwargs):
            return model_fn(k, obs, **kwargs)
        return fixed_params_fn

    model_fixed = fix_params(model, k)
    guide_fixed = fix_params(guide, k)

    svi = DPSVI(
        model_fixed, guide_fixed, optimizer, ELBO(),
        dp_scale=0.01,  clipping_threshold=20., num_obs_total=args.num_samples
    )

    rng, svi_init_rng, fetch_rng = random.split(rng, 3)
    _, batchifier_state = train_init(fetch_rng)
    batch = train_fetch(0, batchifier_state)
    svi_state = svi.init(svi_init_rng, *batch)

    @jit
    def epoch_train(svi_state, data_idx, num_batch):
        def body_fn(i, val):
            svi_state, loss = val
            batch = train_fetch(i, batchifier_state)
            svi_state, batch_loss = svi.update(
                svi_state, *batch
            )
            loss += batch_loss / (args.num_samples * num_batch)
            return svi_state, loss

        return lax.fori_loop(0, num_batch, body_fn, (svi_state, 0.))

    @jit
    def eval_test(svi_state, batchifier_state, num_batch):
        def body_fn(i, loss_sum):
            batch = test_fetch(i, batchifier_state)
            loss = svi.evaluate(svi_state, *batch)
            loss_sum += loss / (args.num_samples * num_batch)
            return loss_sum

        return lax.fori_loop(0, num_batch, body_fn, 0.)

	## Train model
    for i in range(args.num_epochs):
        t_start = time.time()
        rng, data_fetch_rng = random.split(rng, 2)

        num_train_batches, train_batchifier_state = train_init(rng_key=data_fetch_rng)
        svi_state, train_loss = epoch_train(
            svi_state, train_batchifier_state, num_train_batches
        )
        train_loss.block_until_ready()
        t_end = time.time()

        if i % 100 == 0:
            rng, test_fetch_rng = random.split(rng, 2)
            num_test_batches, test_batchifier_state = test_init(rng_key=test_fetch_rng)
            test_loss = eval_test(
                svi_state, test_batchifier_state, num_test_batches
            )

            print("Epoch {}: loss = {} (on training set = {}) ({:.2f} s.)".format(
                    i, test_loss, train_loss, t_end - t_start
                ))

    params = svi.get_params(svi_state)
    print(params)
    posterior_modes = params['mus_loc']
    posterior_pis = dist.Dirichlet(jnp.exp(params['alpha_log'])).mean
    print("MAP estimate of mixture weights: {}".format(posterior_pis))
    print("MAP estimate of mixture modes  : {}".format(posterior_modes))

    acc = compute_assignment_accuracy(
        X_test, latent_vals[1], latent_vals[2], posterior_modes, posterior_pis
    )
    print("assignment accuracy: {}".format(acc))
Пример #5
0
def main(args):
    rng = PRNGKey(1234)
    rng, toy_data_rng = jax.random.split(rng, 2)
    X_train, X_test, mu_true = create_toy_data(toy_data_rng, args.num_samples,
                                               args.dimensions)

    train_init, train_fetch = subsample_batchify_data(
        (X_train, ), batch_size=args.batch_size)
    test_init, test_fetch = split_batchify_data((X_test, ),
                                                batch_size=args.batch_size)

    ## Init optimizer and training algorithms
    optimizer = optimizers.Adam(args.learning_rate)

    svi = DPSVI(model,
                guide,
                optimizer,
                ELBO(),
                dp_scale=args.sigma,
                clipping_threshold=args.clip_threshold,
                d=args.dimensions,
                num_obs_total=args.num_samples)

    rng, svi_init_rng, batchifier_rng = random.split(rng, 3)
    _, batchifier_state = train_init(rng_key=batchifier_rng)
    batch = train_fetch(0, batchifier_state)
    svi_state = svi.init(svi_init_rng, *batch)

    q = args.batch_size / args.num_samples
    eps = svi.get_epsilon(args.delta, q, num_epochs=args.num_epochs)
    print("Privacy epsilon {} (for sigma: {}, delta: {}, C: {}, q: {})".format(
        eps, args.sigma, args.clip_threshold, args.delta, q))

    @jit
    def epoch_train(svi_state, batchifier_state, num_batch):
        def body_fn(i, val):
            svi_state, loss = val
            batch = train_fetch(i, batchifier_state)
            svi_state, batch_loss = svi.update(svi_state, *batch)
            loss += batch_loss / (args.num_samples * num_batch)
            return svi_state, loss

        return lax.fori_loop(0, num_batch, body_fn, (svi_state, 0.))

    @jit
    def eval_test(svi_state, batchifier_state, num_batch):
        def body_fn(i, loss_sum):
            batch = test_fetch(i, batchifier_state)
            loss = svi.evaluate(svi_state, *batch)
            loss_sum += loss / (args.num_samples * num_batch)

            return loss_sum

        return lax.fori_loop(0, num_batch, body_fn, 0.)

## Train model

    for i in range(args.num_epochs):
        t_start = time.time()
        rng, data_fetch_rng = random.split(rng, 2)

        num_train_batches, train_batchifier_state = train_init(
            rng_key=data_fetch_rng)
        svi_state, train_loss = epoch_train(svi_state, train_batchifier_state,
                                            num_train_batches)
        train_loss.block_until_ready()
        t_end = time.time()

        if (i % (args.num_epochs // 10) == 0):
            rng, test_fetch_rng = random.split(rng, 2)
            num_test_batches, test_batchifier_state = test_init(
                rng_key=test_fetch_rng)
            test_loss = eval_test(svi_state, test_batchifier_state,
                                  num_test_batches)

            print(
                "Epoch {}: loss = {} (on training set: {}) ({:.2f} s.)".format(
                    i, test_loss, train_loss, t_end - t_start))

    params = svi.get_params(svi_state)
    mu_loc = params['mu_loc']
    mu_std = jnp.exp(params['mu_std_log'])
    print("### expected: {}".format(mu_true))
    print("### svi result\nmu_loc: {}\nerror: {}\nmu_std: {}".format(
        mu_loc, jnp.linalg.norm(mu_loc - mu_true), mu_std))
    mu_loc, mu_std = analytical_solution(X_train)
    print("### analytical solution\nmu_loc: {}\nerror: {}\nmu_std: {}".format(
        mu_loc, jnp.linalg.norm(mu_loc - mu_true), mu_std))
    mu_loc, mu_std = ml_estimate(X_train)
    print("### ml estimate\nmu_loc: {}\nerror: {}\nmu_std: {}".format(
        mu_loc, jnp.linalg.norm(mu_loc - mu_true), mu_std))
Пример #6
0
Файл: vae.py Проект: byzhang/d3p
def main(args):
    # loading data
    (train_init, train_fetch_plain), num_samples = load_dataset(
        MNIST,
        batch_size=args.batch_size,
        split='train',
        batchifier=subsample_batchify_data)
    (test_init,
     test_fetch_plain), _ = load_dataset(MNIST,
                                         batch_size=args.batch_size,
                                         split='test',
                                         batchifier=split_batchify_data)

    def binarize_fetch(fetch_fn):
        @jit
        def fetch_binarized(batch_nr, batchifier_state, binarize_rng):
            batch = fetch_fn(batch_nr, batchifier_state)
            return binarize(binarize_rng, batch[0]), batch[1]

        return fetch_binarized

    train_fetch = binarize_fetch(train_fetch_plain)
    test_fetch = binarize_fetch(test_fetch_plain)

    # setting up optimizer
    optimizer = optimizers.Adam(args.learning_rate)

    # the minibatch environment in our model scales individual
    # records' contributions to the loss up by num_samples.
    # This can cause numerical instabilities so we scale down
    # the loss by 1/num_samples here.
    sc_model = scale(model, scale=1 / num_samples)
    sc_guide = scale(guide, scale=1 / num_samples)

    if args.no_dp:
        svi = SVI(sc_model,
                  sc_guide,
                  optimizer,
                  ELBO(),
                  num_obs_total=num_samples,
                  z_dim=args.z_dim,
                  hidden_dim=args.hidden_dim)
    else:
        q = args.batch_size / num_samples
        target_eps = args.epsilon
        dp_scale, act_eps, _ = approximate_sigma(target_eps=target_eps,
                                                 delta=1 / num_samples,
                                                 q=q,
                                                 num_iter=int(1 / q) *
                                                 args.num_epochs,
                                                 force_smaller=True)
        print(
            f"using noise scale {dp_scale} for epsilon of {act_eps} (targeted: {target_eps})"
        )

        svi = DPSVI(sc_model,
                    sc_guide,
                    optimizer,
                    ELBO(),
                    dp_scale=dp_scale,
                    clipping_threshold=10.,
                    num_obs_total=num_samples,
                    z_dim=args.z_dim,
                    hidden_dim=args.hidden_dim)

    # preparing random number generators and initializing svi
    rng = PRNGKey(0)
    rng, binarize_rng, svi_init_rng, batchifier_rng = random.split(rng, 4)
    _, batchifier_state = train_init(rng_key=batchifier_rng)
    sample_batch = train_fetch(0, batchifier_state, binarize_rng)[0]
    svi_state = svi.init(svi_init_rng, sample_batch)

    # functions for training tasks
    @jit
    def epoch_train(svi_state, batchifier_state, num_batches, rng):
        """Trains one epoch

        :param svi_state: current state of the optimizer
        :param rng: rng key

        :return: overall training loss over the epoch
        """
        def body_fn(i, val):
            svi_state, loss = val
            binarize_rng = random.fold_in(rng, i)
            batch = train_fetch(i, batchifier_state, binarize_rng)[0]
            svi_state, batch_loss = svi.update(svi_state, batch)
            loss += batch_loss / num_batches
            return svi_state, loss

        svi_state, loss = lax.fori_loop(0, num_batches, body_fn,
                                        (svi_state, 0.))
        return svi_state, loss

    @jit
    def eval_test(svi_state, batchifier_state, num_batches, rng):
        """Evaluates current model state on test data.

        :param svi_state: current state of the optimizer
        :param rng: rng key

        :return: loss over the test split
        """
        def body_fn(i, loss_sum):
            binarize_rng = random.fold_in(rng, i)
            batch = test_fetch(i, batchifier_state, binarize_rng)[0]
            batch_loss = svi.evaluate(svi_state, batch)
            loss_sum += batch_loss / num_batches
            return loss_sum

        return lax.fori_loop(0, num_batches, body_fn, 0.)

    def reconstruct_img(epoch, num_epochs, batchifier_state, svi_state, rng):
        """Reconstructs an image for the given epoch

        Obtains a sample from the testing data set and passes it through the
        VAE. Stores the result as image file 'epoch_{epoch}_recons.png' and
        the original input as 'epoch_{epoch}_original.png' in folder '.results'.

        :param epoch: Number of the current epoch
        :param num_epochs: Number of total epochs
        :param opt_state: Current state of the optimizer
        :param rng: rng key
        """
        assert (num_epochs > 0)
        img = test_fetch_plain(0, batchifier_state)[0][0]
        plt.imsave(os.path.join(
            RESULTS_DIR, "epoch_{:0{}d}_original.png".format(
                epoch, (int(jnp.log10(num_epochs)) + 1))),
                   img,
                   cmap='gray')
        rng, rng_binarize = random.split(rng, 2)
        test_sample = binarize(rng_binarize, img)
        test_sample = jnp.reshape(test_sample, (1, *jnp.shape(test_sample)))
        params = svi.get_params(svi_state)

        samples = sample_multi_posterior_predictive(
            rng, 10, model,
            (1, args.z_dim, args.hidden_dim, np.prod(test_sample.shape[1:])),
            guide, (test_sample, args.z_dim, args.hidden_dim), params)

        img_loc = samples['obs'][0].reshape([28, 28])
        avg_img_loc = jnp.mean(samples['obs'], axis=0).reshape([28, 28])
        plt.imsave(os.path.join(
            RESULTS_DIR, "epoch_{:0{}d}_recons_single.png".format(
                epoch, (int(jnp.log10(num_epochs)) + 1))),
                   img_loc,
                   cmap='gray')
        plt.imsave(os.path.join(
            RESULTS_DIR, "epoch_{:0{}d}_recons_avg.png".format(
                epoch, (int(jnp.log10(num_epochs)) + 1))),
                   avg_img_loc,
                   cmap='gray')

    # main training loop
    for i in range(args.num_epochs):
        t_start = time.time()
        rng, data_fetch_rng, train_rng = random.split(rng, 3)
        num_train_batches, train_batchifier_state, = train_init(
            rng_key=data_fetch_rng)
        svi_state, train_loss = epoch_train(svi_state, train_batchifier_state,
                                            num_train_batches, train_rng)

        rng, test_fetch_rng, test_rng, recons_rng = random.split(rng, 4)
        num_test_batches, test_batchifier_state = test_init(
            rng_key=test_fetch_rng)
        test_loss = eval_test(svi_state, test_batchifier_state,
                              num_test_batches, test_rng)

        reconstruct_img(i, args.num_epochs, test_batchifier_state, svi_state,
                        recons_rng)
        print("Epoch {}: loss = {} (on training set: {}) ({:.2f} s.)".format(
            i, test_loss, train_loss,
            time.time() - t_start))
Пример #7
0
class DPSVITest(unittest.TestCase):
    def setUp(self):
        self.rng = jax.random.PRNGKey(9782346)
        self.batch_size = 10
        self.num_obs_total = 100
        self.px_grads = ((jnp.zeros(
            (self.batch_size, 10000)), jnp.zeros((self.batch_size, 10000))))
        self.px_grads_list, self.tree_def = jax.tree_flatten(self.px_grads)
        self.px_loss = jnp.arange(self.batch_size, dtype=jnp.float32)
        self.dp_scale = 1.
        self.clipping_threshold = 2.
        self.svi = DPSVI(None,
                         None,
                         None,
                         None,
                         self.clipping_threshold,
                         self.dp_scale,
                         num_obs_total=self.num_obs_total)

    def test_px_gradient_aggregation(self):
        svi_state = DPSVIState(None, self.rng, .3)

        np.random.seed(0)
        px_grads_list, _testMethodDoc = jax.tree_flatten(
            (np.random.normal(1, 1, size=(self.batch_size, 10000)),
             np.random.normal(1, 1, size=(self.batch_size, 10000))))

        expected_grads_list = [
            jnp.mean(px_grads, axis=0) for px_grads in px_grads_list
        ]
        expected_loss = jnp.mean(self.px_loss)

        loss, grads_list = self.svi._combine_gradients(px_grads_list,
                                                       self.px_loss)

        self.assertTrue(jnp.allclose(expected_loss, loss),
                        f"expected loss {expected_loss} but was {loss}")
        self.assertTrue(
            jnp.allclose(expected_grads_list, grads_list),
            f"expected gradients {expected_grads_list} but was {grads_list}")

    def test_dp_noise_perturbation(self):
        svi_state = DPSVIState(None, self.rng, .3)

        grads_list = [
            jnp.mean(px_grads, axis=0) for px_grads in self.px_grads_list
        ]

        new_svi_state, grads = \
            self.svi._perturb_and_reassemble_gradients(
                svi_state, grads_list, self.batch_size, self.tree_def
            )

        self.assertIs(svi_state.optim_state, new_svi_state.optim_state)
        self.assertFalse(jnp.allclose(svi_state.rng_key,
                                      new_svi_state.rng_key))
        self.assertEqual(self.tree_def, jax.tree_structure(grads))

        expected_std = (self.dp_scale * self.clipping_threshold /
                        self.batch_size) * svi_state.observation_scale
        for site, px_site in zip(jax.tree_leaves(grads),
                                 jax.tree_leaves(self.px_grads_list)):
            self.assertEqual(px_site.shape[1:], site.shape)
            self.assertTrue(
                jnp.allclose(expected_std, jnp.std(site), atol=1e-2),
                f"expected stdev {expected_std} but was {jnp.std(site)}")
            self.assertTrue(jnp.allclose(0., jnp.mean(site), atol=1e-2))

    def test_dp_noise_perturbation_not_deterministic_over_calls(self):
        """ verifies that different randomness is used in subsequent calls """
        svi_state = DPSVIState(None, self.rng, .3)

        grads_list = [
            jnp.mean(px_grads, axis=0) for px_grads in self.px_grads_list
        ]

        new_svi_state, first_grads = \
            self.svi._perturb_and_reassemble_gradients(
                svi_state, grads_list, self.batch_size, self.tree_def
            )

        _, second_grads = \
            self.svi._perturb_and_reassemble_gradients(
                new_svi_state, grads_list, self.batch_size, self.tree_def
            )

        some_gradient_noise_is_equal = reduce(
            lambda are_equal, acc: are_equal or acc,
            jax.tree_leaves(
                jax.tree_multimap(lambda x, y: jnp.allclose(x, y), first_grads,
                                  second_grads)))
        self.assertFalse(some_gradient_noise_is_equal)

    def test_dp_noise_perturbation_not_deterministic_over_sites(self):
        """ verifies that different randomness is used for different gradient sites """
        svi_state = DPSVIState(None, self.rng, .3)

        grads_list = [
            jnp.mean(px_grads, axis=0) for px_grads in self.px_grads_list
        ]

        _, grads = \
            self.svi._perturb_and_reassemble_gradients(
                svi_state, grads_list, self.batch_size, self.tree_def
            )

        noise_sites = jax.tree_leaves(grads)

        self.assertFalse(jnp.allclose(noise_sites[0], noise_sites[1]))