Ejemplo n.º 1
0
 def testUnpackPackRoundTrip(self):
     opt_init, _, _ = optimizers.momentum(0.1, mass=0.9)
     params = [{'w': onp.random.randn(1, 2), 'bias': onp.random.randn(2)}]
     expected = opt_init(params)
     ans = optimizers.pack_optimizer_state(
         optimizers.unpack_optimizer_state(expected))
     self.assertEqual(ans, expected)
Ejemplo n.º 2
0
def save_checkpoint(filename, step, optimizer_state, history):
    """ Save a training checkpoint """
    pytree = optimizers.unpack_optimizer_state(optimizer_state)
    checkpoint_path = expmgr.get_result_path(filename)
    with checkpoint_path.open('wb') as f:
        pickle.dump(dict(step=step, state=pytree, history=history), f)
    expmgr.safe_wandb_save(checkpoint_path)
    return checkpoint_path
Ejemplo n.º 3
0
 def __getstate__(self):
     # This isn't scalable. I should remove this hardcoded stuff. Note
     # that _opt_init and _opt_update are not present due to PyCapsule pickling errors.
     d = {
         "opt_state": unpack_optimizer_state(self.opt_state),
         "get_weights": self.get_weights,
         "optimizer": self.optimizer,
         "weight_l2": self.weight_l2,
         "smoothing_l2": self.smoothing_l2,
         "is_compiled": self.is_compiled,
         "callbacks": self.callbacks,
         "topology": self.topology,
         "loss": self.loss,
         "_optimizer_kwargs": self._optimizer_kwargs,
         "_predict": self._predict,
     }
     return d
Ejemplo n.º 4
0
def main(_):
    sns.set()
    sns.set_palette(sns.color_palette('hls', 10))
    npr.seed(FLAGS.seed)

    logging.info('Starting experiment.')

    # Create model folder for outputs
    try:
        gfile.MakeDirs(FLAGS.work_dir)
    except gfile.GOSError:
        pass
    stdout_log = gfile.Open('{}/stdout.log'.format(FLAGS.work_dir), 'w+')

    # BEGIN: set up optimizer and load params
    _, opt_update, get_params = optimizers.sgd(FLAGS.learning_rate)

    ckpt_dir = '{}/{}'.format(FLAGS.root_dir, FLAGS.ckpt_idx)
    with gfile.Open(ckpt_dir, 'rb') as fckpt:
        opt_state = optimizers.pack_optimizer_state(pickle.load(fckpt))
    params = get_params(opt_state)

    stdout_log.write('finetune from: {}\n'.format(ckpt_dir))
    logging.info('finetune from: %s', ckpt_dir)
    # END: set up optimizer and load params

    # BEGIN: train data emb, PCA, uncertain
    train_images, train_labels, _ = datasets.get_dataset_split(
        name=FLAGS.train_split.split('-')[0],
        split=FLAGS.train_split.split('-')[1],
        shuffle=False)
    n_train = len(train_images)
    # use mean/std of svhn train
    train_mu, train_std = 128., 128.
    train_images = (train_images - train_mu) / train_std

    # embeddings of all the training points
    train_embeddings = [apply_fn_0(params[:-1],
                                   train_images[b_i:b_i + FLAGS.batch_size]) \
                        for b_i in range(0, n_train, FLAGS.batch_size)]
    train_embeddings = np.concatenate(train_embeddings, axis=0)

    # fit PCA onto embeddings of all the training points
    if FLAGS.dppca_eps is not None:
        pc_cols, e_vals = dp_pca.dp_pca(train_embeddings,
                                        train_embeddings.shape[1],
                                        epsilon=FLAGS.dppca_eps,
                                        delta=1e-5,
                                        sigma=None)
    else:
        big_pca = sklearn.decomposition.PCA(
            n_components=train_embeddings.shape[1])
        big_pca.random_state = FLAGS.seed
        big_pca.fit(train_embeddings)

    # filter out uncertain train points
    n_uncertain = FLAGS.n_extra + FLAGS.uncertain_extra

    train_probs = stax.softmax(apply_fn_1(params[-1:], train_embeddings))
    train_acc = np.mean(
        np.argmax(train_probs, axis=1) == np.argmax(train_labels, axis=1))
    logging.info('initial train acc: %.2f', train_acc)

    if FLAGS.uncertain == 0 or FLAGS.uncertain == 'entropy':
        # entropy
        train_entropy = -onp.sum(train_probs * onp.log(train_probs), axis=1)
        train_uncertain_indices = \
            onp.argsort(train_entropy)[::-1][:n_uncertain]
    elif FLAGS.uncertain == 1 or FLAGS.uncertain == 'difference':
        # 1st_prob - 2nd_prob
        assert len(train_probs.shape) == 2
        sorted_train_probs = onp.sort(train_probs, axis=1)
        train_probs_diff = sorted_train_probs[:, -1] - sorted_train_probs[:,
                                                                          -2]
        assert min(train_probs_diff) > 0.
        train_uncertain_indices = onp.argsort(train_probs_diff)[:n_uncertain]

    if FLAGS.dppca_eps is not None:
        train_uncertain_projected_embeddings, _ = utils.project_embeddings(
            train_embeddings[train_uncertain_indices],
            pca_object=None,
            n_components=FLAGS.k_components,
            pc_cols=pc_cols)
    else:
        train_uncertain_projected_embeddings, _ = utils.project_embeddings(
            train_embeddings[train_uncertain_indices], big_pca,
            FLAGS.k_components)
    logging.info('projected embeddings of uncertain train data')

    del train_images, train_labels, train_embeddings
    # END: train data emb, PCA, uncertain

    # BEGIN: pool data emb
    pool_images, pool_labels, _ = datasets.get_dataset_split(
        name=FLAGS.pool_split.split('-')[0],
        split=FLAGS.pool_split.split('-')[1],
        shuffle=False)
    n_pool = len(pool_images)
    pool_images = (pool_images -
                   train_mu) / train_std  # normalize w train mu/std

    pool_embeddings = [apply_fn_0(params[:-1],
                                  pool_images[b_i:b_i + FLAGS.batch_size]) \
                       for b_i in range(0, n_pool, FLAGS.batch_size)]
    pool_embeddings = np.concatenate(pool_embeddings, axis=0)

    # filter out uncertain pool points
    pool_probs = stax.softmax(apply_fn_1(params[-1:], pool_embeddings))
    if FLAGS.uncertain == 0 or FLAGS.uncertain == 'entropy':
        # entropy
        pool_entropy = -onp.sum(pool_probs * onp.log(pool_probs), axis=1)
        pool_uncertain_indices = onp.argsort(pool_entropy)[::-1][:n_uncertain]
    elif FLAGS.uncertain == 1 or FLAGS.uncertain == 'difference':
        # 1st_prob - 2nd_prob
        assert len(pool_probs.shape) == 2
        sorted_pool_probs = onp.sort(pool_probs, axis=1)
        pool_probs_diff = sorted_pool_probs[:, -1] - sorted_pool_probs[:, -2]
        assert min(pool_probs_diff) > 0.
        pool_uncertain_indices = onp.argsort(pool_probs_diff)[:n_uncertain]

    # constrain pool candidates to ONLY uncertain ones
    pool_images = pool_images[pool_uncertain_indices]
    pool_labels = pool_labels[pool_uncertain_indices]
    pool_embeddings = pool_embeddings[pool_uncertain_indices]
    n_pool = len(pool_uncertain_indices)

    if FLAGS.dppca_eps is not None:
        pool_projected_embeddings, _ = utils.project_embeddings(
            pool_embeddings,
            pca_object=None,
            n_components=FLAGS.k_components,
            pc_cols=pc_cols)
    else:
        pool_projected_embeddings, _ = utils.project_embeddings(
            pool_embeddings, big_pca, FLAGS.k_components)

    del pool_embeddings
    logging.info('projected embeddings of pool data')
    # END: pool data emb

    # BEGIN: assign train_uncertain_projected_embeddings to ONLY 1 point/cluster
    # assign uncertain train to closest pool point/cluster
    pool_index_histogram = onp.zeros(n_pool)
    for i in range(len(train_uncertain_projected_embeddings)):
        # t0 = time.time()
        train_uncertain_point = \
            train_uncertain_projected_embeddings[i].reshape(1, -1)

        if FLAGS.distance == 0 or FLAGS.distance == 'euclidean':
            cluster_distances = euclidean_distances(
                pool_projected_embeddings, train_uncertain_point).reshape(-1)
        elif FLAGS.distance == 1 or FLAGS.distance == 'weighted_euclidean':
            weights = e_vals[:FLAGS.k_components] if FLAGS.dppca_eps is not None \
                else big_pca.singular_values_[:FLAGS.k_components]
            cluster_distances = weighted_euclidean_distances(
                pool_projected_embeddings, train_uncertain_point, weights)

        pool_index = onp.argmin(cluster_distances)
        pool_index_histogram[pool_index] += 1.
        # t1 = time.time()
        # logging.info('%d uncertain train, %s second', i, str(t1 - t0))
    del cluster_distances

    # add Laplacian noise onto #neighors
    if FLAGS.extra_eps is not None:
        pool_index_histogram += npr.laplace(scale=FLAGS.extra_eps -
                                            FLAGS.dppca_eps,
                                            size=pool_index_histogram.shape)

    pool_picked_indices = onp.argsort(
        pool_index_histogram)[::-1][:FLAGS.n_extra]

    logging.info('%d extra pool data picked', len(pool_picked_indices))
    # END: assign train_uncertain_projected_embeddings to ONLY 1 cluster

    # load test data
    test_images, test_labels, _ = datasets.get_dataset_split(
        name=FLAGS.test_split.split('-')[0],
        split=FLAGS.test_split.split('-')[1],
        shuffle=False)
    test_images = (test_images -
                   train_mu) / train_std  # normalize w train mu/std

    # augmentation for train/pool data
    if FLAGS.augment_data:
        augmentation = data.chain_transforms(data.RandomHorizontalFlip(0.5),
                                             data.RandomCrop(4), data.ToDevice)
    else:
        augmentation = None

    test_acc, test_pred = accuracy(params,
                                   shape_as_image(test_images, test_labels),
                                   return_predicted_class=True)
    logging.info('test accuracy: %.2f', test_acc)
    stdout_log.write('test accuracy: {}\n'.format(test_acc))
    stdout_log.flush()
    worst_test_acc, best_test_acc, best_epoch = test_acc, test_acc, 0

    # BEGIN: setup for dp model
    @jit
    def update(_, i, opt_state, batch):
        params = get_params(opt_state)
        return opt_update(i, grad_loss(params, batch), opt_state)

    @jit
    def private_update(rng, i, opt_state, batch):
        params = get_params(opt_state)
        rng = random.fold_in(rng, i)  # get new key for new random numbers
        return opt_update(
            i,
            private_grad(params, batch, rng, FLAGS.l2_norm_clip,
                         FLAGS.noise_multiplier, FLAGS.batch_size), opt_state)

    # END: setup for dp model

    finetune_images = copy.deepcopy(pool_images[pool_picked_indices])
    finetune_labels = copy.deepcopy(pool_labels[pool_picked_indices])

    logging.info('Starting fine-tuning...')
    stdout_log.write('Starting fine-tuning...\n')
    stdout_log.flush()

    # BEGIN: gather points to be used for finetuning
    stdout_log.write('{} points picked via {}\n'.format(
        len(finetune_images), FLAGS.uncertain))
    logging.info('%d points picked via %s', len(finetune_images),
                 FLAGS.uncertain)
    assert FLAGS.n_extra == len(finetune_images)
    # END: gather points to be used for finetuning

    for epoch in range(1, FLAGS.epochs + 1):

        # BEGIN: finetune model with extra data, evaluate and save
        num_extra = len(finetune_images)
        num_complete_batches, leftover = divmod(num_extra, FLAGS.batch_size)
        num_batches = num_complete_batches + bool(leftover)

        finetune = data.DataChunk(X=finetune_images,
                                  Y=finetune_labels,
                                  image_size=28,
                                  image_channels=1,
                                  label_dim=1,
                                  label_format='numeric')

        batches = data.minibatcher(finetune,
                                   FLAGS.batch_size,
                                   transform=augmentation)

        itercount = itertools.count()
        key = random.PRNGKey(FLAGS.seed)

        start_time = time.time()

        for _ in range(num_batches):
            # tmp_time = time.time()
            b = next(batches)
            if FLAGS.dpsgd:
                opt_state = private_update(
                    key, next(itercount), opt_state,
                    shape_as_image(b.X, b.Y, dummy_dim=True))
            else:
                opt_state = update(key, next(itercount), opt_state,
                                   shape_as_image(b.X, b.Y))
            # stdout_log.write('single update in {:.2f} sec\n'.format(
            #     time.time() - tmp_time))

        epoch_time = time.time() - start_time
        stdout_log.write('Epoch {} in {:.2f} sec\n'.format(epoch, epoch_time))
        logging.info('Epoch %d in %.2f sec', epoch, epoch_time)

        # accuracy on test data
        params = get_params(opt_state)

        test_pred_0 = test_pred
        test_acc, test_pred = accuracy(params,
                                       shape_as_image(test_images,
                                                      test_labels),
                                       return_predicted_class=True)
        test_loss = loss(params, shape_as_image(test_images, test_labels))
        stdout_log.write(
            'Eval set loss, accuracy (%): ({:.2f}, {:.2f})\n'.format(
                test_loss, 100 * test_acc))
        logging.info('Eval set loss, accuracy: (%.2f, %.2f)', test_loss,
                     100 * test_acc)
        stdout_log.flush()

        # visualize prediction difference between 2 checkpoints.
        if FLAGS.visualize:
            utils.visualize_ckpt_difference(test_images,
                                            np.argmax(test_labels, axis=1),
                                            test_pred_0,
                                            test_pred,
                                            epoch - 1,
                                            epoch,
                                            FLAGS.work_dir,
                                            mu=train_mu,
                                            sigma=train_std)

        worst_test_acc = min(test_acc, worst_test_acc)
        if test_acc > best_test_acc:
            best_test_acc, best_epoch = test_acc, epoch
            # save opt_state
            with gfile.Open('{}/acc_ckpt'.format(FLAGS.work_dir),
                            'wb') as fckpt:
                pickle.dump(optimizers.unpack_optimizer_state(opt_state),
                            fckpt)
    # END: finetune model with extra data, evaluate and save

    stdout_log.write('best test acc {} @E {}\n'.format(best_test_acc,
                                                       best_epoch))
    stdout_log.close()
Ejemplo n.º 5
0
def main(_):
    logging.info('Starting experiment.')
    configs = FLAGS.config

    # Create model folder for outputs
    try:
        gfile.MakeDirs(FLAGS.exp_dir)
    except gfile.GOSError:
        pass
    stdout_log = gfile.Open('{}/stdout.log'.format(FLAGS.exp_dir), 'w+')

    logging.info('Loading data.')
    tic = time.time()

    train_images, train_labels, _ = datasets.get_dataset_split(
        FLAGS.dataset, 'train')
    n_train = len(train_images)
    train_mu, train_std = onp.mean(train_images), onp.std(train_images)
    train = data.DataChunk(X=(train_images - train_mu) / train_std,
                           Y=train_labels,
                           image_size=32,
                           image_channels=3,
                           label_dim=1,
                           label_format='numeric')

    test_images, test_labels, _ = datasets.get_dataset_split(
        FLAGS.dataset, 'test')
    test = data.DataChunk(
        X=(test_images - train_mu) / train_std,  # normalize w train mean/std
        Y=test_labels,
        image_size=32,
        image_channels=3,
        label_dim=1,
        label_format='numeric')

    # Data augmentation
    if configs.augment_data:
        augmentation = data.chain_transforms(data.RandomHorizontalFlip(0.5),
                                             data.RandomCrop(4), data.ToDevice)
    else:
        augmentation = None
    batch = data.minibatcher(train, configs.batch_size, transform=augmentation)

    # Model architecture
    if configs.architect == 'wrn':
        init_random_params, predict = wide_resnet(configs.block_size,
                                                  configs.channel_multiplier,
                                                  10)
    elif configs.architect == 'cnn':
        init_random_params, predict = cnn()
    else:
        raise ValueError('Model architecture not implemented.')

    if configs.seed is not None:
        key = random.PRNGKey(configs.seed)
    else:
        key = random.PRNGKey(int(time.time()))
    _, params = init_random_params(key, (-1, 32, 32, 3))

    # count params of JAX model
    def count_parameters(params):
        return tree_util.tree_reduce(
            operator.add, tree_util.tree_map(lambda x: np.prod(x.shape),
                                             params))

    logging.info('Number of parameters: %d', count_parameters(params))
    stdout_log.write('Number of params: {}\n'.format(count_parameters(params)))

    # loss functions
    def cross_entropy_loss(params, x_img, y_lbl):
        return -np.mean(stax.logsoftmax(predict(params, x_img)) * y_lbl)

    def mse_loss(params, x_img, y_lbl):
        return 0.5 * np.mean((y_lbl - predict(params, x_img))**2)

    def accuracy(y_lbl_hat, y_lbl):
        target_class = np.argmax(y_lbl, axis=1)
        predicted_class = np.argmax(y_lbl_hat, axis=1)
        return np.mean(predicted_class == target_class)

    # Loss and gradient
    if configs.loss == 'xent':
        loss = cross_entropy_loss
    elif configs.loss == 'mse':
        loss = mse_loss
    else:
        raise ValueError('Loss function not implemented.')
    grad_loss = jit(grad(loss))

    # learning rate schedule and optimizer
    def cosine(initial_step_size, train_steps):
        k = np.pi / (2.0 * train_steps)

        def schedule(i):
            return initial_step_size * np.cos(k * i)

        return schedule

    if configs.optimization == 'sgd':
        lr_schedule = optimizers.make_schedule(configs.learning_rate)
        opt_init, opt_update, get_params = optimizers.sgd(lr_schedule)
    elif configs.optimization == 'momentum':
        lr_schedule = cosine(configs.learning_rate, configs.train_steps)
        opt_init, opt_update, get_params = optimizers.momentum(
            lr_schedule, 0.9)
    else:
        raise ValueError('Optimizer not implemented.')

    opt_state = opt_init(params)

    def private_grad(params, batch, rng, l2_norm_clip, noise_multiplier,
                     batch_size):
        """Return differentially private gradients of params, evaluated on batch."""
        def _clipped_grad(params, single_example_batch):
            """Evaluate gradient for a single-example batch and clip its grad norm."""
            grads = grad_loss(params, single_example_batch[0].reshape(
                (-1, 32, 32, 3)), single_example_batch[1])

            nonempty_grads, tree_def = tree_util.tree_flatten(grads)
            total_grad_norm = np.linalg.norm(
                [np.linalg.norm(neg.ravel()) for neg in nonempty_grads])
            divisor = stop_gradient(
                np.amax((total_grad_norm / l2_norm_clip, 1.)))
            normalized_nonempty_grads = [
                neg / divisor for neg in nonempty_grads
            ]
            return tree_util.tree_unflatten(tree_def,
                                            normalized_nonempty_grads)

        px_clipped_grad_fn = vmap(partial(_clipped_grad, params))
        std_dev = l2_norm_clip * noise_multiplier
        noise_ = lambda n: n + std_dev * random.normal(rng, n.shape)
        normalize_ = lambda n: n / float(batch_size)
        sum_ = lambda n: np.sum(n, 0)  # aggregate
        aggregated_clipped_grads = tree_util.tree_map(
            sum_, px_clipped_grad_fn(batch))
        noised_aggregated_clipped_grads = tree_util.tree_map(
            noise_, aggregated_clipped_grads)
        normalized_noised_aggregated_clipped_grads = (tree_util.tree_map(
            normalize_, noised_aggregated_clipped_grads))
        return normalized_noised_aggregated_clipped_grads

    # summarize measurements
    steps_per_epoch = n_train // configs.batch_size

    def summarize(step, params):
        """Compute measurements in a zipped way."""
        set_entries = [train, test]
        set_bsizes = [configs.train_eval_bsize, configs.test_eval_bsize]
        set_names, loss_dict, acc_dict = ['train', 'test'], {}, {}

        for set_entry, set_bsize, set_name in zip(set_entries, set_bsizes,
                                                  set_names):
            temp_loss, temp_acc, points = 0.0, 0.0, 0
            for b in data.batch(set_entry, set_bsize):
                temp_loss += loss(params, b.X, b.Y) * b.X.shape[0]
                temp_acc += accuracy(predict(params, b.X), b.Y) * b.X.shape[0]
                points += b.X.shape[0]
            loss_dict[set_name] = temp_loss / float(points)
            acc_dict[set_name] = temp_acc / float(points)

        logging.info('Step: %s', str(step))
        logging.info('Train acc : %.4f', acc_dict['train'])
        logging.info('Train loss: %.4f', loss_dict['train'])
        logging.info('Test acc  : %.4f', acc_dict['test'])
        logging.info('Test loss : %.4f', loss_dict['test'])

        stdout_log.write('Step: {}\n'.format(step))
        stdout_log.write('Train acc : {}\n'.format(acc_dict['train']))
        stdout_log.write('Train loss: {}\n'.format(loss_dict['train']))
        stdout_log.write('Test acc  : {}\n'.format(acc_dict['test']))
        stdout_log.write('Test loss : {}\n'.format(loss_dict['test']))

        return acc_dict['test']

    toc = time.time()
    logging.info('Elapsed SETUP time: %s', str(toc - tic))
    stdout_log.write('Elapsed SETUP time: {}\n'.format(toc - tic))

    # BEGIN: training steps
    logging.info('Training network.')
    tic = time.time()
    t = time.time()

    for s in range(configs.train_steps):
        b = next(batch)
        params = get_params(opt_state)

        # t0 = time.time()
        if FLAGS.dpsgd:
            key = random.fold_in(key, s)  # get new key for new random numbers
            opt_state = opt_update(
                s,
                private_grad(params, (b.X.reshape(
                    (-1, 1, 32, 32, 3)), b.Y), key, configs.l2_norm_clip,
                             configs.noise_multiplier, configs.batch_size),
                opt_state)
        else:
            opt_state = opt_update(s, grad_loss(params, b.X, b.Y), opt_state)
        # t1 = time.time()
        # logging.info('batch update time: %s', str(t1 - t0))

        if s % steps_per_epoch == 0:
            with gfile.Open(
                    '{}/ckpt_{}'.format(FLAGS.exp_dir,
                                        int(s / steps_per_epoch)),
                    'wr') as fckpt:
                pickle.dump(optimizers.unpack_optimizer_state(opt_state),
                            fckpt)

            if FLAGS.dpsgd:
                eps = compute_epsilon(s, configs.batch_size, n_train,
                                      configs.target_delta,
                                      configs.noise_multiplier)
                stdout_log.write(
                    'For delta={:.0e}, current epsilon is: {:.2f}\n'.format(
                        configs.target_delta, eps))

            logging.info('Elapsed EPOCH time: %s', str(time.time() - t))
            stdout_log.write('Elapsed EPOCH time: {}'.format(time.time() - t))
            stdout_log.flush()
            t = time.time()

    toc = time.time()
    summarize(configs.train_steps, params)
    logging.info('Elapsed TRAIN time: %s', str(toc - tic))
    stdout_log.write('Elapsed TRAIN time: {}'.format(toc - tic))
    stdout_log.close()
Ejemplo n.º 6
0
 def serialise(self):
     state = self._replace(
         opt_state=optimizers.unpack_optimizer_state(self.opt_state))
     return pickle.dumps(state)