def _replicated_step(inputs):
      """Replicated training step."""

      inputs, labels = inputs
      with tf.GradientTape() as tape:
        model_outputs = model(inputs, training=True)
        loss = loss_fn(labels, model_outputs)
        # Raw loss is used for reporting in metrics/logs.
        raw_loss = loss
        if scale_loss:
          # Scales down the loss for gradients to be invariant from replicas.
          loss = loss / strategy.num_replicas_in_sync

      if explicit_allreduce:
        grad_utils.minimize_using_explicit_allreduce(tape, optimizer, loss,
                                                     training_vars,
                                                     pre_allreduce_callbacks,
                                                     post_allreduce_callbacks)
      else:
        if isinstance(optimizer,
                      tf.keras.mixed_precision.experimental.LossScaleOptimizer):
          with tape:
            scaled_loss = optimizer.get_scaled_loss(loss)
          scaled_grads = tape.gradient(scaled_loss, training_vars)
          grads = optimizer.get_unscaled_gradients(scaled_grads)
        else:
          grads = tape.gradient(loss, training_vars)
        optimizer.apply_gradients(zip(grads, training_vars))
      # For reporting, the metric takes the mean of losses.
      train_loss_metric.update_state(raw_loss)
      for metric in train_metrics:
        metric.update_state(labels, model_outputs)
Example #2
0
        def step_fn(inputs):
            """Function to run on the device."""
            images, labels = inputs
            with tf.GradientTape() as tape:
                logits = self.model(images, training=True)

                prediction_loss = tf.keras.losses.sparse_categorical_crossentropy(
                    labels, logits)
                loss = tf.reduce_sum(prediction_loss) * (
                    1.0 / self.flags_obj.batch_size)
                num_replicas = self.strategy.num_replicas_in_sync
                l2_weight_decay = 1e-4
                if self.flags_obj.single_l2_loss_op:
                    l2_loss = l2_weight_decay * 2 * tf.add_n([
                        tf.nn.l2_loss(v)
                        for v in self.model.trainable_variables
                        if 'bn' not in v.name
                    ])

                    loss += (l2_loss / num_replicas)
                else:
                    loss += (tf.reduce_sum(self.model.losses) / num_replicas)

            grad_utils.minimize_using_explicit_allreduce(
                tape, self.optimizer, loss, self.model.trainable_variables)
            self.train_loss.update_state(loss)
            self.train_accuracy.update_state(labels, logits)