Ejemplo n.º 1
0
 def cross_entropy_loss(logits, labels):
     start_loss = optax.softmax_cross_entropy(
         logits[0], onehot(labels[0], num_classes=num_labels))
     end_loss = optax.softmax_cross_entropy(
         logits[1], onehot(labels[1], num_classes=num_labels))
     xentropy = (start_loss + end_loss) / 2.0
     return jnp.mean(xentropy)
Ejemplo n.º 2
0
    def compute_metrics(
        masked_lm_logits: jnp.ndarray,
        next_sentence_logits: jnp.ndarray,
        masked_lm_labels: jnp.ndarray,
        masked_lm_weights: jnp.ndarray,
        next_sentence_labels: jnp.ndarray,
    ):
        """Computes the pre-training loss and its components."""
        masked_lm_logits = nn.log_softmax(masked_lm_logits)
        masked_lm_labels = onehot(masked_lm_labels.reshape((-1, )),
                                  masked_lm_logits.shape[-1])
        masked_lm_weights = masked_lm_weights.reshape((-1, ))
        masked_lm_loss = -jnp.sum(
            jnp.sum(masked_lm_logits * masked_lm_labels, axis=-1) *
            masked_lm_weights) / jnp.sum(masked_lm_weights)

        next_sentence_logits = nn.log_softmax(next_sentence_logits)
        next_sentence_labels = next_sentence_labels.reshape((-1, ))
        next_sentence_loss = -jnp.mean(
            jnp.sum(
                onehot(next_sentence_labels, next_sentence_logits.shape[-1]) *
                next_sentence_logits,
                axis=-1,
            ))
        return {
            "loss": masked_lm_loss + next_sentence_loss,
            "masked_lm_loss": masked_lm_loss,
            "next_sentence_loss": next_sentence_loss,
        }
Ejemplo n.º 3
0
    def apply(self,
              input_ids,
              input_mask,
              type_ids,
              labels=None,
              *,
              config,
              n_classes,
              deterministic=False):
        """Applies BERT for sequence classification."""
        unused_sequence_output, pooled_output = BertModel(
            input_ids,
            input_mask,
            type_ids,
            config=config,
            deterministic=deterministic,
            name="bert")
        # TODO(kitaev): I think I'm missing dropout here
        logits = layers.OutputProjection(pooled_output,
                                         n_out=n_classes,
                                         kernel_init=kernel_initializer,
                                         name="classification")

        if labels is None:
            return logits
        elif logits.shape[-1] == 1:
            # Regression task
            loss = jnp.mean((logits[Ellipsis, 0] - labels)**2)
            return {"loss": loss}
        else:
            # Classification task
            logits = nn.log_softmax(logits)
            loss = -jnp.mean(
                jnp.sum(onehot(labels, logits.shape[-1]) * logits, axis=-1))
            return {"loss": loss}
Ejemplo n.º 4
0
def sample(inputs, optimizer):
    next_inputs = inputs
    output = []
    batch_size = 1
    carry1 = nn.LSTMCell.initialize_carry(jax.random.PRNGKey(0),
                                          (batch_size, ), 512)
    carry2 = nn.LSTMCell.initialize_carry(jax.random.PRNGKey(0),
                                          (batch_size, ), 512)
    carry3 = nn.LSTMCell.initialize_carry(jax.random.PRNGKey(0),
                                          (batch_size, ), 512)
    carry = [carry1, carry2, carry3]

    def inference(model, carry):
        carry, rnn_output = model(inputs=next_inputs,
                                  train=False,
                                  carry_pred=carry)
        return carry, rnn_output

    for i in range(200):
        carry, rnn_output = inference(optimizer.target, carry)
        output.append(jnp.argmax(rnn_output, axis=-1))
        # Select the argmax as the next input.
        next_inputs = jnp.expand_dims(common_utils.onehot(
            jnp.argmax(rnn_output), params['vocab_length']),
                                      axis=0)
    return output
Ejemplo n.º 5
0
    def __call__(
        self,
        input_ids: jnp.ndarray,
        input_mask: jnp.ndarray,
        type_ids: jnp.ndarray,
        labels: jnp.ndarray = None,
        *,
        deterministic: bool = False,
    ):
        """Applies BERT for sequence classification."""
        bert = BertModel(config=self.config, name="bert")
        _, pooled_output = bert(input_ids,
                                input_mask,
                                type_ids,
                                deterministic=deterministic)
        pooled_output = nn.Dropout(rate=self.config.hidden_dropout_prob,
                                   deterministic=deterministic)(pooled_output)
        logits = layers.OutputProjection(
            n_out=self.n_classes,
            kernel_init=get_kernel_init(self.config),
            name="classification",
        )(pooled_output)

        if labels is None:
            return logits
        elif logits.shape[-1] == 1:
            # Regression task
            loss = jnp.mean((logits[..., 0] - labels)**2)
            return {"loss": loss}
        else:
            # Classification task
            logits = nn.log_softmax(logits)
            loss = -jnp.mean(
                jnp.sum(onehot(labels, logits.shape[-1]) * logits, axis=-1))
            return {"loss": loss}
Ejemplo n.º 6
0
  def loss_fn(params):
    """loss function used for training."""
    logits = models.Transformer(config).apply(
        {"params": params},
        inputs,
        targets,
        inputs_positions=inputs_positions,
        targets_positions=targets_positions,
        inputs_segmentation=inputs_segmentation,
        targets_segmentation=targets_segmentation,
        rngs={"dropout": dropout_rng})

    vocab_size = logits.shape[-1]
    confidence = 1.0 - label_smoothing
    low_confidence = (1.0 - confidence) / (vocab_size - 1)
    normalizing_constant = -(
        confidence * jnp.log(confidence) +
        (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20))
    soft_targets = common_utils.onehot(
        targets, vocab_size, on_value=confidence, off_value=low_confidence)

    loss = -jnp.sum(soft_targets * nn.log_softmax(logits), axis=-1)
    loss = loss - normalizing_constant

    loss = loss * weights
    normalizing_factor = weights.sum()

    mean_loss = loss.sum() / normalizing_factor
    return mean_loss, logits
Ejemplo n.º 7
0
def cross_entropy(logits, targets, weights=None, label_smoothing=0.0):
    """Compute cross entropy and entropy for log probs and targets.
    Args:
     logits: [batch, length, num_classes] float array.
     targets: categorical targets [batch, length] int array.
     weights: None or array of shape [batch, length]
     label_smoothing: label smoothing constant, used to determine the on and off values.
    Returns:
      Tuple of scalar loss and batch normalizing factor.
    """
    if logits.ndim != targets.ndim + 1:
        raise ValueError(
            "Incorrect shapes. Got shape %s logits and %s targets" % (str(logits.shape), str(targets.shape))
        )

    vocab_size = logits.shape[-1]
    confidence = 1.0 - label_smoothing
    low_confidence = (1.0 - confidence) / (vocab_size - 1)
    normalizing_constant = -(
        confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
    )
    soft_targets = common_utils.onehot(targets, vocab_size, on_value=confidence, off_value=low_confidence)

    loss = -jnp.sum(soft_targets * log_softmax(logits), axis=-1)
    loss = loss - normalizing_constant

    if weights is not None:
        loss = loss * weights
        normalizing_factor = weights.sum()
    else:
        normalizing_factor = np.prod(targets.shape)

    return loss.sum(), normalizing_factor
Ejemplo n.º 8
0
def compute_weighted_cross_entropy(logits, targets, weights=None):
    """Compute weighted cross entropy and entropy for log probs and targets.

  Args:
   logits: `[batch, length, num_classes]` float array.
   targets: categorical targets `[batch, length]` int array.
   weights: None or array of shape [batch, length, 1]

  Returns:
    Tuple of scalar loss and batch normalizing factor.
  """
    if logits.ndim != targets.ndim + 1:
        raise ValueError(
            'Incorrect shapes. Got shape %s logits and %s targets' %
            (str(logits.shape), str(targets.shape)))
    if logits.shape[1] != targets.shape[1]:  # Truncate logits.
        logits = logits[:, :targets.shape[1]]

    onehot_targets = common_utils.onehot(targets, logits.shape[-1])
    loss = -jnp.sum(onehot_targets * nn.log_softmax(logits), axis=-1)
    normalizing_factor = jnp.prod(jnp.asarray(targets.shape))
    if weights is not None:
        loss = loss * weights
        normalizing_factor = weights.sum()

    return loss.sum(), normalizing_factor
    def test_small_byt5_integration_test(self):
        """
        For comparision run:
        >>> import t5  # pip install t5==0.9.1

        >>> path_to_byt5_small_checkpoint = '<fill_in>'
        >>> t5_model = t5.models.MtfModel(model_dir=path_to_tf_checkpoint, batch_size=1, tpu=None)
        >>> vocab = t5.data.ByteVocabulary()
        >>> score = t5_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab)
        """

        model = FlaxT5ForConditionalGeneration.from_pretrained(
            "google/byt5-small")
        tokenizer = ByT5Tokenizer.from_pretrained("google/byt5-small")

        input_ids = tokenizer("Hello there", return_tensors="np").input_ids
        labels = tokenizer("Hi I am", return_tensors="np").input_ids

        decoder_input_ids = shift_tokens_right(
            labels, model.config.pad_token_id,
            model.config.decoder_start_token_id)

        logits = model(input_ids, decoder_input_ids=decoder_input_ids).logits
        loss = optax.softmax_cross_entropy(logits,
                                           onehot(labels,
                                                  logits.shape[-1])).mean()

        mtf_score = -(labels.shape[-1] * loss.item())

        EXPECTED_SCORE = -60.7397
        self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4)
Ejemplo n.º 10
0
def compute_contrastive_loss(
    quantized_features, transformer_features, negative_indices, mask_time_indices, logits_temp, num_negatives
):
    batch_size, sequence_length, hidden_size = quantized_features.shape

    # take negative vectors from sampled indices
    quantized_negatives = quantized_features.reshape(-1, hidden_size)[negative_indices.reshape(-1)]
    quantized_negatives = quantized_negatives.reshape(
        batch_size, sequence_length, num_negatives, hidden_size
    ).transpose(2, 0, 1, 3)

    target_features = jnp.concatenate([quantized_features[None, :], quantized_negatives], axis=0)
    loss_logits = optax.cosine_similarity(transformer_features, target_features)
    loss_logits = loss_logits / logits_temp

    neg_is_pos = (quantized_features == quantized_negatives).all(-1)
    neg_is_pos = jnp.concatenate([jnp.full((1,) + loss_logits.shape[1:], False), neg_is_pos], axis=0)

    # make sure incorrectly sampled vectors don't contribute to loss
    loss_logits = jnp.where(neg_is_pos, -1e9, loss_logits)

    predictions = loss_logits.transpose(2, 1, 0).reshape(-1, loss_logits.shape[0])
    targets = ((1 - mask_time_indices) * -100).transpose(1, 0).flatten()

    target_mask = jnp.where(targets >= 0, 1.0, 0.0)
    contrastive_loss = optax.softmax_cross_entropy(predictions, onehot(targets, predictions.shape[-1])) * target_mask

    contrastive_loss = contrastive_loss.sum()

    return contrastive_loss
Ejemplo n.º 11
0
    def eval_step(params, batch):
        labels = batch.pop("labels")

        outputs = model(**batch,
                        output_attentions=True,
                        params=params,
                        train=False)
        logits = outputs["logits"]

        # compute loss
        loss = optax.softmax_cross_entropy(logits,
                                           onehot(labels, logits.shape[-1]))

        # compute accuracy
        accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels)

        # compute head specialization
        specialization = compute_specialization_metric(
            jnp.swapaxes(jnp.stack(outputs["encoder_attentions"]), 0, 1))

        # summarize metrics
        metrics = {
            "loss": loss.mean(),
            "accuracy": accuracy.mean(),
            "specialization": specialization
        }
        metrics = jax.lax.pmean(metrics, axis_name="batch")

        return metrics
Ejemplo n.º 12
0
    def __call__(self, input_ids, type_ids, labels=None, deterministic=False):
        """Applies model for sequence classification.

    Args:
      input_ids: <int>[BATCH_SIZE, MAX_SEQ_LENGTH] tokenized inputs.
      type_ids: <int>[BATCH_SIZE, MAX_SEQ_LENGTH] Ids partitioning input into
        different types.
      labels: True labels associated with inputs. Generally only required for
        training. Shape depends on task type:
        * Classification: <int>[BATCH_SIZE],
        * Regression: <float>[BATCH_SIZE].
      deterministic: Whether to apply dropout to input.

    Returns:
      * If labels supplied (training mode): Model loss and metrics.
      * If no labels supplied (prediction / evaluation mode): Logits with shape
        <float>[BATCH_SIZE, n_classes].
    """
        encoder_output = EncoderModel(self.config, name="encoder")(
            input_ids, type_ids, deterministic=deterministic)

        # All other classification and regression tasks use the pooled output.
        output = encoder_output.pooled_output
        # TODO(jamesleethorp): For WiC, the original SuperGLUE paper
        #  (https://arxiv.org/abs/1905.00537) concatenates the "CLS" and "word"
        #  output representations. We only use the pooled output.

        logits = layers.OutputProjection(n_out=self.n_classes,
                                         kernel_init=default_kernel_init,
                                         name="classification")(output)

        if labels is None:
            # Code path used during evaluation or prediction; metrics can be computed
            # from logits by the caller.
            return logits

        # Code path used during training.
        if (self.config.dataset_name == "glue/stsb" or  # Regression task
                self.config.dataset_name == "super_glue/copa"
                or  # "Regression" task
                self.config.dataset_name
                == "super_glue/record"):  # "Regression" task
            # Logits have shape: [BATCH_SIZE, 1].
            per_example_loss = jnp.sum((logits[Ellipsis, 0] - labels)**2,
                                       axis=-1)
            batch_loss = jnp.mean(per_example_loss)
            return ClassificationStats(batch_loss=batch_loss,
                                       num_labels=labels.size)

        else:  # Classification task
            # Logits have shape: [BATCH_SIZE, self.n_classes].
            logits = nn.log_softmax(logits, axis=-1)
            per_example_loss = -jnp.sum(
                onehot(labels, logits.shape[-1]) * logits, axis=-1)
            batch_loss = jnp.mean(per_example_loss)
            correct_predictions = jnp.sum(logits.argmax(-1) == labels)
            return ClassificationStats(batch_loss=batch_loss,
                                       num_labels=labels.size,
                                       correct_predictions=correct_predictions)
Ejemplo n.º 13
0
    def sampling_loop_body_fn(state):
        """Sampling loop state update."""
        i, sequences, cache, cur_token, ended, rng, tokens_to_logits_state = state

        # Split RNG for sampling.
        rng1, rng2 = random.split(rng)

        # Call fast-decoder model on current tokens to get raw next-position logits.
        logits, new_cache, new_tokens_to_logits_state = tokens_to_logits(
            cur_token, cache, internal_state=tokens_to_logits_state)
        logits = logits / temperature

        # Mask out the BOS token.
        if masked_tokens is not None:
            mask = common_utils.onehot(jnp.array(masked_tokens),
                                       num_classes=logits.shape[-1],
                                       on_value=LARGE_NEGATIVE)
            mask = jnp.sum(mask,
                           axis=0)[None, :]  # Combine multiple masks together
            logits = logits + mask

        # Apply the repetition penalty.
        if repetition_penalty != 1:
            logits = apply_repetition_penalty(
                sequences,
                logits,
                i,
                repetition_penalty=repetition_penalty,
                repetition_window=repetition_window,
                repetition_penalty_normalize=repetition_penalty_normalize)

        # Mask out everything but the top-k entries.
        if top_k is not None:
            # Compute top_k_index and top_k_threshold with shapes (batch_size, 1).
            top_k_index = jnp.argsort(logits,
                                      axis=-1)[:, ::-1][:, top_k - 1:top_k]
            top_k_threshold = jnp.take_along_axis(logits, top_k_index, axis=-1)
            logits = jnp.where(logits < top_k_threshold,
                               jnp.full_like(logits, LARGE_NEGATIVE), logits)
        # Sample next token from logits.
        sample = multinomial(rng1, logits)
        next_token = sample.astype(jnp.int32)
        # Only use sampled tokens if we have past the out_of_prompt_marker.
        out_of_prompt = (sequences[:, i + 1] == out_of_prompt_marker)
        next_token = (next_token * out_of_prompt +
                      sequences[:, i + 1] * ~out_of_prompt)
        # If end-marker reached for batch item, only emit padding tokens.
        next_token = next_token[:, None]
        next_token_or_endpad = jnp.where(ended,
                                         jnp.full_like(next_token, pad_token),
                                         next_token)
        ended |= (next_token_or_endpad == end_marker)
        # Add current sampled tokens to recorded sequences.
        new_sequences = lax.dynamic_update_slice(sequences,
                                                 next_token_or_endpad,
                                                 (0, i + 1))
        return (i + 1, new_sequences, new_cache, next_token_or_endpad, ended,
                rng2, new_tokens_to_logits_state)
Ejemplo n.º 14
0
        def loss_fn(params):
            labels = batch.pop("labels")

            logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]

            # compute loss
            loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])).mean()

            return loss
Ejemplo n.º 15
0
  def __call__(
      self,
      input_ids,
      input_mask,
      type_ids,
      labels = None,
      deterministic = False
  ):
    """Applies model for sequence classification.

    Args:
      input_ids: <int>[BATCH_SIZE, MAX_SEQ_LENGTH] tokenized inputs.
      input_mask: <bool>[BATCH_SIZE, MAX_SEQ_LENGTH] mask separating actual
        inputs from padding. Only used by BERT.
      type_ids: <int>[BATCH_SIZE, MAX_SEQ_LENGTH] Ids partitioning input into
        different types.
      labels: True labels associated with inputs. Generally only required for
        training. Shape depends on task type:
        * Classification: <int>[BATCH_SIZE]
        * Regression: <float>[BATCH_SIZE]
      deterministic: Whether or not to apply dropout to input.

    Returns:
      * If labels supplied (training mode): Model loss and metrics.
      * If no labels supplied (prediction / evaluation mode): Logits of shape
        <float>[BATCH_SIZE, n_classes].
    """
    _, pooled_output = EncoderModel(
        self.config, name="encoder")(
            input_ids, input_mask, type_ids, deterministic=deterministic)

    logits = layers.OutputProjection(
        n_out=self.n_classes,
        kernel_init=default_kernel_init,
        name="classification")(
            pooled_output)

    if labels is None:
      # Code path used during evaluation or prediction; metrics can be computed
      # from logits by the caller.
      return logits

    # Code path used during training.
    if self.config.dataset_name == "glue/stsb":  # Regression task
      loss = jnp.mean((logits[Ellipsis, 0] - labels)**2)
      return {"loss": loss, "num_labels": labels.size}
    else:  # Classification task
      logits = nn.log_softmax(logits)
      loss = -jnp.mean(
          jnp.sum(onehot(labels, logits.shape[-1]) * logits, axis=-1))
      correct_predictions = jnp.sum(logits.argmax(-1) == labels)
      return {
          "loss": loss,
          "correct_predictions": correct_predictions,
          "num_labels": labels.size
      }
Ejemplo n.º 16
0
def cross_entropy_loss(log_softmax_logits, labels):
    """Returns the cross-entropy classification loss.

  Args:
    log_softmax_logits: The log of the softmax of the logits for the mini-batch,
      e.g. as output by jax.nn.log_softmax(logits).
    labels: The labels for the mini-batch.
  """
    num_classes = log_softmax_logits.shape[-1]
    one_hot_labels = common_utils.onehot(labels, num_classes)
    return -jnp.sum(one_hot_labels * log_softmax_logits) / labels.size
Ejemplo n.º 17
0
 def loss(params):
     x_in = shift_right(x)
     logits = model.apply(params,
                          x_in,
                          rngs={
                              "permute": permute_key,
                              "dropout": dropout_key
                          })
     log_prob = nn.log_softmax(logits)
     x_onehot = onehot(x, num_classes=10)
     nll = -jnp.sum(x_onehot * log_prob, axis=-1)
     return jnp.mean(nll)
Ejemplo n.º 18
0
    def sample(self, masked_inputs, rng):
        """Fill in MASK positions in inputs."""
        mask_positions = masked_inputs == self.domain.vocab.mask
        logits = self.score(masked_inputs)

        # Mask out MASK token.
        mask = common_utils.onehot(jnp.array([self.domain.vocab.mask]),
                                   num_classes=logits.shape[-1],
                                   on_value=sampling.LARGE_NEGATIVE)
        logits = logits + mask
        samples = jax.random.categorical(rng, logits=logits)
        infilled = onp.where(mask_positions, samples, masked_inputs)
        return infilled
Ejemplo n.º 19
0
def cross_entropy_loss(logprobs, label, num_classes):
    """Computes the cross entropy loss for one datapoint.

  Args:
    logprobs: log probabilities predicted by the model
    label: true class label
    num_classes: number of classes in the task

  Returns:
    loss: value of the loss.
  """
    one_hot_labels = common_utils.onehot(label, num_classes=num_classes)
    return -jnp.sum(one_hot_labels * logprobs)
Ejemplo n.º 20
0
        def loss_fn(params):
            labels = batch.pop("labels")

            logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]

            # compute loss, ignore padded input tokens
            label_mask = jnp.where(labels > 0, 1.0, 0.0)
            loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask

            # take average
            loss = loss.sum() / label_mask.sum()

            return loss
Ejemplo n.º 21
0
    def metrics_fn(self, logits, batch):
        """Calculates metrics for the classification task.

    Args:
      logits: float array; Output of the model->[batch, length, num_classes].
      batch: dict; Batch of data that has 'label' and optionally 'weights'.

    Returns:
      a dict of metrics.
    """
        target_is_onehot = logits.shape == batch['label'].shape
        if target_is_onehot:
            one_hot_targets = batch['label']
        else:
            one_hot_targets = common_utils.onehot(batch['label'],
                                                  logits.shape[-1])

        if self.dataset.meta_data['num_classes'] == 1:
            # If this is a binary classification task, make sure the shape of labels
            # is (bs, 1) and is the same as the shape of logits.
            one_hot_targets = jnp.reshape(one_hot_targets, logits.shape)

        if self.task_params.get('class_indices'):
            possible_labels_indices = self.task_params.get('class_indices')
            one_hot_targets = one_hot_targets[:, possible_labels_indices]
            logits = logits[:, possible_labels_indices]

        weights = batch.get('weights')  # weights might not be defined
        metrics_dic = {}
        for key in self._METRICS:
            metric_val, metric_normalizer = self._METRICS[key](logits,
                                                               one_hot_targets,
                                                               weights)
            metrics_dic[key] = (jax.lax.psum(metric_val, 'batch'),
                                jax.lax.psum(metric_normalizer, 'batch'))

        # Store dataset related factors.
        for key in batch:
            if 'factor' in key:
                factors = batch[key]
                if weights is not None:
                    val = jnp.sum(metrics.apply_weights(factors, weights))
                    norm = jnp.sum(weights)
                else:
                    val = jnp.sum(factors)
                    norm = len(factors)

                metrics_dic[key] = (jax.lax.psum(val, 'batch'),
                                    jax.lax.psum(norm, 'batch'))

        return metrics_dic
Ejemplo n.º 22
0
    def loss_fn(logits, labels, z_loss=0):
        shift_logits = logits[..., :-1, :]
        shift_labels = labels[..., 1:]

        shift_labels = onehot(shift_labels, shift_logits.shape[-1])

        shift_logits = shift_logits - jax.lax.stop_gradient(shift_logits.max(axis=-1, keepdims=True))
        log_z = jnp.log(jnp.sum(jnp.exp(shift_logits), axis=-1, keepdims=True))
        log_softmax = shift_logits - log_z
        loss = -jnp.sum(shift_labels * log_softmax, axis=-1)

        loss += (1e-4 * jnp.square(log_z.squeeze(-1))) * z_loss

        return loss.mean()
Ejemplo n.º 23
0
    def _compute_metrics(self, masked_lm_logits, next_sentence_logits,
                         masked_lm_labels, masked_lm_weights,
                         next_sentence_labels, **unused_kwargs):
        """Computes the pre-training loss and its components."""
        masked_lm_logits = nn.log_softmax(masked_lm_logits)
        masked_lm_labels = onehot(masked_lm_labels.reshape((-1, )),
                                  masked_lm_logits.shape[-1])
        masked_lm_weights = masked_lm_weights.reshape((-1, ))
        masked_lm_loss = -jnp.sum(
            jnp.sum(masked_lm_logits * masked_lm_labels, axis=-1) *
            masked_lm_weights) / jnp.sum(masked_lm_weights)

        next_sentence_logits = nn.log_softmax(next_sentence_logits)
        next_sentence_labels = next_sentence_labels.reshape((-1, ))
        next_sentence_loss = -jnp.mean(
            jnp.sum(
                onehot(next_sentence_labels, next_sentence_logits.shape[-1]) *
                next_sentence_logits,
                axis=-1))
        return {
            'loss': masked_lm_loss + next_sentence_loss,
            'masked_lm_loss': masked_lm_loss,
            'next_sentence_loss': next_sentence_loss,
        }
Ejemplo n.º 24
0
def get_masked_lm_output(logits, label_ids, label_weights):
  """Calculate masked_lm loss for pretrain task."""
  vocab_size = logits.shape[-1]

  label_ids = jnp.reshape(label_ids, (-1))
  label_weights = jnp.reshape(label_weights, (-1))
  one_hot_labels = common_utils.onehot(
      label_ids, vocab_size, on_value=1.0, off_value=0.0)

  log_probs = nn.log_softmax(logits)
  per_example_loss = -jnp.sum(log_probs * one_hot_labels, axis=-1)

  numerator = jnp.sum(label_weights * per_example_loss)
  denominator = jnp.sum(label_weights) + 1e-5
  loss = numerator / denominator
  return loss, per_example_loss, log_probs
Ejemplo n.º 25
0
def compute_weighted_cross_entropy(logits,
                                   labels):
  """Compute weighted cross entropy and entropy for log probs and labels.

  Args:
   logits: [batch, length, num_classes] float array.
   labels: categorical targets [batch, length] int array.
  Returns:
    Tuple of scalars of loss and per example loss.
  """
  log_probs = nn.log_softmax(logits)
  labels = jnp.reshape(labels, [-1])
  one_hot_labels = common_utils.onehot(labels, num_classes=2)
  per_example_loss = -jnp.sum(one_hot_labels * log_probs, axis=-1)
  loss = jnp.mean(per_example_loss)
  return (loss, per_example_loss)
Ejemplo n.º 26
0
    def eval_step(params, batch):
        labels = batch.pop("labels")

        logits = model(**batch, params=params, train=False)[0]

        # compute loss
        loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1]))

        # compute accuracy
        accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels)

        # summarize metrics
        metrics = {"loss": loss.mean(), "accuracy": accuracy.mean()}
        metrics = jax.lax.pmean(metrics, axis_name="batch")

        return metrics
Ejemplo n.º 27
0
def cross_entropy(logits, targets, weights = None, label_smoothing = 0.0):
    vocab_size = logits.shape[-1]
    confidence = 1.0 - label_smoothing
    low_confidence = (1.0 - confidence) / (vocab_size - 1)
    normalizing_constant = -(
            confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
    ).astype(logits.dtype)
    soft_targets = common_utils.onehot(targets,vocab_size)
    loss = -jnp.sum(soft_targets*nn.log_softmax(logits),axis=-1,dtype=logits.dtype)
    loss -= normalizing_constant
    if weights is not None:
        loss *= weights
        normalizing_factor = weights.sum()
    else:
        normalizing_factor = np.prod(targets.shape,dtype=logits.dtype)
    return loss.sum(),normalizing_factor
Ejemplo n.º 28
0
        def loss_fn(params):
            labels = batch.pop("labels")

            outputs = state.apply_fn(**batch,
                                     output_attentions=True,
                                     params=params,
                                     dropout_rng=dropout_rng,
                                     train=True)
            logits = outputs["logits"]

            # compute loss
            loss = optax.softmax_cross_entropy(
                logits, onehot(labels, logits.shape[-1])).mean()

            return loss, jnp.swapaxes(jnp.stack(outputs["encoder_attentions"]),
                                      0, 1)
Ejemplo n.º 29
0
    def eval_step(params, batch):
        labels = batch.pop("labels")

        logits = model(**batch, params=params, train=False)[0]

        # compute loss, ignore padded input tokens
        label_mask = jnp.where(labels > 0, 1.0, 0.0)
        loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask

        # compute accuracy
        accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels) * label_mask

        # summarize metrics
        metrics = {"loss": loss.sum(), "accuracy": accuracy.sum(), "normalizer": label_mask.sum()}
        metrics = jax.lax.psum(metrics, axis_name="batch")

        return metrics
Ejemplo n.º 30
0
def compute_weighted_cross_entropy(logits,
                                   targets,
                                   weights=None,
                                   label_smoothing=0.0,
                                   z_loss=0.0):
    """Compute weighted cross entropy and entropy for log probs and targets.

  Args:
   logits: [batch, length, num_classes] float array.
   targets: categorical one-hot targets [batch, length, category] int array.
   weights: None or array of shape [batch, length].
   label_smoothing: label smoothing constant, used to determine the on and off
     values.
    z_loss: coefficient for auxilliary z-loss loss term.

  Returns:
    Tuple of scalar loss and batch normalizing factor.
  """
    targets = targets.reshape((-1))
    if logits.ndim != targets.ndim + 1:
        raise ValueError(
            'Incorrect shapes. Got shape %s logits and %s targets' %
            (str(logits.shape), str(targets.shape)))
    vocab_size = logits.shape[-1]
    confidence = 1.0 - label_smoothing
    low_confidence = (1.0 - confidence) / (vocab_size - 1)
    normalizing_constant = -(
        confidence * jnp.log(confidence) +
        (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20))
    soft_targets = common_utils.onehot(targets,
                                       vocab_size,
                                       on_value=confidence,
                                       off_value=low_confidence)
    loss = cross_entropy_with_logits(logits, soft_targets, z_loss=z_loss)
    loss = loss - normalizing_constant

    normalizing_factor = np.prod(targets.shape)
    if weights is not None:
        weights = weights.reshape((-1))
        loss = loss * weights
        normalizing_factor = jnp.sum(weights)

    # HACK T5's "loss_denominator" correction for batchsize 2048 * 114 targetlen..
    # normalizing_factor = 233472.0

    return jnp.sum(loss), normalizing_factor