def main(_): sns.set() sns.set_palette(sns.color_palette('hls', 10)) npr.seed(FLAGS.seed) logging.info('Starting experiment.') # Create model folder for outputs try: gfile.MakeDirs(FLAGS.work_dir) except gfile.GOSError: pass stdout_log = gfile.Open('{}/stdout.log'.format(FLAGS.work_dir), 'w+') # BEGIN: fetch test data and candidate pool test_images, test_labels, _ = datasets.get_dataset_split( name=FLAGS.test_split.split('-')[0], split=FLAGS.test_split.split('-')[1], shuffle=False) pool_images, pool_labels, _ = datasets.get_dataset_split( name=FLAGS.pool_split.split('-')[0], split=FLAGS.pool_split.split('-')[1], shuffle=False) n_pool = len(pool_images) # normalize to range [-1.0, 127./128] test_images = test_images / np.float32(128.0) - np.float32(1.0) pool_images = pool_images / np.float32(128.0) - np.float32(1.0) # augmentation for train/pool data if FLAGS.augment_data: augmentation = data.chain_transforms(data.RandomHorizontalFlip(0.5), data.RandomCrop(4), data.ToDevice) else: augmentation = None # END: fetch test data and candidate pool _, opt_update, get_params = optimizers.sgd(FLAGS.learning_rate) # BEGIN: load ckpt ckpt_dir = '{}/{}'.format(FLAGS.root_dir, FLAGS.ckpt_idx) with gfile.Open(ckpt_dir, 'rb') as fckpt: opt_state = optimizers.pack_optimizer_state(pickle.load(fckpt)) params = get_params(opt_state) stdout_log.write('finetune from: {}\n'.format(ckpt_dir)) logging.info('finetune from: %s', ckpt_dir) test_acc, test_pred = accuracy(params, shape_as_image(test_images, test_labels), return_predicted_class=True) logging.info('test accuracy: %.2f', test_acc) stdout_log.write('test accuracy: {}\n'.format(test_acc)) stdout_log.flush() # END: load ckpt # BEGIN: setup for dp model @jit def update(_, i, opt_state, batch): params = get_params(opt_state) return opt_update(i, grad_loss(params, batch), opt_state) @jit def private_update(rng, i, opt_state, batch): params = get_params(opt_state) rng = random.fold_in(rng, i) # get new key for new random numbers return opt_update( i, private_grad(params, batch, rng, FLAGS.l2_norm_clip, FLAGS.noise_multiplier, FLAGS.batch_size), opt_state) # END: setup for dp model n_uncertain = FLAGS.n_extra + FLAGS.uncertain_extra ### BEGIN: prepare extra points picked from pool data # BEGIN: on pool data pool_embeddings = [apply_fn_0(params[:-1], pool_images[b_i:b_i + FLAGS.batch_size]) \ for b_i in range(0, n_pool, FLAGS.batch_size)] pool_embeddings = np.concatenate(pool_embeddings, axis=0) pool_logits = apply_fn_1(params[-1:], pool_embeddings) pool_true_labels = np.argmax(pool_labels, axis=1) pool_predicted_labels = np.argmax(pool_logits, axis=1) pool_correct_indices = \ onp.where(pool_true_labels == pool_predicted_labels)[0] pool_incorrect_indices = \ onp.where(pool_true_labels != pool_predicted_labels)[0] assert len(pool_correct_indices) + \ len(pool_incorrect_indices) == len(pool_labels) pool_probs = stax.softmax(pool_logits) if FLAGS.uncertain == 0 or FLAGS.uncertain == 'entropy': pool_entropy = -onp.sum(pool_probs * onp.log(pool_probs), axis=1) stdout_log.write('all {} entropy: min {}, max {}\n'.format( len(pool_entropy), onp.min(pool_entropy), onp.max(pool_entropy))) pool_entropy_sorted_indices = onp.argsort(pool_entropy) # take the n_uncertain most uncertain points pool_uncertain_indices = \ pool_entropy_sorted_indices[::-1][:n_uncertain] stdout_log.write('uncertain {} entropy: min {}, max {}\n'.format( len(pool_entropy[pool_uncertain_indices]), onp.min(pool_entropy[pool_uncertain_indices]), onp.max(pool_entropy[pool_uncertain_indices]))) elif FLAGS.uncertain == 1 or FLAGS.uncertain == 'difference': # 1st_prob - 2nd_prob assert len(pool_probs.shape) == 2 sorted_pool_probs = onp.sort(pool_probs, axis=1) pool_probs_diff = sorted_pool_probs[:, -1] - sorted_pool_probs[:, -2] assert min(pool_probs_diff) > 0. pool_uncertain_indices = onp.argsort(pool_probs_diff)[:n_uncertain] # END: on pool data # BEGIN: cluster uncertain pool points big_pca = sklearn.decomposition.PCA(n_components=pool_embeddings.shape[1]) big_pca.random_state = FLAGS.seed # fit PCA onto embeddings of all the pool points big_pca.fit(pool_embeddings) # For uncertain points, project embeddings onto the first K components pool_uncertain_projected_embeddings, _ = utils.project_embeddings( pool_embeddings[pool_uncertain_indices], big_pca, FLAGS.k_components) n_cluster = int(FLAGS.n_extra / FLAGS.ppc) cluster_method = get_cluster_method('{}_nc-{}'.format( FLAGS.clustering, n_cluster)) cluster_method.random_state = FLAGS.seed pool_uncertain_cluster_labels = cluster_method.fit_predict( pool_uncertain_projected_embeddings) pool_uncertain_cluster_label_indices = { x: [] for x in set(pool_uncertain_cluster_labels) } # local i within n_uncertain for i, c_label in enumerate(pool_uncertain_cluster_labels): pool_uncertain_cluster_label_indices[c_label].append(i) # find center of each cluster # aka, the most representative point of each 'tough' cluster pool_picked_indices = [] pool_uncertain_cluster_label_pick = {} for c_label, indices in pool_uncertain_cluster_label_indices.items(): cluster_projected_embeddings = \ pool_uncertain_projected_embeddings[indices] cluster_center = onp.mean(cluster_projected_embeddings, axis=0, keepdims=True) if FLAGS.distance == 0 or FLAGS.distance == 'euclidean': cluster_distances = euclidean_distances( cluster_projected_embeddings, cluster_center).reshape(-1) elif FLAGS.distance == 1 or FLAGS.distance == 'weighted_euclidean': cluster_distances = weighted_euclidean_distances( cluster_projected_embeddings, cluster_center, big_pca.singular_values_[:FLAGS.k_components]) sorted_is = onp.argsort(cluster_distances) sorted_indices = onp.array(indices)[sorted_is] pool_uncertain_cluster_label_indices[c_label] = sorted_indices center_i = sorted_indices[0] # center_i in 3000 pool_uncertain_cluster_label_pick[c_label] = center_i pool_picked_indices.extend( pool_uncertain_indices[sorted_indices[:FLAGS.ppc]]) # BEGIN: visualize cluster of picked uncertain pool if FLAGS.visualize: this_cluster = [] for i in sorted_indices: idx = pool_uncertain_indices[i] img = pool_images[idx] if idx in pool_correct_indices: border_color = 'green' else: border_color = 'red' img = utils.mark_labels(img, pool_predicted_labels[idx], pool_true_labels[idx]) img = utils.denormalize(img, 128., 128.) img = np.squeeze(utils.to_rgb(np.expand_dims(img, 0))) img = utils.add_border(img, width=2, color=border_color) this_cluster.append(img) utils.tile_image_list( this_cluster, '{}/picked_uncertain_pool_cid-{}'.format( FLAGS.work_dir, c_label)) # END: visualize cluster of picked uncertain pool # END: cluster uncertain pool points pool_picked_indices = list(set(pool_picked_indices)) n_gap = FLAGS.n_extra - len(pool_picked_indices) gap_indices = list(set(pool_uncertain_indices) - set(pool_picked_indices)) pool_picked_indices.extend(npr.choice(gap_indices, n_gap, replace=False)) stdout_log.write('n_gap: {}\n'.format(n_gap)) ### END: prepare extra points picked from pool data finetune_images = copy.deepcopy(pool_images[pool_picked_indices]) finetune_labels = copy.deepcopy(pool_labels[pool_picked_indices]) stdout_log.write('{} points picked via {}\n'.format( len(finetune_images), FLAGS.uncertain)) logging.info('%d points picked via %s', len(finetune_images), FLAGS.uncertain) assert FLAGS.n_extra == len(finetune_images) # END: gather points to be used for finetuning stdout_log.write('Starting fine-tuning...\n') logging.info('Starting fine-tuning...') stdout_log.flush() for epoch in range(1, FLAGS.epochs + 1): # BEGIN: finetune model with extra data, evaluate and save num_extra = len(finetune_images) num_complete_batches, leftover = divmod(num_extra, FLAGS.batch_size) num_batches = num_complete_batches + bool(leftover) finetune = data.DataChunk(X=finetune_images, Y=finetune_labels, image_size=28, image_channels=1, label_dim=1, label_format='numeric') batches = data.minibatcher(finetune, FLAGS.batch_size, transform=augmentation) itercount = itertools.count() key = random.PRNGKey(FLAGS.seed) start_time = time.time() for _ in range(num_batches): # tmp_time = time.time() b = next(batches) if FLAGS.dpsgd: opt_state = private_update( key, next(itercount), opt_state, shape_as_image(b.X, b.Y, dummy_dim=True)) else: opt_state = update(key, next(itercount), opt_state, shape_as_image(b.X, b.Y)) # stdout_log.write('single update in {:.2f} sec\n'.format( # time.time() - tmp_time)) epoch_time = time.time() - start_time stdout_log.write('Epoch {} in {:.2f} sec\n'.format(epoch, epoch_time)) logging.info('Epoch %d in %.2f sec', epoch, epoch_time) # accuracy on test data params = get_params(opt_state) test_pred_0 = test_pred test_acc, test_pred = accuracy(params, shape_as_image(test_images, test_labels), return_predicted_class=True) test_loss = loss(params, shape_as_image(test_images, test_labels)) stdout_log.write( 'Eval set loss, accuracy (%): ({:.2f}, {:.2f})\n'.format( test_loss, 100 * test_acc)) logging.info('Eval set loss, accuracy: (%.2f, %.2f)', test_loss, 100 * test_acc) stdout_log.flush() # visualize prediction difference between 2 checkpoints. if FLAGS.visualize: utils.visualize_ckpt_difference(test_images, np.argmax(test_labels, axis=1), test_pred_0, test_pred, epoch - 1, epoch, FLAGS.work_dir, mu=128., sigma=128.) # END: finetune model with extra data, evaluate and save stdout_log.close()
def main(_): sns.set() sns.set_palette(sns.color_palette('hls', 10)) npr.seed(FLAGS.seed) logging.info('Starting experiment.') # Create model folder for outputs try: gfile.MakeDirs(FLAGS.work_dir) except gfile.GOSError: pass stdout_log = gfile.Open('{}/stdout.log'.format(FLAGS.work_dir), 'w+') # BEGIN: set up optimizer and load params _, opt_update, get_params = optimizers.sgd(FLAGS.learning_rate) ckpt_dir = '{}/{}'.format(FLAGS.root_dir, FLAGS.ckpt_idx) with gfile.Open(ckpt_dir, 'rb') as fckpt: opt_state = optimizers.pack_optimizer_state(pickle.load(fckpt)) params = get_params(opt_state) stdout_log.write('finetune from: {}\n'.format(ckpt_dir)) logging.info('finetune from: %s', ckpt_dir) # END: set up optimizer and load params # BEGIN: train data emb, PCA, uncertain train_images, train_labels, _ = datasets.get_dataset_split( name=FLAGS.train_split.split('-')[0], split=FLAGS.train_split.split('-')[1], shuffle=False) n_train = len(train_images) # use mean/std of svhn train train_mu, train_std = 128., 128. train_images = (train_images - train_mu) / train_std # embeddings of all the training points train_embeddings = [apply_fn_0(params[:-1], train_images[b_i:b_i + FLAGS.batch_size]) \ for b_i in range(0, n_train, FLAGS.batch_size)] train_embeddings = np.concatenate(train_embeddings, axis=0) # fit PCA onto embeddings of all the training points if FLAGS.dppca_eps is not None: pc_cols, e_vals = dp_pca.dp_pca(train_embeddings, train_embeddings.shape[1], epsilon=FLAGS.dppca_eps, delta=1e-5, sigma=None) else: big_pca = sklearn.decomposition.PCA( n_components=train_embeddings.shape[1]) big_pca.random_state = FLAGS.seed big_pca.fit(train_embeddings) # filter out uncertain train points n_uncertain = FLAGS.n_extra + FLAGS.uncertain_extra train_probs = stax.softmax(apply_fn_1(params[-1:], train_embeddings)) train_acc = np.mean( np.argmax(train_probs, axis=1) == np.argmax(train_labels, axis=1)) logging.info('initial train acc: %.2f', train_acc) if FLAGS.uncertain == 0 or FLAGS.uncertain == 'entropy': # entropy train_entropy = -onp.sum(train_probs * onp.log(train_probs), axis=1) train_uncertain_indices = \ onp.argsort(train_entropy)[::-1][:n_uncertain] elif FLAGS.uncertain == 1 or FLAGS.uncertain == 'difference': # 1st_prob - 2nd_prob assert len(train_probs.shape) == 2 sorted_train_probs = onp.sort(train_probs, axis=1) train_probs_diff = sorted_train_probs[:, -1] - sorted_train_probs[:, -2] assert min(train_probs_diff) > 0. train_uncertain_indices = onp.argsort(train_probs_diff)[:n_uncertain] if FLAGS.dppca_eps is not None: train_uncertain_projected_embeddings, _ = utils.project_embeddings( train_embeddings[train_uncertain_indices], pca_object=None, n_components=FLAGS.k_components, pc_cols=pc_cols) else: train_uncertain_projected_embeddings, _ = utils.project_embeddings( train_embeddings[train_uncertain_indices], big_pca, FLAGS.k_components) logging.info('projected embeddings of uncertain train data') del train_images, train_labels, train_embeddings # END: train data emb, PCA, uncertain # BEGIN: pool data emb pool_images, pool_labels, _ = datasets.get_dataset_split( name=FLAGS.pool_split.split('-')[0], split=FLAGS.pool_split.split('-')[1], shuffle=False) n_pool = len(pool_images) pool_images = (pool_images - train_mu) / train_std # normalize w train mu/std pool_embeddings = [apply_fn_0(params[:-1], pool_images[b_i:b_i + FLAGS.batch_size]) \ for b_i in range(0, n_pool, FLAGS.batch_size)] pool_embeddings = np.concatenate(pool_embeddings, axis=0) # filter out uncertain pool points pool_probs = stax.softmax(apply_fn_1(params[-1:], pool_embeddings)) if FLAGS.uncertain == 0 or FLAGS.uncertain == 'entropy': # entropy pool_entropy = -onp.sum(pool_probs * onp.log(pool_probs), axis=1) pool_uncertain_indices = onp.argsort(pool_entropy)[::-1][:n_uncertain] elif FLAGS.uncertain == 1 or FLAGS.uncertain == 'difference': # 1st_prob - 2nd_prob assert len(pool_probs.shape) == 2 sorted_pool_probs = onp.sort(pool_probs, axis=1) pool_probs_diff = sorted_pool_probs[:, -1] - sorted_pool_probs[:, -2] assert min(pool_probs_diff) > 0. pool_uncertain_indices = onp.argsort(pool_probs_diff)[:n_uncertain] # constrain pool candidates to ONLY uncertain ones pool_images = pool_images[pool_uncertain_indices] pool_labels = pool_labels[pool_uncertain_indices] pool_embeddings = pool_embeddings[pool_uncertain_indices] n_pool = len(pool_uncertain_indices) if FLAGS.dppca_eps is not None: pool_projected_embeddings, _ = utils.project_embeddings( pool_embeddings, pca_object=None, n_components=FLAGS.k_components, pc_cols=pc_cols) else: pool_projected_embeddings, _ = utils.project_embeddings( pool_embeddings, big_pca, FLAGS.k_components) del pool_embeddings logging.info('projected embeddings of pool data') # END: pool data emb # BEGIN: assign train_uncertain_projected_embeddings to ONLY 1 point/cluster # assign uncertain train to closest pool point/cluster pool_index_histogram = onp.zeros(n_pool) for i in range(len(train_uncertain_projected_embeddings)): # t0 = time.time() train_uncertain_point = \ train_uncertain_projected_embeddings[i].reshape(1, -1) if FLAGS.distance == 0 or FLAGS.distance == 'euclidean': cluster_distances = euclidean_distances( pool_projected_embeddings, train_uncertain_point).reshape(-1) elif FLAGS.distance == 1 or FLAGS.distance == 'weighted_euclidean': weights = e_vals[:FLAGS.k_components] if FLAGS.dppca_eps is not None \ else big_pca.singular_values_[:FLAGS.k_components] cluster_distances = weighted_euclidean_distances( pool_projected_embeddings, train_uncertain_point, weights) pool_index = onp.argmin(cluster_distances) pool_index_histogram[pool_index] += 1. # t1 = time.time() # logging.info('%d uncertain train, %s second', i, str(t1 - t0)) del cluster_distances # add Laplacian noise onto #neighors if FLAGS.extra_eps is not None: pool_index_histogram += npr.laplace(scale=FLAGS.extra_eps - FLAGS.dppca_eps, size=pool_index_histogram.shape) pool_picked_indices = onp.argsort( pool_index_histogram)[::-1][:FLAGS.n_extra] logging.info('%d extra pool data picked', len(pool_picked_indices)) # END: assign train_uncertain_projected_embeddings to ONLY 1 cluster # load test data test_images, test_labels, _ = datasets.get_dataset_split( name=FLAGS.test_split.split('-')[0], split=FLAGS.test_split.split('-')[1], shuffle=False) test_images = (test_images - train_mu) / train_std # normalize w train mu/std # augmentation for train/pool data if FLAGS.augment_data: augmentation = data.chain_transforms(data.RandomHorizontalFlip(0.5), data.RandomCrop(4), data.ToDevice) else: augmentation = None test_acc, test_pred = accuracy(params, shape_as_image(test_images, test_labels), return_predicted_class=True) logging.info('test accuracy: %.2f', test_acc) stdout_log.write('test accuracy: {}\n'.format(test_acc)) stdout_log.flush() worst_test_acc, best_test_acc, best_epoch = test_acc, test_acc, 0 # BEGIN: setup for dp model @jit def update(_, i, opt_state, batch): params = get_params(opt_state) return opt_update(i, grad_loss(params, batch), opt_state) @jit def private_update(rng, i, opt_state, batch): params = get_params(opt_state) rng = random.fold_in(rng, i) # get new key for new random numbers return opt_update( i, private_grad(params, batch, rng, FLAGS.l2_norm_clip, FLAGS.noise_multiplier, FLAGS.batch_size), opt_state) # END: setup for dp model finetune_images = copy.deepcopy(pool_images[pool_picked_indices]) finetune_labels = copy.deepcopy(pool_labels[pool_picked_indices]) logging.info('Starting fine-tuning...') stdout_log.write('Starting fine-tuning...\n') stdout_log.flush() # BEGIN: gather points to be used for finetuning stdout_log.write('{} points picked via {}\n'.format( len(finetune_images), FLAGS.uncertain)) logging.info('%d points picked via %s', len(finetune_images), FLAGS.uncertain) assert FLAGS.n_extra == len(finetune_images) # END: gather points to be used for finetuning for epoch in range(1, FLAGS.epochs + 1): # BEGIN: finetune model with extra data, evaluate and save num_extra = len(finetune_images) num_complete_batches, leftover = divmod(num_extra, FLAGS.batch_size) num_batches = num_complete_batches + bool(leftover) finetune = data.DataChunk(X=finetune_images, Y=finetune_labels, image_size=28, image_channels=1, label_dim=1, label_format='numeric') batches = data.minibatcher(finetune, FLAGS.batch_size, transform=augmentation) itercount = itertools.count() key = random.PRNGKey(FLAGS.seed) start_time = time.time() for _ in range(num_batches): # tmp_time = time.time() b = next(batches) if FLAGS.dpsgd: opt_state = private_update( key, next(itercount), opt_state, shape_as_image(b.X, b.Y, dummy_dim=True)) else: opt_state = update(key, next(itercount), opt_state, shape_as_image(b.X, b.Y)) # stdout_log.write('single update in {:.2f} sec\n'.format( # time.time() - tmp_time)) epoch_time = time.time() - start_time stdout_log.write('Epoch {} in {:.2f} sec\n'.format(epoch, epoch_time)) logging.info('Epoch %d in %.2f sec', epoch, epoch_time) # accuracy on test data params = get_params(opt_state) test_pred_0 = test_pred test_acc, test_pred = accuracy(params, shape_as_image(test_images, test_labels), return_predicted_class=True) test_loss = loss(params, shape_as_image(test_images, test_labels)) stdout_log.write( 'Eval set loss, accuracy (%): ({:.2f}, {:.2f})\n'.format( test_loss, 100 * test_acc)) logging.info('Eval set loss, accuracy: (%.2f, %.2f)', test_loss, 100 * test_acc) stdout_log.flush() # visualize prediction difference between 2 checkpoints. if FLAGS.visualize: utils.visualize_ckpt_difference(test_images, np.argmax(test_labels, axis=1), test_pred_0, test_pred, epoch - 1, epoch, FLAGS.work_dir, mu=train_mu, sigma=train_std) worst_test_acc = min(test_acc, worst_test_acc) if test_acc > best_test_acc: best_test_acc, best_epoch = test_acc, epoch # save opt_state with gfile.Open('{}/acc_ckpt'.format(FLAGS.work_dir), 'wb') as fckpt: pickle.dump(optimizers.unpack_optimizer_state(opt_state), fckpt) # END: finetune model with extra data, evaluate and save stdout_log.write('best test acc {} @E {}\n'.format(best_test_acc, best_epoch)) stdout_log.close()
def cluster_pcs(big_pca, embeddings, cluster_method, eval_images_numpy, fpath, checkpoint_idx, k_components=16, m_demos=8, incorrect_indices=None, visualize_clusters=True): """Cluster projected embeddings onto pcs. Visualize examples in each cluster, sorted by distance from cluster center. If border in red and green, red means predicted incorrectly. """ if incorrect_indices: def color_from_idx(index): return 'red' if index in incorrect_indices else 'green' else: def color_from_idx(_): return 'red' # Create folder for outputs try: gfile.MakeDirs(fpath) except gfile.GOSError: pass # Project embeddings onto the first K components projected_embeddings, _ = project_embeddings(embeddings, big_pca, k_components) # Cluster projected embeddings cluster_labels = cluster_method.fit_predict(projected_embeddings) cluster_label_indices = {x: [] for x in set(cluster_labels)} for idx, c_label in enumerate(cluster_labels): cluster_label_indices[c_label].append(idx) n_clusters = len(cluster_label_indices.keys()) # adjust m_demos wrt the smallest cluster for k in cluster_label_indices.keys(): m_demos = min(m_demos, int(len(cluster_label_indices[k]) / 2)) print('cluster {}, count {}, incorrect {}'.format( k, len(cluster_label_indices[k]), len(incorrect_indices & set(cluster_label_indices[k])))) fname = '{}/ckpt-{}_npc-{}_nc-{}_nd-{}'.format(fpath, checkpoint_idx, k_components, n_clusters, m_demos) # Prepare each row/cluster gradually big_list = [] for cid, c_label in enumerate(cluster_label_indices.keys()): original_indices = cluster_label_indices[c_label] cluster_projected_embeddings = projected_embeddings[original_indices] cluster_center = np.mean(cluster_projected_embeddings, axis=0, keepdims=True) cluster_distances = list( euclidean_distances(cluster_projected_embeddings, cluster_center).reshape(-1)) cluster_indices_distances = zip(original_indices, cluster_distances) # sort as from the outside to the center of cluster cluster_indices_distances = sorted(cluster_indices_distances, key=lambda x: x[1], reverse=True) sorted_indices = list(zip(*cluster_indices_distances)[0]) min_indices = sorted_indices[:m_demos] # outer of cluster max_indices = sorted_indices[-m_demos:] # inner of cluster min_imgs = eval_images_numpy[min_indices] max_imgs = eval_images_numpy[max_indices] this_row = ([min_imgs[m, :] for m in range(m_demos)] + [max_imgs[m, :] for m in range(m_demos)]) this_idx_row = min_indices + max_indices this_row = [ make_images_viewable(np.expand_dims(x, 0)) for x in this_row ] this_row = [to_rgb(x) for x in this_row] this_row = [ add_border(np.squeeze(x), width=2, color=color_from_idx(index)) for index, x in zip(this_idx_row, this_row) ] big_list += this_row if visualize_clusters: this_cluster = [ make_images_viewable(np.expand_dims(x, 0)) for x in eval_images_numpy[sorted_indices] ] this_cluster = [to_rgb(x) for x in this_cluster] this_cluster = [ add_border(np.squeeze(x), width=2, color=color_from_idx(index)) for index, x in zip(sorted_indices, this_cluster) ] tile_image_list(this_cluster, fname + '_cid-{}'.format(cid)) grid_image = tile_image_list(big_list, fname, n_rows=n_clusters) return grid_image
def main(_): logging.info('Starting experiment.') configs = FLAGS.config # Create model folder for outputs try: gfile.MakeDirs(FLAGS.exp_dir) except gfile.GOSError: pass stdout_log = gfile.Open('{}/stdout.log'.format(FLAGS.exp_dir), 'w+') logging.info('Loading data.') tic = time.time() train_images, train_labels, _ = datasets.get_dataset_split( FLAGS.dataset, 'train') n_train = len(train_images) train_mu, train_std = onp.mean(train_images), onp.std(train_images) train = data.DataChunk(X=(train_images - train_mu) / train_std, Y=train_labels, image_size=32, image_channels=3, label_dim=1, label_format='numeric') test_images, test_labels, _ = datasets.get_dataset_split( FLAGS.dataset, 'test') test = data.DataChunk( X=(test_images - train_mu) / train_std, # normalize w train mean/std Y=test_labels, image_size=32, image_channels=3, label_dim=1, label_format='numeric') # Data augmentation if configs.augment_data: augmentation = data.chain_transforms(data.RandomHorizontalFlip(0.5), data.RandomCrop(4), data.ToDevice) else: augmentation = None batch = data.minibatcher(train, configs.batch_size, transform=augmentation) # Model architecture if configs.architect == 'wrn': init_random_params, predict = wide_resnet(configs.block_size, configs.channel_multiplier, 10) elif configs.architect == 'cnn': init_random_params, predict = cnn() else: raise ValueError('Model architecture not implemented.') if configs.seed is not None: key = random.PRNGKey(configs.seed) else: key = random.PRNGKey(int(time.time())) _, params = init_random_params(key, (-1, 32, 32, 3)) # count params of JAX model def count_parameters(params): return tree_util.tree_reduce( operator.add, tree_util.tree_map(lambda x: np.prod(x.shape), params)) logging.info('Number of parameters: %d', count_parameters(params)) stdout_log.write('Number of params: {}\n'.format(count_parameters(params))) # loss functions def cross_entropy_loss(params, x_img, y_lbl): return -np.mean(stax.logsoftmax(predict(params, x_img)) * y_lbl) def mse_loss(params, x_img, y_lbl): return 0.5 * np.mean((y_lbl - predict(params, x_img))**2) def accuracy(y_lbl_hat, y_lbl): target_class = np.argmax(y_lbl, axis=1) predicted_class = np.argmax(y_lbl_hat, axis=1) return np.mean(predicted_class == target_class) # Loss and gradient if configs.loss == 'xent': loss = cross_entropy_loss elif configs.loss == 'mse': loss = mse_loss else: raise ValueError('Loss function not implemented.') grad_loss = jit(grad(loss)) # learning rate schedule and optimizer def cosine(initial_step_size, train_steps): k = np.pi / (2.0 * train_steps) def schedule(i): return initial_step_size * np.cos(k * i) return schedule if configs.optimization == 'sgd': lr_schedule = optimizers.make_schedule(configs.learning_rate) opt_init, opt_update, get_params = optimizers.sgd(lr_schedule) elif configs.optimization == 'momentum': lr_schedule = cosine(configs.learning_rate, configs.train_steps) opt_init, opt_update, get_params = optimizers.momentum( lr_schedule, 0.9) else: raise ValueError('Optimizer not implemented.') opt_state = opt_init(params) def private_grad(params, batch, rng, l2_norm_clip, noise_multiplier, batch_size): """Return differentially private gradients of params, evaluated on batch.""" def _clipped_grad(params, single_example_batch): """Evaluate gradient for a single-example batch and clip its grad norm.""" grads = grad_loss(params, single_example_batch[0].reshape( (-1, 32, 32, 3)), single_example_batch[1]) nonempty_grads, tree_def = tree_util.tree_flatten(grads) total_grad_norm = np.linalg.norm( [np.linalg.norm(neg.ravel()) for neg in nonempty_grads]) divisor = stop_gradient( np.amax((total_grad_norm / l2_norm_clip, 1.))) normalized_nonempty_grads = [ neg / divisor for neg in nonempty_grads ] return tree_util.tree_unflatten(tree_def, normalized_nonempty_grads) px_clipped_grad_fn = vmap(partial(_clipped_grad, params)) std_dev = l2_norm_clip * noise_multiplier noise_ = lambda n: n + std_dev * random.normal(rng, n.shape) normalize_ = lambda n: n / float(batch_size) sum_ = lambda n: np.sum(n, 0) # aggregate aggregated_clipped_grads = tree_util.tree_map( sum_, px_clipped_grad_fn(batch)) noised_aggregated_clipped_grads = tree_util.tree_map( noise_, aggregated_clipped_grads) normalized_noised_aggregated_clipped_grads = (tree_util.tree_map( normalize_, noised_aggregated_clipped_grads)) return normalized_noised_aggregated_clipped_grads # summarize measurements steps_per_epoch = n_train // configs.batch_size def summarize(step, params): """Compute measurements in a zipped way.""" set_entries = [train, test] set_bsizes = [configs.train_eval_bsize, configs.test_eval_bsize] set_names, loss_dict, acc_dict = ['train', 'test'], {}, {} for set_entry, set_bsize, set_name in zip(set_entries, set_bsizes, set_names): temp_loss, temp_acc, points = 0.0, 0.0, 0 for b in data.batch(set_entry, set_bsize): temp_loss += loss(params, b.X, b.Y) * b.X.shape[0] temp_acc += accuracy(predict(params, b.X), b.Y) * b.X.shape[0] points += b.X.shape[0] loss_dict[set_name] = temp_loss / float(points) acc_dict[set_name] = temp_acc / float(points) logging.info('Step: %s', str(step)) logging.info('Train acc : %.4f', acc_dict['train']) logging.info('Train loss: %.4f', loss_dict['train']) logging.info('Test acc : %.4f', acc_dict['test']) logging.info('Test loss : %.4f', loss_dict['test']) stdout_log.write('Step: {}\n'.format(step)) stdout_log.write('Train acc : {}\n'.format(acc_dict['train'])) stdout_log.write('Train loss: {}\n'.format(loss_dict['train'])) stdout_log.write('Test acc : {}\n'.format(acc_dict['test'])) stdout_log.write('Test loss : {}\n'.format(loss_dict['test'])) return acc_dict['test'] toc = time.time() logging.info('Elapsed SETUP time: %s', str(toc - tic)) stdout_log.write('Elapsed SETUP time: {}\n'.format(toc - tic)) # BEGIN: training steps logging.info('Training network.') tic = time.time() t = time.time() for s in range(configs.train_steps): b = next(batch) params = get_params(opt_state) # t0 = time.time() if FLAGS.dpsgd: key = random.fold_in(key, s) # get new key for new random numbers opt_state = opt_update( s, private_grad(params, (b.X.reshape( (-1, 1, 32, 32, 3)), b.Y), key, configs.l2_norm_clip, configs.noise_multiplier, configs.batch_size), opt_state) else: opt_state = opt_update(s, grad_loss(params, b.X, b.Y), opt_state) # t1 = time.time() # logging.info('batch update time: %s', str(t1 - t0)) if s % steps_per_epoch == 0: with gfile.Open( '{}/ckpt_{}'.format(FLAGS.exp_dir, int(s / steps_per_epoch)), 'wr') as fckpt: pickle.dump(optimizers.unpack_optimizer_state(opt_state), fckpt) if FLAGS.dpsgd: eps = compute_epsilon(s, configs.batch_size, n_train, configs.target_delta, configs.noise_multiplier) stdout_log.write( 'For delta={:.0e}, current epsilon is: {:.2f}\n'.format( configs.target_delta, eps)) logging.info('Elapsed EPOCH time: %s', str(time.time() - t)) stdout_log.write('Elapsed EPOCH time: {}'.format(time.time() - t)) stdout_log.flush() t = time.time() toc = time.time() summarize(configs.train_steps, params) logging.info('Elapsed TRAIN time: %s', str(toc - tic)) stdout_log.write('Elapsed TRAIN time: {}'.format(toc - tic)) stdout_log.close()
def main(_): sns.set() sns.set_palette(sns.color_palette('hls', 10)) npr.seed(FLAGS.seed) logging.info('Starting experiment.') # Create model folder for outputs try: gfile.MakeDirs(FLAGS.work_dir) except gfile.GOSError: pass stdout_log = gfile.Open('{}/stdout.log'.format(FLAGS.work_dir), 'w+') # BEGIN: fetch test data and candidate pool test_images, test_labels, _ = datasets.get_dataset_split( name=FLAGS.test_split.split('-')[0], split=FLAGS.test_split.split('-')[1], shuffle=False) pool_images, pool_labels, _ = datasets.get_dataset_split( name=FLAGS.pool_split.split('-')[0], split=FLAGS.pool_split.split('-')[1], shuffle=False) n_pool = len(pool_images) # normalize to range [-1.0, 127./128] test_images = test_images / np.float32(128.0) - np.float32(1.0) pool_images = pool_images / np.float32(128.0) - np.float32(1.0) # augmentation for train/pool data if FLAGS.augment_data: augmentation = data.chain_transforms(data.RandomHorizontalFlip(0.5), data.RandomCrop(4), data.ToDevice) else: augmentation = None # END: fetch test data and candidate pool _, opt_update, get_params = optimizers.sgd(FLAGS.learning_rate) # BEGIN: load ckpt ckpt_dir = '{}/{}'.format(FLAGS.root_dir, FLAGS.ckpt_idx) with gfile.Open(ckpt_dir, 'wr') as fckpt: opt_state = optimizers.pack_optimizer_state(pickle.load(fckpt)) params = get_params(opt_state) stdout_log.write('finetune from: {}\n'.format(ckpt_dir)) logging.info('finetune from: %s', ckpt_dir) test_acc, test_pred = accuracy(params, shape_as_image(test_images, test_labels), return_predicted_class=True) logging.info('test accuracy: %.2f', test_acc) stdout_log.write('test accuracy: {}\n'.format(test_acc)) stdout_log.flush() # END: load ckpt # BEGIN: setup for dp model @jit def update(_, i, opt_state, batch): params = get_params(opt_state) return opt_update(i, grad_loss(params, batch), opt_state) @jit def private_update(rng, i, opt_state, batch): params = get_params(opt_state) rng = random.fold_in(rng, i) # get new key for new random numbers return opt_update( i, private_grad(params, batch, rng, FLAGS.l2_norm_clip, FLAGS.noise_multiplier, FLAGS.batch_size), opt_state) # END: setup for dp model ### BEGIN: prepare extra points picked from pool data # BEGIN: on pool data pool_embeddings = [apply_fn_0(params[:-1], pool_images[b_i:b_i + FLAGS.batch_size]) \ for b_i in range(0, n_pool, FLAGS.batch_size)] pool_embeddings = np.concatenate(pool_embeddings, axis=0) pool_logits = apply_fn_1(params[-1:], pool_embeddings) pool_true_labels = np.argmax(pool_labels, axis=1) pool_predicted_labels = np.argmax(pool_logits, axis=1) pool_correct_indices = \ onp.where(pool_true_labels == pool_predicted_labels)[0] pool_incorrect_indices = \ onp.where(pool_true_labels != pool_predicted_labels)[0] assert len(pool_correct_indices) + \ len(pool_incorrect_indices) == len(pool_labels) pool_probs = stax.softmax(pool_logits) if FLAGS.uncertain == 0 or FLAGS.uncertain == 'entropy': pool_entropy = -onp.sum(pool_probs * onp.log(pool_probs), axis=1) stdout_log.write('all {} entropy: min {}, max {}\n'.format( len(pool_entropy), onp.min(pool_entropy), onp.max(pool_entropy))) pool_entropy_sorted_indices = onp.argsort(pool_entropy) # take the n_extra most uncertain points pool_uncertain_indices = \ pool_entropy_sorted_indices[::-1][:FLAGS.n_extra] stdout_log.write('uncertain {} entropy: min {}, max {}\n'.format( len(pool_entropy[pool_uncertain_indices]), onp.min(pool_entropy[pool_uncertain_indices]), onp.max(pool_entropy[pool_uncertain_indices]))) elif FLAGS.uncertain == 1 or FLAGS.uncertain == 'difference': # 1st_prob - 2nd_prob assert len(pool_probs.shape) == 2 sorted_pool_probs = onp.sort(pool_probs, axis=1) pool_probs_diff = sorted_pool_probs[:, -1] - sorted_pool_probs[:, -2] assert min(pool_probs_diff) > 0. stdout_log.write('all {} difference: min {}, max {}\n'.format( len(pool_probs_diff), onp.min(pool_probs_diff), onp.max(pool_probs_diff))) pool_uncertain_indices = onp.argsort(pool_probs_diff)[:FLAGS.n_extra] stdout_log.write('uncertain {} difference: min {}, max {}\n'.format( len(pool_probs_diff[pool_uncertain_indices]), onp.min(pool_probs_diff[pool_uncertain_indices]), onp.max(pool_probs_diff[pool_uncertain_indices]))) elif FLAGS.uncertain == 2 or FLAGS.uncertain == 'random': pool_uncertain_indices = npr.permutation(n_pool)[:FLAGS.n_extra] # END: on pool data ### END: prepare extra points picked from pool data finetune_images = copy.deepcopy(pool_images[pool_uncertain_indices]) finetune_labels = copy.deepcopy(pool_labels[pool_uncertain_indices]) stdout_log.write('Starting fine-tuning...\n') logging.info('Starting fine-tuning...') stdout_log.flush() stdout_log.write('{} points picked via {}\n'.format( len(finetune_images), FLAGS.uncertain)) logging.info('%d points picked via %s', len(finetune_images), FLAGS.uncertain) assert FLAGS.n_extra == len(finetune_images) for epoch in range(1, FLAGS.epochs + 1): # BEGIN: finetune model with extra data, evaluate and save num_extra = len(finetune_images) num_complete_batches, leftover = divmod(num_extra, FLAGS.batch_size) num_batches = num_complete_batches + bool(leftover) finetune = data.DataChunk(X=finetune_images, Y=finetune_labels, image_size=28, image_channels=1, label_dim=1, label_format='numeric') batches = data.minibatcher(finetune, FLAGS.batch_size, transform=augmentation) itercount = itertools.count() key = random.PRNGKey(FLAGS.seed) start_time = time.time() for _ in range(num_batches): # tmp_time = time.time() b = next(batches) if FLAGS.dpsgd: opt_state = private_update( key, next(itercount), opt_state, shape_as_image(b.X, b.Y, dummy_dim=True)) else: opt_state = update(key, next(itercount), opt_state, shape_as_image(b.X, b.Y)) # stdout_log.write('single update in {:.2f} sec\n'.format( # time.time() - tmp_time)) epoch_time = time.time() - start_time stdout_log.write('Epoch {} in {:.2f} sec\n'.format(epoch, epoch_time)) logging.info('Epoch %d in %.2f sec', epoch, epoch_time) # accuracy on test data params = get_params(opt_state) test_pred_0 = test_pred test_acc, test_pred = accuracy(params, shape_as_image(test_images, test_labels), return_predicted_class=True) test_loss = loss(params, shape_as_image(test_images, test_labels)) stdout_log.write( 'Eval set loss, accuracy (%): ({:.2f}, {:.2f})\n'.format( test_loss, 100 * test_acc)) logging.info('Eval set loss, accuracy: (%.2f, %.2f)', test_loss, 100 * test_acc) stdout_log.flush() # visualize prediction difference between 2 checkpoints. if FLAGS.visualize: utils.visualize_ckpt_difference(test_images, np.argmax(test_labels, axis=1), test_pred_0, test_pred, epoch - 1, epoch, FLAGS.work_dir, mu=128., sigma=128.) # END: finetune model with extra data, evaluate and save stdout_log.close()