def main(_):
    sns.set()
    sns.set_palette(sns.color_palette('hls', 10))
    npr.seed(FLAGS.seed)

    logging.info('Starting experiment.')

    # Create model folder for outputs
    try:
        gfile.MakeDirs(FLAGS.work_dir)
    except gfile.GOSError:
        pass
    stdout_log = gfile.Open('{}/stdout.log'.format(FLAGS.work_dir), 'w+')

    # BEGIN: fetch test data and candidate pool
    test_images, test_labels, _ = datasets.get_dataset_split(
        name=FLAGS.test_split.split('-')[0],
        split=FLAGS.test_split.split('-')[1],
        shuffle=False)
    pool_images, pool_labels, _ = datasets.get_dataset_split(
        name=FLAGS.pool_split.split('-')[0],
        split=FLAGS.pool_split.split('-')[1],
        shuffle=False)

    n_pool = len(pool_images)
    # normalize to range [-1.0, 127./128]
    test_images = test_images / np.float32(128.0) - np.float32(1.0)
    pool_images = pool_images / np.float32(128.0) - np.float32(1.0)

    # augmentation for train/pool data
    if FLAGS.augment_data:
        augmentation = data.chain_transforms(data.RandomHorizontalFlip(0.5),
                                             data.RandomCrop(4), data.ToDevice)
    else:
        augmentation = None
    # END: fetch test data and candidate pool

    _, opt_update, get_params = optimizers.sgd(FLAGS.learning_rate)

    # BEGIN: load ckpt
    ckpt_dir = '{}/{}'.format(FLAGS.root_dir, FLAGS.ckpt_idx)
    with gfile.Open(ckpt_dir, 'rb') as fckpt:
        opt_state = optimizers.pack_optimizer_state(pickle.load(fckpt))
    params = get_params(opt_state)

    stdout_log.write('finetune from: {}\n'.format(ckpt_dir))
    logging.info('finetune from: %s', ckpt_dir)
    test_acc, test_pred = accuracy(params,
                                   shape_as_image(test_images, test_labels),
                                   return_predicted_class=True)
    logging.info('test accuracy: %.2f', test_acc)
    stdout_log.write('test accuracy: {}\n'.format(test_acc))
    stdout_log.flush()
    # END: load ckpt

    # BEGIN: setup for dp model
    @jit
    def update(_, i, opt_state, batch):
        params = get_params(opt_state)
        return opt_update(i, grad_loss(params, batch), opt_state)

    @jit
    def private_update(rng, i, opt_state, batch):
        params = get_params(opt_state)
        rng = random.fold_in(rng, i)  # get new key for new random numbers
        return opt_update(
            i,
            private_grad(params, batch, rng, FLAGS.l2_norm_clip,
                         FLAGS.noise_multiplier, FLAGS.batch_size), opt_state)

    # END: setup for dp model

    n_uncertain = FLAGS.n_extra + FLAGS.uncertain_extra

    ### BEGIN: prepare extra points picked from pool data
    # BEGIN: on pool data
    pool_embeddings = [apply_fn_0(params[:-1],
                                  pool_images[b_i:b_i + FLAGS.batch_size]) \
                       for b_i in range(0, n_pool, FLAGS.batch_size)]
    pool_embeddings = np.concatenate(pool_embeddings, axis=0)

    pool_logits = apply_fn_1(params[-1:], pool_embeddings)

    pool_true_labels = np.argmax(pool_labels, axis=1)
    pool_predicted_labels = np.argmax(pool_logits, axis=1)
    pool_correct_indices = \
        onp.where(pool_true_labels == pool_predicted_labels)[0]
    pool_incorrect_indices = \
        onp.where(pool_true_labels != pool_predicted_labels)[0]
    assert len(pool_correct_indices) + \
        len(pool_incorrect_indices) == len(pool_labels)

    pool_probs = stax.softmax(pool_logits)

    if FLAGS.uncertain == 0 or FLAGS.uncertain == 'entropy':
        pool_entropy = -onp.sum(pool_probs * onp.log(pool_probs), axis=1)
        stdout_log.write('all {} entropy: min {}, max {}\n'.format(
            len(pool_entropy), onp.min(pool_entropy), onp.max(pool_entropy)))

        pool_entropy_sorted_indices = onp.argsort(pool_entropy)
        # take the n_uncertain most uncertain points
        pool_uncertain_indices = \
            pool_entropy_sorted_indices[::-1][:n_uncertain]
        stdout_log.write('uncertain {} entropy: min {}, max {}\n'.format(
            len(pool_entropy[pool_uncertain_indices]),
            onp.min(pool_entropy[pool_uncertain_indices]),
            onp.max(pool_entropy[pool_uncertain_indices])))

    elif FLAGS.uncertain == 1 or FLAGS.uncertain == 'difference':
        # 1st_prob - 2nd_prob
        assert len(pool_probs.shape) == 2
        sorted_pool_probs = onp.sort(pool_probs, axis=1)
        pool_probs_diff = sorted_pool_probs[:, -1] - sorted_pool_probs[:, -2]
        assert min(pool_probs_diff) > 0.
        pool_uncertain_indices = onp.argsort(pool_probs_diff)[:n_uncertain]

    # END: on pool data

    # BEGIN: cluster uncertain pool points
    big_pca = sklearn.decomposition.PCA(n_components=pool_embeddings.shape[1])
    big_pca.random_state = FLAGS.seed
    # fit PCA onto embeddings of all the pool points
    big_pca.fit(pool_embeddings)

    # For uncertain points, project embeddings onto the first K components
    pool_uncertain_projected_embeddings, _ = utils.project_embeddings(
        pool_embeddings[pool_uncertain_indices], big_pca, FLAGS.k_components)

    n_cluster = int(FLAGS.n_extra / FLAGS.ppc)
    cluster_method = get_cluster_method('{}_nc-{}'.format(
        FLAGS.clustering, n_cluster))
    cluster_method.random_state = FLAGS.seed
    pool_uncertain_cluster_labels = cluster_method.fit_predict(
        pool_uncertain_projected_embeddings)
    pool_uncertain_cluster_label_indices = {
        x: []
        for x in set(pool_uncertain_cluster_labels)
    }  # local i within n_uncertain
    for i, c_label in enumerate(pool_uncertain_cluster_labels):
        pool_uncertain_cluster_label_indices[c_label].append(i)

    # find center of each cluster
    # aka, the most representative point of each 'tough' cluster
    pool_picked_indices = []
    pool_uncertain_cluster_label_pick = {}
    for c_label, indices in pool_uncertain_cluster_label_indices.items():
        cluster_projected_embeddings = \
            pool_uncertain_projected_embeddings[indices]
        cluster_center = onp.mean(cluster_projected_embeddings,
                                  axis=0,
                                  keepdims=True)
        if FLAGS.distance == 0 or FLAGS.distance == 'euclidean':
            cluster_distances = euclidean_distances(
                cluster_projected_embeddings, cluster_center).reshape(-1)
        elif FLAGS.distance == 1 or FLAGS.distance == 'weighted_euclidean':
            cluster_distances = weighted_euclidean_distances(
                cluster_projected_embeddings, cluster_center,
                big_pca.singular_values_[:FLAGS.k_components])

        sorted_is = onp.argsort(cluster_distances)
        sorted_indices = onp.array(indices)[sorted_is]
        pool_uncertain_cluster_label_indices[c_label] = sorted_indices
        center_i = sorted_indices[0]  # center_i in 3000
        pool_uncertain_cluster_label_pick[c_label] = center_i
        pool_picked_indices.extend(
            pool_uncertain_indices[sorted_indices[:FLAGS.ppc]])

        # BEGIN: visualize cluster of picked uncertain pool
        if FLAGS.visualize:
            this_cluster = []
            for i in sorted_indices:
                idx = pool_uncertain_indices[i]
                img = pool_images[idx]
                if idx in pool_correct_indices:
                    border_color = 'green'
                else:
                    border_color = 'red'
                    img = utils.mark_labels(img, pool_predicted_labels[idx],
                                            pool_true_labels[idx])
                img = utils.denormalize(img, 128., 128.)
                img = np.squeeze(utils.to_rgb(np.expand_dims(img, 0)))
                img = utils.add_border(img, width=2, color=border_color)
                this_cluster.append(img)
            utils.tile_image_list(
                this_cluster, '{}/picked_uncertain_pool_cid-{}'.format(
                    FLAGS.work_dir, c_label))
        # END: visualize cluster of picked uncertain pool

    # END: cluster uncertain pool points

    pool_picked_indices = list(set(pool_picked_indices))

    n_gap = FLAGS.n_extra - len(pool_picked_indices)
    gap_indices = list(set(pool_uncertain_indices) - set(pool_picked_indices))
    pool_picked_indices.extend(npr.choice(gap_indices, n_gap, replace=False))
    stdout_log.write('n_gap: {}\n'.format(n_gap))
    ### END: prepare extra points picked from pool data

    finetune_images = copy.deepcopy(pool_images[pool_picked_indices])
    finetune_labels = copy.deepcopy(pool_labels[pool_picked_indices])

    stdout_log.write('{} points picked via {}\n'.format(
        len(finetune_images), FLAGS.uncertain))
    logging.info('%d points picked via %s', len(finetune_images),
                 FLAGS.uncertain)
    assert FLAGS.n_extra == len(finetune_images)
    # END: gather points to be used for finetuning

    stdout_log.write('Starting fine-tuning...\n')
    logging.info('Starting fine-tuning...')
    stdout_log.flush()

    for epoch in range(1, FLAGS.epochs + 1):

        # BEGIN: finetune model with extra data, evaluate and save
        num_extra = len(finetune_images)
        num_complete_batches, leftover = divmod(num_extra, FLAGS.batch_size)
        num_batches = num_complete_batches + bool(leftover)

        finetune = data.DataChunk(X=finetune_images,
                                  Y=finetune_labels,
                                  image_size=28,
                                  image_channels=1,
                                  label_dim=1,
                                  label_format='numeric')

        batches = data.minibatcher(finetune,
                                   FLAGS.batch_size,
                                   transform=augmentation)

        itercount = itertools.count()
        key = random.PRNGKey(FLAGS.seed)

        start_time = time.time()

        for _ in range(num_batches):
            # tmp_time = time.time()
            b = next(batches)
            if FLAGS.dpsgd:
                opt_state = private_update(
                    key, next(itercount), opt_state,
                    shape_as_image(b.X, b.Y, dummy_dim=True))
            else:
                opt_state = update(key, next(itercount), opt_state,
                                   shape_as_image(b.X, b.Y))
            # stdout_log.write('single update in {:.2f} sec\n'.format(
            #     time.time() - tmp_time))

        epoch_time = time.time() - start_time
        stdout_log.write('Epoch {} in {:.2f} sec\n'.format(epoch, epoch_time))
        logging.info('Epoch %d in %.2f sec', epoch, epoch_time)

        # accuracy on test data
        params = get_params(opt_state)

        test_pred_0 = test_pred
        test_acc, test_pred = accuracy(params,
                                       shape_as_image(test_images,
                                                      test_labels),
                                       return_predicted_class=True)
        test_loss = loss(params, shape_as_image(test_images, test_labels))
        stdout_log.write(
            'Eval set loss, accuracy (%): ({:.2f}, {:.2f})\n'.format(
                test_loss, 100 * test_acc))
        logging.info('Eval set loss, accuracy: (%.2f, %.2f)', test_loss,
                     100 * test_acc)
        stdout_log.flush()

        # visualize prediction difference between 2 checkpoints.
        if FLAGS.visualize:
            utils.visualize_ckpt_difference(test_images,
                                            np.argmax(test_labels, axis=1),
                                            test_pred_0,
                                            test_pred,
                                            epoch - 1,
                                            epoch,
                                            FLAGS.work_dir,
                                            mu=128.,
                                            sigma=128.)

    # END: finetune model with extra data, evaluate and save
    stdout_log.close()
Exemple #2
0
def main(_):
    sns.set()
    sns.set_palette(sns.color_palette('hls', 10))
    npr.seed(FLAGS.seed)

    logging.info('Starting experiment.')

    # Create model folder for outputs
    try:
        gfile.MakeDirs(FLAGS.work_dir)
    except gfile.GOSError:
        pass
    stdout_log = gfile.Open('{}/stdout.log'.format(FLAGS.work_dir), 'w+')

    # BEGIN: set up optimizer and load params
    _, opt_update, get_params = optimizers.sgd(FLAGS.learning_rate)

    ckpt_dir = '{}/{}'.format(FLAGS.root_dir, FLAGS.ckpt_idx)
    with gfile.Open(ckpt_dir, 'rb') as fckpt:
        opt_state = optimizers.pack_optimizer_state(pickle.load(fckpt))
    params = get_params(opt_state)

    stdout_log.write('finetune from: {}\n'.format(ckpt_dir))
    logging.info('finetune from: %s', ckpt_dir)
    # END: set up optimizer and load params

    # BEGIN: train data emb, PCA, uncertain
    train_images, train_labels, _ = datasets.get_dataset_split(
        name=FLAGS.train_split.split('-')[0],
        split=FLAGS.train_split.split('-')[1],
        shuffle=False)
    n_train = len(train_images)
    # use mean/std of svhn train
    train_mu, train_std = 128., 128.
    train_images = (train_images - train_mu) / train_std

    # embeddings of all the training points
    train_embeddings = [apply_fn_0(params[:-1],
                                   train_images[b_i:b_i + FLAGS.batch_size]) \
                        for b_i in range(0, n_train, FLAGS.batch_size)]
    train_embeddings = np.concatenate(train_embeddings, axis=0)

    # fit PCA onto embeddings of all the training points
    if FLAGS.dppca_eps is not None:
        pc_cols, e_vals = dp_pca.dp_pca(train_embeddings,
                                        train_embeddings.shape[1],
                                        epsilon=FLAGS.dppca_eps,
                                        delta=1e-5,
                                        sigma=None)
    else:
        big_pca = sklearn.decomposition.PCA(
            n_components=train_embeddings.shape[1])
        big_pca.random_state = FLAGS.seed
        big_pca.fit(train_embeddings)

    # filter out uncertain train points
    n_uncertain = FLAGS.n_extra + FLAGS.uncertain_extra

    train_probs = stax.softmax(apply_fn_1(params[-1:], train_embeddings))
    train_acc = np.mean(
        np.argmax(train_probs, axis=1) == np.argmax(train_labels, axis=1))
    logging.info('initial train acc: %.2f', train_acc)

    if FLAGS.uncertain == 0 or FLAGS.uncertain == 'entropy':
        # entropy
        train_entropy = -onp.sum(train_probs * onp.log(train_probs), axis=1)
        train_uncertain_indices = \
            onp.argsort(train_entropy)[::-1][:n_uncertain]
    elif FLAGS.uncertain == 1 or FLAGS.uncertain == 'difference':
        # 1st_prob - 2nd_prob
        assert len(train_probs.shape) == 2
        sorted_train_probs = onp.sort(train_probs, axis=1)
        train_probs_diff = sorted_train_probs[:, -1] - sorted_train_probs[:,
                                                                          -2]
        assert min(train_probs_diff) > 0.
        train_uncertain_indices = onp.argsort(train_probs_diff)[:n_uncertain]

    if FLAGS.dppca_eps is not None:
        train_uncertain_projected_embeddings, _ = utils.project_embeddings(
            train_embeddings[train_uncertain_indices],
            pca_object=None,
            n_components=FLAGS.k_components,
            pc_cols=pc_cols)
    else:
        train_uncertain_projected_embeddings, _ = utils.project_embeddings(
            train_embeddings[train_uncertain_indices], big_pca,
            FLAGS.k_components)
    logging.info('projected embeddings of uncertain train data')

    del train_images, train_labels, train_embeddings
    # END: train data emb, PCA, uncertain

    # BEGIN: pool data emb
    pool_images, pool_labels, _ = datasets.get_dataset_split(
        name=FLAGS.pool_split.split('-')[0],
        split=FLAGS.pool_split.split('-')[1],
        shuffle=False)
    n_pool = len(pool_images)
    pool_images = (pool_images -
                   train_mu) / train_std  # normalize w train mu/std

    pool_embeddings = [apply_fn_0(params[:-1],
                                  pool_images[b_i:b_i + FLAGS.batch_size]) \
                       for b_i in range(0, n_pool, FLAGS.batch_size)]
    pool_embeddings = np.concatenate(pool_embeddings, axis=0)

    # filter out uncertain pool points
    pool_probs = stax.softmax(apply_fn_1(params[-1:], pool_embeddings))
    if FLAGS.uncertain == 0 or FLAGS.uncertain == 'entropy':
        # entropy
        pool_entropy = -onp.sum(pool_probs * onp.log(pool_probs), axis=1)
        pool_uncertain_indices = onp.argsort(pool_entropy)[::-1][:n_uncertain]
    elif FLAGS.uncertain == 1 or FLAGS.uncertain == 'difference':
        # 1st_prob - 2nd_prob
        assert len(pool_probs.shape) == 2
        sorted_pool_probs = onp.sort(pool_probs, axis=1)
        pool_probs_diff = sorted_pool_probs[:, -1] - sorted_pool_probs[:, -2]
        assert min(pool_probs_diff) > 0.
        pool_uncertain_indices = onp.argsort(pool_probs_diff)[:n_uncertain]

    # constrain pool candidates to ONLY uncertain ones
    pool_images = pool_images[pool_uncertain_indices]
    pool_labels = pool_labels[pool_uncertain_indices]
    pool_embeddings = pool_embeddings[pool_uncertain_indices]
    n_pool = len(pool_uncertain_indices)

    if FLAGS.dppca_eps is not None:
        pool_projected_embeddings, _ = utils.project_embeddings(
            pool_embeddings,
            pca_object=None,
            n_components=FLAGS.k_components,
            pc_cols=pc_cols)
    else:
        pool_projected_embeddings, _ = utils.project_embeddings(
            pool_embeddings, big_pca, FLAGS.k_components)

    del pool_embeddings
    logging.info('projected embeddings of pool data')
    # END: pool data emb

    # BEGIN: assign train_uncertain_projected_embeddings to ONLY 1 point/cluster
    # assign uncertain train to closest pool point/cluster
    pool_index_histogram = onp.zeros(n_pool)
    for i in range(len(train_uncertain_projected_embeddings)):
        # t0 = time.time()
        train_uncertain_point = \
            train_uncertain_projected_embeddings[i].reshape(1, -1)

        if FLAGS.distance == 0 or FLAGS.distance == 'euclidean':
            cluster_distances = euclidean_distances(
                pool_projected_embeddings, train_uncertain_point).reshape(-1)
        elif FLAGS.distance == 1 or FLAGS.distance == 'weighted_euclidean':
            weights = e_vals[:FLAGS.k_components] if FLAGS.dppca_eps is not None \
                else big_pca.singular_values_[:FLAGS.k_components]
            cluster_distances = weighted_euclidean_distances(
                pool_projected_embeddings, train_uncertain_point, weights)

        pool_index = onp.argmin(cluster_distances)
        pool_index_histogram[pool_index] += 1.
        # t1 = time.time()
        # logging.info('%d uncertain train, %s second', i, str(t1 - t0))
    del cluster_distances

    # add Laplacian noise onto #neighors
    if FLAGS.extra_eps is not None:
        pool_index_histogram += npr.laplace(scale=FLAGS.extra_eps -
                                            FLAGS.dppca_eps,
                                            size=pool_index_histogram.shape)

    pool_picked_indices = onp.argsort(
        pool_index_histogram)[::-1][:FLAGS.n_extra]

    logging.info('%d extra pool data picked', len(pool_picked_indices))
    # END: assign train_uncertain_projected_embeddings to ONLY 1 cluster

    # load test data
    test_images, test_labels, _ = datasets.get_dataset_split(
        name=FLAGS.test_split.split('-')[0],
        split=FLAGS.test_split.split('-')[1],
        shuffle=False)
    test_images = (test_images -
                   train_mu) / train_std  # normalize w train mu/std

    # augmentation for train/pool data
    if FLAGS.augment_data:
        augmentation = data.chain_transforms(data.RandomHorizontalFlip(0.5),
                                             data.RandomCrop(4), data.ToDevice)
    else:
        augmentation = None

    test_acc, test_pred = accuracy(params,
                                   shape_as_image(test_images, test_labels),
                                   return_predicted_class=True)
    logging.info('test accuracy: %.2f', test_acc)
    stdout_log.write('test accuracy: {}\n'.format(test_acc))
    stdout_log.flush()
    worst_test_acc, best_test_acc, best_epoch = test_acc, test_acc, 0

    # BEGIN: setup for dp model
    @jit
    def update(_, i, opt_state, batch):
        params = get_params(opt_state)
        return opt_update(i, grad_loss(params, batch), opt_state)

    @jit
    def private_update(rng, i, opt_state, batch):
        params = get_params(opt_state)
        rng = random.fold_in(rng, i)  # get new key for new random numbers
        return opt_update(
            i,
            private_grad(params, batch, rng, FLAGS.l2_norm_clip,
                         FLAGS.noise_multiplier, FLAGS.batch_size), opt_state)

    # END: setup for dp model

    finetune_images = copy.deepcopy(pool_images[pool_picked_indices])
    finetune_labels = copy.deepcopy(pool_labels[pool_picked_indices])

    logging.info('Starting fine-tuning...')
    stdout_log.write('Starting fine-tuning...\n')
    stdout_log.flush()

    # BEGIN: gather points to be used for finetuning
    stdout_log.write('{} points picked via {}\n'.format(
        len(finetune_images), FLAGS.uncertain))
    logging.info('%d points picked via %s', len(finetune_images),
                 FLAGS.uncertain)
    assert FLAGS.n_extra == len(finetune_images)
    # END: gather points to be used for finetuning

    for epoch in range(1, FLAGS.epochs + 1):

        # BEGIN: finetune model with extra data, evaluate and save
        num_extra = len(finetune_images)
        num_complete_batches, leftover = divmod(num_extra, FLAGS.batch_size)
        num_batches = num_complete_batches + bool(leftover)

        finetune = data.DataChunk(X=finetune_images,
                                  Y=finetune_labels,
                                  image_size=28,
                                  image_channels=1,
                                  label_dim=1,
                                  label_format='numeric')

        batches = data.minibatcher(finetune,
                                   FLAGS.batch_size,
                                   transform=augmentation)

        itercount = itertools.count()
        key = random.PRNGKey(FLAGS.seed)

        start_time = time.time()

        for _ in range(num_batches):
            # tmp_time = time.time()
            b = next(batches)
            if FLAGS.dpsgd:
                opt_state = private_update(
                    key, next(itercount), opt_state,
                    shape_as_image(b.X, b.Y, dummy_dim=True))
            else:
                opt_state = update(key, next(itercount), opt_state,
                                   shape_as_image(b.X, b.Y))
            # stdout_log.write('single update in {:.2f} sec\n'.format(
            #     time.time() - tmp_time))

        epoch_time = time.time() - start_time
        stdout_log.write('Epoch {} in {:.2f} sec\n'.format(epoch, epoch_time))
        logging.info('Epoch %d in %.2f sec', epoch, epoch_time)

        # accuracy on test data
        params = get_params(opt_state)

        test_pred_0 = test_pred
        test_acc, test_pred = accuracy(params,
                                       shape_as_image(test_images,
                                                      test_labels),
                                       return_predicted_class=True)
        test_loss = loss(params, shape_as_image(test_images, test_labels))
        stdout_log.write(
            'Eval set loss, accuracy (%): ({:.2f}, {:.2f})\n'.format(
                test_loss, 100 * test_acc))
        logging.info('Eval set loss, accuracy: (%.2f, %.2f)', test_loss,
                     100 * test_acc)
        stdout_log.flush()

        # visualize prediction difference between 2 checkpoints.
        if FLAGS.visualize:
            utils.visualize_ckpt_difference(test_images,
                                            np.argmax(test_labels, axis=1),
                                            test_pred_0,
                                            test_pred,
                                            epoch - 1,
                                            epoch,
                                            FLAGS.work_dir,
                                            mu=train_mu,
                                            sigma=train_std)

        worst_test_acc = min(test_acc, worst_test_acc)
        if test_acc > best_test_acc:
            best_test_acc, best_epoch = test_acc, epoch
            # save opt_state
            with gfile.Open('{}/acc_ckpt'.format(FLAGS.work_dir),
                            'wb') as fckpt:
                pickle.dump(optimizers.unpack_optimizer_state(opt_state),
                            fckpt)
    # END: finetune model with extra data, evaluate and save

    stdout_log.write('best test acc {} @E {}\n'.format(best_test_acc,
                                                       best_epoch))
    stdout_log.close()
Exemple #3
0
def cluster_pcs(big_pca,
                embeddings,
                cluster_method,
                eval_images_numpy,
                fpath,
                checkpoint_idx,
                k_components=16,
                m_demos=8,
                incorrect_indices=None,
                visualize_clusters=True):
    """Cluster projected embeddings onto pcs.

  Visualize examples in each cluster, sorted by distance from cluster center.
  If border in red and green, red means predicted incorrectly.
  """
    if incorrect_indices:

        def color_from_idx(index):
            return 'red' if index in incorrect_indices else 'green'
    else:

        def color_from_idx(_):
            return 'red'

    # Create folder for outputs
    try:
        gfile.MakeDirs(fpath)
    except gfile.GOSError:
        pass

    # Project embeddings onto the first K components
    projected_embeddings, _ = project_embeddings(embeddings, big_pca,
                                                 k_components)

    # Cluster projected embeddings
    cluster_labels = cluster_method.fit_predict(projected_embeddings)

    cluster_label_indices = {x: [] for x in set(cluster_labels)}
    for idx, c_label in enumerate(cluster_labels):
        cluster_label_indices[c_label].append(idx)
    n_clusters = len(cluster_label_indices.keys())

    # adjust m_demos wrt the smallest cluster
    for k in cluster_label_indices.keys():
        m_demos = min(m_demos, int(len(cluster_label_indices[k]) / 2))
        print('cluster {}, count {}, incorrect {}'.format(
            k, len(cluster_label_indices[k]),
            len(incorrect_indices & set(cluster_label_indices[k]))))

    fname = '{}/ckpt-{}_npc-{}_nc-{}_nd-{}'.format(fpath, checkpoint_idx,
                                                   k_components, n_clusters,
                                                   m_demos)

    # Prepare each row/cluster gradually
    big_list = []

    for cid, c_label in enumerate(cluster_label_indices.keys()):

        original_indices = cluster_label_indices[c_label]

        cluster_projected_embeddings = projected_embeddings[original_indices]

        cluster_center = np.mean(cluster_projected_embeddings,
                                 axis=0,
                                 keepdims=True)

        cluster_distances = list(
            euclidean_distances(cluster_projected_embeddings,
                                cluster_center).reshape(-1))

        cluster_indices_distances = zip(original_indices, cluster_distances)
        # sort as from the outside to the center of cluster
        cluster_indices_distances = sorted(cluster_indices_distances,
                                           key=lambda x: x[1],
                                           reverse=True)

        sorted_indices = list(zip(*cluster_indices_distances)[0])

        min_indices = sorted_indices[:m_demos]  # outer of cluster
        max_indices = sorted_indices[-m_demos:]  # inner of cluster

        min_imgs = eval_images_numpy[min_indices]
        max_imgs = eval_images_numpy[max_indices]
        this_row = ([min_imgs[m, :] for m in range(m_demos)] +
                    [max_imgs[m, :] for m in range(m_demos)])
        this_idx_row = min_indices + max_indices

        this_row = [
            make_images_viewable(np.expand_dims(x, 0)) for x in this_row
        ]
        this_row = [to_rgb(x) for x in this_row]
        this_row = [
            add_border(np.squeeze(x), width=2, color=color_from_idx(index))
            for index, x in zip(this_idx_row, this_row)
        ]
        big_list += this_row

        if visualize_clusters:

            this_cluster = [
                make_images_viewable(np.expand_dims(x, 0))
                for x in eval_images_numpy[sorted_indices]
            ]
            this_cluster = [to_rgb(x) for x in this_cluster]
            this_cluster = [
                add_border(np.squeeze(x), width=2, color=color_from_idx(index))
                for index, x in zip(sorted_indices, this_cluster)
            ]

            tile_image_list(this_cluster, fname + '_cid-{}'.format(cid))

    grid_image = tile_image_list(big_list, fname, n_rows=n_clusters)
    return grid_image
Exemple #4
0
def main(_):
    logging.info('Starting experiment.')
    configs = FLAGS.config

    # Create model folder for outputs
    try:
        gfile.MakeDirs(FLAGS.exp_dir)
    except gfile.GOSError:
        pass
    stdout_log = gfile.Open('{}/stdout.log'.format(FLAGS.exp_dir), 'w+')

    logging.info('Loading data.')
    tic = time.time()

    train_images, train_labels, _ = datasets.get_dataset_split(
        FLAGS.dataset, 'train')
    n_train = len(train_images)
    train_mu, train_std = onp.mean(train_images), onp.std(train_images)
    train = data.DataChunk(X=(train_images - train_mu) / train_std,
                           Y=train_labels,
                           image_size=32,
                           image_channels=3,
                           label_dim=1,
                           label_format='numeric')

    test_images, test_labels, _ = datasets.get_dataset_split(
        FLAGS.dataset, 'test')
    test = data.DataChunk(
        X=(test_images - train_mu) / train_std,  # normalize w train mean/std
        Y=test_labels,
        image_size=32,
        image_channels=3,
        label_dim=1,
        label_format='numeric')

    # Data augmentation
    if configs.augment_data:
        augmentation = data.chain_transforms(data.RandomHorizontalFlip(0.5),
                                             data.RandomCrop(4), data.ToDevice)
    else:
        augmentation = None
    batch = data.minibatcher(train, configs.batch_size, transform=augmentation)

    # Model architecture
    if configs.architect == 'wrn':
        init_random_params, predict = wide_resnet(configs.block_size,
                                                  configs.channel_multiplier,
                                                  10)
    elif configs.architect == 'cnn':
        init_random_params, predict = cnn()
    else:
        raise ValueError('Model architecture not implemented.')

    if configs.seed is not None:
        key = random.PRNGKey(configs.seed)
    else:
        key = random.PRNGKey(int(time.time()))
    _, params = init_random_params(key, (-1, 32, 32, 3))

    # count params of JAX model
    def count_parameters(params):
        return tree_util.tree_reduce(
            operator.add, tree_util.tree_map(lambda x: np.prod(x.shape),
                                             params))

    logging.info('Number of parameters: %d', count_parameters(params))
    stdout_log.write('Number of params: {}\n'.format(count_parameters(params)))

    # loss functions
    def cross_entropy_loss(params, x_img, y_lbl):
        return -np.mean(stax.logsoftmax(predict(params, x_img)) * y_lbl)

    def mse_loss(params, x_img, y_lbl):
        return 0.5 * np.mean((y_lbl - predict(params, x_img))**2)

    def accuracy(y_lbl_hat, y_lbl):
        target_class = np.argmax(y_lbl, axis=1)
        predicted_class = np.argmax(y_lbl_hat, axis=1)
        return np.mean(predicted_class == target_class)

    # Loss and gradient
    if configs.loss == 'xent':
        loss = cross_entropy_loss
    elif configs.loss == 'mse':
        loss = mse_loss
    else:
        raise ValueError('Loss function not implemented.')
    grad_loss = jit(grad(loss))

    # learning rate schedule and optimizer
    def cosine(initial_step_size, train_steps):
        k = np.pi / (2.0 * train_steps)

        def schedule(i):
            return initial_step_size * np.cos(k * i)

        return schedule

    if configs.optimization == 'sgd':
        lr_schedule = optimizers.make_schedule(configs.learning_rate)
        opt_init, opt_update, get_params = optimizers.sgd(lr_schedule)
    elif configs.optimization == 'momentum':
        lr_schedule = cosine(configs.learning_rate, configs.train_steps)
        opt_init, opt_update, get_params = optimizers.momentum(
            lr_schedule, 0.9)
    else:
        raise ValueError('Optimizer not implemented.')

    opt_state = opt_init(params)

    def private_grad(params, batch, rng, l2_norm_clip, noise_multiplier,
                     batch_size):
        """Return differentially private gradients of params, evaluated on batch."""
        def _clipped_grad(params, single_example_batch):
            """Evaluate gradient for a single-example batch and clip its grad norm."""
            grads = grad_loss(params, single_example_batch[0].reshape(
                (-1, 32, 32, 3)), single_example_batch[1])

            nonempty_grads, tree_def = tree_util.tree_flatten(grads)
            total_grad_norm = np.linalg.norm(
                [np.linalg.norm(neg.ravel()) for neg in nonempty_grads])
            divisor = stop_gradient(
                np.amax((total_grad_norm / l2_norm_clip, 1.)))
            normalized_nonempty_grads = [
                neg / divisor for neg in nonempty_grads
            ]
            return tree_util.tree_unflatten(tree_def,
                                            normalized_nonempty_grads)

        px_clipped_grad_fn = vmap(partial(_clipped_grad, params))
        std_dev = l2_norm_clip * noise_multiplier
        noise_ = lambda n: n + std_dev * random.normal(rng, n.shape)
        normalize_ = lambda n: n / float(batch_size)
        sum_ = lambda n: np.sum(n, 0)  # aggregate
        aggregated_clipped_grads = tree_util.tree_map(
            sum_, px_clipped_grad_fn(batch))
        noised_aggregated_clipped_grads = tree_util.tree_map(
            noise_, aggregated_clipped_grads)
        normalized_noised_aggregated_clipped_grads = (tree_util.tree_map(
            normalize_, noised_aggregated_clipped_grads))
        return normalized_noised_aggregated_clipped_grads

    # summarize measurements
    steps_per_epoch = n_train // configs.batch_size

    def summarize(step, params):
        """Compute measurements in a zipped way."""
        set_entries = [train, test]
        set_bsizes = [configs.train_eval_bsize, configs.test_eval_bsize]
        set_names, loss_dict, acc_dict = ['train', 'test'], {}, {}

        for set_entry, set_bsize, set_name in zip(set_entries, set_bsizes,
                                                  set_names):
            temp_loss, temp_acc, points = 0.0, 0.0, 0
            for b in data.batch(set_entry, set_bsize):
                temp_loss += loss(params, b.X, b.Y) * b.X.shape[0]
                temp_acc += accuracy(predict(params, b.X), b.Y) * b.X.shape[0]
                points += b.X.shape[0]
            loss_dict[set_name] = temp_loss / float(points)
            acc_dict[set_name] = temp_acc / float(points)

        logging.info('Step: %s', str(step))
        logging.info('Train acc : %.4f', acc_dict['train'])
        logging.info('Train loss: %.4f', loss_dict['train'])
        logging.info('Test acc  : %.4f', acc_dict['test'])
        logging.info('Test loss : %.4f', loss_dict['test'])

        stdout_log.write('Step: {}\n'.format(step))
        stdout_log.write('Train acc : {}\n'.format(acc_dict['train']))
        stdout_log.write('Train loss: {}\n'.format(loss_dict['train']))
        stdout_log.write('Test acc  : {}\n'.format(acc_dict['test']))
        stdout_log.write('Test loss : {}\n'.format(loss_dict['test']))

        return acc_dict['test']

    toc = time.time()
    logging.info('Elapsed SETUP time: %s', str(toc - tic))
    stdout_log.write('Elapsed SETUP time: {}\n'.format(toc - tic))

    # BEGIN: training steps
    logging.info('Training network.')
    tic = time.time()
    t = time.time()

    for s in range(configs.train_steps):
        b = next(batch)
        params = get_params(opt_state)

        # t0 = time.time()
        if FLAGS.dpsgd:
            key = random.fold_in(key, s)  # get new key for new random numbers
            opt_state = opt_update(
                s,
                private_grad(params, (b.X.reshape(
                    (-1, 1, 32, 32, 3)), b.Y), key, configs.l2_norm_clip,
                             configs.noise_multiplier, configs.batch_size),
                opt_state)
        else:
            opt_state = opt_update(s, grad_loss(params, b.X, b.Y), opt_state)
        # t1 = time.time()
        # logging.info('batch update time: %s', str(t1 - t0))

        if s % steps_per_epoch == 0:
            with gfile.Open(
                    '{}/ckpt_{}'.format(FLAGS.exp_dir,
                                        int(s / steps_per_epoch)),
                    'wr') as fckpt:
                pickle.dump(optimizers.unpack_optimizer_state(opt_state),
                            fckpt)

            if FLAGS.dpsgd:
                eps = compute_epsilon(s, configs.batch_size, n_train,
                                      configs.target_delta,
                                      configs.noise_multiplier)
                stdout_log.write(
                    'For delta={:.0e}, current epsilon is: {:.2f}\n'.format(
                        configs.target_delta, eps))

            logging.info('Elapsed EPOCH time: %s', str(time.time() - t))
            stdout_log.write('Elapsed EPOCH time: {}'.format(time.time() - t))
            stdout_log.flush()
            t = time.time()

    toc = time.time()
    summarize(configs.train_steps, params)
    logging.info('Elapsed TRAIN time: %s', str(toc - tic))
    stdout_log.write('Elapsed TRAIN time: {}'.format(toc - tic))
    stdout_log.close()
Exemple #5
0
def main(_):
    sns.set()
    sns.set_palette(sns.color_palette('hls', 10))
    npr.seed(FLAGS.seed)

    logging.info('Starting experiment.')

    # Create model folder for outputs
    try:
        gfile.MakeDirs(FLAGS.work_dir)
    except gfile.GOSError:
        pass
    stdout_log = gfile.Open('{}/stdout.log'.format(FLAGS.work_dir), 'w+')

    # BEGIN: fetch test data and candidate pool
    test_images, test_labels, _ = datasets.get_dataset_split(
        name=FLAGS.test_split.split('-')[0],
        split=FLAGS.test_split.split('-')[1],
        shuffle=False)
    pool_images, pool_labels, _ = datasets.get_dataset_split(
        name=FLAGS.pool_split.split('-')[0],
        split=FLAGS.pool_split.split('-')[1],
        shuffle=False)

    n_pool = len(pool_images)
    # normalize to range [-1.0, 127./128]
    test_images = test_images / np.float32(128.0) - np.float32(1.0)
    pool_images = pool_images / np.float32(128.0) - np.float32(1.0)

    # augmentation for train/pool data
    if FLAGS.augment_data:
        augmentation = data.chain_transforms(data.RandomHorizontalFlip(0.5),
                                             data.RandomCrop(4), data.ToDevice)
    else:
        augmentation = None
    # END: fetch test data and candidate pool

    _, opt_update, get_params = optimizers.sgd(FLAGS.learning_rate)

    # BEGIN: load ckpt
    ckpt_dir = '{}/{}'.format(FLAGS.root_dir, FLAGS.ckpt_idx)
    with gfile.Open(ckpt_dir, 'wr') as fckpt:
        opt_state = optimizers.pack_optimizer_state(pickle.load(fckpt))
    params = get_params(opt_state)

    stdout_log.write('finetune from: {}\n'.format(ckpt_dir))
    logging.info('finetune from: %s', ckpt_dir)
    test_acc, test_pred = accuracy(params,
                                   shape_as_image(test_images, test_labels),
                                   return_predicted_class=True)
    logging.info('test accuracy: %.2f', test_acc)
    stdout_log.write('test accuracy: {}\n'.format(test_acc))
    stdout_log.flush()
    # END: load ckpt

    # BEGIN: setup for dp model
    @jit
    def update(_, i, opt_state, batch):
        params = get_params(opt_state)
        return opt_update(i, grad_loss(params, batch), opt_state)

    @jit
    def private_update(rng, i, opt_state, batch):
        params = get_params(opt_state)
        rng = random.fold_in(rng, i)  # get new key for new random numbers
        return opt_update(
            i,
            private_grad(params, batch, rng, FLAGS.l2_norm_clip,
                         FLAGS.noise_multiplier, FLAGS.batch_size), opt_state)

    # END: setup for dp model

    ### BEGIN: prepare extra points picked from pool data
    # BEGIN: on pool data
    pool_embeddings = [apply_fn_0(params[:-1],
                                  pool_images[b_i:b_i + FLAGS.batch_size]) \
                       for b_i in range(0, n_pool, FLAGS.batch_size)]
    pool_embeddings = np.concatenate(pool_embeddings, axis=0)

    pool_logits = apply_fn_1(params[-1:], pool_embeddings)

    pool_true_labels = np.argmax(pool_labels, axis=1)
    pool_predicted_labels = np.argmax(pool_logits, axis=1)
    pool_correct_indices = \
        onp.where(pool_true_labels == pool_predicted_labels)[0]
    pool_incorrect_indices = \
        onp.where(pool_true_labels != pool_predicted_labels)[0]
    assert len(pool_correct_indices) + \
        len(pool_incorrect_indices) == len(pool_labels)

    pool_probs = stax.softmax(pool_logits)

    if FLAGS.uncertain == 0 or FLAGS.uncertain == 'entropy':
        pool_entropy = -onp.sum(pool_probs * onp.log(pool_probs), axis=1)
        stdout_log.write('all {} entropy: min {}, max {}\n'.format(
            len(pool_entropy), onp.min(pool_entropy), onp.max(pool_entropy)))

        pool_entropy_sorted_indices = onp.argsort(pool_entropy)
        # take the n_extra most uncertain points
        pool_uncertain_indices = \
            pool_entropy_sorted_indices[::-1][:FLAGS.n_extra]
        stdout_log.write('uncertain {} entropy: min {}, max {}\n'.format(
            len(pool_entropy[pool_uncertain_indices]),
            onp.min(pool_entropy[pool_uncertain_indices]),
            onp.max(pool_entropy[pool_uncertain_indices])))

    elif FLAGS.uncertain == 1 or FLAGS.uncertain == 'difference':
        # 1st_prob - 2nd_prob
        assert len(pool_probs.shape) == 2
        sorted_pool_probs = onp.sort(pool_probs, axis=1)
        pool_probs_diff = sorted_pool_probs[:, -1] - sorted_pool_probs[:, -2]
        assert min(pool_probs_diff) > 0.
        stdout_log.write('all {} difference: min {}, max {}\n'.format(
            len(pool_probs_diff), onp.min(pool_probs_diff),
            onp.max(pool_probs_diff)))

        pool_uncertain_indices = onp.argsort(pool_probs_diff)[:FLAGS.n_extra]
        stdout_log.write('uncertain {} difference: min {}, max {}\n'.format(
            len(pool_probs_diff[pool_uncertain_indices]),
            onp.min(pool_probs_diff[pool_uncertain_indices]),
            onp.max(pool_probs_diff[pool_uncertain_indices])))

    elif FLAGS.uncertain == 2 or FLAGS.uncertain == 'random':
        pool_uncertain_indices = npr.permutation(n_pool)[:FLAGS.n_extra]

    # END: on pool data

    ### END: prepare extra points picked from pool data

    finetune_images = copy.deepcopy(pool_images[pool_uncertain_indices])
    finetune_labels = copy.deepcopy(pool_labels[pool_uncertain_indices])

    stdout_log.write('Starting fine-tuning...\n')
    logging.info('Starting fine-tuning...')
    stdout_log.flush()

    stdout_log.write('{} points picked via {}\n'.format(
        len(finetune_images), FLAGS.uncertain))
    logging.info('%d points picked via %s', len(finetune_images),
                 FLAGS.uncertain)
    assert FLAGS.n_extra == len(finetune_images)

    for epoch in range(1, FLAGS.epochs + 1):

        # BEGIN: finetune model with extra data, evaluate and save
        num_extra = len(finetune_images)
        num_complete_batches, leftover = divmod(num_extra, FLAGS.batch_size)
        num_batches = num_complete_batches + bool(leftover)

        finetune = data.DataChunk(X=finetune_images,
                                  Y=finetune_labels,
                                  image_size=28,
                                  image_channels=1,
                                  label_dim=1,
                                  label_format='numeric')

        batches = data.minibatcher(finetune,
                                   FLAGS.batch_size,
                                   transform=augmentation)

        itercount = itertools.count()
        key = random.PRNGKey(FLAGS.seed)

        start_time = time.time()

        for _ in range(num_batches):
            # tmp_time = time.time()
            b = next(batches)
            if FLAGS.dpsgd:
                opt_state = private_update(
                    key, next(itercount), opt_state,
                    shape_as_image(b.X, b.Y, dummy_dim=True))
            else:
                opt_state = update(key, next(itercount), opt_state,
                                   shape_as_image(b.X, b.Y))
            # stdout_log.write('single update in {:.2f} sec\n'.format(
            #     time.time() - tmp_time))

        epoch_time = time.time() - start_time
        stdout_log.write('Epoch {} in {:.2f} sec\n'.format(epoch, epoch_time))
        logging.info('Epoch %d in %.2f sec', epoch, epoch_time)

        # accuracy on test data
        params = get_params(opt_state)

        test_pred_0 = test_pred
        test_acc, test_pred = accuracy(params,
                                       shape_as_image(test_images,
                                                      test_labels),
                                       return_predicted_class=True)
        test_loss = loss(params, shape_as_image(test_images, test_labels))
        stdout_log.write(
            'Eval set loss, accuracy (%): ({:.2f}, {:.2f})\n'.format(
                test_loss, 100 * test_acc))
        logging.info('Eval set loss, accuracy: (%.2f, %.2f)', test_loss,
                     100 * test_acc)
        stdout_log.flush()

        # visualize prediction difference between 2 checkpoints.
        if FLAGS.visualize:
            utils.visualize_ckpt_difference(test_images,
                                            np.argmax(test_labels, axis=1),
                                            test_pred_0,
                                            test_pred,
                                            epoch - 1,
                                            epoch,
                                            FLAGS.work_dir,
                                            mu=128.,
                                            sigma=128.)

    # END: finetune model with extra data, evaluate and save

    stdout_log.close()