Example #1
0
def preprocess_causal(batch, bos_token, pad_token, mode):
  """Preprocessing for causal language modeling.

  Right shifts and shards.

  Args:
    batch: [batch x length] tokens.
    bos_token: Int ID to use as beginning of sentence token.
    pad_token: Padding token which should be masked out in loss.
    mode: Mode value.

  Returns:
    Tuple of [batch x length] inputs, targets, per position weights. Targets
      will have random positions masked out with either a MASK token, or a
      randomly chosen token from the vocabulary.
  """
  if mode == Mode.sample:
    inputs = batch
  else:
    inputs = modules.shift_right(batch, bos_token=bos_token)

  targets = batch
  # Mask out PAD in loss.
  if pad_token is None:
    weights = jnp.ones_like(targets)
  else:
    weights = jnp.where(targets != pad_token, 1, 0)
  return inputs, targets, weights
Example #2
0
def eval_step(model, inputs, bos_token):
    weights = jnp.where(inputs != bos_token, 1, 0)
    outputs = inputs
    inputs = modules.shift_right(
        inputs, bos_token=bos_token)  # Do before input at test time.
    logits = model(inputs, train=False, cache=None)
    return utils.compute_metrics(logits, outputs, weights)
Example #3
0
def predict_step(model, inputs, bos_token, output_head='logits'):
    inputs = modules.shift_right(
        inputs, bos_token=bos_token)  # Do before input at test time.
    logits = model(inputs, train=False, cache=None, output_head=output_head)
    return logits