def evaluate_example(self, example: SingleExample, prediction: SinglePrediction) -> MeanStat: """Computes token accuracy for a single sequence example. Args: example: One example with target in range [0, num_classes) of shape [1]. prediction: Unnormalized prediction for ``example`` of shape [num_classes] Returns: MeanStat for token accuracy for a single sequence example or at each token position if ``per_position`` is ``True``. """ target = example[self.target_key] pred = prediction if self.pred_key is None else prediction[ self.pred_key] if self.logits_mask is not None: logits_mask = jnp.array(self.logits_mask) pred += logits_mask target_weight = get_target_weight(target, self.masked_target_values) correct = (target == jnp.argmax(pred, axis=-1)).astype(jnp.float32) if self.per_position: return * target_weight, target_weight) return * target_weight), jnp.sum(target_weight))
def test_argmax_consistent(self): rngs = PRNGSequence(13) vec = jax.random.normal(next(rngs), shape=(5,)) mat = jax.random.normal(next(rngs), shape=(3, 5)) ten = jax.random.normal(next(rngs), shape=(3, 5, 7)) self.assertEqual( argmax(next(rngs), vec), jnp.argmax(vec, axis=-1)) self.assertArrayAlmostEqual( argmax(next(rngs), mat), jnp.argmax(mat, axis=-1)) self.assertArrayAlmostEqual( argmax(next(rngs), mat, axis=0), jnp.argmax(mat, axis=0)) self.assertArrayAlmostEqual( argmax(next(rngs), ten), jnp.argmax(ten, axis=-1)) self.assertArrayAlmostEqual( argmax(next(rngs), ten, axis=0), jnp.argmax(ten, axis=0)) self.assertArrayAlmostEqual( argmax(next(rngs), ten, axis=1), jnp.argmax(ten, axis=1))
def extract_best_search_space(scores_dict, centrality_key, search_spaces_reduced): """Select the arg max score search space.""" return search_spaces_reduced[jnp.argmax(scores_dict[centrality_key]), :, :]
mdbM = moldb.MdbExomol('.database/H2O/1H2-16O/POKAZATEL', nus, crit=1.e-45) # loading molecular dat molmassM = molinfo.molmass('H2O') # molecular mass (H2O) q = mdbM.qr_interp(1500.0) S = SijT(1500.0, mdbM.logsij0, mdbM.nu_lines, mdbM.elower, q) mask = S > 1.e-25 mdbM.masking(mask) Tarr = jnp.logspace(jnp.log10(800), jnp.log10(1600), 100) qt = vmap(mdbM.qr_interp)(Tarr) SijM = jit(vmap(SijT, (0, None, None, None, 0)))(Tarr, mdbM.logsij0, mdbM.nu_lines, mdbM.elower, qt) imax = jnp.argmax(SijM, axis=0) Tmax = Tarr[imax] print(jnp.min(Tmax)) pl = planck.piBarr(jnp.array([1100.0, 1000.0]), nus) print(pl[1] / pl[0]) pl = planck.piBarr(jnp.array([1400.0, 1200.0]), nus) print(pl[1] / pl[0]) lsa = ['solid', 'dashed', 'dotted', 'dashdot'] lab = ['A', 'B', 'C'] fac = 1.e22 fig = plt.figure(figsize=(12, 6)) # for j,i in enumerate(range(len(mdbM.A))): # for j,i in enumerate([56,72,141,147,173,236,259,211,290]):
def _per_batch(inputs, labels): target_class = jnp.argmax(labels, axis=1) predicted_class = jnp.argmax(predict(params, inputs), axis=1) return jnp.mean(predicted_class == target_class)
def accuracy(params, mask): logits = model.apply(params, graph, train=False) correct = jnp.argmax(logits, -1) == jnp.argmax(labels, -1) return jnp.sum(correct * mask) / jnp.sum(mask)
def ground_truth_label(sentences, lengths): scores = batch_score(sentences, lengths) return jnp.argmax(scores, axis=1)
def accuracy(params, batch, model_predict): """Calculate accuracy.""" inputs, targets = batch predicted_class = np.argmax(model_predict(params, inputs), axis=1) return np.mean(predicted_class == targets)
def loss( params: networks_lib.Params, target_params: networks_lib.Params, key_grad: networks_lib.PRNGKey, sample: reverb.ReplaySample ) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]: """Computes mean transformed N-step loss for a batch of sequences.""" # Convert sample data to sequence-major format [T, B, ...]. data = utils.batch_to_sequence( # Get core state & warm it up on observations for a burn-in period. if use_core_state: # Replay core state. online_state = jax.tree_map(lambda x: x[0], data.extras['core_state']) else: online_state = initial_state target_state = online_state # Maybe burn the core state in. if burn_in_length: burn_obs = jax.tree_map(lambda x: x[:burn_in_length], data.observation) key_grad, key1, key2 = jax.random.split(key_grad, 3) _, online_state = unroll.apply(params, key1, burn_obs, online_state) _, target_state = unroll.apply(target_params, key2, burn_obs, target_state) # Only get data to learn on from after the end of the burn in period. data = jax.tree_map(lambda seq: seq[burn_in_length:], data) # Unroll on sequences to get online and target Q-Values. key1, key2 = jax.random.split(key_grad) online_q, _ = unroll.apply(params, key1, data.observation, online_state) target_q, _ = unroll.apply(target_params, key2, data.observation, target_state) # Get value-selector actions from online Q-values for double Q-learning. selector_actions = jnp.argmax(online_q, axis=-1) # Preprocess discounts & rewards. discounts = ( * discount).astype(online_q.dtype) rewards = data.reward if clip_rewards: rewards = jnp.clip(rewards, -max_abs_reward, max_abs_reward) rewards = rewards.astype(online_q.dtype) # Get N-step transformed TD error and loss. batch_td_error_fn = jax.vmap(functools.partial( rlax.transformed_n_step_q_learning, n=bootstrap_n, tx_pair=tx_pair), in_axes=1, out_axes=1) # TODO(b/183945808): when this bug is fixed, truncations of actions, # rewards, and discounts will no longer be necessary. batch_td_error = batch_td_error_fn(online_q[:-1], data.action[:-1], target_q[1:], selector_actions[1:], rewards[:-1], discounts[:-1]) batch_loss = 0.5 * jnp.square(batch_td_error).sum(axis=0) # Importance weighting. probs = importance_weights = (1. / (probs + 1e-6)).astype(online_q.dtype) importance_weights **= importance_sampling_exponent importance_weights /= jnp.max(importance_weights) mean_loss = jnp.mean(importance_weights * batch_loss) # Calculate priorities as a mixture of max and mean sequence errors. abs_td_error = jnp.abs(batch_td_error).astype(online_q.dtype) max_priority = max_priority_weight * jnp.max(abs_td_error, axis=0) mean_priority = (1 - max_priority_weight) * jnp.mean(abs_td_error, axis=0) priorities = (max_priority + mean_priority) return mean_loss, priorities
def hmm_viterbi_log(params, obs_seq, length=None): ''' Computes, for each time step, the marginal conditional probability that the Hidden Markov Model was in each possible state given the observations that were made at each time step, i.e. P(z[i] | x[0], ..., x[num_steps - 1]) for all i from 0 to num_steps - 1 It is based on Parameters ---------- params : HMM Hidden Markov Model obs_seq: array(seq_len) History of observed states Returns ------- * array(seq_len, n_states) Alpha values * array(seq_len, n_states) Beta values * array(seq_len, n_states) Marginal conditional probability * float The loglikelihood giving log(p(x|model)) ''' seq_len = len(obs_seq) if length is None: length = seq_len trans_dist, obs_dist, init_dist = params.trans_dist, params.obs_dist, params.init_dist trans_log_probs = log_softmax(trans_dist.logits) init_log_probs = log_softmax(init_dist.logits) n_states = obs_dist.batch_shape[0] first_log_prob = init_log_probs + obs_dist.log_prob(obs_seq[0]) if seq_len == 1: return jnp.expand_dims(jnp.argmax(first_log_prob), axis=0) def viterbi_forward(prev_logp, t): obs_logp = obs_dist.log_prob(obs_seq[t]) logp = jnp.where( t <= length, prev_logp[..., None] + trans_log_probs + obs_logp[..., None, :], -jnp.inf + jnp.zeros_like(trans_log_probs)) max_logp_given_successor = jnp.where(t <= length, jnp.max(logp, axis=-2), prev_logp) most_likely_given_successor = jnp.where(t <= length, jnp.argmax(logp, axis=-2), -1) return max_logp_given_successor, most_likely_given_successor ts = jnp.arange(1, seq_len) final_log_prob, most_likely_sources = lax.scan(viterbi_forward, first_log_prob, ts) most_likely_initial_given_successor = jnp.argmax(trans_log_probs + first_log_prob, axis=-2) most_likely_sources = jnp.concatenate([ jnp.expand_dims(most_likely_initial_given_successor, axis=0), most_likely_sources ], axis=0) def viterbi_backward(state, t): state = jnp.where( t <= length, jnp.sum(most_likely_sources[t] * one_hot(state, n_states)).astype( jnp.int64), state) most_likely = jnp.where(t <= length, state, -1) return state, most_likely final_state = jnp.argmax(final_log_prob) _, most_likely_path = lax.scan(viterbi_backward, final_state, ts, reverse=True) final_state = jnp.where(length == seq_len, final_state, -1) return jnp.append(most_likely_path, final_state)
def categorical_logits(key, logits, shape=()): shape = shape or logits.shape[:-1] return np.argmax( random.gumbel(key, shape + logits.shape[-1:], logits.dtype) + logits, axis=-1)
def loss_fn( model, padded_example_and_rng, static_metadata, regularization_weights = None, reinforce_weight = 1.0, baseline_weight = 0.001, ): """Loss function for multi-pointer task. Args: model: The model to evaluate. padded_example_and_rng: Padded example to evaluate on, with a PRNGKey. static_metadata: Padding configuration for the example, since this may vary for different examples. regularization_weights: Associates side output key regexes with regularization penalties. reinforce_weight: Weight to give to the reinforce term. baseline_weight: Weight to give to the baseline. Returns: Tuple of loss and metrics. """ padded_example, rng = padded_example_and_rng # Run the model. with side_outputs.collect_side_outputs() as collected_side_outputs: with flax.nn.stochastic(rng): joint_log_probs = model(padded_example, static_metadata) # Computing the loss: # Extract logits for the correct location. log_probs_at_bug = joint_log_probs[padded_example.bug_node_index, :] # Compute p(repair) = sum[ p(node) p(repair | node) ] # -> log p(repair) = logsumexp[ log p(node) + log p (repair | node) ] log_prob_joint = jax.scipy.special.logsumexp( log_probs_at_bug + jnp.log(padded_example.repair_node_mask)) # Metrics: # Marginal log probabilities: log_prob_bug = jax.scipy.special.logsumexp(log_probs_at_bug) log_prob_repair = jax.scipy.special.logsumexp( jax.scipy.special.logsumexp(joint_log_probs, axis=0) + jnp.log(padded_example.repair_node_mask)) # Conditional log probabilities: log_prob_repair_given_bug = log_prob_joint - log_prob_bug log_prob_bug_given_repair = log_prob_joint - log_prob_repair # Majority accuracy (1 if we assign the correct tuple > 50%): # (note that this is easier to compute, since we can't currently aggregate # probability separately for each candidate.) log_half = jnp.log(0.5) majority_acc_joint = log_prob_joint > log_half # Probabilities associated with each node. node_node_probs = jnp.exp(joint_log_probs) # Accumulate across unique candidates by identifier. This has the same shape, # but only the first few values will be populated. node_candidate_probs = padded_example.unique_candidate_operator.apply_add( in_array=node_node_probs, out_array=jnp.zeros_like(node_node_probs), in_dims=[1], out_dims=[1]) # Classify: 50% decision boundary only_buggy_probs =[0, :].set(0).at[:, 0].set(0) p_buggy = jnp.sum(only_buggy_probs) pred_nobug = p_buggy <= 0.5 # Localize/repair: take most likely bug position, conditioned on being buggy pred_bug_loc, pred_cand_id = jnp.unravel_index( jnp.argmax(only_buggy_probs), only_buggy_probs.shape) actual_nobug = jnp.array(padded_example.bug_node_index == 0) actual_bug = jnp.logical_not(actual_nobug) pred_bug = jnp.logical_not(pred_nobug) metrics = { 'nll/joint': -log_prob_joint, 'nll/marginal_bug': -log_prob_bug, 'nll/marginal_repair': -log_prob_repair, 'nll/repair_given_bug': -log_prob_repair_given_bug, 'nll/bug_given_repair': -log_prob_bug_given_repair, 'inaccuracy/legacy_overall': 1 - majority_acc_joint, 'inaccuracy/overall': (~((actual_nobug & pred_nobug) | (actual_bug & pred_bug & (pred_bug_loc == padded_example.bug_node_index) & (pred_cand_id == padded_example.repair_id)))), 'inaccuracy/classification_overall': (actual_nobug != pred_nobug), 'inaccuracy/classification_given_nobug': train_util.RatioMetric( numerator=(actual_nobug & ~pred_nobug), denominator=actual_nobug), 'inaccuracy/classification_given_bug': train_util.RatioMetric( numerator=(actual_bug & ~pred_bug), denominator=actual_bug), 'inaccuracy/localized_given_bug': train_util.RatioMetric( numerator=(actual_bug & ~(pred_bug_loc == padded_example.bug_node_index)), denominator=actual_bug), 'inaccuracy/repaired_given_bug': train_util.RatioMetric( numerator=(actual_bug & ~(pred_cand_id == padded_example.repair_id)), denominator=actual_bug), 'inaccuracy/localized_repaired_given_bug': train_util.RatioMetric( numerator=(actual_bug & ~((pred_bug_loc == padded_example.bug_node_index) & (pred_cand_id == padded_example.repair_id))), denominator=actual_bug), 'inaccuracy/overall_given_bug': train_util.RatioMetric( numerator=(actual_bug & ~(pred_bug & (pred_bug_loc == padded_example.bug_node_index) & (pred_cand_id == padded_example.repair_id))), denominator=actual_bug), } loss = -log_prob_joint for k, v in collected_side_outputs.items(): # Flax collection keys will start with "/". if v.shape == (): # pylint: disable=g-explicit-bool-comparison metrics['side' + k] = v if regularization_weights: total_regularization = 0 for query, weight in regularization_weights.items():'Regularizing side outputs matching query %s', query) found = False for k, v in collected_side_outputs.items(): if, k): found = True'Regularizing %s with weight %f', k, weight) total_regularization += weight * v if not found: raise ValueError( f'Regularization query {query} did not match any side output. ' f'Side outputs were {set(collected_side_outputs.keys())}') loss = loss + total_regularization is_single_sample = any( k.endswith('one_sample_log_prob_per_edge_per_node') for k in collected_side_outputs) if is_single_sample: log_prob, = [ v for k, v in collected_side_outputs.items() if k.endswith('one_sample_log_prob_per_edge_per_node') ] baseline, = [ v for k, v in collected_side_outputs.items() if k.endswith('one_sample_reward_baseline') ] num_real_nodes = padded_example.input_graph.bundle.graph_metadata.num_nodes valid_mask = ( jnp.arange(static_metadata.bundle_padding.static_max_metadata.num_nodes) < num_real_nodes) log_prob = jnp.where(valid_mask[None, :], log_prob, 0) total_log_prob = jnp.sum(log_prob) reinforce_virtual_cost = ( total_log_prob * jax.lax.stop_gradient(loss - baseline)) baseline_penalty = jnp.square(loss - baseline) reinforce_virtual_cost_zeroed = reinforce_virtual_cost - jax.lax.stop_gradient( reinforce_virtual_cost) loss = ( loss + reinforce_weight * reinforce_virtual_cost_zeroed + baseline_weight * baseline_penalty) metrics['reinforce_virtual_cost'] = reinforce_virtual_cost metrics['baseline_penalty'] = baseline_penalty metrics['baseline'] = baseline metrics['total_log_prob'] = total_log_prob metrics = jax.tree_map(lambda x: x.astype(jnp.float32), metrics) return loss, metrics
bias_coef=args.bias_coef, activation=args.activation, norm=args.norm) _, params = net_init(rng=random.PRNGKey(42), input_shape=(-1, 2)) else: raise ValueError # loss functions if args.dataset == 'sinusoid': loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2) elif args.dataset in ['omniglot', 'circle']: loss = lambda fx, targets: -np.sum(logsoftmax(fx) * targets ) / targets.shape[0] acc = lambda fx, targets: np.mean( np.argmax(logsoftmax(fx), axis=-1) == np.argmax(targets, axis=-1)) param_acc = jit(lambda p, x, y: acc(f(p, x), y)) else: raise ValueError grad_loss = jit(grad(lambda p, x, y: loss(f(p, x), y))) param_loss = jit(lambda p, x, y: loss(f(p, x), y)) # optimizers #TODO: separate optimizers for nonlinear and linear? outer_opt_init, outer_opt_update, outer_get_params = select_opt( args.outer_opt_alg, args.outer_step_size)() inner_opt_init, inner_opt_update, inner_get_params = select_opt( args.inner_opt_alg, args.inner_step_size)() # consistent task for plotting eval if args.dataset == 'sinusoid':
def collect_trajectories(env, policy_net_apply, policy_net_params, num_trajectories=1, policy="greedy", max_timestep=None, epsilon=0.1): """Collect trajectories with the given policy net and behaviour.""" trajectories = [] for t in range(num_trajectories): t_start = time.time() rewards = [] actions = [] done = False observation = env.reset() # This is currently shaped (1, 1) + OBS, but new observations will keep # getting added to it, making it eventually (1, T+1) + OBS observation_history = observation[np.newaxis, np.newaxis, :] # Run either till we're done OR if max_timestep is defined only till that # timestep. ts = 0 while ((not done) and (not max_timestep or observation_history.shape[1] < max_timestep)): ts_start = time.time() # Run the policy, to pick an action, shape is (1, t, A) because # observation_history is shaped (1, t) + OBS predictions = policy_net_apply(observation_history, policy_net_params) # We need the predictions for the last time-step, so squeeze the batch # dimension and take the last time-step. predictions = np.squeeze(predictions, axis=0)[-1] # Policy can be run in one of the following ways: # - Greedy # - Epsilon-Greedy # - Categorical-Sampling action = None if policy == "greedy": action = np.argmax(predictions) elif policy == "epsilon-greedy": # A schedule for epsilon is 1/k where k is the episode number sampled. if onp.random.random() < epsilon: # Choose an action at random. action = onp.random.randint(0, high=len(predictions)) else: # Return the best action. action = np.argmax(predictions) elif policy == "categorical-sampling": # NOTE: The predictions aren't probabilities but log-probabilities # instead, since they were computed with LogSoftmax. # So just np.exp them to make them probabilities. predictions = np.exp(predictions) action = onp.argwhere( onp.random.multinomial(1, predictions) == 1) else: raise ValueError("Unknown policy: %s" % policy) # NOTE: Assumption, single batch. try: action = int(action) except TypeError as err: # Let's dump some information before we die off. logging.error("Cannot convert action into an integer: [%s]", err) logging.error("action.shape: [%s]", action.shape) logging.error("action: [%s]", action) logging.error("predictions.shape: [%s]", predictions.shape) logging.error("predictions: [%s]", predictions) logging.error("observation_history: [%s]", observation_history) logging.error("policy_net_params: [%s]", policy_net_params) log_params(policy_net_params, "policy_net_params") raise err observation, reward, done, _ = env.step(action) # observation is of shape OBS, so add extra dims and concatenate on the # time dimension. observation_history = np.concatenate( [observation_history, observation[np.newaxis, np.newaxis, :]], axis=1) rewards.append(reward) actions.append(action) ts += 1 logging.vlog( 2, " Collected time-step[ %5d] of trajectory[ %5d] in [%0.2f] msec.", ts, t, get_time(ts_start)) logging.vlog(2, " Collected trajectory[ %5d] in [%0.2f] msec.", t, get_time(t_start)) # This means we are done we're been terminated early. assert done or (max_timestep and max_timestep >= observation_history.shape[1]) # observation_history is (1, T+1) + OBS, lets squeeze out the batch dim. observation_history = np.squeeze(observation_history, axis=0) trajectories.append( (observation_history, np.stack(actions), np.stack(rewards))) return trajectories
def accuracy(params: hk.Params, batch: Batch) -> jnp.ndarray: predictions = net.apply(params, batch) return jnp.mean(jnp.argmax(predictions, axis=-1) == batch["label"])
def __call__(self, inputs, is_training): """Connects the module to some inputs. Args: inputs: Tensor, final dimension must be equal to ``embedding_dim``. All other leading dimensions will be flattened and treated as a large batch. is_training: boolean, whether this connection is to training data. When this is set to ``False``, the internal moving average statistics will not be updated. Returns: dict: Dictionary containing the following keys and values: * ``quantize``: Tensor containing the quantized version of the input. * ``loss``: Tensor containing the loss to optimize. * ``perplexity``: Tensor containing the perplexity of the encodings. * ``encodings``: Tensor containing the discrete encodings, ie which element of the quantized space each input element was mapped to. * ``encoding_indices``: Tensor containing the discrete encoding indices, ie which element of the quantized space each input element was mapped to. """ flat_inputs = jnp.reshape(inputs, [-1, self.embedding_dim]) embeddings = self.embeddings distances = (jnp.sum(flat_inputs**2, 1, keepdims=True) - 2 * jnp.matmul(flat_inputs, embeddings) + jnp.sum(embeddings**2, 0, keepdims=True)) encoding_indices = jnp.argmax(-distances, 1) encodings = jax.nn.one_hot(encoding_indices, self.num_embeddings, dtype=distances.dtype) # NB: if your code crashes with a reshape error on the line below about a # Tensor containing the wrong number of values, then the most likely cause # is that the input passed in does not have a final dimension equal to # self.embedding_dim. Ideally we would catch this with an Assert but that # creates various other problems related to device placement / TPUs. encoding_indices = jnp.reshape(encoding_indices, inputs.shape[:-1]) quantized = self.quantize(encoding_indices) e_latent_loss = jnp.mean( (jax.lax.stop_gradient(quantized) - inputs)**2) if is_training: cluster_size = jnp.sum(encodings, axis=0) if self.cross_replica_axis: cluster_size = jax.lax.psum(cluster_size, axis_name=self.cross_replica_axis) updated_ema_cluster_size = self.ema_cluster_size(cluster_size) dw = jnp.matmul(flat_inputs.T, encodings) if self.cross_replica_axis: dw = jax.lax.psum(dw, axis_name=self.cross_replica_axis) updated_ema_dw = self.ema_dw(dw) n = jnp.sum(updated_ema_cluster_size) updated_ema_cluster_size = ( (updated_ema_cluster_size + self.epsilon) / (n + self.num_embeddings * self.epsilon) * n) normalised_updated_ema_w = ( updated_ema_dw / jnp.reshape(updated_ema_cluster_size, [1, -1])) hk.set_state("embeddings", normalised_updated_ema_w) loss = self.commitment_cost * e_latent_loss else: loss = self.commitment_cost * e_latent_loss # Straight Through Estimator quantized = inputs + jax.lax.stop_gradient(quantized - inputs) avg_probs = jnp.mean(encodings, 0) if self.cross_replica_axis: avg_probs = jax.lax.pmean(avg_probs, axis_name=self.cross_replica_axis) perplexity = jnp.exp(-jnp.sum(avg_probs * jnp.log(avg_probs + 1e-10))) return { "quantize": quantized, "loss": loss, "perplexity": perplexity, "encodings": encodings, "encoding_indices": encoding_indices, "distances": distances, }
logits = predict(params, inputs) preds = stax.logsoftmax(logits) return -np.mean(np.sum(preds * targets, axis=1)) #set up of index tl = test_labels index7 = tl.tolist().index([0, 0, 0, 0, 0, 0, 0, 1, 0, 0]) print(test_labels[index7]) #computing process to the new x input_image, input_label = shape_as_image(test_images[index7], test_labels[index7]) grad_newx = grad(computation, 1)(params, input_image, input_label) newx = input_image + hyper * np.sign(grad_newx) #start plot and its predicted vector target_class = np.argmax(input_label) predicted_class = np.argmax(predict(params, newx)) #predicted vector predict_vector = predict(params, newx) print('the target class is :', target_class) print('the predict class is :', predicted_class) print('the predicted vector is :', predict_vector) image = np.array(newx) image = image * 255 image = image.reshape(28, 28) plt.imshow(image) """## From here is Part 2""" #initial setup images = np.array(test_images[0:1000])
def accuracy(y_true, y_pred): return jnp.mean(jnp.argmax(y_pred, axis=-1) == y_true)
num_complete_batches, leftover = divmod(X.shape[0], batch_size) num_batches = num_complete_batches + bool(leftover) while True: temp, rng = random.split(rng) perm = random.permutation(temp, X.shape[0]) for i in range(num_batches): batch_idx = perm[i * batch_size:(i + 1) * batch_size] yield X[batch_idx], y[batch_idx] if __name__ == "__main__": rng = random.PRNGKey(0) X, y, X_test, y_test = mnist() X, X_test = X.reshape(-1, 28, 28, 1), X_test.reshape(-1, 28, 28, 1) y, y_test = (np.argmax(y, 1) % 2 == 1).astype( np.float32), (np.argmax(y_test, 1) % 1 == 1).astype(np.float32) temp, rng = random.split(rng) params, predict = model(temp) def loss(params, batch, l2=0.05): X, y = batch y_hat = predict(params, X).reshape(-1) return -np.mean(y * np.log(y_hat) + (1. - y) * np.log(1. - y_hat)) @jit def update(i, opt_state, batch): params = get_params(opt_state) return opt_update(i, grad(loss)(params, batch), opt_state)
def accuracy(trainable_params, untrainable_params, batch): inputs, targets = batch target_class = jnp.argmax(targets, axis=1) params = merge_params(trainable_params, untrainable_params) predicted_class = jnp.argmax(net.apply(params, inputs), axis=1) return jnp.mean(predicted_class == target_class)
def test_change_point_x64(): # Ref: warmup_steps, num_samples = 500, 3000 def model(data): alpha = 1 / np.mean(data) lambda1 = numpyro.sample('lambda1', dist.Exponential(alpha)) lambda2 = numpyro.sample('lambda2', dist.Exponential(alpha)) tau = numpyro.sample('tau', dist.Uniform(0, 1)) lambda12 = np.where( np.arange(len(data)) < tau * len(data), lambda1, lambda2) numpyro.sample('obs', dist.Poisson(lambda12), obs=data) count_data = np.array([ 13, 24, 8, 24, 7, 35, 14, 11, 15, 11, 22, 22, 11, 57, 11, 19, 29, 6, 19, 12, 22, 12, 18, 72, 32, 9, 7, 13, 19, 23, 27, 20, 6, 17, 13, 10, 14, 6, 16, 15, 7, 2, 15, 15, 19, 70, 49, 7, 53, 22, 21, 31, 19, 11, 18, 20, 12, 35, 17, 23, 17, 4, 2, 31, 30, 13, 27, 0, 39, 37, 5, 14, 13, 22, ]) kernel = NUTS(model=model) mcmc = MCMC(kernel, warmup_steps, num_samples), count_data) samples = mcmc.get_samples() tau_posterior = (samples['tau'] * len(count_data)).astype(np.int32) tau_values, counts = onp.unique(tau_posterior, return_counts=True) mode_ind = np.argmax(counts) mode = tau_values[mode_ind] assert mode == 44 if 'JAX_ENABLE_X64' in os.environ: assert samples['lambda1'].dtype == np.float64 assert samples['lambda2'].dtype == np.float64 assert samples['tau'].dtype == np.float64
def sample_categorical(key, logits): minval = jnp.finfo(logits.dtype).tiny unif = jax.random.uniform(key, logits.shape, minval=minval, maxval=1.0) gumbel = -jnp.log(-jnp.log(unif) + minval) category = jnp.argmax(logits + gumbel, -1) return category
def gumbel_max_sampler(logits_1, logits_2, rng): """Samples from a Gumbel-max coupling.""" gumbels = jax.random.gumbel(rng, logits_1.shape) x = jnp.argmax(gumbels + logits_1) y = jnp.argmax(gumbels + logits_2) return jnp.zeros([10, 10]).at[x, y].set(1.)
def main(_): sns.set() sns.set_palette(sns.color_palette('hls', 10)) npr.seed(FLAGS.seed)'Starting experiment.') # Create model folder for outputs try: gfile.MakeDirs(FLAGS.work_dir) except gfile.GOSError: pass stdout_log = gfile.Open('{}/stdout.log'.format(FLAGS.work_dir), 'w+') # use mean/std of svhn train train_images, _, _ = datasets.get_dataset_split( name=FLAGS.train_split.split('-')[0], split=FLAGS.train_split.split('-')[1], shuffle=False) train_mu, train_std = onp.mean(train_images), onp.std(train_images) del train_images # BEGIN: fetch test data and candidate pool test_images, test_labels, _ = datasets.get_dataset_split( name=FLAGS.test_split.split('-')[0], split=FLAGS.test_split.split('-')[1], shuffle=False) pool_images, pool_labels, _ = datasets.get_dataset_split( name=FLAGS.pool_split.split('-')[0], split=FLAGS.pool_split.split('-')[1], shuffle=False) n_pool = len(pool_images) test_images = (test_images - train_mu) / train_std # normalize w train mu/std pool_images = (pool_images - train_mu) / train_std # normalize w train mu/std # augmentation for train/pool data if FLAGS.augment_data: augmentation = data.chain_transforms(data.RandomHorizontalFlip(0.5), data.RandomCrop(4), data.ToDevice) else: augmentation = None # END: fetch test data and candidate pool # BEGIN: load ckpt opt_init, opt_update, get_params = optimizers.sgd(FLAGS.learning_rate) if FLAGS.pretrained_dir is not None: with gfile.Open(FLAGS.pretrained_dir, 'rb') as fpre: pretrained_opt_state = optimizers.pack_optimizer_state( pickle.load(fpre)) fixed_params = get_params(pretrained_opt_state)[:7] ckpt_dir = '{}/{}'.format(FLAGS.root_dir, FLAGS.ckpt_idx) with gfile.Open(ckpt_dir, 'wr') as fckpt: opt_state = optimizers.pack_optimizer_state(pickle.load(fckpt)) params = get_params(opt_state) # combine fixed pretrained params and dpsgd trained last layers params = fixed_params + params opt_state = opt_init(params) else: ckpt_dir = '{}/{}'.format(FLAGS.root_dir, FLAGS.ckpt_idx) with gfile.Open(ckpt_dir, 'wr') as fckpt: opt_state = optimizers.pack_optimizer_state(pickle.load(fckpt)) params = get_params(opt_state) stdout_log.write('finetune from: {}\n'.format(ckpt_dir))'finetune from: %s', ckpt_dir) test_acc, test_pred = accuracy(params, shape_as_image(test_images, test_labels), return_predicted_class=True)'test accuracy: %.2f', test_acc) stdout_log.write('test accuracy: {}\n'.format(test_acc)) stdout_log.flush() # END: load ckpt # BEGIN: setup for dp model @jit def update(_, i, opt_state, batch): params = get_params(opt_state) return opt_update(i, grad_loss(params, batch), opt_state) @jit def private_update(rng, i, opt_state, batch): params = get_params(opt_state) rng = random.fold_in(rng, i) # get new key for new random numbers return opt_update( i, private_grad(params, batch, rng, FLAGS.l2_norm_clip, FLAGS.noise_multiplier, FLAGS.batch_size), opt_state) # END: setup for dp model ### BEGIN: prepare extra points picked from pool data # BEGIN: on pool data pool_embeddings = [apply_fn_0(params[:-1], pool_images[b_i:b_i + FLAGS.batch_size]) \ for b_i in range(0, n_pool, FLAGS.batch_size)] pool_embeddings = np.concatenate(pool_embeddings, axis=0) pool_logits = apply_fn_1(params[-1:], pool_embeddings) pool_true_labels = np.argmax(pool_labels, axis=1) pool_predicted_labels = np.argmax(pool_logits, axis=1) pool_correct_indices = \ onp.where(pool_true_labels == pool_predicted_labels)[0] pool_incorrect_indices = \ onp.where(pool_true_labels != pool_predicted_labels)[0] assert len(pool_correct_indices) + \ len(pool_incorrect_indices) == len(pool_labels) pool_probs = stax.softmax(pool_logits) if FLAGS.uncertain == 0 or FLAGS.uncertain == 'entropy': pool_entropy = -onp.sum(pool_probs * onp.log(pool_probs), axis=1) stdout_log.write('all {} entropy: min {}, max {}\n'.format( len(pool_entropy), onp.min(pool_entropy), onp.max(pool_entropy))) pool_entropy_sorted_indices = onp.argsort(pool_entropy) # take the n_extra most uncertain points pool_uncertain_indices = \ pool_entropy_sorted_indices[::-1][:FLAGS.n_extra] stdout_log.write('uncertain {} entropy: min {}, max {}\n'.format( len(pool_entropy[pool_uncertain_indices]), onp.min(pool_entropy[pool_uncertain_indices]), onp.max(pool_entropy[pool_uncertain_indices]))) elif FLAGS.uncertain == 1 or FLAGS.uncertain == 'difference': # 1st_prob - 2nd_prob assert len(pool_probs.shape) == 2 sorted_pool_probs = onp.sort(pool_probs, axis=1) pool_probs_diff = sorted_pool_probs[:, -1] - sorted_pool_probs[:, -2] assert min(pool_probs_diff) > 0. stdout_log.write('all {} difference: min {}, max {}\n'.format( len(pool_probs_diff), onp.min(pool_probs_diff), onp.max(pool_probs_diff))) pool_uncertain_indices = onp.argsort(pool_probs_diff)[:FLAGS.n_extra] stdout_log.write('uncertain {} difference: min {}, max {}\n'.format( len(pool_probs_diff[pool_uncertain_indices]), onp.min(pool_probs_diff[pool_uncertain_indices]), onp.max(pool_probs_diff[pool_uncertain_indices]))) elif FLAGS.uncertain == 2 or FLAGS.uncertain == 'random': pool_uncertain_indices = npr.permutation(n_pool)[:FLAGS.n_extra] # END: on pool data ### END: prepare extra points picked from pool data finetune_images = copy.deepcopy(pool_images[pool_uncertain_indices]) finetune_labels = copy.deepcopy(pool_labels[pool_uncertain_indices]) stdout_log.write('Starting fine-tuning...\n')'Starting fine-tuning...') stdout_log.flush() stdout_log.write('{} points picked via {}\n'.format( len(finetune_images), FLAGS.uncertain))'%d points picked via %s', len(finetune_images), FLAGS.uncertain) assert FLAGS.n_extra == len(finetune_images) for epoch in range(1, FLAGS.epochs + 1): # BEGIN: finetune model with extra data, evaluate and save num_extra = len(finetune_images) num_complete_batches, leftover = divmod(num_extra, FLAGS.batch_size) num_batches = num_complete_batches + bool(leftover) finetune = data.DataChunk(X=finetune_images, Y=finetune_labels, image_size=32, image_channels=3, label_dim=1, label_format='numeric') batches = data.minibatcher(finetune, FLAGS.batch_size, transform=augmentation) itercount = itertools.count() key = random.PRNGKey(FLAGS.seed) start_time = time.time() for _ in range(num_batches): # tmp_time = time.time() b = next(batches) if FLAGS.dpsgd: opt_state = private_update( key, next(itercount), opt_state, shape_as_image(b.X, b.Y, dummy_dim=True)) else: opt_state = update(key, next(itercount), opt_state, shape_as_image(b.X, b.Y)) # stdout_log.write('single update in {:.2f} sec\n'.format( # time.time() - tmp_time)) epoch_time = time.time() - start_time stdout_log.write('Epoch {} in {:.2f} sec\n'.format(epoch, epoch_time))'Epoch %d in %.2f sec', epoch, epoch_time) # accuracy on test data params = get_params(opt_state) test_pred_0 = test_pred test_acc, test_pred = accuracy(params, shape_as_image(test_images, test_labels), return_predicted_class=True) test_loss = loss(params, shape_as_image(test_images, test_labels)) stdout_log.write( 'Eval set loss, accuracy (%): ({:.2f}, {:.2f})\n'.format( test_loss, 100 * test_acc))'Eval set loss, accuracy: (%.2f, %.2f)', test_loss, 100 * test_acc) stdout_log.flush() # visualize prediction difference between 2 checkpoints. if FLAGS.visualize: utils.visualize_ckpt_difference(test_images, np.argmax(test_labels, axis=1), test_pred_0, test_pred, epoch - 1, epoch, FLAGS.work_dir, mu=train_mu, sigma=train_std) # END: finetune model with extra data, evaluate and save stdout_log.close()
def accuracy(params, batch): inputs, targets = batch target_class = np.argmax(targets, axis=1) predicted_class = np.argmax(predict(params, inputs), axis=1) return np.mean(predicted_class == target_class)
def _accuracy(y, y_hat): """Compute the accuracy of the predictions with respect to one-hot labels.""" return np.mean(np.argmax(y, axis=1) == np.argmax(y_hat, axis=1))
def _gumbel_max(rng, logit_probs): return np.argmax(random.gumbel(rng, logit_probs.shape, logit_probs.dtype) + logit_probs, axis=0)
def mode(self): return jp.argmax(self._probs, axis=-1)
def accuracy(params, images, targets): target_class = jnp.argmax(targets, axis=1) predicted_class = jnp.argmax(batched_predict(params, images), axis=1) return jnp.mean(predicted_class == target_class)
def sample(self, sample_shape, seed): return jp.argmax( self._logits + jax.random.gumbel(seed, shape=sample_shape + self._logits.shape), axis=-1, )