def _test_step_fn(inputs): """Replicated accuracy calculation.""" targets = models.remove_sos_from_seq(inputs["target_ids"], params.pad_token_id) # Using ground truth sequences as targets to calculate logits for accuracy # and perplexity metrics. logits, _, _ = model(inputs, training=False, mode="train") metric_layer([logits, targets]) # Get logits from top beam search results for bleu and rouge metrics. logits = model(inputs, training=False, mode="eval") return targets, logits
def train_step(self, inputs): """The logic for one training step.""" with tf.GradientTape() as tape: logits, _, _ = self(inputs, mode="train", training=True) targets = models.remove_sos_from_seq(inputs["target_ids"], self.params.pad_token_id) loss = transformer_metrics.transformer_loss( logits, targets, self.params.label_smoothing, self.params.vocab_size) # Scales the loss, which results in using the average loss across all # of the replicas for backprop. scaled_loss = loss / self._num_replicas_in_sync tvars = self.trainable_variables grads = tape.gradient(scaled_loss, tvars) self.optimizer.apply_gradients(list(zip(grads, tvars))) return { "training_loss": loss, "learning_rate": self.optimizer._decayed_lr(var_dtype=tf.float32) }