Exemple #1
0
 def testGuidedAlignmentCostUnderDistributionStrategy(
         self, cost_type, with_length):
     strategy = tf.distribute.MirroredStrategy(devices=["/cpu:0"])
     attention_probs = tf.random.uniform([2, 5, 6])
     gold_alignment = tf.random.uniform([2, 5, 6])
     if with_length:
         sequence_length = tf.constant([4, 5], dtype=tf.int32)
     else:
         sequence_length = None
     with strategy.scope():
         losses.guided_alignment_cost(attention_probs,
                                      gold_alignment,
                                      sequence_length=sequence_length,
                                      cost_type=cost_type)
 def compute_loss(self, outputs, labels, training=True):
     params = self.params
     if not isinstance(outputs, dict):
         outputs = dict(logits=outputs)
     logits = outputs["logits"]
     noisy_logits = outputs.get("noisy_logits")
     attention = outputs.get("attention")
     if noisy_logits is not None and params.get("contrastive_learning"):
         return losses.max_margin_loss(
             logits,
             labels["ids_out"],
             labels["length"],
             noisy_logits,
             labels["noisy_ids_out"],
             labels["noisy_length"],
             eta=params.get("max_margin_eta", 0.1),
         )
     (
         loss,
         loss_normalizer,
         loss_token_normalizer,
     ) = losses.cross_entropy_sequence_loss(
         logits,
         labels["ids_out"],
         sequence_length=labels["length"],
         sequence_weight=labels.get("weight"),
         label_smoothing=params.get("label_smoothing", 0.0),
         average_in_time=params.get("average_loss_in_time", False),
         training=training,
     )
     if training:
         gold_alignments = labels.get("alignment")
         guided_alignment_type = params.get("guided_alignment_type")
         if gold_alignments is not None and guided_alignment_type is not None:
             if attention is None:
                 tf.get_logger().warning(
                     "This model did not return attention vectors; "
                     "guided alignment will not be applied"
                 )
             else:
                 loss += losses.guided_alignment_cost(
                     attention[:, :-1],  # Do not constrain last timestep.
                     gold_alignments,
                     sequence_length=self.labels_inputter.get_length(
                         labels, ignore_special_tokens=True
                     ),
                     cost_type=guided_alignment_type,
                     weight=params.get("guided_alignment_weight", 1),
                 )
     return loss, loss_normalizer, loss_token_normalizer