コード例 #1
0
    def build_losses(self,
                     labels,
                     model_outputs,
                     metrics,
                     aux_losses=None) -> tf.Tensor:
        metrics = dict([(metric.name, metric) for metric in metrics])
        lm_output = tf.nn.log_softmax(tf.cast(model_outputs['lm_output'],
                                              tf.float32),
                                      axis=-1)
        mlm_loss = loss_lib.weighted_sparse_categorical_crossentropy_loss(
            labels=labels['masked_lm_ids'],
            predictions=lm_output,
            weights=labels['masked_lm_weights'])
        metrics['lm_example_loss'].update_state(mlm_loss)
        if 'next_sentence_labels' in labels:
            sentence_labels = labels['next_sentence_labels']
            sentence_outputs = tf.cast(model_outputs['next_sentence'],
                                       dtype=tf.float32)
            sentence_loss = loss_lib.weighted_sparse_categorical_crossentropy_loss(
                labels=sentence_labels,
                predictions=tf.nn.log_softmax(sentence_outputs, axis=-1))
            metrics['next_sentence_loss'].update_state(sentence_loss)
            total_loss = mlm_loss + sentence_loss
        else:
            total_loss = mlm_loss

        if aux_losses:
            total_loss += tf.add_n(aux_losses)
        return total_loss
コード例 #2
0
    def call(self,
             lm_output,
             sentence_output,
             lm_label_ids,
             lm_label_weights,
             sentence_labels=None):
        """Implements call() for the layer."""
        lm_label_weights = tf.cast(lm_label_weights, tf.float32)
        lm_output = tf.cast(lm_output, tf.float32)

        mask_label_loss = losses.weighted_sparse_categorical_crossentropy_loss(
            labels=lm_label_ids,
            predictions=lm_output,
            weights=lm_label_weights)

        if sentence_labels is not None:
            sentence_output = tf.cast(sentence_output, tf.float32)
            sentence_loss = losses.weighted_sparse_categorical_crossentropy_loss(
                labels=sentence_labels, predictions=sentence_output)
            loss = mask_label_loss + sentence_loss
        else:
            sentence_loss = None
            loss = mask_label_loss

        batch_shape = tf.slice(tf.shape(lm_label_ids), [0], [1])
        # TODO(hongkuny): Avoids the hack and switches add_loss.
        final_loss = tf.fill(batch_shape, loss)

        self._add_metrics(lm_output, lm_label_ids, lm_label_weights,
                          mask_label_loss, sentence_output, sentence_labels,
                          sentence_loss)
        return final_loss
コード例 #3
0
    def call(self, inputs):
        """Implements call() for the layer."""
        unpacked_inputs = tf_utils.unpack_inputs(inputs)
        lm_output = unpacked_inputs[0]
        sentence_output = unpacked_inputs[1]
        lm_label_ids = unpacked_inputs[2]
        lm_label_weights = tf.keras.backend.cast(unpacked_inputs[3],
                                                 tf.float32)
        sentence_labels = unpacked_inputs[4]

        mask_label_loss = losses.weighted_sparse_categorical_crossentropy_loss(
            labels=lm_label_ids,
            predictions=lm_output,
            weights=lm_label_weights)
        sentence_loss = losses.weighted_sparse_categorical_crossentropy_loss(
            labels=sentence_labels, predictions=sentence_output)
        loss = mask_label_loss + sentence_loss
        batch_shape = tf.slice(tf.keras.backend.shape(sentence_labels), [0],
                               [1])
        # TODO(hongkuny): Avoids the hack and switches add_loss.
        final_loss = tf.fill(batch_shape, loss)

        self._add_metrics(lm_output, lm_label_ids, lm_label_weights,
                          mask_label_loss, sentence_output, sentence_labels,
                          sentence_loss)
        return final_loss
コード例 #4
0
    def build_losses(self,
                     features,
                     model_outputs,
                     metrics,
                     aux_losses=None) -> tf.Tensor:
        metrics = dict([(metric.name, metric) for metric in metrics])
        lm_output = tf.nn.log_softmax(model_outputs['lm_output'], axis=-1)
        mlm_loss = loss_lib.weighted_sparse_categorical_crossentropy_loss(
            labels=features['masked_lm_ids'],
            predictions=lm_output,
            weights=features['masked_lm_weights'])
        metrics['lm_example_loss'].update_state(mlm_loss)
        if 'next_sentence_labels' in features:
            policy = tf.keras.mixed_precision.experimental.global_policy()
            if policy.name == 'mixed_bfloat16':  # b/158514794: bf16 is not stable.
                policy = tf.float32
            predictions = tf.keras.layers.Activation(
                tf.nn.log_softmax,
                dtype=policy)(model_outputs['next_sentence'])

            sentence_labels = features['next_sentence_labels']
            sentence_loss = loss_lib.weighted_sparse_categorical_crossentropy_loss(
                labels=sentence_labels, predictions=predictions)
            metrics['next_sentence_loss'].update_state(sentence_loss)
            total_loss = mlm_loss + sentence_loss
        else:
            total_loss = mlm_loss

        if aux_losses:
            total_loss += tf.add_n(aux_losses)
        return total_loss
コード例 #5
0
    def call(self,
             tag_logits,
             tag_labels,
             input_mask,
             labels_mask,
             point_logits=None,
             point_labels=None):
        """Implements call() for the layer.

    Args:
      tag_logits: [batch_size, seq_length, vocab_size] tensor with tag logits.
      tag_labels: [batch_size, seq_length] tensor with gold outputs.
      input_mask: [batch_size, seq_length]tensor with mask (1s or 0s).
      labels_mask: [batch_size, seq_length] mask for labels, may be a binary
        mask or a weighted float mask.
      point_logits: [batch_size, seq_length, seq_length] optional tensor with
        point logits.
      point_labels: [batch_size, seq_length] optional tensor with gold outputs.

    Returns:
      Scalar loss of the model.
    """
        tag_logits = tf.cast(tag_logits, tf.float32)
        labels_mask = tf.cast(labels_mask, tf.float32) * tf.math.reduce_sum(
            tf.cast(input_mask, tf.float32), axis=-1, keepdims=True)
        tag_logits_loss = losses.weighted_sparse_categorical_crossentropy_loss(
            labels=tag_labels,
            predictions=tag_logits,
            weights=tf.cast(labels_mask, tf.float32),
            from_logits=True)
        if self._use_pointing:
            point_logits_loss = losses.weighted_sparse_categorical_crossentropy_loss(
                labels=point_labels,
                predictions=point_logits,
                weights=tf.cast(input_mask, tf.float32),
                from_logits=True)
            total_loss = tag_logits_loss + tf.cast(
                tf.constant(self._pointing_weight),
                tf.float32) * point_logits_loss
            self._add_metrics(tag_logits, tag_labels, tag_logits_loss,
                              input_mask, labels_mask, total_loss,
                              point_logits, point_labels, point_logits_loss)
        else:
            total_loss = tag_logits_loss
            self._add_metrics(tag_logits, tag_labels, tag_logits_loss,
                              input_mask, labels_mask, total_loss)

        return total_loss
コード例 #6
0
  def call(self, lm_output, lm_label_ids, lm_label_weights):
    """Implements call() for the layer.

    Args:
      lm_output: [batch_size, max_predictions_per_seq, vocab_size] tensor with
        language model logits.
      lm_label_ids: [batch_size, max_predictions_per_seq] tensor with gold
        outputs.
      lm_label_weights: [batch_size, max_predictions_per_seq] tensor with
        per-token weights.

    Returns:
      final_loss: scalar MLM loss.
    """
    lm_label_weights = tf.cast(lm_label_weights, tf.float32)
    lm_output = tf.cast(lm_output, tf.float32)

    mask_label_loss = losses.weighted_sparse_categorical_crossentropy_loss(
        labels=lm_label_ids,
        predictions=lm_output,
        weights=lm_label_weights,
        from_logits=True)

    self._add_metrics(lm_output, lm_label_ids, lm_label_weights,
                      mask_label_loss)
    return mask_label_loss
コード例 #7
0
ファイル: sentence_prediction.py プロジェクト: zzf2014/models
    def build_losses(self,
                     labels,
                     model_outputs,
                     aux_losses=None) -> tf.Tensor:
        loss = loss_lib.weighted_sparse_categorical_crossentropy_loss(
            labels=labels,
            predictions=tf.nn.log_softmax(model_outputs['sentence_prediction'],
                                          axis=-1))

        if aux_losses:
            loss += tf.add_n(aux_losses)
        return loss
コード例 #8
0
ファイル: electra_model.py プロジェクト: supersteph/models
  def call(self, lm_output, lm_label_ids, lm_label_weights,
           discrim_output, discrim_labels):
    """Implements call() for the layer."""
    weights = tf.cast(lm_label_weights, tf.float32)
    lm_output = tf.cast(lm_output, tf.float32)
    discrim_output = tf.cast(discrim_output, tf.float32)
    mask_label_loss = losses.weighted_sparse_categorical_crossentropy_loss(
        labels=lm_label_ids, predictions=lm_output, weights=weights)
    discrim_ind_loss = tf.nn.sigmoid_cross_entropy_with_logits(
        logits=discrim_output,
        labels=tf.cast(discrim_labels, tf.float32))
    discrim_loss = tf.reduce_sum(discrim_ind_loss)
    loss = mask_label_loss + self.config["discrim_rate"] * discrim_loss


    self._add_metrics(lm_output, lm_label_ids, lm_label_weights,
                      mask_label_loss, discrim_output, discrim_labels,
                      discrim_loss, loss)
    return loss