Ejemplo n.º 1
0
 def f(model_output, targets, weights):  # pylint: disable=invalid-name
     predictions = jnp.argmax(model_output, axis=-1)
     shapes.assert_same_shape(predictions, targets)
     position_is_padding = jnp.equal(weights, 0)
     position_is_accurate = jnp.logical_or(jnp.equal(predictions, targets),
                                           position_is_padding)
     sequence_is_accurate = jnp.all(position_is_accurate, axis=-1)
     return jnp.average(sequence_is_accurate)
Ejemplo n.º 2
0
 def forward(self, x):
     rng = self.rng
     base_weights, start_vec = self.weights
     batch_size, length = x.shape[0], x.shape[1]
     max_pos = min(self._bases)**self._n_digits
     rng1, rng2, rng3 = fastmath.random.split(rng, 3)
     assert length < max_pos, 'length (%d) >= max_pos (%d)' % (length,
                                                               max_pos)
     positions = jnp.arange(0, length)[None, :]
     # In training we'll randomize starts for better generalization.
     # We use the trainable start_vec to compensate and give model a way
     # to learn what is the starting position in a sequence.
     if self._mode == 'train':
         # In 1% of training cases still start from 0 to be exactly as in eval.
         start_from_nonzero = fastmath.random.randint(
             rng1, (batch_size, ), 0, self._start_from_zero_one_in)
         start_from_nonzero = jnp.minimum(1, start_from_nonzero)
         random_start = fastmath.random.randint(rng2, (batch_size, ), 0,
                                                max_pos - length)
         random_start *= start_from_nonzero
         positions += random_start[:, None]
     if self._mode == 'predict':
         positions += self.state
     res = []
     for bn, base in enumerate(self._bases):
         pos_embeddings = []
         cur_positions = positions
         for i in range(self._n_digits):
             cur_indices = jnp.mod(cur_positions, base)
             cur_positions = cur_positions // base
             s = base_weights[bn][i]
             pos_embeddings.append(
                 cur_indices.astype(jnp.float32)[:, :, None] * s)
         embeddings = jnp.concatenate(pos_embeddings, axis=-1)
         if self._mode == 'train':
             base_dropout = fastmath.random.randint(
                 rng3, (batch_size, ), 0, self._base_dropout_one_in)
             base_dropout = jnp.minimum(1, base_dropout).astype(jnp.float32)
             embeddings *= base_dropout[:, None, None]
         res.append(embeddings)
     res = sum(res)  # Sum embeddings from all bases.
     # Add start_vec to the first position only to mark it as starting.
     res0 = res[:, 0, :][:, None, :]
     start_pos = res0 + start_vec
     if self._mode == 'predict':
         start_pos = jnp.where(jnp.equal(self.state, 0), start_pos, res0)
         self.state += length  # Add input length to state.
     res = jnp.concatenate([start_pos, res[:, 1:, :]], axis=1)
     return x + res
Ejemplo n.º 3
0
def test_model(preds, target):
    """Function to test the model.

    Args:
        preds (jax.interpreters.xla.DeviceArray): Predictions of a list of batches of tensors corresponding to lines of text.
        target (jax.interpreters.xla.DeviceArray): Actual list of batches of tensors corresponding to lines of text.

    Returns:
        float: log_perplexity of the model.
    """
    ### START CODE HERE (Replace instances of 'None' with your code) ###
    total_log_ppx = np.sum(preds*tl.one_hot(target,preds.shape[-1]) , axis= -1) # HINT: tl.one_hot() should replace one of the Nones

    non_pad = 1.0 - np.equal(target, 0)          # You should check if the target equals 0
    ppx = total_log_ppx * non_pad                             # Get rid of the padding

    log_ppx = np.sum(ppx) / np.sum(non_pad)
    ### END CODE HERE ###
    
    return -log_ppx
Ejemplo n.º 4
0
 def f(model_output, targets):  # pylint: disable=invalid-name
     predictions = jnp.argmax(model_output, axis=-1)
     shapes.assert_same_shape(predictions, targets)
     n_total = predictions.size
     n_correct = jnp.sum(jnp.equal(predictions, targets))
     return n_correct / n_total
Ejemplo n.º 5
0
 def f(predicted_category, target_category):  # pylint: disable=invalid-name
     # TODO(pkozakowski): This assertion breaks some tests. Fix and uncomment.
     # shapes.assert_same_shape(predicted_category, target_category)
     return jnp.equal(predicted_category,
                      target_category).astype(jnp.float32)
Ejemplo n.º 6
0
 def f(model_output, targets, weights):  # pylint: disable=invalid-name
     predictions = jnp.argmax(model_output, axis=-1)
     shapes.assert_same_shape(predictions, targets)
     ones_and_zeros = jnp.equal(predictions, targets)
     return jnp.sum(ones_and_zeros * weights) / jnp.sum(weights)
Ejemplo n.º 7
0
# Cast to jax.interpreters.xla.DeviceArray
predictions = np.array(predictions)
targets = np.array(targets)

reshaped_targets = tl.one_hot(targets, predictions.shape[-1])
    #trax's one_hot function takes the input as one_hot(x, n_categories, dtype=optional)
print(f'reshaped_targets has shape: {reshaped_targets.shape}')

# Total Log Perplexity
total_log_ppx = np.sum(predictions * reshaped_targets, axis= -1)

'''
Now you will need to account for the padding so this metric is not artificially deflated (since a lower perplexity means a better model). For identifying which elements are padding and which are not, you can use np.equal() and get a tensor with 1s in the positions of actual values and 0s where there are paddings.
'''

non_pad = 1.0 - np.equal(targets, 0)
print(f'non_pad has shape: {non_pad.shape}\n')
print(f'non_pad looks like this: \n\n {non_pad}')

real_log_ppx = total_log_ppx * non_pad
print(f'real perplexity still has shape: {real_log_ppx.shape}')

'''
real perplexity still has shape: (32, 64)
'''
print(f'log perplexity tensor before filtering padding: \n\n {total_log_ppx}\n')
print(f'log perplexity tensor after filtering padding: \n\n {real_log_ppx}')

log_ppx = np.sum(real_log_ppx) / np.sum(non_pad)
log_ppx = -log_ppx
print(f'The log perplexity and perplexity of the model are respectively: {log_ppx} and {np.exp(log_ppx)}')