def build_losses(self, labels, model_outputs, metrics, aux_losses=None) -> tf.Tensor: metrics = dict([(metric.name, metric) for metric in metrics]) lm_output = tf.nn.log_softmax(tf.cast(model_outputs['lm_output'], tf.float32), axis=-1) mlm_loss = loss_lib.weighted_sparse_categorical_crossentropy_loss( labels=labels['masked_lm_ids'], predictions=lm_output, weights=labels['masked_lm_weights']) metrics['lm_example_loss'].update_state(mlm_loss) if 'next_sentence_labels' in labels: sentence_labels = labels['next_sentence_labels'] sentence_outputs = tf.cast(model_outputs['next_sentence'], dtype=tf.float32) sentence_loss = loss_lib.weighted_sparse_categorical_crossentropy_loss( labels=sentence_labels, predictions=tf.nn.log_softmax(sentence_outputs, axis=-1)) metrics['next_sentence_loss'].update_state(sentence_loss) total_loss = mlm_loss + sentence_loss else: total_loss = mlm_loss if aux_losses: total_loss += tf.add_n(aux_losses) return total_loss
def call(self, lm_output, sentence_output, lm_label_ids, lm_label_weights, sentence_labels=None): """Implements call() for the layer.""" lm_label_weights = tf.cast(lm_label_weights, tf.float32) lm_output = tf.cast(lm_output, tf.float32) mask_label_loss = losses.weighted_sparse_categorical_crossentropy_loss( labels=lm_label_ids, predictions=lm_output, weights=lm_label_weights) if sentence_labels is not None: sentence_output = tf.cast(sentence_output, tf.float32) sentence_loss = losses.weighted_sparse_categorical_crossentropy_loss( labels=sentence_labels, predictions=sentence_output) loss = mask_label_loss + sentence_loss else: sentence_loss = None loss = mask_label_loss batch_shape = tf.slice(tf.shape(lm_label_ids), [0], [1]) # TODO(hongkuny): Avoids the hack and switches add_loss. final_loss = tf.fill(batch_shape, loss) self._add_metrics(lm_output, lm_label_ids, lm_label_weights, mask_label_loss, sentence_output, sentence_labels, sentence_loss) return final_loss
def call(self, inputs): """Implements call() for the layer.""" unpacked_inputs = tf_utils.unpack_inputs(inputs) lm_output = unpacked_inputs[0] sentence_output = unpacked_inputs[1] lm_label_ids = unpacked_inputs[2] lm_label_weights = tf.keras.backend.cast(unpacked_inputs[3], tf.float32) sentence_labels = unpacked_inputs[4] mask_label_loss = losses.weighted_sparse_categorical_crossentropy_loss( labels=lm_label_ids, predictions=lm_output, weights=lm_label_weights) sentence_loss = losses.weighted_sparse_categorical_crossentropy_loss( labels=sentence_labels, predictions=sentence_output) loss = mask_label_loss + sentence_loss batch_shape = tf.slice(tf.keras.backend.shape(sentence_labels), [0], [1]) # TODO(hongkuny): Avoids the hack and switches add_loss. final_loss = tf.fill(batch_shape, loss) self._add_metrics(lm_output, lm_label_ids, lm_label_weights, mask_label_loss, sentence_output, sentence_labels, sentence_loss) return final_loss
def build_losses(self, features, model_outputs, metrics, aux_losses=None) -> tf.Tensor: metrics = dict([(metric.name, metric) for metric in metrics]) lm_output = tf.nn.log_softmax(model_outputs['lm_output'], axis=-1) mlm_loss = loss_lib.weighted_sparse_categorical_crossentropy_loss( labels=features['masked_lm_ids'], predictions=lm_output, weights=features['masked_lm_weights']) metrics['lm_example_loss'].update_state(mlm_loss) if 'next_sentence_labels' in features: policy = tf.keras.mixed_precision.experimental.global_policy() if policy.name == 'mixed_bfloat16': # b/158514794: bf16 is not stable. policy = tf.float32 predictions = tf.keras.layers.Activation( tf.nn.log_softmax, dtype=policy)(model_outputs['next_sentence']) sentence_labels = features['next_sentence_labels'] sentence_loss = loss_lib.weighted_sparse_categorical_crossentropy_loss( labels=sentence_labels, predictions=predictions) metrics['next_sentence_loss'].update_state(sentence_loss) total_loss = mlm_loss + sentence_loss else: total_loss = mlm_loss if aux_losses: total_loss += tf.add_n(aux_losses) return total_loss
def call(self, tag_logits, tag_labels, input_mask, labels_mask, point_logits=None, point_labels=None): """Implements call() for the layer. Args: tag_logits: [batch_size, seq_length, vocab_size] tensor with tag logits. tag_labels: [batch_size, seq_length] tensor with gold outputs. input_mask: [batch_size, seq_length]tensor with mask (1s or 0s). labels_mask: [batch_size, seq_length] mask for labels, may be a binary mask or a weighted float mask. point_logits: [batch_size, seq_length, seq_length] optional tensor with point logits. point_labels: [batch_size, seq_length] optional tensor with gold outputs. Returns: Scalar loss of the model. """ tag_logits = tf.cast(tag_logits, tf.float32) labels_mask = tf.cast(labels_mask, tf.float32) * tf.math.reduce_sum( tf.cast(input_mask, tf.float32), axis=-1, keepdims=True) tag_logits_loss = losses.weighted_sparse_categorical_crossentropy_loss( labels=tag_labels, predictions=tag_logits, weights=tf.cast(labels_mask, tf.float32), from_logits=True) if self._use_pointing: point_logits_loss = losses.weighted_sparse_categorical_crossentropy_loss( labels=point_labels, predictions=point_logits, weights=tf.cast(input_mask, tf.float32), from_logits=True) total_loss = tag_logits_loss + tf.cast( tf.constant(self._pointing_weight), tf.float32) * point_logits_loss self._add_metrics(tag_logits, tag_labels, tag_logits_loss, input_mask, labels_mask, total_loss, point_logits, point_labels, point_logits_loss) else: total_loss = tag_logits_loss self._add_metrics(tag_logits, tag_labels, tag_logits_loss, input_mask, labels_mask, total_loss) return total_loss
def call(self, lm_output, lm_label_ids, lm_label_weights): """Implements call() for the layer. Args: lm_output: [batch_size, max_predictions_per_seq, vocab_size] tensor with language model logits. lm_label_ids: [batch_size, max_predictions_per_seq] tensor with gold outputs. lm_label_weights: [batch_size, max_predictions_per_seq] tensor with per-token weights. Returns: final_loss: scalar MLM loss. """ lm_label_weights = tf.cast(lm_label_weights, tf.float32) lm_output = tf.cast(lm_output, tf.float32) mask_label_loss = losses.weighted_sparse_categorical_crossentropy_loss( labels=lm_label_ids, predictions=lm_output, weights=lm_label_weights, from_logits=True) self._add_metrics(lm_output, lm_label_ids, lm_label_weights, mask_label_loss) return mask_label_loss
def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor: loss = loss_lib.weighted_sparse_categorical_crossentropy_loss( labels=labels, predictions=tf.nn.log_softmax(model_outputs['sentence_prediction'], axis=-1)) if aux_losses: loss += tf.add_n(aux_losses) return loss
def call(self, lm_output, lm_label_ids, lm_label_weights, discrim_output, discrim_labels): """Implements call() for the layer.""" weights = tf.cast(lm_label_weights, tf.float32) lm_output = tf.cast(lm_output, tf.float32) discrim_output = tf.cast(discrim_output, tf.float32) mask_label_loss = losses.weighted_sparse_categorical_crossentropy_loss( labels=lm_label_ids, predictions=lm_output, weights=weights) discrim_ind_loss = tf.nn.sigmoid_cross_entropy_with_logits( logits=discrim_output, labels=tf.cast(discrim_labels, tf.float32)) discrim_loss = tf.reduce_sum(discrim_ind_loss) loss = mask_label_loss + self.config["discrim_rate"] * discrim_loss self._add_metrics(lm_output, lm_label_ids, lm_label_weights, mask_label_loss, discrim_output, discrim_labels, discrim_loss, loss) return loss