def _replicated_step(inputs): """Replicated training step.""" inputs, labels = inputs with tf.GradientTape() as tape: model_outputs = model(inputs, training=True) loss = loss_fn(labels, model_outputs) # Raw loss is used for reporting in metrics/logs. raw_loss = loss if scale_loss: # Scales down the loss for gradients to be invariant from replicas. loss = loss / strategy.num_replicas_in_sync if explicit_allreduce: grad_utils.minimize_using_explicit_allreduce(tape, optimizer, loss, training_vars, pre_allreduce_callbacks, post_allreduce_callbacks) else: if isinstance(optimizer, tf.keras.mixed_precision.experimental.LossScaleOptimizer): with tape: scaled_loss = optimizer.get_scaled_loss(loss) scaled_grads = tape.gradient(scaled_loss, training_vars) grads = optimizer.get_unscaled_gradients(scaled_grads) else: grads = tape.gradient(loss, training_vars) optimizer.apply_gradients(zip(grads, training_vars)) # For reporting, the metric takes the mean of losses. train_loss_metric.update_state(raw_loss) for metric in train_metrics: metric.update_state(labels, model_outputs)
def step_fn(inputs): """Function to run on the device.""" images, labels = inputs with tf.GradientTape() as tape: logits = self.model(images, training=True) prediction_loss = tf.keras.losses.sparse_categorical_crossentropy( labels, logits) loss = tf.reduce_sum(prediction_loss) * ( 1.0 / self.flags_obj.batch_size) num_replicas = self.strategy.num_replicas_in_sync l2_weight_decay = 1e-4 if self.flags_obj.single_l2_loss_op: l2_loss = l2_weight_decay * 2 * tf.add_n([ tf.nn.l2_loss(v) for v in self.model.trainable_variables if 'bn' not in v.name ]) loss += (l2_loss / num_replicas) else: loss += (tf.reduce_sum(self.model.losses) / num_replicas) grad_utils.minimize_using_explicit_allreduce( tape, self.optimizer, loss, self.model.trainable_variables) self.train_loss.update_state(loss) self.train_accuracy.update_state(labels, logits)