def preprocess_causal(batch, bos_token, pad_token, mode): """Preprocessing for causal language modeling. Right shifts and shards. Args: batch: [batch x length] tokens. bos_token: Int ID to use as beginning of sentence token. pad_token: Padding token which should be masked out in loss. mode: Mode value. Returns: Tuple of [batch x length] inputs, targets, per position weights. Targets will have random positions masked out with either a MASK token, or a randomly chosen token from the vocabulary. """ if mode == Mode.sample: inputs = batch else: inputs = modules.shift_right(batch, bos_token=bos_token) targets = batch # Mask out PAD in loss. if pad_token is None: weights = jnp.ones_like(targets) else: weights = jnp.where(targets != pad_token, 1, 0) return inputs, targets, weights
def eval_step(model, inputs, bos_token): weights = jnp.where(inputs != bos_token, 1, 0) outputs = inputs inputs = modules.shift_right( inputs, bos_token=bos_token) # Do before input at test time. logits = model(inputs, train=False, cache=None) return utils.compute_metrics(logits, outputs, weights)
def predict_step(model, inputs, bos_token, output_head='logits'): inputs = modules.shift_right( inputs, bos_token=bos_token) # Do before input at test time. logits = model(inputs, train=False, cache=None, output_head=output_head) return logits