def dot_product_attention(query, key, value, mask, dropout, mode, rng): """Core dot product self-attention. Args: query: array of representations key: array of representations value: array of representations mask: attention-mask, gates attention dropout: float: dropout rate - keep probability mode: 'eval' or 'train': whether to use dropout rng: JAX PRNGKey: subkey for disposable use Returns: Self attention for q, k, v arrays. """ depth = np.shape(query)[-1] dots = np.matmul(query, np.swapaxes(key, -1, -2)) / np.sqrt(depth) if mask is not None: dots = np.where(mask, dots, -1e9) dots = stax.softmax(dots, axis=-1) if dropout is not None and mode == 'train': keep = random.bernoulli(rng, dropout, dots.shape) dots = np.where(keep, dots / dropout, 0) out = np.matmul(dots, value) return out
def objective(x, params): del params policy = softmax((1. / temperature) * x) ppi = np.einsum('ast,sa->st', true_transition, policy) rpi = np.einsum('sa,sa->s', true_reward, policy) vf = np.linalg.solve( np.eye(true_transition.shape[-1]) - true_discount * ppi, rpi) return initial_distribution @ vf
def discriminator_loss(discriminator_logits, traj_model, traj_expert): discriminator = softmax((1. / temperature) * discriminator_logits) loss = 0 for i in range(traj_len): s_model, a_model = traj_model[i] s_expert, a_expert = traj_expert[i] loss += jnp.log(discriminator[s_model][a_model]) + jnp.log( 1 - discriminator[s_expert][a_expert]) return loss / traj_len
def value_estimation_first_visit(discriminator_logits, model_logits): discriminator = softmax((1. / temperature) * discriminator_logits) policy = softmax((1. / temperature) * model_logits) v0 = [] v1 = [] for i in range(10): rewards0, _ = sample_rewards(policy, discriminator, initialization=True, initial_state=0) rewards1, _ = sample_rewards(policy, discriminator, initialization=True, initial_state=1) v0.append(discounted_reward(rewards0, 0)) v1.append(discounted_reward(rewards1, 0)) v0 = jnp.array(v0).mean() v1 = jnp.array(v1).mean() return jnp.array([v0, v1])
def objective(x, params): del params policy = softmax((1. / temperature) * x) # [2, 2] cumulent = np.log(np.einsum('sa,ast->sat', policy, true_transition)) cumulent = np.einsum('sat,ast->sa', cumulent, true_transition) likelihood = policy_evaluation(true_transition, cumulent, true_discount, policy_expert) print("policy", policy) return initial_distribution @ likelihood
def f(decision_variables): x, theta = decision_variables transition_logits, reward_hat = theta transition_hat = softmax( (1. / arguments.temperature_transition) * transition_logits) op_params = (transition_hat, reward_hat, true_discount, arguments.temperature) return smooth_bellman_optimality_operator(x, op_params)
def random_categorical(logits, num_samples, seed): """Returns a sample from a categorical distribution. `logits` must be 2D.""" probs = stax.softmax(logits) cum_sum = np.cumsum(probs, axis=-1) eta = random.uniform(seed, (num_samples,) + cum_sum.shape[:-1]) cum_sum = np.broadcast_to(cum_sum, (num_samples,) + cum_sum.shape) flat_cum_sum = cum_sum.reshape([-1, cum_sum.shape[-1]]) flat_eta = eta.reshape([-1]) return jax.vmap(_searchsorted)(flat_cum_sum, flat_eta).reshape(eta.shape).T
def value_approximiation(discriminator_logits, model_logits, batch=5, threshold=1e-2): discriminator = softmax((1. / temperature) * discriminator_logits) policy_model = softmax((1. / temperature) * model_logits) v0_rewards, _ = sample_rewards(policy_model, 0, discriminator) v1_rewards, _ = sample_rewards(policy_model, 1, discriminator) value = jnp.array( [discounted_reward(v0_rewards, 0), discounted_reward(v1_rewards, 0)]) opt_init3, opt_update3, get_params3 = optimizers.adam(step_size=0.01) opt_state3 = opt_init3(value) prev = value for i in range(30): rewards, _ = sample_rewards(policy_model, 0, discriminator) v0 = discounted_reward(rewards, 0) rewards, _ = sample_rewards(policy_model, 1, discriminator) v1 = discounted_reward(rewards, 0) print("check", v0, v1) grad_loss = jax.grad(value_loss, (0))(value, jnp.array([v0, v1])) opt_state3 = opt_update3(i, grad_loss, opt_state3) value = get_params(opt_state3) print("value: ", value.flatten(), "prev: ", prev.flatten()) if i > 0 and abs(jnp.max(value - prev)) <= threshold: print(value - prev) print("converged in ", i) return value prev = copy.deepcopy(get_params(opt_state3)) print("************not converged") return value
def discriminator_loss(discriminator_logits, states_model, actions_model, states_expert, actions_expert, traj, discriminator_temperature=1e-2): traj_len = len(traj) states_model = states_model[:traj_len] actions_model = actions_model[:traj_len] states_expert = states_expert[:traj_len] actions_expert = actions_expert[:traj_len] discriminator = softmax((1. / discriminator_temperature) * discriminator_logits) negative = jnp.log(discriminator[(states_model, actions_model)] + 1e-10).sum() positive = jnp.log(1 - discriminator[(states_expert, actions_expert)] + 1e-10).sum() # gradient gradient_penalty return (negative + positive) / traj_len
def save_solution(params, arguments, prefix='solution'): transition_logits, reward_hat = params transition_hat = softmax( (1. / arguments.temperature_transition) * transition_logits) solution = (onp.asarray(transition_hat), onp.asarray(reward_hat)) data = {'solution': solution, 'args': arguments} timestamp = time.strftime("%Y%m%d-%H%M%S") filename = f"{prefix}-{timestamp}.pkl" with open(filename, 'wb') as file: pickle.dump(data, file)
def get_policy_grad_naive(discriminator_logits, model_logits, traj_model): discriminator = softmax((1. / temperature) * discriminator_logits) estimator = 0 gen_losses = [] rewards = [] for i in range(traj_len): s_model, a_model = traj_model[i] gen_losses.append((jnp.log(discriminator[s_model][a_model]))) for i in range(traj_len): rewards.append(discounted_reward(gen_losses, i, gamma=0.9)) return reinforce(discriminator_logits, model_logits, rewards, traj_model)
def value_estimation(model_logits, key, traj_len, discriminator, policy_temperature=1e-1): policy = softmax((1. / policy_temperature) * model_logits) v0 = [] v1 = [] for _ in range(5): sample_length = int(traj_len/2) traj = jnp.ones((sample_length)) states0, actions0, key = sample_trajectory(policy, key, traj, init_state=True, my_s=0) rewards0 = jnp.log(discriminator[(states0, actions0)] + 1e-8) states1, actions1, key = sample_trajectory(policy, key, traj, init_state=True, my_s=1) rewards1 = jnp.log(discriminator[(states1, actions1)] + 1e-8) v0.append(discounted_reward(rewards0, 0, traj_len)) v1.append(discounted_reward(rewards1, 0, traj_len)) v0 = jnp.array(v0).mean() v1 = jnp.array(v1).mean() return jnp.array([v0, v1]), key
def get_policy_grad_gae(discriminator_logits, model_logits, traj_model): discriminator = softmax((1. / temperature) * discriminator_logits) rewards = [] values = [] returns = [] advantages = [] for i in range(traj_len): s_model, a_model = traj_model[i] rewards.append(jnp.log(discriminator[s_model][a_model])) values = value_estimation_first_visit(discriminator_logits, model_logits) for t in range(traj_len): advantages.append(gae(values, rewards, traj_model, t)) return reinforce(discriminator_logits, model_logits, advantages, traj_model)
def cumulative_hazard(self, params, t): # weights p1, p2, p3 = np.maximum(stax.softmax(params[:3]), 1e-25) # weibull params lambda_, rho_ = np.exp(params[3]), np.exp(params[4]) # loglogistic params alpha_, beta_ = np.exp(params[5]), np.exp(params[6]) v = -sp.special.logsumexp( np.hstack(( np.log(p1) - (t / lambda_)**rho_, np.log(p2) - sp.special.logsumexp( np.hstack( (0, beta_ * np.log(t) - beta_ * np.log(alpha_)))), np.log(p3), ))) return v
def update_fun(step, grads, state): """Apply a step of the optimzier.""" del step # Unused. params, grad_seq = state # Update gradient history. grad_seq = append_to_sequence(grad_seq, grads) # Compute normalized gram matrix. gram = innerprod(grad_seq, grad_seq) grad_norm = norms(grad_seq) gram /= (jnp.outer(grad_norm, grad_norm) + 1e-6) # Compute update terms. attn_weights = jnp.dot(stax.softmax(gram, axis=0), theta_gram) attn_term = jnp.tensordot(attn_weights, grad_seq, axes=1) grad_term = jnp.tensordot(theta_grad, grad_seq, axes=1) params -= (grad_term + attn_term) return (params, grad_seq)
def main(_): sns.set() sns.set_palette(sns.color_palette('hls', 10)) npr.seed(FLAGS.seed) logging.info('Starting experiment.') # Create model folder for outputs try: gfile.MakeDirs(FLAGS.work_dir) except gfile.GOSError: pass stdout_log = gfile.Open('{}/stdout.log'.format(FLAGS.work_dir), 'w+') # BEGIN: set up optimizer and load params _, opt_update, get_params = optimizers.sgd(FLAGS.learning_rate) ckpt_dir = '{}/{}'.format(FLAGS.root_dir, FLAGS.ckpt_idx) with gfile.Open(ckpt_dir, 'rb') as fckpt: opt_state = optimizers.pack_optimizer_state(pickle.load(fckpt)) params = get_params(opt_state) stdout_log.write('finetune from: {}\n'.format(ckpt_dir)) logging.info('finetune from: %s', ckpt_dir) # END: set up optimizer and load params # BEGIN: train data emb, PCA, uncertain train_images, train_labels, _ = datasets.get_dataset_split( name=FLAGS.train_split.split('-')[0], split=FLAGS.train_split.split('-')[1], shuffle=False) n_train = len(train_images) # use mean/std of svhn train train_mu, train_std = 128., 128. train_images = (train_images - train_mu) / train_std # embeddings of all the training points train_embeddings = [apply_fn_0(params[:-1], train_images[b_i:b_i + FLAGS.batch_size]) \ for b_i in range(0, n_train, FLAGS.batch_size)] train_embeddings = np.concatenate(train_embeddings, axis=0) # fit PCA onto embeddings of all the training points if FLAGS.dppca_eps is not None: pc_cols, e_vals = dp_pca.dp_pca(train_embeddings, train_embeddings.shape[1], epsilon=FLAGS.dppca_eps, delta=1e-5, sigma=None) else: big_pca = sklearn.decomposition.PCA( n_components=train_embeddings.shape[1]) big_pca.random_state = FLAGS.seed big_pca.fit(train_embeddings) # filter out uncertain train points n_uncertain = FLAGS.n_extra + FLAGS.uncertain_extra train_probs = stax.softmax(apply_fn_1(params[-1:], train_embeddings)) train_acc = np.mean( np.argmax(train_probs, axis=1) == np.argmax(train_labels, axis=1)) logging.info('initial train acc: %.2f', train_acc) if FLAGS.uncertain == 0 or FLAGS.uncertain == 'entropy': # entropy train_entropy = -onp.sum(train_probs * onp.log(train_probs), axis=1) train_uncertain_indices = \ onp.argsort(train_entropy)[::-1][:n_uncertain] elif FLAGS.uncertain == 1 or FLAGS.uncertain == 'difference': # 1st_prob - 2nd_prob assert len(train_probs.shape) == 2 sorted_train_probs = onp.sort(train_probs, axis=1) train_probs_diff = sorted_train_probs[:, -1] - sorted_train_probs[:, -2] assert min(train_probs_diff) > 0. train_uncertain_indices = onp.argsort(train_probs_diff)[:n_uncertain] if FLAGS.dppca_eps is not None: train_uncertain_projected_embeddings, _ = utils.project_embeddings( train_embeddings[train_uncertain_indices], pca_object=None, n_components=FLAGS.k_components, pc_cols=pc_cols) else: train_uncertain_projected_embeddings, _ = utils.project_embeddings( train_embeddings[train_uncertain_indices], big_pca, FLAGS.k_components) logging.info('projected embeddings of uncertain train data') del train_images, train_labels, train_embeddings # END: train data emb, PCA, uncertain # BEGIN: pool data emb pool_images, pool_labels, _ = datasets.get_dataset_split( name=FLAGS.pool_split.split('-')[0], split=FLAGS.pool_split.split('-')[1], shuffle=False) n_pool = len(pool_images) pool_images = (pool_images - train_mu) / train_std # normalize w train mu/std pool_embeddings = [apply_fn_0(params[:-1], pool_images[b_i:b_i + FLAGS.batch_size]) \ for b_i in range(0, n_pool, FLAGS.batch_size)] pool_embeddings = np.concatenate(pool_embeddings, axis=0) # filter out uncertain pool points pool_probs = stax.softmax(apply_fn_1(params[-1:], pool_embeddings)) if FLAGS.uncertain == 0 or FLAGS.uncertain == 'entropy': # entropy pool_entropy = -onp.sum(pool_probs * onp.log(pool_probs), axis=1) pool_uncertain_indices = onp.argsort(pool_entropy)[::-1][:n_uncertain] elif FLAGS.uncertain == 1 or FLAGS.uncertain == 'difference': # 1st_prob - 2nd_prob assert len(pool_probs.shape) == 2 sorted_pool_probs = onp.sort(pool_probs, axis=1) pool_probs_diff = sorted_pool_probs[:, -1] - sorted_pool_probs[:, -2] assert min(pool_probs_diff) > 0. pool_uncertain_indices = onp.argsort(pool_probs_diff)[:n_uncertain] # constrain pool candidates to ONLY uncertain ones pool_images = pool_images[pool_uncertain_indices] pool_labels = pool_labels[pool_uncertain_indices] pool_embeddings = pool_embeddings[pool_uncertain_indices] n_pool = len(pool_uncertain_indices) if FLAGS.dppca_eps is not None: pool_projected_embeddings, _ = utils.project_embeddings( pool_embeddings, pca_object=None, n_components=FLAGS.k_components, pc_cols=pc_cols) else: pool_projected_embeddings, _ = utils.project_embeddings( pool_embeddings, big_pca, FLAGS.k_components) del pool_embeddings logging.info('projected embeddings of pool data') # END: pool data emb # BEGIN: assign train_uncertain_projected_embeddings to ONLY 1 point/cluster # assign uncertain train to closest pool point/cluster pool_index_histogram = onp.zeros(n_pool) for i in range(len(train_uncertain_projected_embeddings)): # t0 = time.time() train_uncertain_point = \ train_uncertain_projected_embeddings[i].reshape(1, -1) if FLAGS.distance == 0 or FLAGS.distance == 'euclidean': cluster_distances = euclidean_distances( pool_projected_embeddings, train_uncertain_point).reshape(-1) elif FLAGS.distance == 1 or FLAGS.distance == 'weighted_euclidean': weights = e_vals[:FLAGS.k_components] if FLAGS.dppca_eps is not None \ else big_pca.singular_values_[:FLAGS.k_components] cluster_distances = weighted_euclidean_distances( pool_projected_embeddings, train_uncertain_point, weights) pool_index = onp.argmin(cluster_distances) pool_index_histogram[pool_index] += 1. # t1 = time.time() # logging.info('%d uncertain train, %s second', i, str(t1 - t0)) del cluster_distances # add Laplacian noise onto #neighors if FLAGS.extra_eps is not None: pool_index_histogram += npr.laplace(scale=FLAGS.extra_eps - FLAGS.dppca_eps, size=pool_index_histogram.shape) pool_picked_indices = onp.argsort( pool_index_histogram)[::-1][:FLAGS.n_extra] logging.info('%d extra pool data picked', len(pool_picked_indices)) # END: assign train_uncertain_projected_embeddings to ONLY 1 cluster # load test data test_images, test_labels, _ = datasets.get_dataset_split( name=FLAGS.test_split.split('-')[0], split=FLAGS.test_split.split('-')[1], shuffle=False) test_images = (test_images - train_mu) / train_std # normalize w train mu/std # augmentation for train/pool data if FLAGS.augment_data: augmentation = data.chain_transforms(data.RandomHorizontalFlip(0.5), data.RandomCrop(4), data.ToDevice) else: augmentation = None test_acc, test_pred = accuracy(params, shape_as_image(test_images, test_labels), return_predicted_class=True) logging.info('test accuracy: %.2f', test_acc) stdout_log.write('test accuracy: {}\n'.format(test_acc)) stdout_log.flush() worst_test_acc, best_test_acc, best_epoch = test_acc, test_acc, 0 # BEGIN: setup for dp model @jit def update(_, i, opt_state, batch): params = get_params(opt_state) return opt_update(i, grad_loss(params, batch), opt_state) @jit def private_update(rng, i, opt_state, batch): params = get_params(opt_state) rng = random.fold_in(rng, i) # get new key for new random numbers return opt_update( i, private_grad(params, batch, rng, FLAGS.l2_norm_clip, FLAGS.noise_multiplier, FLAGS.batch_size), opt_state) # END: setup for dp model finetune_images = copy.deepcopy(pool_images[pool_picked_indices]) finetune_labels = copy.deepcopy(pool_labels[pool_picked_indices]) logging.info('Starting fine-tuning...') stdout_log.write('Starting fine-tuning...\n') stdout_log.flush() # BEGIN: gather points to be used for finetuning stdout_log.write('{} points picked via {}\n'.format( len(finetune_images), FLAGS.uncertain)) logging.info('%d points picked via %s', len(finetune_images), FLAGS.uncertain) assert FLAGS.n_extra == len(finetune_images) # END: gather points to be used for finetuning for epoch in range(1, FLAGS.epochs + 1): # BEGIN: finetune model with extra data, evaluate and save num_extra = len(finetune_images) num_complete_batches, leftover = divmod(num_extra, FLAGS.batch_size) num_batches = num_complete_batches + bool(leftover) finetune = data.DataChunk(X=finetune_images, Y=finetune_labels, image_size=28, image_channels=1, label_dim=1, label_format='numeric') batches = data.minibatcher(finetune, FLAGS.batch_size, transform=augmentation) itercount = itertools.count() key = random.PRNGKey(FLAGS.seed) start_time = time.time() for _ in range(num_batches): # tmp_time = time.time() b = next(batches) if FLAGS.dpsgd: opt_state = private_update( key, next(itercount), opt_state, shape_as_image(b.X, b.Y, dummy_dim=True)) else: opt_state = update(key, next(itercount), opt_state, shape_as_image(b.X, b.Y)) # stdout_log.write('single update in {:.2f} sec\n'.format( # time.time() - tmp_time)) epoch_time = time.time() - start_time stdout_log.write('Epoch {} in {:.2f} sec\n'.format(epoch, epoch_time)) logging.info('Epoch %d in %.2f sec', epoch, epoch_time) # accuracy on test data params = get_params(opt_state) test_pred_0 = test_pred test_acc, test_pred = accuracy(params, shape_as_image(test_images, test_labels), return_predicted_class=True) test_loss = loss(params, shape_as_image(test_images, test_labels)) stdout_log.write( 'Eval set loss, accuracy (%): ({:.2f}, {:.2f})\n'.format( test_loss, 100 * test_acc)) logging.info('Eval set loss, accuracy: (%.2f, %.2f)', test_loss, 100 * test_acc) stdout_log.flush() # visualize prediction difference between 2 checkpoints. if FLAGS.visualize: utils.visualize_ckpt_difference(test_images, np.argmax(test_labels, axis=1), test_pred_0, test_pred, epoch - 1, epoch, FLAGS.work_dir, mu=train_mu, sigma=train_std) worst_test_acc = min(test_acc, worst_test_acc) if test_acc > best_test_acc: best_test_acc, best_epoch = test_acc, epoch # save opt_state with gfile.Open('{}/acc_ckpt'.format(FLAGS.work_dir), 'wb') as fckpt: pickle.dump(optimizers.unpack_optimizer_state(opt_state), fckpt) # END: finetune model with extra data, evaluate and save stdout_log.write('best test acc {} @E {}\n'.format(best_test_acc, best_epoch)) stdout_log.close()
def _rvs(self, n, p): if self.is_logits: p = softmax(p) return multinomial_rvs(self._random_state, p, n, self._size)
def compute_entropy(log_probs): """Compute entropy of a set of log_probs.""" return -jnp.mean(jnp.mean(stax.softmax(log_probs) * log_probs, axis=-1))
def omd_objective(decision_variables): x, _ = decision_variables policy = softmax((1. / arguments.temperature) * x) return -expected_return(mdp, initial_distribution, policy)
def get_log_policy(logits, state, action, policy_temperature = 1e-1): policy = softmax((1. / policy_temperature) * logits) return jnp.log(policy[state][action] + 1e-8)
def equality_constraints(x, params): transition_logits, reward_hat = params transition_hat = softmax((1. / temperature) * transition_logits) params = (transition_hat, reward_hat, true_discount, temperature) return smooth_bellman_optimality_operator(x, params) - x
model_logits = get_params(opt_state) return opt_state, model_logits # initialization discriminator_logits = jnp.ones((2, 2)) model_logits = jnp.ones((2, 2)) opt_init, opt_update, get_params = optimizers.adam(step_size=0.001) opt_state = opt_init(discriminator_logits) opt_init2, opt_update2, get_params2 = optimizers.adam(step_size=0.001) opt_state2 = opt_init2(model_logits) for i in range(50): policy_model = softmax((1. / temperature) * model_logits) traj_model = sample_trajectory(policy_model) traj_expert = sample_trajectory(policy_expert) opt_state, discriminator_logits = update_discriminator( i, discriminator_logits, traj_model, traj_expert, opt_state, opt_update, get_params) opt_state2, model_logits = update_policy(i, discriminator_logits, model_logits, traj_model, opt_state2, opt_update2, get_params2) print("discriminator_logits: \n", discriminator_logits) print("model_logits: \n", model_logits) print("policy: \n", softmax((1. / temperature) * model_logits)) print("")
def get_log_policy(model_logits, s, a): policy_model = softmax((1. / temperature) * model_logits) return jnp.log(policy_model[s][a])
def main(_): sns.set() sns.set_palette(sns.color_palette('hls', 10)) npr.seed(FLAGS.seed) logging.info('Starting experiment.') # Create model folder for outputs try: gfile.MakeDirs(FLAGS.work_dir) except gfile.GOSError: pass stdout_log = gfile.Open('{}/stdout.log'.format(FLAGS.work_dir), 'w+') # BEGIN: fetch test data and candidate pool test_images, test_labels, _ = datasets.get_dataset_split( name=FLAGS.test_split.split('-')[0], split=FLAGS.test_split.split('-')[1], shuffle=False) pool_images, pool_labels, _ = datasets.get_dataset_split( name=FLAGS.pool_split.split('-')[0], split=FLAGS.pool_split.split('-')[1], shuffle=False) n_pool = len(pool_images) # normalize to range [-1.0, 127./128] test_images = test_images / np.float32(128.0) - np.float32(1.0) pool_images = pool_images / np.float32(128.0) - np.float32(1.0) # augmentation for train/pool data if FLAGS.augment_data: augmentation = data.chain_transforms(data.RandomHorizontalFlip(0.5), data.RandomCrop(4), data.ToDevice) else: augmentation = None # END: fetch test data and candidate pool _, opt_update, get_params = optimizers.sgd(FLAGS.learning_rate) # BEGIN: load ckpt ckpt_dir = '{}/{}'.format(FLAGS.root_dir, FLAGS.ckpt_idx) with gfile.Open(ckpt_dir, 'rb') as fckpt: opt_state = optimizers.pack_optimizer_state(pickle.load(fckpt)) params = get_params(opt_state) stdout_log.write('finetune from: {}\n'.format(ckpt_dir)) logging.info('finetune from: %s', ckpt_dir) test_acc, test_pred = accuracy(params, shape_as_image(test_images, test_labels), return_predicted_class=True) logging.info('test accuracy: %.2f', test_acc) stdout_log.write('test accuracy: {}\n'.format(test_acc)) stdout_log.flush() # END: load ckpt # BEGIN: setup for dp model @jit def update(_, i, opt_state, batch): params = get_params(opt_state) return opt_update(i, grad_loss(params, batch), opt_state) @jit def private_update(rng, i, opt_state, batch): params = get_params(opt_state) rng = random.fold_in(rng, i) # get new key for new random numbers return opt_update( i, private_grad(params, batch, rng, FLAGS.l2_norm_clip, FLAGS.noise_multiplier, FLAGS.batch_size), opt_state) # END: setup for dp model n_uncertain = FLAGS.n_extra + FLAGS.uncertain_extra ### BEGIN: prepare extra points picked from pool data # BEGIN: on pool data pool_embeddings = [apply_fn_0(params[:-1], pool_images[b_i:b_i + FLAGS.batch_size]) \ for b_i in range(0, n_pool, FLAGS.batch_size)] pool_embeddings = np.concatenate(pool_embeddings, axis=0) pool_logits = apply_fn_1(params[-1:], pool_embeddings) pool_true_labels = np.argmax(pool_labels, axis=1) pool_predicted_labels = np.argmax(pool_logits, axis=1) pool_correct_indices = \ onp.where(pool_true_labels == pool_predicted_labels)[0] pool_incorrect_indices = \ onp.where(pool_true_labels != pool_predicted_labels)[0] assert len(pool_correct_indices) + \ len(pool_incorrect_indices) == len(pool_labels) pool_probs = stax.softmax(pool_logits) if FLAGS.uncertain == 0 or FLAGS.uncertain == 'entropy': pool_entropy = -onp.sum(pool_probs * onp.log(pool_probs), axis=1) stdout_log.write('all {} entropy: min {}, max {}\n'.format( len(pool_entropy), onp.min(pool_entropy), onp.max(pool_entropy))) pool_entropy_sorted_indices = onp.argsort(pool_entropy) # take the n_uncertain most uncertain points pool_uncertain_indices = \ pool_entropy_sorted_indices[::-1][:n_uncertain] stdout_log.write('uncertain {} entropy: min {}, max {}\n'.format( len(pool_entropy[pool_uncertain_indices]), onp.min(pool_entropy[pool_uncertain_indices]), onp.max(pool_entropy[pool_uncertain_indices]))) elif FLAGS.uncertain == 1 or FLAGS.uncertain == 'difference': # 1st_prob - 2nd_prob assert len(pool_probs.shape) == 2 sorted_pool_probs = onp.sort(pool_probs, axis=1) pool_probs_diff = sorted_pool_probs[:, -1] - sorted_pool_probs[:, -2] assert min(pool_probs_diff) > 0. pool_uncertain_indices = onp.argsort(pool_probs_diff)[:n_uncertain] # END: on pool data # BEGIN: cluster uncertain pool points big_pca = sklearn.decomposition.PCA(n_components=pool_embeddings.shape[1]) big_pca.random_state = FLAGS.seed # fit PCA onto embeddings of all the pool points big_pca.fit(pool_embeddings) # For uncertain points, project embeddings onto the first K components pool_uncertain_projected_embeddings, _ = utils.project_embeddings( pool_embeddings[pool_uncertain_indices], big_pca, FLAGS.k_components) n_cluster = int(FLAGS.n_extra / FLAGS.ppc) cluster_method = get_cluster_method('{}_nc-{}'.format( FLAGS.clustering, n_cluster)) cluster_method.random_state = FLAGS.seed pool_uncertain_cluster_labels = cluster_method.fit_predict( pool_uncertain_projected_embeddings) pool_uncertain_cluster_label_indices = { x: [] for x in set(pool_uncertain_cluster_labels) } # local i within n_uncertain for i, c_label in enumerate(pool_uncertain_cluster_labels): pool_uncertain_cluster_label_indices[c_label].append(i) # find center of each cluster # aka, the most representative point of each 'tough' cluster pool_picked_indices = [] pool_uncertain_cluster_label_pick = {} for c_label, indices in pool_uncertain_cluster_label_indices.items(): cluster_projected_embeddings = \ pool_uncertain_projected_embeddings[indices] cluster_center = onp.mean(cluster_projected_embeddings, axis=0, keepdims=True) if FLAGS.distance == 0 or FLAGS.distance == 'euclidean': cluster_distances = euclidean_distances( cluster_projected_embeddings, cluster_center).reshape(-1) elif FLAGS.distance == 1 or FLAGS.distance == 'weighted_euclidean': cluster_distances = weighted_euclidean_distances( cluster_projected_embeddings, cluster_center, big_pca.singular_values_[:FLAGS.k_components]) sorted_is = onp.argsort(cluster_distances) sorted_indices = onp.array(indices)[sorted_is] pool_uncertain_cluster_label_indices[c_label] = sorted_indices center_i = sorted_indices[0] # center_i in 3000 pool_uncertain_cluster_label_pick[c_label] = center_i pool_picked_indices.extend( pool_uncertain_indices[sorted_indices[:FLAGS.ppc]]) # BEGIN: visualize cluster of picked uncertain pool if FLAGS.visualize: this_cluster = [] for i in sorted_indices: idx = pool_uncertain_indices[i] img = pool_images[idx] if idx in pool_correct_indices: border_color = 'green' else: border_color = 'red' img = utils.mark_labels(img, pool_predicted_labels[idx], pool_true_labels[idx]) img = utils.denormalize(img, 128., 128.) img = np.squeeze(utils.to_rgb(np.expand_dims(img, 0))) img = utils.add_border(img, width=2, color=border_color) this_cluster.append(img) utils.tile_image_list( this_cluster, '{}/picked_uncertain_pool_cid-{}'.format( FLAGS.work_dir, c_label)) # END: visualize cluster of picked uncertain pool # END: cluster uncertain pool points pool_picked_indices = list(set(pool_picked_indices)) n_gap = FLAGS.n_extra - len(pool_picked_indices) gap_indices = list(set(pool_uncertain_indices) - set(pool_picked_indices)) pool_picked_indices.extend(npr.choice(gap_indices, n_gap, replace=False)) stdout_log.write('n_gap: {}\n'.format(n_gap)) ### END: prepare extra points picked from pool data finetune_images = copy.deepcopy(pool_images[pool_picked_indices]) finetune_labels = copy.deepcopy(pool_labels[pool_picked_indices]) stdout_log.write('{} points picked via {}\n'.format( len(finetune_images), FLAGS.uncertain)) logging.info('%d points picked via %s', len(finetune_images), FLAGS.uncertain) assert FLAGS.n_extra == len(finetune_images) # END: gather points to be used for finetuning stdout_log.write('Starting fine-tuning...\n') logging.info('Starting fine-tuning...') stdout_log.flush() for epoch in range(1, FLAGS.epochs + 1): # BEGIN: finetune model with extra data, evaluate and save num_extra = len(finetune_images) num_complete_batches, leftover = divmod(num_extra, FLAGS.batch_size) num_batches = num_complete_batches + bool(leftover) finetune = data.DataChunk(X=finetune_images, Y=finetune_labels, image_size=28, image_channels=1, label_dim=1, label_format='numeric') batches = data.minibatcher(finetune, FLAGS.batch_size, transform=augmentation) itercount = itertools.count() key = random.PRNGKey(FLAGS.seed) start_time = time.time() for _ in range(num_batches): # tmp_time = time.time() b = next(batches) if FLAGS.dpsgd: opt_state = private_update( key, next(itercount), opt_state, shape_as_image(b.X, b.Y, dummy_dim=True)) else: opt_state = update(key, next(itercount), opt_state, shape_as_image(b.X, b.Y)) # stdout_log.write('single update in {:.2f} sec\n'.format( # time.time() - tmp_time)) epoch_time = time.time() - start_time stdout_log.write('Epoch {} in {:.2f} sec\n'.format(epoch, epoch_time)) logging.info('Epoch %d in %.2f sec', epoch, epoch_time) # accuracy on test data params = get_params(opt_state) test_pred_0 = test_pred test_acc, test_pred = accuracy(params, shape_as_image(test_images, test_labels), return_predicted_class=True) test_loss = loss(params, shape_as_image(test_images, test_labels)) stdout_log.write( 'Eval set loss, accuracy (%): ({:.2f}, {:.2f})\n'.format( test_loss, 100 * test_acc)) logging.info('Eval set loss, accuracy: (%.2f, %.2f)', test_loss, 100 * test_acc) stdout_log.flush() # visualize prediction difference between 2 checkpoints. if FLAGS.visualize: utils.visualize_ckpt_difference(test_images, np.argmax(test_labels, axis=1), test_pred_0, test_pred, epoch - 1, epoch, FLAGS.work_dir, mu=128., sigma=128.) # END: finetune model with extra data, evaluate and save stdout_log.close()
win_acc_keys = ['acc_train', 'acc_test', 'acc_train_lin', 'acc_test_lin'] win_acc_eval_keys = [ 'acc_train_eval', 'acc_test_eval', 'acc_train_eval_lin', 'acc_test_eval_lin' ] every_iter_plot_keys = every_iter_plot_keys + win_acc_keys + win_acc_eval_keys log = Log(keys=['update'] + every_iter_plot_keys + win_rank_eval_keys + win_spectrum_eval_keys + win_spectrum_slope_keys + win_quadratic_keys + win_cosine_keys) outer_state = outer_opt_init(params) outer_state_lin = outer_opt_init(params) rmse = jit(lambda fx, fx_lin: np.sqrt(np.mean(fx - fx_lin)**2)) tvd = jit( lambda fx, fx_lin: 0.5 * np.sum(np.abs(softmax(fx) - softmax(fx_lin)))) plotter = VisdomPlotter(viz) if args.dataset == 'sinusoid': task_fn = partial(sinusoid_task, n_support=args.n_support, n_query=args.n_query, noise_std=args.noise_std) elif args.dataset == 'omniglot': task_fn = partial(omniglot_task, split_dict=omniglot_splits['train'], n_way=args.n_way, n_support=args.n_support, n_query=args.n_query) elif args.dataset == 'circle':
def sep_objective(decision_variables): x, _ = decision_variables policy = softmax((1. / arguments.temperature) * x) return -expected_discounted_loglikelihood(mdp, initial_distribution, policy, expert_policy)
def main(_): sns.set() sns.set_palette(sns.color_palette('hls', 10)) npr.seed(FLAGS.seed) logging.info('Starting experiment.') # Create model folder for outputs try: gfile.MakeDirs(FLAGS.work_dir) except gfile.GOSError: pass stdout_log = gfile.Open('{}/stdout.log'.format(FLAGS.work_dir), 'w+') # use mean/std of svhn train train_images, _, _ = datasets.get_dataset_split( name=FLAGS.train_split.split('-')[0], split=FLAGS.train_split.split('-')[1], shuffle=False) train_mu, train_std = onp.mean(train_images), onp.std(train_images) del train_images # BEGIN: fetch test data and candidate pool test_images, test_labels, _ = datasets.get_dataset_split( name=FLAGS.test_split.split('-')[0], split=FLAGS.test_split.split('-')[1], shuffle=False) pool_images, pool_labels, _ = datasets.get_dataset_split( name=FLAGS.pool_split.split('-')[0], split=FLAGS.pool_split.split('-')[1], shuffle=False) n_pool = len(pool_images) test_images = (test_images - train_mu) / train_std # normalize w train mu/std pool_images = (pool_images - train_mu) / train_std # normalize w train mu/std # augmentation for train/pool data if FLAGS.augment_data: augmentation = data.chain_transforms(data.RandomHorizontalFlip(0.5), data.RandomCrop(4), data.ToDevice) else: augmentation = None # END: fetch test data and candidate pool # BEGIN: load ckpt opt_init, opt_update, get_params = optimizers.sgd(FLAGS.learning_rate) if FLAGS.pretrained_dir is not None: with gfile.Open(FLAGS.pretrained_dir, 'rb') as fpre: pretrained_opt_state = optimizers.pack_optimizer_state( pickle.load(fpre)) fixed_params = get_params(pretrained_opt_state)[:7] ckpt_dir = '{}/{}'.format(FLAGS.root_dir, FLAGS.ckpt_idx) with gfile.Open(ckpt_dir, 'wr') as fckpt: opt_state = optimizers.pack_optimizer_state(pickle.load(fckpt)) params = get_params(opt_state) # combine fixed pretrained params and dpsgd trained last layers params = fixed_params + params opt_state = opt_init(params) else: ckpt_dir = '{}/{}'.format(FLAGS.root_dir, FLAGS.ckpt_idx) with gfile.Open(ckpt_dir, 'wr') as fckpt: opt_state = optimizers.pack_optimizer_state(pickle.load(fckpt)) params = get_params(opt_state) stdout_log.write('finetune from: {}\n'.format(ckpt_dir)) logging.info('finetune from: %s', ckpt_dir) test_acc, test_pred = accuracy(params, shape_as_image(test_images, test_labels), return_predicted_class=True) logging.info('test accuracy: %.2f', test_acc) stdout_log.write('test accuracy: {}\n'.format(test_acc)) stdout_log.flush() # END: load ckpt # BEGIN: setup for dp model @jit def update(_, i, opt_state, batch): params = get_params(opt_state) return opt_update(i, grad_loss(params, batch), opt_state) @jit def private_update(rng, i, opt_state, batch): params = get_params(opt_state) rng = random.fold_in(rng, i) # get new key for new random numbers return opt_update( i, private_grad(params, batch, rng, FLAGS.l2_norm_clip, FLAGS.noise_multiplier, FLAGS.batch_size), opt_state) # END: setup for dp model ### BEGIN: prepare extra points picked from pool data # BEGIN: on pool data pool_embeddings = [apply_fn_0(params[:-1], pool_images[b_i:b_i + FLAGS.batch_size]) \ for b_i in range(0, n_pool, FLAGS.batch_size)] pool_embeddings = np.concatenate(pool_embeddings, axis=0) pool_logits = apply_fn_1(params[-1:], pool_embeddings) pool_true_labels = np.argmax(pool_labels, axis=1) pool_predicted_labels = np.argmax(pool_logits, axis=1) pool_correct_indices = \ onp.where(pool_true_labels == pool_predicted_labels)[0] pool_incorrect_indices = \ onp.where(pool_true_labels != pool_predicted_labels)[0] assert len(pool_correct_indices) + \ len(pool_incorrect_indices) == len(pool_labels) pool_probs = stax.softmax(pool_logits) if FLAGS.uncertain == 0 or FLAGS.uncertain == 'entropy': pool_entropy = -onp.sum(pool_probs * onp.log(pool_probs), axis=1) stdout_log.write('all {} entropy: min {}, max {}\n'.format( len(pool_entropy), onp.min(pool_entropy), onp.max(pool_entropy))) pool_entropy_sorted_indices = onp.argsort(pool_entropy) # take the n_extra most uncertain points pool_uncertain_indices = \ pool_entropy_sorted_indices[::-1][:FLAGS.n_extra] stdout_log.write('uncertain {} entropy: min {}, max {}\n'.format( len(pool_entropy[pool_uncertain_indices]), onp.min(pool_entropy[pool_uncertain_indices]), onp.max(pool_entropy[pool_uncertain_indices]))) elif FLAGS.uncertain == 1 or FLAGS.uncertain == 'difference': # 1st_prob - 2nd_prob assert len(pool_probs.shape) == 2 sorted_pool_probs = onp.sort(pool_probs, axis=1) pool_probs_diff = sorted_pool_probs[:, -1] - sorted_pool_probs[:, -2] assert min(pool_probs_diff) > 0. stdout_log.write('all {} difference: min {}, max {}\n'.format( len(pool_probs_diff), onp.min(pool_probs_diff), onp.max(pool_probs_diff))) pool_uncertain_indices = onp.argsort(pool_probs_diff)[:FLAGS.n_extra] stdout_log.write('uncertain {} difference: min {}, max {}\n'.format( len(pool_probs_diff[pool_uncertain_indices]), onp.min(pool_probs_diff[pool_uncertain_indices]), onp.max(pool_probs_diff[pool_uncertain_indices]))) elif FLAGS.uncertain == 2 or FLAGS.uncertain == 'random': pool_uncertain_indices = npr.permutation(n_pool)[:FLAGS.n_extra] # END: on pool data ### END: prepare extra points picked from pool data finetune_images = copy.deepcopy(pool_images[pool_uncertain_indices]) finetune_labels = copy.deepcopy(pool_labels[pool_uncertain_indices]) stdout_log.write('Starting fine-tuning...\n') logging.info('Starting fine-tuning...') stdout_log.flush() stdout_log.write('{} points picked via {}\n'.format( len(finetune_images), FLAGS.uncertain)) logging.info('%d points picked via %s', len(finetune_images), FLAGS.uncertain) assert FLAGS.n_extra == len(finetune_images) for epoch in range(1, FLAGS.epochs + 1): # BEGIN: finetune model with extra data, evaluate and save num_extra = len(finetune_images) num_complete_batches, leftover = divmod(num_extra, FLAGS.batch_size) num_batches = num_complete_batches + bool(leftover) finetune = data.DataChunk(X=finetune_images, Y=finetune_labels, image_size=32, image_channels=3, label_dim=1, label_format='numeric') batches = data.minibatcher(finetune, FLAGS.batch_size, transform=augmentation) itercount = itertools.count() key = random.PRNGKey(FLAGS.seed) start_time = time.time() for _ in range(num_batches): # tmp_time = time.time() b = next(batches) if FLAGS.dpsgd: opt_state = private_update( key, next(itercount), opt_state, shape_as_image(b.X, b.Y, dummy_dim=True)) else: opt_state = update(key, next(itercount), opt_state, shape_as_image(b.X, b.Y)) # stdout_log.write('single update in {:.2f} sec\n'.format( # time.time() - tmp_time)) epoch_time = time.time() - start_time stdout_log.write('Epoch {} in {:.2f} sec\n'.format(epoch, epoch_time)) logging.info('Epoch %d in %.2f sec', epoch, epoch_time) # accuracy on test data params = get_params(opt_state) test_pred_0 = test_pred test_acc, test_pred = accuracy(params, shape_as_image(test_images, test_labels), return_predicted_class=True) test_loss = loss(params, shape_as_image(test_images, test_labels)) stdout_log.write( 'Eval set loss, accuracy (%): ({:.2f}, {:.2f})\n'.format( test_loss, 100 * test_acc)) logging.info('Eval set loss, accuracy: (%.2f, %.2f)', test_loss, 100 * test_acc) stdout_log.flush() # visualize prediction difference between 2 checkpoints. if FLAGS.visualize: utils.visualize_ckpt_difference(test_images, np.argmax(test_labels, axis=1), test_pred_0, test_pred, epoch - 1, epoch, FLAGS.work_dir, mu=train_mu, sigma=train_std) # END: finetune model with extra data, evaluate and save stdout_log.close()
def _planner(params): transition_hat, reward_hat = params q0 = np.zeros_like(reward_hat) solution = smooth_value_iteration(q0, (transition_hat, reward_hat)) return softmax((1. / temperature) * solution.value)
def _rvs(self, p): if self.is_logits: p = softmax(p) return categorical_rvs(self._random_state, p, self._size)
def _objective(params): transition_logits, reward_hat = params transition_hat = softmax((1. / temperature_logits) * transition_logits) return -expected_return(mdp, initial_distribution, planner((transition_hat, reward_hat)))