def learning(self, t, observations, responses, prior): # parametric learning and inference p_cfm, params = prior obs = nn.one_hot(observations, 4) p_aco = params / params.sum(-1, keepdims=True) p_c = einsum('nco,no->nc', p_aco[self.slc, responses], obs) m = jnp.expand_dims(self.mask[t], -1) p_c = m * p_c + (1 - m) / 2. post = einsum('nc,ncfm->ncfm', p_c, p_cfm) norm = post.reshape(post.shape[:-3] + (-1, )).sum(-1)[..., None, None, None] post = post / norm resp = nn.one_hot(responses, 3) post_c = post.reshape(post.shape[:-2] + (-1, )).sum(-1) params_new = params + einsum('na,nc,no,n->naco', resp, post_c, obs, self.mask[t]) pred = einsum('fcz,mfg,ncfm->nzgm', self.p_fcc, self.p_mff, post) if self.dyn_pref: self.lam += jnp.expand_dims(self.eta, -1) * obs return (pred, params_new)
def testOneHot(self): actual = nn.one_hot(np.array([0, 1, 2]), 3) expected = np.array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]) self.assertAllClose(actual, expected, check_dtypes=True) actual = nn.one_hot(np.array([1, 2, 0]), 3) expected = np.array([[0., 1., 0.], [0., 0., 1.], [1., 0., 0.]]) self.assertAllClose(actual, expected, check_dtypes=True)
def testOneHot(self): actual = nn.one_hot(jnp.array([0, 1, 2]), 3) expected = jnp.array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]) self.assertAllClose(actual, expected) actual = nn.one_hot(jnp.array([1, 2, 0]), 3) expected = jnp.array([[0., 1., 0.], [0., 0., 1.], [1., 0., 0.]]) self.assertAllClose(actual, expected)
def testOneHotAxis(self): expected = jnp.array([[0., 1., 0.], [0., 0., 1.], [1., 0., 0.]]).T actual = nn.one_hot(jnp.array([1, 2, 0]), 3, axis=0) self.assertAllClose(actual, expected) actual = nn.one_hot(jnp.array([1, 2, 0]), 3, axis=-2) self.assertAllClose(actual, expected)
def transition_fn(carry, t): lam_prev, x_prev = carry gamma = npyro.deterministic('gamma', nn.softplus(mu + x_prev)) U = jnp.log(lam_prev) - jnp.log(lam_prev.sum(-1, keepdims=True)) logs = logits((beliefs[0][:, t], beliefs[1][:, t]), jnp.expand_dims(gamma, -1), jnp.expand_dims(U, -2)) lam_next = npyro.deterministic( 'lams', lam_prev + nn.one_hot(beliefs[2][t], 4) * jnp.expand_dims(mask[t] * eta, -1)) mixing_dist = dist.CategoricalProbs(weights) component_dist = dist.CategoricalLogits(logs.swapaxes(0, 1)).mask( mask[t][..., None]) with npyro.plate('subjects', N): y = npyro.sample( 'y', dist.MixtureSameFamily(mixing_dist, component_dist)) noise = npyro.sample('dw', dist.Normal(0., 1.)) x_next = rho * x_prev + sigma * noise return (lam_next, x_next), None
def sample_scan(params, tup, x): """ Perform single step update of the network """ _, (update_W, update_U, update_b), (reset_W, reset_U, reset_b), (out_W, out_U, out_b), (sm_W, sm_b) = params hidden = tup[3] logP = tup[2] key = tup[0] inp = tup[1] update_gate = sigmoid( np.dot(inp, update_W) + np.dot(hidden, update_U) + update_b) reset_gate = sigmoid( np.dot(inp, reset_W) + np.dot(hidden, reset_U) + reset_b) output_gate = np.tanh( np.dot(inp, out_W) + np.dot(np.multiply(reset_gate, hidden), out_U) + out_b) output = np.multiply(update_gate, hidden) + np.multiply( 1 - update_gate, output_gate) hidden = output logits = np.dot(hidden, sm_W) + sm_b key, subkey = random.split(key) samples = random.categorical( subkey, logits, axis=1, shape=None) # sampling the conditional samples = one_hot( samples, sm_b.shape[0]) # convert to one hot encoded vector log_P_new = np.sum(samples * log_softmax(logits), axis=1) log_P_new = log_P_new + logP # update the value of the logP of the sample return (key, samples, log_P_new, output), samples
def get_boundaries(self, encodings, segment_id, lengths, training): """Get boundaries (b) for a single segment in batch.""" if segment_id == self.max_num_segments - 1: # Last boundary is always placed on last sequence element. logits_b = None # sample_b = jnp.zeros_like(encodings[:, :, 0]).scatter_( # 1, jnp.expand_dims(lengths, -1) - 1, 1) sample_b = jnp.zeros_like(encodings[:, :, 0]) sample_b = jax.ops.index_update( sample_b, jax.ops.index[jnp.arange(len(lengths)), lengths - 1], 1) else: hidden = nn.relu(self.head_b_1(encodings)) logits_b = jnp.squeeze(self.head_b_2(hidden), -1) # Mask out first position with large neg. value. neg_inf = jnp.ones((encodings.shape[0], 1)) * utils.NEG_INF # TODO(tkipf): Mask out padded positions with large neg. value. logits_b = jnp.concatenate([neg_inf, logits_b[:, 1:]], axis=1) if training: sample_b = utils.gumbel_softmax_sample(hk.next_rng_key(), logits_b, temp=self.temp_b) else: sample_b_idx = jnp.argmax(logits_b, axis=1) sample_b = nn.one_hot(sample_b_idx, logits_b.shape[1]) return logits_b, sample_b
def evaluate_batch(self, flax_module, batch_stats, batch): """Evaluates cross_entopy on the given batch.""" # TODO(ankugarg): Augment with other metrics like log-perplexity. with nn.stateful(batch_stats, mutable=False): logits = flax_module(batch['inputs'], batch['targets'], batch.get('inputs_positions'), batch.get('targets_positions'), batch.get('inputs_segmentation'), batch.get('targets_segmentation'), train=False) weights = batch.get('weights') targets = batch['targets'] if self.dataset_meta_data['apply_one_hot_in_loss']: targets = one_hot(batch['targets'], logits.shape[-1]) # Add log-perplexity metric. evaluated_metrics = {} for key in self.metrics_bundle: per_example_metrics = self.metrics_bundle[key](logits, targets, weights) evaluated_metrics[key] = jnp.sum( lax.psum(per_example_metrics, axis_name='batch')) return evaluated_metrics
def training_cost(self, flax_module, batch_stats, batch, dropout_rng): """Return cross entropy loss with (optional) L2 penalty on the weights.""" with nn.stateful(batch_stats) as new_batch_stats: with nn.stochastic(dropout_rng): # inputs/targets positions and segmentations are required # when we have packed examples. logits = flax_module(batch['inputs'], batch['targets'], batch.get('inputs_positions'), batch.get('targets_positions'), batch.get('inputs_segmentation'), batch.get('targets_segmentation'), train=True) weights = batch.get('weights') targets = batch['targets'] if self.dataset_meta_data['apply_one_hot_in_loss']: targets = one_hot(batch['targets'], logits.shape[-1]) # Optionally apply label smoothing. if self.hps.get('label_smoothing') is not None: targets = model_utils.apply_label_smoothing( targets, self.hps.get('label_smoothing')) total_loss = self.loss_fn(logits, targets, weights) if self.hps.get('l2_decay_factor'): l2_loss = model_utils.l2_regularization( flax_module.params, self.hps.l2_decay_rank_threshold) total_loss += 0.5 * self.hps.l2_decay_factor * l2_loss return total_loss, (new_batch_stats)
def viterbi_backward(state, t): state = jnp.where( t <= length, jnp.sum(most_likely_sources[t] * one_hot(state, n_states)).astype( jnp.int64), state) most_likely = jnp.where(t <= length, state, -1) return state, most_likely
def evaluate_batch(self, params, batch_stats, batch): """Evaluates cross_entopy on the given batch.""" # TODO(ankugarg): Augment with other metrics like log-perplexity. logits = self.flax_module.apply( { 'params': params, 'batch_stats': batch_stats }, batch['inputs'], batch['targets'], inputs_positions=batch.get('inputs_positions'), targets_positions=batch.get('targets_positions'), inputs_segmentation=batch.get('inputs_segmentation'), targets_segmentation=batch.get('targets_segmentation'), train=False) weights = batch.get('weights') targets = batch['targets'] if self.dataset_meta_data['apply_one_hot_in_loss']: targets = one_hot(batch['targets'], logits.shape[-1]) # Add log-perplexity metric. return self.metrics_bundle.gather_from_model_output(logits=logits, targets=targets, weights=weights, axis_name='batch')
def mace(positions, annotations): """ This model corresponds to the plate diagram in Figure 3 of reference [1]. """ num_annotators = int(np.max(positions)) + 1 num_classes = int(np.max(annotations)) + 1 num_items, num_positions = annotations.shape with numpyro.plate("annotator", num_annotators): epsilon = numpyro.sample("epsilon", dist.Dirichlet(jnp.full(num_classes, 10))) theta = numpyro.sample("theta", dist.Beta(0.5, 0.5)) with numpyro.plate("item", num_items, dim=-2): c = numpyro.sample( "c", dist.DiscreteUniform(0, num_classes - 1), infer={"enumerate": "parallel"}, ) with numpyro.plate("position", num_positions): s = numpyro.sample( "s", dist.Bernoulli(1 - theta[positions]), infer={"enumerate": "parallel"}, ) probs = jnp.where(s[..., None] == 0, nn.one_hot(c, num_classes), epsilon[positions]) numpyro.sample("y", dist.Categorical(probs), obs=annotations)
def MCMC_SpinFlip(h, e, seq_1hot, H, key): (L, q) = h.shape # advance RNG key, subkey_site, subkey_char, subkey_accept = random.split(key, 4) # choose site to flip i = random.choice(subkey_site, np.arange(L)) # pick random character for chosen site a = random.choice(subkey_char, np.arange(q)) # make proposal sequence seq_1hot_tmp = index_update(seq_1hot, index[i, :], nn.one_hot(a, q)) H_tmp = Potts_ScoreSeqCore(h, e, seq_1hot_tmp) accept_prob = np.exp(H_tmp - H) flip = np.zeros((L, q), dtype=np.bool_) accept_draw = random.uniform(subkey_accept) flip = index_update(flip, index[i, :], accept_draw < accept_prob) return np.where(flip, seq_1hot_tmp, seq_1hot), np.where(accept_draw < accept_prob, H_tmp, H), key
def sample(self, rng: Generator, shape: Optional[Shape] = None) -> RealArray: if shape is not None: shape += self.shape return one_hot( jax.random.categorical(rng.key, self.log_odds, shape=shape), self.log_odds.shape[-1])
def loss(params, batch): inputs, targets = batch one_hot_targets = one_hot(targets, num_classes=num_symbols) preds, _ = transformer(params, inputs, out_seq_length=seq_length, num_symbols=num_symbols) return -jnp.mean( jnp.sum(jnp.sum(preds * one_hot_targets, axis=-1), axis=-1))
def cross_entropy_loss(outputs, actions): ''' Calculates cross entropy loss (-sum(actions * log(outputs))) ''' l = len(actions[0]) for i in range(len(actions)): # one hot actions actions[i] = nn.one_hot(actions[i], l) return -1 * jnp.sum(actions * jnp.log(outputs))
def transformer(params, inputs, out_seq_length, num_symbols): one_hot_inputs = one_hot(inputs, num_classes=num_symbols) one_hot_outputs = jnp.zeros( [inputs.shape[0], out_seq_length + 1, num_symbols + 1]) #one_hot_outputs[:, 0, -1] = 1. # start symbol one_hot_outputs = ops.index_update(one_hot_outputs, ops.index[:, 0, -1], 1.) output_logits = jnp.zeros([inputs.shape[0], out_seq_length, num_symbols]) (input_embeddings, output_embeddings, encoder_params_list, decoder_params_list, output_params) = params output_w, output_b = output_params encoder_results = jnp.dot(one_hot_inputs, input_embeddings) encoder_results += positional_encodings(encoder_results.shape[-2], encoder_results.shape[-1]) for encoder_layer_i, this_enc_params in enumerate(encoder_params_list): encoder_results = encoder_layer(this_enc_params, encoder_results) for i in range(out_seq_length): decoder_results = jnp.dot(one_hot_outputs[:, :i + 1, :], output_embeddings) decoder_results += positional_encodings(decoder_results.shape[-2], decoder_results.shape[-1]) for decoder_layer_i, this_dec_params in enumerate(decoder_params_list): decoder_results = decoder_layer(this_dec_params, decoder_results, encoder_results) this_step_results = jnp.dot(decoder_results, output_w) + output_b this_step_results = this_step_results[:, -1, :] # previous outputs already produced output_logits = ops.index_update(output_logits, ops.index[:, i, :], this_step_results) one_hot_outputs = ops.index_update( one_hot_outputs, ops.index[:, i + 1, :], one_hot(jnp.argmax(this_step_results, axis=-1), num_classes=num_symbols + 1)) one_hot_outputs = one_hot_outputs[:, 1:, :-1] # chop start symbol return output_logits, one_hot_outputs
def _evaluate_batch(flax_module, batch_stats, batch, metrics_bundle, apply_one_hot_in_loss): """Evaluates metrics on the given batch. Currently we assume each metric_fn in metrics_bundle has the API: metric_fn(logits, targets, weights) and returns an array of shape [batch_size]. We also assume that to compute the aggregate metric, one should sum across all batches, then divide by the total samples seen (calculated by the 'denominator' metric). In this way we currently only support metrics of the 1/N sum f(inputs, targets). Note, the caller is responsible for dividing by metrics['denominator'] when computing the mean of each metric. Args: flax_module: A flax.nn.Module batch_stats: A flax.nn.Collection object tracking batch_stats. batch: A dictionary with keys 'inputs', 'targets', 'weights'. metrics_bundle: A group of metrics to use for evaluation. apply_one_hot_in_loss: Indicates whether or not the targets are one hot encoded. Returns: A dictionary with the same keys as metrics, but mapping to the summed metric across the sharded batch_dim. """ with nn.stateful(batch_stats, mutable=False): logits = flax_module(batch['inputs'], train=False) targets = batch['targets'] if apply_one_hot_in_loss: targets = one_hot(batch['targets'], logits.shape[-1]) # map the dict values (which are functions) to function(targets, logits) weights = batch.get('weights') # Weights might not be defined. eval_batch_size = targets.shape[0] if weights is None: weights = jnp.ones(eval_batch_size) # This psum is required to correctly evaluate with multihost. Only host 0 # will report the metrics, so we must aggregate across all hosts. The psum # will map an array of shape [n_global_devices, batch_size] -> [batch_size] # by summing across the devices dimension. The outer sum then sums across the # batch dim. The result is the we have summed across all samples in the # sharded batch. evaluated_metrics = {} for key in metrics_bundle: per_example_metrics = metrics_bundle[key](logits, targets, weights) evaluated_metrics[key] = jnp.sum( lax.psum(per_example_metrics, axis_name='batch')) return evaluated_metrics
def transition_fn(carry, t): lam_prev = carry U = jnp.log(lam_prev) - jnp.log(lam_prev.sum(-1, keepdims=True)) logs = logits((beliefs[0][t], beliefs[1][t]), jnp.expand_dims(gamma, -1), jnp.expand_dims(U, -2)) lam_next = npyro.deterministic( 'lams', lam_prev + nn.one_hot(beliefs[2][t], 4) * jnp.expand_dims(mask[t] * eta, -1)) npyro.sample('y', dist.CategoricalLogits(logs).mask(mask[t])) return lam_next, None
def _predict(self, params, base_preds, context, return_probs, target=None): # Base logits base_preds = jnp.clip(base_preds, a_min=self.pred_clipping, a_max=(1.0 - self.pred_clipping)) logits = jsp.special.logit(base_preds) logits = jnp.expand_dims(logits, axis=1) if self.num_classes == 2: logits = jnp.tile(logits, reps=(1, 1, 1)) else: logits = jnp.tile(logits, reps=(1, self.num_classes, 1)) # Turn target class into one-hot if target is not None: target = jnn.one_hot(target, num_classes=self.num_classes) if self.num_classes == 2: target = target[:, 1:] # Layers if target is None: for n, layer in enumerate(self.layers): logits = layer.predict(params=params[f'layer{n}'], logits=logits, context=context) else: for n, layer in enumerate(self.layers): params[f'layer{n}'], logits = layer.predict( params=params[f'layer{n}'], logits=logits, context=context, target=target) logits = jnp.squeeze(logits, axis=-1) if self.num_classes == 2: logits = jnp.squeeze(logits, axis=1) # Output prediction if return_probs: prediction = jnn.sigmoid(logits) elif self.num_classes == 2: prediction = logits > 0.0 else: prediction = jnp.argmax(logits, axis=1) if target is None: return prediction else: return params, prediction
def image_iterator(data, rescale, output_shape, is_one_hot, autoencoder, shuffle_rng=None, augment_fn=None, include_example_keys=False): """Preprocesses the batch data arrays in the data generator. Rescales inputs. One hot encode targets if is_one_hot is true. Set targets to inputs of output_shape if autoencoder is true. Args: data: An iterator generating dicts of input and target data arrays. rescale: A lambda function preprocessing input data arrays. output_shape: Shape of network output. is_one_hot: If true, targets are one hot encoded. autoencoder: If true, targets are set to inputs. shuffle_rng: jax.random.PRNGKey augment_fn: The number of classes used for the dataset. include_example_keys: If True, then the tfds_id will be exposed in each batch dict of the validation set under the key `example_key`. Yields: A dictionary mapping keys ('image', 'label') to preprocessed data arrays. """ for batch_index, batch in enumerate(iterator_as_numpy(iter(data))): inputs = batch['image'] targets = batch['label'] if is_one_hot: targets = one_hot(batch['label'], output_shape[-1]) if augment_fn: batch_rng = jax.random.fold_in(shuffle_rng, batch_index) inputs, targets = augment_fn(batch_rng, inputs, targets) inputs = rescale(inputs) if autoencoder: batch_output_shape = tuple([inputs.shape[0]] + list(output_shape)) targets = inputs.reshape(batch_output_shape) if include_example_keys: yield { 'inputs': inputs, 'targets': targets, 'tfds_id': batch['tfds_id'] } else: yield {'inputs': inputs, 'targets': targets}
def MCMC_SeqEmit(h, e, key, nflip): (L, q) = h.shape seq_1hot = nn.one_hot( random.choice(key, np.arange(0, q), shape=(1, L))[0], q) H = Potts_ScoreSeqCore(h, e, seq_1hot) @jit def loop_fun_scan(loop_carry, i): h, e, seq_1hot, H, key = loop_carry seq_1hot, H, key = MCMC_SpinFlip(h, e, seq_1hot, H, key) return (h, e, seq_1hot, H, key), i (h, e, seq_1hot, H, key), i = lax.scan(loop_fun_scan, (h, e, seq_1hot, H, key), None, length=nflip) return seq_1hot, H, key
def transition_fn(carry, t): lam_prev, x_prev = carry gamma = npyro.deterministic('gamma', nn.softplus(mu + x_prev)) U = jnp.log(lam_prev) - jnp.log(lam_prev.sum(-1, keepdims=True)) logs = logits((beliefs[0][t], beliefs[1][t]), jnp.expand_dims(gamma, -1), jnp.expand_dims(U, -2)) lam_next = npyro.deterministic( 'lams', lam_prev + nn.one_hot(beliefs[2][t], 4) * jnp.expand_dims(mask[t] * eta, -1)) npyro.sample('y', dist.CategoricalLogits(logs).mask(mask[t])) noise = npyro.sample('dw', dist.Normal(0., 1.)) x_next = rho * x_prev + sigma * noise return (lam_next, x_next), None
def mace(positions, annotations): """ This model corresponds to the plate diagram in Figure 3 of reference [1]. """ num_annotators = int(np.max(positions)) + 1 num_classes = int(np.max(annotations)) + 1 num_items, num_positions = annotations.shape with numpyro.plate("annotator", num_annotators): epsilon = numpyro.sample("epsilon", dist.Dirichlet(jnp.full(num_classes, 10))) theta = numpyro.sample("theta", dist.Beta(0.5, 0.5)) with numpyro.plate("item", num_items, dim=-2): # NB: using constant logits for discrete uniform prior # (NumPyro does not have DiscreteUniform distribution yet) c = numpyro.sample("c", dist.Categorical(logits=jnp.zeros(num_classes))) with numpyro.plate("position", num_positions): s = numpyro.sample("s", dist.Bernoulli(1 - theta[positions])) probs = jnp.where(s[..., None] == 0, nn.one_hot(c, num_classes), epsilon[positions]) numpyro.sample("y", dist.Categorical(probs), obs=annotations)
def training_cost(self, flax_module, batch_stats, batch, dropout_rng): """Return loss with an L2 penalty on the weights.""" with nn.stateful(batch_stats) as new_batch_stats: with nn.stochastic(dropout_rng): logits = flax_module(batch['inputs'], train=True) weights = batch.get('weights') targets = batch['targets'] if self.dataset_meta_data['apply_one_hot_in_loss']: targets = one_hot(targets, logits.shape[-1]) # Optionally apply label smoothing. if self.hps.get('label_smoothing') is not None: targets = model_utils.apply_label_smoothing( targets, self.hps.get('label_smoothing')) total_loss = self.loss_fn(logits, targets, weights) if self.hps.get('l2_decay_factor'): l2_loss = model_utils.l2_regularization( flax_module.params, self.hps.l2_decay_rank_threshold) total_loss += 0.5 * self.hps.l2_decay_factor * l2_loss return total_loss, (new_batch_stats)
def thompson_sampling_step(model_params, state, model, environment): """ Contextual implementation of the Thompson sampling algorithm. This implementation considers a single step Parameters ---------- model_params: dict environment: function key: jax.random.PRNGKey moidel: instance of a Bandit model """ key, context = state key_sample, key_reward = random.split(key) # Sample an choose an action params = model.sample(key_sample, model_params, context) pred_rewards = model.predict_rewards(params, context) action = pred_rewards.argmax() # environment reward reward = environment(key_reward, action, context) model_params = model.update(action, model_params, context, reward) arm_reward = one_hot(action, K) * reward return model_params, (model_params, arm_reward)
def training_cost(self, params, batch, batch_stats=None, dropout_rng=None): """Return cross entropy loss with (optional) L2 penalty on the weights.""" # inputs/targets positions and segmentations are required when we have # packed examples. logits, new_batch_stats = self.flax_module.apply( { 'params': params, 'batch_stats': batch_stats }, batch['inputs'], batch['targets'], batch.get('inputs_positions'), batch.get('targets_positions'), batch.get('inputs_segmentation'), batch.get('targets_segmentation'), mutable=['batch_stats'], rngs={'dropout': dropout_rng}, train=True) weights = batch.get('weights') targets = batch['targets'] if self.dataset_meta_data['apply_one_hot_in_loss']: targets = one_hot(batch['targets'], logits.shape[-1]) # Optionally apply label smoothing. if self.hps.get('label_smoothing') is not None: targets = model_utils.apply_label_smoothing( targets, self.hps.get('label_smoothing')) total_loss = self.loss_fn(logits, targets, weights) if self.hps.get('l2_decay_factor'): l2_loss = model_utils.l2_regularization( params, self.hps.l2_decay_rank_threshold) total_loss += 0.5 * self.hps.l2_decay_factor * l2_loss return total_loss, (new_batch_stats)
def testOneHotCustomDtype(self): actual = nn.one_hot(jnp.array([0, 1, 2]), 3, dtype=jnp.bool_) expected = jnp.array([[True, False, False], [False, True, False], [False, False, True]]) self.assertAllClose(actual, expected)
def testOneHotNonArrayInput(self): actual = nn.one_hot([0, 1, 2], 3) expected = jnp.array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]) self.assertAllClose(actual, expected)
def testOneHotOutOfBound(self): actual = nn.one_hot(jnp.array([-1, 3]), 3) expected = jnp.array([[0., 0., 0.], [0., 0., 0.]]) self.assertAllClose(actual, expected)