Example #1
0
    def learning(self, t, observations, responses, prior):
        # parametric learning and inference
        p_cfm, params = prior
        obs = nn.one_hot(observations, 4)

        p_aco = params / params.sum(-1, keepdims=True)

        p_c = einsum('nco,no->nc', p_aco[self.slc, responses], obs)

        m = jnp.expand_dims(self.mask[t], -1)
        p_c = m * p_c + (1 - m) / 2.

        post = einsum('nc,ncfm->ncfm', p_c, p_cfm)

        norm = post.reshape(post.shape[:-3] + (-1, )).sum(-1)[..., None, None,
                                                              None]
        post = post / norm

        resp = nn.one_hot(responses, 3)
        post_c = post.reshape(post.shape[:-2] + (-1, )).sum(-1)

        params_new = params + einsum('na,nc,no,n->naco', resp, post_c, obs,
                                     self.mask[t])
        pred = einsum('fcz,mfg,ncfm->nzgm', self.p_fcc, self.p_mff, post)

        if self.dyn_pref:
            self.lam += jnp.expand_dims(self.eta, -1) * obs

        return (pred, params_new)
Example #2
0
    def testOneHot(self):
        actual = nn.one_hot(np.array([0, 1, 2]), 3)
        expected = np.array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]])
        self.assertAllClose(actual, expected, check_dtypes=True)

        actual = nn.one_hot(np.array([1, 2, 0]), 3)
        expected = np.array([[0., 1., 0.], [0., 0., 1.], [1., 0., 0.]])
        self.assertAllClose(actual, expected, check_dtypes=True)
Example #3
0
    def testOneHot(self):
        actual = nn.one_hot(jnp.array([0, 1, 2]), 3)
        expected = jnp.array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]])
        self.assertAllClose(actual, expected)

        actual = nn.one_hot(jnp.array([1, 2, 0]), 3)
        expected = jnp.array([[0., 1., 0.], [0., 0., 1.], [1., 0., 0.]])
        self.assertAllClose(actual, expected)
Example #4
0
    def testOneHotAxis(self):
        expected = jnp.array([[0., 1., 0.], [0., 0., 1.], [1., 0., 0.]]).T

        actual = nn.one_hot(jnp.array([1, 2, 0]), 3, axis=0)
        self.assertAllClose(actual, expected)

        actual = nn.one_hot(jnp.array([1, 2, 0]), 3, axis=-2)
        self.assertAllClose(actual, expected)
Example #5
0
    def transition_fn(carry, t):
        lam_prev, x_prev = carry

        gamma = npyro.deterministic('gamma', nn.softplus(mu + x_prev))

        U = jnp.log(lam_prev) - jnp.log(lam_prev.sum(-1, keepdims=True))

        logs = logits((beliefs[0][:, t], beliefs[1][:, t]),
                      jnp.expand_dims(gamma, -1), jnp.expand_dims(U, -2))

        lam_next = npyro.deterministic(
            'lams', lam_prev +
            nn.one_hot(beliefs[2][t], 4) * jnp.expand_dims(mask[t] * eta, -1))

        mixing_dist = dist.CategoricalProbs(weights)
        component_dist = dist.CategoricalLogits(logs.swapaxes(0, 1)).mask(
            mask[t][..., None])
        with npyro.plate('subjects', N):
            y = npyro.sample(
                'y', dist.MixtureSameFamily(mixing_dist, component_dist))
            noise = npyro.sample('dw', dist.Normal(0., 1.))

        x_next = rho * x_prev + sigma * noise

        return (lam_next, x_next), None
Example #6
0
        def sample_scan(params, tup, x):
            """ Perform single step update of the network """
            _, (update_W, update_U,
                update_b), (reset_W, reset_U,
                            reset_b), (out_W, out_U, out_b), (sm_W,
                                                              sm_b) = params
            hidden = tup[3]
            logP = tup[2]
            key = tup[0]
            inp = tup[1]

            update_gate = sigmoid(
                np.dot(inp, update_W) + np.dot(hidden, update_U) + update_b)
            reset_gate = sigmoid(
                np.dot(inp, reset_W) + np.dot(hidden, reset_U) + reset_b)
            output_gate = np.tanh(
                np.dot(inp, out_W) +
                np.dot(np.multiply(reset_gate, hidden), out_U) + out_b)
            output = np.multiply(update_gate, hidden) + np.multiply(
                1 - update_gate, output_gate)
            hidden = output
            logits = np.dot(hidden, sm_W) + sm_b

            key, subkey = random.split(key)

            samples = random.categorical(
                subkey, logits, axis=1, shape=None)  # sampling the conditional
            samples = one_hot(
                samples, sm_b.shape[0])  # convert to one hot encoded vector
            log_P_new = np.sum(samples * log_softmax(logits), axis=1)
            log_P_new = log_P_new + logP  # update the value of the logP of the sample

            return (key, samples, log_P_new, output), samples
Example #7
0
    def get_boundaries(self, encodings, segment_id, lengths, training):
        """Get boundaries (b) for a single segment in batch."""
        if segment_id == self.max_num_segments - 1:
            # Last boundary is always placed on last sequence element.
            logits_b = None
            # sample_b = jnp.zeros_like(encodings[:, :, 0]).scatter_(
            #     1, jnp.expand_dims(lengths, -1) - 1, 1)
            sample_b = jnp.zeros_like(encodings[:, :, 0])
            sample_b = jax.ops.index_update(
                sample_b, jax.ops.index[jnp.arange(len(lengths)), lengths - 1],
                1)
        else:
            hidden = nn.relu(self.head_b_1(encodings))
            logits_b = jnp.squeeze(self.head_b_2(hidden), -1)
            # Mask out first position with large neg. value.
            neg_inf = jnp.ones((encodings.shape[0], 1)) * utils.NEG_INF
            # TODO(tkipf): Mask out padded positions with large neg. value.
            logits_b = jnp.concatenate([neg_inf, logits_b[:, 1:]], axis=1)
            if training:
                sample_b = utils.gumbel_softmax_sample(hk.next_rng_key(),
                                                       logits_b,
                                                       temp=self.temp_b)
            else:
                sample_b_idx = jnp.argmax(logits_b, axis=1)
                sample_b = nn.one_hot(sample_b_idx, logits_b.shape[1])

        return logits_b, sample_b
Example #8
0
    def evaluate_batch(self, flax_module, batch_stats, batch):
        """Evaluates cross_entopy on the given batch."""

        # TODO(ankugarg): Augment with other metrics like log-perplexity.
        with nn.stateful(batch_stats, mutable=False):
            logits = flax_module(batch['inputs'],
                                 batch['targets'],
                                 batch.get('inputs_positions'),
                                 batch.get('targets_positions'),
                                 batch.get('inputs_segmentation'),
                                 batch.get('targets_segmentation'),
                                 train=False)

        weights = batch.get('weights')
        targets = batch['targets']
        if self.dataset_meta_data['apply_one_hot_in_loss']:
            targets = one_hot(batch['targets'], logits.shape[-1])

        # Add log-perplexity metric.
        evaluated_metrics = {}
        for key in self.metrics_bundle:
            per_example_metrics = self.metrics_bundle[key](logits, targets,
                                                           weights)
            evaluated_metrics[key] = jnp.sum(
                lax.psum(per_example_metrics, axis_name='batch'))

        return evaluated_metrics
Example #9
0
    def training_cost(self, flax_module, batch_stats, batch, dropout_rng):
        """Return cross entropy loss with (optional) L2 penalty on the weights."""

        with nn.stateful(batch_stats) as new_batch_stats:
            with nn.stochastic(dropout_rng):
                # inputs/targets positions and segmentations are required
                # when we have packed examples.
                logits = flax_module(batch['inputs'],
                                     batch['targets'],
                                     batch.get('inputs_positions'),
                                     batch.get('targets_positions'),
                                     batch.get('inputs_segmentation'),
                                     batch.get('targets_segmentation'),
                                     train=True)

        weights = batch.get('weights')
        targets = batch['targets']

        if self.dataset_meta_data['apply_one_hot_in_loss']:
            targets = one_hot(batch['targets'], logits.shape[-1])
        # Optionally apply label smoothing.
        if self.hps.get('label_smoothing') is not None:
            targets = model_utils.apply_label_smoothing(
                targets, self.hps.get('label_smoothing'))
        total_loss = self.loss_fn(logits, targets, weights)

        if self.hps.get('l2_decay_factor'):
            l2_loss = model_utils.l2_regularization(
                flax_module.params, self.hps.l2_decay_rank_threshold)
            total_loss += 0.5 * self.hps.l2_decay_factor * l2_loss
        return total_loss, (new_batch_stats)
Example #10
0
 def viterbi_backward(state, t):
     state = jnp.where(
         t <= length,
         jnp.sum(most_likely_sources[t] * one_hot(state, n_states)).astype(
             jnp.int64), state)
     most_likely = jnp.where(t <= length, state, -1)
     return state, most_likely
Example #11
0
    def evaluate_batch(self, params, batch_stats, batch):
        """Evaluates cross_entopy on the given batch."""

        # TODO(ankugarg): Augment with other metrics like log-perplexity.
        logits = self.flax_module.apply(
            {
                'params': params,
                'batch_stats': batch_stats
            },
            batch['inputs'],
            batch['targets'],
            inputs_positions=batch.get('inputs_positions'),
            targets_positions=batch.get('targets_positions'),
            inputs_segmentation=batch.get('inputs_segmentation'),
            targets_segmentation=batch.get('targets_segmentation'),
            train=False)

        weights = batch.get('weights')
        targets = batch['targets']
        if self.dataset_meta_data['apply_one_hot_in_loss']:
            targets = one_hot(batch['targets'], logits.shape[-1])

        # Add log-perplexity metric.
        return self.metrics_bundle.gather_from_model_output(logits=logits,
                                                            targets=targets,
                                                            weights=weights,
                                                            axis_name='batch')
Example #12
0
def mace(positions, annotations):
    """
    This model corresponds to the plate diagram in Figure 3 of reference [1].
    """
    num_annotators = int(np.max(positions)) + 1
    num_classes = int(np.max(annotations)) + 1
    num_items, num_positions = annotations.shape

    with numpyro.plate("annotator", num_annotators):
        epsilon = numpyro.sample("epsilon",
                                 dist.Dirichlet(jnp.full(num_classes, 10)))
        theta = numpyro.sample("theta", dist.Beta(0.5, 0.5))

    with numpyro.plate("item", num_items, dim=-2):
        c = numpyro.sample(
            "c",
            dist.DiscreteUniform(0, num_classes - 1),
            infer={"enumerate": "parallel"},
        )

        with numpyro.plate("position", num_positions):
            s = numpyro.sample(
                "s",
                dist.Bernoulli(1 - theta[positions]),
                infer={"enumerate": "parallel"},
            )
            probs = jnp.where(s[..., None] == 0, nn.one_hot(c, num_classes),
                              epsilon[positions])
            numpyro.sample("y", dist.Categorical(probs), obs=annotations)
Example #13
0
def MCMC_SpinFlip(h, e, seq_1hot, H, key):

    (L, q) = h.shape

    # advance RNG
    key, subkey_site, subkey_char, subkey_accept = random.split(key, 4)

    # choose site to flip
    i = random.choice(subkey_site, np.arange(L))

    # pick random character for chosen site
    a = random.choice(subkey_char, np.arange(q))

    # make proposal sequence
    seq_1hot_tmp = index_update(seq_1hot, index[i, :], nn.one_hot(a, q))

    H_tmp = Potts_ScoreSeqCore(h, e, seq_1hot_tmp)

    accept_prob = np.exp(H_tmp - H)

    flip = np.zeros((L, q), dtype=np.bool_)
    accept_draw = random.uniform(subkey_accept)
    flip = index_update(flip, index[i, :], accept_draw < accept_prob)

    return np.where(flip, seq_1hot_tmp,
                    seq_1hot), np.where(accept_draw < accept_prob, H_tmp,
                                        H), key
Example #14
0
 def sample(self,
            rng: Generator,
            shape: Optional[Shape] = None) -> RealArray:
     if shape is not None:
         shape += self.shape
     return one_hot(
         jax.random.categorical(rng.key, self.log_odds, shape=shape),
         self.log_odds.shape[-1])
Example #15
0
 def loss(params, batch):
     inputs, targets = batch
     one_hot_targets = one_hot(targets, num_classes=num_symbols)
     preds, _ = transformer(params,
                            inputs,
                            out_seq_length=seq_length,
                            num_symbols=num_symbols)
     return -jnp.mean(
         jnp.sum(jnp.sum(preds * one_hot_targets, axis=-1), axis=-1))
Example #16
0
def cross_entropy_loss(outputs, actions):
    '''
    Calculates cross entropy loss (-sum(actions * log(outputs)))
    '''
    l = len(actions[0])
    for i in range(len(actions)):
        # one hot actions
        actions[i] = nn.one_hot(actions[i], l)
    return -1 * jnp.sum(actions * jnp.log(outputs))
Example #17
0
def transformer(params, inputs, out_seq_length, num_symbols):
    one_hot_inputs = one_hot(inputs, num_classes=num_symbols)

    one_hot_outputs = jnp.zeros(
        [inputs.shape[0], out_seq_length + 1, num_symbols + 1])
    #one_hot_outputs[:, 0, -1] = 1.  # start symbol
    one_hot_outputs = ops.index_update(one_hot_outputs, ops.index[:, 0, -1],
                                       1.)
    output_logits = jnp.zeros([inputs.shape[0], out_seq_length, num_symbols])

    (input_embeddings, output_embeddings, encoder_params_list,
     decoder_params_list, output_params) = params
    output_w, output_b = output_params

    encoder_results = jnp.dot(one_hot_inputs, input_embeddings)
    encoder_results += positional_encodings(encoder_results.shape[-2],
                                            encoder_results.shape[-1])
    for encoder_layer_i, this_enc_params in enumerate(encoder_params_list):
        encoder_results = encoder_layer(this_enc_params, encoder_results)

    for i in range(out_seq_length):
        decoder_results = jnp.dot(one_hot_outputs[:, :i + 1, :],
                                  output_embeddings)
        decoder_results += positional_encodings(decoder_results.shape[-2],
                                                decoder_results.shape[-1])
        for decoder_layer_i, this_dec_params in enumerate(decoder_params_list):
            decoder_results = decoder_layer(this_dec_params, decoder_results,
                                            encoder_results)

        this_step_results = jnp.dot(decoder_results, output_w) + output_b
        this_step_results = this_step_results[:,
                                              -1, :]  # previous outputs already produced

        output_logits = ops.index_update(output_logits, ops.index[:, i, :],
                                         this_step_results)
        one_hot_outputs = ops.index_update(
            one_hot_outputs, ops.index[:, i + 1, :],
            one_hot(jnp.argmax(this_step_results, axis=-1),
                    num_classes=num_symbols + 1))

    one_hot_outputs = one_hot_outputs[:, 1:, :-1]  # chop start symbol
    return output_logits, one_hot_outputs
Example #18
0
def _evaluate_batch(flax_module, batch_stats, batch, metrics_bundle,
                    apply_one_hot_in_loss):
    """Evaluates metrics on the given batch.

  Currently we assume each metric_fn in metrics_bundle has the API:
    metric_fn(logits, targets, weights)
  and returns an array of shape [batch_size]. We also assume that to compute
  the aggregate metric, one should sum across all batches, then divide by the
  total samples seen (calculated by the 'denominator' metric). In this way we
  currently only support metrics of the 1/N sum f(inputs, targets). Note, the
  caller is responsible for dividing by metrics['denominator'] when computing
  the mean of each metric.

  Args:
    flax_module: A flax.nn.Module
    batch_stats: A flax.nn.Collection object tracking batch_stats.
    batch: A dictionary with keys 'inputs', 'targets', 'weights'.
    metrics_bundle: A group of metrics to use for evaluation.
    apply_one_hot_in_loss: Indicates whether or not the targets are one hot
      encoded.

  Returns:
    A dictionary with the same keys as metrics, but mapping to the summed metric
    across the sharded batch_dim.

  """
    with nn.stateful(batch_stats, mutable=False):
        logits = flax_module(batch['inputs'], train=False)
    targets = batch['targets']

    if apply_one_hot_in_loss:
        targets = one_hot(batch['targets'], logits.shape[-1])

    # map the dict values (which are functions) to function(targets, logits)
    weights = batch.get('weights')  # Weights might not be defined.
    eval_batch_size = targets.shape[0]
    if weights is None:
        weights = jnp.ones(eval_batch_size)

    # This psum is required to correctly evaluate with multihost. Only host 0
    # will report the metrics, so we must aggregate across all hosts. The psum
    # will map an array of shape [n_global_devices, batch_size] -> [batch_size]
    # by summing across the devices dimension. The outer sum then sums across the
    # batch dim. The result is the we have summed across all samples in the
    # sharded batch.

    evaluated_metrics = {}
    for key in metrics_bundle:
        per_example_metrics = metrics_bundle[key](logits, targets, weights)
        evaluated_metrics[key] = jnp.sum(
            lax.psum(per_example_metrics, axis_name='batch'))

    return evaluated_metrics
Example #19
0
    def transition_fn(carry, t):
        lam_prev = carry

        U = jnp.log(lam_prev) - jnp.log(lam_prev.sum(-1, keepdims=True))

        logs = logits((beliefs[0][t], beliefs[1][t]),
                      jnp.expand_dims(gamma, -1), jnp.expand_dims(U, -2))

        lam_next = npyro.deterministic(
            'lams', lam_prev +
            nn.one_hot(beliefs[2][t], 4) * jnp.expand_dims(mask[t] * eta, -1))

        npyro.sample('y', dist.CategoricalLogits(logs).mask(mask[t]))

        return lam_next, None
Example #20
0
    def _predict(self, params, base_preds, context, return_probs, target=None):
        # Base logits
        base_preds = jnp.clip(base_preds,
                              a_min=self.pred_clipping,
                              a_max=(1.0 - self.pred_clipping))
        logits = jsp.special.logit(base_preds)
        logits = jnp.expand_dims(logits, axis=1)
        if self.num_classes == 2:
            logits = jnp.tile(logits, reps=(1, 1, 1))
        else:
            logits = jnp.tile(logits, reps=(1, self.num_classes, 1))

        # Turn target class into one-hot
        if target is not None:
            target = jnn.one_hot(target, num_classes=self.num_classes)
            if self.num_classes == 2:
                target = target[:, 1:]

        # Layers
        if target is None:
            for n, layer in enumerate(self.layers):
                logits = layer.predict(params=params[f'layer{n}'],
                                       logits=logits,
                                       context=context)
        else:
            for n, layer in enumerate(self.layers):
                params[f'layer{n}'], logits = layer.predict(
                    params=params[f'layer{n}'],
                    logits=logits,
                    context=context,
                    target=target)

        logits = jnp.squeeze(logits, axis=-1)
        if self.num_classes == 2:
            logits = jnp.squeeze(logits, axis=1)

        # Output prediction
        if return_probs:
            prediction = jnn.sigmoid(logits)
        elif self.num_classes == 2:
            prediction = logits > 0.0
        else:
            prediction = jnp.argmax(logits, axis=1)

        if target is None:
            return prediction
        else:
            return params, prediction
Example #21
0
def image_iterator(data,
                   rescale,
                   output_shape,
                   is_one_hot,
                   autoencoder,
                   shuffle_rng=None,
                   augment_fn=None,
                   include_example_keys=False):
    """Preprocesses the batch data arrays in the data generator.

  Rescales inputs. One hot encode targets if is_one_hot is true.
  Set targets to inputs of output_shape if autoencoder is true.

  Args:
    data: An iterator generating dicts of input and target data arrays.
    rescale: A lambda function preprocessing input data arrays.
    output_shape: Shape of network output.
    is_one_hot: If true, targets are one hot encoded.
    autoencoder: If true, targets are set to inputs.
    shuffle_rng: jax.random.PRNGKey
    augment_fn: The number of classes used for the dataset.
    include_example_keys: If True, then the tfds_id will be exposed in each
      batch dict of the validation set under the key `example_key`.

  Yields:
    A dictionary mapping keys ('image', 'label') to preprocessed data arrays.
  """
    for batch_index, batch in enumerate(iterator_as_numpy(iter(data))):
        inputs = batch['image']
        targets = batch['label']
        if is_one_hot:
            targets = one_hot(batch['label'], output_shape[-1])
        if augment_fn:
            batch_rng = jax.random.fold_in(shuffle_rng, batch_index)
            inputs, targets = augment_fn(batch_rng, inputs, targets)
        inputs = rescale(inputs)
        if autoencoder:
            batch_output_shape = tuple([inputs.shape[0]] + list(output_shape))
            targets = inputs.reshape(batch_output_shape)
        if include_example_keys:
            yield {
                'inputs': inputs,
                'targets': targets,
                'tfds_id': batch['tfds_id']
            }
        else:
            yield {'inputs': inputs, 'targets': targets}
Example #22
0
def MCMC_SeqEmit(h, e, key, nflip):
    (L, q) = h.shape

    seq_1hot = nn.one_hot(
        random.choice(key, np.arange(0, q), shape=(1, L))[0], q)
    H = Potts_ScoreSeqCore(h, e, seq_1hot)

    @jit
    def loop_fun_scan(loop_carry, i):
        h, e, seq_1hot, H, key = loop_carry
        seq_1hot, H, key = MCMC_SpinFlip(h, e, seq_1hot, H, key)
        return (h, e, seq_1hot, H, key), i

    (h, e, seq_1hot, H, key), i = lax.scan(loop_fun_scan,
                                           (h, e, seq_1hot, H, key),
                                           None,
                                           length=nflip)

    return seq_1hot, H, key
Example #23
0
    def transition_fn(carry, t):
        lam_prev, x_prev = carry

        gamma = npyro.deterministic('gamma', nn.softplus(mu + x_prev))

        U = jnp.log(lam_prev) - jnp.log(lam_prev.sum(-1, keepdims=True))

        logs = logits((beliefs[0][t], beliefs[1][t]),
                      jnp.expand_dims(gamma, -1), jnp.expand_dims(U, -2))

        lam_next = npyro.deterministic(
            'lams', lam_prev +
            nn.one_hot(beliefs[2][t], 4) * jnp.expand_dims(mask[t] * eta, -1))

        npyro.sample('y', dist.CategoricalLogits(logs).mask(mask[t]))
        noise = npyro.sample('dw', dist.Normal(0., 1.))

        x_next = rho * x_prev + sigma * noise

        return (lam_next, x_next), None
Example #24
0
def mace(positions, annotations):
    """
    This model corresponds to the plate diagram in Figure 3 of reference [1].
    """
    num_annotators = int(np.max(positions)) + 1
    num_classes = int(np.max(annotations)) + 1
    num_items, num_positions = annotations.shape

    with numpyro.plate("annotator", num_annotators):
        epsilon = numpyro.sample("epsilon", dist.Dirichlet(jnp.full(num_classes, 10)))
        theta = numpyro.sample("theta", dist.Beta(0.5, 0.5))

    with numpyro.plate("item", num_items, dim=-2):
        # NB: using constant logits for discrete uniform prior
        # (NumPyro does not have DiscreteUniform distribution yet)
        c = numpyro.sample("c", dist.Categorical(logits=jnp.zeros(num_classes)))

        with numpyro.plate("position", num_positions):
            s = numpyro.sample("s", dist.Bernoulli(1 - theta[positions]))
            probs = jnp.where(s[..., None] == 0, nn.one_hot(c, num_classes), epsilon[positions])
            numpyro.sample("y", dist.Categorical(probs), obs=annotations)
Example #25
0
    def training_cost(self, flax_module, batch_stats, batch, dropout_rng):
        """Return loss with an L2 penalty on the weights."""
        with nn.stateful(batch_stats) as new_batch_stats:
            with nn.stochastic(dropout_rng):
                logits = flax_module(batch['inputs'], train=True)
        weights = batch.get('weights')
        targets = batch['targets']
        if self.dataset_meta_data['apply_one_hot_in_loss']:
            targets = one_hot(targets, logits.shape[-1])
        # Optionally apply label smoothing.
        if self.hps.get('label_smoothing') is not None:
            targets = model_utils.apply_label_smoothing(
                targets, self.hps.get('label_smoothing'))
        total_loss = self.loss_fn(logits, targets, weights)

        if self.hps.get('l2_decay_factor'):
            l2_loss = model_utils.l2_regularization(
                flax_module.params, self.hps.l2_decay_rank_threshold)
            total_loss += 0.5 * self.hps.l2_decay_factor * l2_loss

        return total_loss, (new_batch_stats)
Example #26
0
def thompson_sampling_step(model_params, state, model, environment):
    """
    Contextual implementation of the Thompson sampling algorithm.
    This implementation considers a single step
    
    Parameters
    ----------
    model_params: dict
    environment: function
    key: jax.random.PRNGKey
    moidel: instance of a Bandit model
    """
    key, context = state
    key_sample, key_reward = random.split(key)
    # Sample an choose an action
    params = model.sample(key_sample, model_params, context)
    pred_rewards = model.predict_rewards(params, context)
    action = pred_rewards.argmax()
    # environment reward
    reward = environment(key_reward, action, context)
    model_params = model.update(action, model_params, context, reward)
    
    arm_reward = one_hot(action, K) * reward
    return model_params, (model_params, arm_reward)
Example #27
0
    def training_cost(self, params, batch, batch_stats=None, dropout_rng=None):
        """Return cross entropy loss with (optional) L2 penalty on the weights."""

        # inputs/targets positions and segmentations are required when we have
        # packed examples.
        logits, new_batch_stats = self.flax_module.apply(
            {
                'params': params,
                'batch_stats': batch_stats
            },
            batch['inputs'],
            batch['targets'],
            batch.get('inputs_positions'),
            batch.get('targets_positions'),
            batch.get('inputs_segmentation'),
            batch.get('targets_segmentation'),
            mutable=['batch_stats'],
            rngs={'dropout': dropout_rng},
            train=True)

        weights = batch.get('weights')
        targets = batch['targets']

        if self.dataset_meta_data['apply_one_hot_in_loss']:
            targets = one_hot(batch['targets'], logits.shape[-1])
        # Optionally apply label smoothing.
        if self.hps.get('label_smoothing') is not None:
            targets = model_utils.apply_label_smoothing(
                targets, self.hps.get('label_smoothing'))
        total_loss = self.loss_fn(logits, targets, weights)

        if self.hps.get('l2_decay_factor'):
            l2_loss = model_utils.l2_regularization(
                params, self.hps.l2_decay_rank_threshold)
            total_loss += 0.5 * self.hps.l2_decay_factor * l2_loss
        return total_loss, (new_batch_stats)
Example #28
0
 def testOneHotCustomDtype(self):
     actual = nn.one_hot(jnp.array([0, 1, 2]), 3, dtype=jnp.bool_)
     expected = jnp.array([[True, False, False], [False, True, False],
                           [False, False, True]])
     self.assertAllClose(actual, expected)
Example #29
0
 def testOneHotNonArrayInput(self):
     actual = nn.one_hot([0, 1, 2], 3)
     expected = jnp.array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]])
     self.assertAllClose(actual, expected)
Example #30
0
 def testOneHotOutOfBound(self):
     actual = nn.one_hot(jnp.array([-1, 3]), 3)
     expected = jnp.array([[0., 0., 0.], [0., 0., 0.]])
     self.assertAllClose(actual, expected)