def _replicated_step(inputs):
      """Replicated training step."""

      inputs, labels = inputs
      with tf.GradientTape() as tape:
        model_outputs = model(inputs, training=True)
        loss = loss_fn(labels, model_outputs)
        # Raw loss is used for reporting in metrics/logs.
        raw_loss = loss
        if scale_loss:
          # Scales down the loss for gradients to be invariant from replicas.
          loss = loss / strategy.num_replicas_in_sync

      if explicit_allreduce:
        grad_utils.minimize_using_explicit_allreduce(tape, optimizer, loss,
                                                     training_vars,
                                                     pre_allreduce_callbacks,
                                                     post_allreduce_callbacks)
      else:
        if isinstance(optimizer,
                      tf.keras.mixed_precision.experimental.LossScaleOptimizer):
          with tape:
            scaled_loss = optimizer.get_scaled_loss(loss)
          scaled_grads = tape.gradient(scaled_loss, training_vars)
          grads = optimizer.get_unscaled_gradients(scaled_grads)
        else:
          grads = tape.gradient(loss, training_vars)
        optimizer.apply_gradients(zip(grads, training_vars))
      # For reporting, the metric takes the mean of losses.
      train_loss_metric.update_state(raw_loss)
      for metric in train_metrics:
        metric.update_state(labels, model_outputs)
Example #2
0
        def step_fn(inputs):
            """Function to run on the device."""
            images, labels = inputs
            with tf.GradientTape() as tape:
                logits = self.model(images, training=True)

                prediction_loss = tf.keras.losses.sparse_categorical_crossentropy(
                    labels, logits)
                loss = tf.reduce_sum(prediction_loss) * (
                    1.0 / self.flags_obj.batch_size)
                num_replicas = self.strategy.num_replicas_in_sync

                if self.flags_obj.single_l2_loss_op:
                    l2_loss = resnet_model.L2_WEIGHT_DECAY * 2 * tf.add_n([
                        tf.nn.l2_loss(v)
                        for v in self.model.trainable_variables
                        if 'bn' not in v.name
                    ])

                    loss += (l2_loss / num_replicas)
                else:
                    loss += (tf.reduce_sum(self.model.losses) / num_replicas)

            grad_utils.minimize_using_explicit_allreduce(
                tape, self.optimizer, loss, self.model.trainable_variables)
            self.train_loss.update_state(loss)
            self.train_accuracy.update_state(labels, logits)
Example #3
0
    def step_fn(inputs):
      """Function to run on the device."""
      images, labels = inputs
      if self.one_hot:
        labels = tf.cast(labels, tf.int32)
        labels = tf.one_hot(labels, 1001)
        labels = tf.squeeze(labels)

      with tf.GradientTape() as tape:
        logits = self.model(images, training=True)

        prediction_loss = self.get_prediction_loss(labels, logits)

        loss = tf.reduce_sum(prediction_loss) * (1.0 /
                                                 self.flags_obj.batch_size)

        if not self.use_lars_optimizer:
          num_replicas = self.strategy.num_replicas_in_sync

          if self.flags_obj.single_l2_loss_op:
            l2_loss = self.flags_obj.weight_decay * tf.add_n([
                tf.nn.l2_loss(v)
                for v in self.model.trainable_variables
                if ('bn' not in v.name)
            ])

            loss += (l2_loss / num_replicas)
          else:
            loss += (tf.reduce_sum(self.model.losses) / num_replicas)

      if horovod_enabled():
        tape = hvd.DistributedGradientTape(tape)
        grads = tape.gradient(loss, self.model.trainable_variables)
        grads_and_vars = zip(grads, self.model.trainable_variables)

        self.optimizer.apply_gradients(
          grads_and_vars, experimental_aggregate_gradients=False)

        tf.cond(self.global_step == 1,
          lambda: hvd.broadcast_variables(self.model.variables + self.optimizer.variables(),
                                          root_rank=0),
          lambda: tf.constant(True))
      else:
        grad_utils.minimize_using_explicit_allreduce(
          tape, self.optimizer, loss, self.model.trainable_variables)

      self.train_loss.update_state(loss)
      self.train_accuracy.update_state(labels, logits)
Example #4
0
def train_step(images, labels, step):
    with tf.GradientTape() as tape:
        predictions = model(images, training=True)
        loss = loss_object(labels, predictions)

    if horovod_enabled():
        tape = hvd.DistributedGradientTape(tape)
        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables),
                                  experimental_aggregate_gradients=True)
        tf.cond(
            step == 0, lambda: hvd.broadcast_variables(
                model.variables + optimizer.variables(), root_rank=0),
            lambda: tf.constant(True))
    else:
        grad_utils.minimize_using_explicit_allreduce(tape, optimizer, loss,
                                                     model.trainable_variables)

    train_loss(loss)
    train_accuracy(labels, predictions)
        def step_fn(inputs):
            """Function to run on the device."""
            images, labels = inputs
            if self.one_hot:
                labels = tf.cast(labels, tf.int32)
                labels = tf.one_hot(labels, 1001)
                labels = tf.squeeze(labels)

            with tf.GradientTape() as tape:
                logits = self.model(images, training=True)

                prediction_loss = self.get_prediction_loss(labels, logits)

                loss = tf.reduce_sum(prediction_loss) * (
                    1.0 / self.flags_obj.batch_size)

                if not self.use_lars_optimizer:
                    num_replicas = self.strategy.num_replicas_in_sync

                    if self.flags_obj.single_l2_loss_op:
                        l2_loss = self.flags_obj.weight_decay * tf.add_n([
                            tf.nn.l2_loss(v)
                            for v in self.model.trainable_variables
                            if ('bn' not in v.name)
                        ])

                        loss += (l2_loss / num_replicas)
                    else:
                        loss += (tf.reduce_sum(self.model.losses) /
                                 num_replicas)

            if horovod_enabled():
                tape = hvd.DistributedGradientTape(tape)
                grads = tape.gradient(loss, self.model.trainable_variables)
                grads_and_vars = zip(grads, self.model.trainable_variables)

                self.optimizer.apply_gradients(
                    grads_and_vars, experimental_aggregate_gradients=False)

                tf.cond(
                    self.global_step == 1, lambda: hvd.broadcast_variables(
                        self.model.variables + self.optimizer.variables(),
                        root_rank=0), lambda: tf.constant(True))
            else:
                grad_utils.minimize_using_explicit_allreduce(
                    tape, self.optimizer, loss, self.model.trainable_variables)

            if self.flags_obj.modeling:
                sess = tf.compat.v1.Session()
                # pbtxt generation
                tf.io.write_graph(sess.graph.as_graph_def(add_shapes=True),
                                  self.flags_obj.model_dir, 'graph.pbtxt')
                # meta graph generation
                tf.compat.v1.train.export_meta_graph(
                    filename='checkpoint_model.meta',
                    meta_info_def=None,
                    graph_def=None,
                    saver_def=None,
                    collection_list=None,
                    as_text=False,
                    graph=None,
                    export_scope=None,
                    clear_devices=False,
                    clear_extraneous_savers=False,
                    strip_default_attrs=False,
                    save_debug_info=False)

            if self.train_loss:
                self.train_loss.update_state(loss)
            if self.train_accuracy:
                self.train_accuracy.update_state(labels, logits)