Ejemplo n.º 1
0
    def _test_step_fn(inputs):
      """Replicated accuracy calculation."""
      targets = models.remove_sos_from_seq(inputs["target_ids"],
                                           params.pad_token_id)

      # Using ground truth sequences as targets to calculate logits for accuracy
      # and perplexity metrics.
      logits, _, _ = model(inputs, training=False, mode="train")
      metric_layer([logits, targets])

      # Get logits from top beam search results for bleu and rouge metrics.
      logits = model(inputs, training=False, mode="eval")

      return targets, logits
Ejemplo n.º 2
0
    def train_step(self, inputs):
        """The logic for one training step."""
        with tf.GradientTape() as tape:
            logits, _, _ = self(inputs, mode="train", training=True)
            targets = models.remove_sos_from_seq(inputs["target_ids"],
                                                 self.params.pad_token_id)
            loss = transformer_metrics.transformer_loss(
                logits, targets, self.params.label_smoothing,
                self.params.vocab_size)
            # Scales the loss, which results in using the average loss across all
            # of the replicas for backprop.
            scaled_loss = loss / self._num_replicas_in_sync

        tvars = self.trainable_variables
        grads = tape.gradient(scaled_loss, tvars)
        self.optimizer.apply_gradients(list(zip(grads, tvars)))
        return {
            "training_loss": loss,
            "learning_rate": self.optimizer._decayed_lr(var_dtype=tf.float32)
        }